Skip to content

Commit

Permalink
deepspeed-chat: calculate loss in fp32
Browse files Browse the repository at this point in the history
Using loss in fp32 can improve training accuracy for all 3 stages.
This was tested with Bloom model using bf16 dtype

While at it, fix stage2 reward model creation: pass zero_stage to
create_critic_model.

Also, in stage3, when using bf16 and tensorboard enabled, we record the actor
and critic loss. Tensorboard accepets a scalar bf16 loss tensor and converts
it to numpy. This fails since numpy does not support conversion from tensor to
bf16. Fix it by logging to tensorboard the loss.item().

Change-Id: I9c8e95d4886cdb44aaa6c14c4aee738e133ae405
Signed-off-by: Moshe Island <[email protected]>
  • Loading branch information
mosheisland committed Oct 4, 2023
1 parent 4364031 commit 044bd98
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer
from utils.ds_utils import get_train_ds_config
from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
from utils.model.model_utils import create_hf_model
from utils.model.model_utils import create_hf_model, causal_lm_model_to_fp32_loss
from utils.perf import print_throughput


Expand Down Expand Up @@ -178,6 +178,12 @@ def parse_args():
help=
"Initial LoRA learning rate (after the potential warmup period) to use."
)
## low precision
parser.add_argument(
'--compute_fp32_loss',
action='store_true',
help='Relevant for low precision dtypes (fp16, bf16, etc.). '
'If specified, loss is calculated in fp32.')
## Tensorboard logging
parser.add_argument('--enable_tensorboard',
action='store_true',
Expand Down Expand Up @@ -234,6 +240,12 @@ def main():
ds_config,
dropout=args.dropout)

if args.compute_fp32_loss:
print_rank_0(
f"Using model {model.__class__.__name__} with loss in fp32",
args.global_rank)
causal_lm_model_to_fp32_loss(model)

if args.lora_dim > 0:
model = convert_linear_layer_to_lora(model, args.lora_module_name,
args.lora_dim)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def parse_args():
help=
"Initial LoRA learning rate (after the potential warmup period) to use."
)
## low precision
parser.add_argument(
'--compute_fp32_loss',
action='store_true',
help='Relevant for low precision dtypes (fp16, bf16, etc.). '
'If specified, loss is calculated in fp32.')
## Tensorboard logging
parser.add_argument('--enable_tensorboard',
action='store_true',
Expand Down Expand Up @@ -226,7 +232,9 @@ def main():
tokenizer,
ds_config,
args.num_padding_at_beginning,
dropout=args.dropout)
dropout=args.dropout,
zero_stage=args.zero_stage,
compute_fp32_loss=args.compute_fp32_loss)

if args.lora_dim > 0:
rm_model = convert_linear_layer_to_lora(rm_model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ def parse_args():
'--enable_mixed_precision_lora',
action='store_true',
help='Enable Mixed Precision ZeRO++ for training and generation.')
## low precision
parser.add_argument(
'--compute_fp32_loss',
action='store_true',
help='Relevant for low precision dtypes (fp16, bf16, etc.). '
'If specified, loss is calculated in fp32.'
'This applies for both actor and critic models.')
## Tensorboard logging
parser.add_argument('--enable_tensorboard',
action='store_true',
Expand Down Expand Up @@ -572,13 +579,13 @@ def main():
average_reward / inner_iter,
global_step=step)
writer.add_scalar('actor_loss',
actor_loss,
actor_loss.item(),
global_step=step)
writer.add_scalar('actor_loss_sum',
actor_loss_sum,
global_step=step)
writer.add_scalar('critic_loss',
critic_loss,
critic_loss.item(),
global_step=step)
writer.add_scalar('critic_loss_sum',
critic_loss_sum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, rlhf_engine, args):
self.end_of_conversation_token_id = self.tokenizer(
args.end_of_conversation_token)['input_ids'][-1]
self.z3_enabled = args.actor_zero_stage == 3
self.compute_fp32_loss = self.args.compute_fp32_loss

