Skip to content

Commit d370f95

Browse files
Merge pull request #3544 from AI-Hypercomputer:shuningjin-fix-error
PiperOrigin-RevId: 893250100
2 parents 1c76267 + 48d1761 commit d370f95

1 file changed

Lines changed: 9 additions & 7 deletions

File tree

tests/utils/forward_pass_logit_checker.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,15 @@ def main(config, test_args): # pylint: disable=W0621
406406

407407
hf_model = AutoModelForCausalLM.from_pretrained(test_args.hf_model_path, dtype=torch_dtype, token=hf_token)
408408

409-
if os.path.isdir(test_args.hf_model_path):
410-
# local hf directory may not contain tokenizer, read from remote tokenizer
411-
tokenizer_path = config.tokenizer_path
412-
else:
413-
tokenizer_path = test_args.hf_model_path
414-
415-
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=hf_token)
409+
# Load tokenizer: `test_args.hf_model_path` or fallback to `config.tokenizer_path`
410+
try:
411+
# Try loading from `test_args.hf_model_path`
412+
max_logging.log(f"Loading tokenizer from {test_args.hf_model_path}.")
413+
tokenizer = AutoTokenizer.from_pretrained(test_args.hf_model_path, token=hf_token)
414+
except Exception as e: # pylint: disable=broad-except
415+
# Fallback to `config.tokenizer_path`. local hf directory may not contain tokenizer, read from remote tokenizer
416+
max_logging.log(f"Tokenizer loading error: {e}.\nLoading tokenizer from {config.tokenizer_path}.")
417+
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path, token=hf_token)
416418

417419
# maxtext model prefix, use eos token as pad token
418420
pad_token_prefixes = ["llama3.1", "mixtral"]

0 commit comments

Comments
 (0)