Skip to content

Commit 6c94666

Browse files
committed
dsl_kernel respects user's comments and blanks in kernels now
1 parent d386291 commit 6c94666

2 files changed

Lines changed: 159 additions & 35 deletions

File tree

src/blosc2/dsl_kernel.py

Lines changed: 151 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import inspect
1313
import os
1414
import textwrap
15+
import tokenize
16+
from io import StringIO
1517
from typing import ClassVar
1618

1719
_PRINT_DSL_KERNEL = os.environ.get("PRINT_DSL_KERNEL", "").strip().lower()
@@ -30,28 +32,129 @@ def _normalize_miniexpr_scalar(value):
3032
raise TypeError("Unsupported scalar type for miniexpr specialization")
3133

3234

33-
class _MiniexprScalarSpecializer(ast.NodeTransformer):
34-
def __init__(self, replacements: dict[str, int | float]):
35-
self.replacements = replacements
35+
def _line_starts(text: str) -> list[int]:
36+
starts = [0]
37+
for i, ch in enumerate(text):
38+
if ch == "\n":
39+
starts.append(i + 1)
40+
return starts
3641

37-
def visit_Name(self, node):
38-
if isinstance(node.ctx, ast.Load) and node.id in self.replacements:
39-
return ast.copy_location(ast.Constant(value=self.replacements[node.id]), node)
40-
return node
4142

42-
def visit_Call(self, node):
43-
node = self.generic_visit(node)
44-
if (
45-
isinstance(node.func, ast.Name)
46-
and node.func.id in {"float", "int"}
47-
and len(node.args) == 1
48-
and not node.keywords
49-
and isinstance(node.args[0], ast.Constant)
50-
and isinstance(node.args[0].value, int | float | bool)
51-
):
52-
folded = float(node.args[0].value) if node.func.id == "float" else int(node.args[0].value)
53-
return ast.copy_location(ast.Constant(value=folded), node)
54-
return node
43+
def _to_abs(line_starts: list[int], line: int, col: int) -> int:
44+
return line_starts[line - 1] + col
45+
46+
47+
def _find_def_signature_span(text: str):
48+
tokens = list(tokenize.generate_tokens(StringIO(text).readline))
49+
for i, tok in enumerate(tokens):
50+
if tok.type != tokenize.NAME or tok.string != "def":
51+
continue
52+
lparen = None
53+
rparen = None
54+
colon = None
55+
depth = 0
56+
for j in range(i + 1, len(tokens)):
57+
t = tokens[j]
58+
if lparen is None:
59+
if t.type == tokenize.OP and t.string == "(":
60+
lparen = t
61+
depth = 1
62+
continue
63+
if t.type == tokenize.OP and t.string == "(":
64+
depth += 1
65+
continue
66+
if t.type == tokenize.OP and t.string == ")":
67+
depth -= 1
68+
if depth == 0:
69+
rparen = t
70+
continue
71+
if rparen is not None and t.type == tokenize.OP and t.string == ":":
72+
colon = t
73+
break
74+
if lparen is not None and rparen is not None:
75+
return lparen, rparen, colon
76+
return None, None, None
77+
78+
79+
def _remove_scalar_params_preserving_source(text: str, scalar_replacements: dict[str, int | float]):
80+
if not scalar_replacements:
81+
return text, 0
82+
83+
lparen, rparen, colon = _find_def_signature_span(text)
84+
if lparen is None or rparen is None:
85+
return text, 0
86+
87+
try:
88+
tree = ast.parse(text)
89+
except Exception:
90+
return text, 0
91+
92+
func = next((n for n in tree.body if isinstance(n, ast.FunctionDef)), None)
93+
if func is None:
94+
return text, 0
95+
96+
kept = [a.arg for a in (func.args.posonlyargs + func.args.args) if a.arg not in scalar_replacements]
97+
line_starts = _line_starts(text)
98+
pstart = _to_abs(line_starts, lparen.end[0], lparen.end[1])
99+
pend = _to_abs(line_starts, rparen.start[0], rparen.start[1])
100+
updated = f"{text[:pstart]}{', '.join(kept)}{text[pend:]}"
101+
body_start = 0
102+
if colon is not None:
103+
body_start = _to_abs(_line_starts(updated), colon.end[0], colon.end[1])
104+
return updated, body_start
105+
106+
107+
def _replace_scalar_names_preserving_source(
108+
text: str, scalar_replacements: dict[str, int | float], body_start: int
109+
):
110+
if not scalar_replacements:
111+
return text
112+
113+
line_starts = _line_starts(text)
114+
tokens = list(tokenize.generate_tokens(StringIO(text).readline))
115+
significant = {
116+
tokenize.NAME,
117+
tokenize.NUMBER,
118+
tokenize.STRING,
119+
tokenize.OP,
120+
tokenize.INDENT,
121+
tokenize.DEDENT,
122+
}
123+
assign_ops = {"=", "+=", "-=", "*=", "/=", "//=", "%=", "&=", "|=", "^=", "<<=", ">>=", ":="}
124+
edits = []
125+
for i, tok in enumerate(tokens):
126+
if tok.type != tokenize.NAME or tok.string not in scalar_replacements:
127+
continue
128+
start_abs = _to_abs(line_starts, tok.start[0], tok.start[1])
129+
if start_abs < body_start:
130+
continue
131+
132+
prev_sig = None
133+
for j in range(i - 1, -1, -1):
134+
if tokens[j].type in significant:
135+
prev_sig = tokens[j]
136+
break
137+
if prev_sig is not None and prev_sig.type == tokenize.OP and prev_sig.string == ".":
138+
continue
139+
140+
next_sig = None
141+
for j in range(i + 1, len(tokens)):
142+
if tokens[j].type in significant:
143+
next_sig = tokens[j]
144+
break
145+
if next_sig is not None and next_sig.type == tokenize.OP and next_sig.string in assign_ops:
146+
continue
147+
148+
end_abs = _to_abs(line_starts, tok.end[0], tok.end[1])
149+
edits.append((start_abs, end_abs, repr(scalar_replacements[tok.string])))
150+
151+
if not edits:
152+
return text
153+
154+
out = text
155+
for start, end, repl in sorted(edits, key=lambda e: e[0], reverse=True):
156+
out = f"{out[:start]}{repl}{out[end:]}"
157+
return out
55158

