108108
109109returns (in order) the correctly escaped:
110110 - `call` with out any type constraints
111- - `setup_stmts`: the content of `@setup` or `nothing ` if that is not provided,
111+ - `setup_stmts`: the content of `@setup` or `[] ` if that is not provided,
112112 - `inputs`: with all args having the constraints removed from call, or
113113 defaulting to `Number`
114114 - `partials`: which are all `Expr{:tuple,...}`
@@ -118,9 +118,9 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
118118 # Setup: normalizing input form etc
119119
120120 if Meta. isexpr (maybe_setup, :macrocall ) && maybe_setup. args[1 ] == Symbol (" @setup" )
121- setup_stmts = map ( esc, maybe_setup. args[3 : end ])
121+ setup_stmts = Any[ esc (ex) for ex in maybe_setup. args[3 : end ]]
122122 else
123- setup_stmts = ( nothing ,)
123+ setup_stmts = []
124124 partials = (maybe_setup, partials... )
125125 end
126126 @assert Meta. isexpr (call, :call )
@@ -185,10 +185,14 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials)
185185 # because this is a pull-back there is one per output of function
186186 Δs = _propagator_inputs (n_outputs)
187187
188+ # Make a projector for each argument
189+ projs, psetup = _make_projectors (call. args[2 : end ])
190+ append! (setup_stmts, psetup)
191+
188192 # 1 partial derivative per input
189193 pullback_returns = map (1 : n_inputs) do input_i
190194 ∂s = [partial. args[input_i] for partial in partials]
191- propagation_expr (Δs, ∂s, true )
195+ propagation_expr (Δs, ∂s, true , projs[input_i] )
192196 end
193197
194198 # Multi-output functions have pullbacks with a tuple input that will be destructured
@@ -215,14 +219,23 @@ end
215219" Declares properly hygenic inputs for propagation expressions"
216220_propagator_inputs (n) = [esc (gensym (Symbol (:Δ , i))) for i in 1 : n]
217221
222+ " given the variable names, escaped but without types, makes setup expressions for projection operators"
223+ function _make_projectors (xs)
224+ projs = map (x -> Symbol (:proj_ , x. args[1 ]), xs)
225+ setups = map ((x,p) -> :($ p = ProjectTo ($ x)), xs, projs)
226+ return projs, setups
227+ end
228+
218229"""
219- propagation_expr(Δs, ∂s, _conj = false)
230+ propagation_expr(Δs, ∂s, [ _conj= false, proj=identity] )
220231
221- Returns the expression for the propagation of
222- the input gradient `Δs` though the partials `∂s`.
223- Specify `_conj = true` to conjugate the partials.
232+ Returns the expression for the propagation of
233+ the input gradient `Δs` though the partials `∂s`.
234+ Specify `_conj = true` to conjugate the partials.
235+ Projector `proj` is a function that will be applied at the end;
236+ for `rrules` it is usually a `ProjectTo(x)`, for `frules` it is `identity`
224237"""
225- function propagation_expr (Δs, ∂s, _conj = false )
238+ function propagation_expr (Δs, ∂s, _conj= false , proj = identity )
226239 # This is basically Δs ⋅ ∂s
227240 _∂s = map (∂s) do ∂s_i
228241 if _conj
@@ -249,7 +262,7 @@ function propagation_expr(Δs, ∂s, _conj = false)
249262 :($ (_∂s[1 ]) * $ (Δs[1 ]))
250263 end
251264
252- return summed_∂_mul_Δs
265+ return :( $ proj ( $ summed_∂_mul_Δs))
253266end
254267
255268"""
0 commit comments