Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,18 +1062,24 @@ def _bake_ordering(self) -> OrderedIR:
)
new_baked_cols.append(baked_column)
new_expr = OrderingExpression(
ex.free_var(baked_column.name), expr.direction, expr.na_last
ex.free_var(baked_column.get_name()), expr.direction, expr.na_last
)
new_exprs.append(new_expr)
else:
elif isinstance(expr.scalar_expression, ex.UnboundVariableExpression):
new_exprs.append(expr)
new_baked_cols.append(self._ibis_bindings[expr.scalar_expression.id])

ordering = self._ordering.with_ordering_columns(new_exprs)
new_ordering = ExpressionOrdering(
tuple(new_exprs),
self._ordering.integer_encoding,
self._ordering.string_encoding,
self._ordering.total_ordering_columns,
)
return OrderedIR(
self._table,
columns=self.columns,
hidden_ordering_columns=[*self._hidden_ordering_columns, *new_baked_cols],
ordering=ordering,
hidden_ordering_columns=tuple(new_baked_cols),
ordering=new_ordering,
predicates=self._predicates,
)

Expand Down
30 changes: 20 additions & 10 deletions bigframes/core/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class SquashedSelect:
columns: Tuple[Tuple[scalar_exprs.Expression, str], ...]
predicate: Optional[scalar_exprs.Expression]
ordering: Tuple[order.OrderingExpression, ...]
reverse_root: bool = False

@classmethod
def from_node(cls, node: nodes.BigFrameNode) -> SquashedSelect:
Expand Down Expand Up @@ -63,7 +64,9 @@ def project(
new_columns = tuple(
(expr.bind_all_variables(self.column_lookup), id) for expr, id in projection
)
return SquashedSelect(self.root, new_columns, self.predicate, self.ordering)
return SquashedSelect(
self.root, new_columns, self.predicate, self.ordering, self.reverse_root
)

def filter(self, predicate: scalar_exprs.Expression) -> SquashedSelect:
if self.predicate is None:
Expand All @@ -72,38 +75,40 @@ def filter(self, predicate: scalar_exprs.Expression) -> SquashedSelect:
new_predicate = ops.and_op.as_expr(
self.predicate, predicate.bind_all_variables(self.column_lookup)
)
return SquashedSelect(self.root, self.columns, new_predicate, self.ordering)
return SquashedSelect(
self.root, self.columns, new_predicate, self.ordering, self.reverse_root
)

def reverse(self) -> SquashedSelect:
new_ordering = tuple(expr.with_reverse() for expr in self.ordering)
return SquashedSelect(self.root, self.columns, self.predicate, new_ordering)
return SquashedSelect(
self.root, self.columns, self.predicate, new_ordering, not self.reverse_root
)

def order_with(self, by: Tuple[order.OrderingExpression, ...]):
adjusted_orderings = [
order_part.bind_variables(self.column_lookup) for order_part in by
]
new_ordering = (*adjusted_orderings, *self.ordering)
return SquashedSelect(self.root, self.columns, self.predicate, new_ordering)
return SquashedSelect(
self.root, self.columns, self.predicate, new_ordering, self.reverse_root
)

def maybe_join(
self, right: SquashedSelect, join_def: join_defs.JoinDefinition
) -> Optional[SquashedSelect]:
if join_def.type == "cross":
# Cannot convert cross join to projection
return None

r_exprs_by_id = {id: expr for expr, id in right.columns}
l_exprs_by_id = {id: expr for expr, id in self.columns}
l_join_exprs = [l_exprs_by_id[cond.left_id] for cond in join_def.conditions]
r_join_exprs = [r_exprs_by_id[cond.right_id] for cond in join_def.conditions]

if (self.root != right.root) or any(
l_expr != r_expr for l_expr, r_expr in zip(l_join_exprs, r_join_exprs)
):
return None

join_type = join_def.type

# Mask columns and remap names to expected schema
lselection = self.columns
rselection = right.columns
Expand All @@ -115,7 +120,6 @@ def maybe_join(
new_predicate = self.predicate
elif join_type == "right":
new_predicate = right.predicate

l_relative, r_relative = relative_predicates(self.predicate, right.predicate)
lmask = l_relative if join_type in {"right", "outer"} else None
rmask = r_relative if join_type in {"left", "outer"} else None
Expand All @@ -126,8 +130,10 @@ def maybe_join(
new_columns = remap_names(join_def, lselection, rselection)

# Reconstruct ordering
reverse_root = self.reverse_root
if join_type == "right":
new_ordering = right.ordering
reverse_root = right.reverse_root
elif join_type == "outer":
if lmask is not None:
prefix = order.OrderingExpression(lmask, order.OrderingDirection.DESC)
Expand Down Expand Up @@ -158,11 +164,15 @@ def maybe_join(
new_ordering = self.ordering
else:
raise ValueError(f"Unexpected join type {join_type}")
return SquashedSelect(self.root, new_columns, new_predicate, new_ordering)
return SquashedSelect(
self.root, new_columns, new_predicate, new_ordering, reverse_root
)

def expand(self) -> nodes.BigFrameNode:
# Safest to apply predicates first, as it may filter out inputs that cannot be handled by other expressions
root = self.root
if self.reverse_root:
root = nodes.ReversedNode(child=root)
if self.predicate:
root = nodes.FilterNode(child=root, predicate=self.predicate)
if self.ordering:
Expand Down