Skip to content

Commit 2f80169

Browse files
committed
added logging
1 parent 146ad2d commit 2f80169

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"""
3535

3636
import inspect
37+
import logging
3738
from typing import Sequence, Callable, Any
3839
from absl import app
3940
from flax import nnx
@@ -226,7 +227,9 @@ def __init__(
226227
def wrt_filter(path, x):
227228
if not isinstance(x, base_wrt):
228229
return False
229-
return not student_freeze_param_filter(path)
230+
freeze = student_freeze_param_filter(path)
231+
logging.info("Student model freezing info: Parameter {path}; freeze={freeze}")
232+
return not freeze
230233

231234
self.wrt_filter = wrt_filter
232235
else:
@@ -589,7 +592,7 @@ def train_distill(
589592
def student_freeze_param_fn(path) -> bool:
590593
path_str = "/".join(str(p) for p in path)
591594
return not any(template in path_str for template in student_params_to_update)
592-
595+
593596
if is_offline:
594597
max_logging.log("Offline Distillation: Skipping Teacher Model loading.")
595598
teacher_model = None

0 commit comments

Comments
 (0)