You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The issue can be resolved by changing the following line in the klujax source code:
@vmap_register(solve_f64, solve)
@vmap_register(solve_c128, solve)
@vmap_register(coo_mul_vec_f64, coo_mul_vec)
@vmap_register(coo_mul_vec_c128, coo_mul_vec)
def coo_vec_operation_vmap(operation, vector_arg_values, batch_axes):
aAi, aAj, aAx, ab = batch_axes
Ai, Aj, Ax, b = vector_arg_values
assert aAi is None, "Ai cannot be vectorized."
assert aAj is None, "Aj cannot be vectorized."
if aAx is not None and ab is not None:
assert isinstance(aAx, int) and isinstance(ab, int)
Ax = jnp.moveaxis(Ax, aAx, 0) # treat as lhs
b = jnp.moveaxis(b, ab, 0) # treat as lhs
result = operation(Ai, Aj, Ax, b)
return result, 0
if ab is None:
assert isinstance(aAx, int)
Ax = jnp.moveaxis(Ax, aAx, 0) # treat as lhs
b = jnp.broadcast_to(b[None], (Ax.shape[0], *b.shape))
result = operation(Ai, Aj, Ax, b)
return result, 0
if aAx is None:
assert isinstance(ab, int)
_log(f"vmap: {b.shape=}")
b = jnp.moveaxis(b, ab, -1) # treat as rhs
_log(f"vmap: {b.shape=}")
shape = b.shape
if b.ndim == 0:
b = b[None, None, None]
elif b.ndim == 1:
b = b[None, None, :]
elif b.ndim == 2:
b = b[None, :, :]
elif b.ndim == 3:
b = b[:, :, :]
b = b.reshape(b.shape[0], b.shape[1], -1)
_log(f"vmap: {b.shape=}")
# b is now guaranteed to have shape (n_lhs, n_col, n_rhs)
result = operation(Ai, Aj, Ax, b)
result = result.reshape(*shape)
return result, -1 # BUG change this to line to "return result, result.ndim-1"
raise ValueError("invalid arguments for vmap")
The text was updated successfully, but these errors were encountered:
Hi,
I'm not sure if this is a jax bug or a klujax bug, but either way, the following script will fail:
The issue can be resolved by changing the following line in the klujax source code:
The text was updated successfully, but these errors were encountered: