From d8fdfc0259a83a85053f5df47c3eb4ce1ad5d970 Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Mon, 8 Apr 2024 10:40:03 -0600 Subject: [PATCH 1/3] WIP on single_level_regrid --- ndpyramid/regrid.py | 96 +++++++++++++++++++++++++++++++-------------- 1 file changed, 66 insertions(+), 30 deletions(-) diff --git a/ndpyramid/regrid.py b/ndpyramid/regrid.py index b9ffe23..bcb3fd1 100644 --- a/ndpyramid/regrid.py +++ b/ndpyramid/regrid.py @@ -186,6 +186,71 @@ def generate_weights_pyramid( return dt.DataTree.from_dict(plevels) +def level_regrid( + ds: xr.Dataset, + *, + level: int, + weights_pyramid: dt.DataTree = None, + method: str = 'bilinear', + regridder_kws: dict = None, + projection: typing.Literal['web-mercator', 'equidistant-cylindrical'] = 'web-mercator', + other_chunks: dict = None, + pixels_per_tile: int = 128, + regridder_apply_kws: dict = None, +): + + import xesmf as xe + + projection_model = Projection(name=projection) + + save_kwargs = { + 'level': level, + 'pixels_per_tile': pixels_per_tile, + 'projection': projection, + 'other_chunks': other_chunks, + 'method': method, + 'regridder_kws': regridder_kws, + 'regridder_apply_kws': regridder_apply_kws, + } + + regridder_kws = {} if regridder_kws is None else regridder_kws + regridder_kws = {'periodic': True, **regridder_kws} + + grid = ds.load() + # get the regridder object + if weights_pyramid is None: + regridder = xe.Regridder(ds, grid, method, **regridder_kws) + else: + # Reconstruct weights into format that xESMF understands + # this is a hack that assumes the weights were generated by + # the `generate_weights_pyramid` function + + ds_w = weights_pyramid[str(level)].ds + weights = _reconstruct_xesmf_weights(ds_w) + regridder = xe.Regridder( + ds, grid, method, reuse_weights=True, weights=weights, **regridder_kws + ) + # regrid + if regridder_apply_kws is None: + regridder_apply_kws = {} + regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws} + # plevels[str(level)] = regridder(ds, **regridder_apply_kws) + level_ds = regridder(ds, **regridder_apply_kws) + level_attrs = { + 'multiscales': multiscales_template( + datasets=[{'path': '.', 'level': level, 'crs': projection_model._crs}], + type='reduce', + method='pyramid_regrid', + version=get_version(), + kwargs=save_kwargs, + ) + } + + # import pdb; pdb.set_trace() + level_ds.attrs['multiscales'] = level_attrs['multiscales'] + return level_ds + + def pyramid_regrid( ds: xr.Dataset, projection: typing.Literal['web-mercator', 'equidistant-cylindrical'] = 'web-mercator', @@ -229,7 +294,6 @@ def pyramid_regrid( pyramid : dt.DataTree Multiscale data pyramid """ - import xesmf as xe if target_pyramid is None: if levels is not None: @@ -274,35 +338,7 @@ def pyramid_regrid( # pyramid data for level in range(levels): - grid = target_pyramid[str(level)].ds.load() - # get the regridder object - if weights_pyramid is None: - regridder = xe.Regridder(ds, grid, method, **regridder_kws) - else: - # Reconstruct weights into format that xESMF understands - # this is a hack that assumes the weights were generated by - # the `generate_weights_pyramid` function - - ds_w = weights_pyramid[str(level)].ds - weights = _reconstruct_xesmf_weights(ds_w) - regridder = xe.Regridder( - ds, grid, method, reuse_weights=True, weights=weights, **regridder_kws - ) - # regrid - if regridder_apply_kws is None: - regridder_apply_kws = {} - regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws} - plevels[str(level)] = regridder(ds, **regridder_apply_kws) - level_attrs = { - 'multiscales': multiscales_template( - datasets=[{'path': '.', 'level': level, 'crs': projection_model._crs}], - type='reduce', - method='pyramid_regrid', - version=get_version(), - kwargs=save_kwargs, - ) - } - plevels[str(level)].attrs['multiscales'] = level_attrs['multiscales'] + plevels[str(level)] = level_regrid(ds=ds, level=level) root = xr.Dataset(attrs=attrs) plevels['/'] = root From ff4c3419323e708e063af06ee82b6d9102e9f0ac Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Mon, 8 Apr 2024 10:40:22 -0600 Subject: [PATCH 2/3] remove pdb --- ndpyramid/regrid.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ndpyramid/regrid.py b/ndpyramid/regrid.py index bcb3fd1..8ea5419 100644 --- a/ndpyramid/regrid.py +++ b/ndpyramid/regrid.py @@ -246,7 +246,6 @@ def level_regrid( ) } - # import pdb; pdb.set_trace() level_ds.attrs['multiscales'] = level_attrs['multiscales'] return level_ds From 5fc11ccf52f0951e617365eb47fc33822d474e46 Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Wed, 15 May 2024 11:59:33 -0600 Subject: [PATCH 3/3] added testing for level regrid and reproject --- ndpyramid/__init__.py | 4 ++-- tests/conftest.py | 9 +++++++++ tests/test_pyramids.py | 8 -------- tests/test_single_level.py | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 10 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_single_level.py diff --git a/ndpyramid/__init__.py b/ndpyramid/__init__.py index 23a2d95..ae1485d 100644 --- a/ndpyramid/__init__.py +++ b/ndpyramid/__init__.py @@ -1,6 +1,6 @@ # flake8: noqa from .coarsen import pyramid_coarsen -from .reproject import pyramid_reproject -from .regrid import pyramid_regrid +from .reproject import pyramid_reproject, level_reproject +from .regrid import pyramid_regrid, level_regrid from ._version import __version__ diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..e263c48 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +import pytest +import xarray as xr + + +@pytest.fixture +def temperature(): + ds = xr.tutorial.open_dataset('air_temperature') + ds['air'].encoding = {} + return ds diff --git a/tests/test_pyramids.py b/tests/test_pyramids.py index 540d28e..7bd2f12 100644 --- a/tests/test_pyramids.py +++ b/tests/test_pyramids.py @@ -1,6 +1,5 @@ import numpy as np import pytest -import xarray as xr from zarr.storage import MemoryStore from ndpyramid import pyramid_coarsen, pyramid_regrid, pyramid_reproject @@ -8,13 +7,6 @@ from ndpyramid.testing import verify_bounds -@pytest.fixture -def temperature(): - ds = xr.tutorial.open_dataset('air_temperature') - ds['air'].encoding = {} - return ds - - def test_xarray_coarsened_pyramid(temperature, benchmark): factors = [4, 2, 1] pyramid = benchmark( diff --git a/tests/test_single_level.py b/tests/test_single_level.py new file mode 100644 index 0000000..f9fb6cb --- /dev/null +++ b/tests/test_single_level.py @@ -0,0 +1,35 @@ +import pytest +from zarr.storage import MemoryStore + +from ndpyramid.regrid import level_regrid +from ndpyramid.reproject import level_reproject + + +@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': False}]) +def test_level_regrid(temperature, regridder_apply_kws, benchmark): + pytest.importorskip('xesmf') + regrid_ds = benchmark( + lambda: level_regrid( + temperature, level=1, regridder_apply_kws=regridder_apply_kws, other_chunks={'time': 2} + ) + ) + assert regrid_ds.attrs['multiscales'] + assert regrid_ds.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' + expected_attrs = ( + temperature['air'].attrs + if not regridder_apply_kws or regridder_apply_kws.get('keep_attrs') + else {} + ) + assert regrid_ds.air.attrs == expected_attrs + regrid_ds.to_zarr(MemoryStore()) + + +def test_reprojected_pyramid(temperature, benchmark): + pytest.importorskip('rioxarray') + temperature = temperature.rio.write_crs('EPSG:4326') + reproject_ds = benchmark(lambda: level_reproject(temperature, level=1)) + assert reproject_ds.attrs['multiscales'] + assert len(reproject_ds.attrs['multiscales'][0]['datasets']) == 1 + assert reproject_ds.attrs['multiscales'][0]['datasets'][0]['crs'] == 'EPSG:3857' + + reproject_ds.to_zarr(MemoryStore())