You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When writing custom distributions, it is often helpful to have numerically stable implementations of log_diff_exp(a, b) := log(exp(a) - exp(b)) and particularly log1m_exp(x) := log(1 - exp(x)). The naive implementations are not stable for many probabilistic programming use cases, and so probabilistic programming languages including Stan and PyMC provide numerically-stable implementations (typically following Machler, 2012) of these functions.
As far as I can tell, Numpyro does not, and they are not present in Jax.
I wonder whether it would be worth providing them. I have written basic implementations following Machler for my own use. I would happily make a PR including them, but someone more experienced could probably write better/more idiomatic ones.
import jax.numpy as jnp
def log1m_exp(x):
"""
Numerically stable calculation
of the quantity log(1 - exp(x)),
following the algorithm of
Machler [1]. This is
the algorithm used in TensorFlow Probability,
PyMC, and Stan, but it is not provided
yet with Numpyro.
Currently returns NaN for x > 0,
but may be modified in the future
to throw a ValueError
[1] https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
"""
# return 0. rather than -0. if
# we get a negative exponent that exceeds
# the floating point representation
arr_x = 1.0 * jnp.array(x)
oob = arr_x < jnp.log(jnp.finfo(
arr_x.dtype).smallest_normal)
mask = arr_x > -0.6931472 # appox -log(2)
more_val = jnp.log(-jnp.expm1(arr_x))
less_val = jnp.log1p(-jnp.exp(arr_x))
return jnp.where(
oob,
0.,
jnp.where(
mask,
more_val,
less_val))
def log_diff_exp(a, b):
# note that following Stan,
# we want the log diff exp
# of -inf, -inf to be -inf,
# not nan, because that
# corresponds to log(0 - 0) = -inf
mask = a > b
masktwo = (a == b) & (a < jnp.inf)
return jnp.where(mask,
1.0 * a + log1m_exp(
1.0 * b - 1.0 * a),
jnp.where(masktwo,
-jnp.inf,
jnp.nan))
The text was updated successfully, but these errors were encountered:
I think it is a nice approach. We have those computations across various places. I think you can put those utilities in distributions/util.py file. The implementation looks reasonable to me. How about to discuss the details in your PR?
When writing custom distributions, it is often helpful to have numerically stable implementations of
log_diff_exp(a, b) := log(exp(a) - exp(b))
and particularlylog1m_exp(x) := log(1 - exp(x))
. The naive implementations are not stable for many probabilistic programming use cases, and so probabilistic programming languages including Stan and PyMC provide numerically-stable implementations (typically following Machler, 2012) of these functions.As far as I can tell, Numpyro does not, and they are not present in Jax.
I wonder whether it would be worth providing them. I have written basic implementations following Machler for my own use. I would happily make a PR including them, but someone more experienced could probably write better/more idiomatic ones.
The text was updated successfully, but these errors were encountered: