Skip to content

Commit

Permalink
Add opt_sample
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 7, 2024
1 parent 5055262 commit cb3d501
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ methods in the current release of PyMC experimental.
MarginalModel
marginalize
model_builder.ModelBuilder
opt_sample

Inference
=========
Expand Down
1 change: 1 addition & 0 deletions pymc_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pymc_experimental.inference.fit import fit
from pymc_experimental.model.marginal.marginal_model import MarginalModel, marginalize
from pymc_experimental.model.model_api import as_model
from pymc_experimental.sampling.mcmc import opt_sample
from pymc_experimental.version import __version__

_log = logging.getLogger("pmx")
Expand Down
Empty file.
76 changes: 76 additions & 0 deletions pymc_experimental/sampling/mcmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import sys

from pymc.model.core import Model
from pymc.sampling.mcmc import sample
from pytensor.graph.rewriting.basic import GraphRewriter

from pymc_experimental.sampling.optimizations.optimize import (
TAGS_TYPE,
optimize_model_for_mcmc_sampling,
)


def opt_sample(
*args,
model: Model | None = None,
include: TAGS_TYPE = ("default",),
exclude: TAGS_TYPE = None,
rewriter: GraphRewriter | None = None,
verbose: bool = False,
**kwargs,
):
"""Sample from a model after applying optimizations.
Parameters
----------
model : Model, optinoal
The model to sample from. If None, use the model associated with the context.
include : TAGS_TYPE
The tags to include in the optimizations. Ignored if `rewriter` is not None.
exclude : TAGS_TYPE
The tags to exclude from the optimizations. Ignored if `rewriter` is not None.
rewriter : RewriteDatabaseQuery (optional)
The rewriter to use. If None, use the default rewriter with the given `include` and `exclude` tags.
verbose : bool, default=False
Print information about the optimizations applied.
*args, **kwargs:
Passed to `pm.sample`
Returns
-------
sample_output:
The output of `pm.sample`
Examples
--------
.. code:: python
import pymc as pm
import pymc_experimental as pmx
with pm.Model() as m:
p = pm.Beta("p", 1, 1, shape=(1000,))
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)
idata = pmx.opt_sample(verbose=True)
"""
if kwargs.get("step", None) is not None:
raise ValueError(
"The `step` argument is not supported in `opt_sample`, as custom steps would refer to the original model.\n"
"You can manually transform the model with `pymc_experimental.sampling.optimizations.optimize_model_for_mcmc_sampling` "
"and then define the custom steps and forward them to `pymc.sample`."
)

opt_model, rewrite_counters = optimize_model_for_mcmc_sampling(
model, include=include, exclude=exclude, rewriter=rewriter
)

if verbose:
applied_opt = False
for rewrite_counter in rewrite_counters:
for rewrite, counts in rewrite_counter.items():
applied_opt = True
print(f"Applied optimization: {rewrite} {counts}x", file=sys.stdout)
if not applied_opt:
print("No optimizations applied", file=sys.stdout)

return sample(*args, model=opt_model, **kwargs)
37 changes: 37 additions & 0 deletions pymc_experimental/sampling/optimizations/optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from collections import Counter
from collections.abc import Sequence
from typing import TypeAlias

from pymc.model.core import Model, modelcontext
from pymc.model.fgraph import fgraph_from_model, model_from_fgraph
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabaseQuery

posterior_optimization_db = EquilibriumDB()
posterior_optimization_db.failure_callback = None # Raise an error if an optimization fails
posterior_optimization_db.name = "posterior_optimization_db"

TAGS_TYPE: TypeAlias = str | Sequence[str] | None


def optimize_model_for_mcmc_sampling(
model: Model,
include: TAGS_TYPE = ("default",),
exclude: TAGS_TYPE = None,
rewriter=None,
) -> tuple[Model, Sequence[Counter]]:
if isinstance(include, str):
include = (include,)
if isinstance(exclude, str):
exclude = (exclude,)

model = modelcontext(model)
fgraph, _ = fgraph_from_model(model)

if rewriter is None:
rewriter = posterior_optimization_db.query(
RewriteDatabaseQuery(include=include, exclude=exclude)
)
_, _, rewrite_counters, *_ = rewriter.rewrite(fgraph)

opt_model = model_from_fgraph(fgraph, mutate_fgraph=True)
return opt_model, rewrite_counters
Empty file added tests/sampling/mcmc/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions tests/sampling/mcmc/test_mcmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from pymc.distributions import Beta, Binomial, InverseGamma
from pymc.model.core import Model
from pymc.step_methods import Slice

from pymc_experimental import opt_sample


def test_custom_step_raises():
with Model() as m:
a = InverseGamma("a", 1, 1)
b = InverseGamma("b", 1, 1)
p = Beta("p", a, b)
y = Binomial("y", n=100, p=p, observed=99)

with pytest.raises(
ValueError, match="The `step` argument is not supported in `opt_sample`"
):
opt_sample(step=Slice([a, b]))

0 comments on commit cb3d501

Please sign in to comment.