Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log1m_exp and log_diff_exp functions #1368

Open
dylanhmorris opened this issue Mar 18, 2022 · 2 comments
Open

log1m_exp and log_diff_exp functions #1368

dylanhmorris opened this issue Mar 18, 2022 · 2 comments
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@dylanhmorris
Copy link
Contributor

dylanhmorris commented Mar 18, 2022

When writing custom distributions, it is often helpful to have numerically stable implementations of log_diff_exp(a, b) := log(exp(a) - exp(b)) and particularly log1m_exp(x) := log(1 - exp(x)). The naive implementations are not stable for many probabilistic programming use cases, and so probabilistic programming languages including Stan and PyMC provide numerically-stable implementations (typically following Machler, 2012) of these functions.

As far as I can tell, Numpyro does not, and they are not present in Jax.

I wonder whether it would be worth providing them. I have written basic implementations following Machler for my own use. I would happily make a PR including them, but someone more experienced could probably write better/more idiomatic ones.

import jax.numpy as jnp


def log1m_exp(x):
    """
    Numerically stable calculation
    of the quantity log(1 - exp(x)),
    following the algorithm of
    Machler [1]. This is
    the algorithm used in TensorFlow Probability,
    PyMC, and Stan, but it is not provided
    yet with Numpyro.

    Currently returns NaN for x > 0,
    but may be modified in the future
    to throw a ValueError

    [1] https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
    """
    # return 0. rather than -0. if
    # we get a negative exponent that exceeds
    # the floating point representation
    arr_x = 1.0 * jnp.array(x)
    oob = arr_x < jnp.log(jnp.finfo(
        arr_x.dtype).smallest_normal)
    mask = arr_x > -0.6931472  # appox -log(2)
    more_val = jnp.log(-jnp.expm1(arr_x))
    less_val = jnp.log1p(-jnp.exp(arr_x))

    return jnp.where(
        oob,
        0.,
        jnp.where(
            mask,
            more_val,
            less_val))


def log_diff_exp(a, b):
    # note that following Stan,
    # we want the log diff exp
    # of -inf, -inf to be -inf,
    # not nan, because that
    # corresponds to log(0 - 0) = -inf
    mask = a > b
    masktwo = (a == b) & (a < jnp.inf)
    return jnp.where(mask,
                     1.0 * a + log1m_exp(
                         1.0 * b - 1.0 * a),
                     jnp.where(masktwo,
                               -jnp.inf,
                               jnp.nan))

@fehiepsi fehiepsi added the enhancement New feature or request label Mar 18, 2022
@fehiepsi
Copy link
Member

I think it is a nice approach. We have those computations across various places. I think you can put those utilities in distributions/util.py file. The implementation looks reasonable to me. How about to discuss the details in your PR?

@dylanhmorris
Copy link
Contributor Author

Sounds good. Will prepare one as soon as I have a chance.

@fehiepsi fehiepsi added the good first issue Good for newcomers label Jul 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants