Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Surprising difference of output between NumPy's float32 and JAX's float32 #25601

Open
jeertmans opened this issue Dec 19, 2024 · 8 comments
Open
Labels
bug Something isn't working

Comments

@jeertmans
Copy link
Contributor

Description

Hello!

I am not sure how to correctly title this issue, but I recently had an issue where the output of my some function would differ when @jax.jit was applied to the function.
Here is the simplified version:

def fun(f):
    f_ghz = f / 1e9
    return f_ghz >= 0.1

Which yields:

>>> fun(0.1e9)
True
>>> jax.jit(fun)(0.1e9)
Array(False, dtype=bool, weak_type=True)

I know Python float are 64-bit precision, while JAX's are 32-bit, and JIT can also optimize some operations, which could (I guess) change the order of operations and lead to different floating point results (?). I tried to force using float-32, in both cases:

import jax.numpy as jnp
import numpy as np


def fun_np(f):
    f_ghz = f / np.float32(1e9)
    print(f"{f_ghz.dtype = }")
    return f_ghz >= np.float32(0.1)

def fun_jnp(f):
    f_ghz = f / jnp.float32(1e9)
    print(f"{f_ghz.dtype = }")
    return f_ghz >= jnp.float32(0.1)

But the results are still surprising:

>>> fun_np(np.float32(0.1e9))
f_ghz.dtype = dtype('float32')
True
>>> fun_jnp(jnp.float32(0.1e9))
f_ghz.dtype = dtype('float32')
Array(False, dtype=bool)

The NumPy version outputs the expected value, while the JAX variant doesn't.
Am I doing something wrong?

I could not find documentation about this, and I wonder if those kinds of things can be prevented, or should be expected?

For this problem, I found a very simple fix:

def fun(f):
    return f >= 0.1 * 1e9

that is, using multiplication instead of division.

>>> fun(0.1e9)
True
>>> jax.jit(fun)(0.1e9)
Array(False, dtype=bool, weak_type=True)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34
jaxlib: 0.4.34
numpy:  1.26.4
python: 3.11.8 (main, Feb 25 2024, 04:18:18) [Clang 17.0.6 ]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='jeertmans', release='6.8.0-49-generic', version='#49-Ubuntu SMP PREEMPT_DYNAMIC Mon Nov  4 02:06:24 UTC 2024', machine='x86_64')
@jeertmans jeertmans added the bug Something isn't working label Dec 19, 2024
@jeertmans jeertmans changed the title Surprising difference of output betweenNumPy's float32 and JAX's float32 Surprising difference of output between NumPy's float32 and JAX's float32 Dec 19, 2024
@pearu
Copy link
Collaborator

pearu commented Dec 19, 2024

A simpler reproducer to this issue is:

>>> np.float32(0.1e9) / np.float32(1e9) == jnp.float32(0.1e9) / jnp.float32(1e9)
Array(False, dtype=bool)
>>> np.float32(0.1e9) / np.float32(1e9), jnp.float32(0.1e9) / jnp.float32(1e9)
(0.1, Array(0.09999999, dtype=float32))

It looks like numpy is using intermediate float64 arithmetic in float32 division and JAX float32 division uses strictly float32 arithmetic. To illustrate this, consider:

>>> import mpmath
>>> mpmath.mp.dps = 15  # double precision
>>> mpmath.mp.mpf(0.1e9) / mpmath.mp.mpf(1e9)
mpf('0.10000000000000001')
>>> mpmath.mp.dps = 6  # single precision
>>> mpmath.mp.mpf(0.1e9) / mpmath.mp.mpf(1e9)
mpf('0.099999994')

So, the issue not about a bug in JAX but rather belongs to https://jax.readthedocs.io/en/latest/faq.html

@jeertmans
Copy link
Contributor Author

Thanks for the reply @pearu!

That is quite surprising that NumPy doens't document that behavior (or I couldn't find it) on their Data type promotion page, but it isn't JAX's fault ^^'.

But I think that documenting this in JAX's FAQ could be great :-)

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 19, 2024

What exactly do you have in mind for documenting in the JAX FAQ?

@jeertmans
Copy link
Contributor Author

What exactly do you have in mind for documenting in the JAX FAQ?

That NumPy, as opposed to JAX, can promote data types for some computation, even if that is not documented, which might lead to difference between a JAX program and a NumPy program, even though you explicitly use the same precision. And probably link to this issue or provide example division.

I am not sure how it is better to phrase this.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 19, 2024

OK – maybe something like that would fit as a bullet point here? https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#miscellaneous-divergences-from-numpy

@jeertmans
Copy link
Contributor Author

OK – maybe something like that would fit as a bullet point here? https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#miscellaneous-divergences-from-numpy

Yes, I think this is a good place for that :)

@pearu
Copy link
Collaborator

pearu commented Dec 19, 2024

FWIW, JAX and Numba are in the same boat regarding float32 semantics with respect to NumPy: https://numba.readthedocs.io/en/stable/reference/fpsemantics.html

@jeertmans
Copy link
Contributor Author

Should I start a PR editing this page

OK – maybe something like that would fit as a bullet point here? https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#miscellaneous-divergences-from-numpy

?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants