|
52 | 52 | HEAD, |
53 | 53 | Q_LORA_UP_PROJ, |
54 | 54 | KV_BATCH, |
55 | | - KV_BATCH_NO_EXP, |
56 | 55 | KV_HEAD, |
57 | 56 | KV_HEAD_DIM, |
58 | 57 | KV_LORA_UP_PROJ, |
59 | 58 | LENGTH, |
60 | | - LENGTH_NO_EXP, |
61 | 59 | MODEL_MODE_PREFILL, |
62 | 60 | MODEL_MODE_TRAIN, |
63 | 61 | PREFILL_KV_BATCH, |
@@ -424,14 +422,11 @@ def mla_as_linen( |
424 | 422 | prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
425 | 423 | prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
426 | 424 | prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
427 | | - query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
428 | | - key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
429 | | - value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
430 | | - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
431 | | - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
432 | | - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
433 | | - input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED), |
434 | | - out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV), |
| 425 | + query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 426 | + key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 427 | + value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 428 | + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED), |
| 429 | + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), |
435 | 430 | prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), |
436 | 431 | decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), |
437 | 432 | prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), |
@@ -496,9 +491,6 @@ def mla_as_linen( |
496 | 491 | query_axis_names=query_axis_names, |
497 | 492 | key_axis_names=key_axis_names, |
498 | 493 | value_axis_names=value_axis_names, |
499 | | - ep_query_axis_names=ep_query_axis_names, |
500 | | - ep_key_axis_names=ep_key_axis_names, |
501 | | - ep_value_axis_names=ep_value_axis_names, |
502 | 494 | input_axis_names=input_axis_names, |
503 | 495 | out_axis_names=out_axis_names, |
504 | 496 | prefill_input_axis_names=prefill_input_axis_names, |
@@ -568,14 +560,11 @@ def __init__( |
568 | 560 | prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
569 | 561 | prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
570 | 562 | prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
571 | | - query_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
572 | | - key_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
573 | | - value_axis_names: AxisNames = (KV_BATCH, LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
574 | | - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
575 | | - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
576 | | - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, LENGTH, KV_HEAD, KV_HEAD_DIM), |
577 | | - input_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, EMBED), |
578 | | - out_axis_names: AxisNames = (BATCH, LENGTH_NO_EXP, HEAD, D_KV), |
| 563 | + query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 564 | + key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 565 | + value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 566 | + input_axis_names: AxisNames = (BATCH, LENGTH, EMBED), |
| 567 | + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV), |
579 | 568 | prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, EMBED), |
580 | 569 | decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, EMBED), |
581 | 570 | prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), |
@@ -657,9 +646,6 @@ def __init__( |
657 | 646 | query_axis_names=query_axis_names, |
658 | 647 | key_axis_names=key_axis_names, |
659 | 648 | value_axis_names=value_axis_names, |
660 | | - ep_query_axis_names=ep_query_axis_names, |
661 | | - ep_key_axis_names=ep_key_axis_names, |
662 | | - ep_value_axis_names=ep_value_axis_names, |
663 | 649 | input_axis_names=input_axis_names, |
664 | 650 | out_axis_names=out_axis_names, |
665 | 651 | prefill_input_axis_names=prefill_input_axis_names, |
@@ -873,12 +859,9 @@ def mla_query_projection( |
873 | 859 | if model_mode == MODEL_MODE_PREFILL: |
874 | 860 | query_logical_name = self.prefill_query_axis_names |
875 | 861 | wqa_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, Q_LORA_UP_PROJ) |
876 | | - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: |
877 | | - query_logical_name = self.ep_query_axis_names |
878 | | - wqa_logical_name = (KV_BATCH_NO_EXP, LENGTH, Q_LORA_UP_PROJ) |
879 | 862 | else: |
880 | 863 | query_logical_name = self.query_axis_names |
881 | | - wqa_logical_name = (KV_BATCH, LENGTH_NO_EXP, Q_LORA_UP_PROJ) |
| 864 | + wqa_logical_name = (KV_BATCH, LENGTH, Q_LORA_UP_PROJ) |
882 | 865 | query_sharding = create_sharding(self.mesh, query_logical_name) |
883 | 866 | wqa_out_sharding = create_sharding(self.mesh, wqa_logical_name) |
884 | 867 | # Set softmax scaling. |
@@ -1029,10 +1012,8 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm |
1029 | 1012 | """MLA key/value projection with integrated rotary embedding.""" |
1030 | 1013 | if model_mode == MODEL_MODE_PREFILL: |
1031 | 1014 | wka_logical_name = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_LORA_UP_PROJ) |
1032 | | - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: |
1033 | | - wka_logical_name = (KV_BATCH_NO_EXP, LENGTH, KV_LORA_UP_PROJ) |
1034 | 1015 | else: |
1035 | | - wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ) |
| 1016 | + wka_logical_name = (KV_BATCH, LENGTH, KV_LORA_UP_PROJ) |
1036 | 1017 | wkva_out_sharding = create_sharding(self.mesh, wka_logical_name) |
1037 | 1018 | low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding) |
1038 | 1019 | low_rank = checkpoint_name(low_rank, "kv_wa_proj") |
@@ -1172,7 +1153,7 @@ def __call__( |
1172 | 1153 | else: |
1173 | 1154 | inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names) |
1174 | 1155 | inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names) |
1175 | | - out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV) |
| 1156 | + out_logical_name = (BATCH, LENGTH, HEAD, D_KV) |
1176 | 1157 |
|
1177 | 1158 | if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None: |
1178 | 1159 | decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32) |
|
0 commit comments