diff --git a/jax/_src/random.py b/jax/_src/random.py index 12aa5b93efbf..f956241ce06f 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -2070,8 +2070,8 @@ def orthogonal( n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()") z = normal(key, (*shape, n, n), dtype) q, r = jnp.linalg.qr(z) - d = jnp.diagonal(r, 0, -2, -1) - return lax.mul(q, lax.expand_dims(lax.div(d, abs(d).astype(d.dtype)), [-2])) + d = jnp.diagonal(r, axis1=-2, axis2=-1) + return q * jnp.expand_dims(jnp.sign(d), -2) def generalized_normal( key: ArrayLike,