Skip to content

Commit 3218c08

Browse files
authored
New Zygote context in every call to AD.pullback_function (#77)
* New Zygote context in every call to `AD.pullback_function` * Make fix more modular
1 parent 19ce815 commit 3218c08

5 files changed

Lines changed: 34 additions & 2 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AbstractDifferentiation"
22
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
33
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
4-
version = "0.5.0"
4+
version = "0.5.1"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

ext/AbstractDifferentiationChainRulesCoreExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import AbstractDifferentiation as AD
44
using ChainRulesCore: ChainRulesCore
55

66
AD.@primitive function pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...)
7-
_, back = ChainRulesCore.rrule_via_ad(ba.ruleconfig, f, xs...)
7+
_, back = ChainRulesCore.rrule_via_ad(AD.ruleconfig(ba), f, xs...)
88
pullback(vs) = Base.tail(back(vs))
99
pullback(vs::Tuple{Any}) = Base.tail(back(first(vs)))
1010
return pullback

ext/AbstractDifferentiationZygoteExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,9 @@ end
99

1010
AD.ZygoteBackend() = AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())
1111

12+
# Context should not persist between different AD calls: fixes #69
13+
function AD.ruleconfig(::AD.ReverseRuleConfigBackend{<:Zygote.ZygoteRuleConfig})
14+
return Zygote.ZygoteRuleConfig()
15+
end
16+
1217
end # module

src/backends.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ struct ReverseRuleConfigBackend{RC} <: AbstractReverseMode
6161
ruleconfig::RC
6262
end
6363

64+
# internal function for extracting the rule config
65+
# falls back to returning the wrapped `ruleconfig` but can be specialized
66+
# e.g., for Zygote to fix #69
67+
ruleconfig(ba::ReverseRuleConfigBackend) = ba.ruleconfig
68+
6469
"""
6570
ZygoteBackend()
6671

test/ruleconfig.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,26 @@ using Zygote
3030
test_lazy_jacobians(backend)
3131
end
3232
end
33+
34+
# issue #69
35+
@testset "Zygote context" begin
36+
ad = AD.ZygoteBackend()
37+
38+
# example in #69: context is not mutated
39+
@test ad.ruleconfig.context.cache === nothing
40+
@test AD.derivative(ad, exp, 1.0) === (exp(1.0),)
41+
@test ad.ruleconfig.context.cache === nothing
42+
@test AD.derivative(ad, exp, 1.0) === (exp(1.0),)
43+
@test ad.ruleconfig.context.cache === nothing
44+
45+
# Jacobian computation still works
46+
# https://github.com/JuliaDiff/AbstractDifferentiation.jl/pull/70#issuecomment-1449481724
47+
function f(x, a)
48+
r = Ref(x)
49+
r[] = r[] + r[]
50+
r[] = r[] * a
51+
r[]
52+
end
53+
@test AD.jacobian(ad, f, [1, 2, 3], 3) == ([6.0 0.0 0.0; 0.0 6.0 0.0; 0.0 0.0 6.0], [2.0, 4.0, 6.0])
54+
end
3355
end

0 commit comments

Comments
 (0)