|
| 1 | +# Converting ZygoteRules.@adjoint to `rrule`s |
| 2 | + |
| 3 | +[ZygoteRules.jl](https://github.com/FluxML/ZygoteRules.jl) is a legacy package similar to ChainRulesCore but supporting [Zygote.jl](https://github.com/FluxML/Zygote.jl) only. |
| 4 | + |
| 5 | +If you have some rules written with ZygoteRules it is a good idea to upgrade them to use ChainRules instead. |
| 6 | +Zygote will still be able to use them, but so will other AD systems, |
| 7 | +and you will get access to some more advanced features. |
| 8 | +Some of these features are currently ignored by Zygote, but could be supported in the future. |
| 9 | + |
| 10 | +## Example |
| 11 | +Consider the function |
| 12 | +```julia |
| 13 | +struct Foo |
| 14 | + a::Float64, |
| 15 | + b::Float64 |
| 16 | +end |
| 17 | + |
| 18 | +f(x, y::Foo, z) = 2*x + y.a |
| 19 | +``` |
| 20 | + |
| 21 | +### ZygoteRules |
| 22 | +```julia |
| 23 | +@adjoint function f(x, y::Foo, z) |
| 24 | + f_pullback(Ω̄) = (2Ω̄, NamedTuple(;a=Ω̄, b=nothing), nothing) |
| 25 | + return f(x, y, z), f_pullback |
| 26 | +end |
| 27 | +``` |
| 28 | + |
| 29 | +### ChainRules |
| 30 | +```julia |
| 31 | +function rrule(::typeof(f), x, y::Foo, z) |
| 32 | + f_pullback(Ω̄) = (NoTangent(), 2Ω̄, Tangent{Foo}(;a=Ω̄), ZeroTangent()) |
| 33 | + return f(x, y, z), f_pullback |
| 34 | +end |
| 35 | +``` |
| 36 | + |
| 37 | +## Write as a `rrule(::typeof(f), ...)` |
| 38 | +No magic macro here, `rrule` is the function that it is. |
| 39 | +The function it is the rule for is the first argument, or second argument if you need to take a [`RuleConfig`](@ref). |
| 40 | + |
| 41 | +Note that when writing the rule for constructor you will need to use `::Type{Foo}`, not `typeof(Foo)`. |
| 42 | +See docs on [Constructors](@ref). |
| 43 | + |
| 44 | +## Include the derivative with respect to the function object itself |
| 45 | +The `ZygoteRules.@adjoint` macro automagically[^1] inserts an extra `nothing` in the return for the function it generates to represent the derivative of output with respect to the function object. |
| 46 | +ChainRules as a philosophy avoids magic as much as possible, and thus require you to return it explicitly. |
| 47 | +If it is a plain function (like `typeof(sin)`), then the differential will be [`NoTangent`](@ref). |
| 48 | + |
| 49 | + |
| 50 | +[^1]: unless you write it in functor form (i.e. `@adjoint (f::MyType)(args...)=...`), in that case like for `rrule` you need to include it explictly. |
| 51 | + |
| 52 | +## Tangent Type changes |
| 53 | +ChainRules uses tangent types that must represent vector spaces (i.e. tangent spaces). |
| 54 | +They need to have things like `+` defined on them. |
| 55 | +ZygoteRules takes a more adhoc approach to this. |
| 56 | + |
| 57 | +### `nothing` becomes an `AbstractZero` |
| 58 | +ZygoteRules uses nothing to represent some sense of zero, in a primal type agnostic way. |
| 59 | +There are many senses of zero. |
| 60 | +ChainRules represents two of them, as subtypes of [`AbstractZero`](@ref). |
| 61 | + |
| 62 | +[`ZeroTangent`](@ref) for the case that there is no relationship between the primal output and the primal input. |
| 63 | +[`NoTangent`](@ref) for the case where conceptually the tangent space doesn't exist. |
| 64 | +e.g. what is the Tangent to a String or an index: those can't be perturbed. |
| 65 | + |
| 66 | +See [FAQ on the difference between `ZeroTangent` and `NoTangent`](@ref faq_abstract_zero). |
| 67 | +At the end of the day it doesn't matter too much if you get them wrong. |
| 68 | +`NoTangent` and `ZeroTangent` more or less act identically. |
| 69 | + |
| 70 | +### `Tuple`s and `NamedTuple`s become `Tangent{T}`s |
| 71 | +Zygote uses `Tuple`s and `NamedTuple`s to represent the structural tangents for `Tuple`s and `struct`s respectively. |
| 72 | +ChainRules core provides a generic [`Tangent{T}`](@ref Tangent) to represent the structural tangent of a primal type `T`. |
| 73 | +It takes positional arguments if representing tangent for a `Tuple`. |
| 74 | +Or keyword argument to represent the tangent for a `struct` or a `NamedTuple`. |
| 75 | +When representing a `struct` you only need to list the nonzero fields -- any not given are implicit considered to be [`ZeroTangent`](@ref). |
| 76 | + |
| 77 | +When we say structural tangent we mean tangent types that are based only on the structure of the primal. |
| 78 | +This is in contrast to a natural tangent which captures some knowledge based on what the primal type represents. |
| 79 | +(E.g. for arrays a natural tangent is often the same kind of array). |
| 80 | +For more details see the the [design docs on the many tangent types](@ref manytypes) |
| 81 | + |
| 82 | + |
| 83 | +## Calling back into AD (`ZygoteRules.pullback`) |
| 84 | +Rules that need to call back into the AD system, e.g, for higher order functions like `map(f, xs)`, need to be changed. |
| 85 | +In `ZygoteRules` you can use `ZygoteRules.pullback` or `ZygoteRules._pullback`, which will always result in calling into Zygote. |
| 86 | +Since ChainRules is AD agnostic, you can't do that. |
| 87 | +Instead you use a [`RuleConfig`](@ref) to specify requirements of an AD system e.g `::RuleConfig{>:HasReverseMode}` work for Zygote, |
| 88 | +and then use [`rrule_via_ad`](@ref). |
| 89 | + |
| 90 | +See the [docs on calling back into AD](@ref config) for more details. |
| 91 | + |
| 92 | +## Consider adding some thunks |
| 93 | + |
| 94 | +A feature ChainRulesCore offers that ZygoteRules doesn't is support for thunks. |
| 95 | +Thunks delay work until it is needed, and avoid it if it never is. |
| 96 | +See docs on [`@thunk`](@ref), [`Thunk`](@ref), [`InplaceableThunk`](@ref). |
| 97 | + |
| 98 | +You don't have to use thunks, though. |
| 99 | +It is easy to go overboard with using thunks. |
| 100 | + |
| 101 | +## Testing Changes |
| 102 | + |
| 103 | +One of the advantages of using ChainRules is that you can easily and robustly test your rules with [ChainRulesTestUtils.jl](https://juliadiff.org/ChainRulesTestUtils.jl/stable/). |
| 104 | +This uses finite differencing to test the accuracy of derivative, as well as checks the correctness of the API. |
| 105 | +It should catch anything you might have gotten wrong referred to in this page. |
| 106 | + |
| 107 | +The test for the above example is `test_rrule(f, 2.5, Foo(9.9, 7.2), 31.0)`. |
| 108 | +You can see it looks a lot like an example call to `rrule`, just with the prefix `test_` added to the start. |
| 109 | + |
| 110 | +## `@nograd` becomes `@non_differentiable` |
| 111 | +Probably more or less with no changes. |
| 112 | +[`@non_differentiable`](@ref) also lets you specify a signature in case you want to restrict non-differentiability to a certain subset of argument types. |
| 113 | + |
| 114 | +## No such thing a `literal_getproperty` |
| 115 | +That is just `getproperty`, it takes `Symbol`. |
| 116 | +It should constant-fold. |
| 117 | +It likely doesn't though as Zygote doesn't play nice with the optimizer. |
| 118 | + |
| 119 | +## Take embedded spaces and types seriously |
| 120 | +Traditionally Zygote has taken a very laissez-faire attitude towards types and mathematical spaces. |
| 121 | +Sometimes treating `Real`s as embedded in the `Complex` plane; sometimes not. |
| 122 | +Sometimes treating sparse and structuredly-sparse matrix as embedded in the space of dense matrices. |
| 123 | +Writing rules that apply to any `Array{T}` which perhaps are only applicable for `Array{<:Real}` and not so much for `Array{Quaternion}`. |
| 124 | +Traditionally ChainRules takes a much more considered approach. |
| 125 | + |
| 126 | +See for example our [docs on how to handle complex numbers](@ref complexfunctions) correctly. |
| 127 | +(The outcome of several long long long discussions with a number of experts in our community) |
| 128 | + |
| 129 | +Now, I am not here to tell you what to do in your package, but this is a good time to reconsider how seriously you take these things in the rules you are converting. |
| 130 | + |
| 131 | +## What if I miss something |
| 132 | + |
| 133 | +It is not great, but it probably OK. |
| 134 | +Zygote's ChainRules interface is fairly forgiving. |
| 135 | +Other AD systems may not be. |
| 136 | +If you test with [ChainRulesTestUtils.jl](https://juliadiff.org/ChainRulesTestUtils.jl/stable/) then you can be confident that you didn't miss anything. |
0 commit comments