diff --git a/Project.toml b/Project.toml index a5b0374..e708532 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,9 @@ ChainRulesCore = "1.0.0" julia = "1" [extras] +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test"] +test = ["Test", "ChainRulesTestUtils", "Zygote"] diff --git a/src/ChainRulesDeclarationHelpers.jl b/src/ChainRulesDeclarationHelpers.jl index c167b2e..0ab7647 100644 --- a/src/ChainRulesDeclarationHelpers.jl +++ b/src/ChainRulesDeclarationHelpers.jl @@ -1,3 +1,4 @@ module ChainRulesDeclarationHelpers - + export @rrule_from_frule + include("rrule_from_frule.jl") end # module \ No newline at end of file diff --git a/src/rrule_from_frule.jl b/src/rrule_from_frule.jl new file mode 100644 index 0000000..e9c6ba8 --- /dev/null +++ b/src/rrule_from_frule.jl @@ -0,0 +1,36 @@ +using ChainRulesCore + +""" + @rrule_from_frule(signature_expression) + +A helper to define an rrule by calling back into AD on an already defined frule. +The pushforward at the point is a linear function which has the same derivative as +the primal at that point. So asking for its derivative gives you the derivative of +the primal. Moreover, asking for its rrule effectively amounts to transposing the +pushforward implied by the frule. + +Further Reading +[1] Roy Frostig, Matthew J. Johnson, Dougal Maclaurin, Adam Paszke, and Alexey Radul. + Decomposing reverse-mode automatic differentiation. LAFI 2021 +""" +macro rrule_from_frule(signature_expression) + @assert Meta.isexpr(signature_expression, :call) + f = signature_expression.args[1] + args = signature_expression.args[2:end] + @assert all(Meta.isexpr.(args, :(::), 2)) + return rrule_from_frule_expr(__source__, f, args) +end + +function rrule_from_frule_expr(__source__, f, args) + f_instance_name = gensym(Symbol(:instance_, Symbol(f))) + return quote + function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, $f_instance_name::Core.Typeof($(esc(f))), $(args...)) + $(__source__) + pushforward(Δfarg...) = frule(Δfarg, $f_instance_name, $(args...))[2] + _, back = rrule_via_ad(config, pushforward, $f_instance_name, $(args...)) + y = $f_instance_name($(args...)) # TODO optimize away redundant primal computation + f_pullback(Δy) = back(Δy)[2:end] + return y, f_pullback + end + end +end diff --git a/test/rrule_from_frule.jl b/test/rrule_from_frule.jl new file mode 100644 index 0000000..d431692 --- /dev/null +++ b/test/rrule_from_frule.jl @@ -0,0 +1,47 @@ +@testset "rrule_from_frule" begin + @testset "single argument" begin + function f(x) + a = sin.(x) + b = sum(a) + c = b * a + return c + end + + function ChainRulesCore.frule((Δself, Δx), ::typeof(f), x) + a, ȧ = sin.(x), cos.(x) .* Δx + b, ḃ = sum(a), sum(ȧ) + c, ċ = b * a, ḃ * a + b * ȧ + return c, ċ + end + + x = rand(3) + test_frule(f, x) + + @rrule_from_frule f(x::AbstractArray{<:Real}) + test_rrule(Zygote.ZygoteRuleConfig(), f, x; check_inferred=false) + end + @testset "multiple arguments" begin + function f(x, z) + a = x + z + b = sin.(a) + c = sum(b) + d = c * b + return d + end + + function ChainRulesCore.frule((Δself, Δx, Δz), ::typeof(f), x, z) + a, ȧ = x + z, Δx + Δz + b, ḃ = sin.(a), cos.(a) .* ȧ + c, ċ = sum(b), sum(ḃ) + d, ḋ = c * b, ċ * b + c * ḃ + return d, ḋ + end + + x = rand(3) + z = rand(3) + test_frule(f, x, z) + + @rrule_from_frule f(x::AbstractArray{<:Real}, z::AbstractArray{<:Real}) + test_rrule(Zygote.ZygoteRuleConfig(), f, x, z; check_inferred=false) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ed22140..d22c4b3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,9 @@ -using ChainRulesDeclarationHelpers using Test +using ChainRulesDeclarationHelpers +using ChainRulesCore +using ChainRulesTestUtils +using Zygote @testset "ChainRulesDeclarationHelpers" begin - -end \ No newline at end of file + include("rrule_from_frule.jl") +end