You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Proposed fix
Replace if self.training with if self.training and torch.is_grad_enabled():.
Reason: When using activation checkpointing with --moe-layer-recompute, the forward function is executed twice. This leads to the load balancing loss being accumulated twice in TopKRouter.aux_loss_load_balancing within megatron/core/transformer/moe/router.py if the condition is only if self.training:. By changing the condition to if self.training and torch.is_grad_enabled():, the accumulation during the first pass (where gradients are not enabled) is prevented, while ensuring the standard training process without --moe-layer-recompute remains unaffected.
thuwzt
changed the title
[BUG] MoE load balancing is accumulated twice when using activation checkpointing
[BUG] MoE load balancing loss is accumulated twice when using activation checkpointing
Dec 20, 2024
Describe the bug
Load balancing loss is accumulated twice when using activation checkpointing
To Reproduce
Train from scratch with / without
--moe-layer-recompute
, setting--moe-router-load-balancing-type aux_loss
Expected behavior
Load balancing loss should be the same in the two settings (and should be slightly higher than 1 which means fully balanced)
Stack trace/logs
without
--moe-layer-recompute
:iteration 10: load_balancing_loss: 1.091395E+00
iteration 20: load_balancing_loss: 1.096082E+00
iteration 30: load_balancing_loss: 1.037049E+00
with
--moe-layer-recompute
:iteration 10: load_balancing_loss: 2.202137E+00
iteration 20: load_balancing_loss: 2.298303E+00
iteration 30: load_balancing_loss: 2.120842E+00
Environment (please complete the following information):
Proposed fix
Replace
if self.training
withif self.training and torch.is_grad_enabled():
.Reason: When using activation checkpointing with
--moe-layer-recompute
, the forward function is executed twice. This leads to the load balancing loss being accumulated twice inTopKRouter.aux_loss_load_balancing
withinmegatron/core/transformer/moe/router.py
if the condition is onlyif self.training:
. By changing the condition toif self.training and torch.is_grad_enabled():
, the accumulation during the first pass (where gradients are not enabled) is prevented, while ensuring the standard training process without--moe-layer-recompute
remains unaffected.A similar issue occurs with z_loss.
The fix is included in the PR #1331.
Additional context
N/A
The text was updated successfully, but these errors were encountered: