Skip to content

Commit 73e3a2a

Browse files
committed
Write docs on converting ZygoteRules
1 parent 3f3019f commit 73e3a2a

7 files changed

Lines changed: 141 additions & 6 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ docs/build
77
docs/site
88
docs/src/assets/chainrules.css
99
docs/src/assets/indigo.css
10+
.vscode/settings.json

docs/Manifest.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
1414
path = ".."
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.10.6"
16+
version = "0.10.7"
1717

1818
[[Compat]]
1919
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
20-
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
20+
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
2121
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
22-
version = "3.30.0"
22+
version = "3.31.0"
2323

2424
[[Dates]]
2525
deps = ["Printf"]

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ makedocs(
5454
"Debug Mode" => "debug_mode.md",
5555
"Gradient Accumulation" => "gradient_accumulation.md",
5656
"Usage in AD" => "use_in_ad_system.md",
57+
"Converting ZygoteRules" => "converting_zygoterules.md",
5758
"Design" => [
5859
"Changing the Primal" => "design/changing_the_primal.md",
5960
"Many Differential Types" => "design/many_differentials.md",

docs/src/FAQ.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ To the best of our knowledge no Julia AD system, with support for the definition
4040
At some point in the future ChainRules may support these. Maybe.
4141

4242

43-
## What is the difference between `ZeroTangent` and `NoTangent` ?
43+
## [What is the difference between `ZeroTangent` and `NoTangent` ?](@id faq_abstract_zero)
4444
`ZeroTangent` and `NoTangent` act almost exactly the same in practice: they result in no change whenever added to anything.
4545
Odds are if you write a rule that returns the wrong one everything will just work fine.
4646
We provide both to allow for clearer writing of rules, and easier debugging.

docs/src/complex.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# How do chain rules work for complex functions?
1+
# [How do chain rules work for complex functions?](@id complexfunctions)
22

33
ChainRules follows the convention that `frule` applied to a function ``f(x + i y) = u(x,y) + i v(x,y)`` with perturbation ``\Delta x + i \Delta y`` returns the value and
44
```math

docs/src/converting_zygoterules.md

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Converting ZygoteRules.@adjoint to `rrule`s
2+
3+
[ZygoteRules.jl](https://github.com/FluxML/ZygoteRules.jl) is a legacy package similar to ChainRules but supporting [Zygote.jl](https://github.com/FluxML/Zygote.jl) only.
4+
5+
If you have some rules written with ZygoteRules it is a good idea to upgrade them to use ChainRules instead.
6+
Zygote will still be able to use them, but so will other AD systems,
7+
and you will get access to some more advanced features.
8+
(Which Zygote may or may not then ignore).
9+
10+
## Example
11+
Consider the function
12+
```julia
13+
struct Foo
14+
a::Float64,
15+
b::Float64
16+
end
17+
18+
f(x, y::Foo, z) = 2*x + y.a
19+
```
20+
21+
### ZygoteRules
22+
```julia
23+
@adjoint function f(x, y::Foo, z)
24+
f_pullback(Ω̄) = (2Ω̄, NamedTuple(;a=Ω̄, b=nothing), nothing)
25+
return f(x, y, z), f_pullback
26+
end
27+
```
28+
29+
### ChainRules
30+
```julia
31+
function rrule(::typeof(f), x, y::Foo, z)
32+
f_pullback(Ω̄) = (NoFields(), 2Ω̄, Tangent{Foo}(;a=Ω̄), ZeroTangent())
33+
return f(x, y, z), f_pullback
34+
end
35+
```
36+
37+
## Write as a `rrule(::typeof(f), ...)`
38+
No magic macro here, `rrule` is the function that it is.
39+
The function it is the rule for is the first argument, or second argument if you need to take a `[RuleConfig`](@ref).
40+
41+
## Include the derivative with respect to the function object itself
42+
The `ZygoteRules.@adjoint` macro automagically[^1] inserts a extra `nothing` in the return for the function it generates to represent the derivative of output with respect to the function object.
43+
ChainRules as a philosophy avoids magic as much as possible, and thus require you to return it explicitly.
44+
If it is a plain function (like `typeof(sin)`), then the differential will be [`NoTangent`](@ref).
45+
46+
47+
[^1]: unless you write it in functor form (i.e. `@adjoint (f::MyType)(args...)=...`), in that case like for `rrule` you need to include it explictly.
48+
49+
## Tangent Type changes
50+
ChainRules uses tangent types that must represent vector spaces (i.e. tangent spaces).
51+
They need to have things like `+` defined on them.
52+
ZygoteRules takes a more adhoc approach to this.
53+
54+
### `nothing` becomes an `AbstractZero`
55+
ZygoteRules uses nothing to represent some sense of zero, in a primal type agnostic way.
56+
There are many senses of zero.
57+
ChainRules represents two of them, as subtypes of [`AbstractZero`](@ref).
58+
59+
[`ZeroTangent`](@ref) for the case that there is no relationship between the primal output and the primal input.
60+
[`NoTangent`](@ref) for the case that conceptually the tangent space don't exist.
61+
e.g. what is the Tangent to a String or an index: those can't be perturbed.
62+
63+
See [FAQ on difference between `ZeroTangent` and `NoTangent`](@ref faq_abstract_zero).
64+
At the end of the day it doesn't matter too much if you get them wrong.
65+
`NoTangent` and `ZeroTangent` more or less act identically.
66+
67+
### `Tuple`s and `NamedTuple`s become `Tangent{T}`s
68+
Zygote uses `Tuple`s and `NamedTuple`s to represent the structural tangents for `Tuple`s and `struct`s respectively.
69+
ChainRules core provides a generic `Tangent{T}`(@ref Tangent) to represent the structural tangent of a primal type `T`.
70+
It takes positional arguments if representing tangent for a `Tuple`.
71+
Or keyword argument to represent the tangent for a `struct`.
72+
When representing a `struct` you only need to list the nonzero fields -- any not given are implicit considered to be [`ZeroTangent`](@ref).
73+
74+
When we say structural tangent we mean tangent types that are based only on the structure of the primal.
75+
This is in contrast to a natural tangent which captures some knowledge based on what the primal type represents.
76+
(E.g. for arrays a natural tangent is often the same kind of array).
77+
For more details see the the [design docs on the many tangent types](@ref manytypes)
78+
79+
80+
## Calling back into AD (`ZygoteRules.pullback`)
81+
Rules that need to call back into the AD system, e.g, for higher order functions like `map(f, xs)`, need to be changed.
82+
In `ZygoteRules` you can use `ZygoteRules.pullback` or `ZygoteRules._pullback`, which will always result in calling into Zygote.
83+
Since ChainRules is AD agnostic, you can't do that.
84+
Instead you use a [`RuleConfig`](@ref) to specify requirements of an AD system e.g `::RuleConfig{>:CanReverseMode}` work for Zygote,
85+
and then use [`rrule_via_ad`](@ref).
86+
87+
See the [docs on calling back into AD](@ref config) for more details.
88+
89+
## Consider adding some thunks
90+
91+
A feature ChainRulesCore offers that ZygoteRules doesn't is support for thunks.
92+
Where work is delayed until it is needed, and avoided if it never is.
93+
See docs on [`@thunk`](@ref), [`Thunk`](@ref), [`InplaceableThunk`](@ref).
94+
95+
You don't have to though.
96+
It is easy to go overboard with using thunks.
97+
98+
## Testing Changes
99+
100+
One of the advantages of using ChainRules is that you can easily and robustly test it with [ChainRulesTestUtils.jl](https://juliadiff.org/ChainRulesTestUtils.jl/stable/).
101+
This both uses finite differencing to test accuracy of derivative, as well as checking the correctness of the API.
102+
It should catch anything you might have gotten wrong referred to in this page.
103+
104+
The test for the above example is `test_rrule(f, 2.5, Foo(9.9, 7.2), 31.0)`.
105+
You can see it looks a lot like an example call to `rrule`, just with the prefix `test_` added to the start.
106+
107+
## `@nograd` becomes `@non_differentiable`
108+
Probably more or less with no changes.
109+
[`@non_differentiable`](@ref) also lets you specify a signature.
110+
111+
## No such thing a `literal_getproperty`
112+
That is just `getproperty`, it takes `Symbol`.
113+
It should constant-fold.
114+
It likely doesn't though as Zygote doesn't play nice with the optimizer.
115+
116+
## Take embedded spaces and types seriously
117+
Traditionally Zygote has taken a very laissez faire attitude towards types and mathematical spaces.
118+
Sometimes treating `Real`s as embedded in the `Complex` plane; some time not.
119+
Sometimes treating sparse and structuredly-sparse matrix as embedded in the space of dense matrixes.
120+
Writing rules that apply to any `Array{T}` which perhaps are only applicable for `Array{<:Real}` and not so much for `Array{Quaternion}`.
121+
Traditionally ChainRules takes a much more considered approach.
122+
123+
See for example our [docs on how to handle complex numbers](@ref complexfunctions) correctly.
124+
(The outcome of several long long long discussions with a number of expert in our community)
125+
126+
Now, I am not here to tell you what to do in your package, but this is a good time to reconsider how seriously you take these things in the rules you are converting.
127+
128+
## What if I miss something
129+
130+
It is not great, but it probably OK.
131+
Zygote's ChainRules interface is fairly forgiving.
132+
Other AD systems may not be.
133+
If you test with [ChainRulesTestUtils.jl](https://juliadiff.org/ChainRulesTestUtils.jl/stable/) then you can be confident that you didn't miss anything.

docs/src/design/many_differentials.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Design Notes: The many-to-many relationship between differential types and primal types.
1+
# [Design Notes: The many-to-many relationship between differential types and primal types](@id manytypes)
22

33
ChainRules has a system where one primal type (the type having its derivative taken) can have multiple possible differential types (the type of the derivative); and where one differential type can correspond to multiple primal types.
44
This is in-contrast to the Swift AD efforts, which has one differential type per primal type (Swift uses the term associated tangent type, rather than differential type).

0 commit comments

Comments
 (0)