Skip to content

Commit 6b1e0ef

Browse files
Handle None values for HLO dump config keys.
Sets `dump_hlo_local_module_name` and `dump_hlo_module_name` to an empty string if they are provided as None. PiperOrigin-RevId: 896185309
1 parent 1e343fe commit 6b1e0ef

2 files changed

Lines changed: 15 additions & 3 deletions

File tree

src/maxtext/configs/pyconfig.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,7 @@ def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]:
8383
return resolve_config_path(argv[1]), argv[2:]
8484
module = _module_from_path(argv[0])
8585
if module not in _CONFIG_FILE_MAPPING:
86-
raise ValueError(
87-
f"No config file provided and no default config found for module '{module}'"
88-
)
86+
raise ValueError(f"No config file provided and no default config found for module '{module}'")
8987
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
9088
logger.warning("No config file provided, using default config mapping: %s", config_path)
9189
return config_path, argv[1:]
@@ -203,6 +201,9 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
203201
if key == "run_name" and new_value is None:
204202
new_value = ""
205203

204+
if key in ("dump_hlo_local_module_name", "dump_hlo_module_name") and new_value is None:
205+
new_value = ""
206+
206207
if key == "tokenizer_path" and new_value is None:
207208
try:
208209
new_value = HF_IDS[raw_keys["model_name"]]

tests/unit/pyconfig_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,21 @@ def test_config_file_mapping(self):
122122

123123
def test_module_from_path(self):
124124
import maxtext.trainers.pre_train.train as train_module
125+
125126
module_file = train_module.__file__
126127
result = _module_from_path(module_file)
127128
self.assertEqual(result, "maxtext.trainers.pre_train.train")
128129

130+
def test_hlo_dump_module_names_none_coercion(self):
131+
config = pyconfig.initialize(
132+
[os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()],
133+
skip_jax_distributed_system=True,
134+
dump_hlo_local_module_name=None,
135+
dump_hlo_module_name=None,
136+
)
137+
self.assertEqual(config.dump_hlo_local_module_name, "")
138+
self.assertEqual(config.dump_hlo_module_name, "")
139+
129140
def test_unknown_module_raises(self):
130141
with self.assertRaises(ValueError):
131142
pyconfig.initialize_pydantic(["/custom_rl/module.py", "run_name=test"])

0 commit comments

Comments
 (0)