|
9 | 9 | from sqlglot.dialects.dialect import DialectType |
10 | 10 |
|
11 | 11 | from sqlmesh.core.macros import MacroRegistry |
| 12 | +from sqlmesh.core.signal import SignalRegistry |
12 | 13 | from sqlmesh.utils.jinja import JinjaMacroRegistry |
13 | 14 | from sqlmesh.core import constants as c |
14 | 15 | from sqlmesh.core.dialect import MacroFunc, parse_one |
@@ -48,23 +49,24 @@ def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs: |
48 | 49 | self.kwargs = kwargs |
49 | 50 |
|
50 | 51 | # Make sure that argument values are expressions in order to pass validation in ModelMeta. |
51 | | - calls = self.kwargs.pop("audits", []) |
52 | | - self.kwargs["audits"] = [ |
53 | | - ( |
54 | | - (call, {}) |
55 | | - if isinstance(call, str) |
56 | | - else ( |
57 | | - call[0], |
58 | | - { |
59 | | - arg_key: exp.convert( |
60 | | - tuple(arg_value) if isinstance(arg_value, list) else arg_value |
61 | | - ) |
62 | | - for arg_key, arg_value in call[1].items() |
63 | | - }, |
| 52 | + for function_call_attribute in ("audits", "signals"): |
| 53 | + calls = self.kwargs.pop(function_call_attribute, []) |
| 54 | + self.kwargs[function_call_attribute] = [ |
| 55 | + ( |
| 56 | + (call, {}) |
| 57 | + if isinstance(call, str) |
| 58 | + else ( |
| 59 | + call[0], |
| 60 | + { |
| 61 | + arg_key: exp.convert( |
| 62 | + tuple(arg_value) if isinstance(arg_value, list) else arg_value |
| 63 | + ) |
| 64 | + for arg_key, arg_value in call[1].items() |
| 65 | + }, |
| 66 | + ) |
64 | 67 | ) |
65 | | - ) |
66 | | - for call in calls |
67 | | - ] |
| 68 | + for call in calls |
| 69 | + ] |
68 | 70 |
|
69 | 71 | if "default_catalog" in kwargs: |
70 | 72 | raise ConfigError("`default_catalog` cannot be set on a per-model basis.") |
@@ -142,6 +144,7 @@ def model( |
142 | 144 | defaults: t.Optional[t.Dict[str, t.Any]] = None, |
143 | 145 | macros: t.Optional[MacroRegistry] = None, |
144 | 146 | jinja_macros: t.Optional[JinjaMacroRegistry] = None, |
| 147 | + signal_definitions: t.Optional[SignalRegistry] = None, |
145 | 148 | audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None, |
146 | 149 | dialect: t.Optional[str] = None, |
147 | 150 | time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, |
@@ -223,6 +226,7 @@ def model( |
223 | 226 | "macros": macros, |
224 | 227 | "jinja_macros": jinja_macros, |
225 | 228 | "audit_definitions": audit_definitions, |
| 229 | + "signal_definitions": signal_definitions, |
226 | 230 | "blueprint_variables": blueprint_variables, |
227 | 231 | **rendered_fields, |
228 | 232 | } |
|
0 commit comments