Skip to content

Commit 448978d

Browse files
committed
Enable indexer_sparse_training for logits and decoding
1 parent 49742a1 commit 448978d

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/maxtext/layers/decoders.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,11 +1102,14 @@ def __call__(
11021102
logits = None
11031103
# When in the Indexer Dense Warm-up stage, skip the expensive output head projection
11041104
# for efficiency, as the main model is frozen and the LM loss is not needed.
1105-
elif (cfg.use_indexer and not cfg.indexer_sparse_training) and self.model_mode == MODEL_MODE_TRAIN:
1105+
# TODO(b/501446870): Investigate model_mode as train at beginning for decoding stage
1106+
elif (
1107+
cfg.use_indexer and cfg.indexer_loss_scaling_factor > 0.0 and not cfg.indexer_sparse_training
1108+
) and model_mode == MODEL_MODE_TRAIN:
11061109
logits = None
11071110
# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
11081111
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
1109-
elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
1112+
elif cfg.num_vocab_tiling > 1 and model_mode == MODEL_MODE_TRAIN:
11101113
logits = None
11111114
self.sow("intermediates", "hidden_states", hidden_state)
11121115

0 commit comments

Comments
 (0)