Skip to content
50 changes: 35 additions & 15 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import google.api_core.gapic_v1.client_info
import google.auth.credentials
import google.cloud.bigquery as bigquery
import google.cloud.bigquery.table
import google.cloud.bigquery_connection_v1
import google.cloud.bigquery_storage_v1
import google.cloud.functions_v2
Expand Down Expand Up @@ -693,7 +694,7 @@ def read_gbq_table(

def _get_snapshot_sql_and_primary_key(
self,
table_ref: bigquery.table.TableReference,
table: google.cloud.bigquery.table.Table,
*,
api_name: str,
use_cache: bool = True,
Expand All @@ -709,7 +710,7 @@ def _get_snapshot_sql_and_primary_key(
table,
) = bigframes_io.get_snapshot_datetime_and_table_metadata(
self.bqclient,
table_ref=table_ref,
table_ref=table.reference,
api_name=api_name,
cache=self._df_snapshot,
use_cache=use_cache,
Expand All @@ -735,7 +736,7 @@ def _get_snapshot_sql_and_primary_key(

try:
table_expression = self.ibis_client.sql(
bigframes_io.create_snapshot_sql(table_ref, snapshot_timestamp)
bigframes_io.create_snapshot_sql(table.reference, snapshot_timestamp)
)
except google.api_core.exceptions.Forbidden as ex:
if "Drive credentials" in ex.message:
Expand Down Expand Up @@ -763,8 +764,9 @@ def _read_gbq_table(
query, default_project=self.bqclient.project
)

table = self.bqclient.get_table(table_ref)
(table_expression, primary_keys,) = self._get_snapshot_sql_and_primary_key(
table_ref, api_name=api_name, use_cache=use_cache
table, api_name=api_name, use_cache=use_cache
)
total_ordering_cols = primary_keys

Expand Down Expand Up @@ -836,9 +838,13 @@ def _read_gbq_table(
ordering=ordering,
)
else:
array_value = self._create_total_ordering(table_expression)
array_value = self._create_total_ordering(
table_expression, table_rows=table.num_rows
)
else:
array_value = self._create_total_ordering(table_expression)
array_value = self._create_total_ordering(
table_expression, table_rows=table.num_rows
)

value_columns = [col for col in array_value.column_ids if col not in index_cols]
block = blocks.Block(
Expand Down Expand Up @@ -1459,10 +1465,19 @@ def _create_empty_temp_table(
def _create_total_ordering(
self,
table: ibis_types.Table,
table_rows: Optional[int],
) -> core.ArrayValue:
# Since this might also be used as the index, don't use the default
# "ordering ID" name.

# For small tables, 64 bits is enough to avoid collisions, 128 bits will never ever collide no matter what
# Assume table is large if table row count is unknown
use_double_hash = (
(table_rows is None) or (table_rows == 0) or (table_rows > 100000)
)

ordering_hash_part = guid.generate_guid("bigframes_ordering_")
ordering_hash_part2 = guid.generate_guid("bigframes_ordering_")
ordering_rand_part = guid.generate_guid("bigframes_ordering_")

# All inputs into hash must be non-null or resulting hash will be null
Expand All @@ -1475,25 +1490,30 @@ def _create_total_ordering(
else str_values[0]
)
full_row_hash = full_row_str.hash().name(ordering_hash_part)
# By modifying value slightly, we get another hash uncorrelated with the first
full_row_hash_p2 = (full_row_str + "_").hash().name(ordering_hash_part2)
# Used to disambiguate between identical rows (which will have identical hash)
random_value = ibis.random().name(ordering_rand_part)

order_values = (
[full_row_hash, full_row_hash_p2, random_value]
if use_double_hash
else [full_row_hash, random_value]
)

original_column_ids = table.columns
table_with_ordering = table.select(
itertools.chain(original_column_ids, [full_row_hash, random_value])
itertools.chain(original_column_ids, order_values)
)

ordering_ref1 = order.ascending_over(ordering_hash_part)
ordering_ref2 = order.ascending_over(ordering_rand_part)
ordering = order.ExpressionOrdering(
ordering_value_columns=(ordering_ref1, ordering_ref2),
total_ordering_columns=frozenset([ordering_hash_part, ordering_rand_part]),
ordering_value_columns=tuple(
order.ascending_over(col.get_name()) for col in order_values
),
total_ordering_columns=frozenset(col.get_name() for col in order_values),
)
columns = [table_with_ordering[col] for col in original_column_ids]
hidden_columns = [
table_with_ordering[ordering_hash_part],
table_with_ordering[ordering_rand_part],
]
hidden_columns = [table_with_ordering[col.get_name()] for col in order_values]
return core.ArrayValue.from_ibis(
self,
table_with_ordering,
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def create_bigquery_session(
google.auth.credentials.Credentials, instance=True
)

if anonymous_dataset is None:
anonymous_dataset = google.cloud.bigquery.DatasetReference(
"test-project",
"test_dataset",
)

if bqclient is None:
bqclient = mock.create_autospec(google.cloud.bigquery.Client, instance=True)
bqclient.project = "test-project"
Expand All @@ -53,6 +59,10 @@ def create_bigquery_session(
table._properties = {}
type(table).location = mock.PropertyMock(return_value="test-region")
type(table).schema = mock.PropertyMock(return_value=table_schema)
type(table).reference = mock.PropertyMock(
return_value=anonymous_dataset.table("test_table")
)
type(table).num_rows = mock.PropertyMock(return_value=1000000000)
bqclient.get_table.return_value = table

if anonymous_dataset is None:
Expand Down
21 changes: 17 additions & 4 deletions tests/unit/session/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def test_read_gbq_cached_table():
table,
)

def get_table_mock(table_ref):
table = google.cloud.bigquery.Table(
table_ref, (google.cloud.bigquery.SchemaField("col", "INTEGER"),)
)
table._properties["numRows"] = "1000000000"
table._properties["location"] = session._location
return table

session.bqclient.get_table = get_table_mock

with pytest.warns(UserWarning, match=re.escape("use_cache=False")):
df = session.read_gbq("my-project.my_dataset.my_table")

Expand Down Expand Up @@ -137,10 +147,13 @@ def query_mock(query, *args, **kwargs):

session.bqclient.query = query_mock

def get_table_mock(dataset_ref):
dataset = google.cloud.bigquery.Dataset(dataset_ref)
dataset.location = session._location
return dataset
def get_table_mock(table_ref):
table = google.cloud.bigquery.Table(
table_ref, (google.cloud.bigquery.SchemaField("col", "INTEGER"),)
)
table._properties["numRows"] = 1000000000
table._properties["location"] = session._location
return table

session.bqclient.get_table = get_table_mock

Expand Down