From c9895af5d194f8f9dd7f83b8f8b4fc7b2ed47441 Mon Sep 17 00:00:00 2001 From: Lukas Welzel Date: Tue, 29 Oct 2024 14:19:24 +0100 Subject: [PATCH] Add GenerationStrategy configuration to hydra-ax-sweeper Implements configs for both ax.modelbridge.generation_strategy.GenerationStrategy and GenerationStep in the ax-sweeper config.py and a constructor method in the CoreAxSweeper class. --- .../hydra_plugins/hydra_ax_sweeper/_core.py | 34 ++++++++++++++++++- .../hydra_plugins/hydra_ax_sweeper/config.py | 20 +++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py b/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py index 4967626c3be..8ec27c7d70f 100644 --- a/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py +++ b/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py @@ -6,6 +6,8 @@ from ax.core import types as ax_types # type: ignore from ax.exceptions.core import SearchSpaceExhausted # type: ignore from ax.service.ax_client import AxClient # type: ignore +from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep # type: ignore +from ax.modelbridge.registry import Models # type: ignore from hydra.core.override_parser.overrides_parser import OverridesParser from hydra.core.override_parser.types import IntervalSweep, Override, Transformer from hydra.core.plugins import Plugins @@ -15,7 +17,7 @@ from omegaconf import DictConfig, OmegaConf from ._earlystopper import EarlyStopper -from .config import AxConfig, ClientConfig, ExperimentConfig +from .config import AxConfig, ClientConfig, ExperimentConfig, GenerationStrategyConfig log = logging.getLogger(__name__) @@ -117,6 +119,7 @@ def __init__(self, ax_config: AxConfig, max_batch_size: Optional[int]): epsilon=ax_config.early_stop.epsilon, minimize=ax_config.early_stop.minimize, ) + self.generation_strategy: GenerationStrategyConfig = ax_config.generation_strategy self.ax_client_config: ClientConfig = ax_config.client self.max_trials = ax_config.max_trials self.ax_params: DictConfig = OmegaConf.create({}) @@ -238,6 +241,30 @@ def sweep_over_batches( trial_index=batch[idx].trial_index, raw_data=val ) + def _create_generation_strategy(self) -> GenerationStrategy: + """Create an Ax GenerationStrategy from configuration.""" + steps = [] + for step_config in self.generation_strategy.steps: + # Convert string model name to Models enum if necessary + model = step_config.model + if isinstance(model, str): + model = getattr(Models, model.upper()) + + step = GenerationStep( + model=model, + num_trials=step_config.num_trials, + max_parallelism=step_config.max_parallelism, + model_kwargs=step_config.model_kwargs, + model_gen_kwargs=step_config.model_gen_kwargs, + ) + steps.append(step) + + return GenerationStrategy( + steps=steps, + name=self.generation_strategy.name, + ) + + def setup_ax_client(self, arguments: List[str]) -> AxClient: """Method to setup the Ax Client""" parameters: List[Dict[Any, Any]] = [] @@ -265,9 +292,14 @@ def setup_ax_client(self, arguments: List[str]) -> AxClient: log.info( f"AxSweeper is optimizing the following parameters: {encoder_parameters_into_string(parameters)}" ) + + generation_strategy = self._create_generation_strategy() + log.info(f"AxSweeper is optimizing the following generation strategy: {generation_strategy}") + ax_client = AxClient( verbose_logging=self.ax_client_config.verbose_logging, random_seed=self.ax_client_config.random_seed, + generation_strategy=generation_strategy, ) ax_client.create_experiment(parameters=parameters, **self.experiment) diff --git a/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/config.py b/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/config.py index 0aeff862f93..9e86949a1cf 100644 --- a/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/config.py +++ b/plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/config.py @@ -40,6 +40,24 @@ class ClientConfig: random_seed: Optional[int] = None +@dataclass +class GenerationStepConfig: + """Configuration for a single step in a generation strategy.""" + model: str = "SOBOL" # e.g., "SOBOL", "GPEI", "MOO" + num_trials: int = 5 # -1 means unlimited trials + max_parallelism: int = 1 + model_kwargs: Dict[str, Any] = field(default_factory=dict) + model_gen_kwargs: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class GenerationStrategyConfig: + """Configuration for the generation strategy""" + steps: List[GenerationStepConfig] = field(default_factory=lambda: [GenerationStepConfig()]) + name: Optional[str] = "default_strategy" + + + @dataclass class AxConfig: # max_trials is application-specific. Tune it for your use case @@ -51,6 +69,8 @@ class AxConfig: # is_noisy = True indicates measurements have unknown uncertainty # is_noisy = False indicates measurements have an uncertainty of zero is_noisy: bool = True + generation_strategy: GenerationStrategyConfig = field(default_factory=GenerationStrategyConfig) + @dataclass