|
32 | 32 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel |
33 | 33 | from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask |
34 | 34 | import jax.numpy as jnp |
35 | | -from jax.sharding import Mesh, NamedSharding |
| 35 | +from jax.sharding import Mesh |
36 | 36 | from maxtext.common.common_types import ( |
37 | 37 | Array, |
38 | 38 | AttentionType, |
|
78 | 78 | from maxtext.layers.initializers import variable_to_logically_partitioned |
79 | 79 | from maxtext.layers.quantizations import AqtQuantization as Quant |
80 | 80 | from maxtext.utils import max_utils |
81 | | -from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_name |
| 81 | +from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_pspec |
82 | 82 | import numpy as np |
83 | 83 | from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel |
84 | 84 | from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask |
@@ -1484,26 +1484,19 @@ def kernel_fn(q, k, v, d, s): |
1484 | 1484 |
|
1485 | 1485 | return attention_output, None |
1486 | 1486 |
|
1487 | | - def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None): |
1488 | | - # decoder_segment_ids can be None |
1489 | | - if pspec is None: |
1490 | | - return None |
1491 | | - sharding = NamedSharding(self.mesh, pspec) |
1492 | | - return maybe_shard_with_name( |
1493 | | - inputs, |
1494 | | - sharding, |
1495 | | - shard_mode=self.config.shard_mode, |
1496 | | - debug_sharding=self.config.debug_sharding, |
1497 | | - extra_stack_level=1, |
1498 | | - ) |
1499 | | - |
1500 | | - query = _maybe_shard_with_pspec(query, axis_names_q) |
1501 | | - key = _maybe_shard_with_pspec(key, axis_names_kv) |
1502 | | - value = _maybe_shard_with_pspec(value, axis_names_kv) |
1503 | | - decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q) |
1504 | | - decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv) |
1505 | | - sinks = _maybe_shard_with_pspec(sinks, sink_axis_names) |
1506 | | - indexer_mask = _maybe_shard_with_pspec(indexer_mask, indexer_mask_axis_names) |
| 1487 | + query = maybe_shard_with_pspec(query, self.mesh, self.config.shard_mode, axis_names_q, self.config.debug_sharding) |
| 1488 | + key = maybe_shard_with_pspec(key, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding) |
| 1489 | + value = maybe_shard_with_pspec(value, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding) |
| 1490 | + decoder_segment_ids_q = maybe_shard_with_pspec( |
| 1491 | + decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_q, self.config.debug_sharding |
| 1492 | + ) |
| 1493 | + decoder_segment_ids_kv = maybe_shard_with_pspec( |
| 1494 | + decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_kv, self.config.debug_sharding |
| 1495 | + ) |
| 1496 | + sinks = maybe_shard_with_pspec(sinks, self.mesh, self.config.shard_mode, sink_axis_names, self.config.debug_sharding) |
| 1497 | + indexer_mask = maybe_shard_with_pspec( |
| 1498 | + indexer_mask, self.mesh, self.config.shard_mode, indexer_mask_axis_names, self.config.debug_sharding |
| 1499 | + ) |
1507 | 1500 |
|
1508 | 1501 | ret = wrap_flash_attention( |
1509 | 1502 | query, |
|
0 commit comments