Skip to content

Commit 1382bd3

Browse files
committed
fix: eliminate broadcast overhead before MoE ragged_all_to_all
Profiling revealed a severe performance bottleneck occurring in every MoE layer immediately before the ragged_all_to_all collective. The issue was traced to the use of jnp.zeros for the collective's output buffer, which forced XLA to broadcast a constant and zero out a large shape. This commit changes the output buffer initialization from jnp.zeros to jax.lax.empty. This change removes the broadcast overhead and improves step time.
1 parent 1e97f2e commit 1382bd3

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11681168
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
11691169
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
11701170
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
1171-
output_shape = jnp.zeros((buffer_size, self.config.emb_dim), dtype=x.dtype)
1171+
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)
11721172

11731173
x = jax.lax.ragged_all_to_all(
11741174
x,
@@ -1331,7 +1331,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13311331
original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok
13321332
if sorted_selected_experts.shape[0] != original_inputs_first_dim:
13331333
raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!")
1334-
output_shape = jnp.zeros(
1334+
output_shape = jax.lax.empty(
13351335
(
13361336
original_inputs_first_dim,
13371337
self.config.emb_dim // self.get_tensor_parallelism_size(),

0 commit comments

Comments
 (0)