Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] MoE load balancing loss is accumulated twice when using activation checkpointing #1330

Open
thuwzt opened this issue Dec 20, 2024 · 0 comments

Comments

@thuwzt
Copy link

thuwzt commented Dec 20, 2024

Describe the bug
Load balancing loss is accumulated twice when using activation checkpointing

To Reproduce
Train from scratch with / without --moe-layer-recompute, setting --moe-router-load-balancing-type aux_loss

Expected behavior
Load balancing loss should be the same in the two settings (and should be slightly higher than 1 which means fully balanced)

Stack trace/logs

  • without --moe-layer-recompute:
    iteration 10: load_balancing_loss: 1.091395E+00
    iteration 20: load_balancing_loss: 1.096082E+00
    iteration 30: load_balancing_loss: 1.037049E+00

  • with --moe-layer-recompute:
    iteration 10: load_balancing_loss: 2.202137E+00
    iteration 20: load_balancing_loss: 2.298303E+00
    iteration 30: load_balancing_loss: 2.120842E+00

Environment (please complete the following information):

  • Megatron-LM d4e72c0
  • PyTorch 2.4.1
  • CUDA 12.1
  • NCCL 2.20.5

Proposed fix
Replace if self.training with if self.training and torch.is_grad_enabled():.

Reason: When using activation checkpointing with --moe-layer-recompute, the forward function is executed twice. This leads to the load balancing loss being accumulated twice in TopKRouter.aux_loss_load_balancing within megatron/core/transformer/moe/router.py if the condition is only if self.training:. By changing the condition to if self.training and torch.is_grad_enabled():, the accumulation during the first pass (where gradients are not enabled) is prevented, while ensuring the standard training process without --moe-layer-recompute remains unaffected.

A similar issue occurs with z_loss.

The fix is included in the PR #1331.

Additional context
N/A

@thuwzt thuwzt changed the title [BUG] MoE load balancing is accumulated twice when using activation checkpointing [BUG] MoE load balancing loss is accumulated twice when using activation checkpointing Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant