-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Surprising difference of output between NumPy's float32
and JAX's float32
#25601
Comments
float32
and JAX's float32
float32
and JAX's float32
A simpler reproducer to this issue is: >>> np.float32(0.1e9) / np.float32(1e9) == jnp.float32(0.1e9) / jnp.float32(1e9)
Array(False, dtype=bool)
>>> np.float32(0.1e9) / np.float32(1e9), jnp.float32(0.1e9) / jnp.float32(1e9)
(0.1, Array(0.09999999, dtype=float32)) It looks like numpy is using intermediate float64 arithmetic in float32 division and JAX float32 division uses strictly float32 arithmetic. To illustrate this, consider: >>> import mpmath
>>> mpmath.mp.dps = 15 # double precision
>>> mpmath.mp.mpf(0.1e9) / mpmath.mp.mpf(1e9)
mpf('0.10000000000000001')
>>> mpmath.mp.dps = 6 # single precision
>>> mpmath.mp.mpf(0.1e9) / mpmath.mp.mpf(1e9)
mpf('0.099999994') So, the issue not about a bug in JAX but rather belongs to https://jax.readthedocs.io/en/latest/faq.html |
Thanks for the reply @pearu! That is quite surprising that NumPy doens't document that behavior (or I couldn't find it) on their Data type promotion page, but it isn't JAX's fault ^^'. But I think that documenting this in JAX's FAQ could be great :-) |
What exactly do you have in mind for documenting in the JAX FAQ? |
That NumPy, as opposed to JAX, can promote data types for some computation, even if that is not documented, which might lead to difference between a JAX program and a NumPy program, even though you explicitly use the same precision. And probably link to this issue or provide example division. I am not sure how it is better to phrase this. |
OK – maybe something like that would fit as a bullet point here? https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#miscellaneous-divergences-from-numpy |
Yes, I think this is a good place for that :) |
FWIW, JAX and Numba are in the same boat regarding float32 semantics with respect to NumPy: https://numba.readthedocs.io/en/stable/reference/fpsemantics.html |
Should I start a PR editing this page
? |
Description
Hello!
I am not sure how to correctly title this issue, but I recently had an issue where the output of my some function would differ when
@jax.jit
was applied to the function.Here is the simplified version:
Which yields:
I know Python float are 64-bit precision, while JAX's are 32-bit, and JIT can also optimize some operations, which could (I guess) change the order of operations and lead to different floating point results (?). I tried to force using float-32, in both cases:
But the results are still surprising:
The NumPy version outputs the expected value, while the JAX variant doesn't.
Am I doing something wrong?
I could not find documentation about this, and I wonder if those kinds of things can be prevented, or should be expected?
For this problem, I found a very simple fix:
that is, using multiplication instead of division.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: