Skip to content

Commit

Permalink
Improving memory utilization of Z2+MoE (#2079)
Browse files Browse the repository at this point in the history
* Shards expert parameter groups
* Do upscaling, optimizer and deletion of fp32 grads one-by-one on each parameter group in zero-2
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
siddharth9820 authored Jul 13, 2022
1 parent b052378 commit c1af73f
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 63 deletions.
35 changes: 30 additions & 5 deletions deepspeed/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def split_params_grads_into_shared_and_expert_params(
return shared_grads, expert_grads


def split_params_into_different_moe_groups_for_optimizer(
param_groups: Tuple[Dict]) -> Tuple[Dict]:
def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict],
max_group_size=178956971
) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer
Args:
Expand Down Expand Up @@ -112,8 +113,32 @@ def split_params_into_different_moe_groups_for_optimizer(
param_group['params'] = new_params

# Flatten the moe groups
for k, v in group_moe.items():
for k1, v1 in v.items():
param_groups.append(v1)
if max_group_size is not None:
for k, v in group_moe.items():
for k1, v1 in v.items():
cur_group = []
all_groups = []
size_of_cur_group = 0
for param in v1['params']:
if size_of_cur_group + param.numel() <= max_group_size:
cur_group.append(param)
size_of_cur_group += param.numel()
else:
all_groups.append(cur_group)
cur_group = [param]
size_of_cur_group = param.numel()
if cur_group:
all_groups.append(cur_group)
for group in all_groups:
new_dict = {}
for key, val in v1.items():
if key != 'params':
new_dict[key] = val
new_dict['params'] = group
param_groups.append(new_dict)
else:
for k, v in group_moe.items():
for k1, v1 in v.items():
param_groups.append(v1)

return tuple(param_groups)
133 changes: 75 additions & 58 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,44 @@ def override_loss_scale(self, loss_scale):
self.custom_loss_scaler = True
self.external_loss_scale = loss_scale

def scaled_global_norm(self, norm_type=2):
assert norm_type == 2, "only L2 norm supported"
norm_groups = []
for i, group in enumerate(self.bit16_groups):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload:
norm_groups.append(
self.complete_grad_norm_calculation_for_cpu_offload(
self.params_in_partition[i]))
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
else:
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))

if self.has_moe_layers:
self._average_expert_grad_norms(norm_groups)

# note that the get_global_norm function only supports l2 norm
return get_global_norm(norm_list=norm_groups)

def get_bit16_param_group(self, group_no):
bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]
partition_id = dist.get_rank(group=self.real_dp_process_group[group_no])
return [
bit16_partitions[dist.get_rank(group=self.real_dp_process_group[group_no])]
]

def _optimizer_step(self, group_no):
original_param_groups = self.optimizer.param_groups
self.optimizer.param_groups = [original_param_groups[group_no]]
from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)])
else:
self.optimizer.step()
self.optimizer.param_groups = original_param_groups

def step(self, closure=None):
"""
Not supporting closure.
Expand All @@ -1671,7 +1709,6 @@ def step(self, closure=None):
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:

if dist.get_rank() == 0:
logger.info(
"[deepspeed] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
Expand All @@ -1692,22 +1729,33 @@ def step(self, closure=None):
self.stop_timers(timer_names)
return

self.start_timers([OPTIMIZER_GRADIENTS])
norm_groups = []
single_partition_grad_groups = []
# skip = False
# Step 1:- Calculate gradient norm using fp-16 grads
see_memory_usage('Before norm calculation')
scaled_global_grad_norm = self.scaled_global_norm()
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale

see_memory_usage('After norm before optimizer')
# Step 2:- run optimizer and upscaling simultaneously
for i, group in enumerate(self.bit16_groups):
self.start_timers([OPTIMIZER_GRADIENTS])
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload:
norm_groups.append(
self.complete_grad_norm_calculation_for_cpu_offload(
self.params_in_partition[i]))
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
else:
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))
self.unscale_and_clip_grads([single_grad_partition],
scaled_global_grad_norm)
self.stop_timers([OPTIMIZER_GRADIENTS])
self.start_timers([OPTIMIZER_STEP])
self._optimizer_step(i)

from deepspeed.ops.adam import DeepSpeedCPUAdam
if not (type(self.optimizer) == DeepSpeedCPUAdam
and self.dtype == torch.half):
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data)

self.stop_timers([OPTIMIZER_STEP])
else:
# free gradients for all the parameters that are not updated by this process(ZeRO stage2)
self.free_grad_in_param_list(self.params_not_in_partition[i])

Expand All @@ -1732,53 +1780,22 @@ def step(self, closure=None):

self.averaged_gradients[i] = None

single_partition_grad_groups.append(single_grad_partition)

if self.has_moe_layers:
self._average_expert_grad_norms(norm_groups)

scaled_global_grad_norm = get_global_norm(norm_list=norm_groups)
self.unscale_and_clip_grads(single_partition_grad_groups,
scaled_global_grad_norm)

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale

self.stop_timers([OPTIMIZER_GRADIENTS])

self.start_timers([OPTIMIZER_STEP])
if self.deepspeed_adam_offload:
from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
bit16_param_groups = [
[
bit16_partitions[dist.get_rank(
group=self.real_dp_process_group[group_id])]
] for group_id,
bit16_partitions in enumerate(self.parallel_partitioned_bit16_groups)
]
self.optimizer.step(fp16_param_groups=bit16_param_groups)
else:
self.optimizer.step()
for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(
group=self.real_dp_process_group[group_id])

bit16_partitions[partition_id].data.copy_(fp32_partition.data)
else:
self.optimizer.step()

# get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None # in step

for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(group=self.real_dp_process_group[group_id])
self.unscale_and_clip_grads([single_grad_partition],
scaled_global_grad_norm)
self.stop_timers([OPTIMIZER_GRADIENTS])

# Step 3:- run the optimizer if no offloading
self.start_timers([OPTIMIZER_STEP])
self._optimizer_step(i)
# Step 4:- get rid of the fp32 gradients. Not needed anymore
self.single_partition_of_fp32_groups[i].grad = None
del single_grad_partition
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
self.stop_timers([OPTIMIZER_STEP])

self.stop_timers([OPTIMIZER_STEP])

see_memory_usage('After optimizer before all-gather')
if self.cpu_offload:
self.reset_cpu_buffers()

Expand All @@ -1794,7 +1811,7 @@ def step(self, closure=None):
self.stop_timers([OPTIMIZER_ALLGATHER])

# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
for i in range(len(self.bit16_groups)):
self._update_model_bit16_weights(i)

self.log_timers(timer_names)
Expand Down

0 comments on commit c1af73f

Please sign in to comment.