Skip to content

Commit 031f7f4

Browse files
authored
Retrieve pullback only once for ReverseRuleConfigBackend (#51)
* Retrieve pullback only once for ReverseRuleConfigBackend * Update Project.toml * `HasReverseMode` is not exported... * Fix CRC imports * Missing bracket
1 parent 8f0d6db commit 031f7f4

3 files changed

Lines changed: 7 additions & 11 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AbstractDifferentiation"
22
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
33
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
4-
version = "0.4.1"
4+
version = "0.4.2"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/AbstractDifferentiation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module AbstractDifferentiation
22

33
using LinearAlgebra, ExprTools, Requires, Compat
4-
using ChainRulesCore: RuleConfig, rrule_via_ad
4+
using ChainRulesCore: ChainRulesCore
55

66
export AD
77

src/ruleconfig.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,13 @@
33
44
AD backend that uses reverse mode with any ChainRules-compatible reverse-mode AD package.
55
"""
6-
struct ReverseRuleConfigBackend{RC <: RuleConfig} <: AbstractReverseMode
6+
struct ReverseRuleConfigBackend{RC<:ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}} <: AbstractReverseMode
77
ruleconfig::RC
88
end
99

1010
AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...)
11-
return (vs) -> begin
12-
_, back = rrule_via_ad(ab.ruleconfig, f, xs...)
13-
if vs isa Tuple && length(vs) === 1
14-
return Base.tail(back(vs[1]))
15-
else
16-
return Base.tail(back(vs))
17-
end
18-
end
11+
_, back = ChainRulesCore.rrule_via_ad(ab.ruleconfig, f, xs...)
12+
pullback(vs) = Base.tail(back(vs))
13+
pullback(vs::Tuple{Any}) = Base.tail(back(first(vs)))
14+
return pullback
1915
end

0 commit comments

Comments
 (0)