File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1114,11 +1114,14 @@ def __call__(
11141114 logits = None
11151115 # When in the Indexer Dense Warm-up stage, skip the expensive output head projection
11161116 # for efficiency, as the main model is frozen and the LM loss is not needed.
1117- elif (cfg .use_indexer and not cfg .indexer_sparse_training ) and self .model_mode == MODEL_MODE_TRAIN :
1117+ # TODO(b/501446870): Investigate model_mode as train at beginning for decoding stage
1118+ elif (
1119+ cfg .use_indexer and cfg .indexer_loss_scaling_factor > 0.0 and not cfg .indexer_sparse_training
1120+ ) and model_mode == MODEL_MODE_TRAIN :
11181121 logits = None
11191122 # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
11201123 # Instead, we keep track on the hidden states, which has smaller size compared to full logits
1121- elif cfg .num_vocab_tiling > 1 and self . model_mode == MODEL_MODE_TRAIN :
1124+ elif cfg .num_vocab_tiling > 1 and model_mode == MODEL_MODE_TRAIN :
11221125 logits = None
11231126 self .sow ("intermediates" , "hidden_states" , hidden_state )
11241127
You can’t perform that action at this time.
0 commit comments