@@ -56,7 +56,7 @@ derivative/setup expressions.
5656This macro assumes complex functions are holomorphic. In general, for non-holomorphic
5757functions, the `frule` and `rrule` must be defined manually.
5858
59- If the derivative is one, (e.g. for identity functions) `true` can be used as the most
59+ If the derivative is one, (e.g. for identity functions) `true` can be used as the most
6060general multiplicative identity.
6161
6262The `@setup` argument can be elided if no setup code is need. In other
@@ -244,24 +244,13 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity)
244244 esc (∂s_i)
245245 end
246246 end
247- n∂s = length (_∂s)
248247
249- summed_∂_mul_Δs = if n∂s > 1
250- # Explicit multiplication is only performed for the first pair
251- # of partial and gradient.
252- init_expr = :((* ). ($ (_∂s[1 ]), $ (Δs[1 ])))
253-
254- # Apply `muladd` iteratively.
255- foldl (Iterators. drop (zip (_∂s, Δs), 1 ); init= init_expr) do ex, (∂s_i, Δs_i)
256- :((muladd). ($ ∂s_i, $ Δs_i, $ ex))
257- end
258- else
259- # Note: we don't want to do broadcasting with only 1 multiply (no `+`),
260- # because some arrays overload multiply with scalar. Avoiding
261- # broadcasting saves compilation time.
262- :($ (_∂s[1 ]) * $ (Δs[1 ]))
248+ # Apply `muladd` iteratively.
249+ # Explicit multiplication is only performed for the first pair of partial and gradient.
250+ init_expr = :(* ($ (_∂s[1 ]), $ (Δs[1 ])))
251+ summed_∂_mul_Δs = foldl (Iterators. drop (zip (_∂s, Δs), 1 ); init= init_expr) do ex, (∂s_i, Δs_i)
252+ :(muladd ($ ∂s_i, $ Δs_i, $ ex))
263253 end
264-
265254 return :($ proj ($ summed_∂_mul_Δs))
266255end
267256
0 commit comments