Skip to content

Commit a9c0f3f

Browse files
committed
Refactor to track metadata objects using a new dict
1 parent 74c1cbd commit a9c0f3f

5 files changed

Lines changed: 65 additions & 49 deletions

File tree

sqlmesh/core/model/common.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def make_python_env(
4343
python_env = {} if python_env is None else python_env
4444
variables = variables or {}
4545
env: t.Dict[str, t.Any] = {}
46+
env_metadata: t.Set[str] = set()
4647
used_macros = {}
4748
used_variables = (used_variables or set()).copy()
4849

@@ -88,9 +89,15 @@ def make_python_env(
8889
if isinstance(used_macro, Executable):
8990
python_env[name] = used_macro
9091
elif not hasattr(used_macro, c.SQLMESH_BUILTIN) and name not in python_env:
91-
build_env(used_macro.func, env=env, name=name, path=module_path)
92+
build_env(
93+
used_macro.func,
94+
env=env,
95+
env_metadata=env_metadata,
96+
name=name,
97+
path=module_path,
98+
)
9299

93-
python_env.update(serialize_env(env, path=module_path))
100+
python_env.update(serialize_env(env, env_metadata=env_metadata, path=module_path))
94101
return _add_variables_to_python_env(
95102
python_env,
96103
used_variables,

sqlmesh/core/model/decorator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def model(
126126
) -> Model:
127127
"""Get the model registered by this function."""
128128
env: t.Dict[str, t.Any] = {}
129+
env_metadata: t.Set[str] = set()
129130
entrypoint = self.func.__name__
130131

131132
if not self.name_provided and not infer_names:
@@ -145,7 +146,7 @@ def model(
145146
f"""Python model "{self.name}"'s `kind` dictionary must contain a `name` key with a valid ModelKindName enum value."""
146147
)
147148

148-
build_env(self.func, env=env, name=entrypoint, path=module_path)
149+
build_env(self.func, env=env, env_metadata=env_metadata, name=entrypoint, path=module_path)
149150

150151
rendered_fields = render_meta_fields(
151152
fields={"name": self.name, **self.kwargs},
@@ -184,7 +185,7 @@ def model(
184185
"defaults": rendered_defaults,
185186
"path": path,
186187
"time_column_format": time_column_format,
187-
"python_env": serialize_env(env, path=module_path),
188+
"python_env": serialize_env(env, env_metadata=env_metadata, path=module_path),
188189
"physical_schema_mapping": physical_schema_mapping,
189190
"project": project,
190191
"default_catalog": default_catalog,

sqlmesh/core/model/definition.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,15 +2483,16 @@ def _create_model(
24832483
)
24842484

24852485
env: t.Dict[str, t.Any] = {}
2486+
env_metadata: t.Set[str] = set()
24862487

24872488
for signal_name, _ in model.signals:
24882489
if signal_definitions and signal_name in signal_definitions:
24892490
func = signal_definitions[signal_name].func
24902491
setattr(func, c.SQLMESH_METADATA, True)
2491-
build_env(func, env=env, name=signal_name, path=module_path)
2492+
build_env(func, env=env, env_metadata=env_metadata, name=signal_name, path=module_path)
24922493

24932494
model.python_env.update(python_env)
2494-
model.python_env.update(serialize_env(env, path=module_path))
2495+
model.python_env.update(serialize_env(env, env_metadata=env_metadata, path=module_path))
24952496
model._path = path
24962497
model.set_time_format(time_column_format)
24972498

sqlmesh/utils/metaprogramming.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def build_env(
267267
obj: t.Any,
268268
*,
269269
env: t.Dict[str, t.Any],
270+
env_metadata: t.Set[str],
270271
name: str,
271272
path: Path,
272273
) -> None:
@@ -277,6 +278,7 @@ def build_env(
277278
Args:
278279
obj: Any python object.
279280
env: Dictionary to store the env.
281+
env_metadata: Set to store the keys that correspond to "metadata only" objects in the env.
280282
name: Name of the object in the env.
281283
path: The module path to serialize. Other modules will not be walked and treated as imports.
282284
"""
@@ -291,16 +293,13 @@ def walk(obj: t.Any, name: str, is_metadata: t.Optional[bool] = None) -> None:
291293
visited.add(name)
292294
name_missing_from_env = name not in env
293295

294-
if name_missing_from_env or (
295-
not is_metadata and env[name] == obj and getattr(env[name], c.SQLMESH_METADATA, None)
296-
):
296+
if name_missing_from_env or (not is_metadata and env[name] == obj and name in env_metadata):
297297
if not name_missing_from_env:
298298
# The existing object in the env is "metadata only" but we're walking it again as a
299299
# non-"metadata only" dependency, so we update this flag to ensure all transitive
300300
# dependencies are also not marked as "metadata only"
301301
is_metadata = False
302-
if hasattr(obj, c.SQLMESH_METADATA):
303-
delattr(obj, c.SQLMESH_METADATA)
302+
env_metadata.remove(name)
304303

305304
if hasattr(obj, c.SQLMESH_MACRO):
306305
# We only need to add the undecorated code of @macro() functions in env, which
@@ -321,9 +320,7 @@ def walk(obj: t.Any, name: str, is_metadata: t.Optional[bool] = None) -> None:
321320
or not _is_relative_to(obj_module.__file__, path)
322321
):
323322
if is_metadata:
324-
setattr(obj, c.SQLMESH_METADATA, True)
325-
elif hasattr(obj, c.SQLMESH_METADATA):
326-
delattr(obj, c.SQLMESH_METADATA)
323+
env_metadata.add(name)
327324

328325
env[name] = obj
329326
return
@@ -360,7 +357,7 @@ def walk(obj: t.Any, name: str, is_metadata: t.Optional[bool] = None) -> None:
360357
walk(v, k, is_metadata)
361358

362359
if is_metadata:
363-
setattr(obj, c.SQLMESH_METADATA, True)
360+
env_metadata.add(name)
364361

365362
# We store the object in the environment after its dependencies, because otherwise we
366363
# could crash at environment hydration time, since dicts are ordered and the top-level
@@ -416,26 +413,31 @@ def is_value(self) -> bool:
416413
return self.kind == ExecutableKind.VALUE
417414

418415
@classmethod
419-
def value(cls, v: t.Any) -> Executable:
420-
return Executable(payload=repr(v), kind=ExecutableKind.VALUE)
416+
def value(cls, v: t.Any, is_metadata: t.Optional[bool] = None) -> Executable:
417+
return Executable(payload=repr(v), kind=ExecutableKind.VALUE, is_metadata=is_metadata)
421418

422419

423-
def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable]:
420+
def serialize_env(
421+
env: t.Dict[str, t.Any],
422+
env_metadata: t.Set[str],
423+
path: Path,
424+
) -> t.Dict[str, Executable]:
424425
"""Serializes a python function into a self contained dictionary.
425426
426427
Recursively walks a function's globals to store all other references inside of env.
427428
428429
Args:
429430
env: Dictionary to store the env.
431+
env_metadata: Keys that correspond to "metadata only" objects in the env.
430432
path: The root path to seralize. Other modules will not be walked and treated as imports.
431433
"""
432434
serialized = {}
433435

434436
for k, v in env.items():
435-
is_metadata = getattr(v, c.SQLMESH_METADATA, None)
437+
is_metadata = True if k in env_metadata else None
436438

437439
if isinstance(v, LITERALS) or v is None:
438-
serialized[k] = Executable.value(v)
440+
serialized[k] = Executable.value(v, is_metadata=is_metadata)
439441
elif inspect.ismodule(v):
440442
name = v.__name__
441443
if hasattr(v, "__file__") and _is_relative_to(v.__file__, path):
@@ -471,7 +473,6 @@ def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable
471473
v = wrapped
472474
file_path = Path(inspect.getfile(wrapped))
473475
relative_obj_file_path = _is_relative_to(file_path, path)
474-
is_metadata = is_metadata or getattr(v, c.SQLMESH_METADATA, None)
475476
except TypeError:
476477
file_path = None
477478
relative_obj_file_path = False

tests/utils/test_metaprogramming.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pytest_mock.plugin import MockerFixture
1212
from sqlglot import exp
1313
from sqlglot import exp as expressions
14-
from sqlglot.expressions import to_table
14+
from sqlglot.expressions import SQLGLOT_META, to_table
1515
from sqlglot.optimizer.pushdown_projections import SELECT_ALL
1616

1717
import tests.utils.test_date as test_date
@@ -100,13 +100,6 @@ def other_func(a: int) -> int:
100100
return X + a + W
101101

102102

103-
def noop_metadata() -> None:
104-
return None
105-
106-
107-
setattr(noop_metadata, c.SQLMESH_METADATA, True)
108-
109-
110103
@contextmanager
111104
def test_context_manager():
112105
yield
@@ -134,8 +127,7 @@ def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2)
134127
sqlglot.parse_one("1")
135128
MyClass()
136129
DataClass(x=y)
137-
noop_metadata()
138-
normalize_model_name("test")
130+
normalize_model_name("test" + SQLGLOT_META)
139131
fetch_data()
140132
function_with_custom_decorator()
141133

@@ -154,7 +146,6 @@ def test_func_globals() -> None:
154146
"Z": 3,
155147
"DataClass": DataClass,
156148
"MyClass": MyClass,
157-
"noop_metadata": noop_metadata,
158149
"normalize_model_name": normalize_model_name,
159150
"other_func": other_func,
160151
"sqlglot": sqlglot,
@@ -163,6 +154,7 @@ def test_func_globals() -> None:
163154
"fetch_data": fetch_data,
164155
"test_context_manager": test_context_manager,
165156
"function_with_custom_decorator": function_with_custom_decorator,
157+
"SQLGLOT_META": SQLGLOT_META,
166158
}
167159
assert func_globals(other_func) == {
168160
"X": 1,
@@ -194,8 +186,7 @@ def test_normalize_source() -> None:
194186
sqlglot.parse_one('1')
195187
MyClass()
196188
DataClass(x=y)
197-
noop_metadata()
198-
normalize_model_name('test')
189+
normalize_model_name('test' + SQLGLOT_META)
199190
fetch_data()
200191
function_with_custom_decorator()
201192
@@ -221,20 +212,23 @@ def closure(z: int):
221212
def test_serialize_env_error() -> None:
222213
with pytest.raises(SQLMeshError):
223214
# pretend to be the module pandas
224-
serialize_env({"test_date": test_date}, path=Path("tests/utils"))
215+
serialize_env({"test_date": test_date}, set(), path=Path("tests/utils"))
225216

226217
with pytest.raises(SQLMeshError):
227-
serialize_env({"select_all": SELECT_ALL}, path=Path("tests/utils"))
218+
serialize_env({"select_all": SELECT_ALL}, set(), path=Path("tests/utils"))
228219

229220

230221
def test_serialize_env() -> None:
231-
env: t.Dict[str, t.Any] = {}
232222
path = Path("tests/utils")
233-
build_env(main_func, env=env, name="MAIN", path=path)
234-
env = serialize_env(env, path=path) # type: ignore
235223

224+
env: t.Dict[str, t.Any] = {}
225+
env_metadata: t.Set[str] = set()
226+
227+
build_env(main_func, env=env, env_metadata=env_metadata, name="MAIN", path=path)
228+
env = serialize_env(env, env_metadata=env_metadata, path=path) # type: ignore
236229
assert prepare_env(env)
237-
assert env == {
230+
231+
expected_env = {
238232
"MAIN": Executable(
239233
name="main_func",
240234
alias="MAIN",
@@ -244,8 +238,7 @@ def test_serialize_env() -> None:
244238
sqlglot.parse_one('1')
245239
MyClass()
246240
DataClass(x=y)
247-
noop_metadata()
248-
normalize_model_name('test')
241+
normalize_model_name('test' + SQLGLOT_META)
249242
fetch_data()
250243
function_with_custom_decorator()
251244
@@ -319,13 +312,6 @@ def test_context_manager():
319312
path="test_metaprogramming.py",
320313
payload="my_lambda = lambda : print('z')",
321314
),
322-
"noop_metadata": Executable(
323-
name="noop_metadata",
324-
path="test_metaprogramming.py",
325-
payload="""def noop_metadata():
326-
return None""",
327-
is_metadata=True,
328-
),
329315
"normalize_model_name": Executable(
330316
payload="from sqlmesh.core.dialect import normalize_model_name",
331317
kind=ExecutableKind.IMPORT,
@@ -401,4 +387,24 @@ def function_with_custom_decorator():
401387
return""",
402388
alias="_func",
403389
),
390+
"SQLGLOT_META": Executable.value("sqlglot.meta"),
404391
}
392+
393+
assert env_metadata == set()
394+
assert env == expected_env
395+
396+
# Annotate the entrypoint as "metadata only" to show how it propagates
397+
setattr(main_func, c.SQLMESH_METADATA, True)
398+
399+
env = {}
400+
env_metadata = set()
401+
402+
build_env(main_func, env=env, env_metadata=env_metadata, name="MAIN", path=path)
403+
env = serialize_env(env, env_metadata=env_metadata, path=path) # type: ignore
404+
assert prepare_env(env)
405+
406+
expected_env = {k: Executable(**v.dict(), is_metadata=True) for k, v in expected_env.items()}
407+
408+
# Every object is treated as "metadata only", transitively
409+
assert env_metadata == set(env)
410+
assert env == expected_env

0 commit comments

Comments
 (0)