3737
3838
3939def _build_global_shape_and_sharding (
40- local_shape : tuple [int , ...], global_mesh : Mesh , global_batch_size : int = 0
40+ local_shape : tuple [int , ...], global_mesh : Mesh
4141) -> tuple [tuple [int , ...], NamedSharding ]:
42- #Handle sharding for setting a gbs < jax.device_count
43- if global_batch_size > 0 :
44- sharding = NamedSharding (global_mesh , PartitionSpec (* global_mesh .axis_names ))
45- else :
46- sharding = NamedSharding (global_mesh , PartitionSpec (global_mesh .axis_names ))
42+ sharding = NamedSharding (global_mesh , PartitionSpec (global_mesh .axis_names ))
4743
4844 global_shape = (jax .process_count () * local_shape [0 ],) + local_shape [1 :]
45+
4946 return global_shape , sharding
5047
5148
52- def _form_global_array (path , array : np .ndarray , global_mesh : Mesh , global_batch_size : int = 0 , split_axis_index : int = 0 ) -> jax .Array :
49+ def _form_global_array (path , array : np .ndarray , global_mesh : Mesh ) -> jax .Array :
5350 """Put local sharded array into local devices"""
54- global_shape , sharding = _build_global_shape_and_sharding (np .shape (array ), global_mesh , global_batch_size )
51+ global_shape , sharding = _build_global_shape_and_sharding (np .shape (array ), global_mesh )
5552 try :
56- local_device_arrays = np .split (array , len (global_mesh .local_devices ), axis = split_axis_index )
53+ local_device_arrays = np .split (array , len (global_mesh .local_devices ), axis = 0 )
5754 except ValueError as array_split_error :
5855 raise ValueError (
5956 f"Unable to put to devices shape { array .shape } with "
@@ -65,7 +62,7 @@ def _form_global_array(path, array: np.ndarray, global_mesh: Mesh, global_batch_
6562 return jax .make_array_from_single_device_arrays (global_shape , sharding , local_device_buffers )
6663
6764
68- def get_next_batch_sharded (local_dataset : Iterator , global_mesh : Mesh , global_batch_size : int = 0 , split_axis_index : int = 0 ) -> jax .Array :
65+ def get_next_batch_sharded (local_dataset : Iterator , global_mesh : Mesh ) -> jax .Array :
6966 """Splits the host loaded data equally over all devices."""
7067
7168 SLEEP_TIME = 10
@@ -86,33 +83,17 @@ def get_next_batch_sharded(local_dataset: Iterator, global_mesh: Mesh, global_ba
8683 if not loaded_data_success :
8784 local_data = local_dataset .next ()
8885
89- input_gdas = jtu .tree_map_with_path (partial (_form_global_array , global_mesh = global_mesh , global_batch_size = global_batch_size , split_axis_index = split_axis_index ), local_data )
86+ input_gdas = jtu .tree_map_with_path (partial (_form_global_array , global_mesh = global_mesh ), local_data )
9087
9188 return input_gdas
9289
9390
9491class MultiHostDataLoadIterator :
9592 """fold get_next_batch_sharded into a iterator class"""
9693
97- def __init__ (self , dataloader : Union [tf .data .Dataset , Iterable ], global_mesh : Mesh , global_batch_size : int = 0 ):
94+ def __init__ (self , dataloader : Union [tf .data .Dataset , Iterable ], global_mesh : Mesh ):
9895 self .global_mesh = global_mesh
9996 self .dataloader = dataloader
100- # Handles sharding for when gbs < number of devices
101- self .global_batch_size = global_batch_size
102- # Use the correct axis for splitting the data across when using global_batch_size
103- split_axis_name = max (global_mesh .shape , key = global_mesh .shape .get )
104- split_axis_index = 0
105- if global_batch_size > 0 :
106- max_logging .log (f"global_batch_size was set to { global_batch_size } , splitting data across { split_axis_name } ." )
107- if split_axis_name == "data" :
108- split_axis_index = 0
109- elif split_axis_name == "fsdp" :
110- split_axis_index = 1
111- elif split_axis_name == "tensor" :
112- split_axis_index = 2
113- else :
114- raise ValueError (f"Could not find { split_axis_name } to split data over." )
115- self .split_axis_index = split_axis_index
11697 if isinstance (self .dataloader , tf .data .Dataset ):
11798 self .local_iterator = self .dataloader .as_numpy_iterator ()
11899 elif isinstance (self .dataloader , Iterable ):
@@ -133,4 +114,4 @@ def __iter__(self):
133114 return self
134115
135116 def __next__ (self ):
136- return get_next_batch_sharded (self .local_iterator , self .global_mesh , self . global_batch_size , self . split_axis_index )
117+ return get_next_batch_sharded (self .local_iterator , self .global_mesh )
0 commit comments