Skip to content

Commit 145034a

Browse files
authored
fix: preserve aliases on cast columns and fix star selection in sqlglot (#17394) (#17455)
** This branch is under testing, not ready for review ** This PR resolves a regression introduced when switching to the default `sqlglot` compiler, where cast columns lost their aliases during type-coercion and were auto-named by BigQuery as `f0_`, `f1_`, etc. (fixes #17394). Before: screen/7FibgBYoY6EN8hR After: screen/AWsDt8aocqyzjup Fixes #<521420846> 🦕
1 parent dd59d36 commit 145034a

9 files changed

Lines changed: 40 additions & 15 deletions

File tree

packages/bigframes/bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,12 +249,13 @@ def select(
249249
# TODO: Explicitly insert CTEs into plan
250250
if len(selections) > 0:
251251
to_select = [
252-
sge.Alias(
253-
this=expr,
252+
expr
253+
if (isinstance(expr, sge.Alias) and expr.alias == id)
254+
or (isinstance(expr, sge.Column) and expr.name == id)
255+
else sge.Alias(
256+
this=expr.this if isinstance(expr, sge.Alias) else expr,
254257
alias=sql.identifier(id),
255258
)
256-
if expr.alias_or_name != id
257-
else expr
258259
for id, expr in selections
259260
]
260261
new_expr = self.expr.select(*to_select)

packages/bigframes/bigframes/core/sql_nodes.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,14 @@ def _node_expressions(self):
276276

277277
@property
278278
def is_star_selection(self) -> bool:
279-
return tuple(self.ids) == tuple(self.child.ids)
279+
if tuple(self.ids) != tuple(self.child.ids):
280+
return False
281+
for cdef in self.selections:
282+
if not isinstance(cdef.expression, ex.DerefOp):
283+
return False
284+
if cdef.expression.id != cdef.id:
285+
return False
286+
return True
280287

281288
@functools.cache
282289
def get_id_mapping(self) -> dict[identifiers.ColumnId, ex.Expression]:
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
SELECT
22
CAST(TIMESTAMP_MICROS(CAST(TRUNC(`int64_col` * 0.001) AS INT64)) AS DATETIME) AS `int64_col`,
3-
SAFE_CAST(`string_col` AS DATETIME),
3+
SAFE_CAST(`string_col` AS DATETIME) AS `string_col`,
44
CAST(TIMESTAMP_MICROS(CAST(TRUNC(`float64_col` * 0.001) AS INT64)) AS DATETIME) AS `float64_col`,
5-
SAFE_CAST(`timestamp_col` AS DATETIME),
5+
SAFE_CAST(`timestamp_col` AS DATETIME) AS `timestamp_col`,
66
SAFE_CAST(`string_col` AS DATETIME) AS `string_col_fmt`
7-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
SELECT
2-
CAST(CAST(`bool_col` AS INT64) AS FLOAT64),
2+
CAST(CAST(`bool_col` AS INT64) AS FLOAT64) AS `bool_col`,
33
CAST('1.34235e4' AS FLOAT64) AS `str_const`,
44
SAFE_CAST(SAFE_CAST(`bool_col` AS INT64) AS FLOAT64) AS `bool_w_safe`
5-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
SELECT
2-
CAST(`int64_col` AS STRING),
2+
CAST(`int64_col` AS STRING) AS `int64_col`,
33
INITCAP(CAST(`bool_col` AS STRING)) AS `bool_col`,
44
INITCAP(SAFE_CAST(`bool_col` AS STRING)) AS `bool_w_safe`
5-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_to_datetime(scalar_types_df: bpd.DataFrame, snapshot):
217217
)
218218

219219
sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys()))
220-
snapshot.assert_match(sql, "out.sql")
220+
snapshot.assert_match(sql + "\n", "out.sql")
221221

222222

223223
def test_to_timestamp(scalar_types_df: bpd.DataFrame, snapshot):

packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_astype_float(scalar_types_df: bpd.DataFrame, snapshot):
6060
"bool_w_safe": ops.AsTypeOp(to_type=to_type, safe=True).as_expr("bool_col"),
6161
}
6262
sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys()))
63-
snapshot.assert_match(sql, "out.sql")
63+
snapshot.assert_match(sql + "\n", "out.sql")
6464

6565

6666
def test_astype_bool(scalar_types_df: bpd.DataFrame, snapshot):
@@ -107,7 +107,7 @@ def test_astype_string(scalar_types_df: bpd.DataFrame, snapshot):
107107
"bool_w_safe": ops.AsTypeOp(to_type=to_type, safe=True).as_expr("bool_col"),
108108
}
109109
sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys()))
110-
snapshot.assert_match(sql, "out.sql")
110+
snapshot.assert_match(sql + "\n", "out.sql")
111111

112112

113113
def test_astype_json(scalar_types_df: bpd.DataFrame, snapshot):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
SELECT
2+
`rowindex`,
3+
CAST(`timestamp_col` AS STRING) AS `timestamp_col`,
4+
CAST(`int64_col` AS FLOAT64) AS `int64_col`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

packages/bigframes/tests/unit/core/compile/sqlglot/test_compile_readtable.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,15 @@ def test_compile_readtable_w_columns_filters(compiler_session, snapshot):
8080
filters=filters,
8181
)
8282
snapshot.assert_match(bf_df.sql, "out.sql")
83+
84+
85+
def test_compile_astype_aliases(scalar_types_df: bpd.DataFrame, snapshot):
86+
# Test case for issue #17394 (CAST columns lose their aliases)
87+
bf_df = scalar_types_df[["timestamp_col", "int64_col"]]
88+
result = bf_df.astype(
89+
{
90+
"timestamp_col": "string[pyarrow]",
91+
"int64_col": "Float64",
92+
}
93+
)
94+
snapshot.assert_match(result.sql + "\n", "out.sql")

0 commit comments

Comments
 (0)