56159

57160
def specialize_miniexpr_inputs(expr_string: str, operands: dict):
@@ -73,14 +176,9 @@ def specialize_miniexpr_inputs(expr_string: str, operands: dict):
73176
if not scalar_replacements:
74177
return expr_string, operands
75178

76-
tree = ast.parse(expr_string)
77-
tree = _MiniexprScalarSpecializer(scalar_replacements).visit(tree)
78-
for node in tree.body:
79-
if isinstance(node, ast.FunctionDef):
80-
node.args.posonlyargs = [a for a in node.args.posonlyargs if a.arg not in scalar_replacements]
81-
node.args.args = [a for a in node.args.args if a.arg not in scalar_replacements]
82-
ast.fix_missing_locations(tree)
83-
return ast.unparse(tree), array_operands
179+
rewritten, body_start = _remove_scalar_params_preserving_source(expr_string, scalar_replacements)
180+
rewritten = _replace_scalar_names_preserving_source(rewritten, scalar_replacements, body_start)
181+
return rewritten, array_operands
84182

85183

86184
def specialize_dsl_miniexpr_inputs(expr_string: str, operands: dict):
@@ -141,14 +239,37 @@ def _extract_dsl(self, func):
141239
if func_node is None:
142240
raise ValueError("No function definition found for DSL extraction")
143241

144-
builder = _DSLBuilder()
145-
dsl_source, input_names = builder.build(func_node)
242+
dsl_source = self._slice_function_source(source, func_node)
243+
input_names = self._input_names_from_signature(func_node)
146244
if _PRINT_DSL_KERNEL:
147245
func_name = getattr(func, "__name__", "<dsl_kernel>")
148246
print(f"[DSLKernel:{func_name}] dsl_source (full):")
149247
print(dsl_source)
150248
return dsl_source, input_names
151249

250+
@staticmethod
251+
def _slice_function_source(source: str, func_node: ast.FunctionDef) -> str:
252+
lines = source.splitlines()
253+
start = func_node.lineno - 1
254+
end_lineno = getattr(func_node, "end_lineno", None)
255+
if end_lineno is None:
256+
end = len(lines)
257+
else:
258+
end = end_lineno
259+
return "\n".join(lines[start:end])
260+
261+
@staticmethod
262+
def _input_names_from_signature(func_node: ast.FunctionDef) -> list[str]:
263+
args = func_node.args
264+
if args.vararg or args.kwarg or args.kwonlyargs:
265+
raise ValueError("DSL kernel does not support *args/**kwargs/kwonly args")
266+
if args.defaults or args.kw_defaults:
267+
raise ValueError("DSL kernel does not support default arguments")
268+
names = [a.arg for a in (args.posonlyargs + args.args)]
269+
if not names:
270+
raise ValueError("DSL kernel must accept at least one argument")
271+
return names
272+
152273
def __call__(self, inputs_tuple, output, offset=None):
153274
if self._legacy_udf_signature:
154275
return self.func(inputs_tuple, output, offset)

tests/ndarray/test_dsl_kernels.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def kernel_while_full(x, y):
8686
@blosc2.dsl_kernel
8787
def kernel_loop_param(x, y, niter):
8888
acc = x
89+
# loop count comes from scalar niter
8990
for _i in range(niter):
9091
acc = np.where(acc < y, acc + 1, acc - 1)
9192
return acc
@@ -183,10 +184,10 @@ def test_dsl_kernel_full_control_flow_kept_as_dsl_function():
183184
assert kernel_control_flow_full.dsl_source is not None
184185
assert "def kernel_control_flow_full(x, y):" in kernel_control_flow_full.dsl_source
185186
assert "for i in range(4):" in kernel_control_flow_full.dsl_source
186-
assert "elif (i == 1):" in kernel_control_flow_full.dsl_source
187+
assert "if i == 1:" in kernel_control_flow_full.dsl_source
187188
assert "continue" in kernel_control_flow_full.dsl_source
188189
assert "break" in kernel_control_flow_full.dsl_source
189-
assert "where(" in kernel_control_flow_full.dsl_source
190+
assert "np.where(" in kernel_control_flow_full.dsl_source
190191

191192
a, b, a2, b2 = _make_arrays()
192193
expr = blosc2.lazyudf(
@@ -205,7 +206,7 @@ def test_dsl_kernel_full_control_flow_kept_as_dsl_function():
205206
def test_dsl_kernel_while_kept_as_dsl_function():
206207
assert kernel_while_full.dsl_source is not None
207208
assert "def kernel_while_full(x, y):" in kernel_while_full.dsl_source
208-
assert "while (i < 3):" in kernel_while_full.dsl_source
209+
assert "while i < 3:" in kernel_while_full.dsl_source
209210

210211
a, b, a2, b2 = _make_arrays()
211212
expr = blosc2.lazyudf(
@@ -280,6 +281,7 @@ def wrapped_set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc=None,
280281
assert "def kernel_loop_param(x, y):" in captured["expr"]
281282
assert "for it in range(3):" not in captured["expr"]
282283
assert "for _i in range(3):" in captured["expr"]
284+
assert "# loop count comes from scalar niter" in captured["expr"]
283285
assert "range(niter)" not in captured["expr"]
284286
assert "float(niter)" not in captured["expr"]
285287
finally:
@@ -338,8 +340,9 @@ def test_jit_backend_pragma_wrapping_dsl_source():
338340
],
339341
)
340342
def test_dsl_kernel_flawed_syntax_detected_fallback_callable(kernel):
341-
assert kernel.dsl_source is None
342-
assert kernel.input_names is None
343+
assert kernel.dsl_source is not None
344+
assert kernel.dsl_source.startswith("def ")
345+
assert kernel.input_names == ["x", "y"]
343346

344347
a, b, a2, b2 = _make_arrays()
345348
expr = blosc2.lazyudf(

0 commit comments

Comments
 (0)