Skip to content

Commit ec516f2

Browse files
feat(bigquery)!: add typed AI scalar function nodes (#7479)
Signed-off-by: Mridankan Mandal <xerontitan90@gmail.com> Co-authored-by: Mridankan Mandal <xerontitan90@gmail.com>
1 parent 98ca4cd commit ec516f2

3 files changed

Lines changed: 40 additions & 0 deletions

File tree

sqlglot/expressions/functions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,24 @@ class AIClassify(Expression, Func):
321321
_sql_names = ["AI_CLASSIFY"]
322322

323323

324+
class AIEmbed(Expression, Func):
325+
arg_types = {"expressions": False}
326+
is_var_len_args = True
327+
_sql_names = ["EMBED"]
328+
329+
330+
class AISimilarity(Expression, Func):
331+
arg_types = {"expressions": False}
332+
is_var_len_args = True
333+
_sql_names = ["SIMILARITY"]
334+
335+
336+
class AIGenerate(Expression, Func):
337+
arg_types = {"expressions": False}
338+
is_var_len_args = True
339+
_sql_names = ["GENERATE"]
340+
341+
324342
class FeaturesAtTime(Expression, Func):
325343
arg_types = {"this": True, "time": False, "num_rows": False, "ignore_feature_nulls": False}
326344

sqlglot/parsers/bigquery.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,16 @@ def _parse_column_ops(self, this: exp.Expr | None) -> exp.Expr | None:
730730
self._retreat(func_index)
731731
parsed = self._parse_function(any_token=True)
732732
if parsed:
733+
if prefix == "AI" and isinstance(parsed, exp.Anonymous):
734+
ai_scalars: dict[str, type[exp.Func]] = {
735+
"EMBED": exp.AIEmbed,
736+
"SIMILARITY": exp.AISimilarity,
737+
"GENERATE": exp.AIGenerate,
738+
}
739+
expr_type = ai_scalars.get(parsed.name.upper())
740+
if expr_type:
741+
parsed = expr_type(expressions=parsed.expressions)
742+
733743
this = self.expression(exp.Dot(this=this.this, expression=parsed))
734744

735745
return this

tests/dialects/test_bigquery.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2449,6 +2449,18 @@ def test_ml_functions(self):
24492449
"SELECT AI.GENERATE_BOOL(MODEL `mydataset.gemini_model`, 'Is sky blue?')"
24502450
)
24512451

2452+
ast = self.validate_identity("SELECT AI.EMBED('hello')")
2453+
assert isinstance(ast.expressions[0], exp.Dot)
2454+
assert isinstance(ast.expressions[0].expression, exp.AIEmbed)
2455+
2456+
ast = self.validate_identity("SELECT AI.SIMILARITY('a', 'b')")
2457+
assert isinstance(ast.expressions[0], exp.Dot)
2458+
assert isinstance(ast.expressions[0].expression, exp.AISimilarity)
2459+
2460+
ast = self.validate_identity("SELECT AI.GENERATE('Write a haiku')")
2461+
assert isinstance(ast.expressions[0], exp.Dot)
2462+
assert isinstance(ast.expressions[0].expression, exp.AIGenerate)
2463+
24522464
def test_merge(self):
24532465
self.validate_all(
24542466
"""

0 commit comments

Comments
 (0)