Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.random.choice(replace=True) samples 0 probability index #25498

Open
LemonATsu opened this issue Dec 16, 2024 · 1 comment
Open

jax.random.choice(replace=True) samples 0 probability index #25498

LemonATsu opened this issue Dec 16, 2024 · 1 comment
Labels
bug Something isn't working duplicate This issue or pull request already exists

Comments

@LemonATsu
Copy link

Description

jax.random.choice(replace=True) will sample 0 probability entry when the input array is large, and the average probability is low (~1e-07):

sample_prob = np.zeros((7000000,))
sample_prob[:5000000] = 1.0
sample_prob = jnp.array(sample_prob / (sample_prob.sum()))
print(sample_prob.max())  # Output: 2e-07
print(sample_prob.min())  # Output: 0.0

sampled_idxs = jax.random.choice(
    jax.random.PRNGKey(0),
    a=jnp.arange(len(sample_prob)),
    shape=(len(sample_prob),),
    p=sample_prob,
    replace=True,
)

print((sample_prob[sampled_idxs]).min())  # Output: 0.0, shouldn't happen

The numpy counter part np.random.choice behaves correctly:

sample_prob = np.zeros((7000000,)).astype(np.float32)
sample_prob[:5000000] = 1.0
sample_prob = sample_prob / (sample_prob.sum())
print(sample_prob.max())  # Output: 2e-07
print(sample_prob.min())  # Output: 0.0

sampled_idxs = np.random.choice(
    a=np.arange(len(sample_prob)),
    size=(len(sample_prob),),
    p=sample_prob,
    replace=True,
)

print((sample_prob[sampled_idxs]).min())  # Output: 2e-07, expected

Seems like an unexpected behavior/bug?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.1.3
python: 3.11.8 (stable, redacted, redacted) [Clang 9999.0.0 (be2df95e9281985b61270bb6420ea0eeeffbbe59)]
device info: Tesla V100-SXM2-16GB-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.10.0-smp-1106.20.0.0', version='#1 [v5.10.0-1106.20.0.0] SMP @1728697352', machine='x86_64')
@LemonATsu LemonATsu added the bug Something isn't working label Dec 16, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 16, 2024

Thanks for the report – this looks to be a duplicate of #22682.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working duplicate This issue or pull request already exists
Projects
None yet
Development

No branches or pull requests

2 participants