Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 23 additions & 42 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,49 +818,30 @@ def _materialize_local(
total_rows = result_batches.approx_total_rows
# Remove downsampling config from subsequent invocations, as otherwise could result in many
# iterations if downsampling undershoots
return self._downsample(
total_rows=total_rows,
sampling_method=sample_config.sampling_method,
fraction=fraction,
random_state=sample_config.random_state,
)._materialize_local(
MaterializationOptions(ordered=materialize_options.ordered)
)
else:
df = result_batches.to_pandas()
df = self._copy_index_to_pandas(df)
df.set_axis(self.column_labels, axis=1, copy=False)
return df, execute_result.query_job

def _downsample(
self, total_rows: int, sampling_method: str, fraction: float, random_state
) -> Block:
# either selecting fraction or number of rows
if sampling_method == _HEAD:
filtered_block = self.slice(stop=int(total_rows * fraction))
return filtered_block
elif (sampling_method == _UNIFORM) and (random_state is None):
filtered_expr = self.expr._uniform_sampling(fraction)
block = Block(
filtered_expr,
index_columns=self.index_columns,
column_labels=self.column_labels,
index_labels=self.index.names,
)
return block
elif sampling_method == _UNIFORM:
block = self.split(
fracs=(fraction,),
random_state=random_state,
sort=False,
)[0]
return block
if sample_config.sampling_method == "head":
# Just truncates the result iterator without a follow-up query
raw_df = result_batches.to_pandas(limit=int(total_rows * fraction))
elif (
sample_config.sampling_method == "uniform"
and sample_config.random_state is None
):
# Pushes sample into result without new query
sampled_batches = execute_result.batches(sample_rate=fraction)
raw_df = sampled_batches.to_pandas()
else: # uniform sample with random state requires a full follow-up query
down_sampled_block = self.split(
fracs=(fraction,),
random_state=sample_config.random_state,
sort=False,
)[0]
return down_sampled_block._materialize_local(
MaterializationOptions(ordered=materialize_options.ordered)
)
else:
# This part should never be called, just in case.
raise NotImplementedError(
f"The downsampling method {sampling_method} is not implemented, "
f"please choose from {','.join(_SAMPLING_METHODS)}."
)
raw_df = result_batches.to_pandas()
df = self._copy_index_to_pandas(raw_df)
df.set_axis(self.column_labels, axis=1, copy=False)
return df, execute_result.query_job

def split(
self,
Expand Down
13 changes: 12 additions & 1 deletion bigframes/core/bq_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,22 @@ def get_arrow_batches(
columns: Sequence[str],
storage_read_client: bigquery_storage_v1.BigQueryReadClient,
project_id: str,
sample_rate: Optional[float] = None,
) -> ReadResult:
table_mod_options = {}
read_options_dict: dict[str, Any] = {"selected_fields": list(columns)}

predicates = []
if data.sql_predicate:
read_options_dict["row_restriction"] = data.sql_predicate
predicates.append(data.sql_predicate)
if sample_rate is not None:
assert isinstance(sample_rate, float)
predicates.append(f"RAND() < {sample_rate}")

if predicates:
full_predicates = " AND ".join(f"( {pred} )" for pred in predicates)
read_options_dict["row_restriction"] = full_predicates

read_options = bq_storage_types.ReadSession.TableReadOptions(**read_options_dict)

if data.at_time:
Expand Down
11 changes: 10 additions & 1 deletion bigframes/core/local_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import uuid

import geopandas # type: ignore
import numpy
import numpy as np
import pandas as pd
import pyarrow as pa
Expand Down Expand Up @@ -124,13 +125,21 @@ def to_arrow(
geo_format: Literal["wkb", "wkt"] = "wkt",
duration_type: Literal["int", "duration"] = "duration",
json_type: Literal["string"] = "string",
sample_rate: Optional[float] = None,
max_chunksize: Optional[int] = None,
) -> tuple[pa.Schema, Iterable[pa.RecordBatch]]:
if geo_format != "wkt":
raise NotImplementedError(f"geo format {geo_format} not yet implemented")
assert json_type == "string"

batches = self.data.to_batches(max_chunksize=max_chunksize)
data = self.data

# This exists for symmetry with remote sources, but sampling local data like this shouldn't really happen
if sample_rate is not None:
to_take = numpy.random.rand(data.num_rows) < sample_rate
data = data.filter(to_take)

batches = data.to_batches(max_chunksize=max_chunksize)
schema = self.data.schema
if duration_type == "int":
schema = _schema_durations_to_ints(schema)
Expand Down
26 changes: 15 additions & 11 deletions bigframes/session/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def arrow_batches(self) -> Iterator[pyarrow.RecordBatch]:

yield batch

def to_arrow_table(self) -> pyarrow.Table:
def to_arrow_table(self, limit: Optional[int] = None) -> pyarrow.Table:
# Need to provide schema if no result rows, as arrow can't infer
# If ther are rows, it is safest to infer schema from batches.
# Any discrepencies between predicted schema and actual schema will produce errors.
Expand All @@ -97,18 +97,21 @@ def to_arrow_table(self) -> pyarrow.Table:
peek_value = list(peek_it)
# TODO: Enforce our internal schema on the table for consistency
if len(peek_value) > 0:
return pyarrow.Table.from_batches(
itertools.chain(peek_value, batches), # reconstruct
)
batches = itertools.chain(peek_value, batches) # reconstruct
if limit:
batches = pyarrow_utils.truncate_pyarrow_iterable(
batches, max_results=limit
)
return pyarrow.Table.from_batches(batches)
else:
try:
return self._schema.to_pyarrow().empty_table()
except pa.ArrowNotImplementedError:
# Bug with some pyarrow versions, empty_table only supports base storage types, not extension types.
return self._schema.to_pyarrow(use_storage_types=True).empty_table()

def to_pandas(self) -> pd.DataFrame:
return io_pandas.arrow_to_pandas(self.to_arrow_table(), self._schema)
def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
return io_pandas.arrow_to_pandas(self.to_arrow_table(limit=limit), self._schema)

def to_pandas_batches(
self, page_size: Optional[int] = None, max_results: Optional[int] = None
Expand Down Expand Up @@ -158,7 +161,7 @@ def schema(self) -> bigframes.core.schema.ArraySchema:
...

@abc.abstractmethod
def batches(self) -> ResultsIterator:
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
...

@property
Expand Down Expand Up @@ -200,9 +203,9 @@ def execution_metadata(self) -> ExecutionMetadata:
def schema(self) -> bigframes.core.schema.ArraySchema:
return self._data.schema

def batches(self) -> ResultsIterator:
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
return ResultsIterator(
iter(self._data.to_arrow()[1]),
iter(self._data.to_arrow(sample_rate=sample_rate)[1]),
self.schema,
self._data.metadata.row_count,
self._data.metadata.total_bytes,
Expand All @@ -226,7 +229,7 @@ def execution_metadata(self) -> ExecutionMetadata:
def schema(self) -> bigframes.core.schema.ArraySchema:
return self._schema

def batches(self) -> ResultsIterator:
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
return ResultsIterator(iter([]), self.schema, 0, 0)


Expand Down Expand Up @@ -260,12 +263,13 @@ def schema(self) -> bigframes.core.schema.ArraySchema:
source_ids = [selection[0] for selection in self._selected_fields]
return self._data.schema.select(source_ids).rename(dict(self._selected_fields))

def batches(self) -> ResultsIterator:
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
read_batches = bq_data.get_arrow_batches(
self._data,
[x[0] for x in self._selected_fields],
self._storage_client,
self._project_id,
sample_rate=sample_rate,
)
arrow_batches: Iterator[pa.RecordBatch] = map(
functools.partial(
Expand Down
2 changes: 1 addition & 1 deletion tests/system/small/test_anywidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def execution_metadata(self) -> ExecutionMetadata:
def schema(self) -> Any:
return schema

def batches(self) -> ResultsIterator:
def batches(self, sample_rate=None) -> ResultsIterator:
return ResultsIterator(
arrow_batches_val,
self.schema,
Expand Down
6 changes: 3 additions & 3 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4524,7 +4524,7 @@ def test_df_kurt(scalars_dfs):
"n_default",
],
)
def test_sample(scalars_dfs, frac, n, random_state):
def test_df_to_pandas_sample(scalars_dfs, frac, n, random_state):
scalars_df, _ = scalars_dfs
df = scalars_df.sample(frac=frac, n=n, random_state=random_state)
bf_result = df.to_pandas()
Expand All @@ -4535,15 +4535,15 @@ def test_sample(scalars_dfs, frac, n, random_state):
assert bf_result.shape[1] == scalars_df.shape[1]


def test_sample_determinism(penguins_df_default_index):
def test_df_to_pandas_sample_determinism(penguins_df_default_index):
df = penguins_df_default_index.sample(n=100, random_state=12345).head(15)
bf_result = df.to_pandas()
bf_result2 = df.to_pandas()

pandas.testing.assert_frame_equal(bf_result, bf_result2)


def test_sample_raises_value_error(scalars_dfs):
def test_df_to_pandas_sample_raises_value_error(scalars_dfs):
scalars_df, _ = scalars_dfs
with pytest.raises(
ValueError, match="Only one of 'n' or 'frac' parameter can be specified."
Expand Down