Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 266 additions & 0 deletions bigframes/core/compile/sqlglot/expressions/datetime_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,272 @@
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler

register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op


def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression:
if origin == "epoch":
return sge.convert(0)
elif origin == "start_day":
return sge.func(
"UNIX_MICROS",
sge.Cast(
this=sge.Cast(
this=y.expr, to=sge.DataType(this=sge.DataType.Type.DATE)
),
to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ),
),
)
elif origin == "start":
return sge.func(
"UNIX_MICROS",
sge.Cast(this=y.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)),
)
else:
raise ValueError(f"Origin {origin} not supported")


@register_binary_op(ops.DatetimeToIntegerLabelOp, pass_op=True)
def datetime_to_integer_label_op(
x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp
) -> sge.Expression:
# Determine if the frequency is fixed by checking if 'op.freq.nanos' is defined.
try:
return _datetime_to_integer_label_fixed_frequency(x, y, op)
except ValueError:
return _datetime_to_integer_label_non_fixed_frequency(x, y, op)


def _datetime_to_integer_label_fixed_frequency(
x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp
) -> sge.Expression:
"""
This function handles fixed frequency conversions where the unit can range
from microseconds (us) to days.
"""
us = op.freq.nanos / 1000
x_int = sge.func(
"UNIX_MICROS",
sge.Cast(this=x.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)),
)
first = _calculate_resample_first(y, op.origin) # type: ignore
x_int_label = sge.Cast(
this=sge.Floor(
this=sge.func(
"IEEE_DIVIDE",
sge.Sub(this=x_int, expression=first),
sge.convert(int(us)),
)
),
to=sge.DataType.build("INT64"),
)
return x_int_label


def _datetime_to_integer_label_non_fixed_frequency(
x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp
) -> sge.Expression:
"""
This function handles non-fixed frequency conversions for units ranging
from weeks to years.
"""
rule_code = op.freq.rule_code
n = op.freq.n
if rule_code == "W-SUN": # Weekly
us = n * 7 * 24 * 60 * 60 * 1000000
x_trunc = sge.TimestampTrunc(this=x.expr, unit=sge.Var(this="WEEK(MONDAY)"))
y_trunc = sge.TimestampTrunc(this=y.expr, unit=sge.Var(this="WEEK(MONDAY)"))
x_plus_6 = sge.Add(
this=x_trunc,
expression=sge.Interval(
this=sge.convert(6), unit=sge.Identifier(this="DAY")
),
)
y_plus_6 = sge.Add(
this=y_trunc,
expression=sge.Interval(
this=sge.convert(6), unit=sge.Identifier(this="DAY")
),
)
x_int = sge.func(
"UNIX_MICROS",
sge.Cast(
this=x_plus_6, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)
),
)
first = sge.func(
"UNIX_MICROS",
sge.Cast(
this=y_plus_6, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)
),
)
return sge.Case(
ifs=[
sge.If(
this=sge.EQ(this=x_int, expression=first),
true=sge.convert(0),
)
],
default=sge.Add(
this=sge.Cast(
this=sge.Floor(
this=sge.func(
"IEEE_DIVIDE",
sge.Sub(
this=sge.Sub(this=x_int, expression=first),
expression=sge.convert(1),
),
sge.convert(us),
)
),
to=sge.DataType.build("INT64"),
),
expression=sge.convert(1),
),
)
elif rule_code == "ME": # Monthly
x_int = sge.Paren( # type: ignore
this=sge.Add(
this=sge.Mul(
this=sge.Extract(
this=sge.Identifier(this="YEAR"), expression=x.expr
),
expression=sge.convert(12),
),
expression=sge.Sub(
this=sge.Extract(
this=sge.Identifier(this="MONTH"), expression=x.expr
),
expression=sge.convert(1),
),
)
)
first = sge.Paren( # type: ignore
this=sge.Add(
this=sge.Mul(
this=sge.Extract(
this=sge.Identifier(this="YEAR"), expression=y.expr
),
expression=sge.convert(12),
),
expression=sge.Sub(
this=sge.Extract(
this=sge.Identifier(this="MONTH"), expression=y.expr
),
expression=sge.convert(1),
),
)
)
return sge.Case(
ifs=[
sge.If(
this=sge.EQ(this=x_int, expression=first),
true=sge.convert(0),
)
],
default=sge.Add(
this=sge.Cast(
this=sge.Floor(
this=sge.func(
"IEEE_DIVIDE",
sge.Sub(
this=sge.Sub(this=x_int, expression=first),
expression=sge.convert(1),
),
sge.convert(n),
)
),
to=sge.DataType.build("INT64"),
),
expression=sge.convert(1),
),
)
elif rule_code == "QE-DEC": # Quarterly
x_int = sge.Paren( # type: ignore
this=sge.Add(
this=sge.Mul(
this=sge.Extract(
this=sge.Identifier(this="YEAR"), expression=x.expr
),
expression=sge.convert(4),
),
expression=sge.Sub(
this=sge.Extract(
this=sge.Identifier(this="QUARTER"), expression=x.expr
),
expression=sge.convert(1),
),
)
)
first = sge.Paren( # type: ignore
this=sge.Add(
this=sge.Mul(
this=sge.Extract(
this=sge.Identifier(this="YEAR"), expression=y.expr
),
expression=sge.convert(4),
),
expression=sge.Sub(
this=sge.Extract(
this=sge.Identifier(this="QUARTER"), expression=y.expr
),
expression=sge.convert(1),
),
)
)
return sge.Case(
ifs=[
sge.If(
this=sge.EQ(this=x_int, expression=first),
true=sge.convert(0),
)
],
default=sge.Add(
this=sge.Cast(
this=sge.Floor(
this=sge.func(
"IEEE_DIVIDE",
sge.Sub(
this=sge.Sub(this=x_int, expression=first),
expression=sge.convert(1),
),
sge.convert(n),
)
),
to=sge.DataType.build("INT64"),
),
expression=sge.convert(1),
),
)
elif rule_code == "YE-DEC": # Yearly
x_int = sge.Extract(this=sge.Identifier(this="YEAR"), expression=x.expr)
first = sge.Extract(this=sge.Identifier(this="YEAR"), expression=y.expr)
return sge.Case(
ifs=[
sge.If(
this=sge.EQ(this=x_int, expression=first),
true=sge.convert(0),
)
],
default=sge.Add(
this=sge.Cast(
this=sge.Floor(
this=sge.func(
"IEEE_DIVIDE",
sge.Sub(
this=sge.Sub(this=x_int, expression=first),
expression=sge.convert(1),
),
sge.convert(n),
)
),
to=sge.DataType.build("INT64"),
),
expression=sge.convert(1),
),
)
else:
raise ValueError(rule_code)


