Skip to content

Commit

Permalink
Update RoformerQKVLinear to support kv_state
Browse files Browse the repository at this point in the history
  • Loading branch information
qdavid1 committed Dec 17, 2024
1 parent 73625c9 commit 6bf7422
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 19 deletions.
52 changes: 46 additions & 6 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ def forward(
*,
key: Optional[Tensor] = None,
value: Optional[Tensor] = None,
kv_state: Optional[Tensor] = None,
time_step: Optional[Tensor] = None,
) -> BaseQKVLinear.Output:
"""Computes attention for the given query, key, value.
Expand All @@ -875,6 +876,12 @@ def forward(
See parent class for full docstring.
"""
if kv_state is not None:
raise ValueError(
"QKVLinear computes key and value projections "
"and does not expect external `kv_state`."
)

key = query if key is None else key
value = query if value is None else value
q_proj = self.q_proj(query)
Expand Down Expand Up @@ -1019,6 +1026,7 @@ def forward(
*,
key: Optional[Tensor] = None,
value: Optional[Tensor] = None,
kv_state: Optional[KVState] = None,
time_step: Optional[Tensor] = None,
) -> BaseQKVLinear.Output:
"""Computes multi-head query, key, and value for the input query, key, value
Expand All @@ -1029,8 +1037,14 @@ def forward(
See parent class for full docstring.
Raises:
ValueError: If key and value are not both set or both None.
ValueError: If key and value are not both set or both None; or if kv_state is not None.
"""
if kv_state is not None:
raise ValueError(
"FusedQKVLinear computes key and value projections "
"and does not expect external `kv_state`."
)

