Skip to content

Commit 19ce815

Browse files
Merge pull request #68 from JuliaDiff/dw/weakdeps
Use weak dependencies if supported
2 parents b929457 + 9db3b75 commit 19ce815

18 files changed

Lines changed: 346 additions & 270 deletions

.github/workflows/CI.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,15 @@ jobs:
2020
fail-fast: false
2121
matrix:
2222
version:
23-
- '1.0'
2423
- '1.6'
2524
- '1'
26-
#- 'nightly'
25+
- 'nightly'
2726
os:
2827
- ubuntu-latest
2928
arch:
3029
- x64
3130
steps:
32-
- uses: actions/checkout@v2
31+
- uses: actions/checkout@v3
3332
- uses: julia-actions/setup-julia@v1
3433
with:
3534
version: ${{ matrix.version }}
@@ -38,6 +37,6 @@ jobs:
3837
- uses: julia-actions/julia-buildpkg@v1
3938
- uses: julia-actions/julia-runtest@v1
4039
- uses: julia-actions/julia-processcoverage@v1
41-
- uses: codecov/codecov-action@v1
40+
- uses: codecov/codecov-action@v3
4241
with:
43-
file: lcov.info
42+
files: lcov.info

Project.toml

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,46 @@
11
name = "AbstractDifferentiation"
22
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
33
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
4-
version = "0.4.4"
4+
version = "0.5.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8-
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
98
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
109
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1110
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1211

12+
[weakdeps]
13+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
14+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
15+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
16+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
17+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
18+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
19+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
20+
21+
[extensions]
22+
AbstractDifferentiationChainRulesCoreExt = "ChainRulesCore"
23+
AbstractDifferentiationFiniteDifferencesExt = "FiniteDifferences"
24+
AbstractDifferentiationForwardDiffExt = ["DiffResults", "ForwardDiff"]
25+
AbstractDifferentiationReverseDiffExt = ["DiffResults", "ReverseDiff"]
26+
AbstractDifferentiationTrackerExt = "Tracker"
27+
AbstractDifferentiationZygoteExt = "Zygote"
28+
1329
[compat]
1430
ChainRulesCore = "1"
15-
Compat = "3, 4"
31+
DiffResults = "1"
1632
ExprTools = "0.1"
33+
FiniteDifferences = "0.12"
1734
ForwardDiff = "0.10"
18-
Requires = "0.5, 1"
35+
Requires = "1"
1936
ReverseDiff = "1"
20-
julia = "1"
37+
Tracker = "0.2"
38+
Zygote = "0.6"
39+
julia = "1.6"
2140

2241
[extras]
42+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
43+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
2344
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2445
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2546
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -29,4 +50,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2950
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3051

