Skip to content

Commit 66bd0a1

Browse files
authored
Merge pull request #378 from JuliaDiff/ox/convertingZygoteRules
Write docs on converting ZygoteRules
2 parents 4b1a2f6 + 54d0972 commit 66bd0a1

7 files changed

Lines changed: 144 additions & 6 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ docs/build
77
docs/site
88
docs/src/assets/chainrules.css
99
docs/src/assets/indigo.css
10+
.vscode/settings.json

docs/Manifest.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
1414
path = ".."
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.10.6"
16+
version = "0.10.7"
1717

1818
[[Compat]]
1919
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
20-
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
20+
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
2121
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
22-
version = "3.30.0"
22+
version = "3.31.0"
2323

2424
[[Dates]]
2525
deps = ["Printf"]

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ makedocs(
5454
"Debug Mode" => "debug_mode.md",
5555
"Gradient Accumulation" => "gradient_accumulation.md",
5656
"Usage in AD" => "use_in_ad_system.md",
57+
"Converting ZygoteRules" => "converting_zygoterules.md",
5758
"Design" => [
5859
"Changing the Primal" => "design/changing_the_primal.md",
5960
"Many Differential Types" => "design/many_differentials.md",

docs/src/FAQ.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ To the best of our knowledge no Julia AD system, with support for the definition
4040
At some point in the future ChainRules may support these. Maybe.
4141

4242

43-
## What is the difference between `ZeroTangent` and `NoTangent` ?
43+
## [What is the difference between `ZeroTangent` and `NoTangent` ?](@id faq_abstract_zero)
4444
`ZeroTangent` and `NoTangent` act almost exactly the same in practice: they result in no change whenever added to anything.
4545
Odds are if you write a rule that returns the wrong one everything will just work fine.
4646
We provide both to allow for clearer writing of rules, and easier debugging.

docs/src/complex.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# How do chain rules work for complex functions?
1+
# [How do chain rules work for complex functions?](@id complexfunctions)
22

33
ChainRules follows the convention that `frule` applied to a function ``f(x + i y) = u(x,y) + i v(x,y)`` with perturbation ``\Delta x + i \Delta y`` returns the value and
44
```math

docs/src/converting_zygoterules.md

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Converting ZygoteRules.@adjoint to `rrule`s
2+
3+
[ZygoteRules.jl](https://github.com/FluxML/ZygoteRules.jl) is a legacy package similar to ChainRulesCore but supporting [Zygote.jl](https://github.com/FluxML/Zygote.jl) only.
4+
5+
If you have some rules written with ZygoteRules it is a good idea to upgrade them to use ChainRules instead.
6+
Zygote will still be able to use them, but so will other AD systems,
7+
and you will get access to some more advanced features.
8+
Some of these features are currently ignored by Zygote, but could be supported in the future.
9+
10+
## Example
11+
Consider the function
12+
```julia
13+
struct Foo
14+
a::Float64,
15+
b::Float64
16+
end
17+
18+
f(x, y::Foo, z) = 2*x + y.a
19+
```
20+
21+
### ZygoteRules
22+
```julia
23+
@adjoint function f(x, y::Foo, z)
24+
f_pullback(Ω̄) = (2Ω̄, NamedTuple(;a=Ω̄, b=nothing), nothing)
25+
return f(x, y, z), f_pullback
26+
end
27+
```
28+
29+
### ChainRules
30+
```julia
31+
function rrule(::typeof(f), x, y::Foo, z)
32+
f_pullback(Ω̄) = (NoTangent(), 2Ω̄, Tangent{Foo}(;a=Ω̄), ZeroTangent())
33+
return f(x, y, z), f_pullback
34+
end
35+
```
36+
37+
## Write as a `rrule(::typeof(f), ...)`
38+
No magic macro here, `rrule` is the function that it is.
39+
The function it is the rule for is the first argument, or second argument if you need to take a [`RuleConfig`](@ref).
40+
41+
Note that when writing the rule for constructor you will need to use `::Type{Foo}`, not `typeof(Foo)`.
42+
See docs on [Constructors](@ref).
43+
44+
## Include the derivative with respect to the function object itself
45+
The `ZygoteRules.@adjoint` macro automagically[^1] inserts an extra `nothing` in the return for the function it generates to represent the derivative of output with respect to the function object.
46+
ChainRules as a philosophy avoids magic as much as possible, and thus require you to return it explicitly.
47+
If it is a plain function (like `typeof(sin)`), then the differential will be [`NoTangent`](@ref).
48+
49+
50+
[^1]: unless you write it in functor form (i.e. `@adjoint (f::MyType)(args...)=...`), in that case like for `rrule` you need to include it explictly.
51+
52+
## Tangent Type changes
53+
ChainRules uses tangent types that must represent vector spaces (i.e. tangent spaces).
54+
They need to have things like `+` defined on them.
55+
ZygoteRules takes a more adhoc approach to this.
56+
57+
### `nothing` becomes an `AbstractZero`
58+
ZygoteRules uses nothing to represent some sense of zero, in a primal type agnostic way.
59+
There are many senses of zero.
60+
ChainRules represents two of them, as subtypes of [`AbstractZero`](@ref).
61+
62+
[`ZeroTangent`](@ref) for the case that there is no relationship between the primal output and the primal input.
63+
[`NoTangent`](@ref) for the case where conceptually the tangent space doesn't exist.
64+
e.g. what is the Tangent to a String or an index: those can't be perturbed.
65+
66+
See [FAQ on the difference between `ZeroTangent` and `NoTangent`](@ref faq_abstract_zero).
67+
At the end of the day it doesn't matter too much if you get them wrong.
68+
`NoTangent` and `ZeroTangent` more or less act identically.
69+
70+
### `Tuple`s and `NamedTuple`s become `Tangent{T}`s
71+
Zygote uses `Tuple`s and `NamedTuple`s to represent the structural tangents for `Tuple`s and `struct`s respectively.
72+
ChainRules core provides a generic [`Tangent{T}`](@ref Tangent) to represent the structural tangent of a primal type `T`.
73+
It takes positional arguments if representing tangent for a `Tuple`.
74+
Or keyword argument to represent the tangent for a `struct` or a `NamedTuple`.
75+
When representing a `struct` you only need to list the nonzero fields -- any not given are implicit considered to be [`ZeroTangent`](@ref).
76+
77+
When we say structural tangent we mean tangent types that are based only on the structure of the primal.
78+
This is in contrast to a natural tangent which captures some knowledge based on what the primal type represents.
79+
(E.g. for arrays a natural tangent is often the same kind of array).
80+
For more details see the the [design docs on the many tangent types](@ref manytypes)
81+
82+
83+
## Calling back into AD (`ZygoteRules.pullback`)
84+
Rules that need to call back into the AD system, e.g, for higher order functions like `map(f, xs)`, need to be changed.
85+
In `ZygoteRules` you can use `ZygoteRules.pullback` or `ZygoteRules._pullback`, which will always result in calling into Zygote.
86+
Since ChainRules is AD agnostic, you can't do that.
87+
Instead you use a [`RuleConfig`](@ref) to specify requirements of an AD system e.g `::RuleConfig{>:HasReverseMode}` work for Zygote,
88+
and then use [`rrule_via_ad`](@ref).
89+
90+
See the [docs on calling back into AD](@ref config) for more details.
91+
92+
## Consider adding some thunks
93+
94+
A feature ChainRulesCore offers that ZygoteRules doesn't is support for thunks.
95+
Thunks delay work until it is needed, and avoid it if it never is.
96+
See docs on [`@thunk`](@ref), [`Thunk`](@ref), [`InplaceableThunk`](@ref).
97+
98+
You don't have to use thunks, though.
99+
It is easy to go overboard with using thunks.
100+
101+
## Testing Changes
102+
103+
One of the advantages of using ChainRules is that you can easily and robustly test your rules with [ChainRulesTestUtils.jl](https://juliadiff.org/ChainRulesTestUtils.jl/stable/).
104+
This uses finite differencing to test the accuracy of derivative, as well as checks the correctness of the API.
105+
It should catch anything you might have gotten wrong referred to in this page.
106+
107+
The test for the above example is `test_rrule(f, 2.5, Foo(9.9, 7.2), 31.0)`.
108+
You can see it looks a lot like an example call to `rrule`, just with the prefix `test_` added to the start.
109+
110+
## `@nograd` becomes `@non_differentiable`
111+
Probably more or less with no changes.
112+
[`@non_differentiable`](@ref) also lets you specify a signature in case you want to restrict non-differentiability to a certain subset of argument types.
113+
114+
## No such thing a `literal_getproperty`
115+
That is just `getproperty`, it takes `Symbol`.
116+
It should constant-fold.
117+
It likely doesn't though as Zygote doesn't play nice with the optimizer.
118+
119+
## Take embedded spaces and types seriously
120+
Traditionally Zygote has taken a very laissez-faire attitude towards types and mathematical spaces.
121+
Sometimes treating `Real`s as embedded in the `Complex` plane; sometimes not.
122+
Sometimes treating sparse and structuredly-sparse matrix as embedded in the space of dense matrices.
123+
Writing rules that apply to any `Array{T}` which perhaps are only applicable for `Array{<:Real}` and not so much for `Array{Quaternion}`.
124+
Traditionally ChainRules takes a much more considered approach.
125+
126+
See for example our [docs on how to handle complex numbers](@ref complexfunctions) correctly.
127+
(The outcome of several long long long discussions with a number of experts in our community)
128+
129+
Now, I am not here to tell you what to do in your package, but this is a good time to reconsider how seriously you take these things in the rules you are converting.
130+
131+
## What if I miss something
132+
133+
It is not great, but it probably OK.
134+
Zygote's ChainRules interface is fairly forgiving.
135+
Other AD systems may not be.
136+
If you test with [ChainRulesTestUtils.jl](https://juliadiff.org/ChainRulesTestUtils.jl/stable/) then you can be confident that you didn't miss anything.

docs/src/design/many_differentials.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Design Notes: The many-to-many relationship between differential types and primal types.
1+
# [Design Notes: The many-to-many relationship between differential types and primal types](@id manytypes)
22

33
ChainRules has a system where one primal type (the type having its derivative taken) can have multiple possible differential types (the type of the derivative); and where one differential type can correspond to multiple primal types.
44
This is in-contrast to the Swift AD efforts, which has one differential type per primal type (Swift uses the term associated tangent type, rather than differential type).

0 commit comments

Comments
 (0)