Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap does not work properly #11

Open
inversecrime opened this issue Dec 19, 2024 · 0 comments
Open

vmap does not work properly #11

inversecrime opened this issue Dec 19, 2024 · 0 comments

Comments

@inversecrime
Copy link

Hi,
I'm not sure if this is a jax bug or a klujax bug, but either way, the following script will fail:


import jax
import klujax
from jax import numpy as jnp


def random_sparse(key, shape, fill_ratio):
    size = prod(shape)
    ndim = len(shape)
    n = int(fill_ratio * size)

    (key_indices, key_data) = jax.random.split(key, 2)

    indices = jnp.floor(jnp.asarray(shape) * jax.random.uniform(key_indices, shape=(n, ndim))).astype(jnp.integer)
    data = jax.random.normal(key_data, shape=(n,)) + 1j * jax.random.normal(key_data, shape=(n,))

    return (indices, data, shape)


n = 5_000

key = jax.random.PRNGKey(0)
(key_A, key_b) = jax.random.split(key, 2)

(A_indices, A_data, A_shape) = random_sparse(key, shape=(n, n), fill_ratio=0.01)
b = jax.random.uniform(key, shape=(n,)) + 1j * jax.random.uniform(key, shape=(n,))


def foo(A_indices, A_data, b):
    print(b)
    x = klujax.solve(A_indices[:, 0], A_indices[:, 1], A_data, b)
    print(x)
    return x


b_batch = jnp.repeat(b[jnp.newaxis, ...], axis=0, repeats=10)
x_batch = jax.jit(lambda A_indices, A_data, b_batch: jax.vmap(lambda b: foo(A_indices, A_data, b))(b_batch))(A_indices, A_data, b_batch)

The issue can be resolved by changing the following line in the klujax source code:

@vmap_register(solve_f64, solve)
@vmap_register(solve_c128, solve)
@vmap_register(coo_mul_vec_f64, coo_mul_vec)
@vmap_register(coo_mul_vec_c128, coo_mul_vec)
def coo_vec_operation_vmap(operation, vector_arg_values, batch_axes):
    aAi, aAj, aAx, ab = batch_axes
    Ai, Aj, Ax, b = vector_arg_values

    assert aAi is None, "Ai cannot be vectorized."
    assert aAj is None, "Aj cannot be vectorized."

    if aAx is not None and ab is not None:
        assert isinstance(aAx, int) and isinstance(ab, int)
        Ax = jnp.moveaxis(Ax, aAx, 0)  # treat as lhs
        b = jnp.moveaxis(b, ab, 0)  # treat as lhs
        result = operation(Ai, Aj, Ax, b)
        return result, 0

    if ab is None:
        assert isinstance(aAx, int)
        Ax = jnp.moveaxis(Ax, aAx, 0)  # treat as lhs
        b = jnp.broadcast_to(b[None], (Ax.shape[0], *b.shape))
        result = operation(Ai, Aj, Ax, b)
        return result, 0

    if aAx is None:
        assert isinstance(ab, int)
        _log(f"vmap: {b.shape=}")
        b = jnp.moveaxis(b, ab, -1)  # treat as rhs
        _log(f"vmap: {b.shape=}")
        shape = b.shape
        if b.ndim == 0:
            b = b[None, None, None]
        elif b.ndim == 1:
            b = b[None, None, :]
        elif b.ndim == 2:
            b = b[None, :, :]
        elif b.ndim == 3:
            b = b[:, :, :]

        b = b.reshape(b.shape[0], b.shape[1], -1)

        _log(f"vmap: {b.shape=}")
        # b is now guaranteed to have shape (n_lhs, n_col, n_rhs)
        result = operation(Ai, Aj, Ax, b)
        result = result.reshape(*shape)
        return result, -1 # BUG change this to line to "return result, result.ndim-1"

    raise ValueError("invalid arguments for vmap")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant