Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Define RowJoinNode and defer rewrite #1183

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,9 @@ def try_row_join(
lcol.name: lcol.name for lcol in self.node.ids
}
other_node, r_mapping = self.prepare_join_names(other)
import bigframes.core.rewrite
import bigframes.core.row_join

result_node = bigframes.core.rewrite.try_row_join(
result_node = bigframes.core.row_join.try_row_join(
self.node, other_node, conditions
)
if result_node is None:
Expand Down
10 changes: 5 additions & 5 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2335,8 +2335,8 @@ def join(
# Handle null index, which only supports row join
# This is the canonical way of aligning on null index, so always allow (ignore block_identity_join)
if self.index.nlevels == other.index.nlevels == 0:
result = try_legacy_row_join(self, other, how=how) or try_new_row_join(
self, other
result = try_row_join(self, other) or try_legacy_row_join(
self, other, how=how
)
if result is not None:
return result
Expand All @@ -2350,8 +2350,8 @@ def join(
and (self.index.nlevels == other.index.nlevels)
and (self.index.dtypes == other.index.dtypes)
):
result = try_legacy_row_join(self, other, how=how) or try_new_row_join(
self, other
result = try_row_join(self, other) or try_legacy_row_join(
self, other, how=how
)
if result is not None:
return result
Expand Down Expand Up @@ -2691,7 +2691,7 @@ def is_uniquely_named(self: BlockIndexProperties):
return len(set(self.names)) == len(self.names)


def try_new_row_join(
def try_row_join(
left: Block, right: Block
) -> Optional[Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]]:
join_keys = tuple(
Expand Down
6 changes: 6 additions & 0 deletions bigframes/core/compile/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ def compile_join(self, node: nodes.JoinNode):
def compile_concat(self, node: nodes.ConcatNode):
return pl.concat(self.compile_node(child) for child in node.child_nodes)

@compile_node.register
def compile_row_join(self, node: nodes.RowJoinNode):
return pl.concat(
(self.compile_node(child) for child in node.child_nodes), how="horizontal"
)

@compile_node.register
def compile_agg(self, node: nodes.AggregateNode):
df = self.compile_node(node.child)
Expand Down
92 changes: 90 additions & 2 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ def roots(self) -> typing.Set[BigFrameNode]:
)
return set(roots)

@property
def all_nodes(self) -> Iterable[BigFrameNode]:
yield self
for child in self.child_nodes:
yield from child.all_nodes

def contains(self, node: BigFrameNode) -> bool:
return node in set(self.all_nodes)

# TODO: Store some local data lazily for select, aggregate nodes.
@property
@abc.abstractmethod
Expand Down Expand Up @@ -209,6 +218,12 @@ def explicitly_ordered(self) -> bool:
"""
...

@functools.cached_property
def height(self) -> int:
if len(self.child_nodes) == 0:
return 0
return max(child.height for child in self.child_nodes) + 1

@functools.cached_property
def total_variables(self) -> int:
return self.variables_introduced + sum(
Expand Down Expand Up @@ -477,6 +492,76 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
return dataclasses.replace(self, conditions=new_conds) # type: ignore


@dataclasses.dataclass(frozen=True, eq=False)
class RowJoinNode(BigFrameNode):
left_child: BigFrameNode
right_child: BigFrameNode

def _validate(self):
assert not (
set(self.left_child.ids) & set(self.right_child.ids)
), "Join ids collide"

@property
def row_preserving(self) -> bool:
return True

@property
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
return (self.left_child, self.right_child)

@property
def order_ambiguous(self) -> bool:
return False

@property
def explicitly_ordered(self) -> bool:
# Do not consider user pre-join ordering intent - they need to re-order post-join in unordered mode.
return True

@property
def fields(self) -> Iterable[Field]:
return itertools.chain(self.left_child.fields, self.right_child.fields)

@functools.cached_property
def variables_introduced(self) -> int:
return 0

@property
def row_count(self) -> Optional[int]:
return self.left_child.row_count

@property
def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
return ()

@property
def added_fields(self) -> Tuple[Field, ...]:
return tuple(self.right_child.fields)

def transform_children(
self, t: Callable[[BigFrameNode], BigFrameNode]
) -> BigFrameNode:
transformed = dataclasses.replace(
self, left_child=t(self.left_child), right_child=t(self.right_child)
)
if self == transformed:
# reusing existing object speeds up eq, and saves a small amount of memory
return self
return transformed

def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
return self.transform_children(lambda x: x.prune(used_cols))

def remap_vars(
self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]
) -> BigFrameNode:
return self

def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
return self


@dataclasses.dataclass(frozen=True, eq=False)
class ConcatNode(BigFrameNode):
# TODO: Explcitly map column ids from each child
Expand Down Expand Up @@ -1519,8 +1604,9 @@ def bottom_up(
root: BigFrameNode,
transform: Callable[[BigFrameNode], BigFrameNode],
*,
memoize=False,
validate=False,
stop: Optional[Callable[[BigFrameNode], bool]] = None,
memoize: bool = False,
validate: bool = False,
) -> BigFrameNode:
"""
Perform a bottom-up transformation of the BigFrameNode tree.
Expand All @@ -1529,6 +1615,8 @@ def bottom_up(
"""

def bottom_up_internal(root: BigFrameNode) -> BigFrameNode:
if (stop is not None) and (stop(root)):
return root
return transform(root.transform_children(bottom_up_internal))

if memoize:
Expand Down
5 changes: 3 additions & 2 deletions bigframes/core/rewrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.

from bigframes.core.rewrite.identifiers import remap_variables
from bigframes.core.rewrite.implicit_align import try_row_join
from bigframes.core.rewrite.implicit_align import rewrite_row_join
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
from bigframes.core.rewrite.slices import pullup_limit_from_slice, rewrite_slice

__all__ = [
"legacy_join_as_projection",
"try_row_join",
"rewrite_slice",
"pullup_limit_from_slice",
"remap_variables",
"combine_nodes",
"rewrite_row_join",
]
Loading
Loading