Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions bigframes/core/compile/sqlglot/aggregations/op_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,16 @@ def arg_checker(*args, **kwargs):
)
return item(*args, **kwargs)

if hasattr(op, "name"):
key = typing.cast(str, op.name)
if key in self._registered_ops:
raise ValueError(f"{key} is already registered")
else:
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
key = str(op)
if key in self._registered_ops:
raise ValueError(f"{key} is already registered")
self._registered_ops[key] = item
return arg_checker

return decorator

def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc:
if isinstance(op, agg_ops.WindowOp):
if not hasattr(op, "name"):
raise ValueError(f"The operator must have a 'name' attribute. Got {op}")
else:
key = typing.cast(str, op.name)
return self._registered_ops[key]
return self._registered_ops[op]
key = op if isinstance(op, type) else type(op)
if str(key) not in self._registered_ops:
raise ValueError(f"{key} is already not registered")
return self._registered_ops[str(key)]
58 changes: 54 additions & 4 deletions bigframes/core/compile/sqlglot/aggregations/unary_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,37 @@ def compile(
return UNARY_OP_REGISTRATION[op](op, column, window=window)


@UNARY_OP_REGISTRATION.register(agg_ops.ApproxQuartilesOp)
def _(
op: agg_ops.ApproxQuartilesOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
if window is not None:
raise NotImplementedError("Approx Quartiles with windowing is not supported.")
# APPROX_QUANTILES returns an array of the quartiles, so we need to index it.
# The op.quartile is 1-based for the quartile, but array is 0-indexed.
# The quartiles are Q0, Q1, Q2, Q3, Q4. op.quartile is 1, 2, or 3.
# The array has 5 elements (for N=4 intervals).
# So we want the element at index `op.quartile`.
approx_quantiles_expr = sge.func("APPROX_QUANTILES", column.expr, sge.convert(4))
return sge.Bracket(
this=approx_quantiles_expr,
expressions=[sge.func("OFFSET", sge.convert(op.quartile))],
)


@UNARY_OP_REGISTRATION.register(agg_ops.ApproxTopCountOp)
def _(
op: agg_ops.ApproxTopCountOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
if window is not None:
raise NotImplementedError("Approx top count with windowing is not supported.")
return sge.func("APPROX_TOP_COUNT", column.expr, sge.convert(op.number))


@UNARY_OP_REGISTRATION.register(agg_ops.CountOp)
def _(
op: agg_ops.CountOp,
Expand Down Expand Up @@ -109,13 +140,23 @@ def _(
return apply_window_if_present(sge.func("MIN", column.expr), window)


@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp)
def _(
op: agg_ops.SizeUnaryOp,
_,
op: agg_ops.QuantileOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
# TODO: Support interpolation argument
# TODO: Support percentile_disc
result: sge.Expression = sge.func("PERCENTILE_CONT", column.expr, sge.convert(op.q))
if window is None:
# PERCENTILE_CONT is a navigation function, not an aggregate function, so it always needs an OVER clause.
result = sge.Window(this=result)
else:
result = apply_window_if_present(result, window)
if op.should_floor_result:
result = sge.Cast(this=sge.func("FLOOR", result), to="INT64")
return result


@UNARY_OP_REGISTRATION.register(agg_ops.RankOp)
Expand All @@ -130,6 +171,15 @@ def _(
)


@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)
def _(
op: agg_ops.SizeUnaryOp,
_,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)


@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
def _(
op: agg_ops.SumOp,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
WITH `bfcte_0` AS (
SELECT
`int64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
APPROX_QUANTILES(`bfcol_0`, 4)[OFFSET(1)] AS `bfcol_1`,
APPROX_QUANTILES(`bfcol_0`, 4)[OFFSET(2)] AS `bfcol_2`,
APPROX_QUANTILES(`bfcol_0`, 4)[OFFSET(3)] AS `bfcol_3`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `q1`,
`bfcol_2` AS `q2`,
`bfcol_3` AS `q3`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
WITH `bfcte_0` AS (
SELECT
`int64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
APPROX_TOP_COUNT(`bfcol_0`, 10) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `int64_col`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
WITH `bfcte_0` AS (
SELECT
`int64_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
PERCENTILE_CONT(`bfcol_0`, 0.5) OVER () AS `bfcol_1`,
CAST(FLOOR(PERCENTILE_CONT(`bfcol_0`, 0.5) OVER ()) AS INT64) AS `bfcol_2`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `quantile`,
`bfcol_2` AS `quantile_floor`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def test_func(op: agg_ops.SizeOp, input: sge.Expression) -> sge.Expression:
return input

assert reg[agg_ops.SizeOp()](op, input) == test_func(op, input)
assert reg[agg_ops.SizeOp.name](op, input) == test_func(op, input)


def test_register_function_first_argument_is_not_agg_op_raise_error():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,30 @@ def _apply_unary_window_op(
return sql


def test_approx_quartiles(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
agg_ops_map = {
"q1": agg_ops.ApproxQuartilesOp(quartile=1).as_expr(col_name),
"q2": agg_ops.ApproxQuartilesOp(quartile=2).as_expr(col_name),
"q3": agg_ops.ApproxQuartilesOp(quartile=3).as_expr(col_name),
}
sql = _apply_unary_agg_ops(
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
)

snapshot.assert_match(sql, "out.sql")


def test_approx_top_count(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
agg_expr = agg_ops.ApproxTopCountOp(number=10).as_expr(col_name)
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])

snapshot.assert_match(sql, "out.sql")


def test_count(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
Expand Down Expand Up @@ -141,6 +165,22 @@ def test_min(scalar_types_df: bpd.DataFrame, snapshot):
snapshot.assert_match(sql, "out.sql")


def test_quantile(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
agg_ops_map = {
"quantile": agg_ops.QuantileOp(q=0.5).as_expr(col_name),
"quantile_floor": agg_ops.QuantileOp(q=0.5, should_floor_result=True).as_expr(
col_name
),
}
sql = _apply_unary_agg_ops(
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
)

snapshot.assert_match(sql, "out.sql")


def test_rank(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
Expand Down