Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]feature(mixtral): support Mixtral-8x7B SFT, Reward, and Alignment #95

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
eb698e9
Support MCore models
haolin-nju Sep 10, 2024
b470622
Fix Megatron version
haolin-nju Sep 10, 2024
41db596
fix pylint
haolin-nju Sep 10, 2024
636141c
fix missing import in reward_inference
haolin-nju Sep 11, 2024
004211c
fix vllm
haolin-nju Sep 11, 2024
1d64707
refine hint
haolin-nju Sep 11, 2024
b7367a2
fix Makefile
haolin-nju Sep 11, 2024
a6fac05
Merge branch 'main' into mixtral
haolin-nju Sep 19, 2024
7ef59ca
Support Mixtral MoE
haolin-nju Sep 19, 2024
b2e544e
fix mixtral
haolin-nju Sep 20, 2024
1aa2301
fix model mixtral
haolin-nju Sep 20, 2024
5e4643b
Merge branch 'main' into mixtral
haolin-nju Sep 20, 2024
19143f1
align diff to main
haolin-nju Sep 20, 2024
20e6272
fix dpo and add test_checkpoint_conversion
haolin-nju Sep 20, 2024
f68afd1
fix error msg
haolin-nju Sep 20, 2024
d7152ac
fix src_gpu to get_or_cache
haolin-nju Sep 30, 2024
a554cf6
revert "fix src_gpu to get_or_cache"
haolin-nju Oct 8, 2024
5424d0b
Merge branch 'main' into mixtral
haolin-nju Nov 5, 2024
caebecf
fix comments
haolin-nju Nov 5, 2024
ef30253
fix import error
haolin-nju Nov 6, 2024
752f5ef
Merge branch 'mixtral' of github.com:alibaba/ChatLearn into mixtral
haolin-nju Nov 6, 2024
7cc7aa9
fix diff introduced in merge
haolin-nju Nov 6, 2024
7420414
Merge branch 'main' into mixtral
haolin-nju Nov 6, 2024
c027414
expost import error earlier
haolin-nju Nov 7, 2024
77f4230
fix merge error
haolin-nju Nov 7, 2024
2896285
fix recursion error
haolin-nju Nov 7, 2024
333e661
add validate_param_sync option to mixtral models
haolin-nju Nov 7, 2024
d86136f
fix redundant empty lines
haolin-nju Nov 7, 2024
4053415
Merge branch 'main' into mixtral
haolin-nju Nov 7, 2024
cd00af0
fix scripts
haolin-nju Nov 8, 2024
7ed2eff
mixtral sft ok
haolin-nju Nov 11, 2024
f23ac37
fix sft
haolin-nju Nov 12, 2024
27bdbcc
Merge branch 'main' into mixtral
haolin-nju Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 52 additions & 23 deletions chatlearn/models/megatron/memory_manager/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
MixedPrecisionOptimizer,
DistributedOptimizer,
Float16OptimizerWithFloat16Params,
ChainedOptimizer,
)


