-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: Define RowJoinNode and defer rewrite #1183
base: main
Are you sure you want to change the base?
Changes from 7 commits
40599d2
3d96d10
7bf5d0d
151e478
72222ce
7903620
3d20473
b864480
b422974
d5b9989
96e37f2
737cc9e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,8 @@ | |
# limitations under the License. | ||
from __future__ import annotations | ||
|
||
import dataclasses | ||
from typing import Iterable, Optional, Tuple | ||
import itertools | ||
from typing import Optional, Set, Tuple | ||
|
||
import bigframes.core.expression | ||
import bigframes.core.guid | ||
|
@@ -34,82 +34,40 @@ | |
ALIGNABLE_NODES = ( | ||
*ADDITIVE_NODES, | ||
bigframes.core.nodes.SelectionNode, | ||
bigframes.core.nodes.RowJoinNode, | ||
) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class ExpressionSpec: | ||
expression: bigframes.core.expression.Expression | ||
node: bigframes.core.nodes.BigFrameNode | ||
|
||
def rewrite_row_join( | ||
node: bigframes.core.nodes.BigFrameNode, | ||
) -> bigframes.core.nodes.BigFrameNode: | ||
if not isinstance(node, bigframes.core.nodes.RowJoinNode): | ||
return node | ||
|
||
def get_expression_spec( | ||
node: bigframes.core.nodes.BigFrameNode, id: bigframes.core.identifiers.ColumnId | ||
) -> ExpressionSpec: | ||
"""Normalizes column value by chaining expressions across multiple selection and projection nodes if possible. | ||
This normalization helps identify whether columns are equivalent. | ||
""" | ||
# TODO: While we chain expression fragments from different nodes | ||
# we could further normalize with constant folding and other scalar expression rewrites | ||
expression: bigframes.core.expression.Expression = ( | ||
bigframes.core.expression.DerefOp(id) | ||
l_node = node.left_child | ||
r_node = node.right_child | ||
divergent_node = first_shared_descendent( | ||
{l_node, r_node}, descendable_types=ALIGNABLE_NODES | ||
) | ||
assert divergent_node is not None | ||
# Inner handler can't RowJoinNode, so bottom up apply the algorithm locally | ||
return bigframes.core.nodes.bottom_up( | ||
node, | ||
lambda x: _rewrite_row_join_node(x, divergent_node), | ||
stop=lambda x: x == divergent_node, | ||
memoize=True, | ||
) | ||
curr_node = node | ||
while True: | ||
if isinstance(curr_node, bigframes.core.nodes.SelectionNode): | ||
select_mappings = { | ||
col_id: ref for ref, col_id in curr_node.input_output_pairs | ||
} | ||
expression = expression.bind_refs( | ||
select_mappings, allow_partial_bindings=True | ||
) | ||
elif isinstance(curr_node, bigframes.core.nodes.ProjectionNode): | ||
proj_mappings = {col_id: expr for expr, col_id in curr_node.assignments} | ||
expression = expression.bind_refs( | ||
proj_mappings, allow_partial_bindings=True | ||
) | ||
elif isinstance( | ||
curr_node, | ||
( | ||
bigframes.core.nodes.WindowOpNode, | ||
bigframes.core.nodes.PromoteOffsetsNode, | ||
), | ||
): | ||
if set(expression.column_references).isdisjoint( | ||
field.id for field in curr_node.added_fields | ||
): | ||
# we don't yet have a way of normalizing window ops into a ExpressionSpec, which only | ||
# handles normalizing scalar expressions at the moment. | ||
pass | ||
else: | ||
return ExpressionSpec(expression, curr_node) | ||
else: | ||
return ExpressionSpec(expression, curr_node) | ||
curr_node = curr_node.child | ||
|
||
|
||
def try_row_join( | ||
l_node: bigframes.core.nodes.BigFrameNode, | ||
r_node: bigframes.core.nodes.BigFrameNode, | ||
join_keys: Tuple[Tuple[str, str], ...], | ||
) -> Optional[bigframes.core.nodes.BigFrameNode]: | ||
"""Joins the two nodes""" | ||
divergent_node = first_shared_descendent( | ||
l_node, r_node, descendable_types=ALIGNABLE_NODES | ||
) | ||
if divergent_node is None: | ||
return None | ||
# check join keys are equivalent by normalizing the expressions as much as posisble | ||
# instead of just comparing ids | ||
for l_key, r_key in join_keys: | ||
# Caller is block, so they still work with raw strings rather than ids | ||
left_id = bigframes.core.identifiers.ColumnId(l_key) | ||
right_id = bigframes.core.identifiers.ColumnId(r_key) | ||
if get_expression_spec(l_node, left_id) != get_expression_spec( | ||
r_node, right_id | ||
): | ||
return None | ||
def _rewrite_row_join_node( | ||
node: bigframes.core.nodes.BigFrameNode, | ||
divergent_node: bigframes.core.nodes.BigFrameNode, | ||
) -> bigframes.core.nodes.BigFrameNode: | ||
if not isinstance(node, bigframes.core.nodes.RowJoinNode): | ||
return node | ||
|
||
l_node = node.left_child | ||
r_node = node.right_child | ||
l_node, l_selection = pull_up_selection(l_node, stop=divergent_node) | ||
r_node, r_selection = pull_up_selection( | ||
r_node, stop=divergent_node, rename_vars=True | ||
|
@@ -131,9 +89,35 @@ def _linearize_trees( | |
) | ||
|
||
merged_node = _linearize_trees(l_node, r_node) | ||
# RowJoin rewrites can otherwise create too deep a tree | ||
merged_node = bigframes.core.nodes.bottom_up( | ||
merged_node, | ||
fold_projections, | ||
stop=lambda x: divergent_node in x.child_nodes, | ||
memoize=True, | ||
) | ||
return bigframes.core.nodes.SelectionNode(merged_node, combined_selection) | ||
|
||
|
||
def fold_projections( | ||
root: bigframes.core.nodes.BigFrameNode, | ||
) -> bigframes.core.nodes.BigFrameNode: | ||
"""If root and child are projection nodes, merge them.""" | ||
if not isinstance(root, bigframes.core.nodes.ProjectionNode): | ||
return root | ||
if not isinstance(root.child, bigframes.core.nodes.ProjectionNode): | ||
return root | ||
mapping = {id: expr for expr, id in root.child.assignments} | ||
new_exprs = ( | ||
*root.child.assignments, | ||
*( | ||
(expr.bind_refs(mapping, allow_partial_bindings=True), id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This tuple comprehension is very long. Shall we use another local variable "root_assignments" to hold the value? And There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extracted out to variable as suggested |
||
for expr, id in root.assignments | ||
), | ||
) | ||
return bigframes.core.nodes.ProjectionNode(root.child.child, new_exprs) | ||
|
||
|
||
def pull_up_selection( | ||
node: bigframes.core.nodes.BigFrameNode, | ||
stop: bigframes.core.nodes.BigFrameNode, | ||
|
@@ -201,26 +185,24 @@ def pull_up_selection( | |
|
||
## Traversal helpers | ||
def first_shared_descendent( | ||
left: bigframes.core.nodes.BigFrameNode, | ||
right: bigframes.core.nodes.BigFrameNode, | ||
descendable_types: Tuple[type[bigframes.core.nodes.UnaryNode], ...], | ||
roots: Set[bigframes.core.nodes.BigFrameNode], | ||
descendable_types: Tuple[type[bigframes.core.nodes.BigFrameNode], ...], | ||
) -> Optional[bigframes.core.nodes.BigFrameNode]: | ||
l_path = tuple(descend(left, descendable_types)) | ||
r_path = tuple(descend(right, descendable_types)) | ||
if l_path[-1] != r_path[-1]: | ||
if not roots: | ||
return None | ||
if len(roots) == 1: | ||
return next(iter(roots)) | ||
|
||
for l_node, r_node in zip(l_path[-len(r_path) :], r_path[-len(l_path) :]): | ||
if l_node == r_node: | ||
return l_node | ||
# should be impossible, as l_path[-1] == r_path[-1] | ||
raise ValueError() | ||
min_height = min(root.height for root in roots) | ||
to_descend = set(root for root in roots if root.height > min_height) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It took me a while to realize that Let's name it "descend_roots" ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. renamed to |
||
if not to_descend: | ||
to_descend = roots | ||
|
||
if any(not isinstance(root, descendable_types) for root in to_descend): | ||
return None | ||
as_is = roots - to_descend | ||
descended = set( | ||
itertools.chain.from_iterable(root.child_nodes for root in to_descend) | ||
) | ||
|
||
def descend( | ||
root: bigframes.core.nodes.BigFrameNode, | ||
descendable_types: Tuple[type[bigframes.core.nodes.UnaryNode], ...], | ||
) -> Iterable[bigframes.core.nodes.BigFrameNode]: | ||
yield root | ||
if isinstance(root, descendable_types): | ||
yield from descend(root.child, descendable_types) | ||
return first_shared_descendent(as_is.union(descended), descendable_types) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably rename "try_new_row_join" now or in the future. The reason is that I guess "try_legacy_row_join" will be eventually removed, and it would be very confusing if we have only a "new" version.
We can just call it "try_row_join".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed to
try_row_join