diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 973fb923..26d5f7ed 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -190,20 +190,12 @@ def update_model_remat_config( ): """Recomputes and sets the remat_spec based on provided layer_cfg. - Only applied if the stack_cfg is a RepeatedTransformerLayer. - Args: stack_cfg: The transformer stack config. layer_cfg: The transformer layer config. offload_dst: Destination of remat checkptoing offloading. - Raises: - NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer. """ - if stack_cfg.klass is not RepeatedTransformerLayer: - raise NotImplementedError( - f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}" - ) remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg)) layer_cfg.set(remat_spec=remat_spec) @@ -277,8 +269,7 @@ def model_config( layer_cfg.self_attention.attention.input_linear = attention_qkv_linear layer_cfg.self_attention.structure = atten_structure layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap - if stack_cfg.klass is RepeatedTransformerLayer: - update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) + update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) # Stack. transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg) decoder_cfg = Decoder.default_config().set( diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b110..63670738 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -26,6 +26,7 @@ MultiheadAttention, RepeatedTransformerLayer, RoFormerQKVLinear, + _save_and_offload_only_these_names_regex, ) from axlearn.common.base_layer import RematSpec from axlearn.common.config import config_for_function @@ -86,6 +87,14 @@ class Version(enum.Enum): } +# Regex patterns for matching remat names +class RematRegex(enum.Enum): + QKV_PROJ = r".*\.?(k|q|v)_proj" + LINEAR1_X = r".*\.?linear1_[01]" + RESIDUAL_ADD = r"(TransformerAttentionLayer\.residual_add" + MLP_RESIDUAL = r"TransformerFeedForwardLayer\.mlp_residual)" + + # Mapping from Fuji versions to total number of tokens used in training. TOTAL_TOKENS = { Version.V1: { @@ -417,6 +426,38 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=config_for_function( + _save_and_offload_only_these_names_regex + ).set( + names_which_can_be_saved="|".join( + [ + RematRegex.QKV_PROJ.value, + RematRegex.LINEAR1_X.value, + RematRegex.RESIDUAL_ADD.value, + RematRegex.MLP_RESIDUAL.value, + ] + ), + names_which_can_be_offloaded=None, + offload_src=None, + offload_dst=None, + ), + ), + } + ), + ], + ), + ), ), ) else: