Skip to content

Commit

Permalink
Fix softmax scale (#903)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 authored Dec 24, 2024
1 parent f91709f commit e4ff72c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions axlearn/common/flash_attention/tpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
_flash_attention_kernel,
causal=causal,
mask_value=DEFAULT_MASK_VALUE,
softmax_scale=softmax_scale,
sm_scale=softmax_scale,
block_k=block_k,
kv_seq_len=kv_seq_len,
)
Expand Down Expand Up @@ -878,7 +878,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _):
_flash_attention_dkv_kernel,
block_q=block_q,
block_k=block_k,
softmax_scale=softmax_scale,
sm_scale=softmax_scale,
causal=causal,
mask_value=mask_value,
q_seq_len=q_seq_len,
Expand Down Expand Up @@ -1068,7 +1068,7 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)

kernel = functools.partial(
_flash_attention_dq_kernel,
softmax_scale=softmax_scale,
sm_scale=softmax_scale,
causal=causal,
mask_value=mask_value,
block_k=block_k,
Expand Down

0 comments on commit e4ff72c

Please sign in to comment.