Skip to content

Commit deb686d

Browse files
committed
revert global_batch_size change.
1 parent 340d7c4 commit deb686d

2 files changed

Lines changed: 11 additions & 30 deletions

File tree

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _parse_tfrecord_fn(example):
105105
)
106106

107107
# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
108-
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh, config.global_batch_size)
108+
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
109109
return train_iter
110110

111111

src/maxdiffusion/multihost_dataloading.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,20 @@
3737

3838

3939
def _build_global_shape_and_sharding(
40-
local_shape: tuple[int, ...], global_mesh: Mesh, global_batch_size: int = 0
40+
local_shape: tuple[int, ...], global_mesh: Mesh
4141
) -> tuple[tuple[int, ...], NamedSharding]:
42-
#Handle sharding for setting a gbs < jax.device_count
43-
if global_batch_size > 0:
44-
sharding = NamedSharding(global_mesh, PartitionSpec(*global_mesh.axis_names))
45-
else:
46-
sharding = NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names))
42+
sharding = NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names))
4743

4844
global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
45+
4946
return global_shape, sharding
5047

5148

52-
def _form_global_array(path, array: np.ndarray, global_mesh: Mesh, global_batch_size: int = 0, split_axis_index: int = 0) -> jax.Array:
49+
def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
5350
"""Put local sharded array into local devices"""
54-
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh, global_batch_size)
51+
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
5552
try:
56-
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=split_axis_index)
53+
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
5754
except ValueError as array_split_error:
5855
raise ValueError(
5956
f"Unable to put to devices shape {array.shape} with "
@@ -65,7 +62,7 @@ def _form_global_array(path, array: np.ndarray, global_mesh: Mesh, global_batch_
6562
return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
6663

6764

68-
def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh, global_batch_size: int = 0, split_axis_index: int = 0) -> jax.Array:
65+
def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh) -> jax.Array:
6966
"""Splits the host loaded data equally over all devices."""
7067

7168
SLEEP_TIME = 10
@@ -86,33 +83,17 @@ def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh, global_ba
8683
if not loaded_data_success:
8784
local_data = local_dataset.next()
8885

89-
input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh=global_mesh, global_batch_size=global_batch_size, split_axis_index=split_axis_index), local_data)
86+
input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh=global_mesh), local_data)
9087

9188
return input_gdas
9289

9390

9491
class MultiHostDataLoadIterator:
9592
"""fold get_next_batch_sharded into a iterator class"""
9693

97-
def __init__(self, dataloader: Union[tf.data.Dataset, Iterable], global_mesh: Mesh, global_batch_size: int = 0):
94+
def __init__(self, dataloader: Union[tf.data.Dataset, Iterable], global_mesh: Mesh):
9895
self.global_mesh = global_mesh
9996
self.dataloader = dataloader
100-
# Handles sharding for when gbs < number of devices
101-
self.global_batch_size = global_batch_size
102-
# Use the correct axis for splitting the data across when using global_batch_size
103-
split_axis_name = max(global_mesh.shape, key=global_mesh.shape.get)
104-
split_axis_index = 0
105-
if global_batch_size > 0:
106-
max_logging.log(f"global_batch_size was set to {global_batch_size}, splitting data across {split_axis_name}.")
107-
if split_axis_name == "data":
108-
split_axis_index = 0
109-
elif split_axis_name == "fsdp":
110-
split_axis_index = 1
111-
elif split_axis_name == "tensor":
112-
split_axis_index = 2
113-
else:
114-
raise ValueError(f"Could not find {split_axis_name} to split data over.")
115-
self.split_axis_index = split_axis_index
11697
if isinstance(self.dataloader, tf.data.Dataset):
11798
self.local_iterator = self.dataloader.as_numpy_iterator()
11899
elif isinstance(self.dataloader, Iterable):
@@ -133,4 +114,4 @@ def __iter__(self):
133114
return self
134115

135116
def __next__(self):
136-
return get_next_batch_sharded(self.local_iterator, self.global_mesh, self.global_batch_size, self.split_axis_index)
117+
return get_next_batch_sharded(self.local_iterator, self.global_mesh)

0 commit comments

Comments
 (0)