Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,25 @@
from __future__ import annotations

import json
from typing import Any, List, Literal, Mapping, Tuple
from typing import Any, List, Literal, Mapping, Tuple, Union

from bigframes import clients, dtypes, series
from bigframes.core import log_adapter
import pandas as pd

from bigframes import clients, dtypes, series, session
from bigframes.core import convert, log_adapter
from bigframes.operations import ai_ops

PROMPT_TYPE = Union[
series.Series,
pd.Series,
List[Union[str, series.Series, pd.Series]],
Tuple[Union[str, series.Series, pd.Series], ...],
]


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_bool(
prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...],
prompt: PROMPT_TYPE,
*,
connection_id: str | None = None,
endpoint: str | None = None,
Expand All @@ -51,7 +60,7 @@ def generate_bool(
0 {'result': True, 'full_response': '{"candidate...
1 {'result': True, 'full_response': '{"candidate...
2 {'result': False, 'full_response': '{"candidat...
dtype: struct<result: bool, full_response: string, status: string>[pyarrow]
dtype: struct<result: bool, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]

>>> bbq.ai.generate_bool((df["col_1"], " is a ", df["col_2"])).struct.field("result")
0 True
Expand All @@ -60,8 +69,9 @@ def generate_bool(
Name: result, dtype: boolean

Args:
prompt (series.Series | List[str|series.Series] | Tuple[str|series.Series, ...]):
A mixture of Series and string literals that specifies the prompt to send to the model.
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
or pandas Series.
connection_id (str, optional):
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
If not provided, the connection from the current session will be used.
Expand All @@ -84,7 +94,7 @@ def generate_bool(
Returns:
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
* "result": a BOOL value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
* "full_response": a STRING value containing the JSON response from the projects.locations.endpoints.generateContent call to the model.
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
The generated text is in the text element.
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
"""
Expand All @@ -104,7 +114,7 @@ def generate_bool(


def _separate_context_and_series(
prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...],
prompt: PROMPT_TYPE,
) -> Tuple[List[str | None], List[series.Series]]:
"""
Returns the two values. The first value is the prompt with all series replaced by None. The second value is all the series
Expand All @@ -123,18 +133,19 @@ def _separate_context_and_series(
return [None], [prompt]

prompt_context: List[str | None] = []
series_list: List[series.Series] = []
series_list: List[series.Series | pd.Series] = []

session = None
for item in prompt:
if isinstance(item, str):
prompt_context.append(item)

elif isinstance(item, series.Series):
elif isinstance(item, (series.Series, pd.Series)):
prompt_context.append(None)

if item.dtype == dtypes.OBJ_REF_DTYPE:
# Multi-model support
item = item.blob.read_url()
Copy link
Contributor

@chelsea-lin chelsea-lin Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the pd.Series be converted into a series.Series represented for multi-model? I would suggest to have two IF branches: one for pd.Series and another one for. series.Series, for more readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

series conversions happens at the bottom of the function body at line 157

This if branch is just to grab the session from the first BigFrames session

if isinstance(item, series.Series) and session is None:
# Use the first available BF session if there's any.
session = item._session
series_list.append(item)

else:
Expand All @@ -143,7 +154,20 @@ def _separate_context_and_series(
if not series_list:
raise ValueError("Please provide at least one Series in the prompt")

return prompt_context, series_list
converted_list = [_convert_series(s, session) for s in series_list]

return prompt_context, converted_list


def _convert_series(
s: series.Series | pd.Series, session: session.Session | None
) -> series.Series:
result = convert.to_bf_series(s, default_index=None, session=session)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when two series.Series have two different sessions, should we throw an error here?

Copy link
Contributor Author

@sycai sycai Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to leave that check to the place where we "align" series:

values, block = self._align_n(
others, ignore_self=ignore_self, cast_scalars=False
)


if result.dtype == dtypes.OBJ_REF_DTYPE:
# Support multimodel
return result.blob.read_url()
return result


def _resolve_connection_id(series: series.Series, connection_id: str | None):
Expand Down
2 changes: 1 addition & 1 deletion bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
pa.struct(
(
pa.field("result", pa.bool_()),
pa.field("full_response", pa.string()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
Expand Down
27 changes: 23 additions & 4 deletions tests/system/small/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pyarrow as pa
import pytest

from bigframes import series
from bigframes import dtypes, series
import bigframes.bigquery as bbq
import bigframes.pandas as bpd

Expand All @@ -35,7 +35,26 @@ def test_ai_generate_bool(session):
pa.struct(
(
pa.field("result", pa.bool_()),
pa.field("full_response", pa.string()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
)


def test_ai_generate_bool_with_pandas(session):
s1 = pd.Series(["apple", "bear"])
s2 = bpd.Series(["fruit", "tree"], session=session)
prompt = (s1, " is a ", s2)

result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash")

assert _contains_no_nulls(result)
assert result.dtype == pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.bool_()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
Expand All @@ -62,7 +81,7 @@ def test_ai_generate_bool_with_model_params(session):
pa.struct(
(
pa.field("result", pa.bool_()),
pa.field("full_response", pa.string()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
Expand All @@ -81,7 +100,7 @@ def test_ai_generate_bool_multi_model(session):
pa.struct(
(
pa.field("result", pa.bool_()),
pa.field("full_response", pa.string()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
Expand Down