3152
[targets]
32-
test = ["Test", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Tracker", "Zygote"]
53+
test = ["Test", "ChainRulesCore", "DiffResults", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Tracker", "Zygote"]

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# AbstractDifferentiation
22

3-
[![CI](https://github.com/JuliaDiff/AbstractDifferentiation.jl/workflows/CI/badge.svg?branch=master)](https://github.com/JuliaDiff/AbstractDifferentiation.jl/actions?query=workflow%3ACI)
3+
[![CI](https://github.com/JuliaDiff/AbstractDifferentiation.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/JuliaDiff/AbstractDifferentiation.jl/actions/workflows/CI.yml?query=branch%3Amaster)
44
[![Coverage](https://codecov.io/gh/JuliaDiff/AbstractDifferentiation.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaDiff/AbstractDifferentiation.jl)
55

66
## Motivation
@@ -11,18 +11,18 @@ Julia has more (automatic) differentiation packages than you can count on 2 hand
1111

1212
## Loading `AbstractDifferentiation`
1313

14-
To load `AbstractDifferentiation`, use:
14+
To load `AbstractDifferentiation`, it is recommended to use
1515
```julia
16-
using AbstractDifferentiation
16+
import AbstractDifferentiation as AD
1717
```
18-
`AbstractDifferentiation` exports a single name `AD` which is just an alias for the `AbstractDifferentiation` module itself. You can use this to access names inside `AbstractDifferentiation` using `AD.<>` instead of typing the long name `AbstractDifferentiation`.
18+
With the `AD` alias you can access names inside of `AbstractDifferentiation` using `AD.<>` instead of typing the long name `AbstractDifferentiation`.
1919

2020
## `AbstractDifferentiation` backends
2121

2222
To use `AbstractDifferentiation`, first construct a backend instance `ab::AD.AbstractBackend` using your favorite differentiation package in Julia that supports `AbstractDifferentiation`.
2323
In particular, you may want to use `AD.ReverseRuleConfigBackend(ruleconfig)` for any [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible reverse mode differentiation package.
2424

25-
The following backends are temporarily made available by `AbstractDifferentiation` as soon as their corresponding package is loaded (thanks to [Requires.jl](https://github.com/JuliaPackaging/Requires.jl)):
25+
The following backends are temporarily made available by `AbstractDifferentiation` as soon as their corresponding package is loaded (thanks to [weak dependencies](https://pkgdocs.julialang.org/dev/creating-packages/#Weak-dependencies) on Julia ≥ 1.9 and [Requires.jl](https://github.com/JuliaPackaging/Requires.jl) on older Julia versions):
2626

2727
- `AD.ForwardDiffBackend()` for [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)
2828
- `AD.FiniteDifferencesBackend()` for [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl)
@@ -35,7 +35,7 @@ In the long term, these backend objects (and many more) will be defined within t
3535
Here's an example:
3636

3737
```julia
38-
julia> using AbstractDifferentiation, Zygote
38+
julia> import AbstractDifferentiation as AD, Zygote
3939

4040
julia> ab = AD.ZygoteBackend()
4141
AbstractDifferentiation.ReverseRuleConfigBackend{Zygote.ZygoteRuleConfig{Zygote.Context}}(Zygote.ZygoteRuleConfig{Zygote.Context}(Zygote.Context(nothing)))
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module AbstractDifferentiationChainRulesCoreExt
2+
3+
import AbstractDifferentiation as AD
4+
using ChainRulesCore: ChainRulesCore
5+
6+
AD.@primitive function pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...)
7+
_, back = ChainRulesCore.rrule_via_ad(ba.ruleconfig, f, xs...)
8+
pullback(vs) = Base.tail(back(vs))
9+
pullback(vs::Tuple{Any}) = Base.tail(back(first(vs)))
10+
return pullback
11+
end
12+
13+
end # module
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
module AbstractDifferentiationFiniteDifferencesExt
2+
3+
import AbstractDifferentiation as AD
4+
if AD.EXTENSIONS_SUPPORTED
5+
using FiniteDifferences: FiniteDifferences
6+
else
7+
using ..FiniteDifferences: FiniteDifferences
8+
end
9+
10+
"""
11+
FiniteDifferencesBackend(method=FiniteDifferences.central_fdm(5, 1))
12+
13+
Create an AD backend that uses forward mode with FiniteDifferences.jl.
14+
"""
15+
AD.FiniteDifferencesBackend() = AD.FiniteDifferencesBackend(FiniteDifferences.central_fdm(5, 1))
16+
17+
AD.@primitive function jacobian(ba::AD.FiniteDifferencesBackend, f, xs...)
18+
return FiniteDifferences.jacobian(ba.method, f, xs...)
19+
end
20+
21+
function AD.pushforward_function(ba::AD.FiniteDifferencesBackend, f, xs...)
22+
return function pushforward(vs)
23+
ws = FiniteDifferences.jvp(ba.method, f, tuple.(xs, vs)...)
24+
return length(xs) == 1 ? (ws,) : ws
25+
end
26+
end
27+
28+
function AD.pullback_function(ba::AD.FiniteDifferencesBackend, f, xs...)
29+
function pullback(vs)
30+
return FiniteDifferences.j′vp(ba.method, f, vs, xs...)
31+
end
32+
end
33+
34+
end # module

src/forwarddiff.jl renamed to ext/AbstractDifferentiationForwardDiffExt.jl

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1-
using .ForwardDiff: ForwardDiff, DiffResults
2-
3-
"""
4-
ForwardDiffBackend{CS}
5-
6-
AD backend that uses forward mode with ForwardDiff.jl.
7-
8-
The type parameter `CS` denotes the chunk size of the differentiation algorithm. If it is
9-
`Nothing`, then ForwardiffDiff uses a heuristic to set the chunk size based on the input.
10-
11-
See also: [ForwardDiff.jl: Configuring Chunk Size](https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Configuring-Chunk-Size)
12-
"""
13-
struct ForwardDiffBackend{CS} <: AbstractForwardMode end
1+
module AbstractDifferentiationForwardDiffExt
2+
3+
import AbstractDifferentiation as AD
4+
if AD.EXTENSIONS_SUPPORTED
5+
using DiffResults: DiffResults
6+
using ForwardDiff: ForwardDiff
7+
else
8+
using ..DiffResults: DiffResults
9+
using ..ForwardDiff: ForwardDiff
10+
end
1411

1512
"""
1613
ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing)
@@ -23,11 +20,11 @@ ForwarddDiff uses a heuristic to set the chunk size based on the input. Alternat
2320
2421
See also: [ForwardDiff.jl: Configuring Chunk Size](https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Configuring-Chunk-Size)
2522
"""
26-
function ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing)
27-
return ForwardDiffBackend{getchunksize(chunksize)}()
23+
function AD.ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing)
24+
return AD.ForwardDiffBackend{getchunksize(chunksize)}()
2825
end
2926

30-
@primitive function pushforward_function(ba::ForwardDiffBackend, f, xs...)
27+
AD.@primitive function pushforward_function(ba::AD.ForwardDiffBackend, f, xs...)
3128
return function pushforward(vs)
3229
if length(xs) == 1
3330
v = vs isa Tuple ? only(vs) : vs
@@ -38,35 +35,35 @@ end
3835
end
3936
end
4037

41-
primal_value(x::ForwardDiff.Dual) = ForwardDiff.value(x)
42-
primal_value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x)
38+
AD.primal_value(x::ForwardDiff.Dual) = ForwardDiff.value(x)
39+
AD.primal_value(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x)
4340

4441
# these implementations are more efficient than the fallbacks
4542

46-
function gradient(ba::ForwardDiffBackend, f, x::AbstractArray)
43+
function AD.gradient(ba::AD.ForwardDiffBackend, f, x::AbstractArray)
4744
cfg = ForwardDiff.GradientConfig(f, x, chunk(ba, x))
4845
return (ForwardDiff.gradient(f, x, cfg),)
4946
end
5047

51-
function jacobian(ba::ForwardDiffBackend, f, x::AbstractArray)
52-
cfg = ForwardDiff.JacobianConfig(asarray f, x, chunk(ba, x))
53-
return (ForwardDiff.jacobian(asarray f, x, cfg),)
48+
function AD.jacobian(ba::AD.ForwardDiffBackend, f, x::AbstractArray)
49+
cfg = ForwardDiff.JacobianConfig(AD.asarray f, x, chunk(ba, x))
50+
return (ForwardDiff.jacobian(AD.asarray f, x, cfg),)
5451
end
55-
jacobian(::ForwardDiffBackend, f, x::Number) = (ForwardDiff.derivative(f, x),)
52+
AD.jacobian(::AD.ForwardDiffBackend, f, x::Number) = (ForwardDiff.derivative(f, x),)
5653

57-
function hessian(ba::ForwardDiffBackend, f, x::AbstractArray)
54+
function AD.hessian(ba::AD.ForwardDiffBackend, f, x::AbstractArray)
5855
cfg = ForwardDiff.HessianConfig(f, x, chunk(ba, x))
5956
return (ForwardDiff.hessian(f, x, cfg),)
6057
end
6158

62-
function value_and_gradient(ba::ForwardDiffBackend, f, x::AbstractArray)
59+
function AD.value_and_gradient(ba::AD.ForwardDiffBackend, f, x::AbstractArray)
6360
result = DiffResults.GradientResult(x)
6461
cfg = ForwardDiff.GradientConfig(f, x, chunk(ba, x))
6562
ForwardDiff.gradient!(result, f, x, cfg)
6663
return DiffResults.value(result), (DiffResults.derivative(result),)
6764
end
6865

69-
function value_and_hessian(ba::ForwardDiffBackend, f, x)
66+
function AD.value_and_hessian(ba::AD.ForwardDiffBackend, f, x)
7067
result = DiffResults.HessianResult(x)
7168
cfg = ForwardDiff.HessianConfig(f, result, x, chunk(ba, x))
7269
ForwardDiff.hessian!(result, f, x, cfg)
@@ -80,5 +77,7 @@ end
8077
getchunksize(::Nothing) = Nothing
8178
getchunksize(::Val{N}) where {N} = N
8279

83-
chunk(::ForwardDiffBackend{Nothing}, x) = ForwardDiff.Chunk(x)
84-
chunk(::ForwardDiffBackend{N}, _) where {N} = ForwardDiff.Chunk{N}()
80+
chunk(::AD.ForwardDiffBackend{Nothing}, x) = ForwardDiff.Chunk(x)
81+
chunk(::AD.ForwardDiffBackend{N}, _) where {N} = ForwardDiff.Chunk{N}()
82+
83+
end # module
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
module AbstractDifferentiationReverseDiffExt
2+
3+
import AbstractDifferentiation as AD
4+
if AD.EXTENSIONS_SUPPORTED
5+
using DiffResults: DiffResults
6+
using ReverseDiff: ReverseDiff
7+
else
8+
using ..DiffResults: DiffResults
9+
using ..ReverseDiff: ReverseDiff
10+
end
11+
12+
AD.primal_value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x)
13+
AD.primal_value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x)
14+
AD.primal_value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x)
15+
16+
AD.@primitive function jacobian(ba::AD.ReverseDiffBackend, f, xs...)
17+
xs_arr = map(AD.asarray, xs)
18+
tape = ReverseDiff.JacobianTape(xs_arr) do (xs_arr...)
19+
xs_new = map(xs, xs_arr) do x, x_arr
20+
return x isa Number ? only(x_arr) : x_arr
21+
end
22+
return AD.asarray(f(xs_new...))
23+
end
24+
results = ReverseDiff.jacobian!(tape, xs_arr)
25+
return map(xs, results) do x, result
26+
return x isa Number ? vec(result) : result
27+
end
28+
end
29+
function AD.jacobian(::AD.ReverseDiffBackend, f, xs::AbstractArray...)
30+
return ReverseDiff.jacobian(AD.asarray f, xs)
31+
end
32+
33+
function AD.derivative(::AD.ReverseDiffBackend, f, xs::Number...)
34+
tape = ReverseDiff.InstructionTape()
35+
xs_tracked = ReverseDiff.TrackedReal.(xs, zero.(xs), Ref(tape))
36+
y_tracked = f(xs_tracked...)
37+
ReverseDiff.seed!(y_tracked)
38+
ReverseDiff.reverse_pass!(tape)
39+
return ReverseDiff.deriv.(xs_tracked)
40+
end
41+
42+
function AD.gradient(::AD.ReverseDiffBackend, f, xs::AbstractArray...)
43+
return ReverseDiff.gradient(f, xs)
44+
end
45+
46+
function AD.hessian(::AD.ReverseDiffBackend, f, x::AbstractArray)
47+
return (ReverseDiff.hessian(f, x),)
48+
end
49+
50+
function AD.value_and_gradient(::AD.ReverseDiffBackend, f, x::AbstractArray)
51+
result = DiffResults.GradientResult(x)
52+
cfg = ReverseDiff.GradientConfig(x)
53+
ReverseDiff.gradient!(result, f, x, cfg)
54+
return DiffResults.value(result), (DiffResults.derivative(result),)
55+
end
56+
57+
function AD.value_and_hessian(::AD.ReverseDiffBackend, f, x)
58+
result = DiffResults.HessianResult(x)
59+
cfg = ReverseDiff.HessianConfig(result, x)
60+
ReverseDiff.hessian!(result, f, x, cfg)
61+
return DiffResults.value(result), (DiffResults.hessian(result),)
62+
end
63+
64+
end # module
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
module AbstractDifferentiationTrackerExt
2+
3+
import AbstractDifferentiation as AD
4+
if AD.EXTENSIONS_SUPPORTED
5+
using Tracker: Tracker
6+
else
7+
using ..Tracker: Tracker
8+
end
9+
10+
function AD.second_lowest(::AD.TrackerBackend)
11+
return throw(ArgumentError("Tracker backend does not support nested differentiation."))
12+
end
13+
14+
AD.primal_value(x::Tracker.TrackedReal) = Tracker.data(x)
15+
AD.primal_value(x::Tracker.TrackedArray) = Tracker.data(x)
16+
AD.primal_value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x)
17+
18+
AD.@primitive function pullback_function(ba::AD.TrackerBackend, f, xs...)
19+
value, back = Tracker.forward(f, xs...)
20+
function pullback(ws)
21+
if ws isa Tuple && !(value isa Tuple)
22+
map(Tracker.data, back(only(ws)))
23+
else
24+
map(Tracker.data, back(ws))
25+
end
26+
end
27+
return pullback
28+
end
29+
30+
function AD.derivative(::AD.TrackerBackend, f, xs::Number...)
31+
return Tracker.data.(Tracker.gradient(f, xs...))
32+
end
33+
34+
function AD.gradient(::AD.TrackerBackend, f, xs::AbstractVector...)
35+
return Tracker.data.(Tracker.gradient(f, xs...))
36+
end
37+
38+
end # module
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module AbstractDifferentiationZygoteExt
2+
3+
import AbstractDifferentiation as AD
4+
if AD.EXTENSIONS_SUPPORTED
5+
using Zygote: Zygote
6+
else
7+
using ..Zygote: Zygote
8+
end
9+
10+
AD.ZygoteBackend() = AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())
11+
12+
end # module

0 commit comments

Comments
 (0)