|
| 1 | +# Opting out of rules |
| 2 | + |
| 3 | +It is common to define rules fairly generically. |
| 4 | +Often matching (or exceeding) how generic the matching original primal method is. |
| 5 | +Sometimes this is not the correct behavour. |
| 6 | +Sometimes the AD can do better than this human defined rule. |
| 7 | +If this is generally the case, then we should not have the rule defined at all. |
| 8 | +But if it is only the case for a particular set of types, then we want to opt-out just that one. |
| 9 | +This is done with the [`@opt_out`](@ref) macro. |
| 10 | + |
| 11 | +Consider one might have a rrule for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself) |
| 12 | +```julia |
| 13 | +function rrule(::typeof(sum), x::AbstractArray{<:Number}; dims=:) |
| 14 | + y = sum(x; dims=dims) |
| 15 | + project = ProjectTo(x) |
| 16 | + function sum_pullback(ȳ) |
| 17 | + # broadcasting the two works out the size no-matter `dims` |
| 18 | + # project makes sure we stay in the same vector subspace as `x` |
| 19 | + # no putting in off-diagonal entries in Diagonal etc |
| 20 | + x̄ = project(broadcast(last∘tuple, x, ȳ))) |
| 21 | + return (NoTangent(), x̄) |
| 22 | + end |
| 23 | + return y, sum_pullback |
| 24 | +end |
| 25 | +``` |
| 26 | + |
| 27 | +That is a fairly reasonable `rrule` for the vast majority of cases. |
| 28 | + |
| 29 | +You might have a custom array type for which you could write a faster rule. |
| 30 | +For example, the pullback for summing a`SkewSymmetric` matrix can be optimizes to basically be `Diagonal(fill(ȳ, size(x,1)))`. |
| 31 | +To do that, you can indeed write another more specific [`rrule`](@ref). |
| 32 | +But another case is where the AD system itself would generate a more optimized case. |
| 33 | + |
| 34 | +For example, the a [`NamedDimArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type. |
| 35 | +It's sum method is basically just to call `sum` on it's parent. |
| 36 | +It is entirely conceivable[^1] that the AD system can do better than our `rrule` here. |
| 37 | +For example by avoiding the overhead of [`project`ing](@ref ProjectTo). |
| 38 | + |
| 39 | +To opt-out of using the `rrule` and to allow the AD system to do its own thing we use the |
| 40 | +[`@opt_out`](@ref) macro, to say to not use it for sum. |
| 41 | + |
| 42 | +```julia |
| 43 | +@opt_out rrule(::typeof(sum), ::NamedDimsArray) |
| 44 | +``` |
| 45 | + |
| 46 | +We could even opt-out for all 1 arg functions. |
| 47 | +```@julia |
| 48 | +@opt_out rrule(::Any, ::NamedDimsArray) |
| 49 | +``` |
| 50 | +Though this is likely to cause some method-ambiguities. |
| 51 | + |
| 52 | +Similar can be done `@opt_out frule`. |
| 53 | +It can also be done passing in a [`RuleConfig`](@ref config). |
| 54 | + |
| 55 | + |
| 56 | +### How to support this (for AD implementers) |
| 57 | + |
| 58 | +We provide two ways to know that a rule has been opted out of. |
| 59 | + |
| 60 | +## `rrule` / `frule` returns `nothing` |
| 61 | + |
| 62 | +`@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`. |
| 63 | + |
| 64 | +If you are in a position to generate code, in response to values returned by function calls then you can do something like: |
| 65 | +```@julia |
| 66 | +res = rrule(f, xs) |
| 67 | +if res === nothing |
| 68 | + y, pullback = perform_ad_via_decomposition(r, xs) # do AD without hitting the rrule |
| 69 | +else |
| 70 | + y, pullback = res |
| 71 | +end |
| 72 | +``` |
| 73 | +The Julia compiler, will specialize based on inferring the restun type of `rrule`, and so can remove that branch. |
| 74 | + |
| 75 | +## `no_rrule` / `no_frule` has a method |
| 76 | + |
| 77 | +`@opt_out` also defines a method for [`ChainRulesCore.no_frule`](@ref) or [`ChainRulesCore.no_rrule`](@ref). |
| 78 | +The use of this method doesn't matter, what matters is it's method-table. |
| 79 | +A simple thing you can do with this is not support opting out. |
| 80 | +To do this, filter all methods from the `rrule`/`frule` method table that also occur in the `no_frule`/`no_rrule` table. |
| 81 | +This will thus avoid ever hitting an `rrule`/`frule` that returns `nothing` and thus makes your library error. |
| 82 | +This is easily done, though it does mean ignoring the user's stated desire to opt out of the rule. |
| 83 | + |
| 84 | +More complex you can use this to generate code that triggers your AD. |
| 85 | +If for a given signature there is a more specific method in the `no_rrule`/`no_frule` method-table, than the one that would be hit from the `rrule`/`frule` table |
| 86 | +(Excluding the one that exactly matches which will return `nothing`) then you know that the rule should not be used. |
| 87 | +You can, likely by looking at the primal method table, workout which method you would have it if the rule had not been defined, |
| 88 | +and then `invoke` it. |
| 89 | + |
| 90 | + |
| 91 | + |
| 92 | +[^1]: It is also possible, that this is not the case. Benchmark your real uses cases. |
0 commit comments