Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special remat for Neuron #898

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3956,15 +3956,22 @@ def policy(prim, *_, **params):
return policy


SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)"
FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*"
# Regex patterns for matching remat names
class RematRegexSavePatterns(enum.Enum):
QKV_PROJ = r".*[kqv]_proj"
O_PROJ = r".*o_proj"
CONTEXT = r".*context"
LINEAR1_X = r".*linear1_[01]"
LINEAR2_X = r".*linear2_[01]"
SELF_ATTENTION = re.compile("|".join([QKV_PROJ, O_PROJ, CONTEXT])).pattern
FEED_FORWARD = re.compile("|".join([LINEAR1_X, LINEAR2_X])).pattern


def build_remat_spec(
stack_cfg: Union[
BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore
],
save_pattern: _SavePattern = SELF_ATTENTION_SAVE_PATTERN,
save_pattern: _SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value,
offload_pattern: _SavePattern = None,
offload_dst: str = "pinned_host",
) -> Optional[RematSpec]:
Expand Down
68 changes: 66 additions & 2 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

from axlearn.common import attention, attention_bias, test_utils, utils
from axlearn.common.attention import (
FEED_FORWARD_SAVE_PATTERN,
BaseStackedTransformerLayer,
BaseTransformerLayer,
BottleNeckAdapterTransformerLayer,
Expand All @@ -58,6 +57,7 @@
PipelinedTransformerLayer,
QKVLinear,
QLinear,
RematRegexSavePatterns,
RepeatedTransformerLayer,
RoFormerQKVLinear,
StackedTransformerLayer,
Expand Down Expand Up @@ -3420,7 +3420,7 @@ def f(x, layer_params):
jax.remat(
f,
policy=_save_and_offload_only_these_names_regex(
names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN,
names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value,
names_which_can_be_offloaded=None,
offload_src="device",
offload_dst="pinned_host",
Expand Down Expand Up @@ -3875,6 +3875,70 @@ def f(x, layer_params):
5,
)

def test_build_remat_spec_neuron(self):
model_dim, num_heads = 6, 2
cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim)
cfg.self_attention.attention.set(num_heads=num_heads, causal=True)
cfg.feed_forward.hidden_dim = model_dim * 4
cfg.vlog = 5

layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))

batch_size, tgt_len = 2, 5
rng = np.random.default_rng(seed=123)
target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32)

def f(x, layer_params):
forward_outputs, _ = F(
layer,
inputs=dict(
data=x,
),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)
return forward_outputs

# Ignore type errors.
spec: Any = build_remat_spec(mock.MagicMock())

policy = (
config_for_function(_save_and_offload_only_these_names_regex)
.set(
names_which_can_be_saved="|".join(
[
RematRegexSavePatterns.QKV_PROJ.value,
RematRegexSavePatterns.LINEAR1_X.value,
]
),
names_which_can_be_offloaded=None,
offload_src=None,
offload_dst=None,
)
.instantiate()
)

_, default_policy_backward = jax.linearize(
jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse),
jnp.asarray(target),
layer_params,
)
_, full_remat_backward = jax.linearize(
jax.remat(f),
jnp.asarray(target),
layer_params,
)

# Eliminated the remat of qkv_proj and linear1_0 = 4 dots. This assumes
# FlashAttention is not enabled.
self.assertEqual(
str(full_remat_backward).count(" dot_general")
- str(default_policy_backward).count(" dot_general"),
4,
)


class TestStackModel(BaseLayer):
"""A dummy transformer stack."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ model.encoder.context.context.num_layers: 17
model.encoder.context.context.remat_spec['prevent_cse']: False
model.encoder.context.context.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.encoder.context.context.remat_spec['policy'].names_which_can_be_offloaded: None
model.encoder.context.context.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.encoder.context.context.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.encoder.context.context.remat_spec['policy'].offload_dst: 'pinned_host'
model.encoder.context.context.remat_spec['policy'].offload_src: 'device'
model.encoder.context.context.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.encoder.context.context.num_layers: 1
model.encoder.context.context.remat_spec['prevent_cse']: False
model.encoder.context.context.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.encoder.context.context.remat_spec['policy'].names_which_can_be_offloaded: None
model.encoder.context.context.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.encoder.context.context.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.encoder.context.context.remat_spec['policy'].offload_dst: 'pinned_host'
model.encoder.context.context.remat_spec['policy'].offload_src: 'device'
model.encoder.context.context.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,22 @@ mesh_rules[1][1][2]: 1
mesh_rules[1][1][3]: 128
mesh_rules[1][1][4]: 1
mesh_rules[1][1][5]: 1
mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down Expand Up @@ -214,7 +230,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,22 @@ mesh_rules[1][1][2]: 1
mesh_rules[1][1][3]: 128
mesh_rules[1][1][4]: 1
mesh_rules[1][1][5]: 1
mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down Expand Up @@ -214,7 +230,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,22 @@ mesh_rules[1][1][2]: 1
mesh_rules[1][1][3]: 128
mesh_rules[1][1][4]: 1
mesh_rules[1][1][5]: 1
mesh_rules[2][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]'
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None
mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None
mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
Expand Down Expand Up @@ -214,7 +230,7 @@ model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLaye
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*o_proj|.*context'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
Expand Down
Loading