|
37 | 37 | AxisIdxes, |
38 | 38 | AxisNames, |
39 | 39 | BATCH, |
40 | | - BATCH_NO_EXP, |
41 | 40 | CACHE_BATCH, |
42 | 41 | CACHE_BATCH_PREFILL, |
43 | 42 | CACHE_SEQUENCE, |
|
53 | 52 | HEAD, |
54 | 53 | Q_LORA_UP_PROJ, |
55 | 54 | KV_BATCH, |
56 | | - KV_BATCH_NO_EXP, |
57 | 55 | KV_HEAD, |
58 | 56 | KV_HEAD_DIM, |
59 | 57 | KV_LORA_UP_PROJ, |
60 | 58 | LENGTH, |
61 | | - LENGTH_NO_EXP, |
62 | 59 | MODEL_MODE_PREFILL, |
63 | 60 | MODEL_MODE_TRAIN, |
64 | 61 | PREFILL_KV_BATCH, |
@@ -425,16 +422,11 @@ def mla_as_linen( |
425 | 422 | prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
426 | 423 | prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
427 | 424 | prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
428 | | - query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
429 | | - key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
430 | | - value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
431 | | - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
432 | | - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
433 | | - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
434 | | - input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED), |
435 | | - ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED), |
436 | | - out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV), |
437 | | - ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV), |
| 425 | + query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 426 | + key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 427 | + value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 428 | + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED), |
| 429 | + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), |
438 | 430 | prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), |
439 | 431 | decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), |
440 | 432 | prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), |
@@ -499,13 +491,8 @@ def mla_as_linen( |
499 | 491 | query_axis_names=query_axis_names, |
500 | 492 | key_axis_names=key_axis_names, |
501 | 493 | value_axis_names=value_axis_names, |
502 | | - ep_query_axis_names=ep_query_axis_names, |
503 | | - ep_key_axis_names=ep_key_axis_names, |
504 | | - ep_value_axis_names=ep_value_axis_names, |
505 | 494 | input_axis_names=input_axis_names, |
506 | | - ep_input_axis_names=ep_input_axis_names, |
507 | 495 | out_axis_names=out_axis_names, |
508 | | - ep_out_axis_names=ep_out_axis_names, |
509 | 496 | prefill_input_axis_names=prefill_input_axis_names, |
510 | 497 | decode_input_axis_names=decode_input_axis_names, |
511 | 498 | prefill_out_axis_names=prefill_out_axis_names, |
@@ -573,16 +560,11 @@ def __init__( |
573 | 560 | prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
574 | 561 | prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
575 | 562 | prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
576 | | - query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
577 | | - key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
578 | | - value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
579 | | - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
580 | | - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
581 | | - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
582 | | - input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED), |
583 | | - ep_input_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, EMBED), |
584 | | - out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV), |
585 | | - ep_out_axis_names: AxisNames = (BATCH_NO_EXP, LENGTH, HEAD, D_KV), |
| 563 | + query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 564 | + key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 565 | + value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 566 | + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED), |
| 567 | + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), |
586 | 568 | prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), |
587 | 569 | decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), |
588 | 570 | prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), |
@@ -664,13 +646,8 @@ def __init__( |
664 | 646 | query_axis_names=query_axis_names, |
665 | 647 | key_axis_names=key_axis_names, |
666 | 648 | value_axis_names=value_axis_names, |
667 | | - ep_query_axis_names=ep_query_axis_names, |
668 | | - ep_key_axis_names=ep_key_axis_names, |
669 | | - ep_value_axis_names=ep_value_axis_names, |
670 | 649 | input_axis_names=input_axis_names, |
671 | | - ep_input_axis_names=ep_input_axis_names, |
672 | 650 | out_axis_names=out_axis_names, |
673 | | - ep_out_axis_names=ep_out_axis_names, |
674 | 651 | prefill_input_axis_names=prefill_input_axis_names, |
675 | 652 | decode_input_axis_names=decode_input_axis_names, |
676 | 653 | prefill_out_axis_names=prefill_out_axis_names, |
@@ -882,12 +859,9 @@ def mla_query_projection( |
882 | 859 | if model_mode == MODEL_MODE_PREFILL: |
883 | 860 | query_logical_name = self.prefill_query_axis_names |
884 | 861 | wqa_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, Q_LORA_UP_PROJ) |
885 | | - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: |
886 | | - query_logical_name = self.ep_query_axis_names |
887 | | - wqa_logical_name = (KV_BATCH_NO_EXP, LENGTH, Q_LORA_UP_PROJ) |
888 | 862 | else: |
889 | 863 | query_logical_name = self.query_axis_names |
890 | | - wqa_logical_name = (KV_BATCH, LENGTH_NO_EXP, Q_LORA_UP_PROJ) |
| 864 | + wqa_logical_name = (KV_BATCH, LENGTH, Q_LORA_UP_PROJ) |
891 | 865 | query_sharding = create_sharding(self.mesh, query_logical_name) |
892 | 866 | wqa_out_sharding = create_sharding(self.mesh, wqa_logical_name) |
893 | 867 | # Set softmax scaling. |
@@ -1038,10 +1012,8 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm |
1038 | 1012 | """MLA key/value projection with integrated rotary embedding.""" |
1039 | 1013 | if model_mode == MODEL_MODE_PREFILL: |
1040 | 1014 | wka_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_LORA_UP_PROJ) |
1041 | | - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: |
1042 | | - wka_logical_name = (KV_BATCH_NO_EXP, LENGTH, KV_LORA_UP_PROJ) |
1043 | 1015 | else: |
1044 | | - wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ) |
| 1016 | + wka_logical_name = (KV_BATCH, LENGTH, KV_LORA_UP_PROJ) |
1045 | 1017 | wkva_out_sharding = create_sharding(self.mesh, wka_logical_name) |
1046 | 1018 | low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding) |
1047 | 1019 | low_rank = checkpoint_name(low_rank, "kv_wa_proj") |
@@ -1178,14 +1150,10 @@ def __call__( |
1178 | 1150 | inputs_q = self._maybe_shard_with_logical(inputs_q, self.prefill_input_axis_names) |
1179 | 1151 | inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.prefill_input_axis_names) |
1180 | 1152 | out_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV) |
1181 | | - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: |
1182 | | - inputs_q = self._maybe_shard_with_logical(inputs_q, self.ep_input_axis_names) |
1183 | | - inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.ep_input_axis_names) |
1184 | | - out_logical_name = (BATCH_NO_EXP, LENGTH, HEAD, D_KV) |
1185 | 1153 | else: |
1186 | 1154 | inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names) |
1187 | 1155 | inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names) |
1188 | | - out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV) |
| 1156 | + out_logical_name = (BATCH, LENGTH, HEAD, D_KV) |
1189 | 1157 |
|
1190 | 1158 | if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None: |
1191 | 1159 | decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32) |
|
0 commit comments