@@ -905,46 +905,86 @@ def sparse_matmul(
905905 ):
906906 """Perform sparse matrix multiplication of inputs and Experts."""
907907
908- def gmm (
909- inputs , kernel , tiling , group_sizes , expert_assignments , weight_gather_axes , input_buffer_count , combine_scopes
910- ):
908+ def jax_ragged_dot_gmm (inputs , kernel , tiling , group_sizes , expert_assignments , padding_amount ):
909+ """Execute jax.lax.ragged_dot, with potential quantization"""
910+ m , k , n = inputs .shape [0 ], inputs .shape [1 ], kernel .shape [2 ]
911+ tiling = (
912+ min (tiling [0 ], m ),
913+ min (tiling [1 ], k ),
914+ min (tiling [2 ], n ),
915+ )
916+ rhs_inputs = kernel
917+ if isinstance (kernel , aqt .QTensor ):
918+ if kernel .bias or kernel .sparsity_mask or len (kernel .scale ) > 1 :
919+ raise ValueError ("Unsupported usecase for ragged_dot with quantized kernel." )
920+ rhs_inputs = kernel .qvalue
921+ if self .config .use_qwix_quantization :
922+ # Use full contraction for QWIX quantization to allow quantization
923+ # fusion (max reduce over contracting dimension).
924+ tiling = (tiling [0 ], k , tiling [2 ])
925+
926+ is_tpu = self .mesh .devices .flat [0 ] == "tpu"
927+ # TPU needs random mosaic_fusion_group; GPU/CPU needs deterministic ID for autotuner sync
928+ mosaic_group_id = f"{ random .randint (0 , 1000000000 )} " if is_tpu else "0"
929+ with set_xla_metadata (
930+ ragged_dot_tiling = "," .join ([str (t ) for t in tiling ]),
931+ mosaic_fusion_group = mosaic_group_id ,
932+ ):
933+ output = jax .lax .ragged_dot (
934+ lhs = inputs ,
935+ rhs = rhs_inputs ,
936+ group_sizes = group_sizes ,
937+ preferred_element_type = self .dtype ,
938+ )
939+ if isinstance (kernel , aqt .QTensor ):
940+ # Multiply outputs by the kernely scale
941+ scales = jnp .take (kernel .scale [0 ].squeeze (), indices = expert_assignments , axis = 0 )
942+ if padding_amount > 0 :
943+ scales = jax .lax .pad (
944+ scales ,
945+ jnp .array (0.0 , dtype = scales .dtype ),
946+ [(0 , padding_amount , 0 ), (0 , 0 , 0 )],
947+ )
948+ output *= scales
949+ return output
950+
951+ def get_tokamax_group_sizes (group_sizes , inputs , kernel ):
911952 # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
912953 if self .config .using_pipeline_parallelism and self .config .pipeline_fsdp_ag_per_repeat :
913- tokamax_group_sizes = group_sizes
954+ return group_sizes
914955 elif self .config .attention == "vllm_rpa" :
915- tokamax_group_sizes = group_sizes
956+ return group_sizes
916957 else :
917- tokamax_group_sizes = tokamax .RaggedDotGroupSizes (
958+ return tokamax .RaggedDotGroupSizes (
918959 group_sizes ,
919960 max_utils .generate_representative_group_sizes (inputs .shape [0 ], kernel .shape [0 ]),
920961 )
921- pad_length = self .config .wi_tile_fwd_batch_seq
922- hs_shape = inputs .shape
923- # pad length is the 1st dimension of tiling size in gmm call
924- if inputs .shape [0 ] != expert_assignments .shape [0 ]:
925- raise ValueError ("The number of input tokens must match the number of expert" " assignments!" )
926- padding_amount = 0
927- if hs_shape [0 ] % pad_length :
928- padding_amount = pad_length - hs_shape [0 ] % pad_length
929- inputs = jax .lax .pad (inputs , jnp .array (0.0 , dtype = inputs .dtype ), [(0 , padding_amount , 0 ), (0 , 0 , 0 )])
930-
931- inputs = inputs .astype (self .dtype )
932- kernel = kernel .astype (self .dtype )
933962
963+ def get_quantization_dtypes ():
934964 lhs_quantize_dtype , rhs_quantize_dtype = None , None
935965 if self .quant is not None :
936966 quant_dg = self .quant .quant_dg
937967 lhs_quantize_dtype = quant_dg .fwd .dg_quantizer .lhs .numerics .get_dtype ()
938968 rhs_quantize_dtype = quant_dg .fwd .dg_quantizer .rhs .numerics .get_dtype ()
939- m , k , n = inputs .shape [0 ], inputs .shape [1 ], kernel .shape [2 ]
940- if not self .config .megablox and not self .config .use_tokamax_gmm :
941- tiling = (
942- min (tiling [0 ], m ),
943- min (tiling [1 ], k ),
944- min (tiling [2 ], n ),
945- )
969+ return lhs_quantize_dtype , rhs_quantize_dtype
970+
971+ def gmm (
972+ inputs , kernel , tiling , group_sizes , expert_assignments , weight_gather_axes , input_buffer_count , combine_scopes
973+ ):
974+ if inputs .shape [0 ] != expert_assignments .shape [0 ]:
975+ raise ValueError ("The number of input tokens must match the number of expert assignments!" )
976+
977+ tokamax_group_sizes = get_tokamax_group_sizes (group_sizes , inputs , kernel )
978+ orig_inputs_shape = inputs .shape # save shape of inputs before potentially padding.
979+ inputs , padding_amount = max_utils .maybe_pad (inputs , self .config .wi_tile_fwd_batch_seq )
980+ inputs = inputs .astype (self .dtype )
981+ kernel = kernel .astype (self .dtype )
982+ lhs_quantize_dtype , rhs_quantize_dtype = get_quantization_dtypes ()
983+
984+ # We support three implementations for gmm - tokamax, older forked kernel, or jax.lax.ragged_dot
985+ # For quantized tokamax we call a forked version that supports our quantization recipes.
946986 if self .config .use_tokamax_gmm :
947- if self .config .quantization :
987+ if self .config .quantization : # tokamax (quantized)
948988 output = mblx .gmm (
949989 lhs = inputs ,
950990 rhs = kernel ,
@@ -959,7 +999,7 @@ def gmm(
959999 input_buffer_count = input_buffer_count ,
9601000 combine_scopes = combine_scopes ,
9611001 )
962- else :
1002+ else : # tokamax (unquantized)
9631003 output = tokamax .ragged_dot (
9641004 lhs = inputs ,
9651005 rhs = kernel ,
@@ -968,56 +1008,23 @@ def gmm(
9681008 preferred_element_type = self .dtype ,
9691009 implementation = "mosaic" ,
9701010 )
971- else :
972- if self .config .megablox :
973- output = mblx .gmm (
974- lhs = inputs ,
975- rhs = kernel ,
976- group_sizes = group_sizes ,
977- preferred_element_type = self .dtype ,
978- tiling = tiling ,
979- lhs_quantize_dtype = lhs_quantize_dtype ,
980- rhs_quantize_dtype = rhs_quantize_dtype ,
981- use_qwix_quantization = self .config .use_qwix_quantization ,
982- use_tokamax_backend = self .config .use_tokamax_gmm ,
983- weight_gather_axes = weight_gather_axes ,
984- )
985- else :
986- rhs_inputs = kernel
987- if isinstance (kernel , aqt .QTensor ):
988- if kernel .bias or kernel .sparsity_mask or len (kernel .scale ) > 1 :
989- raise ValueError ("Unsupported usecase for ragged_dot with quantized kernel." )
990- rhs_inputs = kernel .qvalue
991- if self .config .use_qwix_quantization :
992- # Use full contraction for QWIX quantization to allow quantization
993- # fusion (max reduce over contracting dimension).
994- tiling = (tiling [0 ], k , tiling [2 ])
995-
996- is_tpu = self .mesh .devices .flat [0 ] == "tpu"
997- # TPU needs random mosaic_fusion_group; GPU/CPU needs deterministic ID for autotuner sync
998- mosaic_group_id = f"{ random .randint (0 , 1000000000 )} " if is_tpu else "0"
999- with set_xla_metadata (
1000- ragged_dot_tiling = "," .join ([str (t ) for t in tiling ]),
1001- mosaic_fusion_group = mosaic_group_id ,
1002- ):
1003- output = jax .lax .ragged_dot (
1004- lhs = inputs ,
1005- rhs = rhs_inputs ,
1006- group_sizes = group_sizes ,
1007- preferred_element_type = self .dtype ,
1008- )
1009- if isinstance (kernel , aqt .QTensor ):
1010- # Multiply outputs by the kernely scale
1011- scales = jnp .take (kernel .scale [0 ].squeeze (), indices = expert_assignments , axis = 0 )
1012- if padding_amount > 0 :
1013- scales = jax .lax .pad (
1014- scales ,
1015- jnp .array (0.0 , dtype = scales .dtype ),
1016- [(0 , padding_amount , 0 ), (0 , 0 , 0 )],
1017- )
1018- output *= scales
1011+ elif self .config .megablox : # Older forked megablox
1012+ output = mblx .gmm (
1013+ lhs = inputs ,
1014+ rhs = kernel ,
1015+ group_sizes = group_sizes ,
1016+ preferred_element_type = self .dtype ,
1017+ tiling = tiling ,
1018+ lhs_quantize_dtype = lhs_quantize_dtype ,
1019+ rhs_quantize_dtype = rhs_quantize_dtype ,
1020+ use_qwix_quantization = self .config .use_qwix_quantization ,
1021+ use_tokamax_backend = self .config .use_tokamax_gmm ,
1022+ weight_gather_axes = weight_gather_axes ,
1023+ )
1024+ else : # jax.lax.ragged_dot
1025+ output = jax_ragged_dot_gmm (inputs , kernel , tiling , group_sizes , expert_assignments , padding_amount )
10191026 if padding_amount > 0 :
1020- output = output [: hs_shape [0 ]]
1027+ output = output [: orig_inputs_shape [0 ]]
10211028 return output
10221029
10231030 # Currently, we support data, tensor, and expert parallelism with Megablox.
0 commit comments