Skip to content

Commit 1ae4aba

Browse files
committed
Add ring attention kernel
Signed-off-by: Kunjan patel <kunjanp@google.com>
1 parent c5d9018 commit 1ae4aba

8 files changed

Lines changed: 5321 additions & 6 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jit_initializers: True
5353
# Set true to load weights from pytorch
5454
from_pt: True
5555
split_head_dim: True
56-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
56+
attention: 'ring_flash' # Supported attention: dot_product, flash, cudnn_flash_te
5757

5858
flash_block_sizes: {}
5959
# Use on v6e
@@ -143,8 +143,8 @@ data_sharding: [['data', 'fsdp', 'tensor']]
143143
# By default, product of the DCN axes should equal number of slices
144144
# and product of the ICI axes should equal number of devices per slice.
145145
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
146-
dcn_fsdp_parallelism: -1
147-
dcn_tensor_parallelism: 1
146+
dcn_fsdp_parallelism: 4
147+
dcn_tensor_parallelism: 2
148148
ici_data_parallelism: 1
149149
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
150150
ici_tensor_parallelism: 1
@@ -217,11 +217,12 @@ adam_eps: 1.e-8 # A small constant applied to denominator outside of the square
217217
adam_weight_decay: 0 # AdamW Weight decay
218218
max_grad_norm: 1.0
219219

220-
enable_profiler: False
220+
enable_profiler: True
221221
# Skip first n steps for profiling, to omit things like compilation and to give
222222
# the iteration time a chance to stabilize.
223223
skip_first_n_steps_for_profiler: 5
224224
profiler_steps: 10
225+
tensorboard_dir: /home/kunjanp_google_com/wan-21-md/maxdiffusion/.trace/flash-tp
225226

226227
# Generation parameters
227228
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."

src/maxdiffusion/models/attention_flax.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ def _tpu_ring_flash_attention_v1(
166166
dtype: jnp.dtype = jnp.float32,
167167
) -> jax.Array:
168168
"""TPU Ring Flash Attention with correct padding, transposition, and sharding."""
169-
from ringattention import ringattention
169+
# from ringattention import ringattention
170+
from maxdiffusion.models.ringattention.ringattention_pallas_tpu import ring_flash_attention_tpu as ringattention
170171
from einops import rearrange
171172

172173
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
@@ -255,7 +256,7 @@ def _tpu_ring_flash_attention(
255256
usp_degree: Optional[int] = 1,
256257
) -> jax.Array:
257258
"""TPU Ring/USP Flash Attention with correct padding, transposition, and sharding."""
258-
from ringattention import ringattention
259+
from maxdiffusion.models.ringattention.ringattention_pallas_tpu import ring_flash_attention_tpu as ringattention
259260

260261
max_block_size = 1024 if dtype == jnp.bfloat16 else 512
261262
blockwise_kwargs = {

0 commit comments

Comments
 (0)