Skip to content

Commit 3971206

Browse files
Merge pull request #3312 from Angelogeb:cudnn_jax-scale-fix
PiperOrigin-RevId: 892411074
2 parents 4910293 + 9c2cc66 commit 3971206

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxtext/layers/attention_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1672,7 +1672,7 @@ def cudnn_jax_flash_attention(
16721672
key,
16731673
value,
16741674
mask_type=MaskType.CAUSAL,
1675-
scale=1.0 / math.sqrt(head_dim),
1675+
scale=1.0,
16761676
dropout_rate=self.dropout_rate,
16771677
qkv_layout="BTNH",
16781678
return_residual=True,

0 commit comments

Comments
 (0)