-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* migrate to new lightray * clean up tuning * fix config * use env var in tune.yaml to parse storage dir * fix typo * update tune remote get args
- Loading branch information
Showing
17 changed files
with
285 additions
and
263 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .sandbox import Sandbox, Tune | ||
from .sandbox import Sandbox, TunePipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,56 @@ | ||
import os | ||
from typing import TYPE_CHECKING | ||
|
||
import law | ||
import luigi | ||
from luigi.util import inherits | ||
|
||
from aframe.base import AframeRayTask | ||
from aframe.base import AframeRayTask, AframeSingularityTask, AframeWrapperTask | ||
from aframe.config import ray_head, ray_worker, s3, ssh, wandb | ||
from aframe.targets import Bytes, LawS3Target | ||
from aframe.tasks.train.base import RemoteTrainBase | ||
from aframe.tasks.train.base import RemoteTrainBase, TrainBase | ||
|
||
if TYPE_CHECKING: | ||
from aframe.helm import RayCluster | ||
|
||
|
||
class TuneRemote(RemoteTrainBase, AframeRayTask): | ||
name = luigi.Parameter( | ||
default="ray-tune", | ||
description="Name of the tune job. " | ||
"Will be used to group runs in WandB", | ||
) | ||
search_space = luigi.Parameter( | ||
default="train.tune.search_space", | ||
description="Import path to the search space file " | ||
"used for hyperparameter tuning. This file is expected " | ||
"to contain a dictionary named `space` of the search space", | ||
) | ||
num_samples = luigi.IntParameter(description="Number of trials to run") | ||
min_epochs = luigi.IntParameter( | ||
description="Minimum number of epochs each trial " | ||
"can run before early stopping is considered." | ||
) | ||
max_epochs = luigi.IntParameter( | ||
description="Maximum number of epochs each trial can run" | ||
) | ||
reduction_factor = luigi.IntParameter( | ||
description="Fraction of poor performing trials to stop early" | ||
) | ||
workers_per_trial = luigi.IntParameter( | ||
default=1, description="Number of ray workers to use per trial" | ||
) | ||
gpus_per_worker = luigi.IntParameter( | ||
default=1, description="Number of gpus to allocate to each ray worker" | ||
class TuneLocal(TrainBase, AframeSingularityTask): | ||
tune_config = luigi.Parameter( | ||
description="Path to the `yaml` file used" | ||
" to configure the lightray tune job. " | ||
) | ||
|
||
@property | ||
def default_image(self): | ||
return "train.sif" | ||
|
||
def output(self): | ||
path = self.run_dir / "best.pt" | ||
return law.LocalFileTarget(str(path), format=Bytes) | ||
|
||
def run(self): | ||
from lightray.cli import cli | ||
|
||
args = ["--config", self.tune_config, "--"] | ||
lightning_args = self.get_args() | ||
lightning_args.pop( | ||
0 | ||
) # remove "fit" subcommand since lightray takes care of it | ||
args.extend(lightning_args) | ||
|
||
results = cli(args) | ||
prefix = "s3://" if str(self.run_dir).startswith("s3://") else "" | ||
|
||
# return path to best model weights from best trial | ||
best = results.get_best_result(scope="all") | ||
best = best.get_best_checkpoint() | ||
weights = os.path.join(prefix, best.path, "model.pt") | ||
|
||
# copy the best weights to the output location | ||
s3().client.copy(weights, self.output().path) | ||
|
||
|
||
class TuneRemote(RemoteTrainBase, AframeRayTask): | ||
git_url = luigi.Parameter( | ||
default="[email protected]:ML4GW/aframev2.git", | ||
description="Git repository url to clone and" | ||
|
@@ -52,6 +62,10 @@ class TuneRemote(RemoteTrainBase, AframeRayTask): | |
description="Git branch or commit to checkout. " | ||
"Only used if `dev` is set to True", | ||
) | ||
tune_config = luigi.Parameter( | ||
description="Path to the `yaml` file used" | ||
" to configure the lightray tune job. " | ||
) | ||
|
||
# image used locally to connect to the ray cluster | ||
@property | ||
|
@@ -63,15 +77,6 @@ def use_wandb(self): | |
# always use wandb logging for tune jobs | ||
return True | ||
|
||
def get_ip(self): | ||
""" | ||
Get the ip of the ray cluster that | ||
is stored via an environment variable | ||
""" | ||
ip = os.getenv("AFRAME_RAY_CLUSTER_IP") | ||
ip += ":10001" | ||
return f"ray://{ip}" | ||
|
||
def configure_cluster(self, cluster: "RayCluster"): | ||
# get ssh key for git-sync init container | ||
with open(ssh().ssh_file, "r") as f: | ||
|
@@ -108,44 +113,57 @@ def output(self): | |
return LawS3Target(str(path), format=Bytes) | ||
|
||
def run(self): | ||
from lightray.tune import run | ||
from ray.tune.schedulers import ASHAScheduler | ||
from lightray.cli import cli | ||
|
||
from train.callbacks import TraceModel | ||
from train.cli import AframeCLI | ||
args = ["--config", self.tune_config, "--"] | ||
lightning_args = self.get_args() | ||
lightning_args.pop( | ||
0 | ||
) # remove "fit" subcommand since lightray takes care of it | ||
args.extend(lightning_args) | ||
|
||
args = self.get_args() | ||
|
||
scheduler = ASHAScheduler( | ||
max_t=self.max_epochs, | ||
grace_period=self.min_epochs, | ||
reduction_factor=self.reduction_factor, | ||
) | ||
metric_name = "valid_auroc" | ||
objective = "max" | ||
results = cli(args) | ||
prefix = "s3://" if str(self.run_dir).startswith("s3://") else "" | ||
results = run( | ||
cli_cls=AframeCLI, | ||
name=self.name, | ||
scheduler=scheduler, | ||
metric_name=metric_name, | ||
objective=objective, | ||
search_space=self.search_space, | ||
num_samples=self.num_samples, | ||
workers_per_trial=self.workers_per_trial, | ||
gpus_per_worker=self.gpus_per_worker, | ||
cpus_per_gpu=ray_worker().cpus_per_gpu, | ||
storage_dir=self.run_dir, | ||
callbacks=[TraceModel], | ||
address=self.get_ip(), | ||
args=args, | ||
) | ||
|
||
# return path to best model weights from best trial | ||
best = results.get_best_result( | ||
metric=metric_name, mode=objective, scope="all" | ||
) | ||
best = best.get_best_checkpoint(metric=metric_name, mode=objective) | ||
best = results.get_best_result(scope="all") | ||
best = best.get_best_checkpoint() | ||
weights = os.path.join(prefix, best.path, "model.pt") | ||
|
||
# copy the best weights to the output location | ||
s3().client.copy(weights, self.output().path) | ||
|
||
|
||
@inherits(TuneLocal, TuneRemote) | ||
class Tune(AframeWrapperTask): | ||
""" | ||
Class that dynamically chooses between | ||
remote training on nautilus or local training on LDG. | ||
Useful for incorporating into pipelines where | ||
you don't care where the training is run. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.remote = self.validate_dirs() | ||
|
||
def validate_dirs(self) -> bool: | ||
# train remotely if run_dir stars with s3:// | ||
|
||
# Note: one can specify a remote data_dir, but | ||
# train locally | ||
remote = str(self.run_dir).startswith("s3://") | ||
|
||
if remote and not str(self.data_dir).startswith("s3://"): | ||
raise ValueError( | ||
"If run_dir is an s3 path, data_dir must also be an s3 path" | ||
"Got data_dir: {self.data_dir} and run_dir: {self.run_dir}" | ||
) | ||
return remote | ||
|
||
def requires(self): | ||
if self.remote: | ||
return TuneRemote.req(self) | ||
else: | ||
return TuneLocal.req(self) |
Oops, something went wrong.