Skip to content

Commit 9c2cc66

Browse files
committed
Fix scale to 1.0 for cudnn_flash_jax
1 parent e6cd443 commit 9c2cc66

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxtext/layers/attention_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1672,7 +1672,7 @@ def cudnn_jax_flash_attention(
16721672
key,
16731673
value,
16741674
mask_type=MaskType.CAUSAL,
1675-
scale=1.0 / math.sqrt(head_dim),
1675+
scale=1.0,
16761676
dropout_rate=self.dropout_rate,
16771677
qkv_layout="BTNH",
16781678
return_residual=True,

0 commit comments

Comments
 (0)