File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1102,11 +1102,14 @@ def __call__(
11021102 logits = None
11031103 # When in the Indexer Dense Warm-up stage, skip the expensive output head projection
11041104 # for efficiency, as the main model is frozen and the LM loss is not needed.
1105- elif (cfg .use_indexer and not cfg .indexer_sparse_training ) and self .model_mode == MODEL_MODE_TRAIN :
1105+ # TODO(b/501446870): Investigate model_mode as train at beginning for decoding stage
1106+ elif (
1107+ cfg .use_indexer and cfg .indexer_loss_scaling_factor > 0.0 and not cfg .indexer_sparse_training
1108+ ) and model_mode == MODEL_MODE_TRAIN :
11061109 logits = None
11071110 # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
11081111 # Instead, we keep track on the hidden states, which has smaller size compared to full logits
1109- elif cfg .num_vocab_tiling > 1 and self . model_mode == MODEL_MODE_TRAIN :
1112+ elif cfg .num_vocab_tiling > 1 and model_mode == MODEL_MODE_TRAIN :
11101113 logits = None
11111114 self .sow ("intermediates" , "hidden_states" , hidden_state )
11121115
You can’t perform that action at this time.
0 commit comments