@@ -126,6 +126,7 @@ def test_prepare_inputs_logic(self):
126126 trainer .teacher_model = mock .Mock ()
127127 trainer .model = mock .Mock ()
128128 trainer .gen_model_input_fn = lambda x : {"inputs" : {"some_key" : "some_val" }}
129+ trainer .wrt_filter = lambda path , x : True # type: ignore
129130
130131 # 2. Setup Input
131132 # pylint: disable=unexpected-keyword-arg
@@ -153,6 +154,7 @@ def test_train_step_skips_teacher_forward_when_output_present(
153154 # pylint: disable=no-value-for-parameter
154155 trainer = train_distill .MaxTextDistillationTrainer .__new__ (train_distill .MaxTextDistillationTrainer )
155156 trainer .strategy = mock .Mock ()
157+ trainer .wrt_filter = lambda path , x : True # type: ignore
156158
157159 # 2. Setup Batch WITH teacher_output
158160 mock_batch = {
@@ -205,6 +207,7 @@ def test_train_step_calls_teacher_forward_when_output_missing(
205207 # pylint: disable=no-value-for-parameter
206208 trainer = train_distill .MaxTextDistillationTrainer .__new__ (train_distill .MaxTextDistillationTrainer )
207209 trainer .strategy = mock .Mock ()
210+ trainer .wrt_filter = lambda path , x : True # type: ignore
208211
209212 # 2. Setup Batch WITHOUT teacher_output
210213 mock_batch = {
@@ -278,6 +281,7 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_
278281 # pylint: disable=no-value-for-parameter
279282 trainer = train_distill .MaxTextDistillationTrainer .__new__ (train_distill .MaxTextDistillationTrainer )
280283 trainer .strategy = mock .Mock ()
284+ trainer .wrt_filter = lambda path , x : True # type: ignore
281285
282286 # 2. Setup Batch WITH targets_segmentation
283287 mock_targets_segmentation = jnp .array ([[1 , 1 , 0 ]])
@@ -579,6 +583,7 @@ def test_eval_step_calls_student_forward(self):
579583 # pylint: disable=no-value-for-parameter
580584 trainer = train_distill .MaxTextDistillationTrainer .__new__ (train_distill .MaxTextDistillationTrainer )
581585 trainer .strategy = mock .Mock ()
586+ trainer .wrt_filter = lambda path , x : True # type: ignore
582587
583588 # 2. Setup Input Mocks
584589 raw_inputs = mock .Mock ()
@@ -675,6 +680,7 @@ def test_post_process_train_step(self):
675680 """Verifies metrics are moved from aux dict to the trainer buffer."""
676681 # pylint: disable=no-value-for-parameter
677682 trainer = train_distill .MaxTextDistillationTrainer .__new__ (train_distill .MaxTextDistillationTrainer )
683+ trainer .wrt_filter = lambda path , x : True # type: ignore
678684
679685 # Setup MetricsBuffer mock
680686 mock_buffer = mock .Mock ()
@@ -723,6 +729,7 @@ def __call__(self, x):
723729 # pylint: disable=no-value-for-parameter
724730 trainer = train_distill .MaxTextDistillationTrainer .__new__ (train_distill .MaxTextDistillationTrainer )
725731 trainer .strategy = mock .Mock ()
732+ trainer .wrt_filter = lambda path , x : True # type: ignore
726733
727734 dummy_batch = {
728735 "input_tokens" : jnp .ones ((1 , 2 )),
@@ -1121,6 +1128,92 @@ def test_main_online_mode_loads_teacher(
11211128 self .assertIs (model_bundle .student_model , mock_student_model )
11221129 self .assertIs (model_bundle .teacher_model , mock_teacher_model )
11231130
1131+ def test_student_freeze_param_filter (self ):
1132+ """Verifies that student_freeze_param_filter correctly freezes specified parameters."""
1133+
1134+ # 1. Setup a dummy model with multiple layers
1135+ class DummyModel (nnx .Module ):
1136+
1137+ def __init__ (self ):
1138+ self .layer1 = nnx .Linear (in_features = 2 , out_features = 2 , rngs = nnx .Rngs (0 ))
1139+ self .layer2 = nnx .Linear (in_features = 2 , out_features = 2 , rngs = nnx .Rngs (1 ))
1140+
1141+ def __call__ (self , input_tokens , ** kwargs ):
1142+ # Apply layers
1143+ return self .layer2 (self .layer1 (input_tokens ))
1144+
1145+ student = DummyModel ()
1146+ teacher = DummyModel ()
1147+ model_bundle = train_distill .ModelBundle (teacher_model = teacher , student_model = student )
1148+
1149+ # Snapshot initial weights
1150+ initial_layer1_weights = student .layer1 .kernel .get_value ().copy ()
1151+ initial_layer2_weights = student .layer2 .kernel .get_value ().copy ()
1152+
1153+ # 2. Setup freeze filter (freeze layer1, train layer2)
1154+ def freeze_filter (path ):
1155+ path_str = "/" .join (str (p ) for p in path )
1156+ return "layer1" in path_str
1157+
1158+ # 3. Setup Strategy and TrainingConfig
1159+ strategy = mock .Mock ()
1160+ strategy .compute_loss .side_effect = lambda s_out , t_out , labels : (jnp .sum (s_out .logits ), {"aux" : 1.0 })
1161+ strategy .create_labels .return_value = None
1162+ strategy .student_forward_fn = lambda model , ** kw : distillation_utils .DistillationForwardOutput (
1163+ logits = model (kw ["input_tokens" ])
1164+ )
1165+ strategy .teacher_forward_fn = lambda model , ** kw : distillation_utils .DistillationForwardOutput (
1166+ logits = model (kw ["input_tokens" ])
1167+ )
1168+
1169+ # pylint: disable=import-outside-toplevel
1170+ from tunix .sft import peft_trainer
1171+
1172+ train_config = peft_trainer .TrainingConfig (
1173+ max_steps = 1 ,
1174+ eval_every_n_steps = 0 ,
1175+ # checkpointing_options=ocp.CheckpointManagerOptions(create=False),
1176+ gradient_accumulation_steps = 1 ,
1177+ )
1178+
1179+ # 4. Initialize Trainer
1180+ trainer = train_distill .MaxTextDistillationTrainer (
1181+ model = model_bundle ,
1182+ strategy = strategy ,
1183+ optimizer = optax .sgd (0.1 ),
1184+ training_config = train_config ,
1185+ student_freeze_param_filter = freeze_filter ,
1186+ )
1187+ trainer ._lora_enabled = False
1188+ trainer .is_managed_externally = True
1189+
1190+ trainer = trainer .with_gen_model_input_fn (
1191+ lambda batch : {
1192+ "input_tokens" : batch ["input_tokens" ],
1193+ "positions" : None ,
1194+ "attention_mask" : None ,
1195+ "decoder_segment_ids" : None ,
1196+ "targets" : None ,
1197+ "teacher_output" : distillation_utils .DistillationForwardOutput (logits = jnp .ones ((1 , 2 ))),
1198+ }
1199+ )
1200+
1201+ dummy_batch = {"input_tokens" : jnp .ones ((1 , 2 ))}
1202+
1203+ # 5. Execute Pass
1204+ trainer ._train_step (model_bundle , trainer .optimizer , dummy_batch )
1205+
1206+ # 6. Verify layer1 is unchanged (frozen)
1207+ np .testing .assert_allclose (
1208+ student .layer1 .kernel .get_value (),
1209+ initial_layer1_weights ,
1210+ err_msg = "layer1 weights should be frozen and remain unchanged." ,
1211+ )
1212+
1213+ # Verify layer2 has changed (trained)
1214+ is_layer2_unchanged = np .allclose (student .layer2 .kernel .get_value (), initial_layer2_weights )
1215+ self .assertFalse (is_layer2_unchanged , msg = "layer2 weights should have updated." )
1216+
11241217
11251218if __name__ == "__main__" :
11261219 absltest .main ()
0 commit comments