# Those value can be changed
self.kl_ctl = 0.1
Expand Down Expand Up @@ -139,6 +140,9 @@ def generate_experience(self, prompts, mask, step):

logits = output.logits
logits_ref = output_ref.logits
if self.compute_fp32_loss:
logits = logits.to(torch.float)
logits_ref = logits_ref.to(torch.float)

self.generate_time = generate_end - generate_start

Expand Down Expand Up @@ -271,6 +275,9 @@ def critic_loss_fn(self, values, old_values, returns, mask):
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
if self.compute_fp32_loss:
values = values.float()
values_clipped = values_clipped.float()
vf_loss1 = (values - returns)**2
vf_loss2 = (values_clipped - returns)**2
vf_loss = 0.5 * torch.sum(
Expand Down
61 changes: 59 additions & 2 deletions applications/DeepSpeed-Chat/training/utils/model/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,61 @@ def configure_dropout(model_config, dropout):
setattr(model_config, key, dropout)


def causal_lm_model_to_fp32_loss(model):
""" Convert CausalLM model to calculate loss in fp32 """

def causal_lm_forward(
input_ids=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**deprecated_arguments,
):
output = model.__original_forward__(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

return_dict = isinstance(output, dict)
lm_logits = output.logits if return_dict else output[0]
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].float().contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length))

if not return_dict:
# re-pack output with fp32 loss
return ((loss, ) + output) if loss is not None else output

output.loss = loss
return output

model.__original_forward__ = model.forward
model.forward = causal_lm_forward


def create_hf_model(model_class,
model_name_or_path,
tokenizer,
Expand Down Expand Up @@ -64,7 +119,8 @@ def create_critic_model(model_name_or_path,
num_padding_at_beginning=0,
rlhf_training=False,
dropout=None,
zero_stage=0):
zero_stage=0,
compute_fp32_loss=False):
# OPT model family always put a padding token at the beginning of the sequence,
# we did not see this in other models but not sure if it is a general rule

Expand All @@ -80,7 +136,8 @@ def create_critic_model(model_name_or_path,
critic_model = RewardModel(
critic_model,
tokenizer,
num_padding_at_beginning=num_padding_at_beginning)
num_padding_at_beginning=num_padding_at_beginning,
compute_fp32_loss=compute_fp32_loss)

if rlhf_training:
# load critic model from checkpoint
Expand Down
12 changes: 10 additions & 2 deletions applications/DeepSpeed-Chat/training/utils/model/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
## https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
class RewardModel(nn.Module):

def __init__(self, base_model, tokenizer, num_padding_at_beginning=0):
def __init__(self,
base_model,
tokenizer,
num_padding_at_beginning=0,
compute_fp32_loss=False):
super().__init__()
self.config = base_model.config
self.num_padding_at_beginning = num_padding_at_beginning
Expand All @@ -27,6 +31,7 @@ def __init__(self, base_model, tokenizer, num_padding_at_beginning=0):
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
self.rwtranrsformer = base_model
self.PAD_ID = tokenizer.pad_token_id
self.compute_fp32_loss = compute_fp32_loss

def gradient_checkpointing_enable(self):
self.rwtranrsformer.gradient_checkpointing_enable()
Expand Down Expand Up @@ -73,7 +78,7 @@ def forward(self,
rejected_rewards = rewards[bs:]

# Compute pairwise loss. Only backprop on the different tokens before padding
loss = 0
loss = 0.
for i in range(bs):
chosen_id = chosen_ids[i]
rejected_id = rejected_ids[i]
Expand Down Expand Up @@ -104,6 +109,9 @@ def forward(self,
chosen_reward[c_ind - 1]) #use the end score for reference
rejected_mean_scores.append(rejected_reward[r_ind - 1])

if self.compute_fp32_loss:
c_truncated_reward = c_truncated_reward.float()
r_truncated_reward = r_truncated_reward.float()
loss += -torch.nn.functional.logsigmoid(c_truncated_reward -
r_truncated_reward).mean()

Expand Down

0 comments on commit 044bd98

Please sign in to comment.