Commit 1382bd3
committed
fix: eliminate broadcast overhead before MoE ragged_all_to_all
Profiling revealed a severe performance bottleneck occurring in every
MoE layer immediately before the ragged_all_to_all collective. The issue
was traced to the use of jnp.zeros for the collective's output buffer,
which forced XLA to broadcast a constant and zero out a large shape.
This commit changes the output buffer initialization from jnp.zeros to
jax.lax.empty.
This change removes the broadcast overhead and improves step time.1 parent 1e97f2e commit 1382bd3
1 file changed
Lines changed: 2 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1168 | 1168 | | |
1169 | 1169 | | |
1170 | 1170 | | |
1171 | | - | |
| 1171 | + | |
1172 | 1172 | | |
1173 | 1173 | | |
1174 | 1174 | | |
| |||
1331 | 1331 | | |
1332 | 1332 | | |
1333 | 1333 | | |
1334 | | - | |
| 1334 | + | |
1335 | 1335 | | |
1336 | 1336 | | |
1337 | 1337 | | |
| |||
0 commit comments