Skip to content

Commit 00bf1cd

Browse files
committed
add support for < 1 batch item per device.
1 parent 128ef01 commit 00bf1cd

7 files changed

Lines changed: 41 additions & 17 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ jit_initializers: True
5454
from_pt: True
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
57+
flash_min_seq_length: 4096
5758

5859
flash_block_sizes: {}
5960
# Use on v6e
@@ -131,6 +132,7 @@ logical_axis_rules: [
131132
['activation_batch', 'data'],
132133
['mlp','tensor'],
133134
['embed','fsdp'],
135+
['heads', 'tensor'],
134136
['norm', 'tensor'],
135137
['conv_batch', ['data','fsdp']],
136138
['out_channels', 'tensor'],

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _parse_tfrecord_fn(example):
105105
)
106106

107107
# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
108-
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
108+
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh, config.global_batch_size)
109109
return train_iter
110110

111111

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,6 @@ def _tpu_flash_attention(
187187
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
188188
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
189189
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
190-
flash_axis_names_splash_kernel: AxisNames = (HEAD, KV_LENGTH)
191-
axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
192-
named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
193-
194-
shard_head_size = mesh.shape["tensor"]
195190

196191
@functools.partial(
197192
shard_map.shard_map,
@@ -215,6 +210,9 @@ def wrap_flash_attention(query, key, value):
215210
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
216211
block_sizes=block_sizes,
217212
)
213+
# jax.debug.print("query.shape: {x}", x=query.shape)
214+
# jax.debug.print("key.shape: {x}", x=key.shape)
215+
# jax.debug.print("value.shape: {x}", x=value.shape)
218216
attention_output = jax.vmap(splash_kernel)(query, key, value)
219217
return attention_output
220218

@@ -799,6 +797,7 @@ def __call__(
799797
query_proj = _unflatten_heads(query_proj, self.heads)
800798
key_proj = _unflatten_heads(key_proj, self.heads)
801799
value_proj = _unflatten_heads(value_proj, self.heads)
800+
# output of _unflatten_heads Batch, heads, seq_len, head_dim
802801
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
803802

804803
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)

src/maxdiffusion/multihost_dataloading.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,23 @@
3737

3838

3939
def _build_global_shape_and_sharding(
40-
local_shape: tuple[int, ...], global_mesh: Mesh
40+
local_shape: tuple[int, ...], global_mesh: Mesh, global_batch_size: int = 0
4141
) -> tuple[tuple[int, ...], NamedSharding]:
42-
sharding = NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names))
42+
#Handle sharding for setting a gbs < jax.device_count
43+
if global_batch_size > 0:
44+
sharding = NamedSharding(global_mesh, PartitionSpec(*global_mesh.axis_names))
45+
else:
46+
sharding = NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names))
4347

4448
global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
45-
4649
return global_shape, sharding
4750

4851

49-
def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
52+
def _form_global_array(path, array: np.ndarray, global_mesh: Mesh, global_batch_size: int = 0, split_axis_index: int = 0) -> jax.Array:
5053
"""Put local sharded array into local devices"""
51-
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
54+
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh, global_batch_size)
5255
try:
53-
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
56+
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=split_axis_index)
5457
except ValueError as array_split_error:
5558
raise ValueError(
5659
f"Unable to put to devices shape {array.shape} with "
@@ -62,7 +65,7 @@ def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
6265
return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
6366

6467

65-
def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh) -> jax.Array:
68+
def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh, global_batch_size: int = 0, split_axis_index: int = 0) -> jax.Array:
6669
"""Splits the host loaded data equally over all devices."""
6770

6871
SLEEP_TIME = 10
@@ -83,17 +86,33 @@ def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh) -> jax.Ar
8386
if not loaded_data_success:
8487
local_data = local_dataset.next()
8588

86-
input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh=global_mesh), local_data)
89+
input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh=global_mesh, global_batch_size=global_batch_size, split_axis_index=split_axis_index), local_data)
8790

8891
return input_gdas
8992

9093

9194
class MultiHostDataLoadIterator:
9295
"""fold get_next_batch_sharded into a iterator class"""
9396

94-
def __init__(self, dataloader: Union[tf.data.Dataset, Iterable], global_mesh: Mesh):
97+
def __init__(self, dataloader: Union[tf.data.Dataset, Iterable], global_mesh: Mesh, global_batch_size: int = 0):
9598
self.global_mesh = global_mesh
9699
self.dataloader = dataloader
100+
# Handles sharding for when gbs < number of devices
101+
self.global_batch_size = global_batch_size
102+
# Use the correct axis for splitting the data across when using global_batch_size
103+
split_axis_name = max(global_mesh.shape, key=global_mesh.shape.get)
104+
split_axis_index = 0
105+
if global_batch_size > 0:
106+
max_logging.log(f"global_batch_size was set to {global_batch_size}, splitting data across {split_axis_name}.")
107+
if split_axis_name == "data":
108+
split_axis_index = 0
109+
elif split_axis_name == "fsdp":
110+
split_axis_index = 1
111+
elif split_axis_name == "tensor":
112+
split_axis_index = 2
113+
else:
114+
raise ValueError(f"Could not find {split_axis_name} to split data over.")
115+
self.split_axis_index = split_axis_index
97116
if isinstance(self.dataloader, tf.data.Dataset):
98117
self.local_iterator = self.dataloader.as_numpy_iterator()
99118
elif isinstance(self.dataloader, Iterable):
@@ -114,4 +133,4 @@ def __iter__(self):
114133
return self
115134

116135
def __next__(self):
117-
return get_next_batch_sharded(self.local_iterator, self.global_mesh)
136+
return get_next_batch_sharded(self.local_iterator, self.global_mesh, self.global_batch_size, self.split_axis_index)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
7979
wan_config["precision"] = get_precision(config)
8080
wan_config["flash_block_sizes"] = get_flash_block_sizes(config)
8181
wan_config["remat_policy"] = config.remat_policy
82+
wan_config["flash_min_seq_length"] = config.flash_min_seq_length
8283

8384
# 2. eval_shape - will not use flops or create weights on device
8485
# thus not using HBM memory.

src/maxdiffusion/pyconfig.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def user_init(raw_keys):
181181
raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"])
182182
raw_keys["num_slices"] = get_num_slices(raw_keys)
183183
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
184+
if "global_batch_size" not in raw_keys.keys():
185+
raw_keys["global_batch_size"] = 0
184186

185187

186188
def get_num_slices(raw_keys):

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def __init__(self, config):
6969
if config.train_text_encoder:
7070
raise ValueError("this script currently doesn't support training text_encoders")
7171

72-
self.global_batch_size = self.config.per_device_batch_size * jax.device_count()
72+
#self.global_batch_size = self.config.per_device_batch_size * jax.device_count()
73+
self.global_batch_size = config.global_batch_size if config.global_batch_size > 0 else config.per_device_batch_size * jax.device_count()
7374

7475
def post_training_steps(self, pipeline, params, train_states, msg=""):
7576
pass

0 commit comments

Comments
 (0)