Skip to content

Commit

Permalink
Bugfix: vorticity_space_forcing was applying jnp.fft.rfft2 to GridVar…
Browse files Browse the repository at this point in the history
…iable when array was expected.

PiperOrigin-RevId: 660911244
  • Loading branch information
kochkov92 authored and JAX-CFD authors committed Aug 8, 2024
1 parent 48f69f5 commit 0c17e38
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax_cfd/ml/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def spectral_kolmogorov_forcing(grid):

@gin.register
def vorticity_space_forcing(grid: grids.Grid, forcing_module: ForcingModule):
"""Returns a forcing function that applies forcing in vorticity space."""
forcing_fn = forcing_module(grid, offsets=((0.0, 0.0), (0.0, 0.0)))
velocity_solve = spectral_utils.vorticity_to_velocity(grid)
kx, ky = grid.rfft_mesh()
Expand All @@ -60,7 +61,7 @@ def forcing_fn_ret(vorticity):
v = tuple(
grids.GridVariable(grids.GridArray(ifft(u), offset, grid), bc)
for u in velocity_solve(fft(vorticity)))
fhatu, fhatv = tuple(fft(u) for u in forcing_fn(v))
fhatu, fhatv = tuple(fft(u.data) for u in forcing_fn(v))
fhat_vorticity = 2j * jnp.pi * (fhatv * kx - fhatu * ky)
return ifft(fhat_vorticity)

Expand Down

0 comments on commit 0c17e38

Please sign in to comment.