Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,81 @@ def generate_int(
return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_double(
prompt: PROMPT_TYPE,
*,
connection_id: str | None = None,
endpoint: str | None = None,
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
model_params: Mapping[Any, Any] | None = None,
) -> series.Series:
"""
Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.

**Examples:**

>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> bpd.options.display.progress_bar = None
>>> animal = bpd.Series(["Kangaroo", "Rabbit", "Spider"])
>>> bbq.ai.generate_double(("How many legs does a ", animal, " have?"))
0 {'result': 2.0, 'full_response': '{"candidates...
1 {'result': 4.0, 'full_response': '{"candidates...
2 {'result': 8.0, 'full_response': '{"candidates...
dtype: struct<result: double, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]

>>> bbq.ai.generate_double(("How many legs does a ", animal, " have?")).struct.field("result")
0 2.0
1 4.0
2 8.0
Name: result, dtype: Float64

Args:
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
or pandas Series.
connection_id (str, optional):
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
If not provided, the connection from the current session will be used.
endpoint (str, optional):
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects a recent stable
version of Gemini to use.
request_type (Literal["dedicated", "shared", "unspecified"]):
Specifies the type of inference request to send to the Gemini model. The request type determines what quota the request uses.
* "dedicated": function only uses Provisioned Throughput quota. The function returns the error Provisioned throughput is not
purchased or is not active if Provisioned Throughput quota isn't available.
* "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota.
* "unspecified": If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota.
If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first.
If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
model_params (Mapping[Any, Any]):
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.

Returns:
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
* "result": an DOUBLE value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
The generated text is in the text element.
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
"""

prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0

operator = ai_ops.AIGenerateDouble(
prompt_context=tuple(prompt_context),
connection_id=_resolve_connection_id(series_list[0], connection_id),
endpoint=endpoint,
request_type=request_type,
model_params=json.dumps(model_params) if model_params else None,
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


def _separate_context_and_series(
prompt: PROMPT_TYPE,
) -> Tuple[List[str | None], List[series.Series]]:
Expand Down
16 changes: 15 additions & 1 deletion bigframes/core/compile/ibis_compiler/scalar_op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,7 +1986,7 @@ def ai_generate_bool(

@scalar_op_compiler.register_nary_op(ops.AIGenerateInt, pass_op=True)
def ai_generate_int(
*values: ibis_types.Value, op: ops.AIGenerateBool
*values: ibis_types.Value, op: ops.AIGenerateInt
) -> ibis_types.StructValue:

return ai_ops.AIGenerateInt(
Expand All @@ -1998,6 +1998,20 @@ def ai_generate_int(
).to_expr()


@scalar_op_compiler.register_nary_op(ops.AIGenerateDouble, pass_op=True)
def ai_generate_double(
*values: ibis_types.Value, op: ops.AIGenerateDouble
) -> ibis_types.StructValue:

return ai_ops.AIGenerateDouble(
_construct_prompt(values, op.prompt_context), # type: ignore
op.connection_id, # type: ignore
op.endpoint, # type: ignore
op.request_type.upper(), # type: ignore
op.model_params, # type: ignore
).to_expr()


def _construct_prompt(
col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None]
) -> ibis_types.StructValue:
Expand Down
7 changes: 7 additions & 0 deletions bigframes/core/compile/sqlglot/expressions/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateInt) -> sge.Expression:
return sge.func("AI.GENERATE_INT", *args)


@register_nary_op(ops.AIGenerateDouble, pass_op=True)
def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression:
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)

return sge.func("AI.GENERATE_DOUBLE", *args)


def _construct_prompt(
exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...]
) -> sge.Kwarg:
Expand Down
3 changes: 2 additions & 1 deletion bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateInt
from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateDouble, AIGenerateInt
from bigframes.operations.array_ops import (
ArrayIndexOp,
ArrayReduceOp,
Expand Down Expand Up @@ -413,6 +413,7 @@
"GeoStDistanceOp",
# AI ops
"AIGenerateBool",
"AIGenerateDouble",
"AIGenerateInt",
# Numpy ops mapping
"NUMPY_TO_BINOP",
Expand Down
22 changes: 22 additions & 0 deletions bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,25 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
)
)
)


@dataclasses.dataclass(frozen=True)
class AIGenerateDouble(base_ops.NaryOp):
name: ClassVar[str] = "ai_generate_double"

prompt_context: Tuple[str | None, ...]
connection_id: str
endpoint: str | None
request_type: Literal["dedicated", "shared", "unspecified"]
model_params: str | None

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.float64()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
)
39 changes: 39 additions & 0 deletions tests/system/small/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,5 +146,44 @@ def test_ai_generate_int_multi_model(session):
)


def test_ai_generate_double(session):
s = bpd.Series(["Cat"], session=session)
prompt = ("How many legs does a ", s, " have?")

result = bbq.ai.generate_double(prompt, endpoint="gemini-2.5-flash")

assert _contains_no_nulls(result)
assert result.dtype == pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.float64()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
)


def test_ai_generate_double_multi_model(session):
df = session.from_glob_path(
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
)

result = bbq.ai.generate_double(
("How many animals are there in the picture ", df["image"])
)

assert _contains_no_nulls(result)
assert result.dtype == pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.float64()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
)


def _contains_no_nulls(s: series.Series) -> bool:
return len(s) == s.count()
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
AI.GENERATE_DOUBLE(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'test_connection_id',
endpoint => 'gemini-2.5-flash',
request_type => 'SHARED'
) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `result`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
AI.GENERATE_DOUBLE(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'test_connection_id',
request_type => 'SHARED',
model_params => JSON '{}'
) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `result`
FROM `bfcte_1`
45 changes: 45 additions & 0 deletions tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,48 @@ def test_ai_generate_int_with_model_param(
)

snapshot.assert_match(sql, "out.sql")


def test_ai_generate_double(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

op = ops.AIGenerateDouble(
# The prompt does not make semantic sense but we only care about syntax correctness.
prompt_context=(None, " is the same as ", None),
connection_id="test_connection_id",
endpoint="gemini-2.5-flash",
request_type="shared",
model_params=None,
)

sql = utils._apply_unary_ops(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

snapshot.assert_match(sql, "out.sql")


def test_ai_generate_double_with_model_param(
scalar_types_df: dataframe.DataFrame, snapshot
):
if version.Version(sqlglot.__version__) < version.Version("25.18.0"):
pytest.skip(
"Skip test because SQLGLot cannot compile model params to JSON at this version."
)

col_name = "string_col"

op = ops.AIGenerateDouble(
# The prompt does not make semantic sense but we only care about syntax correctness.
prompt_context=(None, " is the same as ", None),
connection_id="test_connection_id",
endpoint=None,
request_type="shared",
model_params=json.dumps(dict()),
)

sql = utils._apply_unary_ops(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

snapshot.assert_match(sql, "out.sql")
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,9 @@ def visit_AIGenerateBool(self, op, **kwargs):
def visit_AIGenerateInt(self, op, **kwargs):
return sge.func("AI.GENERATE_INT", *self._compile_ai_args(**kwargs))

def visit_AIGenerateDouble(self, op, **kwargs):
return sge.func("AI.GENERATE_DOUBLE", *self._compile_ai_args(**kwargs))

def _compile_ai_args(self, **kwargs):
args = []

Expand Down
23 changes: 23 additions & 0 deletions third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,26 @@ def dtype(self) -> dt.Struct:
return dt.Struct.from_tuples(
(("result", dt.int64), ("full_resposne", dt.string), ("status", dt.string))
)


@public
class AIGenerateDouble(Value):
"""Generate integers based on the prompt"""

prompt: Value
connection_id: Value[dt.String]
endpoint: Optional[Value[dt.String]]
request_type: Value[dt.String]
model_params: Optional[Value[dt.String]]

shape = rlz.shape_like("prompt")

@attribute
def dtype(self) -> dt.Struct:
return dt.Struct.from_tuples(
(
("result", dt.float64),
("full_resposne", dt.string),
("status", dt.string),
)
)