-
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Normal summary stats optimization
- Loading branch information
1 parent
cb3d501
commit 1469915
Showing
6 changed files
with
170 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# ruff: noqa: F401 | ||
# Add rewrites to the optimization DBs | ||
import pymc_experimental.sampling.optimizations.summary_stats | ||
|
||
from pymc_experimental.sampling.optimizations.optimize import ( | ||
optimize_model_for_mcmc_sampling, | ||
posterior_optimization_db, | ||
) | ||
|
||
__all__ = [ | ||
"posterior_optimization_db", | ||
"optimize_model_for_mcmc_sampling", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import pytensor.tensor as pt | ||
|
||
from pymc.distributions import Gamma, Normal | ||
from pymc.model.fgraph import ModelObservedRV, model_observed_rv | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.graph.rewriting.basic import node_rewriter | ||
|
||
from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db | ||
|
||
|
||
@node_rewriter(tracks=[ModelObservedRV]) | ||
def summary_stats_normal(fgraph: FunctionGraph, node): | ||
"""Applies the equivalence (up to a normalizing constant) described in: | ||
https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics | ||
""" | ||
[observed_rv] = node.outputs | ||
[rv, data] = node.inputs | ||
|
||
if not isinstance(rv.owner.op, Normal): | ||
return None | ||
|
||
# Check the normal RV is not just a scalar | ||
if all(rv.type.broadcastable): | ||
return None | ||
|
||
# Check that the observed RV is not used anywhere else (like a Potential or Deterministic) | ||
# There should be only one use: as an "output" | ||
if len(fgraph.clients[observed_rv]) > 1: | ||
return None | ||
|
||
mu, sigma = rv.owner.op.dist_params(rv.owner) | ||
|
||
# Check if mu and sigma are scalar RVs | ||
if not all(mu.type.broadcastable) and not all(sigma.type.broadcastable): | ||
return None | ||
|
||
# Check that mu and sigma are not used anywhere else | ||
# Note: This is too restrictive, it's fine if they're used in Deterministics! | ||
# There should only be two uses: as an "output" and as the param of the `rv` | ||
if len(fgraph.clients[mu]) > 2 or len(fgraph.clients[sigma]) > 2: | ||
return None | ||
|
||
# Remove expand_dims | ||
mu = mu.squeeze() | ||
sigma = sigma.squeeze() | ||
|
||
# Apply the rewrite | ||
mean_data = pt.mean(data) | ||
mean_data.name = None | ||
var_data = pt.var(data, ddof=1) | ||
var_data.name = None | ||
N = data.size | ||
sqrt_N = pt.sqrt(N) | ||
nm1_over2 = (N - 1) / 2 | ||
|
||
observed_mean = model_observed_rv( | ||
Normal.dist(mu=mu, sigma=sigma / sqrt_N), | ||
mean_data, | ||
) | ||
observed_mean.name = f"{rv.name}_mean" | ||
|
||
observed_var = model_observed_rv( | ||
Gamma.dist(alpha=nm1_over2, beta=nm1_over2 / (sigma**2)), | ||
var_data, | ||
) | ||
observed_var.name = f"{rv.name}_var" | ||
|
||
fgraph.add_output(observed_mean, import_missing=True) | ||
fgraph.add_output(observed_var, import_missing=True) | ||
fgraph.remove_node(node) | ||
# Just so it shows in the profile for verbose=True, | ||
# It won't do anything because node is not in the fgraph anymore | ||
return [node.out.copy()] | ||
|
||
|
||
posterior_optimization_db.register( | ||
summary_stats_normal.__name__, summary_stats_normal, "default", "summary_stats" | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import numpy as np | ||
|
||
from pymc.distributions import HalfNormal, Normal | ||
from pymc.model.core import Model | ||
|
||
from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling | ||
|
||
|
||
def test_summary_stats_normal(): | ||
rng = np.random.default_rng(3) | ||
y_data = rng.normal(loc=1, scale=0.5, size=(1000,)) | ||
|
||
with Model() as m: | ||
mu = Normal("mu") | ||
sigma = HalfNormal("sigma") | ||
y = Normal("y", mu=mu, sigma=sigma, observed=y_data) | ||
|
||
assert len(m.free_RVs) == 2 | ||
assert len(m.observed_RVs) == 1 | ||
|
||
new_m, rewrite_counters = optimize_model_for_mcmc_sampling(m) | ||
assert "summary_stats_normal" in (r.name for rc in rewrite_counters for r in rc) | ||
|
||
assert len(new_m.free_RVs) == 2 | ||
assert len(new_m.observed_RVs) == 2 | ||
|
||
# Confirm equivalent (up to an additive normalization constant) | ||
m_logp = m.compile_logp() | ||
new_m_logp = new_m.compile_logp() | ||
|
||
ip = m.initial_point() | ||
first_logp_diff = m_logp(ip) - new_m_logp(ip) | ||
|
||
ip["mu"] += 0.5 | ||
ip["sigma_log__"] += 1.5 | ||
second_logp_diff = m_logp(ip) - new_m_logp(ip) | ||
|
||
np.testing.assert_allclose(first_logp_diff, second_logp_diff) | ||
|
||
# dlogp should be the same | ||
m_dlogp = m.compile_dlogp() | ||
new_m_dlogp = new_m.compile_dlogp() | ||
np.testing.assert_allclose(m_dlogp(ip), new_m_dlogp(ip)) |