Skip to content

Commit

Permalink
Tuning clean up (#299)
Browse files Browse the repository at this point in the history
* migrate to new lightray

* clean up tuning

* fix config

* use env var in tune.yaml to parse storage dir

* fix typo

* update tune remote get args
  • Loading branch information
EthanMarx authored Oct 31, 2024
1 parent aabee27 commit 2bb1a4c
Show file tree
Hide file tree
Showing 17 changed files with 285 additions and 263 deletions.
12 changes: 7 additions & 5 deletions aframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ def _get_volumes(self):
volumes[aws_dir] = aws_dir

# bind AFRAME env var data directories
# so users can point to other users directories
# so users can point to other users directories.
# dont bind s3 directories!
for path in AFRAME_DATA_DIRS:
value = os.getenv(path)
if value is not None:
if value is not None and not value.startswith("s3://"):
volumes[value] = value
return volumes

Expand Down Expand Up @@ -266,10 +267,11 @@ def configure_cluster(self, cluster):
return cluster

def sandbox_env(self, _):
# hacky way to pass cluster ip to sandbox task
# that gets run in the container.
# set the ray address environment variable
# in the container to the cluster ip
# Ray will use this env variable to initialize cluster
env = super().sandbox_env(_)
env["AFRAME_RAY_CLUSTER_IP"] = self.ip
env["RAY_ADDRESS"] = f"ray://{self.ip}:10001"
env["WANDB_API_KEY"] = wandb().api_key
env["WANDB_USERNAME"] = wandb().username

Expand Down
2 changes: 1 addition & 1 deletion aframe/pipelines/sandbox/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .sandbox import Sandbox, Tune
from .sandbox import Sandbox, TunePipeline
21 changes: 16 additions & 5 deletions aframe/pipelines/sandbox/configs/tune.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,25 @@ inherit = $AFRAME_REPO/aframe/pipelines/sandbox/configs/base.cfg

[luigi_TuneRemote]
config = $AFRAME_REPO/projects/train/config.yaml
tune_config = $AFRAME_REPO/projects/train/configs/tune.yaml
ifos = &::luigi_base::ifos
kernel_length = &::luigi_base::kernel_length
sample_rate = &::luigi_base::sample_rate
highpass = &::luigi_base::highpass
fduration = &::luigi_base::fduration
seed = &::luigi_base::seed
reduction_factor = 2
min_epochs = 20
max_epochs = 200
num_samples = 512
name = first-full-tune


[luigi_ray_head]
cpus = 32
memory = 32G

# configure how many pods
# and how many gpus per pod
[luigi_ray_worker]
replicas = 1
gpus_per_replica = 2

# set path to your ssh file
# if mounting in remote code to kubernetes pod
[luigi_ssh]
6 changes: 3 additions & 3 deletions aframe/pipelines/sandbox/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from aframe.tasks import TestingWaveforms, Train
from aframe.tasks.infer import Infer
from aframe.tasks.plots.sv import SensitiveVolume
from aframe.tasks.train.tune import TuneRemote
from aframe.tasks.train.tune import Tune


class SandboxInfer(Infer):
Expand Down Expand Up @@ -34,5 +34,5 @@ class Sandbox(_Sandbox):
train_task = Train


class Tune(_Sandbox):
train_task = TuneRemote
class TunePipeline(_Sandbox):
train_task = Tune
1 change: 1 addition & 0 deletions aframe/tasks/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def configure_data_args(self, args: List[str]) -> None:

def get_args(self):
args = [
"fit",
"--config",
self.config,
"--seed_everything",
Expand Down
162 changes: 90 additions & 72 deletions aframe/tasks/train/tune.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,56 @@
import os
from typing import TYPE_CHECKING

import law
import luigi
from luigi.util import inherits

from aframe.base import AframeRayTask
from aframe.base import AframeRayTask, AframeSingularityTask, AframeWrapperTask
from aframe.config import ray_head, ray_worker, s3, ssh, wandb
from aframe.targets import Bytes, LawS3Target
from aframe.tasks.train.base import RemoteTrainBase
from aframe.tasks.train.base import RemoteTrainBase, TrainBase

if TYPE_CHECKING:
from aframe.helm import RayCluster


class TuneRemote(RemoteTrainBase, AframeRayTask):
name = luigi.Parameter(
default="ray-tune",
description="Name of the tune job. "
"Will be used to group runs in WandB",
)
search_space = luigi.Parameter(
default="train.tune.search_space",
description="Import path to the search space file "
"used for hyperparameter tuning. This file is expected "
"to contain a dictionary named `space` of the search space",
)
num_samples = luigi.IntParameter(description="Number of trials to run")
min_epochs = luigi.IntParameter(
description="Minimum number of epochs each trial "
"can run before early stopping is considered."
)
max_epochs = luigi.IntParameter(
description="Maximum number of epochs each trial can run"
)
reduction_factor = luigi.IntParameter(
description="Fraction of poor performing trials to stop early"
)
workers_per_trial = luigi.IntParameter(
default=1, description="Number of ray workers to use per trial"
)
gpus_per_worker = luigi.IntParameter(
default=1, description="Number of gpus to allocate to each ray worker"
class TuneLocal(TrainBase, AframeSingularityTask):
tune_config = luigi.Parameter(
description="Path to the `yaml` file used"
" to configure the lightray tune job. "
)

@property
def default_image(self):
return "train.sif"

def output(self):
path = self.run_dir / "best.pt"
return law.LocalFileTarget(str(path), format=Bytes)

def run(self):
from lightray.cli import cli

args = ["--config", self.tune_config, "--"]
lightning_args = self.get_args()
lightning_args.pop(
0
) # remove "fit" subcommand since lightray takes care of it
args.extend(lightning_args)

results = cli(args)
prefix = "s3://" if str(self.run_dir).startswith("s3://") else ""

# return path to best model weights from best trial
best = results.get_best_result(scope="all")
best = best.get_best_checkpoint()
weights = os.path.join(prefix, best.path, "model.pt")

# copy the best weights to the output location
s3().client.copy(weights, self.output().path)


class TuneRemote(RemoteTrainBase, AframeRayTask):
git_url = luigi.Parameter(
default="[email protected]:ML4GW/aframev2.git",
description="Git repository url to clone and"
Expand All @@ -52,6 +62,10 @@ class TuneRemote(RemoteTrainBase, AframeRayTask):
description="Git branch or commit to checkout. "
"Only used if `dev` is set to True",
)
tune_config = luigi.Parameter(
description="Path to the `yaml` file used"
" to configure the lightray tune job. "
)

# image used locally to connect to the ray cluster
@property
Expand All @@ -63,15 +77,6 @@ def use_wandb(self):
# always use wandb logging for tune jobs
return True

def get_ip(self):
"""
Get the ip of the ray cluster that
is stored via an environment variable
"""
ip = os.getenv("AFRAME_RAY_CLUSTER_IP")
ip += ":10001"
return f"ray://{ip}"

def configure_cluster(self, cluster: "RayCluster"):
# get ssh key for git-sync init container
with open(ssh().ssh_file, "r") as f:
Expand Down Expand Up @@ -108,44 +113,57 @@ def output(self):
return LawS3Target(str(path), format=Bytes)

def run(self):
from lightray.tune import run
from ray.tune.schedulers import ASHAScheduler
from lightray.cli import cli

from train.callbacks import TraceModel
from train.cli import AframeCLI
args = ["--config", self.tune_config, "--"]
lightning_args = self.get_args()
lightning_args.pop(
0
) # remove "fit" subcommand since lightray takes care of it
args.extend(lightning_args)

args = self.get_args()

scheduler = ASHAScheduler(
max_t=self.max_epochs,
grace_period=self.min_epochs,
reduction_factor=self.reduction_factor,
)
metric_name = "valid_auroc"
objective = "max"
results = cli(args)
prefix = "s3://" if str(self.run_dir).startswith("s3://") else ""
results = run(
cli_cls=AframeCLI,
name=self.name,
scheduler=scheduler,
metric_name=metric_name,
objective=objective,
search_space=self.search_space,
num_samples=self.num_samples,
workers_per_trial=self.workers_per_trial,
gpus_per_worker=self.gpus_per_worker,
cpus_per_gpu=ray_worker().cpus_per_gpu,
storage_dir=self.run_dir,
callbacks=[TraceModel],
address=self.get_ip(),
args=args,
)

# return path to best model weights from best trial
best = results.get_best_result(
metric=metric_name, mode=objective, scope="all"
)
best = best.get_best_checkpoint(metric=metric_name, mode=objective)
best = results.get_best_result(scope="all")
best = best.get_best_checkpoint()
weights = os.path.join(prefix, best.path, "model.pt")

# copy the best weights to the output location
s3().client.copy(weights, self.output().path)


@inherits(TuneLocal, TuneRemote)
class Tune(AframeWrapperTask):
"""
Class that dynamically chooses between
remote training on nautilus or local training on LDG.
Useful for incorporating into pipelines where
you don't care where the training is run.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.remote = self.validate_dirs()

def validate_dirs(self) -> bool:
# train remotely if run_dir stars with s3://

# Note: one can specify a remote data_dir, but
# train locally
remote = str(self.run_dir).startswith("s3://")

if remote and not str(self.data_dir).startswith("s3://"):
raise ValueError(
"If run_dir is an s3 path, data_dir must also be an s3 path"
"Got data_dir: {self.data_dir} and run_dir: {self.run_dir}"
)
return remote

def requires(self):
if self.remote:
return TuneRemote.req(self)
else:
return TuneLocal.req(self)
Loading

0 comments on commit 2bb1a4c

Please sign in to comment.