File tree Expand file tree Collapse file tree
tests/end_to_end/tpu/gpt_oss/20b Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -58,13 +58,15 @@ export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=81920'
5858python3 -m tests.utils.forward_pass_logit_checker " ${MAXTEXT_CONFIGS_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ configs} " //base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check model_name=${MODEL_NAME} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product sparse_matmul=True megablox=True per_device_batch_size=1 max_target_length=4 max_prefill_predict_length=4 dtype=float32 --atol=0.1 --rtol=0.1 --max_kl_div=3e-4
5959
6060# Run pre-training - megablox implementation
61- python3 -m maxtext.trainers.pre_train.train " ${MAXTEXT_CONFIGS_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ configs} " //base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=4
61+ python3 -m maxtext.trainers.pre_train.train " ${MAXTEXT_CONFIGS_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ configs} " //base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=4 gcs_metrics=true
6262
6363# Run fine-tuning - megablox implementation
64- python3 -m maxtext.trainers.pre_train.train " ${MAXTEXT_CONFIGS_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ configs} " //base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4
64+ # TODO: remove `abort_on_nan_loss=false` after b/497864549
65+ python3 -m maxtext.trainers.pre_train.train " ${MAXTEXT_CONFIGS_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ configs} " //base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 gcs_metrics=true abort_on_nan_loss=false
6566
6667# Run supervised fine-tuning - megablox implementation
67- python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated " ${MAXTEXT_CONFIGS_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ configs/ post_train} " //sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4
68+ # TODO: remove `abort_on_nan_loss=false` after b/497864549
69+ python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated " ${MAXTEXT_CONFIGS_DIR:- ${MAXTEXT_REPO_ROOT:- $PWD } / src/ maxtext/ configs/ post_train} " //sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 gcs_metrics=true abort_on_nan_loss=false
6870
6971# Run decoding - megablox implementation
7072# Note decode requires the access token for huggingface tokenizer even if the model is not gated
You can’t perform that action at this time.
0 commit comments