Skip to content

Commit

Permalink
Merge branch 'tohtana/stage-mii-benchmark' of github.com:microsoft/De…
Browse files Browse the repository at this point in the history
…epSpeedExamples into tohtana/stage-mii-benchmark
  • Loading branch information
tohtana committed Nov 7, 2023
2 parents e85f98a + 5261263 commit d3aabe3
Show file tree
Hide file tree
Showing 29 changed files with 453 additions and 129 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ This folder contains end-to-end applications that use DeepSpeed to train and use
There are several training and finetuning examples so please see the individual folders for specific instructions.

## 3. Inference
The DeepSpeed Huggingface inference [README](./inference/huggingface/README.md) explains how to get started with running DeepSpeed Huggingface inference examples.
- The DeepSpeed-MII inference [README](./inference/mii/README.md) explains how to get started with running model inference with [DeepSpeed-MII](https://github.com/Microsoft/DeepSpeed-MII) and [DeepSpeed-FastGen](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen).
- The DeepSpeed Huggingface inference [README](./inference/huggingface/README.md) explains how to get started with running DeepSpeed Huggingface inference examples.

## 4. Compression
Model compression examples.
Expand Down
137 changes: 137 additions & 0 deletions applications/DeepSpeed-Chat/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

results/
outputs/

.amltconfig
.test_output
*.hdf5
*.h5
15 changes: 10 additions & 5 deletions applications/DeepSpeed-Chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,24 @@ A fast, affordable, scalable and open system framework for enabling end-to-end R
<!-- markdown-toc start - Don't edit this section. Run M-x markdown-toc-refresh-toc -->
## Table of Contents

- [🐕DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales🐕](#deepspeed-chat-easy-fast-and-affordable-rlhf-training-of-chatgpt-like-models-at-all-scales)
- [Table of Contents](#table-of-contents)
- [📰 Latest News 📰](#-latest-news-)
- [🚀 What is DeepSpeed Chat 🚀](#-what-is-deepspeed-chat-)
- [🚀 What is DeepSpeed Chat 🚀](#-what-is-deepspeed-chat-)
- [🧨 Capabilities 🧨](#-capabilities-)
- [☕ Quick Start ☕](#-quick-start-)
- [🐼 Installation](#-installation)
- [🐼 Single Script for Training 3-Step RLHF Pipeline](#-one-single-script-completes-all-three-stages-of-rlhf-training-and-generate-your-first-chatgpt-model)
- [🐼 One Single Script Completes All Three Steps of RLHF Training and Generate Your First ChatGPT Model](#-one-single-script-completes-all-three-steps-of-rlhf-training-and-generate-your-first-chatgpt-model)
- [🐼 Demonstration: Individual Step Fine-Tuning](#-demonstration-individual-step-fine-tuning)
- [🕐 Step 1 - Supervised Fine-Tuning](#-step-1---supervised-fine-tuning)
- [🕑 Step 2 - Reward Model](#-step-2---reward-model)
- [🕒 Step 3 - Reinforcement Learning with Human Feedback](#-step-3---reinforcement-learning-with-human-feedback)
- [🐼 Adding and using your own datasets in DeepSpeed-Chat](#-adding-and-using-your-own-datasets-in-deepspeed-chat)
- [🐼 Customizing RLHF training pipeline via DeepSpeed-Chat’s APIs](#-customizing-your-own-rlhf-training-pipeline-using-deepspeed-chats-rlhf-apis)
- [🐼 Serving Your Model: Plug-in and Test!](#-serving-plug-in-your-final-model-trained-by-deepspeed-chat-and-test-it-out)
- [🐼 Adding and using your own datasets in DeepSpeed-Chat](#-adding-and-using-your-own-datasets-in-deepspeed-chat)
- [🐼 Customizing your own RLHF training pipeline using DeepSpeed-Chat’s RLHF APIs](#-customizing-your-own-rlhf-training-pipeline-using-deepspeed-chats-rlhf-apis)
- [🐼 Serving: Plug-in your final model trained by DeepSpeed-Chat and test it out!](#-serving-plug-in-your-final-model-trained-by-deepspeed-chat-and-test-it-out)
- [🔥 Training Performance Evaluation 🔥](#-training-performance-evaluation-)
- [🐲 Superior Model Scale and Low Training Cost](#-superior-model-scale-and-low-training-cost)
- [🐲 Throughput and Model Size Scalability Comparisons with Existing RLHF Systems](#-throughput-and-model-size-scalability-comparisons-with-existing-rlhf-systems)
- [😽 Supported Models 😽](#-supported-models-)
- [🔬 Build Pipeline Status 🔬](#-build-pipeline-status-)
- [⚓ Documentation and Tutorial ⚓](#-documentation-and-tutorial-)
Expand Down Expand Up @@ -119,6 +123,7 @@ pip install deepspeed>=0.9.0
git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/applications/DeepSpeed-Chat/
pip install -r requirements.txt
pip install -e .
```

### 🐼 One Single Script Completes All Three Steps of RLHF Training and Generate Your First ChatGPT Model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,12 @@
# DeepSpeed Team
import torch
import torch.nn.functional as F
import sys
import os
import time
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))

from utils.utils import print_rank_0
from dschat.utils.utils import print_rank_0


def print_all_ranks(tag, value, rank):
Expand Down Expand Up @@ -62,6 +57,11 @@ def __init__(self, rlhf_engine, args):
self.z3_enabled = args.actor_zero_stage == 3
self.compute_fp32_loss = self.args.compute_fp32_loss

# In case the generated experience is not valid (too short), we use the last valid
# generated experience. Alternatively, we can skip the step (on all workers).
# For now, use the last valid experience which is a simpler solution
self.last_generated_experience = None

# Those value can be changed
self.kl_ctl = 0.1
self.clip_reward_value = 5
Expand Down Expand Up @@ -100,7 +100,8 @@ def _generate_sequence(self, prompts, mask, step):
ans = seq[:, prompt_length:]
valid_ans_len = (ans != self.tokenizer.pad_token_id).sum(dim=-1)

if self.args.print_answers:
if self.args.print_answers and (step % self.args.print_answers_interval
== 0):
print(
f"--- prompt --> step={step}, rank={torch.distributed.get_rank()}, {self.tokenizer.batch_decode(prompts, skip_special_tokens=True)}"
)
Expand All @@ -112,10 +113,24 @@ def _generate_sequence(self, prompts, mask, step):
for i in range(batch_size):
if valid_ans_len[
i] <= 1: # if the answer is shorter than 1 token, drop it
print(
f'Dropping too short generated answer: {step=}: \n'
f'prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n'
f'answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}'
)
continue
else:
out_seq.append(seq[i:i + 1])
out_seq = torch.cat(out_seq, dim=0) # concate output in the batch dim

if not out_seq:
print(
f'All generated results are too short for rank={self.args.local_rank} step={step}\n'
f'-> prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n'
f'-> answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}'
)
return None

out_seq = torch.cat(out_seq, dim=0) # concat output in the batch dim

return out_seq

Expand All @@ -124,6 +139,12 @@ def generate_experience(self, prompts, mask, step):
generate_start = time.time()
seq = self._generate_sequence(prompts, mask, step)
generate_end = time.time()
if seq is None:
assert self.last_generated_experience is not None, f'Invalid generated experience at {step=}'
prompts = self.last_generated_experience['prompts']
seq = self.last_generated_experience['seq']
else:
self.last_generated_experience = {'prompts': prompts, 'seq': seq}
self.train()

pad_token_id = self.tokenizer.pad_token_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from deepspeed.ops.adam import DeepSpeedCPUAdam
from transformers import AutoModelForCausalLM, get_scheduler

from utils.ds_utils import get_train_ds_config, get_eval_ds_config
from utils.module.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
from utils.model.model_utils import create_hf_model, create_critic_model
from utils.utils import get_optimizer_grouped_parameters
from dschat.utils.ds_utils import get_train_ds_config, get_eval_ds_config
from dschat.utils.module.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
from dschat.utils.model.model_utils import create_hf_model, create_critic_model
from dschat.utils.utils import get_optimizer_grouped_parameters
"""
TODOs:
* support HF models for critic (for debugging), must be a previously saved ckpt from step-2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import hashlib
from itertools import chain
from . import raw_datasets
from dschat.utils.data import raw_datasets
from deepspeed.accelerator import get_accelerator


Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

import os
# DeepSpeed Team
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from torch.utils.data import Subset
import re

Expand All @@ -15,7 +16,9 @@ def __init__(self, output_path, seed, local_rank, dataset_name):
self.output_path = output_path
self.seed = seed
self.local_rank = local_rank
if not dataset_name == 'local/jsonfile':
if os.path.exists(dataset_name):
self.raw_datasets = load_from_disk(dataset_name)
elif not dataset_name == 'local/jsonfile':
self.raw_datasets = load_dataset(dataset_name)

def get_train_data(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from huggingface_hub import snapshot_download
from transformers.deepspeed import HfDeepSpeedConfig

from .reward_model import RewardModel
from ..utils import load_state_dict_into_model
from dschat.utils.model.reward_model import RewardModel
from dschat.utils.utils import load_state_dict_into_model, print_rank_0


def configure_dropout(model_config, dropout):
Expand Down Expand Up @@ -41,17 +41,19 @@ def causal_lm_forward(
return_dict=None,
**deprecated_arguments,
):
kwargs = dict() if model.config.model_type == "llama" else dict(
head_mask=head_mask)
output = model.__original_forward__(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
return_dict=return_dict,
**kwargs)

return_dict = isinstance(output, dict)
lm_logits = output.logits if return_dict else output[0]
Expand Down Expand Up @@ -130,8 +132,8 @@ def create_critic_model(model_name_or_path,
critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer,
ds_config, rlhf_training, dropout)
end = time.time()
if torch.distributed.get_rank() == 0:
print(f"> Creating model from_config took {end - start} seconds")
print_rank_0(f">Creating model from_config took {end - start} seconds",
None)

critic_model = RewardModel(
critic_model,
Expand All @@ -152,8 +154,8 @@ def create_critic_model(model_name_or_path,
start = time.time()
model_ckpt_state_dict = torch.load(model_ckpt_path, map_location='cpu')
end = time.time()
if torch.distributed.get_rank() == 0:
print(f"> torch.load took {end - start} seconds")
print_rank_0(f">Creating model from_config took {end - start} seconds",
None)

# load critic model from checkpoint with zero-stage 3 compatibility
# this functionality may be moved to DS checkpoint load API in future
Expand All @@ -163,7 +165,8 @@ def create_critic_model(model_name_or_path,
"",
zero_stage=zero_stage)
end = time.time()
if torch.distributed.get_rank() == 0:
print(f"> Loading model state dict took {end - start} seconds")

print_rank_0(f">Creating model from_config took {end - start} seconds",
None)

return critic_model
Loading

0 comments on commit d3aabe3

Please sign in to comment.