diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 13c5e30d..9c3759f2 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -1459,3 +1459,30 @@ def validate_contains_paths(x: Nested[Tensor], paths: Sequence[str]): f"Input is expected to contain '{path}'; " f"instead, it contains: '{jax.tree_structure(x)}'." ) from e + + +def save_only_these_regex_patterns(*regex_patterns_to_save): + """Save only the values that match the regex pattern. + + Args: + regexes_to_save: List of regex patterns to save + + Returns: + Callable: Policy that matches regex + """ + regex_patterns_to_save = frozenset(regex_patterns_to_save) + + def policy(*_, **params): + if "name" in params: + param_name = params["name"] + for regex_to_save in regex_patterns_to_save: + if re.search(regex_to_save, param_name): + # if name exists and matches any regex pattern specified + logging.info("Remat: Saving %s", param_name) + return True + # named but not specified + logging.info("Remat: Not saving %s", param_name) + # Unnamed tensor is not saved + return False + + return policy diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 973fb923..26d5f7ed 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -190,20 +190,12 @@ def update_model_remat_config( ): """Recomputes and sets the remat_spec based on provided layer_cfg. - Only applied if the stack_cfg is a RepeatedTransformerLayer. - Args: stack_cfg: The transformer stack config. layer_cfg: The transformer layer config. offload_dst: Destination of remat checkptoing offloading. - Raises: - NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer. """ - if stack_cfg.klass is not RepeatedTransformerLayer: - raise NotImplementedError( - f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}" - ) remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg)) layer_cfg.set(remat_spec=remat_spec) @@ -277,8 +269,7 @@ def model_config( layer_cfg.self_attention.attention.input_linear = attention_qkv_linear layer_cfg.self_attention.structure = atten_structure layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap - if stack_cfg.klass is RepeatedTransformerLayer: - update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) + update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) # Stack. transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg) decoder_cfg = Decoder.default_config().set( diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b110..2d8df6be 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -39,7 +39,7 @@ MeshShapeModifier, RematSpecModifier, ) -from axlearn.common.utils import extended_checkpoint_policies +from axlearn.common.utils import extended_checkpoint_policies, save_only_these_regex_patterns from axlearn.experiments.text.gpt.common import ( STEP_DTYPE, SourceBuilder, @@ -417,6 +417,36 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=config_for_function( + save_only_these_regex_patterns + ).set( + # pylint: disable=anomalous-backslash-in-string + regex_patterns_to_save=[ + r"TransformerAttentionLayer\.residual_add", + r"\.?(k|q|v)_proj$", + r"\.?linear1_0", + r"\.?linear1_1", + r"TransformerFeedForwardLayer\.mlp_residual", + ] + # pylint: enable=anomalous-backslash-in-string + ), + ), + } + ), + ], + ), + ), ), ) else: