Skip to content

Commit

Permalink
enable special remat for neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 17, 2024
1 parent 3ae8f9f commit 9715aef
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 10 deletions.
28 changes: 28 additions & 0 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
# pylint: disable=abstract-method,too-many-lines
import enum
import functools
import logging
import math
import re
from collections.abc import Sequence
Expand Down Expand Up @@ -3934,6 +3935,33 @@ def forward(
_SavePattern = Union[str, re.Pattern, None]


def save_only_these_regex_patterns(*regex_patterns_to_save):
"""Save only the values that match the regex pattern.
Args:
regexes_to_save: List of regex patterns to save
Returns:
Callable: Policy that matches regex
"""
regex_patterns_to_save = frozenset(regex_patterns_to_save)

def policy(*_, **params):
if "name" in params:
param_name = params["name"]
for regex_to_save in regex_patterns_to_save:
if re.search(regex_to_save, param_name):
# if name exists and matches any regex pattern specified
logging.info("Remat: Saving %s", param_name)
return True
# named but not specified
logging.info("Remat: Not saving %s", param_name)
# Unnamed tensor is not saved
return False

return policy


# Adapted from jax source code to support regex. Reference:
# https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120
def _save_and_offload_only_these_names_regex(
Expand Down
11 changes: 1 addition & 10 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,20 +190,12 @@ def update_model_remat_config(
):
"""Recomputes and sets the remat_spec based on provided layer_cfg.
Only applied if the stack_cfg is a RepeatedTransformerLayer.
Args:
stack_cfg: The transformer stack config.
layer_cfg: The transformer layer config.
offload_dst: Destination of remat checkptoing offloading.
Raises:
NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer.
"""
if stack_cfg.klass is not RepeatedTransformerLayer:
raise NotImplementedError(
f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}"
)

remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg))
layer_cfg.set(remat_spec=remat_spec)
Expand Down Expand Up @@ -277,8 +269,7 @@ def model_config(
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear
layer_cfg.self_attention.structure = atten_structure
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap
if stack_cfg.klass is RepeatedTransformerLayer:
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
# Stack.
transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg)
decoder_cfg = Decoder.default_config().set(
Expand Down
31 changes: 31 additions & 0 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MultiheadAttention,
RepeatedTransformerLayer,
RoFormerQKVLinear,
save_only_these_regex_patterns,
)
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import config_for_function
Expand Down Expand Up @@ -417,6 +418,36 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=config_for_function(
save_only_these_regex_patterns
).set(
# pylint: disable=anomalous-backslash-in-string
regex_patterns_to_save=[
r"TransformerAttentionLayer\.residual_add",
r"\.?(k|q|v)_proj$",
r"\.?linear1_0",
r"\.?linear1_1",
r"TransformerFeedForwardLayer\.mlp_residual",
]
# pylint: enable=anomalous-backslash-in-string
),
),
}
),
],
),
),
),
)
else:
Expand Down

0 comments on commit 9715aef

Please sign in to comment.