diff --git a/examples/megatron/entry/train_sft.py b/examples/megatron/entry/train_sft.py index c293439c..53d66aad 100644 --- a/examples/megatron/entry/train_sft.py +++ b/examples/megatron/entry/train_sft.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """finetune sft""" - +import os from functools import partial import torch @@ -27,17 +27,23 @@ from megatron.core.models.gpt import GPTModel as MCoreGPTModel from megatron.training import pretrain from megatron.training.arguments import core_transformer_config_from_args -from megatron.training.utils import average_losses_across_data_parallel_group -from megatron.training.utils import get_ltor_masks_and_position_ids +from megatron.training.utils import ( + get_batch_on_this_cp_rank, + get_ltor_masks_and_position_ids, +) from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, ) from megatron.core.transformer.spec_utils import import_module +from megatron.core import mpu +from megatron.core.utils import StragglerDetector from examples.megatron.data.sft_dataset import build_train_valid_test_datasets +stimer = StragglerDetector() + def model_provider(pre_process=True, post_process=True): """Build the model.""" @@ -81,6 +87,10 @@ def model_provider(pre_process=True, post_process=True): def get_batch(data_iterator): """Generate a batch""" + + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): + return None, None, None, None, None + args = get_args() tokenizer = get_tokenizer() @@ -110,32 +120,75 @@ def get_batch(data_iterator): args.eod_mask_loss, ) - return tokens, labels, loss_mask, attention_mask, position_ids - + batch = { + "tokens": tokens, + "labels": labels, + "loss_mask": loss_mask, + "attention_mask": attention_mask, + "position_ids": position_ids + } + batch = get_batch_on_this_cp_rank(batch) + return batch.values() + + +def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ + args = get_args() -def loss_func(loss_mask, output_tensor): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + total_tokens = loss_mask.sum() + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) + + if args.context_parallel_size > 1: + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + + # Check individual rank losses are not NaN prior to DP all-reduce. + if args.check_for_nan_in_loss_and_grad: + global_rank = torch.distributed.get_rank() + assert not loss[0].isnan(), ( + f'Rank {global_rank}: found NaN in local forward loss calculation. ' + f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' + ) # Reduce loss for logging. - averaged_loss = average_losses_across_data_parallel_group([loss]) - - return loss, {'lm loss': averaged_loss[0]} + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + local_num_tokens = loss[1].clone().detach().to(torch.int) + return ( + loss[0] * args.context_parallel_size, + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])}, + ) def forward_step(data_iterator, model): - """Forward step.""" + """Forward training step + """ timers = get_timers() # Get the batch. timers('batch-generator', log_level=2).start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( - data_iterator) + with stimer(bdata=True): + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) timers('batch-generator').stop() - output_tensor = model(tokens, position_ids, attention_mask, - labels=labels) + with stimer: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) return output_tensor, partial(loss_func, loss_mask)