Skip to content

Commit feb61fc

Browse files
committed
Small refactor of gmm in moe.py
1 parent 7988f44 commit feb61fc

3 files changed

Lines changed: 119 additions & 76 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 83 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -905,46 +905,86 @@ def sparse_matmul(
905905
):
906906
"""Perform sparse matrix multiplication of inputs and Experts."""
907907

908-
def gmm(
909-
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
910-
):
908+
def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments, padding_amount):
909+
"""Execute jax.lax.ragged_dot, with potential quantization"""
910+
m, k, n = inputs.shape[0], inputs.shape[1], kernel.shape[2]
911+
tiling = (
912+
min(tiling[0], m),
913+
min(tiling[1], k),
914+
min(tiling[2], n),
915+
)
916+
rhs_inputs = kernel
917+
if isinstance(kernel, aqt.QTensor):
918+
if kernel.bias or kernel.sparsity_mask or len(kernel.scale) > 1:
919+
raise ValueError("Unsupported usecase for ragged_dot with quantized kernel.")
920+
rhs_inputs = kernel.qvalue
921+
if self.config.use_qwix_quantization:
922+
# Use full contraction for QWIX quantization to allow quantization
923+
# fusion (max reduce over contracting dimension).
924+
tiling = (tiling[0], k, tiling[2])
925+
926+
is_tpu = self.mesh.devices.flat[0] == "tpu"
927+
# TPU needs random mosaic_fusion_group; GPU/CPU needs deterministic ID for autotuner sync
928+
mosaic_group_id = f"{random.randint(0, 1000000000)}" if is_tpu else "0"
929+
with set_xla_metadata(
930+
ragged_dot_tiling=",".join([str(t) for t in tiling]),
931+
mosaic_fusion_group=mosaic_group_id,
932+
):
933+
output = jax.lax.ragged_dot(
934+
lhs=inputs,
935+
rhs=rhs_inputs,
936+
group_sizes=group_sizes,
937+
preferred_element_type=self.dtype,
938+
)
939+
if isinstance(kernel, aqt.QTensor):
940+
# Multiply outputs by the kernely scale
941+
scales = jnp.take(kernel.scale[0].squeeze(), indices=expert_assignments, axis=0)
942+
if padding_amount > 0:
943+
scales = jax.lax.pad(
944+
scales,
945+
jnp.array(0.0, dtype=scales.dtype),
946+
[(0, padding_amount, 0), (0, 0, 0)],
947+
)
948+
output *= scales
949+
return output
950+
951+
def get_tokamax_group_sizes(group_sizes, inputs, kernel):
911952
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
912953
if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat:
913-
tokamax_group_sizes = group_sizes
954+
return group_sizes
914955
elif self.config.attention == "vllm_rpa":
915-
tokamax_group_sizes = group_sizes
956+
return group_sizes
916957
else:
917-
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
958+
return tokamax.RaggedDotGroupSizes(
918959
group_sizes,
919960
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
920961
)
921-
pad_length = self.config.wi_tile_fwd_batch_seq
922-
hs_shape = inputs.shape
923-
# pad length is the 1st dimension of tiling size in gmm call
924-
if inputs.shape[0] != expert_assignments.shape[0]:
925-
raise ValueError("The number of input tokens must match the number of expert" " assignments!")
926-
padding_amount = 0
927-
if hs_shape[0] % pad_length:
928-
padding_amount = pad_length - hs_shape[0] % pad_length
929-
inputs = jax.lax.pad(inputs, jnp.array(0.0, dtype=inputs.dtype), [(0, padding_amount, 0), (0, 0, 0)])
930-
931-
inputs = inputs.astype(self.dtype)
932-
kernel = kernel.astype(self.dtype)
933962

963+
def get_quantization_dtypes():
934964
lhs_quantize_dtype, rhs_quantize_dtype = None, None
935965
if self.quant is not None:
936966
quant_dg = self.quant.quant_dg
937967
lhs_quantize_dtype = quant_dg.fwd.dg_quantizer.lhs.numerics.get_dtype()
938968
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
939-
m, k, n = inputs.shape[0], inputs.shape[1], kernel.shape[2]
940-
if not self.config.megablox and not self.config.use_tokamax_gmm:
941-
tiling = (
942-
min(tiling[0], m),
943-
min(tiling[1], k),
944-
min(tiling[2], n),
945-
)
969+
return lhs_quantize_dtype, rhs_quantize_dtype
970+
971+
def gmm(
972+
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
973+
):
974+
if inputs.shape[0] != expert_assignments.shape[0]:
975+
raise ValueError("The number of input tokens must match the number of expert assignments!")
976+
977+
tokamax_group_sizes = get_tokamax_group_sizes(group_sizes, inputs, kernel)
978+
orig_inputs_shape = inputs.shape # save shape of inputs before potentially padding.
979+
inputs, padding_amount = max_utils.maybe_pad(inputs, self.config.wi_tile_fwd_batch_seq)
980+
inputs = inputs.astype(self.dtype)
981+
kernel = kernel.astype(self.dtype)
982+
lhs_quantize_dtype, rhs_quantize_dtype = get_quantization_dtypes()
983+
984+
# We support three implementations for gmm - tokamax, older forked kernel, or jax.lax.ragged_dot
985+
# For quantized tokamax we call a forked version that supports our quantization recipes.
946986
if self.config.use_tokamax_gmm:
947-
if self.config.quantization:
987+
if self.config.quantization: # tokamax (quantized)
948988
output = mblx.gmm(
949989
lhs=inputs,
950990
rhs=kernel,
@@ -959,7 +999,7 @@ def gmm(
959999
input_buffer_count=input_buffer_count,
9601000
combine_scopes=combine_scopes,
9611001
)
962-
else:
1002+
else: # tokamax (unquantized)
9631003
output = tokamax.ragged_dot(
9641004
lhs=inputs,
9651005
rhs=kernel,
@@ -968,56 +1008,23 @@ def gmm(
9681008
preferred_element_type=self.dtype,
9691009
implementation="mosaic",
9701010
)
971-
else:
972-
if self.config.megablox:
973-
output = mblx.gmm(
974-
lhs=inputs,
975-
rhs=kernel,
976-
group_sizes=group_sizes,
977-
preferred_element_type=self.dtype,
978-
tiling=tiling,
979-
lhs_quantize_dtype=lhs_quantize_dtype,
980-
rhs_quantize_dtype=rhs_quantize_dtype,
981-
use_qwix_quantization=self.config.use_qwix_quantization,
982-
use_tokamax_backend=self.config.use_tokamax_gmm,
983-
weight_gather_axes=weight_gather_axes,
984-
)
985-
else:
986-
rhs_inputs = kernel
987-
if isinstance(kernel, aqt.QTensor):
988-
if kernel.bias or kernel.sparsity_mask or len(kernel.scale) > 1:
989-
raise ValueError("Unsupported usecase for ragged_dot with quantized kernel.")
990-
rhs_inputs = kernel.qvalue
991-
if self.config.use_qwix_quantization:
992-
# Use full contraction for QWIX quantization to allow quantization
993-
# fusion (max reduce over contracting dimension).
994-
tiling = (tiling[0], k, tiling[2])
995-
996-
is_tpu = self.mesh.devices.flat[0] == "tpu"
997-
# TPU needs random mosaic_fusion_group; GPU/CPU needs deterministic ID for autotuner sync
998-
mosaic_group_id = f"{random.randint(0, 1000000000)}" if is_tpu else "0"
999-
with set_xla_metadata(
1000-
ragged_dot_tiling=",".join([str(t) for t in tiling]),
1001-
mosaic_fusion_group=mosaic_group_id,
1002-
):
1003-
output = jax.lax.ragged_dot(
1004-
lhs=inputs,
1005-
rhs=rhs_inputs,
1006-
group_sizes=group_sizes,
1007-
preferred_element_type=self.dtype,
1008-
)
1009-
if isinstance(kernel, aqt.QTensor):
1010-
# Multiply outputs by the kernely scale
1011-
scales = jnp.take(kernel.scale[0].squeeze(), indices=expert_assignments, axis=0)
1012-
if padding_amount > 0:
1013-
scales = jax.lax.pad(
1014-
scales,
1015-
jnp.array(0.0, dtype=scales.dtype),
1016-
[(0, padding_amount, 0), (0, 0, 0)],
1017-
)
1018-
output *= scales
1011+
elif self.config.megablox: # Older forked megablox
1012+
output = mblx.gmm(
1013+
lhs=inputs,
1014+
rhs=kernel,
1015+
group_sizes=group_sizes,
1016+
preferred_element_type=self.dtype,
1017+
tiling=tiling,
1018+
lhs_quantize_dtype=lhs_quantize_dtype,
1019+
rhs_quantize_dtype=rhs_quantize_dtype,
1020+
use_qwix_quantization=self.config.use_qwix_quantization,
1021+
use_tokamax_backend=self.config.use_tokamax_gmm,
1022+
weight_gather_axes=weight_gather_axes,
1023+
)
1024+
else: # jax.lax.ragged_dot
1025+
output = jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments, padding_amount)
10191026
if padding_amount > 0:
1020-
output = output[: hs_shape[0]]
1027+
output = output[: orig_inputs_shape[0]]
10211028
return output
10221029

10231030
# Currently, we support data, tensor, and expert parallelism with Megablox.

src/maxtext/utils/max_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,3 +1145,13 @@ def generate_representative_group_sizes(target_m: int, g: int) -> tuple[int, ...
11451145
repr_val = np.int32((repr_val / np.sum(repr_val)) * target_m)
11461146
repr_val[0] += target_m - np.sum(repr_val)
11471147
return tuple(map(int, repr_val))
1148+
1149+
1150+
def maybe_pad(inputs, tile_size):
1151+
"""Pads the inputs leading dimension to be divisible by tile_size."""
1152+
inputs_dim = inputs.shape[0]
1153+
padding_amount = 0
1154+
if inputs_dim % tile_size:
1155+
padding_amount = tile_size - inputs_dim % tile_size
1156+
inputs = jax.lax.pad(inputs, jnp.array(0.0, dtype=inputs.dtype), [(0, padding_amount, 0), (0, 0, 0)])
1157+
return inputs, padding_amount

tests/unit/max_utils_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,5 +318,31 @@ def test_initialize_jax_for_gpu_prefers_cuda_visible_devices_in_loop(self, mock_
318318
mock_log.assert_any_call("Using CUDA_VISIBLE_DEVICES to initialize JAX distributed system: 0,2")
319319

320320

321+
class TestMaybePad(unittest.TestCase):
322+
"""Tests that maybe_pad satisfies its contract."""
323+
324+
def test_odd_shape_padded(self):
325+
inputs = jnp.ones((9, 8))
326+
tile_size = 4
327+
328+
target_padding_amount = 3
329+
target = jnp.concat((inputs, jnp.zeros((3, 8))))
330+
331+
padded, padding_amount = max_utils.maybe_pad(inputs, tile_size)
332+
self.assertTrue(jnp.equal(padded, target).all())
333+
self.assertEqual(padding_amount, target_padding_amount)
334+
335+
def test_regular_shape_unpadded(self):
336+
inputs = jnp.ones((12, 13))
337+
tile_size = 4
338+
339+
target_padding_amount = 0
340+
target = jnp.ones((12, 13))
341+
342+
padded, padding_amount = max_utils.maybe_pad(inputs, tile_size)
343+
self.assertTrue(jnp.equal(padded, target).all())
344+
self.assertEqual(padding_amount, target_padding_amount)
345+
346+
321347
if __name__ == "__main__":
322348
unittest.main()

0 commit comments

Comments
 (0)