with child_context("qkv_proj"):
params = self.qkv_proj.parameters
if key is None and value is None:
Expand Down Expand Up @@ -1111,12 +1125,18 @@ def forward(
*,
key: Optional[Tensor] = None,
value: Optional[Tensor] = None,
kv_state: Optional[Tensor] = None,
time_step: Optional[Tensor] = None,
) -> FusedQKVLinear.Output:
"""See FusedQKVLinear for full docstring.
N.B. Only supports cases where key and value are both None.
"""
if kv_state is not None:
raise ValueError(
"FusedGroupedQKVLinear computes key and value projections "
"and does not expect external `kv_state`."
)
if key is not None or value is not None:
raise ValueError("Key and value should be both None.")
cfg = self.config
Expand Down Expand Up @@ -1193,6 +1213,7 @@ def apply_rotary_position_embeddings(
key: Tensor,
value: Tensor,
sinusoidal_pos: Tensor,
rotary_key: bool,
rotary_value: bool,
) -> tuple[Tensor, Tensor, Tensor]:
"""This is a jax implementation (a copy) of the RoPE apply_rotary_position_embeddings.
Expand All @@ -1205,7 +1226,8 @@ def apply_rotary_position_embeddings(
key: Key embeddings with shape [batch_size, seq_len, num_heads, dim].
value: Value embeddings with shape [batch_size, seq_len, num_heads, dim].
sinusoidal_pos: Rotary position embeddings with shape [batch_size, seq_len, 1, dim].
rotary_value: Whether to apply rotary position embeddings on value layer.
rotary_key: Whether to apply rotary position embeddings on key.
rotary_value: Whether to apply rotary position embeddings on value.
Returns:
A tuple of:
Expand All @@ -1226,9 +1248,13 @@ def apply_rotary_position_embeddings(
jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape
)
query = query * cos_pos + rotate_half_query * sin_pos
# rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
rotate_half_key = jnp.reshape(jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape)
key = key * cos_pos + rotate_half_key * sin_pos

if rotary_key:
# rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
rotate_half_key = jnp.reshape(
jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape
)
key = key * cos_pos + rotate_half_key * sin_pos
if rotary_value:
# rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
rotate_half_value = jnp.reshape(
Expand All @@ -1252,6 +1278,10 @@ class Config(BaseQKVLinear.Config):
RoFormerSinusoidalPositionalEmbedding.default_config()
)
input_linear: BaseQKVLinear.Config = QKVLinear.default_config()
# Whether to apply RoPE rotations to the key embeddings.
# Set to False to support pre-rotated key inputs.
rotary_key: Optional[bool] = None
# Whether to apply RoPE rotations to the value embeddings.
rotary_value: Required[bool] = REQUIRED

def __init__(self, cfg: QKVLinear.Config, *, parent: Module):
Expand All @@ -1277,17 +1307,26 @@ def num_kv_heads(self):
"""Propagate num KV heads from input linear."""
return self.i_proj.num_kv_heads

@property
def rotary_key(self):
"""Whether to rotate the key embeddings.
None maps to True to support older configs.
"""
return self.config.rotary_key is None or self.config.rotary_key

def forward(
self,
query: Tensor,
*,
key: Optional[Tensor] = None,
value: Optional[Tensor] = None,
kv_state: Optional[KVState] = None,
time_step: Optional[Tensor] = None,
) -> BaseQKVLinear.Output:
cfg = self.config
# Query should have shape of [batch_size, seq_len, num_heads, per_head_dim].
query, key, value = self.i_proj(query, key=key, value=value)
query, key, value = self.i_proj(query, key=key, value=value, kv_state=kv_state)
query_pos = jnp.arange(query.shape[1])[None] # [batch_size=1, seq_len].
if time_step is not None:
query_pos = query_pos + time_step[:, None] # [batch_size, seq_len].
Expand All @@ -1299,6 +1338,7 @@ def forward(
query=query,
key=key,
value=value,
rotary_key=self.rotary_key,
rotary_value=cfg.rotary_value,
)

Expand Down
53 changes: 40 additions & 13 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,18 +737,24 @@ def test_alibi_attention_mask(self):
class RoFormerSinusoidalPositionalEmbeddingTest(TestCase):
"""Tests RoFormerSinusoidalPositionalEmbedding."""

@parameterized.parameters(
(2, 3, 10, 32, True),
(2, 3, 8, 32, False),
(2, 4, 6, 32, True),
(2, 4, 8, 16, False),
(2, 5, 8, 48, True),
(2, 5, 8, 64, False),
@parameterized.product(
tensor_dimensions=(
(2, 3, 10, 32),
(2, 3, 8, 32),
(2, 4, 6, 32),
(2, 4, 8, 16),
(2, 5, 8, 48),
(2, 5, 8, 64),
),
rotary_key=(True, False),
rotary_value=(True, False),
)
def test_apply_rotary_position_embeddings(
self, batch_size, num_heads, max_len, dim, rotary_value
self, tensor_dimensions: tuple[int, int, int, int], rotary_key: bool, rotary_value: bool
):
# Unittest against the apply_rotary_position_embeddings in HF.
batch_size, num_heads, max_len, dim = tensor_dimensions

token_ids = np.random.randint(low=1, high=20, size=[batch_size, max_len])
sinusoidal_pos_layer = hf_roformer.RoFormerSinusoidalPositionalEmbedding(max_len, dim)
sinusoidal_pos = sinusoidal_pos_layer(as_torch_tensor(token_ids).shape)[None, None, :, :]
Expand All @@ -771,11 +777,15 @@ def test_apply_rotary_position_embeddings(
sinusoidal_pos, as_torch_tensor(query), as_torch_tensor(key)
)
ref_v_proj = as_torch_tensor(value)
if not rotary_key:
ref_k_proj = as_torch_tensor(key)

test_q_proj, test_k_proj, test_v_proj = test_layer(
sinusoidal_pos=as_tensor(sinusoidal_pos),
query=query,
key=key,
value=value,
rotary_key=rotary_key,
rotary_value=rotary_value,
)
np.testing.assert_allclose(test_q_proj, ref_q_proj, atol=5e-7)
Expand Down Expand Up @@ -1128,6 +1138,7 @@ def test_against_llama_for_apply_rotary_emb(self):
key=jnp.asarray(key),
value=jnp.asarray(value),
sinusoidal_pos=axlearn_rope,
rotary_key=True,
rotary_value=False,
)

Expand Down Expand Up @@ -1382,11 +1393,22 @@ def test_num_kv_heads(
layer = cfg.instantiate(parent=None)
self.assertEqual(expected, layer.num_kv_heads)

def test_qlinear(self):
@parameterized.parameters(
(QKVLinear.default_config(), QLinear.default_config()),
(
RoFormerQKVLinear.default_config().set(
input_linear=QKVLinear.default_config(), rotary_key=False, rotary_value=False
),
RoFormerQKVLinear.default_config().set(
input_linear=QLinear.default_config(), rotary_key=False, rotary_value=False
),
),
)
def test_qlinear(self, base_cfg, test_cfg):
"""Tests that QLinear is equivalent to QKVLinear with the same kv_state."""
with utils.numeric_checks(True):
model_dim = 12
num_heads = 4
num_heads = 3
per_head_dim = model_dim // num_heads
layer_kwargs = dict(
query_dim=model_dim,
Expand All @@ -1395,16 +1417,21 @@ def test_qlinear(self):
num_heads=num_heads,
per_head_dim=per_head_dim,
)
base_cfg = QKVLinear.default_config().set(**layer_kwargs)
test_cfg = QLinear.default_config().set(**layer_kwargs)
base_cfg = base_cfg.set(**layer_kwargs)
test_cfg = test_cfg.set(**layer_kwargs)
maybe_set_config(test_cfg, num_kv_heads=num_heads)
base_layer = base_cfg.set(name="base").instantiate(parent=None)
test_layer = test_cfg.set(name="test").instantiate(parent=None)

# Construct base layer state.
base_state = base_layer.initialize_parameters_recursively(jax.random.PRNGKey(0))
# Map state to QLinear.
test_state = {"q_proj": base_state["q_proj"]}
if "q_proj" in base_state:
test_state = {"q_proj": base_state["q_proj"]}
elif "i_proj" in base_state:
test_state = {"i_proj": {"q_proj": base_state["i_proj"]["q_proj"]}}
else:
raise ValueError("Cannot find expected q_proj state.")

# Construct test inputs.
batch_size, src_len, tgt_len = 2, 6, 6
Expand Down
7 changes: 7 additions & 0 deletions axlearn/common/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,15 @@ def forward(
*,
key: Optional[Tensor] = None,
value: Optional[Tensor] = None,
kv_state: Optional[Tensor] = None,
time_step: Optional[Tensor] = None,
) -> BaseQKVLinear.Output:
if kv_state is not None:
raise ValueError(
"LoraFusedQKVLinear computes key and value projections "
"and does not expect external `kv_state`."
)

cfg = self.config
if key is None and value is None:
inputs = query
Expand Down

0 comments on commit 6bf7422

Please sign in to comment.