11# These are some macros (and supporting functions) to make it easier to define rules.
2- using Base. Meta
32
4- macro strip_linenos (expr)
5- return esc (Base. remove_linenums! (expr))
6- end
3+ # ###########################################################################################
4+ # ## @scalar_rule
75
86"""
97 @scalar_rule(f(x₁, x₂, ...),
@@ -88,7 +86,6 @@ macro scalar_rule(call, maybe_setup, partials...)
8886 frule_expr = scalar_frule_expr (__source__, f, call, setup_stmts, inputs, partials)
8987 rrule_expr = scalar_rrule_expr (__source__, f, call, setup_stmts, inputs, partials)
9088
91- # ###########################################################################
9289 # Final return: building the expression to insert in the place of this macro
9390 code = quote
9491 if ! ($ f isa Type) && fieldcount (typeof ($ f)) > 0
@@ -114,7 +111,6 @@ returns (in order) the correctly escaped:
114111 - `partials`: which are all `Expr{:tuple,...}`
115112"""
116113function _normalize_scalarrules_macro_input (call, maybe_setup, partials)
117- # ###########################################################################
118114 # Setup: normalizing input form etc
119115
120116 if Meta. isexpr (maybe_setup, :macrocall ) && maybe_setup. args[1 ] == Symbol (" @setup" )
@@ -275,6 +271,9 @@ propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propna
275271propagator_name (fname:: Symbol , propname:: Symbol ) = Symbol (fname, :_ , propname)
276272propagator_name (fname:: QuoteNode , propname:: Symbol ) = propagator_name (fname. value, propname)
277273
274+ # ###########################################################################################
275+ # ## @non_differentiable
276+
278277"""
279278 @non_differentiable(signature_expression)
280279
@@ -324,7 +323,7 @@ macro non_differentiable(sig_expr)
324323 :($ (primal_name)($ (unconstrained_args... )))
325324 else
326325 normal_args = unconstrained_args[1 : end - 1 ]
327- var_arg = unconstrained_args [end ]
326+ var_arg = s [end ]
328327 :($ (primal_name)($ (normal_args... ), $ (var_arg). .. ))
329328 end
330329
@@ -393,10 +392,13 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
393392 end
394393end
395394
396-
397- # ##########
395+ # ###########################################################################################
398396# Helpers
399397
398+ macro strip_linenos (expr)
399+ return esc (Base. remove_linenums! (expr))
400+ end
401+
400402"""
401403 _isvararg(expr)
402404
0 commit comments