@@ -114,8 +114,10 @@ def compare_sharding_jsons(json1: dict, model1_name: str, json2: dict, model2_na
114114# Requires JAX TPU support to generate the simulated TPU topology.
115115@pytest .mark .cpu_only
116116@pytest .mark .tpu_backend
117- @pytest .mark .parametrize ("model_name, topology, num_slice" , TEST_CASES )
118- def test_sharding_dump_for_model (model_name : str , topology : str , num_slice : str ) -> None :
117+ @pytest .mark .parametrize ("model_name, topology, num_slice, custom_mesh_and_rule, overrides" , TEST_CASES )
118+ def test_sharding_dump_for_model (
119+ model_name : str , topology : str , num_slice : str , custom_mesh_and_rule : str , overrides : tuple
120+ ) -> None :
119121 """
120122 Test sharding configurations from train_compile.get_shaped_inputs.
121123 This test verifies that the sharding configurations for various models and topologies remain consistent with golden files.
@@ -132,9 +134,16 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str)
132134 "enable_nnx=False" ,
133135 "pure_nnx_decoder=False" ,
134136 ]
137+ if custom_mesh_and_rule :
138+ params .append (f"custom_mesh_and_rule={ custom_mesh_and_rule } " )
139+ if overrides :
140+ params .extend (overrides )
135141
136142 root_dir = "tests/utils/sharding_info"
137- base_path = os .path .join (root_dir , model_name , topology , f"slice_{ num_slice } " )
143+ rule_name = f"rule_{ custom_mesh_and_rule } " if custom_mesh_and_rule else "rule_default"
144+ if overrides :
145+ rule_name += "_" + "_" .join (overrides )
146+ base_path = os .path .join (root_dir , model_name , topology , f"slice_{ num_slice } " , rule_name )
138147
139148 named_json_path = os .path .join (base_path , "named_shardings.json" )
140149 logical_json_path = os .path .join (base_path , "logical_shardings.json" )
@@ -206,12 +215,16 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str)
206215
207216@pytest .fixture (
208217 scope = "module" ,
209- params = [pytest .param (case , id = f"{ case [0 ]} -{ case [1 ]} -{ case [2 ]} " ) for case in TEST_CASES ],
218+ params = [pytest .param (case , id = f"{ case [0 ]} -{ case [1 ]} -{ case [2 ]} - { case [ 3 ] } - { '' . join ( case [ 4 ]) } " ) for case in TEST_CASES ],
210219)
211220def abstract_state_and_shardings (request ):
212221 """Pytest fixture to set up model, config, and generate abstract state once per test case."""
213- model_name , topology , num_slice = request .param
214- print (f"Testing model: { model_name } , topology: { topology } , num_slices: { num_slice } " , flush = True )
222+ model_name , topology , num_slice , custom_mesh_and_rule , overrides = request .param
223+ print (
224+ f"Testing model: { model_name } , topology: { topology } , num_slices: { num_slice } , "
225+ "rule: {custom_mesh_and_rule}, overrides: {overrides}" ,
226+ flush = True ,
227+ )
215228 params = [
216229 "/deps/MaxText/tests/unit/sharding_compare_test" ,
217230 get_test_config_path (),
@@ -223,6 +236,10 @@ def abstract_state_and_shardings(request):
223236 "enable_nnx=False" ,
224237 "pure_nnx_decoder=False" ,
225238 ]
239+ if custom_mesh_and_rule :
240+ params .append (f"custom_mesh_and_rule={ custom_mesh_and_rule } " )
241+ if overrides :
242+ params .extend (overrides )
226243 config = pyconfig .initialize (params )
227244 validate_config (config )
228245
@@ -245,7 +262,16 @@ def abstract_state_and_shardings(request):
245262 # Get logical shardings from maxtext_utils
246263 logical_shardings = maxtext_utils .get_logical_annotations (config , topology_mesh , init_state_fn )
247264
248- return model_name , topology , num_slice , abstract_state , state_mesh_shardings , logical_shardings
265+ return (
266+ model_name ,
267+ topology ,
268+ num_slice ,
269+ custom_mesh_and_rule ,
270+ overrides ,
271+ abstract_state ,
272+ state_mesh_shardings ,
273+ logical_shardings ,
274+ )
249275
250276
251277@pytest .mark .cpu_only
@@ -257,9 +283,16 @@ class TestGetAbstractState:
257283 def test_get_abstract_state_sharding (self , abstract_state_and_shardings ): # pylint: disable=redefined-outer-name
258284 """Tests that get_abstract_state returns a state with the correct abstract structure and compares sharding."""
259285
260- model_name , topology , num_slice , abstract_state , state_mesh_shardings , logical_shardings = (
261- abstract_state_and_shardings
262- )
286+ (
287+ model_name ,
288+ topology ,
289+ num_slice ,
290+ custom_mesh_and_rule ,
291+ overrides ,
292+ abstract_state ,
293+ state_mesh_shardings ,
294+ logical_shardings ,
295+ ) = abstract_state_and_shardings
263296
264297 assert hasattr (abstract_state , "params" )
265298 assert hasattr (abstract_state , "opt_state" )
@@ -268,7 +301,10 @@ def test_get_abstract_state_sharding(self, abstract_state_and_shardings): # pyl
268301 assert param_leaf .dtype == jnp .float32
269302
270303 root_dir = "tests/utils/sharding_info" # Or your target directory
271- base_path = os .path .join (root_dir , model_name , topology , f"slice_{ num_slice } " )
304+ rule_name = f"rule_{ custom_mesh_and_rule } " if custom_mesh_and_rule else "rule_default"
305+ if overrides :
306+ rule_name += "_" + "_" .join (overrides )
307+ base_path = os .path .join (root_dir , model_name , topology , f"slice_{ num_slice } " , rule_name )
272308 os .makedirs (base_path , exist_ok = True ) # Ensure directory exists for saving actual
273309
274310 error_messages = []
0 commit comments