From 60a631467ca8b166df7e852156c5fa714e912094 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Fri, 26 Apr 2024 12:24:53 -0500 Subject: [PATCH] Add time remaining column to progress bars (#7273) * Add time remaining column to progress bars * Consistent order remaining/elapsed * Disable sample_posterior_predictive taskbar when progressbar=False * Formatting * More formatting * More formatting (why doesnt pre-commit fix this?) * Disable progress bar when progress=False * Set refresh flag in progress bar updates * Typo --- pymc/backends/arviz.py | 2 +- pymc/sampling/forward.py | 4 +++- pymc/sampling/mcmc.py | 4 ++-- pymc/sampling/parallel.py | 6 +++++- pymc/sampling/population.py | 6 ++++-- pymc/smc/sampling.py | 14 ++++++++++++-- pymc/tuning/starting.py | 3 ++- pymc/variational/inference.py | 4 +++- 8 files changed, 32 insertions(+), 11 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 72a02d9091b..40ab7d3ae0e 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -659,7 +659,7 @@ def apply_function_over_dataset( out_dict = _DefaultTrace(n_pts) indices = range(n_pts) - with Progress(console=Console(theme=progressbar_theme)) as progress: + with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress: task = progress.add_task("Computing ...", total=n_pts, visible=progressbar) for idx in indices: out = fn(posterior_pts[idx]) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index fe0f2085bb0..cc1dcd52f6e 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -829,7 +829,9 @@ def sample_posterior_predictive( _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) try: - with Progress(console=Console(theme=progressbar_theme)) as progress: + with Progress( + console=Console(theme=progressbar_theme), disable=not progressbar + ) as progress: task = progress.add_task("Sampling ...", total=samples, visible=progressbar) for idx in np.arange(samples): if nchain > 1: diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 7f750090f81..96190e5ff03 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1041,8 +1041,8 @@ def _sample( for it, diverging in enumerate(sampling_gen): if it >= skip_first and diverging: _pbar_data["divergences"] += 1 - progress.update(task, advance=1) - progress.update(task, advance=1, completed=True) + progress.update(task, refresh=True, advance=1) + progress.update(task, refresh=True, advance=1, completed=True) except KeyboardInterrupt: pass diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index cc6908647ec..c2f9791de5c 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -27,7 +27,7 @@ import numpy as np from rich.console import Console -from rich.progress import BarColumn, Progress, TimeRemainingColumn +from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme from pymc.blocking import DictToArrayBijection @@ -428,7 +428,10 @@ def __init__( BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), + TextColumn("/"), + TimeElapsedColumn(), console=Console(theme=progressbar_theme), + disable=not progressbar, ) self._show_progress = progressbar self._divergences = 0 @@ -465,6 +468,7 @@ def __iter__(self): self._divergences += 1 progress.update( task, + refresh=True, completed=self._completed_draws, total=self._total_draws, description=self._desc.format(self), diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 1627bb8de77..2b0aad2b32a 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -24,7 +24,7 @@ import cloudpickle import numpy as np -from rich.progress import BarColumn, Progress, TimeRemainingColumn +from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn from pymc.backends.base import BaseTrace from pymc.initial_point import PointType @@ -104,7 +104,7 @@ def _sample_population( task = progress.add_task("[red]Sampling...", total=draws, visible=progressbar) for _ in sampling: - progress.update(task, advance=1) + progress.update(task, advance=1, refresh=True) return @@ -180,6 +180,8 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True): BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), + TextColumn("/"), + TimeElapsedColumn(), ) as self._progress: for c, stepper in enumerate(steppers): # enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index d9b76f211ce..db4044a4fe9 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -25,7 +25,13 @@ import numpy as np from arviz import InferenceData -from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn +from rich.progress import ( + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) import pymc @@ -366,6 +372,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): with Progress( TextColumn("{task.description}"), SpinnerColumn(), + TimeRemainingColumn(), + TextColumn("/"), TimeElapsedColumn(), TextColumn("{task.fields[status]}"), ) as progress: @@ -403,6 +411,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): stage = update_data["stage"] beta = update_data["beta"] # update the progress bar for this task: - progress.update(status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id) + progress.update( + status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id, refresh=True + ) return tuple(cloudpickle.loads(r.result()) for r in futures) diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 129d6f89730..09b787c506d 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -178,7 +178,7 @@ def find_MAP( if isinstance(e, StopIteration): pm._log.info(e) finally: - cost_func.progress.update(cost_func.task, completed=cost_func.n_eval) + cost_func.progress.update(cost_func.task, completed=cost_func.n_eval, refresh=True) print(file=sys.stdout) mx0 = RaveledVars(mx0, x0.point_map_info) @@ -223,6 +223,7 @@ def __init__( *Progress.get_default_columns(), TextColumn("{task.fields[loss]}"), console=Console(theme=progressbar_theme), + disable=not progressbar, ) self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="") diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 3d9e6fd8eae..3a9a69add72 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -166,7 +166,9 @@ def fit( def _iterate_without_loss(self, s, n, step_func, progressbar, progressbar_theme, callbacks): i = 0 try: - with Progress(console=Console(theme=progressbar_theme)) as progress: + with Progress( + console=Console(theme=progressbar_theme), disable=not progressbar + ) as progress: task = progress.add_task("Fitting", total=n, visible=progressbar) for i in range(n): step_func()