diff --git a/jax_cfd/ml/forcings.py b/jax_cfd/ml/forcings.py index 61e93d0..1d62eb5 100644 --- a/jax_cfd/ml/forcings.py +++ b/jax_cfd/ml/forcings.py @@ -48,6 +48,7 @@ def spectral_kolmogorov_forcing(grid): @gin.register def vorticity_space_forcing(grid: grids.Grid, forcing_module: ForcingModule): + """Returns a forcing function that applies forcing in vorticity space.""" forcing_fn = forcing_module(grid, offsets=((0.0, 0.0), (0.0, 0.0))) velocity_solve = spectral_utils.vorticity_to_velocity(grid) kx, ky = grid.rfft_mesh() @@ -60,7 +61,7 @@ def forcing_fn_ret(vorticity): v = tuple( grids.GridVariable(grids.GridArray(ifft(u), offset, grid), bc) for u in velocity_solve(fft(vorticity))) - fhatu, fhatv = tuple(fft(u) for u in forcing_fn(v)) + fhatu, fhatv = tuple(fft(u.data) for u in forcing_fn(v)) fhat_vorticity = 2j * jnp.pi * (fhatv * kx - fhatu * ky) return ifft(fhat_vorticity)