From b04ebdb3459bbde2dea526ecdd40ff599319ee11 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Tue, 23 Apr 2024 08:45:35 -0500 Subject: [PATCH] Disable sample_posterior_predictive taskbar when progressbar=False --- pymc/sampling/forward.py | 2 +- pymc/sampling/parallel.py | 2 +- pymc/sampling/population.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index fe0f2085bb0..dc89ebc7035 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -829,7 +829,7 @@ def sample_posterior_predictive( _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) try: - with Progress(console=Console(theme=progressbar_theme)) as progress: + with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress: task = progress.add_task("Sampling ...", total=samples, visible=progressbar) for idx in np.arange(samples): if nchain > 1: diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 9505daef70c..c32ed23ef35 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -429,7 +429,7 @@ def __init__( "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), TextColumn("/"), - TimeElapsedColumn() + TimeElapsedColumn(), console=Console(theme=progressbar_theme), ) self._show_progress = progressbar diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 660b9e60e9d..69cb4793726 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -24,7 +24,7 @@ import cloudpickle import numpy as np -from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn, TimeElaspedColumn +from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn, TimeElapsedColumn from pymc.backends.base import BaseTrace from pymc.initial_point import PointType @@ -181,7 +181,7 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True): "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), TextColumn("/"), - TimeElaspedColumn(), + TimeElapsedColumn(), ) as self._progress: for c, stepper in enumerate(steppers): # enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)