diff --git a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py index 84623ae76..cb5636f0e 100755 --- a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py @@ -139,9 +139,12 @@ def parse_args(): parser.add_argument('--gradient_checkpointing', action='store_true', help='Enable HF gradient checkpointing for model.') - parser.add_argument('--disable_dropout', - action='store_true', - help='Disable the dropout of the model.') + parser.add_argument( + "--dropout", + type=float, + default=None, + help="If dropout configured, use it. " + "Otherwise, keep the default dropout configuration of the model.") # deepspeed features parser.add_argument('--offload', action='store_true', @@ -229,7 +232,7 @@ def main(): args.model_name_or_path, tokenizer, ds_config, - disable_dropout=args.disable_dropout) + dropout=args.dropout) if args.lora_dim > 0: model = convert_linear_layer_to_lora(model, args.lora_module_name, diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py index b84ccdbaf..56b8a6110 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py @@ -138,9 +138,12 @@ def parse_args(): '--gradient_checkpointing', action='store_true', help='Enable HF gradient checkpointing for Actor model.') - parser.add_argument('--disable_dropout', - action='store_true', - help='Disable the dropout of the model.') + parser.add_argument( + "--dropout", + type=float, + default=None, + help="If dropout configured, use it. " + "Otherwise, keep the default dropout configuration of the model.") # deepspeed features parser.add_argument('--offload', action='store_true', @@ -223,7 +226,7 @@ def main(): tokenizer, ds_config, args.num_padding_at_beginning, - disable_dropout=args.disable_dropout) + dropout=args.dropout) if args.lora_dim > 0: rm_model = convert_linear_layer_to_lora(rm_model, diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/rw_eval.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/rw_eval.py index 343c1d0e1..7df1af6c2 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/rw_eval.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/rw_eval.py @@ -43,8 +43,11 @@ def load_stuff(model_name_or_path, num_padding_at_beginning): tokenizer = load_hf_tokenizer(model_name_or_path, fast_tokenizer=True) tokenizer.pad_token = tokenizer.eos_token - model = create_critic_model(model_name_or_path, tokenizer, None, - num_padding_at_beginning, True) + model = create_critic_model(model_name_or_path, + tokenizer, + None, + num_padding_at_beginning, + dropout=0.) return model, tokenizer diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/multi_node/run_350m.sh b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/multi_node/run_350m.sh index cea008824..51852af45 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/multi_node/run_350m.sh +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/multi_node/run_350m.sh @@ -23,7 +23,7 @@ deepspeed main.py \ --max_seq_len 512 \ --learning_rate 5e-5 \ --weight_decay 0.1 \ - --disable_dropout \ + --dropout 0.0 \ --num_train_epochs 1 \ --gradient_accumulation_steps 1 \ --lr_scheduler_type cosine \ diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_gpu/run_350m.sh b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_gpu/run_350m.sh index 5f836a46f..284fd44a0 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_gpu/run_350m.sh +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_gpu/run_350m.sh @@ -14,7 +14,7 @@ fi mkdir -p $OUTPUT deepspeed --num_gpus 1 main.py --model_name_or_path facebook/opt-350m \ - --num_padding_at_beginning 1 --weight_decay 0.1 --disable_dropout --gradient_accumulation_steps 4 --zero_stage $ZERO_STAGE \ + --num_padding_at_beginning 1 --weight_decay 0.1 --dropout 0.0 --gradient_accumulation_steps 4 --zero_stage $ZERO_STAGE \ --enable_tensorboard \ --tensorboard_path $OUTPUT \ --deepspeed --output_dir $OUTPUT &> $OUTPUT/training.log diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_node/run_350m.sh b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_node/run_350m.sh index 2d1709955..d7ff70106 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_node/run_350m.sh +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_node/run_350m.sh @@ -24,7 +24,7 @@ deepspeed main.py \ --learning_rate 5e-5 \ --weight_decay 0.1 \ --num_train_epochs 1 \ - --disable_dropout \ + --dropout 0.0 \ --gradient_accumulation_steps 1 \ --lr_scheduler_type cosine \ --num_warmup_steps 0 \ diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_node/sweep/run_single.sh b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_node/sweep/run_single.sh index c308a2c5f..6f5453af1 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_node/sweep/run_single.sh +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/opt/single_node/sweep/run_single.sh @@ -30,7 +30,7 @@ cmd="deepspeed main.py \ --learning_rate 5e-5 \ --weight_decay 0.1 \ --num_train_epochs 1 \ - --disable_dropout \ + --dropout 0.0 \ --gradient_accumulation_steps 1 \ --lr_scheduler_type cosine \ --num_warmup_steps 0 \ diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index 342eeea4a..e14490ba8 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -268,12 +268,20 @@ def parse_args(): '--critic_gradient_checkpointing', action='store_true', help='Enable HF gradient checkpointing for Critic model.') - parser.add_argument('--disable_actor_dropout', - action='store_true', - help='Disable the dropout of the actor model.') - parser.add_argument('--disable_critic_dropout', - action='store_true', - help='Disable the dropout of the critical model.') + parser.add_argument( + "--actor_dropout", + type=float, + default=None, + help="If actor dropout configured, use it. " + "Otherwise, keep the default dropout configuration of the actor model." + ) + parser.add_argument( + "--critic_dropout", + type=float, + default=None, + help="If critic dropout configured, use it. " + "Otherwise, keep the default dropout configuration of the critic model." + ) ## LoRA for efficient training setting parser.add_argument("--actor_lora_dim", type=int, diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py index 96b3ad632..3a192d017 100755 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py @@ -92,7 +92,7 @@ def _init_actor(self, actor_model_name_or_path): model_name_or_path=actor_model_name_or_path, tokenizer=self.tokenizer, ds_config=ds_config, - disable_dropout=self.args.disable_actor_dropout) + dropout=self.args.actor_dropout) # LoRA if self.args.actor_lora_dim > 0: @@ -221,7 +221,7 @@ def _init_critic(self, critic_model_name_or_path): ds_config=ds_eval_config, num_padding_at_beginning=self.args.num_padding_at_beginning, rlhf_training=True, - disable_dropout=self.args.disable_critic_dropout, + dropout=self.args.critic_dropout, zero_stage=self.args.critic_zero_stage) # LoRA @@ -295,7 +295,7 @@ def _init_reward(self, critic_model_name_or_path): ds_config=ds_eval_config, num_padding_at_beginning=self.args.num_padding_at_beginning, rlhf_training=True, - disable_dropout=self.args.disable_critic_dropout, + dropout=self.args.critic_dropout, zero_stage=zero_stage) reward_engine, *_ = deepspeed.initialize(model=reward_model, diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b.sh b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b.sh index 8de770a1e..c58e94eab 100755 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b.sh +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b.sh @@ -46,7 +46,7 @@ deepspeed --master_port 12346 main.py \ --actor_gradient_checkpointing \ --critic_gradient_checkpointing \ --offload_reference_model \ - --disable_actor_dropout \ + --actor_dropout 0.0 \ --num_warmup_steps 100 \ --deepspeed --seed 1234 \ --actor_zero_stage $ACTOR_ZERO_STAGE \ diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b_lora.sh b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b_lora.sh index 263297acb..830c3750e 100755 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b_lora.sh +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b_lora.sh @@ -46,7 +46,7 @@ deepspeed --master_port 12346 main.py \ --actor_gradient_checkpointing \ --critic_gradient_checkpointing \ --offload_reference_model \ - --disable_actor_dropout \ + --actor_dropout 0.0 \ --num_warmup_steps 100 \ --deepspeed --seed 1234 \ --actor_zero_stage $ACTOR_ZERO_STAGE \ diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b_mixz.sh b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b_mixz.sh index 8d8dcb6db..abde0b54a 100755 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b_mixz.sh +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/llama2/run_llama2_7b_mixz.sh @@ -46,7 +46,7 @@ deepspeed --master_port 12346 main.py \ --actor_gradient_checkpointing \ --critic_gradient_checkpointing \ --offload_reference_model \ - --disable_actor_dropout \ + --actor_dropout 0.0 \ --num_warmup_steps 100 \ --deepspeed --seed 1234 \ --actor_zero_stage $ACTOR_ZERO_STAGE \ diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/multi_node/run_66b.sh b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/multi_node/run_66b.sh index 4cd944611..c70908ceb 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/multi_node/run_66b.sh +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/multi_node/run_66b.sh @@ -51,7 +51,7 @@ deepspeed --master_port 12346 main.py \ --actor_zero_stage $ACTOR_ZERO_STAGE \ --critic_zero_stage $CRITIC_ZERO_STAGE \ --actor_gradient_checkpointing \ - --disable_actor_dropout \ + --actor_dropout 0.0 \ --actor_lora_dim 128 \ --actor_lora_module_name decoder.layers. \ --output_dir $OUTPUT \ diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_gpu/run_1.3b.sh b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_gpu/run_1.3b.sh index 1b1a5c489..41cacebab 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_gpu/run_1.3b.sh +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_gpu/run_1.3b.sh @@ -23,5 +23,5 @@ deepspeed --num_gpus 1 main.py \ --actor_model_name_or_path $ACTOR_MODEL_PATH --critic_model_name_or_path $CRITIC_MODEL_PATH \ --actor_zero_stage $ACTOR_ZERO_STAGE --critic_zero_stage $CRITIC_ZERO_STAGE \ --num_padding_at_beginning 1 --gradient_accumulation_steps 2 \ - --deepspeed --actor_lora_dim 128 --enable_hybrid_engine --actor_gradient_checkpointing --disable_actor_dropout \ + --deepspeed --actor_lora_dim 128 --enable_hybrid_engine --actor_gradient_checkpointing --actor_dropout 0.0 \ --output_dir $OUTPUT &> $OUTPUT/training.log diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_gpu/run_6.7b_lora.sh b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_gpu/run_6.7b_lora.sh index 3ae8a6d37..2c3a01d5f 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_gpu/run_6.7b_lora.sh +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_gpu/run_6.7b_lora.sh @@ -43,7 +43,7 @@ deepspeed --num_gpus 1 main.py \ --actor_lora_dim 128 \ --actor_gradient_checkpointing \ --critic_gradient_checkpointing \ - --disable_actor_dropout \ + --actor_dropout 0.0 \ --enable_hybrid_engine \ --output_dir $OUTPUT \ &> $OUTPUT/training.log diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_1.3b.sh b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_1.3b.sh index d306fb58c..5449bfea4 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_1.3b.sh +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_1.3b.sh @@ -50,7 +50,7 @@ deepspeed --master_port 12346 main.py \ --num_train_epochs 1 \ --lr_scheduler_type cosine \ --gradient_accumulation_steps 1 \ - --disable_actor_dropout \ + --actor_dropout 0.0 \ --num_warmup_steps 100 \ --deepspeed --seed 1234 \ --enable_hybrid_engine \ diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_1.3b_lora.sh b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_1.3b_lora.sh index bcc440530..b39ccb833 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_1.3b_lora.sh +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_1.3b_lora.sh @@ -38,7 +38,7 @@ deepspeed --master_port 12346 main.py \ --gradient_accumulation_steps 1 \ --num_warmup_steps 100 \ --deepspeed --seed 1234 \ - --disable_actor_dropout \ + --actor_dropout 0.0 \ ${ACTOR_ZERO_STAGE} \ ${CRITIC_ZERO_STAGE} \ --actor_lora_dim 128 \ diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_13b.sh b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_13b.sh index f037eb985..82751bd7f 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_13b.sh +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_13b.sh @@ -48,7 +48,7 @@ deepspeed --master_port 12346 main.py \ --actor_zero_stage $ACTOR_ZERO_STAGE \ --critic_zero_stage $CRITIC_ZERO_STAGE \ --actor_gradient_checkpointing \ - --disable_actor_dropout \ + --actor_dropout 0.0 \ --actor_lora_dim 128 \ --actor_lora_module_name decoder.layers. \ --output_dir $OUTPUT \ diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_30b_lora.sh b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_30b_lora.sh index 4fbdc4faf..c5c9133ff 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_30b_lora.sh +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_30b_lora.sh @@ -38,7 +38,7 @@ deepspeed --master_port 12346 main.py \ --lr_scheduler_type cosine \ --gradient_accumulation_steps 1 \ --actor_gradient_checkpointing \ - --disable_actor_dropout \ + --actor_dropout 0.0 \ --num_warmup_steps 100 \ --deepspeed --seed 1234 \ ${ACTOR_ZERO_STAGE} \ diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_6.7b.sh b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_6.7b.sh index 920a54894..f877bebdf 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_6.7b.sh +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_6.7b.sh @@ -44,7 +44,7 @@ deepspeed --master_port 12346 main.py \ --lr_scheduler_type cosine \ --gradient_accumulation_steps 1 \ --actor_gradient_checkpointing \ - --disable_actor_dropout \ + --actor_dropout 0.0 \ --num_warmup_steps 100 \ --deepspeed --seed 1234 \ --enable_hybrid_engine \ diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/sweep/run_single.sh b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/sweep/run_single.sh index 630ad2e0c..15ec6e576 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/sweep/run_single.sh +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/sweep/run_single.sh @@ -85,7 +85,7 @@ cmd="deepspeed --num_nodes=1 main.py \ --critic_weight_decay 0 \ --num_warmup_steps 100 \ --deepspeed --seed 1234 \ - --disable_actor_dropout \ + --actor_dropout 0.0 \ --print_answers \ --actor_zero_stage ${ACTOR_ZERO_STAGE} \ --critic_zero_stage ${CRITIC_ZERO_STAGE} \ diff --git a/applications/DeepSpeed-Chat/training/utils/model/model_utils.py b/applications/DeepSpeed-Chat/training/utils/model/model_utils.py index fdf9bc9c8..8a0051523 100644 --- a/applications/DeepSpeed-Chat/training/utils/model/model_utils.py +++ b/applications/DeepSpeed-Chat/training/utils/model/model_utils.py @@ -16,15 +16,24 @@ from ..utils import load_state_dict_into_model +def configure_dropout(model_config, dropout): + if dropout is not None: + for key in ('dropout', 'attention_dropout', 'hidden_dropout', + 'activation_dropout'): + if hasattr(model_config, key): + print(f"Setting model_config.{key} to {dropout}") + setattr(model_config, key, dropout) + + def create_hf_model(model_class, model_name_or_path, tokenizer, ds_config=None, rlhf_training=False, - disable_dropout=False): + dropout=None): model_config = AutoConfig.from_pretrained(model_name_or_path) - if disable_dropout: - model_config.dropout = 0.0 + configure_dropout(model_config, dropout) + # Note: dschf is defined in function scope to avoid global effects # https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: @@ -54,7 +63,7 @@ def create_critic_model(model_name_or_path, ds_config, num_padding_at_beginning=0, rlhf_training=False, - disable_dropout=False, + dropout=None, zero_stage=0): # OPT model family always put a padding token at the beginning of the sequence, # we did not see this in other models but not sure if it is a general rule @@ -63,7 +72,7 @@ def create_critic_model(model_name_or_path, start = time.time() critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer, - ds_config, rlhf_training, disable_dropout) + ds_config, rlhf_training, dropout) end = time.time() if torch.distributed.get_rank() == 0: print(f"> Creating model from_config took {end - start} seconds")