Skip to content

Commit

Permalink
[pallas:mosaic_gpu] Change the fori tests to also take the while_p pa…
Browse files Browse the repository at this point in the history
…th and fix the bug.

The bug was that bounds were dropped ctx.avals_in and then they were being
extracted. Extract them before dropping them.

PiperOrigin-RevId: 708266659
  • Loading branch information
cperivol authored and Google-ML-Automation committed Dec 20, 2024
1 parent 0b190bb commit 20efbd9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
4 changes: 3 additions & 1 deletion jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,8 +1572,11 @@ def _lower_while_via_fori(
body_nconsts,
):
assert not fori_jaxpr.constvars
# The pattern matcher looks for conditions with no constants.
assert cond_nconsts == 0

# Reflect the changes of the pattern matcher to the context.
lb_aval, ub_aval, *_ = ctx.avals_in[cond_nconsts + body_nconsts:]
ctx = ctx.replace(
avals_in=(
*ctx.avals_in[cond_nconsts:body_nconsts],
Expand All @@ -1585,7 +1588,6 @@ def _lower_while_via_fori(
_, consts, (lb, ub, *args) = util.split_list(
args, [cond_nconsts, body_nconsts]
)
lb_aval, ub_aval, *_ = ctx.avals_in[body_nconsts:]
lb = _ensure_ir_value(lb, lb_aval.dtype)
ub = _ensure_ir_value(ub, ub_aval.dtype)
for_out = _lower_jaxpr_to_for_loop(
Expand Down
29 changes: 21 additions & 8 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@
jax.config.parse_flags_with_absl()


def _fori_loop(force_while: bool, lb, ub, body, init):
if force_while:
# using jnp.asarray make the matcher for while or scan to think
# that the bounds are dynamic and forces the use of the while
# primitive.
lb, ub = jnp.asarray(lb), jnp.asarray(ub)
return jax.lax.fori_loop(lb, ub, body, init)


class PallasTest(jtu.JaxTestCase):

def setUp(self):
Expand Down Expand Up @@ -705,19 +714,21 @@ def kernel(x_ref, o_ref):
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
np.testing.assert_array_equal(kernel(x), x)

def test_fori_loop_array(self):
@parameterized.parameters(False, True)
def test_fori_loop_array(self, force_while):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
)
def kernel(x_ref, o_ref):
# Equivalent to x_ref[...] + 2 + 3.
o_ref[...] = jax.lax.fori_loop(2, 4, lambda i, x: x + i, x_ref[...])
o_ref[...] = _fori_loop(force_while, 2, 4, lambda i, x: x + i, x_ref[...])

x = jnp.arange(256).astype(jnp.int32)
np.testing.assert_array_equal(kernel(x), x + 2 + 3)

def test_fori_loop_scalar(self):
@parameterized.parameters(False, True)
def test_fori_loop_scalar(self, force_while):

@functools.partial(
pl.pallas_call,
Expand All @@ -726,7 +737,7 @@ def test_fori_loop_scalar(self):
def kernel(o_ref):
# Equivalent to 2 + 3.
o_ref[...] = jax.lax.broadcast(
jax.lax.fori_loop(2, 4, lambda i, x: x + i, 0), o_ref.shape
_fori_loop(force_while, 2, 4, lambda i, x: x + i, 0), o_ref.shape
)

np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))
Expand All @@ -747,7 +758,8 @@ def kernel(o_ref):

np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32))

def test_fori_loop_tuple(self):
@parameterized.parameters(False, True)
def test_fori_loop_tuple(self, force_while):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
Expand All @@ -761,14 +773,15 @@ def body(step, xs):

# Equivalent to 3 * (0 + 1).
o_ref[...] = jax.lax.broadcast(
sum(jax.lax.fori_loop(2, 4, body, (0, 0, 0))), o_ref.shape
sum(_fori_loop(force_while, 2, 4, body, (0, 0, 0))), o_ref.shape
)

np.testing.assert_array_equal(
kernel(), jnp.full([256], 3 * (0 + 1), dtype=jnp.int32)
)

def test_fori_loop_indexed_store(self):
@parameterized.parameters(False, True)
def test_fori_loop_indexed_store(self, force_while):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32),
Expand All @@ -778,7 +791,7 @@ def body(idx, _):
o_ref[idx] = x_ref[idx] + y_ref[idx]
return ()

jax.lax.fori_loop(0, 4, body, ())
_fori_loop(force_while, 0, 4, body, ())

x = jnp.arange(4 * 128).reshape(4, 128).astype(jnp.float32)
y = x + 1
Expand Down

0 comments on commit 20efbd9

Please sign in to comment.