Skip to content

Commit

Permalink
Add Normal summary stats optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 7, 2024
1 parent cb3d501 commit 1469915
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 1 deletion.
13 changes: 13 additions & 0 deletions pymc_experimental/sampling/optimizations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# ruff: noqa: F401
# Add rewrites to the optimization DBs
import pymc_experimental.sampling.optimizations.summary_stats

from pymc_experimental.sampling.optimizations.optimize import (
optimize_model_for_mcmc_sampling,
posterior_optimization_db,
)

__all__ = [
"posterior_optimization_db",
"optimize_model_for_mcmc_sampling",
]
79 changes: 79 additions & 0 deletions pymc_experimental/sampling/optimizations/summary_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytensor.tensor as pt

from pymc.distributions import Gamma, Normal
from pymc.model.fgraph import ModelObservedRV, model_observed_rv
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter

from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db


@node_rewriter(tracks=[ModelObservedRV])
def summary_stats_normal(fgraph: FunctionGraph, node):
"""Applies the equivalence (up to a normalizing constant) described in:
https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics
"""
[observed_rv] = node.outputs
[rv, data] = node.inputs

if not isinstance(rv.owner.op, Normal):
return None

# Check the normal RV is not just a scalar
if all(rv.type.broadcastable):
return None

# Check that the observed RV is not used anywhere else (like a Potential or Deterministic)
# There should be only one use: as an "output"
if len(fgraph.clients[observed_rv]) > 1:
return None

mu, sigma = rv.owner.op.dist_params(rv.owner)

# Check if mu and sigma are scalar RVs
if not all(mu.type.broadcastable) and not all(sigma.type.broadcastable):
return None

# Check that mu and sigma are not used anywhere else
# Note: This is too restrictive, it's fine if they're used in Deterministics!
# There should only be two uses: as an "output" and as the param of the `rv`
if len(fgraph.clients[mu]) > 2 or len(fgraph.clients[sigma]) > 2:
return None

# Remove expand_dims
mu = mu.squeeze()
sigma = sigma.squeeze()

# Apply the rewrite
mean_data = pt.mean(data)
mean_data.name = None
var_data = pt.var(data, ddof=1)
var_data.name = None
N = data.size
sqrt_N = pt.sqrt(N)
nm1_over2 = (N - 1) / 2

observed_mean = model_observed_rv(
Normal.dist(mu=mu, sigma=sigma / sqrt_N),
mean_data,
)
observed_mean.name = f"{rv.name}_mean"

observed_var = model_observed_rv(
Gamma.dist(alpha=nm1_over2, beta=nm1_over2 / (sigma**2)),
var_data,
)
observed_var.name = f"{rv.name}_var"

fgraph.add_output(observed_mean, import_missing=True)
fgraph.add_output(observed_var, import_missing=True)
fgraph.remove_node(node)
# Just so it shows in the profile for verbose=True,
# It won't do anything because node is not in the fgraph anymore
return [node.out.copy()]


posterior_optimization_db.register(
summary_stats_normal.__name__, summary_stats_normal, "default", "summary_stats"
)
Empty file added tests/sampling/__init__.py
Empty file.
36 changes: 35 additions & 1 deletion tests/sampling/mcmc/test_mcmc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import pytest

from pymc.distributions import Beta, Binomial, InverseGamma
from pymc.distributions import Beta, Binomial, HalfNormal, InverseGamma, Normal
from pymc.model.core import Model
from pymc.sampling.mcmc import sample
from pymc.step_methods import Slice

from pymc_experimental import opt_sample
Expand All @@ -18,3 +20,35 @@ def test_custom_step_raises():
ValueError, match="The `step` argument is not supported in `opt_sample`"
):
opt_sample(step=Slice([a, b]))


def test_sample_opt_summary_stats(capsys):
rng = np.random.default_rng(3)
y_data = rng.normal(loc=1, scale=0.5, size=(1000,))

with Model() as m:
mu = Normal("mu")
sigma = HalfNormal("sigma")
y = Normal("y", mu=mu, sigma=sigma, observed=y_data)

sample_kwargs = dict(
chains=1, tune=500, draws=500, compute_convergence_checks=False, progressbar=False
)
idata = sample(**sample_kwargs)
# TODO: Make extract_data more robust to avoid this warning/error
# Or alternatively extract data on the original model, not the optimized one
with pytest.warns(UserWarning, match="Could not extract data from symbolic observation"):
opt_idata = opt_sample(**sample_kwargs, verbose=True)

captured_out = capsys.readouterr().out
assert "Applied optimization: summary_stats_normal 1x" in captured_out

assert opt_idata.posterior.sizes["chain"] == 1
assert opt_idata.posterior.sizes["draw"] == 500
np.testing.assert_allclose(
idata.posterior["mu"].mean(), opt_idata.posterior["mu"].mean(), rtol=1e-2
)
np.testing.assert_allclose(
idata.posterior["sigma"].mean(), opt_idata.posterior["sigma"].mean(), rtol=1e-2
)
assert idata.sample_stats.sampling_time > opt_idata.sample_stats.sampling_time
Empty file.
43 changes: 43 additions & 0 deletions tests/sampling/optimizations/test_summary_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np

from pymc.distributions import HalfNormal, Normal
from pymc.model.core import Model

from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling


def test_summary_stats_normal():
rng = np.random.default_rng(3)
y_data = rng.normal(loc=1, scale=0.5, size=(1000,))

with Model() as m:
mu = Normal("mu")
sigma = HalfNormal("sigma")
y = Normal("y", mu=mu, sigma=sigma, observed=y_data)

assert len(m.free_RVs) == 2
assert len(m.observed_RVs) == 1

new_m, rewrite_counters = optimize_model_for_mcmc_sampling(m)
assert "summary_stats_normal" in (r.name for rc in rewrite_counters for r in rc)

assert len(new_m.free_RVs) == 2
assert len(new_m.observed_RVs) == 2

# Confirm equivalent (up to an additive normalization constant)
m_logp = m.compile_logp()
new_m_logp = new_m.compile_logp()

ip = m.initial_point()
first_logp_diff = m_logp(ip) - new_m_logp(ip)

ip["mu"] += 0.5
ip["sigma_log__"] += 1.5
second_logp_diff = m_logp(ip) - new_m_logp(ip)

np.testing.assert_allclose(first_logp_diff, second_logp_diff)

# dlogp should be the same
m_dlogp = m.compile_dlogp()
new_m_dlogp = new_m.compile_dlogp()
np.testing.assert_allclose(m_dlogp(ip), new_m_dlogp(ip))

0 comments on commit 1469915

Please sign in to comment.