Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions bigframes/core/compile/sqlglot/expressions/datetime_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler

register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op


@register_unary_op(ops.FloorDtOp, pass_op=True)
Expand Down Expand Up @@ -51,6 +52,28 @@ def _(expr: TypedExpr, op: ops.FloorDtOp) -> sge.Expression:
return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=bq_freq))


def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression:
if origin == "epoch":
return sge.convert(0)
elif origin == "start_day":
return sge.func(
"UNIX_MICROS",
sge.Cast(
this=sge.Cast(
this=y.expr, to=sge.DataType(this=sge.DataType.Type.DATE)
),
to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ),
),
)
elif origin == "start":
return sge.func(
"UNIX_MICROS",
sge.Cast(this=y.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)),
)
else:
raise ValueError(f"Origin {origin} not supported")


@register_unary_op(ops.hour_op)
def _(expr: TypedExpr) -> sge.Expression:
return sge.Extract(this=sge.Identifier(this="HOUR"), expression=expr.expr)
Expand Down Expand Up @@ -170,3 +193,243 @@ def _(expr: TypedExpr, op: ops.UnixSeconds) -> sge.Expression:
@register_unary_op(ops.year_op)
def _(expr: TypedExpr) -> sge.Expression:
return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr)


def _dtype_to_sql_string(dtype: dtypes.Dtype) -> str:
if dtype == dtypes.TIMESTAMP_DTYPE:
return "TIMESTAMP"
if dtype == dtypes.DATETIME_DTYPE:
return "DATETIME"
if dtype == dtypes.DATE_DTYPE:
return "DATE"
if dtype == dtypes.TIME_DTYPE:
return "TIME"
# Should not be reached in this operator
raise ValueError(f"Unsupported dtype for datetime conversion: {dtype}")


