-
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5055262
commit cb3d501
Showing
7 changed files
with
135 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import sys | ||
|
||
from pymc.model.core import Model | ||
from pymc.sampling.mcmc import sample | ||
from pytensor.graph.rewriting.basic import GraphRewriter | ||
|
||
from pymc_experimental.sampling.optimizations.optimize import ( | ||
TAGS_TYPE, | ||
optimize_model_for_mcmc_sampling, | ||
) | ||
|
||
|
||
def opt_sample( | ||
*args, | ||
model: Model | None = None, | ||
include: TAGS_TYPE = ("default",), | ||
exclude: TAGS_TYPE = None, | ||
rewriter: GraphRewriter | None = None, | ||
verbose: bool = False, | ||
**kwargs, | ||
): | ||
"""Sample from a model after applying optimizations. | ||
Parameters | ||
---------- | ||
model : Model, optinoal | ||
The model to sample from. If None, use the model associated with the context. | ||
include : TAGS_TYPE | ||
The tags to include in the optimizations. Ignored if `rewriter` is not None. | ||
exclude : TAGS_TYPE | ||
The tags to exclude from the optimizations. Ignored if `rewriter` is not None. | ||
rewriter : RewriteDatabaseQuery (optional) | ||
The rewriter to use. If None, use the default rewriter with the given `include` and `exclude` tags. | ||
verbose : bool, default=False | ||
Print information about the optimizations applied. | ||
*args, **kwargs: | ||
Passed to `pm.sample` | ||
Returns | ||
------- | ||
sample_output: | ||
The output of `pm.sample` | ||
Examples | ||
-------- | ||
.. code:: python | ||
import pymc as pm | ||
import pymc_experimental as pmx | ||
with pm.Model() as m: | ||
p = pm.Beta("p", 1, 1, shape=(1000,)) | ||
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250) | ||
idata = pmx.opt_sample(verbose=True) | ||
""" | ||
if kwargs.get("step", None) is not None: | ||
raise ValueError( | ||
"The `step` argument is not supported in `opt_sample`, as custom steps would refer to the original model.\n" | ||
"You can manually transform the model with `pymc_experimental.sampling.optimizations.optimize_model_for_mcmc_sampling` " | ||
"and then define the custom steps and forward them to `pymc.sample`." | ||
) | ||
|
||
opt_model, rewrite_counters = optimize_model_for_mcmc_sampling( | ||
model, include=include, exclude=exclude, rewriter=rewriter | ||
) | ||
|
||
if verbose: | ||
applied_opt = False | ||
for rewrite_counter in rewrite_counters: | ||
for rewrite, counts in rewrite_counter.items(): | ||
applied_opt = True | ||
print(f"Applied optimization: {rewrite} {counts}x", file=sys.stdout) | ||
if not applied_opt: | ||
print("No optimizations applied", file=sys.stdout) | ||
|
||
return sample(*args, model=opt_model, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from collections import Counter | ||
from collections.abc import Sequence | ||
from typing import TypeAlias | ||
|
||
from pymc.model.core import Model, modelcontext | ||
from pymc.model.fgraph import fgraph_from_model, model_from_fgraph | ||
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabaseQuery | ||
|
||
posterior_optimization_db = EquilibriumDB() | ||
posterior_optimization_db.failure_callback = None # Raise an error if an optimization fails | ||
posterior_optimization_db.name = "posterior_optimization_db" | ||
|
||
TAGS_TYPE: TypeAlias = str | Sequence[str] | None | ||
|
||
|
||
def optimize_model_for_mcmc_sampling( | ||
model: Model, | ||
include: TAGS_TYPE = ("default",), | ||
exclude: TAGS_TYPE = None, | ||
rewriter=None, | ||
) -> tuple[Model, Sequence[Counter]]: | ||
if isinstance(include, str): | ||
include = (include,) | ||
if isinstance(exclude, str): | ||
exclude = (exclude,) | ||
|
||
model = modelcontext(model) | ||
fgraph, _ = fgraph_from_model(model) | ||
|
||
if rewriter is None: | ||
rewriter = posterior_optimization_db.query( | ||
RewriteDatabaseQuery(include=include, exclude=exclude) | ||
) | ||
_, _, rewrite_counters, *_ = rewriter.rewrite(fgraph) | ||
|
||
opt_model = model_from_fgraph(fgraph, mutate_fgraph=True) | ||
return opt_model, rewrite_counters |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import pytest | ||
|
||
from pymc.distributions import Beta, Binomial, InverseGamma | ||
from pymc.model.core import Model | ||
from pymc.step_methods import Slice | ||
|
||
from pymc_experimental import opt_sample | ||
|
||
|
||
def test_custom_step_raises(): | ||
with Model() as m: | ||
a = InverseGamma("a", 1, 1) | ||
b = InverseGamma("b", 1, 1) | ||
p = Beta("p", a, b) | ||
y = Binomial("y", n=100, p=p, observed=99) | ||
|
||
with pytest.raises( | ||
ValueError, match="The `step` argument is not supported in `opt_sample`" | ||
): | ||
opt_sample(step=Slice([a, b])) |