From 0c17e3855702f884265b97bd6ff0793c34f3155e Mon Sep 17 00:00:00 2001 From: Dmitrii Kochkov Date: Thu, 8 Aug 2024 11:18:42 -0700 Subject: [PATCH] Bugfix: vorticity_space_forcing was applying jnp.fft.rfft2 to GridVariable when array was expected. PiperOrigin-RevId: 660911244 --- jax_cfd/ml/forcings.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax_cfd/ml/forcings.py b/jax_cfd/ml/forcings.py index 61e93d0..1d62eb5 100644 --- a/jax_cfd/ml/forcings.py +++ b/jax_cfd/ml/forcings.py @@ -48,6 +48,7 @@ def spectral_kolmogorov_forcing(grid): @gin.register def vorticity_space_forcing(grid: grids.Grid, forcing_module: ForcingModule): + """Returns a forcing function that applies forcing in vorticity space.""" forcing_fn = forcing_module(grid, offsets=((0.0, 0.0), (0.0, 0.0))) velocity_solve = spectral_utils.vorticity_to_velocity(grid) kx, ky = grid.rfft_mesh() @@ -60,7 +61,7 @@ def forcing_fn_ret(vorticity): v = tuple( grids.GridVariable(grids.GridArray(ifft(u), offset, grid), bc) for u in velocity_solve(fft(vorticity))) - fhatu, fhatv = tuple(fft(u) for u in forcing_fn(v)) + fhatu, fhatv = tuple(fft(u.data) for u in forcing_fn(v)) fhat_vorticity = 2j * jnp.pi * (fhatv * kx - fhatu * ky) return ifft(fhat_vorticity)