@register_binary_op(ops.IntegerLabelToDatetimeOp, pass_op=True)
def integer_label_to_datetime_op(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
# Determine if the frequency is fixed by checking if 'op.freq.nanos' is defined.
try:
return _integer_label_to_datetime_op_fixed_frequency(x, y, op)
except ValueError:
return _integer_label_to_datetime_op_non_fixed_frequency(x, y, op)


def _integer_label_to_datetime_op_fixed_frequency(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
"""
This function handles fixed frequency conversions where the unit can range
from microseconds (us) to days.
"""
us = op.freq.nanos / 1000
first = _calculate_resample_first(y, op.origin) # type: ignore
x_label = sge.Cast(
this=sge.func(
"TIMESTAMP_MICROS",
sge.Cast(
this=sge.Add(
this=sge.Mul(
this=sge.Cast(this=x.expr, to=sge.DataType.build("BIGNUMERIC")),
expression=sge.convert(int(us)),
),
expression=sge.Cast(
this=first, to=sge.DataType.build("BIGNUMERIC")
),
),
to=sge.DataType.build("INT64"),
),
),
to=_dtype_to_sql_string(y.dtype), # type: ignore
)
return x_label


def _integer_label_to_datetime_op_non_fixed_frequency(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
"""
This function handles non-fixed frequency conversions for units ranging
from weeks to years.
"""
rule_code = op.freq.rule_code
n = op.freq.n
if rule_code == "W-SUN": # Weekly
us = n * 7 * 24 * 60 * 60 * 1000000
first = sge.func(
"UNIX_MICROS",
sge.Add(
this=sge.TimestampTrunc(
this=sge.Cast(
this=y.expr,
to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ),
),
unit=sge.Var(this="WEEK(MONDAY)"),
),
expression=sge.Interval(
this=sge.convert(6), unit=sge.Identifier(this="DAY")
),
),
)
x_label = sge.Cast(
this=sge.func(
"TIMESTAMP_MICROS",
sge.Cast(
this=sge.Add(
this=sge.Mul(
this=sge.Cast(
this=x.expr, to=sge.DataType.build("BIGNUMERIC")
),
expression=sge.convert(us),
),
expression=sge.Cast(
this=first, to=sge.DataType.build("BIGNUMERIC")
),
),
to=sge.DataType.build("INT64"),
),
),
to=_dtype_to_sql_string(y.dtype), # type: ignore
)
elif rule_code == "ME": # Monthly
one = sge.convert(1)
twelve = sge.convert(12)
first = sge.Sub( # type: ignore
this=sge.Add(
this=sge.Mul(
this=sge.Extract(this="YEAR", expression=y.expr),
expression=twelve,
),
expression=sge.Extract(this="MONTH", expression=y.expr),
),
expression=one,
)
x_val = sge.Add(
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
)
year = sge.Cast(
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, twelve)),
to=sge.DataType.build("INT64"),
)
month = sge.Add(this=sge.Mod(this=x_val, expression=twelve), expression=one)
next_year = sge.Case(
ifs=[
sge.If(
this=sge.EQ(this=month, expression=twelve),
true=sge.Add(this=year, expression=one),
)
],
default=year,
)
next_month = sge.Case(
ifs=[
sge.If(
this=sge.EQ(this=month, expression=twelve),
true=one,
)
],
default=sge.Add(this=month, expression=one),
)
next_month_date = sge.func(
"TIMESTAMP",
sge.Anonymous(
this="DATETIME",
expressions=[
next_year,
next_month,
one,
sge.convert(0),
sge.convert(0),
sge.convert(0),
],
),
)
x_label = sge.Sub( # type: ignore
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
)
elif rule_code == "QE-DEC": # Quarterly
one = sge.convert(1)
three = sge.convert(3)
four = sge.convert(4)
twelve = sge.convert(12)
first = sge.Sub( # type: ignore
this=sge.Add(
this=sge.Mul(
this=sge.Extract(this="YEAR", expression=y.expr),
expression=four,
),
expression=sge.Extract(this="QUARTER", expression=y.expr),
),
expression=one,
)
x_val = sge.Add(
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
)
year = sge.Cast(
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, four)),
to=sge.DataType.build("INT64"),
)
month = sge.Mul( # type: ignore
this=sge.Paren(
this=sge.Add(this=sge.Mod(this=x_val, expression=four), expression=one)
),
expression=three,
)
next_year = sge.Case(
ifs=[
sge.If(
this=sge.EQ(this=month, expression=twelve),
true=sge.Add(this=year, expression=one),
)
],
default=year,
)
next_month = sge.Case(
ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)],
default=sge.Add(this=month, expression=one),
)
next_month_date = sge.Anonymous(
this="DATETIME",
expressions=[
next_year,
next_month,
one,
sge.convert(0),
sge.convert(0),
sge.convert(0),
],
)
x_label = sge.Sub( # type: ignore
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
)
elif rule_code == "YE-DEC": # Yearly
one = sge.convert(1)
first = sge.Extract(this="YEAR", expression=y.expr)
x_val = sge.Add(
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
)
next_year = sge.Add(this=x_val, expression=one) # type: ignore
next_month_date = sge.func(
"TIMESTAMP",
sge.Anonymous(
this="DATETIME",
expressions=[
next_year,
one,
one,
sge.convert(0),
sge.convert(0),
sge.convert(0),
],
),
)
x_label = sge.Sub( # type: ignore
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
)
else:
raise ValueError(rule_code)
return sge.Cast(this=x_label, to=_dtype_to_sql_string(y.dtype)) # type: ignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
WITH `bfcte_0` AS (
SELECT
`rowindex`,
`timestamp_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CAST(TIMESTAMP_MICROS(
CAST(CAST(`rowindex` AS BIGNUMERIC) * 86400000000 + CAST(UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)) AS BIGNUMERIC) AS INT64)
) AS TIMESTAMP) AS `bfcol_2`,
CAST(DATETIME(
CASE
WHEN (
MOD(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
) + 1
) * 3 = 12
THEN CAST(FLOOR(
IEEE_DIVIDE(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
)
) AS INT64) + 1
ELSE CAST(FLOOR(
IEEE_DIVIDE(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
)
) AS INT64)
END,
CASE
WHEN (
MOD(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
) + 1
) * 3 = 12
THEN 1
ELSE (
MOD(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
) + 1
) * 3 + 1
END,
1,
0,
0,
0
) - INTERVAL 1 DAY AS TIMESTAMP) AS `bfcol_3`
FROM `bfcte_0`
)
SELECT
`bfcol_2` AS `fixed_freq`,
`bfcol_3` AS `non_fixed_freq`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
WITH `bfcte_0` AS (
SELECT
`rowindex`,
`timestamp_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CAST(TIMESTAMP_MICROS(
CAST(CAST(`rowindex` AS BIGNUMERIC) * 86400000000 + CAST(UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)) AS BIGNUMERIC) AS INT64)
) AS TIMESTAMP) AS `bfcol_2`
FROM `bfcte_0`
)
SELECT
`bfcol_2` AS `fixed_freq`
FROM `bfcte_1`
Loading
Loading