Skip to content

Commit

Permalink
Merge branch 'main' into bugfix-scalar-arr-casting
Browse files Browse the repository at this point in the history
* main:
  (feat): Support for `pandas` `ExtensionArray` (pydata#8723)
  Migrate datatree mapping.py (pydata#8948)
  Add mypy to dev dependencies (pydata#8947)
  Convert 360_day calendars by choosing random dates to drop or add (pydata#8603)
  • Loading branch information
dcherian committed Apr 18, 2024
2 parents e3493b0 + 9eb180b commit e27f572
Show file tree
Hide file tree
Showing 25 changed files with 562 additions and 90 deletions.
11 changes: 10 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ v2024.04.0 (unreleased)

New Features
~~~~~~~~~~~~

- New "random" method for converting to and from 360_day calendars (:pull:`8603`).
By `Pascal Bourgault <https://github.com/aulemahal>`_.
- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array
by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`,
for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray`
then, such as broadcasting.
By `Ilan Gold <https://github.com/ilan-gold>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand All @@ -34,6 +40,9 @@ Bug fixes

Internal Changes
~~~~~~~~~~~~~~~~
- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`)
By `Matt Savoie <https://github.com/flamingbear>`_ `Owen Littlejohns
<https://github.com/owenlittlejohns>` and `Tom Nicholas <https://github.com/TomNicholas>`_.


.. _whats-new.2024.03.0:
Expand Down
4 changes: 3 additions & 1 deletion properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from hypothesis import given # isort:skip

numeric_dtypes = st.one_of(
npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes()
npst.unsigned_integer_dtypes(endianness="="),
npst.integer_dtypes(endianness="="),
npst.floating_dtypes(endianness="="),
)

numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt))
Expand Down
14 changes: 8 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"]
complete = ["xarray[accel,io,parallel,viz,dev]"]
dev = [
"hypothesis",
"mypy",
"pre-commit",
"pytest",
"pytest-cov",
Expand Down Expand Up @@ -86,8 +87,8 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]
[tool.mypy]
enable_error_code = "redundant-self"
exclude = [
'xarray/util/generate_.*\.py',
'xarray/datatree_/.*\.py',
'xarray/util/generate_.*\.py',
'xarray/datatree_/.*\.py',
]
files = "xarray"
show_error_codes = true
Expand All @@ -98,8 +99,8 @@ warn_unused_ignores = true

# Ignore mypy errors for modules imported from datatree_.
[[tool.mypy.overrides]]
module = "xarray.datatree_.*"
ignore_errors = true
module = "xarray.datatree_.*"

# Much of the numerical computing stack doesn't have type annotations yet.
[[tool.mypy.overrides]]
Expand Down Expand Up @@ -129,6 +130,7 @@ module = [
"opt_einsum.*",
"pandas.*",
"pooch.*",
"pyarrow.*",
"pydap.*",
"pytest.*",
"scipy.*",
Expand Down Expand Up @@ -255,6 +257,9 @@ target-version = "py39"
# E402: module level import not at top of file
# E501: line too long - let black worry about that
# E731: do not assign a lambda expression, use a def
extend-safe-fixes = [
"TID252", # absolute imports
]
ignore = [
"E402",
"E501",
Expand All @@ -268,9 +273,6 @@ select = [
"I", # isort
"UP", # Pyupgrade
]
extend-safe-fixes = [
"TID252", # absolute imports
]

[tool.ruff.lint.per-file-ignores]
# don't enforce absolute imports
Expand Down
53 changes: 45 additions & 8 deletions xarray/coding/calendar_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def convert_calendar(
The target calendar name.
dim : str
Name of the time coordinate in the input DataArray or Dataset.
align_on : {None, 'date', 'year'}
align_on : {None, 'date', 'year', 'random'}
Must be specified when either the source or target is a `"360_day"`
calendar; ignored otherwise. See Notes.
missing : any, optional
Expand Down Expand Up @@ -143,6 +143,16 @@ def convert_calendar(
will be dropped as there are no equivalent dates in a standard calendar.
This option is best used with data on a frequency coarser than daily.
"random"
Similar to "year", each day of year of the source is mapped to another day of year
of the target. However, instead of having always the same missing days according
the source and target years, here 5 days are chosen randomly, one for each fifth
of the year. However, February 29th is always missing when converting to a leap year,
or its value is dropped when converting from a leap year. This is similar to the method
used in the LOCA dataset (see Pierce, Cayan, and Thrasher (2014). doi:10.1175/JHM-D-14-0082.1).
This option is best used on daily data.
"""
from xarray.core.dataarray import DataArray

Expand Down Expand Up @@ -174,14 +184,20 @@ def convert_calendar(

out = obj.copy()

if align_on == "year":
if align_on in ["year", "random"]:
# Special case for conversion involving 360_day calendar
# Instead of translating dates directly, this tries to keep the position within a year similar.

new_doy = time.groupby(f"{dim}.year").map(
_interpolate_day_of_year, target_calendar=calendar, use_cftime=use_cftime
)

if align_on == "year":
# Instead of translating dates directly, this tries to keep the position within a year similar.
new_doy = time.groupby(f"{dim}.year").map(
_interpolate_day_of_year,
target_calendar=calendar,
use_cftime=use_cftime,
)
elif align_on == "random":
# The 5 days to remove are randomly chosen, one for each of the five 72-days periods of the year.
new_doy = time.groupby(f"{dim}.year").map(
_random_day_of_year, target_calendar=calendar, use_cftime=use_cftime
)
# Convert the source datetimes, but override the day of year with our new day of years.
out[dim] = DataArray(
[
Expand Down Expand Up @@ -229,6 +245,27 @@ def _interpolate_day_of_year(time, target_calendar, use_cftime):
).astype(int)


def _random_day_of_year(time, target_calendar, use_cftime):
"""Return a day of year in the new calendar.
Removes Feb 29th and five other days chosen randomly within five sections of 72 days.
"""
year = int(time.dt.year[0])
source_calendar = time.dt.calendar
new_doy = np.arange(360) + 1
rm_idx = np.random.default_rng().integers(0, 72, 5) + 72 * np.arange(5)
if source_calendar == "360_day":
for idx in rm_idx:
new_doy[idx + 1 :] = new_doy[idx + 1 :] + 1
if _days_in_year(year, target_calendar, use_cftime) == 366:
new_doy[new_doy >= 60] = new_doy[new_doy >= 60] + 1
elif target_calendar == "360_day":
new_doy = np.insert(new_doy, rm_idx - np.arange(5), -1)
if _days_in_year(year, source_calendar, use_cftime) == 366:
new_doy = np.insert(new_doy, 60, -1)
return new_doy[time.dt.dayofyear - 1]


def _convert_to_new_calendar_with_new_day_of_year(
date, day_of_year, calendar, use_cftime
):
Expand Down
60 changes: 47 additions & 13 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload

import numpy as np
from pandas.api.types import is_extension_array_dtype

# remove once numpy 2.0 is the oldest supported version
try:
Expand Down Expand Up @@ -6852,10 +6853,13 @@ def reduce(
if (
# Some reduction functions (e.g. std, var) need to run on variables
# that don't have the reduce dims: PR5393
not reduce_dims
or not numeric_only
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
not is_extension_array_dtype(var.dtype)
and (
not reduce_dims
or not numeric_only
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
)
):
# prefer to aggregate over axis=None rather than
# axis=(0, 1) if they will be equivalent, because
Expand Down Expand Up @@ -7168,13 +7172,37 @@ def to_pandas(self) -> pd.Series | pd.DataFrame:
)

def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
columns = [k for k in self.variables if k not in self.dims]
columns_in_order = [k for k in self.variables if k not in self.dims]
non_extension_array_columns = [
k
for k in columns_in_order
if not is_extension_array_dtype(self.variables[k].data)
]
extension_array_columns = [
k
for k in columns_in_order
if is_extension_array_dtype(self.variables[k].data)
]
data = [
self._variables[k].set_dims(ordered_dims).values.reshape(-1)
for k in columns
for k in non_extension_array_columns
]
index = self.coords.to_index([*ordered_dims])
return pd.DataFrame(dict(zip(columns, data)), index=index)
broadcasted_df = pd.DataFrame(
dict(zip(non_extension_array_columns, data)), index=index
)
for extension_array_column in extension_array_columns:
extension_array = self.variables[extension_array_column].data.array
index = self[self.variables[extension_array_column].dims[0]].data
extension_array_df = pd.DataFrame(
{extension_array_column: extension_array},
index=self[self.variables[extension_array_column].dims[0]].data,
)
extension_array_df.index.name = self.variables[extension_array_column].dims[
0
]
broadcasted_df = broadcasted_df.join(extension_array_df)
return broadcasted_df[columns_in_order]

def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame:
"""Convert this dataset into a pandas.DataFrame.
Expand Down Expand Up @@ -7321,11 +7349,13 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
)

# Cast to a NumPy array first, in case the Series is a pandas Extension
# array (which doesn't have a valid NumPy dtype)
# TODO: allow users to control how this casting happens, e.g., by
# forwarding arguments to pandas.Series.to_numpy?
arrays = [(k, np.asarray(v)) for k, v in dataframe.items()]
arrays = []
extension_arrays = []
for k, v in dataframe.items():
if not is_extension_array_dtype(v):
arrays.append((k, np.asarray(v)))
else:
extension_arrays.append((k, v))

indexes: dict[Hashable, Index] = {}
index_vars: dict[Hashable, Variable] = {}
Expand All @@ -7339,6 +7369,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
xr_idx = PandasIndex(lev, dim)
indexes[dim] = xr_idx
index_vars.update(xr_idx.create_variables())
arrays += [(k, np.asarray(v)) for k, v in extension_arrays]
extension_arrays = []
else:
index_name = idx.name if idx.name is not None else "index"
dims = (index_name,)
Expand All @@ -7352,7 +7384,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
obj._set_sparse_data_from_dataframe(idx, arrays, dims)
else:
obj._set_numpy_data_from_dataframe(idx, arrays, dims)
return obj
for name, extension_array in extension_arrays:
obj[name] = (dims, extension_array)
return obj[dataframe.columns] if len(dataframe.columns) else obj

def to_dask_dataframe(
self, dim_order: Sequence[Hashable] | None = None, set_index: bool = False
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from xarray.core.coordinates import DatasetCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
from xarray.core.datatree_mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.core.indexes import Index, Indexes
from xarray.core.merge import dataset_update_method
from xarray.core.options import OPTIONS as XR_OPTS
Expand All @@ -36,11 +41,6 @@
from xarray.datatree_.datatree.formatting_html import (
datatree_repr as datatree_repr_html,
)
from xarray.datatree_.datatree.mapping import (
TreeIsomorphismError,
check_isomorphic,
map_over_subtree,
)
from xarray.datatree_.datatree.ops import (
DataTreeArithmeticMixin,
MappedDatasetMethodsMixin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import sys
from itertools import repeat
from textwrap import dedent
from typing import TYPE_CHECKING, Callable, Tuple
from typing import TYPE_CHECKING, Callable

from xarray import DataArray, Dataset

from xarray.core.iterators import LevelOrderIter
from xarray.core.treenode import NodePath, TreeNode

Expand Down Expand Up @@ -84,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
path_a, path_b = node_a.path, node_b.path

if require_names_equal:
if node_a.name != node_b.name:
diff = dedent(
f"""\
if require_names_equal and node_a.name != node_b.name:
diff = dedent(
f"""\
Node '{path_a}' in the left object has name '{node_a.name}'
Node '{path_b}' in the right object has name '{node_b.name}'"""
)
return diff
)
return diff

if len(node_a.children) != len(node_b.children):
diff = dedent(
Expand Down Expand Up @@ -125,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable:
func : callable
Function to apply to datasets with signature:
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
`func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`.
(i.e. func must accept at least one Dataset and return at least one Dataset.)
Function will not be applied to any nodes without datasets.
Expand Down Expand Up @@ -154,7 +152,7 @@ def map_over_subtree(func: Callable) -> Callable:
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?

@functools.wraps(func)
def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
"""Internal function which maps func over every node in tree, returning a tree of the results."""
from xarray.core.datatree import DataTree

Expand Down Expand Up @@ -259,19 +257,18 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
return _map_over_subtree


def _handle_errors_with_path_context(path):
def _handle_errors_with_path_context(path: str):
"""Wraps given function so that if it fails it also raises path to node on which it failed."""

def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if sys.version_info >= (3, 11):
# Add the context information to the error message
e.add_note(
f"Raised whilst mapping function over node with path {path}"
)
# Add the context information to the error message
add_note(
e, f"Raised whilst mapping function over node with path {path}"
)
raise

return wrapper
Expand All @@ -287,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None:
err.add_note(msg)


def _check_single_set_return_values(path_to_node, obj):
def _check_single_set_return_values(
path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
):
"""Check types returned from single evaluation of func, and return number of return values received from func."""
if isinstance(obj, (Dataset, DataArray)):
return 1
Expand Down
Loading

0 comments on commit e27f572

Please sign in to comment.