Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Define RowJoinNode and defer rewrite #1183

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,9 @@ def try_row_join(
lcol.name: lcol.name for lcol in self.node.ids
}
other_node, r_mapping = self.prepare_join_names(other)
import bigframes.core.rewrite
import bigframes.core.row_join

result_node = bigframes.core.rewrite.try_row_join(
result_node = bigframes.core.row_join.try_row_join(
self.node, other_node, conditions
)
if result_node is None:
Expand Down
8 changes: 4 additions & 4 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2335,8 +2335,8 @@ def join(
# Handle null index, which only supports row join
# This is the canonical way of aligning on null index, so always allow (ignore block_identity_join)
if self.index.nlevels == other.index.nlevels == 0:
result = try_legacy_row_join(self, other, how=how) or try_new_row_join(
self, other
result = try_new_row_join(self, other) or try_legacy_row_join(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably rename "try_new_row_join" now or in the future. The reason is that I guess "try_legacy_row_join" will be eventually removed, and it would be very confusing if we have only a "new" version.

We can just call it "try_row_join".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to try_row_join

self, other, how=how
)
if result is not None:
return result
Expand All @@ -2350,8 +2350,8 @@ def join(
and (self.index.nlevels == other.index.nlevels)
and (self.index.dtypes == other.index.dtypes)
):
result = try_legacy_row_join(self, other, how=how) or try_new_row_join(
self, other
result = try_new_row_join(self, other) or try_legacy_row_join(
self, other, how=how
)
if result is not None:
return result
Expand Down
6 changes: 6 additions & 0 deletions bigframes/core/compile/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ def compile_join(self, node: nodes.JoinNode):
def compile_concat(self, node: nodes.ConcatNode):
return pl.concat(self.compile_node(child) for child in node.child_nodes)

@compile_node.register
def compile_row_join(self, node: nodes.RowJoinNode):
return pl.concat(
(self.compile_node(child) for child in node.child_nodes), how="horizontal"
)

@compile_node.register
def compile_agg(self, node: nodes.AggregateNode):
df = self.compile_node(node.child)
Expand Down
83 changes: 81 additions & 2 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ def explicitly_ordered(self) -> bool:
"""
...

@functools.cached_property
def height(self) -> int:
if len(self.child_nodes) == 0:
return 0
return max(child.height for child in self.child_nodes) + 1

@functools.cached_property
def total_variables(self) -> int:
return self.variables_introduced + sum(
Expand Down Expand Up @@ -477,6 +483,76 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
return dataclasses.replace(self, conditions=new_conds) # type: ignore


@dataclasses.dataclass(frozen=True, eq=False)
class RowJoinNode(BigFrameNode):
left_child: BigFrameNode
right_child: BigFrameNode

def _validate(self):
assert not (
set(self.left_child.ids) & set(self.right_child.ids)
), "Join ids collide"

@property
def row_preserving(self) -> bool:
return True

@property
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
return (self.left_child, self.right_child)

@property
def order_ambiguous(self) -> bool:
return False

@property
def explicitly_ordered(self) -> bool:
# Do not consider user pre-join ordering intent - they need to re-order post-join in unordered mode.
return True

@property
def fields(self) -> Iterable[Field]:
return itertools.chain(self.left_child.fields, self.right_child.fields)

@functools.cached_property
def variables_introduced(self) -> int:
return 0

@property
def row_count(self) -> Optional[int]:
return self.left_child.row_count

@property
def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
return ()

@property
def added_fields(self) -> Tuple[Field, ...]:
return tuple(self.right_child.fields)

def transform_children(
self, t: Callable[[BigFrameNode], BigFrameNode]
) -> BigFrameNode:
transformed = dataclasses.replace(
self, left_child=t(self.left_child), right_child=t(self.right_child)
)
if self == transformed:
# reusing existing object speeds up eq, and saves a small amount of memory
return self
return transformed

def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
return self.transform_children(lambda x: x.prune(used_cols))

def remap_vars(
self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]
) -> BigFrameNode:
return self

def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
return self


@dataclasses.dataclass(frozen=True, eq=False)
class ConcatNode(BigFrameNode):
# TODO: Explcitly map column ids from each child
Expand Down Expand Up @@ -1519,8 +1595,9 @@ def bottom_up(
root: BigFrameNode,
transform: Callable[[BigFrameNode], BigFrameNode],
*,
memoize=False,
validate=False,
stop: Optional[Callable[[BigFrameNode], bool]] = None,
memoize: bool = False,
validate: bool = False,
) -> BigFrameNode:
"""
Perform a bottom-up transformation of the BigFrameNode tree.
Expand All @@ -1529,6 +1606,8 @@ def bottom_up(
"""

def bottom_up_internal(root: BigFrameNode) -> BigFrameNode:
if (stop is not None) and (stop(root)):
return root
return transform(root.transform_children(bottom_up_internal))

if memoize:
Expand Down
5 changes: 3 additions & 2 deletions bigframes/core/rewrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.

from bigframes.core.rewrite.identifiers import remap_variables
from bigframes.core.rewrite.implicit_align import try_row_join
from bigframes.core.rewrite.implicit_align import rewrite_row_join
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
from bigframes.core.rewrite.slices import pullup_limit_from_slice, rewrite_slice

__all__ = [
"legacy_join_as_projection",
"try_row_join",
"rewrite_slice",
"pullup_limit_from_slice",
"remap_variables",
"combine_nodes",
"rewrite_row_join",
]
158 changes: 70 additions & 88 deletions bigframes/core/rewrite/implicit_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
from __future__ import annotations

import dataclasses
from typing import Iterable, Optional, Tuple
import itertools
from typing import Optional, Set, Tuple

import bigframes.core.expression
import bigframes.core.guid
Expand All @@ -34,82 +34,40 @@
ALIGNABLE_NODES = (
*ADDITIVE_NODES,
bigframes.core.nodes.SelectionNode,
bigframes.core.nodes.RowJoinNode,
)


@dataclasses.dataclass(frozen=True)
class ExpressionSpec:
expression: bigframes.core.expression.Expression
node: bigframes.core.nodes.BigFrameNode

def rewrite_row_join(
node: bigframes.core.nodes.BigFrameNode,
) -> bigframes.core.nodes.BigFrameNode:
if not isinstance(node, bigframes.core.nodes.RowJoinNode):
return node

def get_expression_spec(
node: bigframes.core.nodes.BigFrameNode, id: bigframes.core.identifiers.ColumnId
) -> ExpressionSpec:
"""Normalizes column value by chaining expressions across multiple selection and projection nodes if possible.
This normalization helps identify whether columns are equivalent.
"""
# TODO: While we chain expression fragments from different nodes
# we could further normalize with constant folding and other scalar expression rewrites
expression: bigframes.core.expression.Expression = (
bigframes.core.expression.DerefOp(id)
l_node = node.left_child
r_node = node.right_child
divergent_node = first_shared_descendent(
{l_node, r_node}, descendable_types=ALIGNABLE_NODES
)
assert divergent_node is not None
# Inner handler can't RowJoinNode, so bottom up apply the algorithm locally
return bigframes.core.nodes.bottom_up(
node,
lambda x: _rewrite_row_join_node(x, divergent_node),
stop=lambda x: x == divergent_node,
memoize=True,
)
curr_node = node
while True:
if isinstance(curr_node, bigframes.core.nodes.SelectionNode):
select_mappings = {
col_id: ref for ref, col_id in curr_node.input_output_pairs
}
expression = expression.bind_refs(
select_mappings, allow_partial_bindings=True
)
elif isinstance(curr_node, bigframes.core.nodes.ProjectionNode):
proj_mappings = {col_id: expr for expr, col_id in curr_node.assignments}
expression = expression.bind_refs(
proj_mappings, allow_partial_bindings=True
)
elif isinstance(
curr_node,
(
bigframes.core.nodes.WindowOpNode,
bigframes.core.nodes.PromoteOffsetsNode,
),
):
if set(expression.column_references).isdisjoint(
field.id for field in curr_node.added_fields
):
# we don't yet have a way of normalizing window ops into a ExpressionSpec, which only
# handles normalizing scalar expressions at the moment.
pass
else:
return ExpressionSpec(expression, curr_node)
else:
return ExpressionSpec(expression, curr_node)
curr_node = curr_node.child


def try_row_join(
l_node: bigframes.core.nodes.BigFrameNode,
r_node: bigframes.core.nodes.BigFrameNode,
join_keys: Tuple[Tuple[str, str], ...],
) -> Optional[bigframes.core.nodes.BigFrameNode]:
"""Joins the two nodes"""
divergent_node = first_shared_descendent(
l_node, r_node, descendable_types=ALIGNABLE_NODES
)
if divergent_node is None:
return None
# check join keys are equivalent by normalizing the expressions as much as posisble
# instead of just comparing ids
for l_key, r_key in join_keys:
# Caller is block, so they still work with raw strings rather than ids
left_id = bigframes.core.identifiers.ColumnId(l_key)
right_id = bigframes.core.identifiers.ColumnId(r_key)
if get_expression_spec(l_node, left_id) != get_expression_spec(
r_node, right_id
):
return None
def _rewrite_row_join_node(
node: bigframes.core.nodes.BigFrameNode,
divergent_node: bigframes.core.nodes.BigFrameNode,
) -> bigframes.core.nodes.BigFrameNode:
if not isinstance(node, bigframes.core.nodes.RowJoinNode):
return node

l_node = node.left_child
r_node = node.right_child
l_node, l_selection = pull_up_selection(l_node, stop=divergent_node)
r_node, r_selection = pull_up_selection(
r_node, stop=divergent_node, rename_vars=True
Expand All @@ -131,9 +89,35 @@ def _linearize_trees(
)

merged_node = _linearize_trees(l_node, r_node)
# RowJoin rewrites can otherwise create too deep a tree
merged_node = bigframes.core.nodes.bottom_up(
merged_node,
fold_projections,
stop=lambda x: divergent_node in x.child_nodes,
memoize=True,
)
return bigframes.core.nodes.SelectionNode(merged_node, combined_selection)


def fold_projections(
root: bigframes.core.nodes.BigFrameNode,
) -> bigframes.core.nodes.BigFrameNode:
"""If root and child are projection nodes, merge them."""
if not isinstance(root, bigframes.core.nodes.ProjectionNode):
return root
if not isinstance(root.child, bigframes.core.nodes.ProjectionNode):
return root
mapping = {id: expr for expr, id in root.child.assignments}
new_exprs = (
*root.child.assignments,
*(
(expr.bind_refs(mapping, allow_partial_bindings=True), id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tuple comprehension is very long. Shall we use another local variable "root_assignments" to hold the value?

And
"new_exprs = tuple(root.child.assignments) + root_assignments"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extracted out to variable as suggested

for expr, id in root.assignments
),
)
return bigframes.core.nodes.ProjectionNode(root.child.child, new_exprs)


def pull_up_selection(
node: bigframes.core.nodes.BigFrameNode,
stop: bigframes.core.nodes.BigFrameNode,
Expand Down Expand Up @@ -201,26 +185,24 @@ def pull_up_selection(

## Traversal helpers
def first_shared_descendent(
left: bigframes.core.nodes.BigFrameNode,
right: bigframes.core.nodes.BigFrameNode,
descendable_types: Tuple[type[bigframes.core.nodes.UnaryNode], ...],
roots: Set[bigframes.core.nodes.BigFrameNode],
descendable_types: Tuple[type[bigframes.core.nodes.BigFrameNode], ...],
) -> Optional[bigframes.core.nodes.BigFrameNode]:
l_path = tuple(descend(left, descendable_types))
r_path = tuple(descend(right, descendable_types))
if l_path[-1] != r_path[-1]:
if not roots:
return None
if len(roots) == 1:
return next(iter(roots))

for l_node, r_node in zip(l_path[-len(r_path) :], r_path[-len(l_path) :]):
if l_node == r_node:
return l_node
# should be impossible, as l_path[-1] == r_path[-1]
raise ValueError()
min_height = min(root.height for root in roots)
to_descend = set(root for root in roots if root.height > min_height)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me a while to realize that to_descend is a set of root notes.

Let's name it "descend_roots" ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to roots_to_descend

if not to_descend:
to_descend = roots

if any(not isinstance(root, descendable_types) for root in to_descend):
return None
as_is = roots - to_descend
descended = set(
itertools.chain.from_iterable(root.child_nodes for root in to_descend)
)

def descend(
root: bigframes.core.nodes.BigFrameNode,
descendable_types: Tuple[type[bigframes.core.nodes.UnaryNode], ...],
) -> Iterable[bigframes.core.nodes.BigFrameNode]:
yield root
if isinstance(root, descendable_types):
yield from descend(root.child, descendable_types)
return first_shared_descendent(as_is.union(descended), descendable_types)
Loading
Loading