3737
3838
3939def _build_global_shape_and_sharding (
40- local_shape : tuple [int , ...], global_mesh : Mesh
40+ local_shape : tuple [int , ...], global_mesh : Mesh , global_batch_size : int = 0
4141) -> tuple [tuple [int , ...], NamedSharding ]:
42- sharding = NamedSharding (global_mesh , PartitionSpec (global_mesh .axis_names ))
42+ #Handle sharding for setting a gbs < jax.device_count
43+ if global_batch_size > 0 :
44+ sharding = NamedSharding (global_mesh , PartitionSpec (* global_mesh .axis_names ))
45+ else :
46+ sharding = NamedSharding (global_mesh , PartitionSpec (global_mesh .axis_names ))
4347
4448 global_shape = (jax .process_count () * local_shape [0 ],) + local_shape [1 :]
45-
4649 return global_shape , sharding
4750
4851
49- def _form_global_array (path , array : np .ndarray , global_mesh : Mesh ) -> jax .Array :
52+ def _form_global_array (path , array : np .ndarray , global_mesh : Mesh , global_batch_size : int = 0 , split_axis_index : int = 0 ) -> jax .Array :
5053 """Put local sharded array into local devices"""
51- global_shape , sharding = _build_global_shape_and_sharding (np .shape (array ), global_mesh )
54+ global_shape , sharding = _build_global_shape_and_sharding (np .shape (array ), global_mesh , global_batch_size )
5255 try :
53- local_device_arrays = np .split (array , len (global_mesh .local_devices ), axis = 0 )
56+ local_device_arrays = np .split (array , len (global_mesh .local_devices ), axis = split_axis_index )
5457 except ValueError as array_split_error :
5558 raise ValueError (
5659 f"Unable to put to devices shape { array .shape } with "
@@ -62,7 +65,7 @@ def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
6265 return jax .make_array_from_single_device_arrays (global_shape , sharding , local_device_buffers )
6366
6467
65- def get_next_batch_sharded (local_dataset : Iterator , global_mesh : Mesh ) -> jax .Array :
68+ def get_next_batch_sharded (local_dataset : Iterator , global_mesh : Mesh , global_batch_size : int = 0 , split_axis_index : int = 0 ) -> jax .Array :
6669 """Splits the host loaded data equally over all devices."""
6770
6871 SLEEP_TIME = 10
@@ -83,17 +86,33 @@ def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh) -> jax.Ar
8386 if not loaded_data_success :
8487 local_data = local_dataset .next ()
8588
86- input_gdas = jtu .tree_map_with_path (partial (_form_global_array , global_mesh = global_mesh ), local_data )
89+ input_gdas = jtu .tree_map_with_path (partial (_form_global_array , global_mesh = global_mesh , global_batch_size = global_batch_size , split_axis_index = split_axis_index ), local_data )
8790
8891 return input_gdas
8992
9093
9194class MultiHostDataLoadIterator :
9295 """fold get_next_batch_sharded into a iterator class"""
9396
94- def __init__ (self , dataloader : Union [tf .data .Dataset , Iterable ], global_mesh : Mesh ):
97+ def __init__ (self , dataloader : Union [tf .data .Dataset , Iterable ], global_mesh : Mesh , global_batch_size : int = 0 ):
9598 self .global_mesh = global_mesh
9699 self .dataloader = dataloader
100+ # Handles sharding for when gbs < number of devices
101+ self .global_batch_size = global_batch_size
102+ # Use the correct axis for splitting the data across when using global_batch_size
103+ split_axis_name = max (global_mesh .shape , key = global_mesh .shape .get )
104+ split_axis_index = 0
105+ if global_batch_size > 0 :
106+ max_logging .log (f"global_batch_size was set to { global_batch_size } , splitting data across { split_axis_name } ." )
107+ if split_axis_name == "data" :
108+ split_axis_index = 0
109+ elif split_axis_name == "fsdp" :
110+ split_axis_index = 1
111+ elif split_axis_name == "tensor" :
112+ split_axis_index = 2
113+ else :
114+ raise ValueError (f"Could not find { split_axis_name } to split data over." )
115+ self .split_axis_index = split_axis_index
97116 if isinstance (self .dataloader , tf .data .Dataset ):
98117 self .local_iterator = self .dataloader .as_numpy_iterator ()
99118 elif isinstance (self .dataloader , Iterable ):
@@ -114,4 +133,4 @@ def __iter__(self):
114133 return self
115134
116135 def __next__ (self ):
117- return get_next_batch_sharded (self .local_iterator , self .global_mesh )
136+ return get_next_batch_sharded (self .local_iterator , self .global_mesh , self . global_batch_size , self . split_axis_index )
0 commit comments