Skip to content

Commit 7e5ef3e

Browse files
Merge pull request #3633 from AI-Hypercomputer:indexer_fix
PiperOrigin-RevId: 897809905
2 parents f82cf83 + 448978d commit 7e5ef3e

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/maxtext/layers/decoders.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,11 +1114,14 @@ def __call__(
11141114
logits = None
11151115
# When in the Indexer Dense Warm-up stage, skip the expensive output head projection
11161116
# for efficiency, as the main model is frozen and the LM loss is not needed.
1117-
elif (cfg.use_indexer and not cfg.indexer_sparse_training) and self.model_mode == MODEL_MODE_TRAIN:
1117+
# TODO(b/501446870): Investigate model_mode as train at beginning for decoding stage
1118+
elif (
1119+
cfg.use_indexer and cfg.indexer_loss_scaling_factor > 0.0 and not cfg.indexer_sparse_training
1120+
) and model_mode == MODEL_MODE_TRAIN:
11181121
logits = None
11191122
# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
11201123
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
1121-
elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
1124+
elif cfg.num_vocab_tiling > 1 and model_mode == MODEL_MODE_TRAIN:
11221125
logits = None
11231126
self.sow("intermediates", "hidden_states", hidden_state)
11241127

0 commit comments

Comments
 (0)