Skip to content

Commit f3cb4ad

Browse files
feat(bigframes): update ai.if_() params to match the SQL version (#16857)
Reference: https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-if I left out the `embeddings` param to keep things simple. It will be introduced later. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 8fa0f81 commit f3cb4ad

10 files changed

Lines changed: 71 additions & 10 deletions

File tree

packages/bigframes/bigframes/bigquery/_operations/ai.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,9 @@ def if_(
817817
prompt: PROMPT_TYPE,
818818
*,
819819
connection_id: str | None = None,
820+
endpoint: str | None = None,
821+
optimization_mode: Literal["minimize_cost", "maximize_quality"] = "minimize_cost",
822+
max_error_ratio: float | None = None,
820823
) -> series.Series:
821824
"""
822825
Evaluates the prompt to True or False. Compared to `ai.generate_bool()`, this function
@@ -838,20 +841,26 @@ def if_(
838841
1 Illinois
839842
dtype: string
840843
841-
.. note::
842-
843-
This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
844-
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
845-
and might have limited support. For more information, see the launch stage descriptions
846-
(https://cloud.google.com/products#product-launch-stages).
847-
848844
Args:
849845
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
850846
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
851847
or pandas Series.
852848
connection_id (str, optional):
853849
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
854850
If not provided, the query uses your end-user credential.
851+
endpoint (str, optional):
852+
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
853+
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
854+
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML dynamically chooses a model based on your query to have the
855+
best cost to quality tradeoff for the task.
856+
optimization_mode (Literal["minimize_cost", "maximize_quality"]):
857+
Specifies the optimization strategy to use. Supported values are:
858+
* "minimize_cost" (default): uses a local, distilled model to process the majority of rows, reducing latency and cost.
859+
* "maximize_quality": always uses the remote LLM for inference.
860+
max_error_ratio (float):
861+
A float value between 0.0 and 1.0 that contains the maximum acceptable ratio of row-level inference failures to
862+
rows processed on this function. If this value is exceeded, then the query fails. The default value is 1.0.
863+
This argument isn't supported when `optimization_mode` is set to "minimize_cost".
855864
856865
Returns:
857866
bigframes.series.Series: A new series of bools.
@@ -863,6 +872,9 @@ def if_(
863872
operator = ai_ops.AIIf(
864873
prompt_context=tuple(prompt_context),
865874
connection_id=connection_id,
875+
endpoint=endpoint,
876+
optimization_mode=optimization_mode,
877+
max_error_ratio=max_error_ratio,
866878
)
867879

868880
return series_list[0]._apply_nary_op(operator, series_list[1:])

packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,6 +1983,9 @@ def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:
19831983
return ai_ops.AIIf(
19841984
_construct_prompt(values, op.prompt_context), # type: ignore
19851985
op.connection_id, # type: ignore
1986+
op.endpoint, # type: ignore
1987+
op.optimization_mode.upper() if op.optimization_mode is not None else None, # type: ignore
1988+
op.max_error_ratio, # type: ignore
19861989
).to_expr()
19871990

19881991

packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
139139
expression=sge.JSON(this=sge.Literal.string(value)),
140140
)
141141
)
142+
elif field == "optimization_mode":
143+
args.append(
144+
sge.Kwarg(this=field, expression=sge.Literal.string(value.upper()))
145+
)
146+
elif field == "max_error_ratio":
147+
args.append(sge.Kwarg(this=field, expression=sge.Literal.number(value)))
142148
elif field == "request_type":
143149
args.append(
144150
sge.Kwarg(this=field, expression=sge.Literal.string(value.upper()))

packages/bigframes/bigframes/operations/ai_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ class AIIf(base_ops.NaryOp):
146146

147147
prompt_context: Tuple[str | None, ...]
148148
connection_id: str | None
149+
endpoint: str | None = None
150+
optimization_mode: str | None = None
151+
max_error_ratio: float | None = None
149152

150153
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
151154
return dtypes.BOOL_DTYPE

packages/bigframes/tests/system/small/bigquery/test_ai.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,11 @@ def test_ai_if(session):
323323
s2 = bpd.Series(["fruit", "tree"], session=session)
324324
prompt = (s1, " is a ", s2)
325325

326-
result = bbq.ai.if_(prompt)
326+
result = bbq.ai.if_(
327+
prompt,
328+
optimization_mode="maximize_quality",
329+
max_error_ratio=0.5,
330+
)
327331

328332
assert _contains_no_nulls(result)
329333
assert result.dtype == dtypes.BOOL_DTYPE
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
SELECT
2-
AI.IF(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result`
2+
AI.IF(
3+
prompt => (`string_col`, ' is the same as ', `string_col`),
4+
optimization_mode => 'MINIMIZE_COST',
5+
max_error_ratio => 0.5
6+
) AS `result`
37
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
SELECT
22
AI.IF(
33
prompt => (`string_col`, ' is the same as ', `string_col`),
4-
connection_id => 'bigframes-dev.us.bigframes-default-connection'
4+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
5+
optimization_mode => 'MINIMIZE_COST',
6+
max_error_ratio => 0.5
57
) AS `result`
68
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
SELECT
2+
AI.IF(
3+
prompt => (`string_col`, ' is the same as ', `string_col`),
4+
endpoint => 'gemini-2.5-flash'
5+
) AS `result`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,24 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot, connection_id):
358358
op = ops.AIIf(
359359
prompt_context=(None, " is the same as ", None),
360360
connection_id=connection_id,
361+
optimization_mode="minimize_cost",
362+
max_error_ratio=0.5,
363+
)
364+
365+
sql = utils._apply_ops_to_sql(
366+
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
367+
)
368+
369+
snapshot.assert_match(sql, "out.sql")
370+
371+
372+
def test_ai_if_with_endpoint(scalar_types_df: dataframe.DataFrame, snapshot):
373+
col_name = "string_col"
374+
375+
op = ops.AIIf(
376+
prompt_context=(None, " is the same as ", None),
377+
connection_id=None,
378+
endpoint="gemini-2.5-flash",
361379
)
362380

363381
sql = utils._apply_ops_to_sql(

packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ class AIIf(Value):
138138

139139
prompt: Value
140140
connection_id: Optional[Value[dt.String]]
141+
endpoint: Optional[Value[dt.String]] = None
142+
optimization_mode: Optional[Value[dt.String]] = None
143+
max_error_ratio: Optional[Value[dt.Float64]] = None
141144

142145
shape = rlz.shape_like("prompt")
143146

0 commit comments

Comments
 (0)