Skip to content

Commit 3385aa0

Browse files
Merge pull request #3648 from AI-Hypercomputer:shuningjin-gpt-oss-nan
PiperOrigin-RevId: 899197184
2 parents 4a1da22 + ac0cf4d commit 3385aa0

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

tests/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,15 @@ export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=81920'
5858
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check model_name=${MODEL_NAME} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product sparse_matmul=True megablox=True per_device_batch_size=1 max_target_length=4 max_prefill_predict_length=4 dtype=float32 --atol=0.1 --rtol=0.1 --max_kl_div=3e-4
5959

6060
# Run pre-training - megablox implementation
61-
python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=4
61+
python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=4 gcs_metrics=true
6262

6363
# Run fine-tuning - megablox implementation
64-
python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4
64+
# TODO: remove `abort_on_nan_loss=false` after b/497864549
65+
python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 gcs_metrics=true abort_on_nan_loss=false
6566

6667
# Run supervised fine-tuning - megablox implementation
67-
python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs/post_train}"//sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4
68+
# TODO: remove `abort_on_nan_loss=false` after b/497864549
69+
python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs/post_train}"//sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 gcs_metrics=true abort_on_nan_loss=false
6870

6971
# Run decoding - megablox implementation
7072
# Note decode requires the access token for huggingface tokenizer even if the model is not gated

0 commit comments

Comments
 (0)