Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable context parallelism in SFT #190

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 68 additions & 15 deletions examples/megatron/entry/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""finetune sft"""

import os
from functools import partial

import torch
Expand All @@ -27,17 +27,23 @@
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.training import pretrain
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.utils import average_losses_across_data_parallel_group
from megatron.training.utils import get_ltor_masks_and_position_ids
from megatron.training.utils import (
get_batch_on_this_cp_rank,
get_ltor_masks_and_position_ids,
)
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.transformer.spec_utils import import_module
from megatron.core import mpu
from megatron.core.utils import StragglerDetector

from examples.megatron.data.sft_dataset import build_train_valid_test_datasets


stimer = StragglerDetector()

def model_provider(pre_process=True, post_process=True):
"""Build the model."""

Expand Down Expand Up @@ -81,6 +87,10 @@ def model_provider(pre_process=True, post_process=True):

def get_batch(data_iterator):
"""Generate a batch"""

if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line of code can be simplified

return None, None, None, None, None

args = get_args()
tokenizer = get_tokenizer()

Expand Down Expand Up @@ -110,32 +120,75 @@ def get_batch(data_iterator):
args.eod_mask_loss,
)

return tokens, labels, loss_mask, attention_mask, position_ids

batch = {
"tokens": tokens,
"labels": labels,
"loss_mask": loss_mask,
"attention_mask": attention_mask,
"position_ids": position_ids
}
batch = get_batch_on_this_cp_rank(batch)
return batch.values()


def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
"""Loss function.

Args:
loss_mask (torch.Tensor): Used to mask out some portions of the loss
output_tensor (torch.Tensor): The tensor with the losses

Returns:
the loss scalar for this micro-batch
the number of non-padded tokens in this microbatch
a dict containing reporting metrics on the loss and number of tokens across
the data parallel ranks
"""
args = get_args()

def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])

if args.context_parallel_size > 1:
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())

# Check individual rank losses are not NaN prior to DP all-reduce.
if args.check_for_nan_in_loss_and_grad:
global_rank = torch.distributed.get_rank()
assert not loss[0].isnan(), (
f'Rank {global_rank}: found NaN in local forward loss calculation. '
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
)

# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])

return loss, {'lm loss': averaged_loss[0]}
reporting_loss = loss.clone().detach()
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())

local_num_tokens = loss[1].clone().detach().to(torch.int)
return (
loss[0] * args.context_parallel_size,
local_num_tokens,
{'lm loss': (reporting_loss[0], reporting_loss[1])},
)


def forward_step(data_iterator, model):
"""Forward step."""
"""Forward training step
"""
timers = get_timers()

# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
with stimer(bdata=True):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()

output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
with stimer:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)

return output_tensor, partial(loss_func, loss_mask)

Expand Down
Loading