Skip to content

Commit 9f151e7

Browse files
committed
Freezing all but changed and norm layer student weights
1 parent 7f479f4 commit 9f151e7

2 files changed

Lines changed: 38 additions & 5 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,12 @@ class Distillation(BaseModel):
11261126
distill_beta: float = Field(0.0, description="Weight for the feature loss component. Use 0.0 to disable")
11271127
distill_layer_indices: None | list = Field(None, description="Feature indices for feature loss.")
11281128

1129+
# --- Distillation freezing filter --
1130+
student_params_to_update: None | list = Field(
1131+
None,
1132+
description="a list of model param name templates to finetune in the student model. The other parameters will be frozen if this attribute is non empty)",
1133+
)
1134+
11291135

11301136
class TrainingLoop(BaseModel):
11311137
"""Configuration for the main training loop, evaluation, and reproducibility."""

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
"""
3535

3636
import inspect
37-
from typing import Sequence, Callable
37+
from typing import Sequence, Callable, Any
3838
from absl import app
3939
from flax import nnx
4040
from flax.linen import partitioning as nn_partitioning
@@ -199,7 +199,15 @@ class MaxTextDistillationTrainer(peft_trainer.PeftTrainer):
199199
(positions, segment_ids) are passed to the model.
200200
"""
201201

202-
def __init__(self, model, strategy: distillation_utils.DistillationStrategy, optimizer, training_config, **kwargs):
202+
def __init__(
203+
self,
204+
model,
205+
strategy: distillation_utils.DistillationStrategy,
206+
optimizer,
207+
training_config,
208+
student_freeze_param_filter: Callable[[Any], bool] | None = None,
209+
**kwargs,
210+
):
203211
# We pass a dummy optimizer to the base PeftTrainer temporarily to prevent PeftTrainer from eagerly
204212
# allocating massive optimizer states for the entire ModelBundle (including the frozen teacher) before
205213
# redefining the trainer optimizer here.
@@ -211,8 +219,20 @@ def __init__(self, model, strategy: distillation_utils.DistillationStrategy, opt
211219
# override optimizer to only use student_model.
212220
if training_config.gradient_accumulation_steps is not None and training_config.gradient_accumulation_steps > 1:
213221
optimizer = optax.MultiSteps(optimizer, training_config.gradient_accumulation_steps)
214-
wrt = nnx.LoRAParam if self._lora_enabled else nnx.Param
215-
self.optimizer = nnx.Optimizer(model.student_model, optimizer, wrt=wrt)
222+
223+
base_wrt = nnx.LoRAParam if getattr(self, "_lora_enabled", False) else nnx.Param
224+
if student_freeze_param_filter:
225+
226+
def wrt_filter(path, x):
227+
if not isinstance(x, base_wrt):
228+
return False
229+
return not student_freeze_param_filter(path)
230+
231+
self.wrt_filter = wrt_filter
232+
else:
233+
self.wrt_filter = base_wrt
234+
235+
self.optimizer = nnx.Optimizer(model.student_model, optimizer, wrt=self.wrt_filter)
216236

217237
# Detect if Tunix expects _train_step to return grad_norm by inspecting the source
218238
self._tunix_expects_grad_norm = False
@@ -282,7 +302,7 @@ def loss_wrapper(student, teacher, batch):
282302
# we only compute gradients for the student.
283303
grad_fn = nnx.value_and_grad(
284304
loss_wrapper,
285-
argnums=0,
305+
argnums=nnx.DiffState(0, self.wrt_filter),
286306
has_aux=True,
287307
)
288308

@@ -564,6 +584,12 @@ def train_distill(
564584
_log_config_details(student_config, "Student")
565585
student_model = get_maxtext_model(student_config, mesh)
566586

587+
student_params_to_update = getattr(student_config, "student_params_to_update", [])
588+
589+
def student_freeze_param_fn(path) -> bool:
590+
path_str = "/".join(str(p) for p in path)
591+
return not any(template in path_str for template in student_params_to_update)
592+
567593
if is_offline:
568594
max_logging.log("Offline Distillation: Skipping Teacher Model loading.")
569595
teacher_model = None
@@ -582,6 +608,7 @@ def train_distill(
582608
strategy=strategy,
583609
optimizer=optimizer,
584610
training_config=train_config,
611+
student_freeze_param_filter=student_freeze_param_fn if student_params_to_update else None,
585612
)
586613
trainer.is_managed_externally = True
587614
trainer._has_aux = True # pylint: disable=protected-access

0 commit comments

Comments
 (0)