Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pandas as pd
import pyarrow
import pyarrow.parquet
from pydantic import StrictStr
from pydantic import StrictStr, validator
from pydantic.typing import Literal
from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed

Expand Down Expand Up @@ -83,7 +83,8 @@ class BigQueryOfflineStoreConfig(FeastConfigBaseModel):

project_id: Optional[StrictStr] = None
""" (optional) GCP project name used for the BigQuery offline store """

billing_project_id: Optional[StrictStr] = None
""" (optional) GCP project name used to run the bigquery jobs at """
location: Optional[StrictStr] = None
""" (optional) GCP location name used for the BigQuery offline store.
Examples of location names include ``US``, ``EU``, ``us-central1``, ``us-west4``.
Expand All @@ -94,6 +95,14 @@ class BigQueryOfflineStoreConfig(FeastConfigBaseModel):
gcs_staging_location: Optional[str] = None
""" (optional) GCS location used for offloading BigQuery results as parquet files."""

@validator("billing_project_id")
def project_id_exists(cls, v, values, **kwargs):
if v and not values["project_id"]:
raise ValueError(
"please specify project_id if billing_project_id is specified"
)
return v


class BigQueryOfflineStore(OfflineStore):
@staticmethod
Expand Down Expand Up @@ -122,9 +131,11 @@ def pull_latest_from_table_or_query(
timestamps.append(created_timestamp_column)
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
field_string = ", ".join(join_key_columns + feature_name_columns + timestamps)

project_id = (
config.offline_store.billing_project_id or config.offline_store.project_id
)
client = _get_bigquery_client(
project=config.offline_store.project_id,
project=project_id,
location=config.offline_store.location,
)
query = f"""
Expand Down Expand Up @@ -162,9 +173,11 @@ def pull_all_from_table_or_query(
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
assert isinstance(data_source, BigQuerySource)
from_expression = data_source.get_table_query_string()

project_id = (
config.offline_store.billing_project_id or config.offline_store.project_id
)
client = _get_bigquery_client(
project=config.offline_store.project_id,
project=project_id,
location=config.offline_store.location,
)
field_string = ", ".join(
Expand Down Expand Up @@ -197,17 +210,22 @@ def get_historical_features(
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
for fv in feature_views:
assert isinstance(fv.batch_source, BigQuerySource)

project_id = (
config.offline_store.billing_project_id or config.offline_store.project_id
)
client = _get_bigquery_client(
project=config.offline_store.project_id,
project=project_id,
location=config.offline_store.location,
)

assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)

if config.offline_store.billing_project_id:
dataset_project = str(config.offline_store.project_id)
else:
dataset_project = client.project
table_reference = _get_table_reference_for_new_entity(
client,
client.project,
dataset_project,
config.offline_store.dataset,
config.offline_store.location,
)
Expand Down Expand Up @@ -295,9 +313,11 @@ def write_logged_features(
):
destination = logging_config.destination
assert isinstance(destination, BigQueryLoggingDestination)

project_id = (
config.offline_store.billing_project_id or config.offline_store.project_id
)
client = _get_bigquery_client(
project=config.offline_store.project_id,
project=project_id,
location=config.offline_store.location,
)

Expand Down Expand Up @@ -353,9 +373,11 @@ def offline_write_batch(

if table.schema != pa_schema:
table = table.cast(pa_schema)

project_id = (
config.offline_store.billing_project_id or config.offline_store.project_id
)
client = _get_bigquery_client(
project=config.offline_store.project_id,
project=project_id,
location=config.offline_store.location,
)

Expand Down Expand Up @@ -451,7 +473,10 @@ def to_bigquery(
if not job_config:
today = date.today().strftime("%Y%m%d")
rand_id = str(uuid.uuid4())[:7]
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
if self.config.offline_store.billing_project_id:
path = f"{self.config.offline_store.project_id}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
else:
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
job_config = bigquery.QueryJobConfig(destination=path)

if not job_config.dry_run and self.on_demand_feature_views:
Expand Down Expand Up @@ -525,7 +550,10 @@ def to_remote_storage(self) -> List[str]:

bucket: str
prefix: str
storage_client = StorageClient(project=self.client.project)
if self.config.offline_store.billing_project_id:
storage_client = StorageClient(project=self.config.offline_store.project_id)
else:
storage_client = StorageClient(project=self.client.project)
bucket, prefix = self._gcs_path[len("gs://") :].split("/", 1)
prefix = prefix.rsplit("/", 1)[0]
if prefix.startswith("/"):
Expand Down