Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic TPU] Add support for true divide in bf16 on TPUv6 #25608

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,8 @@ LogicalResult canonicalize_elementwise(int hardware_generation_,
return failure();
}
auto element_type = ty.getElementType();
// PowFOp and DivFOp do not seem to be supported in bf16 on later
// hardware.
bool needs_cast = hardware_generation_ <= 5 || isa<math::PowFOp>(op) ||
isa<arith::DivFOp>(op);
// PowFOp does not seem to be supported in bf16 on later hardware.
bool needs_cast = hardware_generation_ <= 5 || isa<math::PowFOp>(op);
if (needs_cast && element_type.isBF16()) {
auto target_f32 =
builder.create<arith::ExtFOp>(op.getLoc(), target_f32_ty, operand)
Expand Down
17 changes: 15 additions & 2 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,18 +1116,31 @@ def kernel(x_ref, y_ref, out_ref):
@parameterized.parameters(
("int32", "float32"),
("float32", "float32"),
("bfloat16", "bfloat16"),
)
def test_true_divide(self, dtype, out_dtype):
if jtu.test_device_matches(["tpu"]):
if out_dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6):
self.skipTest("bfloat16 is not supported on older TPU generations")
if not jtu.if_cloud_tpu_at_least(2024, 12, 21):
self.skipTest("Requires libtpu built after 2024-12-21")
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((8,), out_dtype),
out_shape=jax.ShapeDtypeStruct((8, 8), out_dtype),
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...])

x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype)
y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype)
np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y))
x = jnp.repeat(x, 8, axis=0).reshape(8, 8)
y = jnp.tile(y, 8).reshape(8, 8)
rtol = 8e-3 if dtype == "bfloat16" else 1e-6
np.testing.assert_allclose(
jnp.true_divide(x, y).astype(jnp.float32),
kernel(x, y).astype(jnp.float32),
rtol=rtol,
)

@parameterized.parameters("float16", "bfloat16")
def test_true_divide_unsupported(self, dtype):
Expand Down
Loading