@@ -924,6 +924,43 @@ def _enrich_memreport(self, args):
924924 # Hook for enriching memory report with additional metadata
925925 return {}
926926
927+ def specialize (self , ** kwargs ):
928+ """
929+ """
930+
931+ specialize = as_tuple (kwargs .pop ('specialize' , []))
932+
933+ if not specialize :
934+ return self , kwargs
935+
936+ # FIXME: Cannot cope with things like sizes/strides yet since it only
937+ # looks at the parameters
938+
939+ # Build the arguments list for specialization
940+ with self ._profiler .timer_on ('specialization' ):
941+ args = self .arguments (** kwargs )
942+ # Uses parameters here since Specializer needs {symbol: sympy value}
943+ specialized_values = {p : sympify (args [p .name ])
944+ for p in self .parameters
945+ if p .name in specialize }
946+
947+ op = Specializer (specialized_values ).visit (self )
948+
949+ with switch_log_level (comm = args .comm ):
950+ self ._emit_args_profiling ('specialization' )
951+
952+ unspecialized_kwargs = {k : v for k , v in kwargs .items ()
953+ if k not in specialize }
954+
955+ return op , unspecialized_kwargs
956+
957+ def apply_specialize (self , ** kwargs ):
958+ """
959+ """
960+
961+ op , unspecialized_kwargs = self .specialize (** kwargs )
962+ return op .apply (** unspecialized_kwargs )
963+
927964 def apply (self , ** kwargs ):
928965 """
929966 Execute the Operator.
@@ -986,29 +1023,6 @@ def apply(self, **kwargs):
9861023 >>> op = Operator(Eq(u3.forward, u3 + 1))
9871024 >>> summary = op.apply(time_M=10)
9881025 """
989- # Get items expected to be specialized
990- specialize = as_tuple (kwargs .pop ('specialize' , []))
991-
992- if specialize :
993- # FIXME: Cannot cope with things like sizes/strides yet since it only
994- # looks at the parameters
995-
996- # Build the arguments list for specialization
997- with self ._profiler .timer_on ('specialization' ):
998- args = self .arguments (** kwargs )
999- # Uses parameters here since Specializer needs {symbol: sympy value}
1000- specialized_values = {p : sympify (args [p .name ])
1001- for p in self .parameters if p .name in specialize }
1002-
1003- op = Specializer (specialized_values ).visit (self )
1004-
1005- with switch_log_level (comm = args .comm ):
1006- self ._emit_args_profiling ('specialization' )
1007-
1008- unspecialized_kwargs = {k : v for k , v in kwargs .items ()
1009- if k not in specialize }
1010-
1011- return op .apply (** unspecialized_kwargs )
10121026
10131027 # Compile the operator before building the arguments list
10141028 # to avoid out of memory with greedy compilers
0 commit comments