From a5c036a26acc94d2377fc1d837df26050c36c098 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 5 Dec 2024 14:57:37 +0100 Subject: [PATCH] Add Beta-Binomial conjugacy optimization --- .../model/marginal/distributions.py | 14 +- pymc_experimental/sampling/mcmc.py | 38 ++++ .../sampling/optimizations/__init__.py | 1 + .../sampling/optimizations/conjugacy.py | 199 ++++++++++++++++++ .../optimizations/conjugate_sampler.py | 115 ++++++++++ pymc_experimental/utils/ofg.py | 17 ++ tests/sampling/mcmc/test_mcmc.py | 55 +++++ .../sampling/optimizations/test_conjugacy.py | 59 ++++++ 8 files changed, 485 insertions(+), 13 deletions(-) create mode 100644 pymc_experimental/sampling/optimizations/conjugacy.py create mode 100644 pymc_experimental/sampling/optimizations/conjugate_sampler.py create mode 100644 pymc_experimental/utils/ofg.py create mode 100644 tests/sampling/optimizations/test_conjugacy.py diff --git a/pymc_experimental/model/marginal/distributions.py b/pymc_experimental/model/marginal/distributions.py index 661665e9..287e9065 100644 --- a/pymc_experimental/model/marginal/distributions.py +++ b/pymc_experimental/model/marginal/distributions.py @@ -7,7 +7,6 @@ from pymc.logprob.abstract import MeasurableOp, _logprob from pymc.logprob.basic import conditional_logp, logp from pymc.pytensorf import constant_fold -from pytensor import Variable from pytensor.compile.builders import OpFromGraph from pytensor.compile.mode import Mode from pytensor.graph import Op, vectorize_graph @@ -17,6 +16,7 @@ from pytensor.tensor import TensorVariable from pymc_experimental.distributions import DiscreteMarkovChain +from pymc_experimental.utils.ofg import inline_ofg_outputs class MarginalRV(OpFromGraph, MeasurableOp): @@ -126,18 +126,6 @@ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> Tens return logp.transpose(*dims_alignment) -def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: - """Inline the inner graph (outputs) of an OpFromGraph Op. - - Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" - the inner graph. - """ - return clone_replace( - op.inner_outputs, - replace=tuple(zip(op.inner_inputs, inputs)), - ) - - DUMMY_ZERO = pt.constant(0, name="dummy_zero") diff --git a/pymc_experimental/sampling/mcmc.py b/pymc_experimental/sampling/mcmc.py index 0d1085f4..c57e7fa2 100644 --- a/pymc_experimental/sampling/mcmc.py +++ b/pymc_experimental/sampling/mcmc.py @@ -52,6 +52,44 @@ def opt_sample( y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250) idata = pmx.opt_sample(verbose=True) + + # Applied optimization: beta_binomial_conjugacy 1x + # ConjugateRVSampler: [p] + + + You can control which optimizations are applied using the `include` and `exclude` arguments: + + .. code:: python + import pymc as pm + import pymc_experimental as pmx + + with pm.Model() as m: + p = pm.Beta("p", 1, 1, shape=(1000,)) + y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250) + + idata = pmx.opt_sample(exclude="conjugacy", verbose=True) + + # No optimizations applied + # NUTS: [p] + + .. code:: python + import pymc as pm + import pymc_experimental as pmx + + with pm.Model() as m: + a = pm.InverseGamma("a", 1, 1) + b = pm.InverseGamma("b", 1, 1) + p = pm.Beta("p", a, b, shape=(1000,)) + y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250) + + # By default, the conjugacy of p will not be applied because it depends on other free variables + idata = pmx.opt_sample(include="conjugacy-eager", verbose=True) + + # Applied optimization: beta_binomial_conjugacy_eager 1x + # CompoundStep + # >NUTS: [a, b] + # >ConjugateRVSampler: [p] + """ if kwargs.get("step", None) is not None: raise ValueError( diff --git a/pymc_experimental/sampling/optimizations/__init__.py b/pymc_experimental/sampling/optimizations/__init__.py index 363535b4..78d0c858 100644 --- a/pymc_experimental/sampling/optimizations/__init__.py +++ b/pymc_experimental/sampling/optimizations/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 # Add rewrites to the optimization DBs +import pymc_experimental.sampling.optimizations.conjugacy import pymc_experimental.sampling.optimizations.summary_stats from pymc_experimental.sampling.optimizations.optimize import ( diff --git a/pymc_experimental/sampling/optimizations/conjugacy.py b/pymc_experimental/sampling/optimizations/conjugacy.py new file mode 100644 index 00000000..ad357884 --- /dev/null +++ b/pymc_experimental/sampling/optimizations/conjugacy.py @@ -0,0 +1,199 @@ +from collections.abc import Sequence +from functools import partial + +from pymc.distributions import Beta, Binomial +from pymc.model.fgraph import ModelFreeRV, ModelValuedVar, model_free_rv +from pymc.pytensorf import collect_default_updates +from pytensor.graph.basic import Variable, ancestors +from pytensor.graph.fg import FunctionGraph, Output +from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.subtensor import _sum_grad_over_bcasted_dims as sum_bcasted_dims + +from pymc_experimental.sampling.optimizations.conjugate_sampler import ( + ConjugateRV, +) +from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db + + +def register_conjugacy_rewrites_variants(rewrite_fn, tracks=(ModelFreeRV,)): + """Register a rewrite function and its force variant in the posterior optimization DB.""" + name = rewrite_fn.__name__ + + rewrite_fn_default = partial(rewrite_fn, eager=False) + rewrite_fn_default.__name__ = name + rewrite_default = node_rewriter(tracks=tracks)(rewrite_fn_default) + + rewrite_fn_eager = partial(rewrite_fn, eager=True) + rewrite_fn_eager.__name__ = f"{name}_eager" + rewrite_eager = node_rewriter(tracks=tracks)(rewrite_fn_eager) + + posterior_optimization_db.register( + rewrite_default.__name__, + rewrite_default, + "default", + "conjugacy", + ) + + posterior_optimization_db.register( + rewrite_eager.__name__, + rewrite_eager, + "non-default", + "conjugacy-eager", + ) + + return rewrite_default, rewrite_eager + + +def has_free_rv_ancestor(vars: Variable | Sequence[Variable]) -> bool: + """Return True if any of the variables have a model variable as an ancestor.""" + if not isinstance(vars, Sequence): + vars = (vars,) + + # TODO: It should stop at observed RVs, it doesn't matter if they have a free RV above + # Did not implement due to laziness and it being a rare case + return any( + var.owner is not None and isinstance(var.owner.op, ModelFreeRV) for var in ancestors(vars) + ) + + +def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable: + """Return the Model dummy var that wraps the RV""" + for client, _ in fgraph.clients[rv]: + if isinstance(client.op, ModelValuedVar): + return client.outputs[0] + + +def get_dist_params(rv: Variable) -> tuple[Variable]: + return rv.owner.op.dist_params(rv.owner) + + +def rv_used_by( + fgraph: FunctionGraph, + rv: Variable, + used_by_type: type, + used_as_arg_idx: int | Sequence[int], + strict: bool = True, +) -> list[Variable]: + """Return the RVs that use `rv` as an argument in an operation of type `used_by_type`. + + RV may be used directly or broadcasted before being used. + + Parameters + ---------- + fgraph : FunctionGraph + The function graph containing the RVs + rv : Variable + The RV to check for uses. + used_by_type : type + The type of operation that may use the RV. + used_as_arg_idx : int | Sequence[int] + The index of the RV in the operation's inputs. + strict : bool, default=True + If True, return no results when the RV is used in an unrecognized way. + + """ + if isinstance(used_as_arg_idx, int): + used_as_arg_idx = (used_as_arg_idx,) + + clients = fgraph.clients + used_by: list[Variable] = [] + for client, inp_idx in clients[rv]: + if isinstance(client.op, Output): + continue + + if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx: + # RV is directly used by the RV type + used_by.append(client.default_output()) + + elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims: + for sub_client, sub_inp_idx in clients[client.outputs[0]]: + if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx: + # RV is broadcasted and then used by the RV type + used_by.append(sub_client.default_output()) + elif strict: + # Some other unrecognized use, bail out + return [] + elif strict: + # Some other unrecognized use, bail out + return [] + + return used_by + + +def wrap_rv_and_conjugate_rv( + fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable] +) -> Variable: + """Wrap the RV and its conjugate posterior RV in a ConjugateRV node. + + Also takes care of handling the random number generators used in the conjugate posterior. + """ + rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items()) + for rng in rngs: + if rng not in fgraph.inputs: + fgraph.add_input(rng) + conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs]) + return conjugate_op(rv, *inputs, *rngs)[0] + + +def create_untransformed_free_rv( + fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable] +) -> Variable: + """Create a model FreeRV without transform.""" + transform = None + value = rv.type(name=name) + fgraph.add_input(value) + free_rv = model_free_rv(rv, value, transform, *dims) + free_rv.name = name + return free_rv + + +def beta_binomial_conjugacy(fgraph: FunctionGraph, node, eager: bool = False): + if not isinstance(node.op, ModelFreeRV): + return None + + [beta_free_rv] = node.outputs + beta_rv, _, *beta_dims = node.inputs + + if not isinstance(beta_rv.owner.op, Beta): + return None + + _, beta_rv_size, a, b = beta_rv.owner.inputs + if not eager and has_free_rv_ancestor([a, b]): + # Don't apply rewrite if a, b depend on other model variables as that will force a Gibbs sampling scheme + return None + + p_arg_idx = 3 # inputs to Binomial are (rng, size, n, p) + binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx) + + if len(binomial_rvs) != 1: + # Question: Can we apply conjugacy when RV is used by more than one binomial? + return None + + [binomial_rv] = binomial_rvs + + binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv) + if binomial_model_var is None: + return None + + # We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv) + n, _ = get_dist_params(binomial_rv) + + # Use value of y in new graph to avoid circularity + y = binomial_model_var.owner.inputs[1] + + conjugate_a = sum_bcasted_dims(beta_rv, a + y) + conjugate_b = sum_bcasted_dims(beta_rv, b + (n - y)) + + conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b, shape=beta_rv_size) + + new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y]) + new_beta_free_rv = create_untransformed_free_rv( + fgraph, new_beta_rv, beta_free_rv.name, beta_dims + ) + return [new_beta_free_rv] + + +beta_binomial_conjugacy_default, beta_binomial_conjugacy_force = ( + register_conjugacy_rewrites_variants(beta_binomial_conjugacy) +) diff --git a/pymc_experimental/sampling/optimizations/conjugate_sampler.py b/pymc_experimental/sampling/optimizations/conjugate_sampler.py new file mode 100644 index 00000000..9ddaa462 --- /dev/null +++ b/pymc_experimental/sampling/optimizations/conjugate_sampler.py @@ -0,0 +1,115 @@ +import numpy as np + +from pymc import STEP_METHODS +from pymc.distributions.distribution import _support_point +from pymc.initial_point import PointType +from pymc.logprob.abstract import MeasurableOp, _logprob +from pymc.model.core import modelcontext +from pymc.pytensorf import compile_pymc +from pymc.step_methods.compound import BlockedStep, Competence, StepMethodState +from pymc.util import get_value_vars_from_user_vars +from pytensor import shared +from pytensor.compile.builders import OpFromGraph +from pytensor.link.jax.linker import JAXLinker +from pytensor.tensor.random.type import RandomGeneratorType + +from pymc_experimental.utils.ofg import inline_ofg_outputs + + +class ConjugateRV(OpFromGraph, MeasurableOp): + """Wrapper for ConjugateRVs, that outputs the original RV and the conjugate posterior expression. + + For partial step samplers to work, the logp and initial point correspond to the original RV + while the variable itself is sampled by default by the `ConjugateRVSampler` by evaluating directly the + conjugate posterior expression (i.e., taking forward random draws). + """ + + +@_logprob.register(ConjugateRV) +def conjugate_rv_logp(op, values, rv, *params, **kwargs): + # Logp is the same as the original RV + return _logprob(rv.owner.op, values, *rv.owner.inputs) + + +@_support_point.register(ConjugateRV) +def conjugate_rv_support_point(op, conjugate_rv, rv, *params): + # Support point is the same as the original RV + return _support_point(rv.owner.op, rv, *rv.owner.inputs) + + +class ConjugateRVSampler(BlockedStep): + name = "conjugate_rv_sampler" + _state_class = StepMethodState + + def __init__(self, vars, model=None, rng=None, compile_kwargs: dict | None = None, **kwargs): + if len(vars) != 1: + raise ValueError("ConjugateRVSampler can only be assigned to one variable at a time") + + model = modelcontext(model) + [value] = get_value_vars_from_user_vars(vars, model=model) + rv = model.values_to_rvs[value] + self.vars = (value,) + self.rv_name = value.name + + if model.rvs_to_transforms[rv] is not None: + raise ValueError("Variable assigned to ConjugateRVSampler cannot be transformed") + + rv_and_posterior_rv_node = rv.owner + op = rv_and_posterior_rv_node.op + if not isinstance(op, ConjugateRV): + raise ValueError("Variable must be a ConjugateRV") + + # Replace RVs in inputs of rv_posterior_rv_node by the corresponding value variables + value_inputs = model.replace_rvs_by_values( + [rv_and_posterior_rv_node.outputs[1]], + )[0].owner.inputs + # Inline the ConjugateRV graph to only compile `posterior_rv` + _, posterior_rv, *_ = inline_ofg_outputs(op, value_inputs) + + if compile_kwargs is None: + compile_kwargs = {} + self.posterior_fn = compile_pymc( + model.value_vars, + posterior_rv, + random_seed=rng, + on_unused_input="ignore", + **compile_kwargs, + ) + self.posterior_fn.trust_input = True + if isinstance(self.posterior_fn.maker.linker, JAXLinker): + # Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables + # used internally are not the ones that `function.get_shared()` returns. + raise ValueError("ConjugateRVSampler is not compatible with JAX backend") + + def set_rng(self, rng: np.random.Generator): + # Copy the function and replace any shared RNGs + # This is needed so that it can work correctly with multiple traces + # This will be costly if set_rng is called too often! + shared_rngs = [ + var + for var in self.posterior_fn.get_shared() + if isinstance(var.type, RandomGeneratorType) + ] + n_shared_rngs = len(shared_rngs) + swap = { + old_shared_rng: shared(rng, borrow=True) + for old_shared_rng, rng in zip(shared_rngs, rng.spawn(n_shared_rngs), strict=True) + } + self.posterior_fn = self.posterior_fn.copy(swap=swap) + + def step(self, point: PointType) -> tuple[PointType, list]: + new_point = point.copy() + new_point[self.rv_name] = self.posterior_fn(**point) + return new_point, [] + + @staticmethod + def competence(var, has_grad): + """BinaryMetropolis is only suitable for Bernoulli and Categorical variables with k=2.""" + if isinstance(var.owner.op, ConjugateRV): + return Competence.IDEAL + + return Competence.INCOMPATIBLE + + +# Register the ConjugateRVSampler +STEP_METHODS.append(ConjugateRVSampler) diff --git a/pymc_experimental/utils/ofg.py b/pymc_experimental/utils/ofg.py new file mode 100644 index 00000000..6de8ed4a --- /dev/null +++ b/pymc_experimental/utils/ofg.py @@ -0,0 +1,17 @@ +from collections.abc import Sequence + +from pytensor.compile.builders import OpFromGraph +from pytensor.graph.basic import Variable +from pytensor.graph.replace import clone_replace + + +def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: + """Inline the inner graph (outputs) of an OpFromGraph Op. + + Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" + the inner graph. + """ + return clone_replace( + op.inner_outputs, + replace=tuple(zip(op.inner_inputs, inputs)), + ) diff --git a/tests/sampling/mcmc/test_mcmc.py b/tests/sampling/mcmc/test_mcmc.py index dc8f0f30..9b0f18cc 100644 --- a/tests/sampling/mcmc/test_mcmc.py +++ b/tests/sampling/mcmc/test_mcmc.py @@ -1,3 +1,5 @@ +import logging + import numpy as np import pytest @@ -52,3 +54,56 @@ def test_sample_opt_summary_stats(capsys): idata.posterior["sigma"].mean(), opt_idata.posterior["sigma"].mean(), rtol=1e-2 ) assert idata.sample_stats.sampling_time > opt_idata.sample_stats.sampling_time + + +def test_sample_opt_conjugate(caplog, capsys): + caplog.set_level(logging.INFO, logger="pymc") + + sample_kwargs = dict( + include="conjugacy-eager", + tune=0, + draws=250, + chains=4, + progressbar=False, + compute_convergence_checks=False, + ) + + with Model() as m: + p = Beta("p", 1, 1) + y = Binomial("y", n=100, p=p, observed=99) + + idata = opt_sample( + **sample_kwargs, + random_seed=0, + verbose=True, + ) + + assert "Applied optimization: beta_binomial_conjugacy_eager 1x" in capsys.readouterr().out + + # Test it used ConjugateRVSampler + assert "ConjugateRVSampler: [p]" in caplog.text + + np.testing.assert_allclose(idata.posterior["p"].mean(), 100 / 102, atol=1e-3) + np.testing.assert_allclose( + idata.posterior["p"].std(), np.sqrt(100 * 2 / (102**2 * 103)), atol=1e-3 + ) + + # Draws are different across chains + assert (np.diff(idata.posterior["p"].isel(draw=0).values) != 0).all() + + # Check draws respect random_seed + with m: + new_idata = opt_sample( + **sample_kwargs, + random_seed=0, + ) + np.testing.assert_allclose( + idata.posterior["p"].isel(draw=0), new_idata.posterior["p"].isel(draw=0) + ) + + with m: + new_idata = opt_sample( + **sample_kwargs, + random_seed=1, + ) + assert not np.allclose(idata.posterior["p"].isel(draw=0), new_idata.posterior["p"].isel(draw=0)) diff --git a/tests/sampling/optimizations/test_conjugacy.py b/tests/sampling/optimizations/test_conjugacy.py new file mode 100644 index 00000000..b45aabfc --- /dev/null +++ b/tests/sampling/optimizations/test_conjugacy.py @@ -0,0 +1,59 @@ +import numpy as np +import pytest + +from pymc.distributions import Beta, Binomial, DiracDelta +from pymc.model.core import Model +from pymc.model.transform.conditioning import remove_value_transforms +from pymc.sampling import draw + +from pymc_experimental.sampling.optimizations.conjugate_sampler import ConjugateRV +from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling + + +@pytest.mark.parametrize("eager", [False, True]) +def test_beta_binomial_conjugacy(eager): + with Model() as m: + if eager: + a, b = DiracDelta("a,b", [1, 1]) + else: + a, b = 1, 1 + p = Beta("p", a, b) + y = Binomial("y", n=100, p=p, observed=99) + + assert m.rvs_to_transforms[p] is not None + assert isinstance(p.owner.op, Beta) + + new_m, rewrite_counters = optimize_model_for_mcmc_sampling(m) + rewrite_applied = "beta_binomial_conjugacy" in (r.name for rc in rewrite_counters for r in rc) + if eager: + assert not rewrite_applied + new_m, rewrite_counters = optimize_model_for_mcmc_sampling(m, include="conjugacy-eager") + assert "beta_binomial_conjugacy_eager" in (r.name for rc in rewrite_counters for r in rc) + else: + assert rewrite_applied + + new_p = new_m["p"] + assert isinstance(new_p.owner.op, ConjugateRV) + assert new_m.rvs_to_transforms[new_p] is None + beta_rv, conjugate_beta_rv, *_ = new_p.owner.outputs + + # Check it behaves like a beta and its conjugate + beta_draws, conjugate_beta_draws = draw( + [beta_rv, conjugate_beta_rv], draws=1000, random_seed=25 + ) + np.testing.assert_allclose(beta_draws.mean(), 1 / 2, atol=1e-2) + np.testing.assert_allclose(conjugate_beta_draws.mean(), 100 / 102, atol=1e-3) + np.testing.assert_allclose(beta_draws.std(), np.sqrt(1 / 12), atol=1e-2) + np.testing.assert_allclose( + conjugate_beta_draws.std(), np.sqrt(100 * 2 / (102**2 * 103)), atol=1e-3 + ) + + # Check if support point and logp is the same as the original model without transforms + untransformed_m = remove_value_transforms(m) + new_m_ip = new_m.initial_point() + for key, value in untransformed_m.initial_point().items(): + np.testing.assert_allclose(new_m_ip[key], value) + + new_m_logp = new_m.compile_logp()(new_m_ip) + untransformed_m_logp = untransformed_m.compile_logp()(new_m_ip) + np.testing.assert_allclose(new_m_logp, untransformed_m_logp)