Skip to content

Commit e0cb1d5

Browse files
Merge pull request #3541 from AI-Hypercomputer:parambole/494634065
PiperOrigin-RevId: 895394473
2 parents 9777a4c + 1382bd3 commit e0cb1d5

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11821182
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
11831183
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
11841184
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
1185-
output_shape = jnp.zeros((buffer_size, self.config.emb_dim), dtype=x.dtype)
1185+
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)
11861186

11871187
x = jax.lax.ragged_all_to_all(
11881188
x,
@@ -1345,7 +1345,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13451345
original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok
13461346
if sorted_selected_experts.shape[0] != original_inputs_first_dim:
13471347
raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!")
1348-
output_shape = jnp.zeros(
1348+
output_shape = jax.lax.empty(
13491349
(
13501350
original_inputs_first_dim,
13511351
self.config.emb_dim // self.get_tensor_parallelism_size(),

0 commit comments

Comments
 (0)