1212import inspect
1313import os
1414import textwrap
15+ import tokenize
16+ from io import StringIO
1517from typing import ClassVar
1618
1719_PRINT_DSL_KERNEL = os .environ .get ("PRINT_DSL_KERNEL" , "" ).strip ().lower ()
@@ -30,28 +32,129 @@ def _normalize_miniexpr_scalar(value):
3032 raise TypeError ("Unsupported scalar type for miniexpr specialization" )
3133
3234
33- class _MiniexprScalarSpecializer (ast .NodeTransformer ):
34- def __init__ (self , replacements : dict [str , int | float ]):
35- self .replacements = replacements
35+ def _line_starts (text : str ) -> list [int ]:
36+ starts = [0 ]
37+ for i , ch in enumerate (text ):
38+ if ch == "\n " :
39+ starts .append (i + 1 )
40+ return starts
3641
37- def visit_Name (self , node ):
38- if isinstance (node .ctx , ast .Load ) and node .id in self .replacements :
39- return ast .copy_location (ast .Constant (value = self .replacements [node .id ]), node )
40- return node
4142
42- def visit_Call (self , node ):
43- node = self .generic_visit (node )
44- if (
45- isinstance (node .func , ast .Name )
46- and node .func .id in {"float" , "int" }
47- and len (node .args ) == 1
48- and not node .keywords
49- and isinstance (node .args [0 ], ast .Constant )
50- and isinstance (node .args [0 ].value , int | float | bool )
51- ):
52- folded = float (node .args [0 ].value ) if node .func .id == "float" else int (node .args [0 ].value )
53- return ast .copy_location (ast .Constant (value = folded ), node )
54- return node
43+ def _to_abs (line_starts : list [int ], line : int , col : int ) -> int :
44+ return line_starts [line - 1 ] + col
45+
46+
47+ def _find_def_signature_span (text : str ):
48+ tokens = list (tokenize .generate_tokens (StringIO (text ).readline ))
49+ for i , tok in enumerate (tokens ):
50+ if tok .type != tokenize .NAME or tok .string != "def" :
51+ continue
52+ lparen = None
53+ rparen = None
54+ colon = None
55+ depth = 0
56+ for j in range (i + 1 , len (tokens )):
57+ t = tokens [j ]
58+ if lparen is None :
59+ if t .type == tokenize .OP and t .string == "(" :
60+ lparen = t
61+ depth = 1
62+ continue
63+ if t .type == tokenize .OP and t .string == "(" :
64+ depth += 1
65+ continue
66+ if t .type == tokenize .OP and t .string == ")" :
67+ depth -= 1
68+ if depth == 0 :
69+ rparen = t
70+ continue
71+ if rparen is not None and t .type == tokenize .OP and t .string == ":" :
72+ colon = t
73+ break
74+ if lparen is not None and rparen is not None :
75+ return lparen , rparen , colon
76+ return None , None , None
77+
78+
79+ def _remove_scalar_params_preserving_source (text : str , scalar_replacements : dict [str , int | float ]):
80+ if not scalar_replacements :
81+ return text , 0
82+
83+ lparen , rparen , colon = _find_def_signature_span (text )
84+ if lparen is None or rparen is None :
85+ return text , 0
86+
87+ try :
88+ tree = ast .parse (text )
89+ except Exception :
90+ return text , 0
91+
92+ func = next ((n for n in tree .body if isinstance (n , ast .FunctionDef )), None )
93+ if func is None :
94+ return text , 0
95+
96+ kept = [a .arg for a in (func .args .posonlyargs + func .args .args ) if a .arg not in scalar_replacements ]
97+ line_starts = _line_starts (text )
98+ pstart = _to_abs (line_starts , lparen .end [0 ], lparen .end [1 ])
99+ pend = _to_abs (line_starts , rparen .start [0 ], rparen .start [1 ])
100+ updated = f"{ text [:pstart ]} { ', ' .join (kept )} { text [pend :]} "
101+ body_start = 0
102+ if colon is not None :
103+ body_start = _to_abs (_line_starts (updated ), colon .end [0 ], colon .end [1 ])
104+ return updated , body_start
105+
106+
107+ def _replace_scalar_names_preserving_source (
108+ text : str , scalar_replacements : dict [str , int | float ], body_start : int
109+ ):
110+ if not scalar_replacements :
111+ return text
112+
113+ line_starts = _line_starts (text )
114+ tokens = list (tokenize .generate_tokens (StringIO (text ).readline ))
115+ significant = {
116+ tokenize .NAME ,
117+ tokenize .NUMBER ,
118+ tokenize .STRING ,
119+ tokenize .OP ,
120+ tokenize .INDENT ,
121+ tokenize .DEDENT ,
122+ }
123+ assign_ops = {"=" , "+=" , "-=" , "*=" , "/=" , "//=" , "%=" , "&=" , "|=" , "^=" , "<<=" , ">>=" , ":=" }
124+ edits = []
125+ for i , tok in enumerate (tokens ):
126+ if tok .type != tokenize .NAME or tok .string not in scalar_replacements :
127+ continue
128+ start_abs = _to_abs (line_starts , tok .start [0 ], tok .start [1 ])
129+ if start_abs < body_start :
130+ continue
131+
132+ prev_sig = None
133+ for j in range (i - 1 , - 1 , - 1 ):
134+ if tokens [j ].type in significant :
135+ prev_sig = tokens [j ]
136+ break
137+ if prev_sig is not None and prev_sig .type == tokenize .OP and prev_sig .string == "." :
138+ continue
139+
140+ next_sig = None
141+ for j in range (i + 1 , len (tokens )):
142+ if tokens [j ].type in significant :
143+ next_sig = tokens [j ]
144+ break
145+ if next_sig is not None and next_sig .type == tokenize .OP and next_sig .string in assign_ops :
146+ continue
147+
148+ end_abs = _to_abs (line_starts , tok .end [0 ], tok .end [1 ])
149+ edits .append ((start_abs , end_abs , repr (scalar_replacements [tok .string ])))
150+
151+ if not edits :
152+ return text
153+
154+ out = text
155+ for start , end , repl in sorted (edits , key = lambda e : e [0 ], reverse = True ):
156+ out = f"{ out [:start ]} { repl } { out [end :]} "
157+ return out
55158
56159
57160def specialize_miniexpr_inputs (expr_string : str , operands : dict ):
@@ -73,14 +176,9 @@ def specialize_miniexpr_inputs(expr_string: str, operands: dict):
73176 if not scalar_replacements :
74177 return expr_string , operands
75178
76- tree = ast .parse (expr_string )
77- tree = _MiniexprScalarSpecializer (scalar_replacements ).visit (tree )
78- for node in tree .body :
79- if isinstance (node , ast .FunctionDef ):
80- node .args .posonlyargs = [a for a in node .args .posonlyargs if a .arg not in scalar_replacements ]
81- node .args .args = [a for a in node .args .args if a .arg not in scalar_replacements ]
82- ast .fix_missing_locations (tree )
83- return ast .unparse (tree ), array_operands
179+ rewritten , body_start = _remove_scalar_params_preserving_source (expr_string , scalar_replacements )
180+ rewritten = _replace_scalar_names_preserving_source (rewritten , scalar_replacements , body_start )
181+ return rewritten , array_operands
84182
85183
86184def specialize_dsl_miniexpr_inputs (expr_string : str , operands : dict ):
@@ -141,14 +239,37 @@ def _extract_dsl(self, func):
141239 if func_node is None :
142240 raise ValueError ("No function definition found for DSL extraction" )
143241
144- builder = _DSLBuilder ( )
145- dsl_source , input_names = builder . build (func_node )
242+ dsl_source = self . _slice_function_source ( source , func_node )
243+ input_names = self . _input_names_from_signature (func_node )
146244 if _PRINT_DSL_KERNEL :
147245 func_name = getattr (func , "__name__" , "<dsl_kernel>" )
148246 print (f"[DSLKernel:{ func_name } ] dsl_source (full):" )
149247 print (dsl_source )
150248 return dsl_source , input_names
151249
250+ @staticmethod
251+ def _slice_function_source (source : str , func_node : ast .FunctionDef ) -> str :
252+ lines = source .splitlines ()
253+ start = func_node .lineno - 1
254+ end_lineno = getattr (func_node , "end_lineno" , None )
255+ if end_lineno is None :
256+ end = len (lines )
257+ else :
258+ end = end_lineno
259+ return "\n " .join (lines [start :end ])
260+
261+ @staticmethod
262+ def _input_names_from_signature (func_node : ast .FunctionDef ) -> list [str ]:
263+ args = func_node .args
264+ if args .vararg or args .kwarg or args .kwonlyargs :
265+ raise ValueError ("DSL kernel does not support *args/**kwargs/kwonly args" )
266+ if args .defaults or args .kw_defaults :
267+ raise ValueError ("DSL kernel does not support default arguments" )
268+ names = [a .arg for a in (args .posonlyargs + args .args )]
269+ if not names :
270+ raise ValueError ("DSL kernel must accept at least one argument" )
271+ return names
272+
152273 def __call__ (self , inputs_tuple , output , offset = None ):
153274 if self ._legacy_udf_signature :
154275 return self .func (inputs_tuple , output , offset )
0 commit comments