diff --git a/xarray/groupers.py b/xarray/groupers.py index e4cb884e6de..97de79602d4 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -16,8 +16,10 @@ from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.core import duck_array_ops +from xarray.core.computation import apply_ufunc from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray +from xarray.core.duck_array_ops import isnull from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper @@ -29,6 +31,7 @@ SideOptions, ) from xarray.core.variable import Variable +from xarray.namedarray.pycompat import is_chunked_array __all__ = [ "EncodedGroups", @@ -96,7 +99,7 @@ def __init__( assert isinstance(full_index, pd.Index) self.full_index = full_index - if group_indices is None: + if group_indices is None and not is_chunked_array(codes.data): self.group_indices = tuple( g for g in _codes_to_group_indices(codes.data.ravel(), len(full_index)) @@ -155,10 +158,17 @@ class UniqueGrouper(Grouper): """Grouper object for grouping by a categorical variable.""" _group_as_index: pd.Index | None = field(default=None, repr=False) + labels: np.ndarray | None = field(default=None) + + def __post_init__(self) -> None: + if self.labels is not None: + self.labels = np.sort(self.labels) @property def group_as_index(self) -> pd.Index: """Caches the group DataArray as a pandas Index.""" + if is_chunked_array(self.group): + raise ValueError("Please call compute manually.") if self._group_as_index is None: if self.group.ndim == 1: self._group_as_index = self.group.to_index() @@ -169,6 +179,11 @@ def group_as_index(self) -> pd.Index: def factorize(self, group: T_Group) -> EncodedGroups: self.group = group + if is_chunked_array(group.data) and self.labels is None: + raise ValueError("When grouping by a dask array, `labels` must be passed.") + if self.labels is not None: + return self._factorize_given_labels(group) + index = self.group_as_index is_unique_and_monotonic = isinstance(self.group, _DummyGroup) or ( index.is_unique @@ -182,6 +197,24 @@ def factorize(self, group: T_Group) -> EncodedGroups: else: return self._factorize_unique() + def _factorize_given_labels(self, group: T_Group) -> EncodedGroups: + codes = apply_ufunc( + _factorize_given_labels, + group, + kwargs={"labels": self.labels}, + dask="parallelized", + output_dtypes=[np.int64], + ) + return EncodedGroups( + codes=codes, + full_index=pd.Index(self.labels), + unique_coord=Variable( + dims=codes.name, + data=self.labels, + attrs=self.group.attrs, + ), + ) + def _factorize_unique(self) -> EncodedGroups: # look through group to find the unique values sort = not isinstance(self.group_as_index, pd.MultiIndex) @@ -291,13 +324,9 @@ def __post_init__(self) -> None: if duck_array_ops.isnull(self.bins).all(): raise ValueError("All bin edges are NaN.") - def factorize(self, group: T_Group) -> EncodedGroups: - from xarray.core.dataarray import DataArray - - data = np.asarray(group.data) # Cast _DummyGroup data to array - - binned, self.bins = pd.cut( # type: ignore [call-overload] - data.ravel(), + def _cut(self, data): + return pd.cut( # type: ignore [call-overload] + np.asarray(data).ravel(), bins=self.bins, right=self.right, labels=self.labels, @@ -307,23 +336,43 @@ def factorize(self, group: T_Group) -> EncodedGroups: retbins=True, ) - binned_codes = binned.codes - if (binned_codes == -1).all(): + def _factorize_lazy(self, group: T_Group) -> DataArray: + def _wrapper(data, **kwargs): + binned, bins = self._cut(data) + if isinstance(self.bins, int): + # we are running eagerly, update self.bins with actual edges instead + self.bins = bins + return binned.codes.reshape(data.shape) + + return apply_ufunc(_wrapper, group, dask="parallelized") + + def factorize(self, group: T_Group) -> EncodedGroups: + if isinstance(group, _DummyGroup): + group = DataArray(group.data, dims=group.dims, name=group.name) + by_is_chunked = is_chunked_array(group.data) + if isinstance(self.bins, int) and by_is_chunked: + raise ValueError( + f"Bin edges must be provided when grouping by chunked arrays. Received {self.bins=!r} instead" + ) + codes = self._factorize_lazy(group) + if not by_is_chunked and (codes == -1).all(): raise ValueError( f"None of the data falls within bins with edges {self.bins!r}" ) new_dim_name = f"{group.name}_bins" + codes.name = new_dim_name + + # This seems silly, but it lets us have Pandas handle the complexity + # of labels, precision, and include_lowest, even when group is a chunked array + dummy, _ = self._cut(np.array([1, 2, 3]).astype(group.dtype)) + full_index = dummy.categories + if not by_is_chunked: + uniques = np.sort(pd.unique(codes.data.ravel())) + unique_values = full_index[uniques[uniques != -1]] + else: + unique_values = full_index - full_index = binned.categories - uniques = np.sort(pd.unique(binned_codes)) - unique_values = full_index[uniques[uniques != -1]] - - codes = DataArray( - binned_codes.reshape(group.shape), - getattr(group, "coords", None), - name=new_dim_name, - ) unique_coord = Variable( dims=new_dim_name, data=unique_values, attrs=group.attrs ) @@ -461,6 +510,21 @@ def factorize(self, group: T_Group) -> EncodedGroups: ) +def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray: + # Copied from flox + sort = False # use labels as provided + sorter = np.argsort(labels) + codes = np.searchsorted(labels, data, sorter=sorter) + mask = ~np.isin(data, labels) | isnull(data) | (codes == len(labels)) + if not sort: + # codes is the index in to the sorted array. + # if we didn't want sorting, unsort it back + codes[(codes == len(labels),)] = -1 + codes = sorter[(codes,)] + codes[mask] = -1 + return codes + + def unique_value_groups( ar, sort: bool = True ) -> tuple[np.ndarray | pd.Index, np.ndarray]: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index dc869cc3a34..bdf017e2be9 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2583,7 +2583,9 @@ def test_groupby_math_auto_chunk() -> None: sub = xr.DataArray( InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]} ) - actual = da.chunk(x=1, y=2).groupby("label") - sub + chunked = da.chunk(x=1, y=2) + chunked.label.load() + actual = chunked.groupby("label") - sub assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)}