Skip to content

Commit

Permalink
Torch Profiling (#49)
Browse files Browse the repository at this point in the history
* Fix setting smaller run for profiling

* Make style

* Fix test_num_gpu

* Make style

* Update readme
  • Loading branch information
erogol authored Jun 3, 2022
1 parent 5217943 commit cb0237b
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 39 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ We don't use ```.spawn()``` to initiate multi-gpu training since it causes certa
- ```.spawn()``` trains the model in subprocesses and the model in the main process is not updated.
- DataLoader with N processes gets really slow when the N is large.

## Profiling example

- Create the torch profiler as you like and pass it to the trainer.
```python
import torch
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler/"),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
prof = trainer.profile_fit(profiler, epochs=1, small_run=64)
then run Tensorboard
```
- Run the tensorboard.
```console
tensorboard --logdir="./profiler/"
```

## Supported Experiment Loggers
- [Tensorboard](https://www.tensorflow.org/tensorboard) - actively maintained
- [ClearML](https://clear.ml/) - actively maintained
Expand Down
42 changes: 22 additions & 20 deletions tests/test_num_gpus.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,46 @@
import os
from trainer.distribute import get_gpus
from trainer import TrainerArgs
import unittest
from argparse import Namespace
from unittest import TestCase, mock

from trainer import TrainerArgs
from trainer.distribute import get_gpus


class TestGpusStringParsingMethods(TestCase):

@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
def test_parse_gpus_set_in_env_var_and_args(self):
parsed_args = create_args_parser().parse_args(['--gpus', '0,1'])
gpus = get_gpus(parsed_args)
expected_value = ['0']
args = Namespace(gpus="0,1")
gpus = get_gpus(args)
expected_value = ["0"]
self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value))

@mock.patch.dict(os.environ, {})
def test_parse_gpus_set_in_args(self):
parsed_args = create_args_parser().parse_args(['--gpus', '0,1'])
gpus = get_gpus(parsed_args)
expected_value = ['0', '1']
args = Namespace(gpus="0,1")
gpus = get_gpus(args)
expected_value = ["0", "1"]
self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value))

@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})
def test_parse_gpus_set_in_env_var(self):
parsed_args = create_args_parser().parse_args(None)
gpus = get_gpus(parsed_args)
expected_value = ['0', '1']
args = Namespace()
gpus = get_gpus(args)
expected_value = ["0", "1"]
self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value))

@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0, 1 "})
def test_parse_gpus_set_in_env_var_with_spaces(self):
parsed_args = create_args_parser().parse_args(None)
gpus = get_gpus(parsed_args)
expected_value = ['0', '1']
args = Namespace()
gpus = get_gpus(args)
expected_value = ["0", "1"]
self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value))

@mock.patch.dict(os.environ, {})
def test_parse_gpus_set_in_args_with_spaces(self):
parsed_args = create_args_parser().parse_args(['--gpus', '0, 1, 2, 3 '])
gpus = get_gpus(parsed_args)
expected_value = ['0', '1', '2', '3']
args = Namespace(gpus="0, 1, 2, 3 ")
gpus = get_gpus(args)
expected_value = ["0", "1", "2", "3"]
self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value))


Expand All @@ -52,5 +54,5 @@ def create_args_parser():
return parser


if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
unittest.main()
4 changes: 1 addition & 3 deletions trainer/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import subprocess
import time

import torch

from trainer import TrainerArgs, logger


Expand Down Expand Up @@ -64,7 +62,7 @@ def distribute():
def get_gpus(args):
# set active gpus from CUDA_VISIBLE_DEVICES or --gpus
if "CUDA_VISIBLE_DEVICES" in os.environ:
gpus = os.environ['CUDA_VISIBLE_DEVICES']
gpus = os.environ["CUDA_VISIBLE_DEVICES"]
else:
gpus = args.gpus
gpus = list(map(str.strip, gpus.split(",")))
Expand Down
38 changes: 22 additions & 16 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,7 @@ def __init__( # pylint: disable=dangerous-default-value
self.test_samples = None

# only use a subset of the samples if small_run is set
if args.small_run is not None:
print(f"[!] Small Run, only using {args.small_run} samples.")
self.train_samples = None if self.train_samples is None else self.train_samples[: args.small_run]
self.eval_samples = None if self.eval_samples is None else self.eval_samples[: args.small_run]
self.test_samples = None if self.test_samples is None else self.test_samples[: args.small_run]
self.setup_small_run(args.small_run)

# init the model
if model is None and get_model is None:
Expand Down Expand Up @@ -547,6 +543,14 @@ def init_loggers(args: "Coqpit", config: "Coqpit", output_path: str, dashboard_l
dashboard_logger = logger_factory(config, output_path)
return dashboard_logger, c_logger

def setup_small_run(self, small_run: int = None):
"""Use a subset of samples for training, evaluation and testing."""
if small_run is not None:
logger.info("[!] Small Run, only using %i samples.", small_run)
self.train_samples = None if self.train_samples is None else self.train_samples[:small_run]
self.eval_samples = None if self.eval_samples is None else self.eval_samples[:small_run]
self.test_samples = None if self.test_samples is None else self.test_samples[:small_run]

def init_training(
self, args: TrainerArgs, coqpit_overrides: Dict, config: Coqpit = None
): # pylint: disable=no-self-use
Expand Down Expand Up @@ -1520,15 +1524,15 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None):
>>> import torch
>>> profiler = torch.profiler.profile(
>>> activities=[
>>> torch.profiler.ProfilerActivity.CPU,
>>> torch.profiler.ProfilerActivity.CUDA,
>>> ],
>>> schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
>>> on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler/"),
>>> record_shapes=True,
>>> profile_memory=True,
>>> with_stack=True,
>>> activities=[
>>> torch.profiler.ProfilerActivity.CPU,
>>> torch.profiler.ProfilerActivity.CUDA,
>>> ],
>>> schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
>>> on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler/"),
>>> record_shapes=True,
>>> profile_memory=True,
>>> with_stack=True,
>>> )
>>> prof = trainer.profile_fit(profiler, epochs=1, small_run=64)
"""
Expand All @@ -1538,13 +1542,15 @@ def profile_fit(self, torch_profiler, epochs=None, small_run=None):
self.config.epocshs = epochs
# use a smaller set of training samples for profiling
if small_run:
self.config.small_run = small_run
self.setup_small_run(small_run)
# run profiler
self.config.run_eval = False
self.config.test_delay_epochs = 9999999
self.config.epochs = epochs
# set a callback to progress the profiler
self.callbacks_on_train_step_end = [lambda trainer: trainer.torch_profiler.step()] # pylint: disable=attribute-defined-outside-init
self.callbacks_on_train_step_end = [ # pylint: disable=attribute-defined-outside-init
lambda trainer: trainer.torch_profiler.step()
]
# set the profiler to access in the Trainer
self.torch_profiler = torch_profiler # pylint: disable=attribute-defined-outside-init
# set logger output for Tensorboard
Expand Down

0 comments on commit cb0237b

Please sign in to comment.