-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Example and benchmark of APIs to offload states (#942)
* add benchmarking for offloading states * fix api names
- Loading branch information
Showing
4 changed files
with
233 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Offloading States Example | ||
|
||
The script `offload_states.py` demonstrates how to offload the state of a model. Here is the example usage. | ||
|
||
```bash | ||
$ deepspeed --num_gpus=4 offload_states.py --hidden_dim 32768 --nlayers 4 --pin_memory --non_blocking | ||
... | ||
Memory usage (0): include=None, pin_memory=True, non_blocking=True alloc_before_offload=18198419456 alloc_after_offload=17763840 | ||
Memory usage (1): include=None, pin_memory=True, non_blocking=True alloc_before_offload=18198760960 alloc_after_offload=17763840 | ||
... | ||
Summary: pin_memory=True non_blocking=True offload=5.643414640426636 load=2.4087101459503173 | ||
``` | ||
|
||
`run_benchmark.sh` shows how to run the script with different configurations. The script outputs the time for offloading and loading the states. | ||
|
||
```bash | ||
$ ./run_benchmark.sh | ||
... | ||
| |pin_memory=0_non_blocking=0|pin_memory=0_non_blocking=1|pin_memory=1_non_blocking=0|pin_memory=1_non_blocking=1| | ||
|--:|---------------------------|---------------------------|---------------------------|---------------------------| | ||
| 1|4.34 / 3.42 |4.99 / 2.37 |6.5 / 2.42 |6.0 / 2.39 | | ||
| 2|9.9 / 3.28 |5.1 / 2.34 |6.21 / 2.42 |6.25 / 2.45 | | ||
| 3|9.92 / 3.19 |6.71 / 2.35 |6.33 / 2.38 |5.93 / 2.42 | | ||
| 4|9.55 / 2.82 |7.11 / 2.39 |6.9 / 2.38 |6.5 / 2.43 |... | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import time | ||
import argparse | ||
|
||
import deepspeed.comm as dist | ||
from deepspeed.accelerator import get_accelerator | ||
import torch | ||
|
||
import deepspeed | ||
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum | ||
|
||
|
||
class SimpleModel(torch.nn.Module): | ||
|
||
def __init__(self, hidden_dim, empty_grad=False, nlayers=1): | ||
super(SimpleModel, self).__init__() | ||
self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)]) | ||
if empty_grad: | ||
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) | ||
self.cross_entropy_loss = torch.nn.CrossEntropyLoss() | ||
|
||
def forward(self, x, y): | ||
for l in self.linears: | ||
x = l(x) | ||
return self.cross_entropy_loss(x, y) | ||
|
||
|
||
def random_dataset(total_samples, hidden_dim, device, dtype): | ||
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype) | ||
train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) | ||
train_dataset = torch.utils.data.TensorDataset(train_data, train_label) | ||
return train_dataset | ||
|
||
|
||
def random_dataloader(model, total_samples, hidden_dim, device, dtype): | ||
batch_size = model.train_micro_batch_size_per_gpu() | ||
train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype) | ||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) | ||
return train_loader | ||
|
||
|
||
def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking, iteration, warmup): | ||
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) | ||
data_loader = random_dataloader(model=model, | ||
total_samples=iteration, | ||
hidden_dim=hidden_dim, | ||
device=model.device, | ||
dtype=dtype) | ||
|
||
time_offload_list = [] | ||
time_load_list = [] | ||
|
||
dist.barrier() | ||
for i, batch in enumerate(data_loader): | ||
loss = model(batch[0], batch[1]) | ||
model.backward(loss) | ||
model.step() | ||
|
||
# Start offloading | ||
alloc_before_offload = get_accelerator().memory_allocated() | ||
dist.barrier() | ||
|
||
time_start = time.time() | ||
model.offload_states(include=include, | ||
device=OffloadDeviceEnum.cpu, | ||
pin_memory=pin_memory, | ||
non_blocking=non_blocking) | ||
dist.barrier() | ||
time_after_offload = time.time() | ||
alloc_after_offload = get_accelerator().memory_allocated() | ||
assert alloc_after_offload < alloc_before_offload, f"Allocated memory should decrease after offload" | ||
|
||
# Load offloaded states back | ||
model.reload_states() | ||
dist.barrier() | ||
time_after_load = time.time() | ||
|
||
time_offload_list.append(time_after_offload - time_start) | ||
time_load_list.append(time_after_load - time_after_offload) | ||
|
||
assert alloc_after_offload < get_accelerator().memory_allocated( | ||
), f"Allocated memory should increase after offload back" | ||
|
||
if dist.get_rank() == 0: | ||
print( | ||
f"Memory usage ({i}): include={include}, pin_memory={pin_memory}, non_blocking={non_blocking} alloc_before_offload={alloc_before_offload} alloc_after_offload={alloc_after_offload}" | ||
) | ||
|
||
# remove warmup | ||
time_offload_list = time_offload_list[warmup:] | ||
time_load_list = time_load_list[warmup:] | ||
|
||
if dist.get_rank() == 0: | ||
with open("offload_states.log", "a") as f: | ||
offload_time = sum(time_offload_list) / len(time_offload_list) | ||
load_time = sum(time_load_list) / len(time_load_list) | ||
msg = f"{1 if pin_memory else 0},{1 if non_blocking else 0},{offload_time},{load_time}" | ||
f.write(f"{msg}\n") | ||
print(f"Summary: pin_memory={pin_memory} non_blocking={non_blocking} offload={offload_time} load={load_time}") | ||
|
||
# Needed in ZeRO 3. Not doing so can give memory leak | ||
model.destroy() | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Test Offload States") | ||
parser.add_argument("--included_state", type=str, choices=[e.name for e in OffloadStateTypeEnum] + [None], default=None, help="State to include") | ||
parser.add_argument("--pin_memory", action='store_true', help="Pin memory") | ||
parser.add_argument("--non_blocking", action='store_true', help="Non blocking") | ||
parser.add_argument("--nlayers", type=int, default=1, help="Number of layers") | ||
parser.add_argument("--hidden_dim", type=int, default=1024, help="Hidden dimension") | ||
parser.add_argument('--dtype', choices=['torch.bfloat16', 'torch.float16', 'torch.float32'], default='torch.bfloat16', help='Data type') | ||
parser.add_argument("--local_rank", type=int, default=-1, help="Local rank") | ||
parser.add_argument("--iteration", type=int, default=10, help="Warmup") | ||
parser.add_argument("--warmup", type=int, default=5, help="Warmup") | ||
|
||
args = parser.parse_args() | ||
|
||
dtype = eval(args.dtype) | ||
hidden_dim = args.hidden_dim | ||
|
||
config_dict = { | ||
"train_micro_batch_size_per_gpu": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 1e-6 | ||
} | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
}, | ||
} | ||
|
||
if dtype == torch.float16: | ||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} | ||
elif dtype == torch.bfloat16: | ||
config_dict["bf16"] = {"enabled": True} | ||
|
||
with deepspeed.zero.Init(config_dict_or_path=config_dict): | ||
model = SimpleModel(hidden_dim, nlayers=args.nlayers) | ||
|
||
included_state = None if args.included_state is None else [OffloadStateTypeEnum[args.included_state]] | ||
run_model(model, config_dict, hidden_dim, dtype, included_state, args.pin_memory, args.non_blocking, args.iteration, args.warmup) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import pandas as pd | ||
from pytablewriter import MarkdownTableWriter | ||
|
||
|
||
def read_csv(file_path): | ||
return pd.read_csv(file_path) | ||
|
||
df = read_csv('offload_states.log') | ||
df.columns = ['pin_memory', 'non_blocking', 'offload_time', 'load_time'] | ||
|
||
df['ratio_string'] = df['offload_time'].round(2).astype(str) + " / " + df['load_time'].round(2).astype(str) | ||
|
||
result_df = pd.DataFrame({ | ||
'pin_memory=0_non_blocking=0': df[(df['pin_memory'] == 0) & (df['non_blocking'] == 0)]['ratio_string'].reset_index(drop=True), | ||
'pin_memory=0_non_blocking=1': df[(df['pin_memory'] == 0) & (df['non_blocking'] == 1)]['ratio_string'].reset_index(drop=True), | ||
'pin_memory=1_non_blocking=0': df[(df['pin_memory'] == 1) & (df['non_blocking'] == 0)]['ratio_string'].reset_index(drop=True), | ||
'pin_memory=1_non_blocking=1': df[(df['pin_memory'] == 1) & (df['non_blocking'] == 1)]['ratio_string'].reset_index(drop=True) | ||
}) | ||
result_df = result_df.dropna() | ||
result_df.index = range(1, len(result_df) + 1) | ||
result_df.index.name = 'trial' | ||
# print(result_df) | ||
|
||
writer = MarkdownTableWriter() | ||
writer.from_dataframe(result_df, | ||
add_index_column=True, | ||
) | ||
writer.write_table() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
NGPUS=4 | ||
HIDDEN_SIZE=32768 | ||
NUM_LAYERS=4 | ||
|
||
TRIALS=10 | ||
|
||
PIN_MEMORY_OPTS=(0 1) | ||
NON_BLOCKING_OPTS=(0 1) | ||
|
||
for i in $(seq 1 $TRIALS); do | ||
for PIN_MEMORY in "${PIN_MEMORY_OPTS[@]}"; do | ||
PIN_MEMORY_ARG="" | ||
if [ $PIN_MEMORY -eq 1 ]; then | ||
PIN_MEMORY_ARG="--pin_memory" | ||
fi | ||
|
||
for NON_BLOCKING in "${NON_BLOCKING_OPTS[@]}"; do | ||
NON_BLOCKING_ARG="" | ||
if [ $NON_BLOCKING -eq 1 ]; then | ||
NON_BLOCKING_ARG="--non_blocking" | ||
fi | ||
|
||
echo "Running iteration $i" | ||
deepspeed --num_gpus=$NGPUS offload_states.py --hidden_dim $HIDDEN_SIZE --nlayers $NUM_LAYERS $PIN_MEMORY_ARG $NON_BLOCKING_ARG | ||
done | ||
done | ||
done | ||
python output_table.py |