@register_unary_op(ops.FloorDtOp, pass_op=True)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
WITH `bfcte_0` AS (
SELECT
`datetime_col`,
`timestamp_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CAST(FLOOR(
IEEE_DIVIDE(
UNIX_MICROS(CAST(`datetime_col` AS TIMESTAMP)) - UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)),
86400000000
)
) AS INT64) AS `bfcol_2`,
CASE
WHEN UNIX_MICROS(
CAST(TIMESTAMP_TRUNC(`datetime_col`, WEEK(MONDAY)) + INTERVAL 6 DAY AS TIMESTAMP)
) = UNIX_MICROS(
CAST(TIMESTAMP_TRUNC(`timestamp_col`, WEEK(MONDAY)) + INTERVAL 6 DAY AS TIMESTAMP)
)
THEN 0
ELSE CAST(FLOOR(
IEEE_DIVIDE(
UNIX_MICROS(
CAST(TIMESTAMP_TRUNC(`datetime_col`, WEEK(MONDAY)) + INTERVAL 6 DAY AS TIMESTAMP)
) - UNIX_MICROS(
CAST(TIMESTAMP_TRUNC(`timestamp_col`, WEEK(MONDAY)) + INTERVAL 6 DAY AS TIMESTAMP)
) - 1,
604800000000
)
) AS INT64) + 1
END AS `bfcol_3`
FROM `bfcte_0`
)
SELECT
`bfcol_2` AS `fixed_freq`,
`bfcol_3` AS `non_fixed_freq_weekly`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ def test_dayofyear(scalar_types_df: bpd.DataFrame, snapshot):
snapshot.assert_match(sql, "out.sql")


def test_datetime_to_integer_label(scalar_types_df: bpd.DataFrame, snapshot):
col_names = ["datetime_col", "timestamp_col"]
bf_df = scalar_types_df[col_names]
ops_map = {
"fixed_freq": ops.DatetimeToIntegerLabelOp(
freq=pd.tseries.offsets.Day(), origin="start", closed="left" # type: ignore
).as_expr("datetime_col", "timestamp_col"),
"non_fixed_freq_weekly": ops.DatetimeToIntegerLabelOp(
freq=pd.tseries.offsets.Week(weekday=6), origin="start", closed="left" # type: ignore
).as_expr("datetime_col", "timestamp_col"),
}

sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys()))
snapshot.assert_match(sql, "out.sql")


def test_floor_dt(scalar_types_df: bpd.DataFrame, snapshot):
col_names = ["datetime_col", "timestamp_col", "date_col"]
bf_df = scalar_types_df[col_names]
Expand Down