Expand Down Expand Up @@ -92,34 +93,58 @@ def __init__(
self._use_distributed_optimizer = use_distributed_optimizer
self._bucket_size_mb = bucket_size_mb

def sanity_check(single_optimizer):
assert isinstance(
single_optimizer, (MixedPrecisionOptimizer,)
), f'Only support optimizer type MixedPrecisionOptimizer and its subclasses, current type is {str(type(optimizer))}.'

if self._use_distributed_optimizer:
assert isinstance(single_optimizer, DistributedOptimizer)
else:
log_rank_0('Current optimizer is Float16OptimizerWithFloat16Params')
assert isinstance(single_optimizer, Float16OptimizerWithFloat16Params)

assert isinstance(
model, (DistributedDataParallel,)
), f'Only support model type DistributedDataParallel, current type is {str(type(model))}.'
assert isinstance(
optimizer, (MixedPrecisionOptimizer,)
), f'Only support optimizer type MixedPrecisionOptimizer and its subclasses, current type is {str(type(optimizer))}.'

# sanity check
if self._use_distributed_optimizer:
assert isinstance(optimizer, DistributedOptimizer)
if isinstance(optimizer, ChainedOptimizer):
for single_optimizer in optimizer.chained_optimizers:
sanity_check(single_optimizer)
self._is_chained_optimizer = True
else:
log_rank_0('Current optimizer is Float16OptimizerWithFloat16Params')
assert isinstance(optimizer, Float16OptimizerWithFloat16Params)
sanity_check(optimizer)
self._is_chained_optimizer = False

self._main_weights_offloaded = False
self._group_flat_main_weights: Optional[List[BucketizedFlatTensors]] = None

self._megatron_version = get_megatron_version()

def _optimizer_load_state_bucket_into_device(self, device):
def get_optimizer_list(self):
if self._is_chained_optimizer:
optimizer_list = self._optimizer.chained_optimizers
else:
optimizer_list = [self._optimizer]
return optimizer_list

def _optimizer_load_state_bucket_into_device(self, device, optimizer=None):
"""put the state bucket onto a device"""
state_dict = self._optimizer.optimizer.state_dict()
for tensors in state_dict['state'].values():
keys = list(tensors.keys())
for key in keys:
# compatible with transformer_engine v1.10, state['master_param']=None
if tensors[key] is not None:
tensors[key] = tensors[key].to(device=device, non_blocking=True)
if optimizer is not None:
if isinstance(optimizer, ChainedOptimizer):
optimizer_list = optimizer.chained_optimizers
else:
optimizer_list = [optimizer]
else:
optimizer_list = self.get_optimizer_list()

for single_optimizer in optimizer_list:
state_dict = single_optimizer.optimizer.state_dict()
for tensors in state_dict['state'].values():
keys = list(tensors.keys())
for key in keys:
# compatible with transformer_engine v1.10, state['master_param']=None
if tensors[key] is not None:
tensors[key] = tensors[key].to(device=device, non_blocking=True)
# make sure the loading is finished before returning
torch.cuda.synchronize()

Expand Down Expand Up @@ -154,12 +179,16 @@ def offload_main_weights(self):
return

if self._group_flat_main_weights is None:
if self._use_distributed_optimizer:
self._group_flat_main_weights = self._flat_param_groups(
[self._optimizer.shard_fp32_from_float16_groups]
)
else:
self._group_flat_main_weights = self._flat_param_groups([self._optimizer.fp32_from_float16_groups])
self._group_flat_main_weights = []
optimizer_list = self.get_optimizer_list()

for optimizer in optimizer_list:
if self._use_distributed_optimizer:
self._group_flat_main_weights.extend(self._flat_param_groups(
[optimizer.shard_fp32_from_float16_groups]
))
else:
self._group_flat_main_weights.extend(self._flat_param_groups([optimizer.fp32_from_float16_groups]))

for flat_main_weights in self._group_flat_main_weights:
flat_main_weights.copy_to_primary_store()
Expand Down
204 changes: 106 additions & 98 deletions chatlearn/models/megatron/memory_manager/trainer_v1v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,31 @@ def offload_weights(self):
log_rank_0('Call offload_weights when already offloaded. Ignore it.')
return

optimizer = self._optimizer
optimizer_list = self.get_optimizer_list()

if self._use_distributed_optimizer:
optimizer.shard_float16_groups.clear()
optimizer.shard_fp32_groups.clear()
for optimizer in optimizer_list:
optimizer.shard_float16_groups.clear()
optimizer.shard_fp32_groups.clear()

if self._group_flat_weights is None:
if self._use_distributed_optimizer:
self._group_flat_weights = self._flat_param_groups(
[
optimizer.model_float16_groups,
optimizer.model_fp32_groups,
],
)
else:
self._group_flat_weights = self._flat_param_groups(
[
optimizer.float16_groups,
optimizer.fp32_from_fp32_groups,
],
)
self._group_flat_weights = []

for optimizer in optimizer_list:
if self._use_distributed_optimizer:
self._group_flat_weights.extend(self._flat_param_groups(
[
optimizer.model_float16_groups,
optimizer.model_fp32_groups,
],
))
else:
self._group_flat_weights.extend(self._flat_param_groups(
[
optimizer.float16_groups,
optimizer.fp32_from_fp32_groups,
],
))

for flat_weights in self._group_flat_weights:
flat_weights.copy_to_primary_store()
Expand All @@ -124,7 +128,7 @@ def onload_weights(self):
log_rank_0('Call onload_weights when already onloaded. Ignore it.')
return

optimizer = self._optimizer
optimizer_list = self.get_optimizer_list()

for flat_weights in self._group_flat_weights:
flat_weights.copy_to_gpu_buffer()
Expand All @@ -148,55 +152,56 @@ def onload_weights(self):
self._weights_offloaded = False
return

shard_float16_groups = optimizer.shard_float16_groups
shard_fp32_groups = optimizer.shard_fp32_groups
param_gbuf_map = optimizer.model_param_gbuf_map
opt_group_ranges = optimizer.opt_group_ranges
model_gbuf_ranges = optimizer.model_gbuf_ranges

# Rebuild shard_float16_groups and shard_fp32_groups,
# see Megatron DistributedOptimizer#build_model_and_main_param_groups.
for _, group_range in enumerate(opt_group_ranges):
shard_float16_params_this_group = []
shard_fp32_params_this_group = []
shard_float16_groups.append(shard_float16_params_this_group)
shard_fp32_groups.append(shard_fp32_params_this_group)

for model_param in group_range["params"]:
assert model_param.requires_grad
if self._megatron_version == MegatronVersion.V2:
model_index, dtype, bucket_index = param_gbuf_map[model_param]
gbuf_range = model_gbuf_ranges[model_index][dtype][bucket_index]
param_range = gbuf_range["param_map"][model_param]["param"]
elif self._megatron_version == MegatronVersion.V1:
model_index, dtype = param_gbuf_map[model_param]
gbuf_range = model_gbuf_ranges[model_index][dtype]
param_range = gbuf_range["param_map"][model_param]["param"]

# fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end]
tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared

shard_float16_params_this_group.append(shard_model_param)

# fp32 params.
elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
shard_fp32_params_this_group.append(shard_model_param)
tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
else:
raise TypeError(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(model_param.type())
)
for optimizer in optimizer_list:
shard_float16_groups = optimizer.shard_float16_groups
shard_fp32_groups = optimizer.shard_fp32_groups
param_gbuf_map = optimizer.model_param_gbuf_map
opt_group_ranges = optimizer.opt_group_ranges
model_gbuf_ranges = optimizer.model_gbuf_ranges

# Rebuild shard_float16_groups and shard_fp32_groups,
# see Megatron DistributedOptimizer#build_model_and_main_param_groups.
for _, group_range in enumerate(opt_group_ranges):
shard_float16_params_this_group = []
shard_fp32_params_this_group = []
shard_float16_groups.append(shard_float16_params_this_group)
shard_fp32_groups.append(shard_fp32_params_this_group)

for model_param in group_range["params"]:
assert model_param.requires_grad
if self._megatron_version == MegatronVersion.V2:
model_index, dtype, bucket_index = param_gbuf_map[model_param]
gbuf_range = model_gbuf_ranges[model_index][dtype][bucket_index]
param_range = gbuf_range["param_map"][model_param]["param"]
elif self._megatron_version == MegatronVersion.V1:
model_index, dtype = param_gbuf_map[model_param]
gbuf_range = model_gbuf_ranges[model_index][dtype]
param_range = gbuf_range["param_map"][model_param]["param"]

# fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end]
tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared

shard_float16_params_this_group.append(shard_model_param)

# fp32 params.
elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
shard_fp32_params_this_group.append(shard_model_param)
tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
else:
raise TypeError(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(model_param.type())
)

self._weights_offloaded = False

Expand All @@ -208,16 +213,17 @@ def free_grad_buffers(self):
log_rank_0('Call free_grad_buffers when already freed. Ignore it.')
return

optimizer = self._optimizer
optimizer_list = self.get_optimizer_list()
grad_dtype_to_params = self._grad_dtype_to_params

# This is necessary, but don't know why.
optimizer.zero_grad(True)
for optimizer in optimizer_list:
# This is necessary, but don't know why.
optimizer.zero_grad(True)

if self._use_distributed_optimizer:
# Release param_buffers because they share storage with grad_buffers.
# Note: param_buffers are only available in DistributedOptimizer.
optimizer.param_buffers.clear()
if self._use_distributed_optimizer:
# Release param_buffers because they share storage with grad_buffers.
# Note: param_buffers are only available in DistributedOptimizer.
optimizer.param_buffers.clear()

# Release grad_buffers, including buckets in GradBuffer for newer Megatron version.
# Release `main_grad` of parameters.
Expand Down Expand Up @@ -249,7 +255,7 @@ def build_grad_buffers(self):
log_rank_0('Call build_grad_buffers when already built. Ignore it.')
return

optimizer = self._optimizer
optimizer_list = self.get_optimizer_list()
params_dtype = self._params_dtype
grad_dtype_to_params = self._grad_dtype_to_params

Expand Down Expand Up @@ -283,31 +289,33 @@ def build_grad_buffers(self):
return

# Re-allocate param_buffers, see Megatron DistributedOptimizer#__init__.
optimizer.param_buffers = []
for _, _ in enumerate(optimizer.models):
current_param_buffers = {}
for dtype, grad_buffer in self.get_grad_buffers().items():
current_param_buffers[dtype] = []
if self._megatron_version == MegatronVersion.V2:
for bucket in grad_buffer.buckets:
# pylint: disable=too-many-nested-blocks
for optimizer in optimizer_list:
optimizer.param_buffers = []
for _, _ in enumerate(optimizer.models):
current_param_buffers = {}
for dtype, grad_buffer in self.get_grad_buffers().items():
current_param_buffers[dtype] = []
if self._megatron_version == MegatronVersion.V2:
for bucket in grad_buffer.buckets:
try:
storage = bucket.data.storage()._untyped()
# pylint: disable-next=bare-except
except:
storage = bucket.data.storage().untyped()

param_buffer = torch.tensor([], dtype=params_dtype, device=bucket.data.device).set_(storage)
param_buffer = param_buffer[bucket.offset : bucket.offset + bucket.data.numel()]
current_param_buffers[dtype].append(param_buffer)
elif self._megatron_version == MegatronVersion.V1:
try:
storage = bucket.data.storage()._untyped()
storage = grad_buffer.data.storage()._untyped()
# pylint: disable-next=bare-except
except:
storage = bucket.data.storage().untyped()

param_buffer = torch.tensor([], dtype=params_dtype, device=bucket.data.device).set_(storage)
param_buffer = param_buffer[bucket.offset : bucket.offset + bucket.data.numel()]
current_param_buffers[dtype].append(param_buffer)
elif self._megatron_version == MegatronVersion.V1:
try:
storage = grad_buffer.data.storage()._untyped()
# pylint: disable-next=bare-except
except:
storage = grad_buffer.data.storage().untyped()
param_buffer = torch.tensor([], dtype=params_dtype, device=grad_buffer.data.device).set_(storage)
param_buffer = param_buffer[: grad_buffer.numel_padded]
current_param_buffers[dtype] = param_buffer
optimizer.param_buffers.append(current_param_buffers)
storage = grad_buffer.data.storage().untyped()
param_buffer = torch.tensor([], dtype=params_dtype, device=grad_buffer.data.device).set_(storage)
param_buffer = param_buffer[: grad_buffer.numel_padded]
current_param_buffers[dtype] = param_buffer
optimizer.param_buffers.append(current_param_buffers)

self._grad_buffers_freed = False
Loading
Loading