diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 37baf3d8..448d728f 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -55,6 +55,7 @@ # pylint: disable=abstract-method,too-many-lines import enum import functools +import logging import math import re from collections.abc import Sequence @@ -3934,6 +3935,33 @@ def forward( _SavePattern = Union[str, re.Pattern, None] +def save_only_these_regex_patterns(*regex_patterns_to_save): + """Save only the values that match the regex pattern. + + Args: + regexes_to_save: List of regex patterns to save + + Returns: + Callable: Policy that matches regex + """ + regex_patterns_to_save = frozenset(regex_patterns_to_save) + + def policy(*_, **params): + if "name" in params: + param_name = params["name"] + for regex_to_save in regex_patterns_to_save: + if re.search(regex_to_save, param_name): + # if name exists and matches any regex pattern specified + logging.info("Remat: Saving %s", param_name) + return True + # named but not specified + logging.info("Remat: Not saving %s", param_name) + # Unnamed tensor is not saved + return False + + return policy + + # Adapted from jax source code to support regex. Reference: # https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120 def _save_and_offload_only_these_names_regex( diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 973fb923..26d5f7ed 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -190,20 +190,12 @@ def update_model_remat_config( ): """Recomputes and sets the remat_spec based on provided layer_cfg. - Only applied if the stack_cfg is a RepeatedTransformerLayer. - Args: stack_cfg: The transformer stack config. layer_cfg: The transformer layer config. offload_dst: Destination of remat checkptoing offloading. - Raises: - NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer. """ - if stack_cfg.klass is not RepeatedTransformerLayer: - raise NotImplementedError( - f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}" - ) remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg)) layer_cfg.set(remat_spec=remat_spec) @@ -277,8 +269,7 @@ def model_config( layer_cfg.self_attention.attention.input_linear = attention_qkv_linear layer_cfg.self_attention.structure = atten_structure layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap - if stack_cfg.klass is RepeatedTransformerLayer: - update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) + update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) # Stack. transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg) decoder_cfg = Decoder.default_config().set( diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b110..5708c173 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -26,6 +26,7 @@ MultiheadAttention, RepeatedTransformerLayer, RoFormerQKVLinear, + save_only_these_regex_patterns, ) from axlearn.common.base_layer import RematSpec from axlearn.common.config import config_for_function @@ -417,6 +418,36 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=True, + policy=config_for_function( + save_only_these_regex_patterns + ).set( + # pylint: disable=anomalous-backslash-in-string + regex_patterns_to_save=[ + r"TransformerAttentionLayer\.residual_add", + r"\.?(k|q|v)_proj$", + r"\.?linear1_0", + r"\.?linear1_1", + r"TransformerFeedForwardLayer\.mlp_residual", + ] + # pylint: enable=anomalous-backslash-in-string + ), + ), + } + ), + ], + ), + ), ), ) else: