Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 16 additions & 36 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ def __init__(
# Now that we're starting the session, don't allow the options to be
# changed.
context._session_started = True
self._df_snapshot: Dict[bigquery.TableReference, datetime.datetime] = {}
self._df_snapshot: Dict[
bigquery.TableReference, Tuple[datetime.datetime, bigquery.Table]
] = {}

@property
def bqclient(self):
Expand Down Expand Up @@ -698,16 +700,25 @@ def _get_snapshot_sql_and_primary_key(
column(s), then return those too so that ordering generation can be
avoided.
"""
# If there are primary keys defined, the query engine assumes these
# columns are unique, even if the constraint is not enforced. We make
# the same assumption and use these columns as the total ordering keys.
table = self.bqclient.get_table(table_ref)
(
snapshot_timestamp,
table,
) = bigframes_io.get_snapshot_datetime_and_table_metadata(
self.bqclient,
table_ref=table_ref,
api_name=api_name,
cache=self._df_snapshot,
use_cache=use_cache,
)

if table.location.casefold() != self._location.casefold():
raise ValueError(
f"Current session is in {self._location} but dataset '{table.project}.{table.dataset_id}' is located in {table.location}"
)

# If there are primary keys defined, the query engine assumes these
# columns are unique, even if the constraint is not enforced. We make
# the same assumption and use these columns as the total ordering keys.
primary_keys = None
if (
(table_constraints := getattr(table, "table_constraints", None)) is not None
Expand All @@ -718,37 +729,6 @@ def _get_snapshot_sql_and_primary_key(
):
primary_keys = columns

job_config = bigquery.QueryJobConfig()
job_config.labels["bigframes-api"] = api_name
if use_cache and table_ref in self._df_snapshot.keys():
snapshot_timestamp = self._df_snapshot[table_ref]

# Cache hit could be unexpected. See internal issue 329545805.
# Raise a warning with more information about how to avoid the
# problems with the cache.
warnings.warn(
f"Reading cached table from {snapshot_timestamp} to avoid "
"incompatibilies with previous reads of this table. To read "
"the latest version, set `use_cache=False` or close the "
"current session with Session.close() or "
"bigframes.pandas.close_session().",
# There are many layers before we get to (possibly) the user's code:
# pandas.read_gbq_table
# -> with_default_session
# -> Session.read_gbq_table
# -> _read_gbq_table
# -> _get_snapshot_sql_and_primary_key
stacklevel=6,
)
else:
snapshot_timestamp = list(
self.bqclient.query(
"SELECT CURRENT_TIMESTAMP() AS `current_timestamp`",
job_config=job_config,
).result()
)[0][0]
self._df_snapshot[table_ref] = snapshot_timestamp

try:
table_expression = self.ibis_client.sql(
bigframes_io.create_snapshot_sql(table_ref, snapshot_timestamp)
Expand Down
54 changes: 54 additions & 0 deletions bigframes/session/_io/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import types
from typing import Dict, Iterable, Optional, Sequence, Tuple, Union
import uuid
import warnings

import google.api_core.exceptions
import google.cloud.bigquery as bigquery
Expand Down Expand Up @@ -121,6 +122,59 @@ def table_ref_to_sql(table: bigquery.TableReference) -> str:
return f"`{table.project}`.`{table.dataset_id}`.`{table.table_id}`"


def get_snapshot_datetime_and_table_metadata(
bqclient: bigquery.Client,
table_ref: bigquery.TableReference,
*,
api_name: str,
cache: Dict[bigquery.TableReference, Tuple[datetime.datetime, bigquery.Table]],
use_cache: bool = True,
) -> Tuple[datetime.datetime, bigquery.Table]:
cached_table = cache.get(table_ref)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if use_cache=False, we can avoid looking through the cache, though the cache won't be too large and the looking is not too expensive.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good observation!

Python does dictionary lookups all the time, even for accessing attributes of objects, so I'm not too worried about it. Will leave as-is to avoid the extra nesting.

if use_cache and cached_table is not None:
snapshot_timestamp, _ = cached_table

# Cache hit could be unexpected. See internal issue 329545805.
# Raise a warning with more information about how to avoid the
# problems with the cache.
warnings.warn(
f"Reading cached table from {snapshot_timestamp} to avoid "
"incompatibilies with previous reads of this table. To read "
"the latest version, set `use_cache=False` or close the "
"current session with Session.close() or "
"bigframes.pandas.close_session().",
# There are many layers before we get to (possibly) the user's code:
# pandas.read_gbq_table
# -> with_default_session
# -> Session.read_gbq_table
# -> _read_gbq_table
# -> _get_snapshot_sql_and_primary_key
# -> get_snapshot_datetime_and_table_metadata
stacklevel=7,
)
return cached_table

# TODO(swast): It's possible that the table metadata is changed between now
# and when we run the CURRENT_TIMESTAMP() query to see when we can time
# travel to. Find a way to fetch the table metadata and BQ's current time
# atomically.
table = bqclient.get_table(table_ref)

# TODO(b/336521938): Refactor to make sure we set the "bigframes-api"
# whereever we execute a query.
job_config = bigquery.QueryJobConfig()
job_config.labels["bigframes-api"] = api_name
snapshot_timestamp = list(
bqclient.query(
"SELECT CURRENT_TIMESTAMP() AS `current_timestamp`",
job_config=job_config,
).result()
)[0][0]
cached_table = (snapshot_timestamp, table)
cache[table_ref] = cached_table
return cached_table


def create_snapshot_sql(
table_ref: bigquery.TableReference, current_timestamp: datetime.datetime
) -> str:
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/session/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ def test_read_gbq_cached_table():
google.cloud.bigquery.DatasetReference("my-project", "my_dataset"),
"my_table",
)
session._df_snapshot[table_ref] = datetime.datetime(
1999, 1, 2, 3, 4, 5, 678901, tzinfo=datetime.timezone.utc
table = google.cloud.bigquery.Table(table_ref)
table._properties["location"] = session._location
session._df_snapshot[table_ref] = (
datetime.datetime(1999, 1, 2, 3, 4, 5, 678901, tzinfo=datetime.timezone.utc),
table,
)

with pytest.warns(UserWarning, match=re.escape("use_cache=False")):
Expand Down