Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rrule_from_frule macro #2

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

niklasschmitz
Copy link

@niklasschmitz niklasschmitz commented Oct 1, 2021

This draft PR implements a proposal for a new rrule definition helper macro @rrule_from_frule. It generates an rrule from an already defined frule by calling-back into (reverse-mode) AD.

This is also inspired by JAX's approach to decomposing reverse-mode into (forward-mode) linearization + transposition (where transposition only needs rules for linear primitives).

This minimum working example illustrates the idea from a ChainRules.jl perspective:

# rrule from frule (transposition)
using Zygote
using ChainRulesCore
using LinearAlgebra

function f(x)
    a = sin.(x)
    b = sum(a)
    c = b * a
    return c
end

function ChainRulesCore.frule((Δself, Δx,), ::typeof(f), x)
    a, ȧ = sin.(x), cos.(x) .* Δx
    b, ḃ = sum(a), sum(ȧ)
    c, ċ = b * a, ḃ * a + b * ȧ
    return c, ċ
end

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(f), x)
    pushforward(Δfx...) = frule(Δfx, f, x)[2]
    _, back = rrule_via_ad(config, pushforward, f, x)
    f_pullback(Δy) = back(Δy)[2:end]
    return f(x), f_pullback
end

let x = rand(3)
    v = randn(3)
    w = randn(3)
    jvp(f, x, v) = frule((NoTangent(), v), f, x)[2]
    vjp(f, x, w) = rrule_via_ad(Zygote.ZygoteRuleConfig(), f, x)[2](w)[2]
    dot(w, jvp(f, x, v))  dot(vjp(f, x, w), v)
end

This PR aims to provide this approach as a macro that can be used like @rrule_from_frule f(x::AbstractArray{<:Real}) to minimize boiler-plate code.

test/rrule_from_frule.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member

oxinabox commented Oct 1, 2021

Nice.

If we resolved JuliaDiff/ChainRulesTestUtils.jl#221
we could drop the Zygote dependency from testing which would be good, but not blocking for this.

It might be nice to add a short explination of why this works into the docstring.

My notes from slack are

The pushforward at the point is a linear function which has the same derivative as the primal at that point.
So asking for it's derivative gives you the derivative of the primal.
And it is probably easier to AD than the primal function since you have already done the work of removing nonlinear parts. (at least with aggressive constant folding they will be gone, idk that we have that)

@niklasschmitz
Copy link
Author

Thanks @oxinabox! I now added a short docstring. I also have left two more TODO's in the code:

  1. Allowing more general function signature expressions in the macro, importantly multiple arguments (and possibly varargs).
    Should the handling of these be modelled after the macros in ChainRulesCore.jl or would you recommend a package like ExprTools.jl or MLStyle.jl? This is my first macro so I'm very open for guiding suggestions.

  2. Avoiding redundant computation of the primal. Currently, the frule is executed (inside the pushforward, discarding the primal value), and then additionally the primal is executed once to get the correct primal value to return from the rrule. I haven't solved it yet, but I think there's likely room for improvement here.

@oxinabox
Copy link
Member

oxinabox commented Oct 3, 2021

  1. Avoiding redundant computation of the primal. Currently, the frule is executed (inside the pushforward, discarding the primal value), and then additionally the primal is executed once to get the correct primal value to return from the rrule. I haven't solved it yet, but I think there's likely room for improvement here.

Related to this, I think Zygote (and probably Diffractor) will AFAIK also generate code for the pullback of that discarded primal i think.
Unlike Jax we can't aggressively constant fold that out of the pushforward, though I haven't seen the new tools that are in Julia 1.8 (and maybe 1.7-rc2?) (cc @simeonschaub might know).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants