Skip to content

Commit b8cf440

Browse files
committed
update
Update remove dense matmul changes update formatting update fixes
1 parent c57e4c5 commit b8cf440

3 files changed

Lines changed: 39 additions & 23 deletions

File tree

src/maxtext/layers/attention_op.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
3333
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
3434
import jax.numpy as jnp
35-
from jax.sharding import Mesh, NamedSharding
35+
from jax.sharding import Mesh
3636
from maxtext.common.common_types import (
3737
Array,
3838
AttentionType,
@@ -78,7 +78,7 @@
7878
from maxtext.layers.initializers import variable_to_logically_partitioned
7979
from maxtext.layers.quantizations import AqtQuantization as Quant
8080
from maxtext.utils import max_utils
81-
from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_name
81+
from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_pspec
8282
import numpy as np
8383
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel
8484
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask
@@ -1484,26 +1484,19 @@ def kernel_fn(q, k, v, d, s):
14841484

14851485
return attention_output, None
14861486

1487-
def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
1488-
# decoder_segment_ids can be None
1489-
if pspec is None:
1490-
return None
1491-
sharding = NamedSharding(self.mesh, pspec)
1492-
return maybe_shard_with_name(
1493-
inputs,
1494-
sharding,
1495-
shard_mode=self.config.shard_mode,
1496-
debug_sharding=self.config.debug_sharding,
1497-
extra_stack_level=1,
1498-
)
1499-
1500-
query = _maybe_shard_with_pspec(query, axis_names_q)
1501-
key = _maybe_shard_with_pspec(key, axis_names_kv)
1502-
value = _maybe_shard_with_pspec(value, axis_names_kv)
1503-
decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q)
1504-
decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv)
1505-
sinks = _maybe_shard_with_pspec(sinks, sink_axis_names)
1506-
indexer_mask = _maybe_shard_with_pspec(indexer_mask, indexer_mask_axis_names)
1487+
query = maybe_shard_with_pspec(query, self.mesh, self.config.shard_mode, axis_names_q, self.config.debug_sharding)
1488+
key = maybe_shard_with_pspec(key, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding)
1489+
value = maybe_shard_with_pspec(value, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding)
1490+
decoder_segment_ids_q = maybe_shard_with_pspec(
1491+
decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_q, self.config.debug_sharding
1492+
)
1493+
decoder_segment_ids_kv = maybe_shard_with_pspec(
1494+
decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_kv, self.config.debug_sharding
1495+
)
1496+
sinks = maybe_shard_with_pspec(sinks, self.mesh, self.config.shard_mode, sink_axis_names, self.config.debug_sharding)
1497+
indexer_mask = maybe_shard_with_pspec(
1498+
indexer_mask, self.mesh, self.config.shard_mode, indexer_mask_axis_names, self.config.debug_sharding
1499+
)
15071500

15081501
ret = wrap_flash_attention(
15091502
query,

src/maxtext/layers/moe.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from maxtext.kernels import megablox as mblx
3737
from maxtext.utils import max_logging
3838
from maxtext.utils import max_utils
39-
from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding
39+
from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding, maybe_shard_with_pspec
4040
from maxtext.utils.sharding import logical_to_mesh_axes
4141
import numpy as np
4242
import qwix.pallas as qpl
@@ -1439,6 +1439,16 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14391439
gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes)
14401440
pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, pre_bias_logits_axes)
14411441

1442+
w0_kernel = maybe_shard_with_pspec(w0_kernel, self.mesh, self.config.shard_mode, w0_pspec)
1443+
w1_kernel = maybe_shard_with_pspec(w1_kernel, self.mesh, self.config.shard_mode, w1_pspec)
1444+
wo_kernel = maybe_shard_with_pspec(wo_kernel, self.mesh, self.config.shard_mode, wo_pspec)
1445+
if w0_bias is not None:
1446+
w0_bias = maybe_shard_with_pspec(w0_bias, self.mesh, self.config.shard_mode, w0_bias_pspec)
1447+
if w1_bias is not None:
1448+
w1_bias = maybe_shard_with_pspec(w1_bias, self.mesh, self.config.shard_mode, w1_bias_pspec)
1449+
if wo_bias is not None:
1450+
wo_bias = maybe_shard_with_pspec(wo_bias, self.mesh, self.config.shard_mode, wo_bias_pspec)
1451+
14421452
return wrapper(
14431453
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs
14441454
)

src/maxtext/utils/sharding.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,19 @@ def maybe_shard_with_name(
115115
return jax.lax.with_sharding_constraint(inputs, named_sharding)
116116

117117

118+
def maybe_shard_with_pspec(inputs, mesh, shard_mode, pspec: jax.sharding.PartitionSpec | None, debug_sharding=False):
119+
if pspec is None:
120+
return None
121+
sharding = NamedSharding(mesh, pspec)
122+
return maybe_shard_with_name(
123+
inputs,
124+
sharding,
125+
shard_mode=shard_mode,
126+
debug_sharding=debug_sharding,
127+
extra_stack_level=1,
128+
)
129+
130+
118131
def maybe_shard_with_logical(
119132
inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc=""
120133
):

0 commit comments

Comments
 (0)