Skip to content

Commit

Permalink
enable special remat for neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 17, 2024
1 parent 3ae8f9f commit a08fb3c
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 11 deletions.
27 changes: 27 additions & 0 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,3 +1459,30 @@ def validate_contains_paths(x: Nested[Tensor], paths: Sequence[str]):
f"Input is expected to contain '{path}'; "
f"instead, it contains: '{jax.tree_structure(x)}'."
) from e


def save_only_these_regex_patterns(*regex_patterns_to_save):
"""Save only the values that match the regex pattern.
Args:
regexes_to_save: List of regex patterns to save
Returns:
Callable: Policy that matches regex
"""
regex_patterns_to_save = frozenset(regex_patterns_to_save)

def policy(*_, **params):
if "name" in params:
param_name = params["name"]
for regex_to_save in regex_patterns_to_save:
if re.search(regex_to_save, param_name):
# if name exists and matches any regex pattern specified
logging.info("Remat: Saving %s", param_name)
return True
# named but not specified
logging.info("Remat: Not saving %s", param_name)
# Unnamed tensor is not saved
return False

return policy
11 changes: 1 addition & 10 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,20 +190,12 @@ def update_model_remat_config(
):
"""Recomputes and sets the remat_spec based on provided layer_cfg.
Only applied if the stack_cfg is a RepeatedTransformerLayer.
Args:
stack_cfg: The transformer stack config.
layer_cfg: The transformer layer config.
offload_dst: Destination of remat checkptoing offloading.
Raises:
NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer.
"""
if stack_cfg.klass is not RepeatedTransformerLayer:
raise NotImplementedError(
f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}"
)

remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg))
layer_cfg.set(remat_spec=remat_spec)
Expand Down Expand Up @@ -277,8 +269,7 @@ def model_config(
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear
layer_cfg.self_attention.structure = atten_structure
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap
if stack_cfg.klass is RepeatedTransformerLayer:
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
# Stack.
transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg)
decoder_cfg = Decoder.default_config().set(
Expand Down
32 changes: 31 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
MeshShapeModifier,
RematSpecModifier,
)
from axlearn.common.utils import extended_checkpoint_policies
from axlearn.common.utils import extended_checkpoint_policies, save_only_these_regex_patterns
from axlearn.experiments.text.gpt.common import (
STEP_DTYPE,
SourceBuilder,
Expand Down Expand Up @@ -417,6 +417,36 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=config_for_function(
save_only_these_regex_patterns
).set(
# pylint: disable=anomalous-backslash-in-string
regex_patterns_to_save=[
r"TransformerAttentionLayer\.residual_add",
r"\.?(k|q|v)_proj$",
r"\.?linear1_0",
r"\.?linear1_1",
r"TransformerFeedForwardLayer\.mlp_residual",
]
# pylint: enable=anomalous-backslash-in-string
),
),
}
),
],
),
),
),
)
else:
Expand Down

0 comments on commit a08fb3c

Please sign in to comment.