|
39 | 39 | AxisIdxes, |
40 | 40 | AxisNames, |
41 | 41 | BATCH, |
42 | | - BATCH_NO_EXP, |
43 | 42 | CACHE_BATCH, |
44 | 43 | CACHE_BATCH_PREFILL, |
45 | 44 | CACHE_HEADS, |
|
61 | 60 | HEAD, |
62 | 61 | KV_LENGTH, |
63 | 62 | LENGTH, |
64 | | - LENGTH_NO_EXP, |
65 | 63 | MODEL_MODE_AUTOREGRESSIVE, |
66 | 64 | MODEL_MODE_PREFILL, |
67 | 65 | MODEL_MODE_TRAIN, |
@@ -302,12 +300,9 @@ def attention_op_as_linen( |
302 | 300 | float32_qk_product: bool = False, |
303 | 301 | max_prefill_predict_length: int = -1, |
304 | 302 | float32_logits: bool = False, |
305 | | - flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH_NO_EXP, D_KV), |
306 | | - flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH, D_KV), |
| 303 | + flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), |
307 | 304 | flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), |
308 | | - flash_axis_names_kv_ep: AxisNames = (BATCH_NO_EXP, HEAD, KV_LENGTH, D_KV), |
309 | | - flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH_NO_EXP), |
310 | | - flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH), |
| 305 | + flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH), |
311 | 306 | prefill_cache_logical_axis_names: AxisNames = ( |
312 | 307 | CACHE_BATCH_PREFILL, |
313 | 308 | CACHE_SEQUENCE, |
@@ -364,11 +359,8 @@ def attention_op_as_linen( |
364 | 359 | max_prefill_predict_length=max_prefill_predict_length, |
365 | 360 | float32_logits=float32_logits, |
366 | 361 | flash_axis_names_q=flash_axis_names_q, |
367 | | - flash_axis_names_q_ep=flash_axis_names_q_ep, |
368 | 362 | flash_axis_names_kv=flash_axis_names_kv, |
369 | | - flash_axis_names_kv_ep=flash_axis_names_kv_ep, |
370 | 363 | flash_axis_names_splash_kernel=flash_axis_names_splash_kernel, |
371 | | - flash_axis_names_splash_kernel_ep=flash_axis_names_splash_kernel_ep, |
372 | 364 | prefill_cache_logical_axis_names=prefill_cache_logical_axis_names, |
373 | 365 | cache_logical_axis_names=cache_logical_axis_names, |
374 | 366 | cache_scale_logical_axis_names=cache_scale_logical_axis_names, |
@@ -405,12 +397,9 @@ def __init__( |
405 | 397 | float32_qk_product: bool = False, |
406 | 398 | max_prefill_predict_length: int = -1, |
407 | 399 | float32_logits: bool = False, |
408 | | - flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH_NO_EXP, D_KV), |
409 | | - flash_axis_names_q_ep: AxisNames = (BATCH_NO_EXP, HEAD, LENGTH, D_KV), |
| 400 | + flash_axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), |
410 | 401 | flash_axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), |
411 | | - flash_axis_names_kv_ep: AxisNames = (BATCH_NO_EXP, HEAD, KV_LENGTH, D_KV), |
412 | | - flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH_NO_EXP), |
413 | | - flash_axis_names_splash_kernel_ep: AxisNames = (HEAD, LENGTH), |
| 402 | + flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH), |
414 | 403 | prefill_cache_logical_axis_names: AxisNames = ( |
415 | 404 | CACHE_BATCH_PREFILL, |
416 | 405 | CACHE_SEQUENCE, |
@@ -492,11 +481,8 @@ def __init__( |
492 | 481 | self.max_prefill_predict_length = max_prefill_predict_length |
493 | 482 | self.float32_logits = float32_logits |
494 | 483 | self.flash_axis_names_q = flash_axis_names_q |
495 | | - self.flash_axis_names_q_ep = flash_axis_names_q_ep |
496 | 484 | self.flash_axis_names_kv = flash_axis_names_kv |
497 | | - self.flash_axis_names_kv_ep = flash_axis_names_kv_ep |
498 | 485 | self.flash_axis_names_splash_kernel = flash_axis_names_splash_kernel |
499 | | - self.flash_axis_names_splash_kernel_ep = flash_axis_names_splash_kernel_ep |
500 | 486 | self.prefill_cache_logical_axis_names = prefill_cache_logical_axis_names |
501 | 487 | self.cache_logical_axis_names = cache_logical_axis_names |
502 | 488 | self.cache_scale_logical_axis_names = cache_scale_logical_axis_names |
@@ -1150,23 +1136,13 @@ def tpu_flash_attention( |
1150 | 1136 | segment_axis_names_kv = None |
1151 | 1137 | sink_axis_names = self._logical_to_mesh_axes((HEAD,)) |
1152 | 1138 | if decoder_segment_ids is not None: |
1153 | | - if self.config.expert_shard_attention_option == EP_AS_CONTEXT: |
1154 | | - segment_axis_names_q = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH)) |
1155 | | - segment_axis_names_kv = self._logical_to_mesh_axes((BATCH_NO_EXP, KV_LENGTH)) |
1156 | | - else: |
1157 | | - segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP)) |
1158 | | - segment_axis_names_kv = self._logical_to_mesh_axes((BATCH, KV_LENGTH)) |
1159 | | - |
1160 | | - if self.config.expert_shard_attention_option == EP_AS_CONTEXT: |
1161 | | - axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel_ep) |
1162 | | - axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q_ep) |
1163 | | - axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv_ep) |
1164 | | - indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH_NO_EXP, Q_LENGTH, KV_LENGTH)) |
1165 | | - else: |
1166 | | - axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel) |
1167 | | - axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q) |
1168 | | - axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv) |
1169 | | - indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH)) |
| 1139 | + segment_axis_names_q = self._logical_to_mesh_axes((BATCH, Q_LENGTH_NO_EXP)) |
| 1140 | + segment_axis_names_kv = self._logical_to_mesh_axes((BATCH, KV_LENGTH)) |
| 1141 | + |
| 1142 | + axis_names_splash_kernel = self._logical_to_mesh_axes(self.flash_axis_names_splash_kernel) |
| 1143 | + axis_names_q = self._logical_to_mesh_axes(self.flash_axis_names_q) |
| 1144 | + axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv) |
| 1145 | + indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH, Q_LENGTH, KV_LENGTH)) |
1170 | 1146 |
|
1171 | 1147 | global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv |
1172 | 1148 | global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel |
|
0 commit comments