Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions bigframes/core/compile/sqlglot/aggregations/windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from bigframes.core import utils, window_spec
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
import bigframes.core.expression as ex
import bigframes.core.ordering as ordering_spec
import bigframes.dtypes as dtypes


def apply_window_if_present(
Expand Down Expand Up @@ -52,10 +54,7 @@ def apply_window_if_present(
order = sge.Order(expressions=order_by)

group_by = (
[
scalar_compiler.scalar_op_compiler.compile_expression(key)
for key in window.grouping_keys
]
[_compile_group_by_key(key) for key in window.grouping_keys]
if window.grouping_keys
else None
)
Expand Down Expand Up @@ -164,3 +163,18 @@ def _get_window_bounds(

side = "PRECEDING" if value < 0 else "FOLLOWING"
return sge.convert(abs(value)), side


def _compile_group_by_key(key: ex.Expression) -> sge.Expression:
expr = scalar_compiler.scalar_op_compiler.compile_expression(key)
# The group_by keys has been rewritten by bind_schema_to_node
assert isinstance(key, ex.ResolvedDerefOp)

# Some types need to be converted to another type to enable groupby
if key.dtype == dtypes.FLOAT_DTYPE:
expr = sge.Cast(this=expr, to="STRING")
elif key.dtype == dtypes.GEO_DTYPE:
expr = sge.Cast(this=expr, to="BYTES")
elif key.dtype == dtypes.JSON_DTYPE:
expr = sge.func("TO_JSON_STRING", expr)
return expr
47 changes: 42 additions & 5 deletions tests/unit/core/compile/sqlglot/aggregations/test_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import pytest
import sqlglot.expressions as sge

from bigframes import dtypes
from bigframes.core import window_spec
from bigframes.core.compile.sqlglot.aggregations.windows import (
apply_window_if_present,
get_window_order_by,
)
import bigframes.core.expression as ex
import bigframes.core.identifiers as ids
import bigframes.core.ordering as ordering


Expand Down Expand Up @@ -82,16 +84,37 @@ def test_apply_window_if_present_row_bounded_no_ordering_raises(self):
),
)

def test_apply_window_if_present_unbounded_grouping_no_ordering(self):
def test_apply_window_if_present_grouping_no_ordering(self):
result = apply_window_if_present(
sge.Var(this="value"),
window_spec.WindowSpec(
grouping_keys=(ex.deref("col1"),),
grouping_keys=(
ex.ResolvedDerefOp(
ids.ColumnId("col1"),
dtype=dtypes.STRING_DTYPE,
is_nullable=True,
),
ex.ResolvedDerefOp(
ids.ColumnId("col2"),
dtype=dtypes.FLOAT_DTYPE,
is_nullable=True,
),
ex.ResolvedDerefOp(
ids.ColumnId("col3"),
dtype=dtypes.JSON_DTYPE,
is_nullable=True,
),
ex.ResolvedDerefOp(
ids.ColumnId("col4"),
dtype=dtypes.GEO_DTYPE,
is_nullable=True,
),
),
),
)
self.assertEqual(
result.sql(dialect="bigquery"),
"value OVER (PARTITION BY `col1`)",
"value OVER (PARTITION BY `col1`, CAST(`col2` AS STRING), TO_JSON_STRING(`col3`), CAST(`col4` AS BYTES))",
)

def test_apply_window_if_present_range_bounded(self):
Expand Down Expand Up @@ -126,8 +149,22 @@ def test_apply_window_if_present_all_params(self):
result = apply_window_if_present(
sge.Var(this="value"),
window_spec.WindowSpec(
grouping_keys=(ex.deref("col1"),),
ordering=(ordering.OrderingExpression(ex.deref("col2")),),
grouping_keys=(
ex.ResolvedDerefOp(
ids.ColumnId("col1"),
dtype=dtypes.STRING_DTYPE,
is_nullable=True,
),
),
ordering=(
ordering.OrderingExpression(
ex.ResolvedDerefOp(
ids.ColumnId("col2"),
dtype=dtypes.STRING_DTYPE,
is_nullable=True,
)
),
),
bounds=window_spec.RowsWindowBounds(start=-1, end=0),
),
)
Expand Down