Skip to content

Commit

Permalink
enable special remat for neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 18, 2024
1 parent 3ae8f9f commit b8c8fa5
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
11 changes: 1 addition & 10 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,20 +190,12 @@ def update_model_remat_config(
):
"""Recomputes and sets the remat_spec based on provided layer_cfg.
Only applied if the stack_cfg is a RepeatedTransformerLayer.
Args:
stack_cfg: The transformer stack config.
layer_cfg: The transformer layer config.
offload_dst: Destination of remat checkptoing offloading.
Raises:
NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer.
"""
if stack_cfg.klass is not RepeatedTransformerLayer:
raise NotImplementedError(
f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}"
)

remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg))
layer_cfg.set(remat_spec=remat_spec)
Expand Down Expand Up @@ -277,8 +269,7 @@ def model_config(
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear
layer_cfg.self_attention.structure = atten_structure
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap
if stack_cfg.klass is RepeatedTransformerLayer:
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
# Stack.
transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg)
decoder_cfg = Decoder.default_config().set(
Expand Down
31 changes: 31 additions & 0 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MultiheadAttention,
RepeatedTransformerLayer,
RoFormerQKVLinear,
_save_and_offload_only_these_names_regex,
)
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import config_for_function
Expand Down Expand Up @@ -417,6 +418,36 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=config_for_function(
_save_and_offload_only_these_names_regex
).set(
# pylint: disable=anomalous-backslash-in-string
names_which_can_be_saved=r"(TransformerAttentionLayer\.residual_add" # pylint: disable=C0301
"|.*\.?(k|q|v)_proj"
"|.*\.?linear1_[01]"
"|TransformerFeedForwardLayer\.mlp_residual)",
names_which_can_be_offloaded=None,
offload_src=None,
offload_dst=None,
# pylint: enable=anomalous-backslash-in-string
),
),
}
),
],
),
),
),
)
else:
Expand Down

0 comments on commit b8c8fa5

Please sign in to comment.