Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,15 @@ def offline_write_batch(
location=config.offline_store.location,
)

parquet_options = bigquery.ParquetOptions()
parquet_options.enable_list_inference = True

job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.PARQUET,
schema=arrow_schema_to_bq_schema(pa_schema),
create_disposition=config.offline_store.table_create_disposition,
write_disposition="WRITE_APPEND", # Default but included for clarity
parquet_options=parquet_options,
)

with tempfile.TemporaryFile() as parquet_temp_file:
Expand Down
61 changes: 61 additions & 0 deletions sdk/python/tests/unit/infra/offline_stores/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,64 @@ def test_table_property_unaffected_by_query_priority(self):
timestamp_field="ts",
)
assert source.table == "project.dataset.write_target"


class TestOfflineWriteBatch:
@patch("feast.infra.offline_stores.bigquery._get_bigquery_client")
def test_offline_write_batch_enables_list_inference(self, mock_get_client):
"""LoadJobConfig must set parquet_options.enable_list_inference = True
so that BigQuery correctly interprets PyArrow list columns from parquet.
"""
from unittest.mock import MagicMock

source = BigQuerySource(
name="test",
table="project.dataset.table",
timestamp_field="ts",
)
fv = MagicMock()
fv.batch_source = source

pa_schema = pyarrow.schema(
[
pyarrow.field("entity_id", pyarrow.string()),
pyarrow.field("tags", pyarrow.list_(pyarrow.string())),
pyarrow.field("ts", pyarrow.timestamp("us", tz="UTC")),
]
)
pa_table = pyarrow.table(
{
"entity_id": ["e1"],
"tags": [["a", "b"]],
"ts": [datetime(2024, 1, 1, tzinfo=timezone.utc)],
},
schema=pa_schema,
)

mock_client = MagicMock()
mock_get_client.return_value = mock_client
mock_client.load_table_from_file.return_value = MagicMock()

config = RepoConfig(
registry="gs://test/registry.db",
project="test",
provider="gcp",
offline_store=BigQueryOfflineStoreConfig(project_id="test-project"),
online_store=SqliteOnlineStoreConfig(),
)

with patch(
"feast.infra.offline_stores.offline_utils.get_pyarrow_schema_from_batch_source",
return_value=(pa_schema, pa_table.column_names),
):
BigQueryOfflineStore.offline_write_batch(
config=config,
feature_view=fv,
table=pa_table,
progress=None,
)

call_kwargs = mock_client.load_table_from_file.call_args
job_config = call_kwargs[1]["job_config"]
assert job_config.parquet_options is not None
assert job_config.parquet_options.enable_list_inference is True
Loading