Skip to content

Commit 90fcbae

Browse files
committed
More fixes
1 parent d42dbe6 commit 90fcbae

7 files changed

Lines changed: 21 additions & 24 deletions

ext/AbstractDifferentiationChainRulesCoreExt.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
module AbstractDifferentiationChainRulesCoreExt
22

33
import AbstractDifferentiation as AD
4-
if AD.EXTENSIONS_SUPPORTED
5-
using ChainRulesCore: ChainRulesCore
6-
else
7-
using .ChainRulesCore: ChainRulesCore
8-
end
4+
using ChainRulesCore: ChainRulesCore
95

10-
AD.@primitive function pullback_function(ab::AD.ReverseRuleConfigBackend, f, xs...)
6+
AD.@primitive function pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...)
117
_, back = ChainRulesCore.rrule_via_ad(ab.ruleconfig, f, xs...)
128
pullback(vs) = Base.tail(back(vs))
139
pullback(vs::Tuple{Any}) = Base.tail(back(first(vs)))

ext/AbstractDifferentiationForwardDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function AD.ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing)
2424
return AD.ForwardDiffBackend{getchunksize(chunksize)}()
2525
end
2626

27-
AD.@primitive function pushforward_function(::AD.ForwardDiffBackend, f, xs...)
27+
AD.@primitive function pushforward_function(ba::AD.ForwardDiffBackend, f, xs...)
2828
return function pushforward(vs)
2929
if length(xs) == 1
3030
v = vs isa Tuple ? only(vs) : vs

ext/AbstractDifferentiationReverseDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ AD.primal_value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x)
1313
AD.primal_value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x)
1414
AD.primal_value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x)
1515

16-
AD.@primitive function jacobian(::AD.ReverseDiffBackend, f, xs...)
16+
AD.@primitive function jacobian(ba::AD.ReverseDiffBackend, f, xs...)
1717
xs_arr = map(AD.asarray, xs)
1818
tape = ReverseDiff.JacobianTape(xs_arr) do (xs_arr...)
1919
xs_new = map(xs, xs_arr) do x, x_arr

ext/AbstractDifferentiationTrackerExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ AD.primal_value(x::Tracker.TrackedReal) = Tracker.data(x)
1515
AD.primal_value(x::Tracker.TrackedArray) = Tracker.data(x)
1616
AD.primal_value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x)
1717

18-
AD.@primitive function pullback_function(::AD.TrackerBackend, f, xs...)
18+
AD.@primitive function pullback_function(ba::AD.TrackerBackend, f, xs...)
1919
value, back = Tracker.forward(f, xs...)
2020
function pullback(ws)
2121
if ws isa Tuple && !(value isa Tuple)

src/AbstractDifferentiation.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -506,16 +506,16 @@ macro primitive(expr)
506506
end
507507

508508
function define_pushforward_function_and_friends(fdef)
509-
fdef[:name] = :(AbstractDifferentiation.pushforward_function)
509+
fdef[:name] = :($(AbstractDifferentiation).pushforward_function)
510510
args = fdef[:args]
511511
funcs = quote
512512
$(ExprTools.combinedef(fdef))
513-
function AbstractDifferentiation.jacobian($(args...),)
514-
identity_like = AbstractDifferentiation.identity_matrix_like($(args[3:end]...),)
515-
pff = AbstractDifferentiation.pushforward_function($(args...),)
513+
function $(AbstractDifferentiation).jacobian($(args...),)
514+
identity_like = $(identity_matrix_like)($(args[3:end]...),)
515+
pff = $(pushforward_function)($(args...),)
516516
if eltype(identity_like) <: Tuple{Vararg{Union{AbstractMatrix, Number}}}
517517
return map(identity_like) do identity_like_i
518-
return mapreduce(hcat, AbstractDifferentiation._eachcol.(identity_like_i)...) do (cols...)
518+
return mapreduce(hcat, $(_eachcol).(identity_like_i)...) do (cols...)
519519
pff(cols)
520520
end
521521
end
@@ -542,17 +542,17 @@ function define_pushforward_function_and_friends(fdef)
542542
end
543543

544544
function define_pullback_function_and_friends(fdef)
545-
fdef[:name] = :(AbstractDifferentiation.pullback_function)
545+
fdef[:name] = :($(AbstractDifferentiation).pullback_function)
546546
args = fdef[:args]
547547
funcs = quote
548548
$(ExprTools.combinedef(fdef))
549-
function AbstractDifferentiation.jacobian($(args...),)
550-
value_and_pbf = AbstractDifferentiation.value_and_pullback_function($(args...),)
549+
function $(AbstractDifferentiation).jacobian($(args...),)
550+
value_and_pbf = $(value_and_pullback_function)($(args...),)
551551
value, _ = value_and_pbf(nothing)
552-
identity_like = AbstractDifferentiation.identity_matrix_like(value)
552+
identity_like = $(identity_matrix_like)(value)
553553
if eltype(identity_like) <: Tuple{Vararg{AbstractMatrix}}
554554
return map(identity_like) do identity_like_i
555-
return mapreduce(vcat, AbstractDifferentiation._eachcol.(identity_like_i)...) do (cols...)
555+
return mapreduce(vcat, $(_eachcol).(identity_like_i)...) do (cols...)
556556
value_and_pbf(cols)[2]'
557557
end
558558
end
@@ -575,12 +575,12 @@ _eachcol(a::Number) = (a,)
575575
_eachcol(a) = eachcol(a)
576576

577577
function define_jacobian_and_friends(fdef)
578-
fdef[:name] = :(AbstractDifferentiation.jacobian)
578+
fdef[:name] = :($(AbstractDifferentiation).jacobian)
579579
return ExprTools.combinedef(fdef)
580580
end
581581

582582
function define_primal_value(fdef)
583-
fdef[:name] = :(AbstractDifferentiation.primal_value)
583+
fdef[:name] = :($(AbstractDifferentiation).primal_value)
584584
return ExprTools.combinedef(fdef)
585585
end
586586

src/backends.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ struct TrackerBackend <: AbstractReverseMode end
5454
AD backend that uses reverse mode with any ChainRules-compatible reverse-mode AD package.
5555
5656
!!! note
57-
To be able to use this backend, you have to load ChainRulesCore.
57+
On Julia >= 1.9, you have to load ChainRulesCore (possibly implicitly by loading
58+
a ChainRules-compatible AD package) to be able to use this backend.
5859
"""
5960
struct ReverseRuleConfigBackend{RC} <: AbstractReverseMode
6061
ruleconfig::RC

test/test_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ function test_hessians(backend; multiple_inputs=false, test_types=true)
150150
else
151151
# explicit test that AbstractDifferentiation throws an error
152152
# don't support tuple of Hessians
153-
@test_throws AssertionError H1 = AD.hessian(backend, fgrad, (xvec, yvec))
154-
@test_throws MethodError H1 = AD.hessian(backend, fgrad, xvec, yvec)
153+
@test_throws ArgumentError AD.hessian(backend, fgrad, (xvec, yvec))
154+
@test_throws MethodError AD.hessian(backend, fgrad, xvec, yvec)
155155
end
156156

157157
# @test dfgraddxdx(xvec,yvec) ≈ H1[1] atol=1e-10

0 commit comments

Comments
 (0)