diff --git a/README.md b/README.md index 014f5683f..da961c0ee 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,8 @@ This folder contains end-to-end applications that use DeepSpeed to train and use There are several training and finetuning examples so please see the individual folders for specific instructions. ## 3. Inference -The DeepSpeed Huggingface inference [README](./inference/huggingface/README.md) explains how to get started with running DeepSpeed Huggingface inference examples. +- The DeepSpeed-MII inference [README](./inference/mii/README.md) explains how to get started with running model inference with [DeepSpeed-MII](https://github.com/Microsoft/DeepSpeed-MII) and [DeepSpeed-FastGen](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen). +- The DeepSpeed Huggingface inference [README](./inference/huggingface/README.md) explains how to get started with running DeepSpeed Huggingface inference examples. ## 4. Compression Model compression examples. diff --git a/applications/DeepSpeed-Chat/.gitignore b/applications/DeepSpeed-Chat/.gitignore new file mode 100644 index 000000000..cbc922f8f --- /dev/null +++ b/applications/DeepSpeed-Chat/.gitignore @@ -0,0 +1,137 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +results/ +outputs/ + +.amltconfig +.test_output +*.hdf5 +*.h5 diff --git a/applications/DeepSpeed-Chat/README.md b/applications/DeepSpeed-Chat/README.md index c0a456823..ac9f3ab6a 100644 --- a/applications/DeepSpeed-Chat/README.md +++ b/applications/DeepSpeed-Chat/README.md @@ -33,20 +33,24 @@ A fast, affordable, scalable and open system framework for enabling end-to-end R ## Table of Contents +- [πŸ•DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All ScalesπŸ•](#deepspeed-chat-easy-fast-and-affordable-rlhf-training-of-chatgpt-like-models-at-all-scales) +- [Table of Contents](#table-of-contents) - [πŸ“° Latest News πŸ“°](#-latest-news-) -- [πŸš€ What is DeepSpeed Chat πŸš€οΈ](#-what-is-deepspeed-chat-) +- [πŸš€ What is DeepSpeed Chat πŸš€](#-what-is-deepspeed-chat-) - [🧨 Capabilities 🧨](#-capabilities-) - [β˜• Quick Start β˜•](#-quick-start-) - [🐼 Installation](#-installation) - - [🐼 Single Script for Training 3-Step RLHF Pipeline](#-one-single-script-completes-all-three-stages-of-rlhf-training-and-generate-your-first-chatgpt-model) + - [🐼 One Single Script Completes All Three Steps of RLHF Training and Generate Your First ChatGPT Model](#-one-single-script-completes-all-three-steps-of-rlhf-training-and-generate-your-first-chatgpt-model) - [🐼 Demonstration: Individual Step Fine-Tuning](#-demonstration-individual-step-fine-tuning) - [πŸ• Step 1 - Supervised Fine-Tuning](#-step-1---supervised-fine-tuning) - [πŸ•‘ Step 2 - Reward Model](#-step-2---reward-model) - [πŸ•’ Step 3 - Reinforcement Learning with Human Feedback](#-step-3---reinforcement-learning-with-human-feedback) - - [🐼 Adding and using your own datasets in DeepSpeed-Chat](#-adding-and-using-your-own-datasets-in-deepspeed-chat) - - [🐼 Customizing RLHF training pipeline via DeepSpeed-Chat’s APIs](#-customizing-your-own-rlhf-training-pipeline-using-deepspeed-chats-rlhf-apis) - - [🐼 Serving Your Model: Plug-in and Test!](#-serving-plug-in-your-final-model-trained-by-deepspeed-chat-and-test-it-out) + - [🐼 Adding and using your own datasets in DeepSpeed-Chat](#-adding-and-using-your-own-datasets-in-deepspeed-chat) + - [🐼 Customizing your own RLHF training pipeline using DeepSpeed-Chat’s RLHF APIs](#-customizing-your-own-rlhf-training-pipeline-using-deepspeed-chats-rlhf-apis) + - [🐼 Serving: Plug-in your final model trained by DeepSpeed-Chat and test it out!](#-serving-plug-in-your-final-model-trained-by-deepspeed-chat-and-test-it-out) - [πŸ”₯ Training Performance Evaluation πŸ”₯](#-training-performance-evaluation-) + - [🐲 Superior Model Scale and Low Training Cost](#-superior-model-scale-and-low-training-cost) + - [🐲 Throughput and Model Size Scalability Comparisons with Existing RLHF Systems](#-throughput-and-model-size-scalability-comparisons-with-existing-rlhf-systems) - [😽 Supported Models 😽](#-supported-models-) - [πŸ”¬ Build Pipeline Status πŸ”¬](#-build-pipeline-status-) - [βš“ Documentation and Tutorial βš“](#-documentation-and-tutorial-) @@ -119,6 +123,7 @@ pip install deepspeed>=0.9.0 git clone https://github.com/microsoft/DeepSpeedExamples.git cd DeepSpeedExamples/applications/DeepSpeed-Chat/ pip install -r requirements.txt +pip install -e . ``` ### 🐼 One Single Script Completes All Three Steps of RLHF Training and Generate Your First ChatGPT Model diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py b/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py similarity index 89% rename from applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py rename to applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py index c79d644ef..22cba6be0 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py +++ b/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py @@ -4,17 +4,12 @@ # DeepSpeed Team import torch import torch.nn.functional as F -import sys -import os import time import deepspeed from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.accelerator import get_accelerator -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) - -from utils.utils import print_rank_0 +from dschat.utils.utils import print_rank_0 def print_all_ranks(tag, value, rank): @@ -62,6 +57,11 @@ def __init__(self, rlhf_engine, args): self.z3_enabled = args.actor_zero_stage == 3 self.compute_fp32_loss = self.args.compute_fp32_loss + # In case the generated experience is not valid (too short), we use the last valid + # generated experience. Alternatively, we can skip the step (on all workers). + # For now, use the last valid experience which is a simpler solution + self.last_generated_experience = None + # Those value can be changed self.kl_ctl = 0.1 self.clip_reward_value = 5 @@ -100,7 +100,8 @@ def _generate_sequence(self, prompts, mask, step): ans = seq[:, prompt_length:] valid_ans_len = (ans != self.tokenizer.pad_token_id).sum(dim=-1) - if self.args.print_answers: + if self.args.print_answers and (step % self.args.print_answers_interval + == 0): print( f"--- prompt --> step={step}, rank={torch.distributed.get_rank()}, {self.tokenizer.batch_decode(prompts, skip_special_tokens=True)}" ) @@ -112,10 +113,24 @@ def _generate_sequence(self, prompts, mask, step): for i in range(batch_size): if valid_ans_len[ i] <= 1: # if the answer is shorter than 1 token, drop it + print( + f'Dropping too short generated answer: {step=}: \n' + f'prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n' + f'answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}' + ) continue else: out_seq.append(seq[i:i + 1]) - out_seq = torch.cat(out_seq, dim=0) # concate output in the batch dim + + if not out_seq: + print( + f'All generated results are too short for rank={self.args.local_rank} step={step}\n' + f'-> prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n' + f'-> answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}' + ) + return None + + out_seq = torch.cat(out_seq, dim=0) # concat output in the batch dim return out_seq @@ -124,6 +139,12 @@ def generate_experience(self, prompts, mask, step): generate_start = time.time() seq = self._generate_sequence(prompts, mask, step) generate_end = time.time() + if seq is None: + assert self.last_generated_experience is not None, f'Invalid generated experience at {step=}' + prompts = self.last_generated_experience['prompts'] + seq = self.last_generated_experience['seq'] + else: + self.last_generated_experience = {'prompts': prompts, 'seq': seq} self.train() pad_token_id = self.tokenizer.pad_token_id diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py similarity index 97% rename from applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py rename to applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py index 3a192d017..5b6778cc2 100755 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py +++ b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py @@ -9,10 +9,10 @@ from deepspeed.ops.adam import DeepSpeedCPUAdam from transformers import AutoModelForCausalLM, get_scheduler -from utils.ds_utils import get_train_ds_config, get_eval_ds_config -from utils.module.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible -from utils.model.model_utils import create_hf_model, create_critic_model -from utils.utils import get_optimizer_grouped_parameters +from dschat.utils.ds_utils import get_train_ds_config, get_eval_ds_config +from dschat.utils.module.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible +from dschat.utils.model.model_utils import create_hf_model, create_critic_model +from dschat.utils.utils import get_optimizer_grouped_parameters """ TODOs: * support HF models for critic (for debugging), must be a previously saved ckpt from step-2 diff --git a/applications/DeepSpeed-Chat/training/utils/data/data_utils.py b/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py similarity index 99% rename from applications/DeepSpeed-Chat/training/utils/data/data_utils.py rename to applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py index 277ebd730..744ad7927 100644 --- a/applications/DeepSpeed-Chat/training/utils/data/data_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py @@ -14,7 +14,7 @@ import os import hashlib from itertools import chain -from . import raw_datasets +from dschat.utils.data import raw_datasets from deepspeed.accelerator import get_accelerator diff --git a/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py b/applications/DeepSpeed-Chat/dschat/utils/data/raw_datasets.py similarity index 99% rename from applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py rename to applications/DeepSpeed-Chat/dschat/utils/data/raw_datasets.py index 3c84f4b07..2838f9dc0 100644 --- a/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py +++ b/applications/DeepSpeed-Chat/dschat/utils/data/raw_datasets.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 +import os # DeepSpeed Team -from datasets import load_dataset +from datasets import load_dataset, load_from_disk from torch.utils.data import Subset import re @@ -15,7 +16,9 @@ def __init__(self, output_path, seed, local_rank, dataset_name): self.output_path = output_path self.seed = seed self.local_rank = local_rank - if not dataset_name == 'local/jsonfile': + if os.path.exists(dataset_name): + self.raw_datasets = load_from_disk(dataset_name) + elif not dataset_name == 'local/jsonfile': self.raw_datasets = load_dataset(dataset_name) def get_train_data(self): diff --git a/applications/DeepSpeed-Chat/training/utils/ds_utils.py b/applications/DeepSpeed-Chat/dschat/utils/ds_utils.py similarity index 100% rename from applications/DeepSpeed-Chat/training/utils/ds_utils.py rename to applications/DeepSpeed-Chat/dschat/utils/ds_utils.py diff --git a/applications/DeepSpeed-Chat/training/utils/model/model_utils.py b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py similarity index 90% rename from applications/DeepSpeed-Chat/training/utils/model/model_utils.py rename to applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py index 508b58317..97d3bff15 100644 --- a/applications/DeepSpeed-Chat/training/utils/model/model_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py @@ -12,8 +12,8 @@ from huggingface_hub import snapshot_download from transformers.deepspeed import HfDeepSpeedConfig -from .reward_model import RewardModel -from ..utils import load_state_dict_into_model +from dschat.utils.model.reward_model import RewardModel +from dschat.utils.utils import load_state_dict_into_model, print_rank_0 def configure_dropout(model_config, dropout): @@ -41,17 +41,19 @@ def causal_lm_forward( return_dict=None, **deprecated_arguments, ): + kwargs = dict() if model.config.model_type == "llama" else dict( + head_mask=head_mask) output = model.__original_forward__( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, - head_mask=head_mask, inputs_embeds=inputs_embeds, labels=None, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict) + return_dict=return_dict, + **kwargs) return_dict = isinstance(output, dict) lm_logits = output.logits if return_dict else output[0] @@ -130,8 +132,8 @@ def create_critic_model(model_name_or_path, critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer, ds_config, rlhf_training, dropout) end = time.time() - if torch.distributed.get_rank() == 0: - print(f"> Creating model from_config took {end - start} seconds") + print_rank_0(f">Creating model from_config took {end - start} seconds", + None) critic_model = RewardModel( critic_model, @@ -152,8 +154,8 @@ def create_critic_model(model_name_or_path, start = time.time() model_ckpt_state_dict = torch.load(model_ckpt_path, map_location='cpu') end = time.time() - if torch.distributed.get_rank() == 0: - print(f"> torch.load took {end - start} seconds") + print_rank_0(f">Creating model from_config took {end - start} seconds", + None) # load critic model from checkpoint with zero-stage 3 compatibility # this functionality may be moved to DS checkpoint load API in future @@ -163,7 +165,8 @@ def create_critic_model(model_name_or_path, "", zero_stage=zero_stage) end = time.time() - if torch.distributed.get_rank() == 0: - print(f"> Loading model state dict took {end - start} seconds") + + print_rank_0(f">Creating model from_config took {end - start} seconds", + None) return critic_model diff --git a/applications/DeepSpeed-Chat/training/utils/model/reward_model.py b/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py similarity index 96% rename from applications/DeepSpeed-Chat/training/utils/model/reward_model.py rename to applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py index f11d8787a..60d063b18 100644 --- a/applications/DeepSpeed-Chat/training/utils/model/reward_model.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/reward_model.py @@ -29,15 +29,15 @@ def __init__(self, self.config.n_embd = self.config.hidden_size if hasattr( self.config, "hidden_size") else self.config.n_embd self.v_head = nn.Linear(self.config.n_embd, 1, bias=False) - self.rwtranrsformer = base_model + self.rwtransformer = base_model self.PAD_ID = tokenizer.pad_token_id self.compute_fp32_loss = compute_fp32_loss def gradient_checkpointing_enable(self): - self.rwtranrsformer.gradient_checkpointing_enable() + self.rwtransformer.gradient_checkpointing_enable() def gradient_checkpointing_disable(self): - self.rwtranrsformer.gradient_checkpointing_disable() + self.rwtransformer.gradient_checkpointing_disable() def forward(self, input_ids=None, @@ -54,7 +54,7 @@ def forward(self, else: kwargs = dict(head_mask=head_mask) - transformer_outputs = self.rwtranrsformer( + transformer_outputs = self.rwtransformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -140,7 +140,7 @@ def forward_value(self, else: kwargs = dict(head_mask=head_mask) - transformer_outputs = self.rwtranrsformer( + transformer_outputs = self.rwtransformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, diff --git a/applications/DeepSpeed-Chat/training/utils/module/lora.py b/applications/DeepSpeed-Chat/dschat/utils/module/lora.py similarity index 98% rename from applications/DeepSpeed-Chat/training/utils/module/lora.py rename to applications/DeepSpeed-Chat/dschat/utils/module/lora.py index cd37e6496..32c9730b6 100644 --- a/applications/DeepSpeed-Chat/training/utils/module/lora.py +++ b/applications/DeepSpeed-Chat/dschat/utils/module/lora.py @@ -131,10 +131,10 @@ def convert_lora_to_linear_layer(model): return model -def only_optimize_lora_parameters(model): +def only_optimize_lora_parameters(model, force_optimize_params=[]): # turn off the gradient of all the parameters except the LoRA parameters for name, param in model.named_parameters(): - if "lora_right_weight" in name or "lora_left_weight" in name: + if "lora_right_weight" in name or "lora_left_weight" in name or name in force_optimize_params: param.requires_grad = True else: param.requires_grad = False diff --git a/applications/DeepSpeed-Chat/training/utils/perf.py b/applications/DeepSpeed-Chat/dschat/utils/perf.py similarity index 100% rename from applications/DeepSpeed-Chat/training/utils/perf.py rename to applications/DeepSpeed-Chat/dschat/utils/perf.py diff --git a/applications/DeepSpeed-Chat/training/utils/utils.py b/applications/DeepSpeed-Chat/dschat/utils/utils.py similarity index 86% rename from applications/DeepSpeed-Chat/training/utils/utils.py rename to applications/DeepSpeed-Chat/dschat/utils/utils.py index b5cfb8d6d..e4dc7d036 100644 --- a/applications/DeepSpeed-Chat/training/utils/utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/utils.py @@ -14,9 +14,22 @@ import torch.nn as nn -def print_rank_0(msg, rank=0): - if rank <= 0: +def print_rank_0(msg, rank=None): + if rank is not None and rank <= 0: print(msg) + elif is_rank_0(): + print(msg) + + +def is_rank_0(): + """Check whether it is rank 0.""" + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + return True + else: + return False + else: + return True def to_device(batch, device): @@ -44,6 +57,21 @@ def update(self, num): return self.mean +class ExponentialMovingAverage: + + def __init__(self, alpha=0.9): + self.alpha = alpha + self.ema = None + + def update(self, num): + prev_ema = num if self.ema is None else self.ema + self.ema = self.alpha * prev_ema + (1.0 - self.alpha) * num + return self.ema + + def get(self): + return self.ema if self.ema is not None else 0. + + def get_tokenizer(model_name_or_path, fast_tokenizer=True): if "llama" in model_name_or_path: from transformers.models.llama import LlamaTokenizer @@ -63,7 +91,9 @@ def get_tokenizer(model_name_or_path, fast_tokenizer=True): return tokenizer -def load_hf_tokenizer(model_name_or_path, fast_tokenizer=True): +def load_hf_tokenizer(model_name_or_path, + fast_tokenizer=True, + add_special_tokens=None): if os.path.exists(model_name_or_path): # Locally tokenizer loading has some issue, so we need to force download model_json = os.path.join(model_name_or_path, "config.json") @@ -77,6 +107,12 @@ def load_hf_tokenizer(model_name_or_path, fast_tokenizer=True): tokenizer = get_tokenizer(model_name_or_path, fast_tokenizer=fast_tokenizer) + if add_special_tokens is not None: + add_special_tokens = [add_special_tokens] if isinstance(add_special_tokens, str) \ + else add_special_tokens + tokenizer.add_special_tokens( + {'additional_special_tokens': add_special_tokens}) + return tokenizer @@ -174,15 +210,18 @@ def get_optimizer_grouped_parameters( model, weight_decay, lora_lr=5e-4, - no_decay_name_list=["bias", "LayerNorm.weight"], + no_decay_name_list=[ + "bias", "layer_norm.weight", "layernorm.weight", "norm.weight", + "ln_f.weight" + ], lora_name_list=["lora_right_weight", "lora_left_weight"], ): optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() - if (not any(nd in n for nd in no_decay_name_list) - and p.requires_grad and not any(nd in n + if (not any(nd in n.lower() for nd in no_decay_name_list) + and p.requires_grad and not any(nd in n.lower() for nd in lora_name_list)) ], "weight_decay": @@ -191,8 +230,8 @@ def get_optimizer_grouped_parameters( { "params": [ p for n, p in model.named_parameters() - if (not any(nd in n for nd in no_decay_name_list) - and p.requires_grad and any(nd in n + if (not any(nd in n.lower() for nd in no_decay_name_list) + and p.requires_grad and any(nd in n.lower() for nd in lora_name_list)) ], "weight_decay": @@ -203,7 +242,7 @@ def get_optimizer_grouped_parameters( { "params": [ p for n, p in model.named_parameters() - if (any(nd in n + if (any(nd in n.lower() for nd in no_decay_name_list) and p.requires_grad) ], "weight_decay": diff --git a/applications/DeepSpeed-Chat/train.py b/applications/DeepSpeed-Chat/e2e_rlhf.py similarity index 100% rename from applications/DeepSpeed-Chat/train.py rename to applications/DeepSpeed-Chat/e2e_rlhf.py diff --git a/applications/DeepSpeed-Chat/setup.py b/applications/DeepSpeed-Chat/setup.py new file mode 100644 index 000000000..01a1ed83f --- /dev/null +++ b/applications/DeepSpeed-Chat/setup.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# setup.py: install script for deepspeed_chat +""" +to install deepspeed_chat and its dependencies for development work, +run this cmd from the root directory: + pip install -e . +""" +import setuptools + +setuptools.setup( + name="deepspeed-chat", + version="0.1", + url= + "https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat", + include_package_data=True, + packages=setuptools.find_packages(include=['dschat']), + install_requires=[ + "datasets>=2.8.0", "sentencepiece>=0.1.97", "protobuf==3.20.3", + "accelerate>=0.15.0", "torch>=1.12.0", "deepspeed>=0.9.2", + "transformers>=4.31.0,!=4.33.2", "tensorboard" + ], + extras_require={ + "azureml": [ + "azure-ml-component", + "azureml-core", + ], + }) diff --git a/applications/DeepSpeed-Chat/training/tests/test_training.py b/applications/DeepSpeed-Chat/tests/test_training.py similarity index 97% rename from applications/DeepSpeed-Chat/training/tests/test_training.py rename to applications/DeepSpeed-Chat/tests/test_training.py index 3be4f6ff6..7ffe02972 100644 --- a/applications/DeepSpeed-Chat/training/tests/test_training.py +++ b/applications/DeepSpeed-Chat/tests/test_training.py @@ -66,7 +66,7 @@ def test_ds_chat(zero_stage, hybrid_engine, offload, lora): # cd into execution dir wd = os.getcwd() - os.chdir("../step3_rlhf_finetuning") + os.chdir("../training/step3_rlhf_finetuning") sweep_script = "training_scripts/opt/single_node/sweep/run_single.sh" # Run bash script @@ -85,3 +85,5 @@ def test_ds_chat(zero_stage, hybrid_engine, offload, lora): ), "Actor model was not saved during step 3 training." assert file_exists(f"{output_path}/critic/", "pytorch_model.bin" ), "Critic model was not saved during step 3 training." + + os.chdir(wd) diff --git a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py index 6dd91ca5d..c37d1f4cd 100755 --- a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py @@ -4,9 +4,7 @@ # DeepSpeed Team import argparse -import os import math -import sys import torch from torch.utils.data import DataLoader, RandomSampler, SequentialSampler @@ -23,14 +21,12 @@ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed import get_accelerator -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) -from utils.data.data_utils import create_prompt_dataset -from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer -from utils.ds_utils import get_train_ds_config -from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible -from utils.model.model_utils import create_hf_model, causal_lm_model_to_fp32_loss -from utils.perf import print_throughput +from dschat.utils.data.data_utils import create_prompt_dataset +from dschat.utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer +from dschat.utils.ds_utils import get_train_ds_config +from dschat.utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible +from dschat.utils.model.model_utils import create_hf_model, causal_lm_model_to_fp32_loss +from dschat.utils.perf import print_throughput def parse_args(): @@ -191,6 +187,11 @@ def parse_args(): parser.add_argument('--tensorboard_path', type=str, default="step1_tensorboard") + ## Tokenizer + parser.add_argument( + "--add_eot_token", + action='store_true', + help="Add <|endoftext|> as additional special token to tokenizer") ## Print loss parser.add_argument('--print_loss', action='store_true', @@ -233,7 +234,12 @@ def main(): torch.distributed.barrier() # load_hf_tokenizer will get the correct tokenizer and set padding tokens based on the model family - tokenizer = load_hf_tokenizer(args.model_name_or_path, fast_tokenizer=True) + args.end_of_conversation_token = "<|endoftext|>" + additional_special_tokens = args.end_of_conversation_token if args.add_eot_token else None + tokenizer = load_hf_tokenizer(args.model_name_or_path, + fast_tokenizer=True, + add_special_tokens=additional_special_tokens) + model = create_hf_model(AutoModelForCausalLM, args.model_name_or_path, tokenizer, @@ -293,14 +299,14 @@ def evaluation(model, eval_dataloader): losses += loss.float() losses = losses / (step + 1) try: - perplexity = torch.exp(losses) - except OverflowError: - perplexity = float("inf") - try: - perplexity = get_all_reduce_mean(perplexity).item() + losses = get_all_reduce_mean(losses) except: pass - return perplexity + try: + perplexity = torch.exp(losses).item() + except OverflowError: + perplexity = float("inf") + return perplexity, losses.item() # Split weights in two groups, one with weight decay and the other not. optimizer_grouped_parameters = get_optimizer_grouped_parameters( @@ -336,8 +342,8 @@ def evaluation(model, eval_dataloader): print_rank_0( f"***** Evaluating perplexity, Epoch {0}/{args.num_train_epochs} *****", args.global_rank) - perplexity = evaluation(model, eval_dataloader) - print_rank_0(f"ppl: {perplexity}", args.global_rank) + perplexity, eval_loss = evaluation(model, eval_dataloader) + print_rank_0(f"ppl: {perplexity}, loss: {eval_loss}", args.global_rank) for epoch in range(args.num_train_epochs): print_rank_0( @@ -365,8 +371,8 @@ def evaluation(model, eval_dataloader): print_rank_0( f"***** Evaluating perplexity, Epoch {epoch+1}/{args.num_train_epochs} *****", args.global_rank) - perplexity = evaluation(model, eval_dataloader) - print_rank_0(f"ppl: {perplexity}", args.global_rank) + perplexity, eval_loss = evaluation(model, eval_dataloader) + print_rank_0(f"ppl: {perplexity}, loss: {eval_loss}", args.global_rank) model.tput_timer.update_epoch_count() if args.output_dir is not None: diff --git a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/prompt_eval.py b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/prompt_eval.py index bcdc4be18..a25b0edea 100644 --- a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/prompt_eval.py +++ b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/prompt_eval.py @@ -5,16 +5,12 @@ import argparse import logging import torch -import sys -import os from transformers import ( AutoModelForCausalLM, ) -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) -from utils.model.model_utils import create_hf_model -from utils.utils import load_hf_tokenizer +from dschat.utils.model.model_utils import create_hf_model +from dschat.utils.utils import load_hf_tokenizer from deepspeed import get_accelerator logger = logging.getLogger(__name__) @@ -74,6 +70,10 @@ def parse_args(): type=str, default="English", choices=["English", "Chinese", "Japanese"]) + parser.add_argument( + "--add_eot_token", + action='store_true', + help="Add <|endoftext|> as additional special token to tokenizer") args = parser.parse_args() @@ -197,8 +197,11 @@ def main(): device = torch.device(get_accelerator().device_name(0)) + args.end_of_conversation_token = "<|endoftext|>" + additional_special_tokens = args.end_of_conversation_token if args.add_eot_token else None tokenizer = load_hf_tokenizer(args.model_name_or_path_baseline, - fast_tokenizer=True) + fast_tokenizer=True, + add_special_tokens=additional_special_tokens) model_baseline = create_hf_model(AutoModelForCausalLM, args.model_name_or_path_baseline, diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py index d3352ce3d..265c1caf4 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py @@ -4,9 +4,7 @@ # DeepSpeed Team import argparse -import os import math -import sys import torch from torch.utils.data import DataLoader, RandomSampler, SequentialSampler @@ -21,13 +19,11 @@ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed.accelerator import get_accelerator -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) -from utils.model.model_utils import create_critic_model -from utils.data.data_utils import create_prompt_dataset, DataCollatorReward -from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer -from utils.ds_utils import get_train_ds_config -from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible +from dschat.utils.model.model_utils import create_critic_model +from dschat.utils.data.data_utils import create_prompt_dataset, DataCollatorReward +from dschat.utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer +from dschat.utils.ds_utils import get_train_ds_config +from dschat.utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible def parse_args(): @@ -201,6 +197,11 @@ def parse_args(): parser.add_argument('--tensorboard_path', type=str, default="step2_tensorboard") + ## Tokenizer + parser.add_argument( + "--add_eot_token", + action='store_true', + help="Add <|endoftext|> as additional special token to tokenizer") parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() @@ -238,7 +239,11 @@ def main(): torch.distributed.barrier() # load_hf_tokenizer will get the correct tokenizer and set padding tokens based on the model family - tokenizer = load_hf_tokenizer(args.model_name_or_path, fast_tokenizer=True) + args.end_of_conversation_token = "<|endoftext|>" + additional_special_tokens = args.end_of_conversation_token if args.add_eot_token else None + tokenizer = load_hf_tokenizer(args.model_name_or_path, + fast_tokenizer=True, + add_special_tokens=additional_special_tokens) rm_model = create_critic_model(args.model_name_or_path, tokenizer, ds_config, @@ -247,12 +252,25 @@ def main(): zero_stage=args.zero_stage, compute_fp32_loss=args.compute_fp32_loss) + # Model bigscience/bloom-560m has large variance at ln_f.weight parameter + # This makes bf16 finetuning hard. + # In general, since we are replacing the model head, it makes sense to reset + # the LN that precedes it. + force_optimize_params = [] + if "bigscience/bloom-" in args.model_name_or_path: + torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight) + torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias) + force_optimize_params.extend( + ['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias']) + if args.lora_dim > 0: rm_model = convert_linear_layer_to_lora(rm_model, args.lora_module_name, args.lora_dim) if args.only_optimize_lora: - rm_model = only_optimize_lora_parameters(rm_model) + force_optimize_params.append('v_head.weight') + rm_model = only_optimize_lora_parameters(rm_model, + force_optimize_params) rm_model = make_model_gradient_checkpointing_compatible(rm_model) train_phase = 2 diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/rw_eval.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/rw_eval.py index 7df1af6c2..23f9a66af 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/rw_eval.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/rw_eval.py @@ -4,16 +4,10 @@ # DeepSpeed Team import argparse -import os import torch -import sys - -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) -from utils.model.model_utils import create_critic_model -from utils.utils import to_device -from utils.utils import load_hf_tokenizer +from dschat.utils.model.model_utils import create_critic_model +from dschat.utils.utils import to_device, load_hf_tokenizer from deepspeed import get_accelerator @@ -35,13 +29,20 @@ def parse_args(): "OPT model has a fixed number (1) of padding tokens at the beginning of the input. " "We did not see this in other models but keep it as an option for now.", ) + parser.add_argument( + "--add_eot_token", + action='store_true', + help="Add <|endoftext|> as additional special token to tokenizer") args = parser.parse_args() return args -def load_stuff(model_name_or_path, num_padding_at_beginning): +def load_stuff(model_name_or_path, num_padding_at_beginning, + additional_special_tokens): - tokenizer = load_hf_tokenizer(model_name_or_path, fast_tokenizer=True) + tokenizer = load_hf_tokenizer(model_name_or_path, + fast_tokenizer=True, + add_special_tokens=additional_special_tokens) tokenizer.pad_token = tokenizer.eos_token model = create_critic_model(model_name_or_path, tokenizer, @@ -106,8 +107,12 @@ def run_pair_comparison(): device = torch.device(get_accelerator().device_name(0)) + args.end_of_conversation_token = "<|endoftext|>" + additional_special_tokens = args.end_of_conversation_token if args.add_eot_token else None + rm_model, tokenizer = load_stuff(args.model_name_or_path, - args.num_padding_at_beginning) + args.num_padding_at_beginning, + additional_special_tokens) rm_model.to(device) rm_model.eval() @@ -126,12 +131,13 @@ def run_pair_comparison(): for prompt, good_ans, bad_ans in zip(prompt_list, good_ans_list, bad_ans_list): - batch = prepare_datapair(prompt, - good_ans, - bad_ans, - tokenizer, - max_seq_len=512, - end_of_conversation_token="<|endoftext|>") + batch = prepare_datapair( + prompt, + good_ans, + bad_ans, + tokenizer, + max_seq_len=512, + end_of_conversation_token=args.end_of_conversation_token) batch = to_device(batch, device) # Run inference with torch.no_grad(): @@ -150,18 +156,23 @@ def run_single_sample(): args = parse_args() device = torch.device(get_accelerator().device_name()) + args.end_of_conversation_token = "<|endoftext|>" + additional_special_tokens = args.end_of_conversation_token if args.add_eot_token else None + rm_model, tokenizer = load_stuff(args.model_name_or_path, - args.num_padding_at_beginning) + args.num_padding_at_beginning, + additional_special_tokens) rm_model.to(device) prompt = "Human: Explain the moon landing to a 6 year old in a few sentences." my_ans = "Assistant: The moon landing was a major milestone in the history of human exploration of the solar system. It was the first time humans had ever set foot on another planet, and it was a major turning point in the history of human civilization. The astronauts, Neil Armstrong, Buzz Aldrin, and Michael Collins, successfully landed the Apollo 11 spacecraft on the moon, marking the first time humans had ever set foot on another" - batch = prepare_singlesample(prompt, - my_ans, - tokenizer, - max_seq_len=512, - end_of_conversation_token="<|endoftext|>") + batch = prepare_singlesample( + prompt, + my_ans, + tokenizer, + max_seq_len=512, + end_of_conversation_token=args.end_of_conversation_token) batch = to_device(batch, device) rm_model.eval() diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index 0ad68f761..a5be5671b 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -33,17 +33,13 @@ import deepspeed -from ppo_trainer import DeepSpeedPPOTrainer, DeepSpeedPPOTrainerUnsupervised -from rlhf_engine import DeepSpeedRLHFEngine - -import sys - -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) -from utils.data.data_utils import create_prompt_dataset, MiniDataset, DataCollatorRLHF, get_unsupervised_data -from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, moving_average, save_zero_three_model, load_hf_tokenizer -from utils.module.lora import convert_lora_to_linear_layer -from utils.perf import print_throughput_step3 +from dschat.rlhf.ppo_trainer import DeepSpeedPPOTrainer, DeepSpeedPPOTrainerUnsupervised +from dschat.rlhf.rlhf_engine import DeepSpeedRLHFEngine +from dschat.utils.data.data_utils import create_prompt_dataset, MiniDataset, DataCollatorRLHF, get_unsupervised_data +from dschat.utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, moving_average, save_zero_three_model, load_hf_tokenizer, \ + ExponentialMovingAverage +from dschat.utils.module.lora import convert_lora_to_linear_layer +from dschat.utils.perf import print_throughput_step3 from deepspeed.accelerator import get_accelerator writer = None @@ -339,6 +335,11 @@ def parse_args(): parser.add_argument('--tensorboard_path', type=str, default="step3_tensorboard") + ## Tokenizer + parser.add_argument( + "--add_eot_token", + action='store_true', + help="Add <|endoftext|> as additional special token to tokenizer") ## Actor/critic model overflow alignment parser.add_argument( '--align_overflow', @@ -348,6 +349,11 @@ def parse_args(): parser.add_argument('--print_answers', action='store_true', help='Print prompt and answers during training') + parser.add_argument( + "--print_answers_interval", + type=int, + default=1, + help="If --print_answers enabled, controls the printing interval.") ## Testing parser.add_argument( '--enable_test_mode', @@ -459,8 +465,12 @@ def main(): torch.distributed.barrier() # load_hf_tokenizer will get the correct tokenizer and set padding tokens based on the model family + args.end_of_conversation_token = "<|endoftext|>" + additional_special_tokens = args.end_of_conversation_token if args.add_eot_token else None tokenizer = load_hf_tokenizer(args.actor_model_name_or_path, - fast_tokenizer=True) + fast_tokenizer=True, + add_special_tokens=additional_special_tokens) + prompt_train_dataloader, unsupervised_train_dataloader, num_total_iters = create_datasets( args=args, tokenizer=tokenizer, train_phase=3) @@ -479,8 +489,6 @@ def main(): rlhf_engine.actor.optimizer.quantize_nontrainable_params() print_rank_0("Mixed Precision ZeRO++ enabled") - args.end_of_conversation_token = "<|endoftext|>" - ppo_trainer = DeepSpeedPPOTrainerUnsupervised if unsupervised_training_enabled else DeepSpeedPPOTrainer trainer = ppo_trainer(rlhf_engine, args) @@ -491,9 +499,13 @@ def main(): args.per_device_training_batch_size) # Train! - print_rank_0("***** Running training *****", args.global_rank) + print_rank_0( + f"***** Running training (total_iters={num_total_iters}) *****", + args.global_rank) non_overflow_step_count = 0 + step_average_reward = 0. + ema_reward_score = ExponentialMovingAverage() for epoch in range(args.num_train_epochs): print_rank_0( @@ -565,9 +577,15 @@ def main(): rlhf_engine.critic, args, e2e_time, trainer.generate_time, training_time, args.global_rank) + average_reward = get_all_reduce_mean(average_reward).item() + step_average_reward += average_reward / args.gradient_accumulation_steps_actor + if (step + 1) % args.gradient_accumulation_steps_actor == 0: + ema_reward_score.update(step_average_reward) + step_average_reward = 0. + print_rank_0( - f"Average reward score: {average_reward/inner_iter}", + f"Average reward score: {average_reward/inner_iter} | EMA reward score: {ema_reward_score.get()}", args.global_rank) print_rank_0( "-------------------------------------------------------------------------------------", diff --git a/inference/mii/README.md b/inference/mii/README.md new file mode 100644 index 000000000..d701d5537 --- /dev/null +++ b/inference/mii/README.md @@ -0,0 +1,5 @@ +# DeepSpeed MII Examples + +Install the requirements by running `pip install -r requirements.txt`. + +Once [DeepSpeed-MII](https://github.com/microsoft/deepspeed-mii) is installed you have two options for deployment: an interactive non-persistent pipeline or a persistent serving deployment. For details on these files please refer to the [Getting Started guide for MII](https://github.com/microsoft/deepspeed-mii#getting-started-with-mii). diff --git a/inference/mii/client.py b/inference/mii/client.py new file mode 100644 index 000000000..6d19fec3a --- /dev/null +++ b/inference/mii/client.py @@ -0,0 +1,6 @@ +import mii + +client = mii.client("mistralai/Mistral-7B-v0.1") +output = client.generate("Deepspeed is", max_new_tokens=128) + +print(output) diff --git a/inference/mii/pipeline.py b/inference/mii/pipeline.py new file mode 100644 index 000000000..dcf9e8b03 --- /dev/null +++ b/inference/mii/pipeline.py @@ -0,0 +1,6 @@ +from mii import pipeline + +pipe = pipeline("mistralai/Mistral-7B-v0.1") +output = pipe(["Hello, my name is", "DeepSpeed is"], max_new_tokens=128) + +print(output) diff --git a/inference/mii/requirements.txt b/inference/mii/requirements.txt new file mode 100644 index 000000000..07d9f7e16 --- /dev/null +++ b/inference/mii/requirements.txt @@ -0,0 +1 @@ +mii>=0.1.0 diff --git a/inference/mii/serve.py b/inference/mii/serve.py new file mode 100644 index 000000000..09c0c306c --- /dev/null +++ b/inference/mii/serve.py @@ -0,0 +1,3 @@ +import mii + +mii.serve("mistralai/Mistral-7B-v0.1") diff --git a/inference/mii/terminate.py b/inference/mii/terminate.py new file mode 100644 index 000000000..2a7ed3211 --- /dev/null +++ b/inference/mii/terminate.py @@ -0,0 +1,4 @@ +import mii + +client = mii.client("mistralai/Mistral-7B-v0.1") +client.terminate_server() diff --git a/training/cifar/cifar10_deepspeed.py b/training/cifar/cifar10_deepspeed.py index a28bdcad0..da82e60db 100755 --- a/training/cifar/cifar10_deepspeed.py +++ b/training/cifar/cifar10_deepspeed.py @@ -343,7 +343,7 @@ def create_moe_param_groups(model): # We simply have to loop over our data iterator, and feed the inputs to the # network and optimize. -for epoch in range(2): # loop over the dataset multiple times +for epoch in range(args.epochs): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader): diff --git a/training/data_efficiency/vit_finetuning/requirement.txt b/training/data_efficiency/vit_finetuning/requirement.txt index 8bec1b063..9cf596612 100644 --- a/training/data_efficiency/vit_finetuning/requirement.txt +++ b/training/data_efficiency/vit_finetuning/requirement.txt @@ -1,4 +1,4 @@ -timm +timm==0.6.5 torch>1.10.0 torchvision>0.11.1 -mpi4py \ No newline at end of file +mpi4py