Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add checkpoint #945

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion training/DeepSpeed-Domino/domino/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,20 @@ def parse_args():
'validation set.')
parser.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
parser.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
parser.add_argument('--no-save-optim', action='store_true', default=None,
help='Do not save current optimizer.')
parser.add_argument('--no-save-rng', action='store_true', default=None,
help='Do not save current rng state.')
parser.add_argument('--save-interval', type=int, default=None,
help='Number of iterations between checkpoint saves.')

parser.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.')
parser.add_argument('--no-load-optim', action='store_true', default=None,
help='Do not load optimizer when loading checkpoint.')
parser.add_argument('--no-load-rng', action='store_true', default=None,
help='Do not load rng state when loading checkpoint.')
args = parser.parse_args()

args.rank = int(os.getenv('RANK', '0'))
Expand Down
3 changes: 3 additions & 0 deletions training/DeepSpeed-Domino/domino/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from domino.modules.fused_bias_gelu import bias_gelu

from megatron import fused_kernels
import deepspeed


def initialize_domino():
Expand All @@ -37,6 +38,8 @@ def initialize_domino():
world_size=args.world_size,
rank=args.rank
)
deepspeed.init_distributed()

mpu.initialize_model_parallel(args.tensor_model_parallel_size)
seed_ = args.seed
data_parallel_random_init = False
Expand Down
153 changes: 152 additions & 1 deletion training/DeepSpeed-Domino/domino/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,70 @@ def forward(self, input_ids, position_ids):

return combined_embeds

def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load."""

state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
if self.add_position_embedding:
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)

return state_dict_

def load_state_dict(self, state_dict, strict=True):
"""Customized load."""

# Word embedding.
if self._word_embeddings_key in state_dict:
state_dict_ = state_dict[self._word_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] \
= state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)

# Position embedding.
if self.add_position_embedding:
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)

# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if 'tokentype_embeddings' in key:
state_dict_[key.split('tokentype_embeddings.')[1]] \
= state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_,
strict=strict)
else:
print('***WARNING*** expected tokentype embeddings in the '
'checkpoint but could not find it', flush=True)


class RotaryEmbedding(nn.Module):
def __init__(self, dim, seq_len_interpolation_factor=None):
Expand Down Expand Up @@ -190,4 +254,91 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
encoder_output = encoder_output_t

return encoder_output


def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""For easy load."""

state_dict_ = {}
if self.pre_process:
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.add_encoder:
state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.post_process:
if self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.untie_embeddings_and_output_weights:
state_dict_[self._output_layer_key] \
= self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars)

if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)

return state_dict_

def load_state_dict(self, state_dict, strict=True):
"""Customized load."""

# Embedding.
if self.pre_process:
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if '_embeddings' in key:
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)

# Encoder.
if self.add_encoder:
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# For backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else:
# For backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]

# For backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention

self.encoder.load_state_dict(state_dict_, strict=strict)

# Pooler.
if self.post_process:
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
if self.untie_embeddings_and_output_weights:
assert 'output_layer' in state_dict, \
'could not find data for output_layer in the checkpoint'
self.output_layer.load_state_dict(state_dict[self._output_layer_key],
strict=strict)
# Decoder.
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict)
16 changes: 15 additions & 1 deletion training/DeepSpeed-Domino/domino/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ def __init__(self, config=None, share_embeddings_and_output_weights=True):
self.config = config
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights

def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints.
"""

return self.state_dict(prefix=prefix, keep_vars=keep_vars)

def initialize_word_embeddings(self):
self.share_embeddings_and_output_weights = True
return
Expand Down Expand Up @@ -74,7 +81,8 @@ def float_conversion(val):
return conversion_helper(val, float_conversion)


class Float16Module(torch.nn.Module):
# class Float16Module(torch.nn.Module):
class Float16Module(DominoModule):

def __init__(self, module, args):
super(Float16Module, self).__init__()
Expand All @@ -91,3 +99,9 @@ def forward(self, *inputs, **kwargs):
outputs = float16_to_fp32(outputs)
return outputs

def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
""" Retrieve state_dict from the module being wrapped."""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)

def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
11 changes: 11 additions & 0 deletions training/DeepSpeed-Domino/domino/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from domino.initialize import set_jit_fusion_options
from domino.tensor_parallel.data import broadcast_data

from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint

def is_rank_0():
# if torch.cuda.current_device() == 0:
Expand Down Expand Up @@ -109,6 +111,11 @@ def setup_model_and_optimizer(base_model,
optimizer = get_megatron_optimizer(models, no_wd_decay_cond, scale_lr_cond)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)

if args.load is not None:
args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
else:
args.iteration = 0

args.iteration = 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this line. Otherwise, the args.iteration is wrong if the checkpoint is successfully loaded.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do we need to change args.train_iters? args.iteration may be larger than args.train_iters if there is a checkpoint.


return model, optimizer, opt_param_scheduler
Expand Down Expand Up @@ -297,6 +304,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
config)

iteration += 1
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)

args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * get_num_microbatches()

Expand Down
29 changes: 19 additions & 10 deletions training/DeepSpeed-Domino/megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@

import torch

from megatron import update_num_microbatches
from megatron.core import mpu, tensor_parallel
from .global_vars import get_args
from .utils import (unwrap_model,
print_rank_0)
# from megatron.core import mpu, tensor_parallel
# from .global_vars import get_args
# from .utils import (unwrap_model,
# print_rank_0)

# from megatron import update_num_microbatches
import domino.parallel_state as mpu
from domino.tensor_parallel.random import get_cuda_rng_tracker
from domino.arguments import get_args
from domino.utils import unwrap_model, print_rank_0

_CHECKPOINT_VERSION = None

Expand Down Expand Up @@ -194,7 +198,7 @@ def get_rng_state():
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
'rng_tracker_states': get_cuda_rng_tracker().get_states()}

rng_state_list = None
if torch.distributed.is_initialized() and \
Expand All @@ -218,6 +222,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):

# Only rank zero of the data parallel writes to the disk.
model = unwrap_model(model)
model_module = model.module
model = [model_module]

print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
Expand All @@ -241,7 +247,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):

# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
# state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
if len(model) == 1:
Expand Down Expand Up @@ -503,6 +509,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
load_dir = getattr(args, load_arg)

model = unwrap_model(model)
model_module = model.module
model = [model_module]

state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=False)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no exit_on_missing_checkpoint argument now. It will produce the following error if one has argument load but no checkpoint.

[rank1]:     if not args.exit_on_missing_checkpoint:
[rank1]: AttributeError: 'Namespace' object has no attribute 'exit_on_missing_checkpoint'

Expand All @@ -522,6 +530,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
set_checkpoint_version(state_dict.get('checkpoint_version', 0))

# Set iteration.
args.finetune = False
if args.finetune or release:
iteration = 0
else:
Expand All @@ -544,7 +553,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0)
update_num_microbatches(consumed_samples=args.consumed_train_samples)
# update_num_microbatches(consumed_samples=args.consumed_train_samples)
args.consumed_valid_samples = getattr(checkpoint_args,
'consumed_valid_samples', 0)
else:
Expand Down Expand Up @@ -614,7 +623,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(
get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(state_dict['random_rng_state'])
Expand All @@ -624,7 +633,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(
get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
Expand Down
2 changes: 1 addition & 1 deletion training/DeepSpeed-Domino/megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from megatron import get_tensorboard_writer
from megatron.core import mpu, tensor_parallel
from megatron.arguments import parse_args, validate_args
from megatron.checkpointing import load_args_from_checkpoint
# from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
Expand Down