Skip to content

Commit

Permalink
[pallas/pallas_mgpu] Discharging run_scoped should not be discharging…
Browse files Browse the repository at this point in the history
… the intermediates

When we do run_scoped[jaxpr, R1,R2], it can't be assumed that references
corresponding to R1 and R2 can be safely discharged. Sometimes they can (eg
Accumulator) but sometimes they can't (eg SMEM scratch). It should be up to the
lowering rule to do such discharging.

This further means that during lowering there is no guarantee that the
references will not be used/returned by nested scoped blocks so we also remove
that check.

PiperOrigin-RevId: 703334972
  • Loading branch information
cperivol authored and Google-ML-Automation committed Dec 20, 2024
1 parent 20efbd9 commit aab18bd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 21 deletions.
9 changes: 0 additions & 9 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,15 +1421,6 @@ def _run_scoped_lowering_rule(
ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts
)

for o in outs:
# This is definitely one of the accumulators we produced. Each
# run_scoped call is responsible for dereferencing its own
# accumulators.
if isinstance(o, mgpu.WGMMAAccumulator) or (
isinstance(o, ir.Value) and ir.MemRefType.isinstance(o.type)
):
raise ValueError(f"No references are allowed to escape a scope. (got {o})")

assert len(outs) == len(jaxpr.outvars), (jaxpr, outs)
return outs

Expand Down
27 changes: 15 additions & 12 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,31 +893,34 @@ def _run_scoped_discharge_rule(
**_):
del out_avals
num_consts = len(args_flat)
# discharge_state only discharges invars, not consts, so in order to
# discharge the requested refs we need to move them to the invar set.
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
num_return_values = len(jaxpr_noconst.outvars)
should_discharge = should_discharge + [
isinstance(var.aval, state.AbstractRef) for var in jaxpr.invars
]
discharged_body, new_consts = state_discharge.discharge_state(
jaxpr_noconst, [], should_discharge=should_discharge)
jaxpr_noconst,
[],
should_discharge=should_discharge + [False] * len(jaxpr.invars),
)
if new_consts:
raise NotImplementedError(
"Cannot handle new consts created by state discharge.")
# Create inputs filled with uninitialized values to the body.
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
init_vals = [uninitialized_value(
aval.shape, aval.dtype) for aval in body_avals]
init_vals_with_consts = args_flat + tuple(init_vals)
out = jax_core.eval_jaxpr(discharged_body, [], *init_vals_with_consts)

# Lowering expects that the jaxpr.consts to be the eqn.invals.
discharged_body = pe.convert_invars_to_constvars(discharged_body, num_consts)

# Run_scoped discharged the external variables but the scoped ones
# are not discharged.
out = run_scoped_p.bind(*args_flat, jaxpr=discharged_body)
# Order of outputs:
# (1) return values, (2) closed refs, (3) scoped refs.
return_values = out[:num_return_values]
ref_outputs = out[num_return_values:]
# We update all ref values with their updated values from the discharged
# body. For other values we leave them in place.
updates = [
ref_outputs.pop(0) if isinstance(aval, pallas_core.AbstractMemoryRef)
else None for aval in in_avals]
ref_outputs.pop(0) if should and isinstance(aval, pallas_core.AbstractMemoryRef)
else None for should, aval in zip(should_discharge, in_avals)]
assert len(updates) == len(in_avals), f'{len(updates)} != {len(in_avals)}'
return updates, return_values

Expand Down

0 comments on commit aab18bd

Please sign in to comment.