Skip to content

Commit 9a55e74

Browse files
committed
Shorten gradient accumulation test
1 parent 5478bad commit 9a55e74

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

tests/integration/gradient_accumulation_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,11 @@ def test_grad_accumulate_same_loss(self):
6565
"gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off)
6666
"enable_checkpointing=False",
6767
"enable_goodput_recording=False",
68+
"decoder_block=simple",
6869
"base_emb_dim=256",
6970
"base_num_decoder_layers=4",
7071
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
71-
"steps=20",
72+
"steps=2",
7273
]
7374
# Run with gradient accumulation with accumulate_steps=10, per_device_batch=1 --> simulating per_device_batch=10
7475
train_main(

0 commit comments

Comments
 (0)