File tree Expand file tree Collapse file tree
src/maxtext/trainers/post_train/distillation Expand file tree Collapse file tree Original file line number Diff line number Diff line change 3434"""
3535
3636import inspect
37+ import logging
3738from typing import Sequence , Callable , Any
3839from absl import app
3940from flax import nnx
@@ -226,7 +227,9 @@ def __init__(
226227 def wrt_filter (path , x ):
227228 if not isinstance (x , base_wrt ):
228229 return False
229- return not student_freeze_param_filter (path )
230+ freeze = student_freeze_param_filter (path )
231+ logging .info ("Student model freezing info: Parameter {path}; freeze={freeze}" )
232+ return not freeze
230233
231234 self .wrt_filter = wrt_filter
232235 else :
@@ -589,7 +592,7 @@ def train_distill(
589592 def student_freeze_param_fn (path ) -> bool :
590593 path_str = "/" .join (str (p ) for p in path )
591594 return not any (template in path_str for template in student_params_to_update )
592-
595+
593596 if is_offline :
594597 max_logging .log ("Offline Distillation: Skipping Teacher Model loading." )
595598 teacher_model = None
You can’t perform that action at this time.
0 commit comments