Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Single Level Regridding #124

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ndpyramid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# flake8: noqa

from .coarsen import pyramid_coarsen
from .reproject import pyramid_reproject
from .regrid import pyramid_regrid
from .reproject import pyramid_reproject, level_reproject
from .regrid import pyramid_regrid, level_regrid
from ._version import __version__
95 changes: 65 additions & 30 deletions ndpyramid/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,70 @@ def generate_weights_pyramid(
return dt.DataTree.from_dict(plevels)


def level_regrid(
ds: xr.Dataset,
*,
level: int,
weights_pyramid: dt.DataTree = None,
method: str = 'bilinear',
regridder_kws: dict = None,
projection: typing.Literal['web-mercator', 'equidistant-cylindrical'] = 'web-mercator',
other_chunks: dict = None,
pixels_per_tile: int = 128,
regridder_apply_kws: dict = None,
):

import xesmf as xe

projection_model = Projection(name=projection)

save_kwargs = {
'level': level,
'pixels_per_tile': pixels_per_tile,
'projection': projection,
'other_chunks': other_chunks,
'method': method,
'regridder_kws': regridder_kws,
'regridder_apply_kws': regridder_apply_kws,
}

regridder_kws = {} if regridder_kws is None else regridder_kws
regridder_kws = {'periodic': True, **regridder_kws}

grid = ds.load()
# get the regridder object
if weights_pyramid is None:
regridder = xe.Regridder(ds, grid, method, **regridder_kws)
else:
# Reconstruct weights into format that xESMF understands
# this is a hack that assumes the weights were generated by
# the `generate_weights_pyramid` function

ds_w = weights_pyramid[str(level)].ds
weights = _reconstruct_xesmf_weights(ds_w)
regridder = xe.Regridder(
ds, grid, method, reuse_weights=True, weights=weights, **regridder_kws
)
# regrid
if regridder_apply_kws is None:
regridder_apply_kws = {}
regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws}
# plevels[str(level)] = regridder(ds, **regridder_apply_kws)
level_ds = regridder(ds, **regridder_apply_kws)
level_attrs = {
'multiscales': multiscales_template(
datasets=[{'path': '.', 'level': level, 'crs': projection_model._crs}],
type='reduce',
method='pyramid_regrid',
version=get_version(),
kwargs=save_kwargs,
)
}

level_ds.attrs['multiscales'] = level_attrs['multiscales']
return level_ds


def pyramid_regrid(
ds: xr.Dataset,
projection: typing.Literal['web-mercator', 'equidistant-cylindrical'] = 'web-mercator',
Expand Down Expand Up @@ -229,7 +293,6 @@ def pyramid_regrid(
pyramid : dt.DataTree
Multiscale data pyramid
"""
import xesmf as xe

if target_pyramid is None:
if levels is not None:
Expand Down Expand Up @@ -274,35 +337,7 @@ def pyramid_regrid(

# pyramid data
for level in range(levels):
grid = target_pyramid[str(level)].ds.load()
# get the regridder object
if weights_pyramid is None:
regridder = xe.Regridder(ds, grid, method, **regridder_kws)
else:
# Reconstruct weights into format that xESMF understands
# this is a hack that assumes the weights were generated by
# the `generate_weights_pyramid` function

ds_w = weights_pyramid[str(level)].ds
weights = _reconstruct_xesmf_weights(ds_w)
regridder = xe.Regridder(
ds, grid, method, reuse_weights=True, weights=weights, **regridder_kws
)
# regrid
if regridder_apply_kws is None:
regridder_apply_kws = {}
regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws}
plevels[str(level)] = regridder(ds, **regridder_apply_kws)
level_attrs = {
'multiscales': multiscales_template(
datasets=[{'path': '.', 'level': level, 'crs': projection_model._crs}],
type='reduce',
method='pyramid_regrid',
version=get_version(),
kwargs=save_kwargs,
)
}
plevels[str(level)].attrs['multiscales'] = level_attrs['multiscales']
plevels[str(level)] = level_regrid(ds=ds, level=level)

root = xr.Dataset(attrs=attrs)
plevels['/'] = root
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest
import xarray as xr


@pytest.fixture
def temperature():
ds = xr.tutorial.open_dataset('air_temperature')
ds['air'].encoding = {}
return ds
8 changes: 0 additions & 8 deletions tests/test_pyramids.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
import numpy as np
import pytest
import xarray as xr
from zarr.storage import MemoryStore

from ndpyramid import pyramid_coarsen, pyramid_regrid, pyramid_reproject
from ndpyramid.regrid import generate_weights_pyramid, make_grid_ds
from ndpyramid.testing import verify_bounds


@pytest.fixture
def temperature():
ds = xr.tutorial.open_dataset('air_temperature')
ds['air'].encoding = {}
return ds


def test_xarray_coarsened_pyramid(temperature, benchmark):
factors = [4, 2, 1]
pyramid = benchmark(
Expand Down
35 changes: 35 additions & 0 deletions tests/test_single_level.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
from zarr.storage import MemoryStore

from ndpyramid.regrid import level_regrid
from ndpyramid.reproject import level_reproject


@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': False}])
def test_level_regrid(temperature, regridder_apply_kws, benchmark):
pytest.importorskip('xesmf')
regrid_ds = benchmark(
lambda: level_regrid(
temperature, level=1, regridder_apply_kws=regridder_apply_kws, other_chunks={'time': 2}
)
)
assert regrid_ds.attrs['multiscales']
assert regrid_ds.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857'
expected_attrs = (
temperature['air'].attrs
if not regridder_apply_kws or regridder_apply_kws.get('keep_attrs')
else {}
)
assert regrid_ds.air.attrs == expected_attrs
regrid_ds.to_zarr(MemoryStore())


def test_reprojected_pyramid(temperature, benchmark):
pytest.importorskip('rioxarray')
temperature = temperature.rio.write_crs('EPSG:4326')
reproject_ds = benchmark(lambda: level_reproject(temperature, level=1))
assert reproject_ds.attrs['multiscales']
assert len(reproject_ds.attrs['multiscales'][0]['datasets']) == 1
assert reproject_ds.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857'

reproject_ds.to_zarr(MemoryStore())
Loading