From 38f4195edb7b2f359ff6a25bd57b1f4ead7cf075 Mon Sep 17 00:00:00 2001 From: qdavid Date: Thu, 21 Nov 2024 11:22:12 -0800 Subject: [PATCH] Update RoformerQKVLinear to support kv_state --- axlearn/common/attention.py | 45 ++++++++++++++++++++++----- axlearn/common/attention_test.py | 53 ++++++++++++++++++++++++-------- axlearn/common/lora.py | 7 +++++ 3 files changed, 85 insertions(+), 20 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 26aceb797..37baf3d8b 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -867,6 +867,7 @@ def forward( *, key: Optional[Tensor] = None, value: Optional[Tensor] = None, + kv_state: Optional[Tensor] = None, time_step: Optional[Tensor] = None, ) -> BaseQKVLinear.Output: """Computes attention for the given query, key, value. @@ -875,6 +876,12 @@ def forward( See parent class for full docstring. """ + if kv_state is not None: + raise ValueError( + "QKVLinear computes key and value projections " + "and does not expect external `kv_state`." + ) + key = query if key is None else key value = query if value is None else value q_proj = self.q_proj(query) @@ -1019,6 +1026,7 @@ def forward( *, key: Optional[Tensor] = None, value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, time_step: Optional[Tensor] = None, ) -> BaseQKVLinear.Output: """Computes multi-head query, key, and value for the input query, key, value @@ -1029,8 +1037,14 @@ def forward( See parent class for full docstring. Raises: - ValueError: If key and value are not both set or both None. + ValueError: If key and value are not both set or both None; or if kv_state is not None. """ + if kv_state is not None: + raise ValueError( + "FusedQKVLinear computes key and value projections " + "and does not expect external `kv_state`." + ) + with child_context("qkv_proj"): params = self.qkv_proj.parameters if key is None and value is None: @@ -1111,12 +1125,18 @@ def forward( *, key: Optional[Tensor] = None, value: Optional[Tensor] = None, + kv_state: Optional[Tensor] = None, time_step: Optional[Tensor] = None, ) -> FusedQKVLinear.Output: """See FusedQKVLinear for full docstring. N.B. Only supports cases where key and value are both None. """ + if kv_state is not None: + raise ValueError( + "FusedGroupedQKVLinear computes key and value projections " + "and does not expect external `kv_state`." + ) if key is not None or value is not None: raise ValueError("Key and value should be both None.") cfg = self.config @@ -1193,6 +1213,7 @@ def apply_rotary_position_embeddings( key: Tensor, value: Tensor, sinusoidal_pos: Tensor, + rotary_key: bool, rotary_value: bool, ) -> tuple[Tensor, Tensor, Tensor]: """This is a jax implementation (a copy) of the RoPE apply_rotary_position_embeddings. @@ -1205,7 +1226,8 @@ def apply_rotary_position_embeddings( key: Key embeddings with shape [batch_size, seq_len, num_heads, dim]. value: Value embeddings with shape [batch_size, seq_len, num_heads, dim]. sinusoidal_pos: Rotary position embeddings with shape [batch_size, seq_len, 1, dim]. - rotary_value: Whether to apply rotary position embeddings on value layer. + rotary_key: Whether to apply rotary position embeddings on key. + rotary_value: Whether to apply rotary position embeddings on value. Returns: A tuple of: @@ -1226,9 +1248,13 @@ def apply_rotary_position_embeddings( jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape ) query = query * cos_pos + rotate_half_query * sin_pos - # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] - rotate_half_key = jnp.reshape(jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape) - key = key * cos_pos + rotate_half_key * sin_pos + + if rotary_key: + # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] + rotate_half_key = jnp.reshape( + jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape + ) + key = key * cos_pos + rotate_half_key * sin_pos if rotary_value: # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2] rotate_half_value = jnp.reshape( @@ -1252,6 +1278,7 @@ class Config(BaseQKVLinear.Config): RoFormerSinusoidalPositionalEmbedding.default_config() ) input_linear: BaseQKVLinear.Config = QKVLinear.default_config() + # Whether to apply RoPE rotations to the value embeddings. rotary_value: Required[bool] = REQUIRED def __init__(self, cfg: QKVLinear.Config, *, parent: Module): @@ -1283,23 +1310,27 @@ def forward( *, key: Optional[Tensor] = None, value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, time_step: Optional[Tensor] = None, ) -> BaseQKVLinear.Output: cfg = self.config # Query should have shape of [batch_size, seq_len, num_heads, per_head_dim]. - query, key, value = self.i_proj(query, key=key, value=value) + query, key, value = self.i_proj(query, key=key, value=value, kv_state=kv_state) query_pos = jnp.arange(query.shape[1])[None] # [batch_size=1, seq_len]. if time_step is not None: query_pos = query_pos + time_step[:, None] # [batch_size, seq_len]. sinusoidal_pos_emb = self.rope_pos_emb_layer.forward(query_pos).astype(query.dtype) # sinusoidal_pos_emb shape should be [batch_size, seq_len, 1, dim] sinusoidal_pos_emb = jnp.expand_dims(sinusoidal_pos_emb, 2) + + i_proj_computes_kv = kv_state is None query, key, value = apply_rotary_position_embeddings( sinusoidal_pos=sinusoidal_pos_emb, query=query, key=key, value=value, - rotary_value=cfg.rotary_value, + rotary_key=i_proj_computes_kv, + rotary_value=i_proj_computes_kv and cfg.rotary_value, ) return self.Output(query, key, value) diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 5d4aeb623..1e188ecc0 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -737,18 +737,24 @@ def test_alibi_attention_mask(self): class RoFormerSinusoidalPositionalEmbeddingTest(TestCase): """Tests RoFormerSinusoidalPositionalEmbedding.""" - @parameterized.parameters( - (2, 3, 10, 32, True), - (2, 3, 8, 32, False), - (2, 4, 6, 32, True), - (2, 4, 8, 16, False), - (2, 5, 8, 48, True), - (2, 5, 8, 64, False), + @parameterized.product( + tensor_dimensions=( + (2, 3, 10, 32), + (2, 3, 8, 32), + (2, 4, 6, 32), + (2, 4, 8, 16), + (2, 5, 8, 48), + (2, 5, 8, 64), + ), + rotary_key=(True, False), + rotary_value=(True, False), ) def test_apply_rotary_position_embeddings( - self, batch_size, num_heads, max_len, dim, rotary_value + self, tensor_dimensions: tuple[int, int, int, int], rotary_key: bool, rotary_value: bool ): # Unittest against the apply_rotary_position_embeddings in HF. + batch_size, num_heads, max_len, dim = tensor_dimensions + token_ids = np.random.randint(low=1, high=20, size=[batch_size, max_len]) sinusoidal_pos_layer = hf_roformer.RoFormerSinusoidalPositionalEmbedding(max_len, dim) sinusoidal_pos = sinusoidal_pos_layer(as_torch_tensor(token_ids).shape)[None, None, :, :] @@ -771,11 +777,15 @@ def test_apply_rotary_position_embeddings( sinusoidal_pos, as_torch_tensor(query), as_torch_tensor(key) ) ref_v_proj = as_torch_tensor(value) + if not rotary_key: + ref_k_proj = as_torch_tensor(key) + test_q_proj, test_k_proj, test_v_proj = test_layer( sinusoidal_pos=as_tensor(sinusoidal_pos), query=query, key=key, value=value, + rotary_key=rotary_key, rotary_value=rotary_value, ) np.testing.assert_allclose(test_q_proj, ref_q_proj, atol=5e-7) @@ -1128,6 +1138,7 @@ def test_against_llama_for_apply_rotary_emb(self): key=jnp.asarray(key), value=jnp.asarray(value), sinusoidal_pos=axlearn_rope, + rotary_key=True, rotary_value=False, ) @@ -1382,11 +1393,22 @@ def test_num_kv_heads( layer = cfg.instantiate(parent=None) self.assertEqual(expected, layer.num_kv_heads) - def test_qlinear(self): + @parameterized.parameters( + (QKVLinear.default_config(), QLinear.default_config()), + ( + RoFormerQKVLinear.default_config().set( + input_linear=QKVLinear.default_config(), rotary_value=False + ), + RoFormerQKVLinear.default_config().set( + input_linear=QLinear.default_config(), rotary_value=False + ), + ), + ) + def test_qlinear(self, base_cfg, test_cfg): """Tests that QLinear is equivalent to QKVLinear with the same kv_state.""" with utils.numeric_checks(True): model_dim = 12 - num_heads = 4 + num_heads = 3 per_head_dim = model_dim // num_heads layer_kwargs = dict( query_dim=model_dim, @@ -1395,8 +1417,8 @@ def test_qlinear(self): num_heads=num_heads, per_head_dim=per_head_dim, ) - base_cfg = QKVLinear.default_config().set(**layer_kwargs) - test_cfg = QLinear.default_config().set(**layer_kwargs) + base_cfg = base_cfg.set(**layer_kwargs) + test_cfg = test_cfg.set(**layer_kwargs) maybe_set_config(test_cfg, num_kv_heads=num_heads) base_layer = base_cfg.set(name="base").instantiate(parent=None) test_layer = test_cfg.set(name="test").instantiate(parent=None) @@ -1404,7 +1426,12 @@ def test_qlinear(self): # Construct base layer state. base_state = base_layer.initialize_parameters_recursively(jax.random.PRNGKey(0)) # Map state to QLinear. - test_state = {"q_proj": base_state["q_proj"]} + if "q_proj" in base_state: + test_state = {"q_proj": base_state["q_proj"]} + elif "i_proj" in base_state: + test_state = {"i_proj": {"q_proj": base_state["i_proj"]["q_proj"]}} + else: + raise ValueError("Cannot find expected q_proj state.") # Construct test inputs. batch_size, src_len, tgt_len = 2, 6, 6 diff --git a/axlearn/common/lora.py b/axlearn/common/lora.py index b968f1548..199cef603 100644 --- a/axlearn/common/lora.py +++ b/axlearn/common/lora.py @@ -516,8 +516,15 @@ def forward( *, key: Optional[Tensor] = None, value: Optional[Tensor] = None, + kv_state: Optional[Tensor] = None, time_step: Optional[Tensor] = None, ) -> BaseQKVLinear.Output: + if kv_state is not None: + raise ValueError( + "LoraFusedQKVLinear computes key and value projections " + "and does not expect external `kv_state`." + ) + cfg = self.config if key is None and value is None: inputs = query