Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mqa parallelization #51

Open
wants to merge 3 commits into
base: multi-query-attention
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@


from .glu_activations import GLU_ACTIVATIONS
from ..mpu import copy_to_tensor_model_parallel_region, LinearWithGradAccumulationAndAsyncCommunication

# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
Expand Down Expand Up @@ -552,10 +553,9 @@ def __init__(self, init_method,
init_method=init_method)
elif attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery':
# TODO: Find a way to merge the query and key-value computations?
self.query = mpu.ColumnParallelLinear(
self.query = get_linear_layer(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
# In MultiQuery attention, keys and values are shared across heads
# Use args.kv_channels instead of projection_size
Expand All @@ -565,6 +565,11 @@ def __init__(self, init_method,
args.hidden_size,
2 * args.kv_channels,
init_method=init_method)

self.async_tensor_model_parallel_allreduce = args.async_tensor_model_parallel_allreduce and world_size > 1
self.sequence_parallel = args.sequence_parallel and world_size > 1
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion

elif attention_type == AttnType.cross_attn and self.attention_head_type == 'multihead':
assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear(
Expand Down Expand Up @@ -686,28 +691,25 @@ def forward(self, hidden_states, attention_mask,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
elif self.attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery':
kv_input=hidden_states
# Attention heads [sq, b, h] --> [sq, b, (2 * hn)]
mixed_kv_layer = self.key_value(kv_input)

# Reduce the KV gradients in the tensor-parallel direction.
# This is different from multi-head attention which reduces the KV input,
# because the sum over attn heads happens in the attn weight gradient instead of the KV layer:
# A [b, n * sq, sk] = Q [b, n * sq, hn] x K^T [b, hn, sk]
# G_K [b, sk, hn] = G_A [b, sk, n * sq] x Q [b, n * sq, hn]
# = sum_p (G_Ap [b, sk, np * sq] x Q_p [b, np * sq, hn])
if get_args().sequence_parallel:
# We switch to the tensor parallel regime here instead of at the KV input
# so that the KV layer is done in parallel instead of just duplicated.
mixed_kv_layer = mpu.gather_from_sequence_parallel_region(mixed_kv_layer, tensor_parallel_output_grad=True)
kv_input = hidden_states

# Manually handle communication of kv_input
if self.async_tensor_model_parallel_allreduce or \
self.sequence_parallel:
kv_input = kv_input
else:
mixed_kv_layer = mpu.copy_to_tensor_model_parallel_region(mixed_kv_layer)
kv_input = copy_to_tensor_model_parallel_region(kv_input)

# [sq, b, (2 * hn)] --> [sq, b, np (expanded), 2 * hn]
# new_tensor_shape = mixed_kv_layer.size()[:-1] + \
# (self.num_attention_heads_per_partition,
# 2 * self.hidden_size_per_attention_head)
# mixed_kv_layer = mixed_kv_layer.unsqueeze(2).expand(*new_tensor_shape)
# TODO @thomasw21: This is stupid because `LinearWithGradAccumulationAndAsyncCommunication` also all_gathers the activations
if self.sequence_parallel:
kv_input_gathered = mpu.gather_from_sequence_parallel_region(
kv_input,
tensor_parallel_output_grad=True)
else:
kv_input_gathered = kv_input

# Attention heads [sq, b, h] --> [sq, b, (2 * hn)]
mixed_kv_layer = self.key_value(kv_input_gathered)

# [sq, b, (2 * hn)] --> [sq, b, 1, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
Expand All @@ -720,7 +722,9 @@ def forward(self, hidden_states, attention_mask,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

# Attention head [sq, b, h] --> [sq, b, np * hn]
query_layer, _ = self.query(hidden_states)
query_layer = LinearWithGradAccumulationAndAsyncCommunication.apply(
kv_input, self.query.weight, self.query.bias, self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce, self.sequence_parallel)
# [sq, b, np * hn] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
Expand Down
1 change: 0 additions & 1 deletion megatron/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ def reduce_model_grads(self, args, timers):
if (
args.attention_head_type == "multiquery"
and mpu.get_tensor_model_parallel_world_size() > 1
and args.sequence_parallel
):
timers('backward-key-value-all-reduce').start()
self.allreduce_key_value_grads(args)
Expand Down