Skip to content

Commit ee16373

Browse files
committed
feat(bigframes): implement ai.similarity
1 parent aa43c83 commit ee16373

14 files changed

Lines changed: 263 additions & 1 deletion

File tree

packages/bigframes/bigframes/bigquery/_operations/ai.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,87 @@ def score(
869869
return series_list[0]._apply_nary_op(operator, series_list[1:])
870870

871871

872+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
873+
def similarity(
874+
content1: str | series.Series | pd.Series,
875+
content2: str | series.Series | pd.Series,
876+
*,
877+
endpoint: str | None = None,
878+
model: str | None = None,
879+
model_params: Mapping[Any, Any] | None = None,
880+
connection_id: str | None = None,
881+
) -> series.Series:
882+
"""
883+
Returns a FLOAT64 value that represents the cosine similarity between the two inputs.
884+
885+
**Examples:**
886+
887+
>>> import bigframes.pandas as bpd
888+
>>> import bigframes.bigquery as bbq
889+
>>> df = bpd.DataFrame({'word': ['happy', 'sad']})
890+
>>> bbq.ai.similarity(df['word'], 'glad', endpoint='text-embedding-005') # doctest: +SKIP
891+
0 0.916601
892+
1 0.660579
893+
894+
Args:
895+
content1 (str | Series):
896+
A string or series that provides the first value to compare. Both a BigFrames Series or a pandas Series are allowed.
897+
content2 (str | Series):
898+
A string or series that provides the second value to compare. Both a BigFrames Series or a pandas Series are allowed.
899+
endpoint (str, optional):
900+
Specifies the Vertex AI endpoint to use for the text embedding model.
901+
If you specify the model name, such as `'text-embedding-005'`, rather than a URL, then BigQuery ML automatically identifies the model and uses the model's full endpoint.
902+
model (str, optional):
903+
Specifies a built-in text embedding model. The only supported value is the embeddinggemma-300m model.
904+
If you specify this parameter, you can't specify the `endpoint`, `model_params`, or `connection_id` parameters.
905+
model_params (Mapping[Any, Any], optional):
906+
Provides additional parameters to the model. You can use any of the parameters object fields.
907+
One of these fields, `outputDimensionality`, lets you specify the number of dimensions to use when generating embeddings.
908+
connection_id (str, optional):
909+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
910+
911+
Returns:
912+
bigframes.series.Series: A new series of FLOAT64 values representing the cosine similarity.
913+
"""
914+
if model is not None:
915+
if any(x is not None for x in [endpoint, model_params, connection_id]):
916+
raise ValueError(
917+
"If 'model' is specified, you cannot specify 'endpoint', 'model_params', or 'connection_id'."
918+
)
919+
elif endpoint is None:
920+
raise ValueError("You must specify either 'model' or 'endpoint'.")
921+
922+
operator = ai_ops.AISimilarity(
923+
endpoint=endpoint,
924+
model=model,
925+
model_params=json.dumps(model_params) if model_params else None,
926+
connection_id=connection_id,
927+
)
928+
929+
# Find a unifying session for the subsequent operations.
930+
bf_session = None
931+
if isinstance(content1, series.Series):
932+
bf_session = content1._session
933+
elif isinstance(content2, series.Series):
934+
bf_session = content2._session
935+
936+
if isinstance(content1, str) and isinstance(content2, str):
937+
content1 = series.Series([content1], session=bf_session)
938+
return content1._apply_binary_op(content2, operator)
939+
elif isinstance(content1, str):
940+
# content2 must be a series
941+
content2 = convert.to_bf_series(
942+
content2, default_index=None, session=bf_session
943+
)
944+
return content2._apply_binary_op(content1, operator)
945+
else:
946+
# content1 must be a series.
947+
content1 = convert.to_bf_series(
948+
content1, default_index=None, session=bf_session
949+
)
950+
return content1._apply_binary_op(content2, operator)
951+
952+
872953
@log_adapter.method_logger(custom_base_name="bigquery_ai")
873954
def forecast(
874955
df: dataframe.DataFrame | pd.DataFrame,

packages/bigframes/bigframes/bigquery/ai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
generate_text,
6969
if_,
7070
score,
71+
similarity,
7172
)
7273

7374
__all__ = [
@@ -82,4 +83,5 @@
8283
"generate_text",
8384
"if_",
8485
"score",
86+
"similarity",
8587
]

packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,6 +1992,20 @@ def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructVal
19921992
).to_expr()
19931993

19941994

1995+
@scalar_op_compiler.register_binary_op(ops.AISimilarity, pass_op=True)
1996+
def ai_similarity(
1997+
content1: ibis_types.Value, content2: ibis_types.Value, op: ops.AISimilarity
1998+
) -> ibis_types.Value:
1999+
return ai_ops.AISimilarity(
2000+
content1, # type: ignore
2001+
content2, # type: ignore
2002+
op.endpoint, # type: ignore
2003+
op.model, # type: ignore
2004+
op.model_params, # type: ignore
2005+
op.connection_id, # type: ignore
2006+
).to_expr()
2007+
2008+
19952009
def _construct_prompt(
19962010
col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None]
19972011
) -> ibis_types.StructValue:

packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2424

2525
register_nary_op = expression_compiler.expression_compiler.register_nary_op
26+
register_binary_op = expression_compiler.expression_compiler.register_binary_op
2627

2728

2829
@register_nary_op(ops.AIGenerate, pass_op=True)
@@ -76,6 +77,16 @@ def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
7677
return sge.func("AI.SCORE", *args)
7778

7879

80+
@register_binary_op(ops.AISimilarity, pass_op=True)
81+
def _(content1: TypedExpr, content2: TypedExpr, op: ops.AISimilarity) -> sge.Expression:
82+
args = [
83+
sge.Kwarg(this="content1", expression=content1.expr),
84+
sge.Kwarg(this="content2", expression=content2.expr),
85+
] + _construct_named_args(op)
86+
87+
return sge.func("AI.SIMILARITY", *args)
88+
89+
7990
def _construct_prompt(
8091
exprs: tuple[TypedExpr, ...],
8192
prompt_context: tuple[str | None, ...],
@@ -94,7 +105,7 @@ def _construct_prompt(
94105
return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt))
95106

96107

97-
def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
108+
def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
98109
args = []
99110

100111
op_args = asdict(op)

packages/bigframes/bigframes/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AIGenerateInt,
2323
AIIf,
2424
AIScore,
25+
AISimilarity,
2526
)
2627
from bigframes.operations.array_ops import (
2728
ArrayIndexOp,
@@ -436,6 +437,7 @@
436437
"AIGenerateInt",
437438
"AIIf",
438439
"AIScore",
440+
"AISimilarity",
439441
# Numpy ops mapping
440442
"NUMPY_TO_BINOP",
441443
"NUMPY_TO_OP",

packages/bigframes/bigframes/operations/ai_ops.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,16 @@ class AIScore(base_ops.NaryOp):
150150

151151
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
152152
return dtypes.FLOAT_DTYPE
153+
154+
155+
@dataclasses.dataclass(frozen=True)
156+
class AISimilarity(base_ops.BinaryOp):
157+
name: ClassVar[str] = "ai_similarity"
158+
159+
endpoint: str | None
160+
model: str | None
161+
model_params: str | None
162+
connection_id: str | None
163+
164+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
165+
return dtypes.FLOAT_DTYPE

packages/bigframes/tests/system/small/bigquery/test_ai.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,5 +370,53 @@ def test_forecast_w_params(time_series_df_default_index: dataframe.DataFrame):
370370
)
371371

372372

373+
def test_ai_similarity(session):
374+
s1 = bpd.Series(["happy", "sad"], session=session)
375+
s2 = pd.Series(["glad", "angry"])
376+
377+
result = bbq.ai.similarity(s1, s2, endpoint="text-embedding-005")
378+
379+
assert _contains_no_nulls(result)
380+
assert result.dtype == dtypes.FLOAT_DTYPE
381+
382+
383+
def test_ai_similarity_one_content_is_string_literal(session):
384+
s1 = "happy"
385+
s2 = bpd.Series(["glad", "angry"], session=session)
386+
387+
result = bbq.ai.similarity(s1, s2, model="embeddinggemma-300m")
388+
389+
assert _contains_no_nulls(result)
390+
assert result.dtype == dtypes.FLOAT_DTYPE
391+
392+
393+
def test_ai_similarity_both_contents_are_string_literals(session):
394+
s1 = "happy"
395+
s2 = "glad"
396+
397+
result = bbq.ai.similarity(s1, s2, endpoint="text-embedding-005")
398+
399+
assert _contains_no_nulls(result)
400+
assert result.dtype == dtypes.FLOAT_DTYPE
401+
402+
403+
def test_ai_similarity_no_endpoint_or_model__raises_error(session):
404+
s1 = bpd.Series(["happy", "sad"], session=session)
405+
s2 = bpd.Series(["glad", "angry"], session=session)
406+
407+
with pytest.raises(ValueError):
408+
bbq.ai.similarity(s1, s2)
409+
410+
411+
def test_ai_similarity_both_endpoint_and_model__raises_error(session):
412+
s1 = "happy"
413+
s2 = "glad"
414+
415+
with pytest.raises(ValueError):
416+
bbq.ai.similarity(
417+
s1, s2, endpoint="text-embedding-005", model="embeddinggemma-300m"
418+
)
419+
420+
373421
def _contains_no_nulls(s: series.Series) -> bool:
374422
return len(s) == s.count()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.SIMILARITY(content1 => `string_col`, content2 => `string_col`, endpoint => 'text-embedding-005') AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
SELECT
2+
AI.SIMILARITY(
3+
content1 => `string_col`,
4+
content2 => `string_col`,
5+
endpoint => 'text-embedding-005',
6+
connection_id => 'bigframes-dev.us.bigframes-default-connection'
7+
) AS `result`
8+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.SIMILARITY(content1 => `string_col`, content2 => `string_col`, model => 'embeddinggemma-300m') AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

0 commit comments

Comments
 (0)