Skip to content

Commit 146ad2d

Browse files
committed
added a unit test
1 parent 9f151e7 commit 146ad2d

1 file changed

Lines changed: 93 additions & 0 deletions

File tree

tests/post_training/unit/train_distill_test.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def test_prepare_inputs_logic(self):
126126
trainer.teacher_model = mock.Mock()
127127
trainer.model = mock.Mock()
128128
trainer.gen_model_input_fn = lambda x: {"inputs": {"some_key": "some_val"}}
129+
trainer.wrt_filter = lambda path, x: True # type: ignore
129130

130131
# 2. Setup Input
131132
# pylint: disable=unexpected-keyword-arg
@@ -153,6 +154,7 @@ def test_train_step_skips_teacher_forward_when_output_present(
153154
# pylint: disable=no-value-for-parameter
154155
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
155156
trainer.strategy = mock.Mock()
157+
trainer.wrt_filter = lambda path, x: True # type: ignore
156158

157159
# 2. Setup Batch WITH teacher_output
158160
mock_batch = {
@@ -205,6 +207,7 @@ def test_train_step_calls_teacher_forward_when_output_missing(
205207
# pylint: disable=no-value-for-parameter
206208
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
207209
trainer.strategy = mock.Mock()
210+
trainer.wrt_filter = lambda path, x: True # type: ignore
208211

209212
# 2. Setup Batch WITHOUT teacher_output
210213
mock_batch = {
@@ -278,6 +281,7 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_
278281
# pylint: disable=no-value-for-parameter
279282
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
280283
trainer.strategy = mock.Mock()
284+
trainer.wrt_filter = lambda path, x: True # type: ignore
281285

282286
# 2. Setup Batch WITH targets_segmentation
283287
mock_targets_segmentation = jnp.array([[1, 1, 0]])
@@ -579,6 +583,7 @@ def test_eval_step_calls_student_forward(self):
579583
# pylint: disable=no-value-for-parameter
580584
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
581585
trainer.strategy = mock.Mock()
586+
trainer.wrt_filter = lambda path, x: True # type: ignore
582587

583588
# 2. Setup Input Mocks
584589
raw_inputs = mock.Mock()
@@ -675,6 +680,7 @@ def test_post_process_train_step(self):
675680
"""Verifies metrics are moved from aux dict to the trainer buffer."""
676681
# pylint: disable=no-value-for-parameter
677682
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
683+
trainer.wrt_filter = lambda path, x: True # type: ignore
678684

679685
# Setup MetricsBuffer mock
680686
mock_buffer = mock.Mock()
@@ -723,6 +729,7 @@ def __call__(self, x):
723729
# pylint: disable=no-value-for-parameter
724730
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
725731
trainer.strategy = mock.Mock()
732+
trainer.wrt_filter = lambda path, x: True # type: ignore
726733

727734
dummy_batch = {
728735
"input_tokens": jnp.ones((1, 2)),
@@ -1121,6 +1128,92 @@ def test_main_online_mode_loads_teacher(
11211128
self.assertIs(model_bundle.student_model, mock_student_model)
11221129
self.assertIs(model_bundle.teacher_model, mock_teacher_model)
11231130

1131+
def test_student_freeze_param_filter(self):
1132+
"""Verifies that student_freeze_param_filter correctly freezes specified parameters."""
1133+
1134+
# 1. Setup a dummy model with multiple layers
1135+
class DummyModel(nnx.Module):
1136+
1137+
def __init__(self):
1138+
self.layer1 = nnx.Linear(in_features=2, out_features=2, rngs=nnx.Rngs(0))
1139+
self.layer2 = nnx.Linear(in_features=2, out_features=2, rngs=nnx.Rngs(1))
1140+
1141+
def __call__(self, input_tokens, **kwargs):
1142+
# Apply layers
1143+
return self.layer2(self.layer1(input_tokens))
1144+
1145+
student = DummyModel()
1146+
teacher = DummyModel()
1147+
model_bundle = train_distill.ModelBundle(teacher_model=teacher, student_model=student)
1148+
1149+
# Snapshot initial weights
1150+
initial_layer1_weights = student.layer1.kernel.get_value().copy()
1151+
initial_layer2_weights = student.layer2.kernel.get_value().copy()
1152+
1153+
# 2. Setup freeze filter (freeze layer1, train layer2)
1154+
def freeze_filter(path):
1155+
path_str = "/".join(str(p) for p in path)
1156+
return "layer1" in path_str
1157+
1158+
# 3. Setup Strategy and TrainingConfig
1159+
strategy = mock.Mock()
1160+
strategy.compute_loss.side_effect = lambda s_out, t_out, labels: (jnp.sum(s_out.logits), {"aux": 1.0})
1161+
strategy.create_labels.return_value = None
1162+
strategy.student_forward_fn = lambda model, **kw: distillation_utils.DistillationForwardOutput(
1163+
logits=model(kw["input_tokens"])
1164+
)
1165+
strategy.teacher_forward_fn = lambda model, **kw: distillation_utils.DistillationForwardOutput(
1166+
logits=model(kw["input_tokens"])
1167+
)
1168+
1169+
# pylint: disable=import-outside-toplevel
1170+
from tunix.sft import peft_trainer
1171+
1172+
train_config = peft_trainer.TrainingConfig(
1173+
max_steps=1,
1174+
eval_every_n_steps=0,
1175+
# checkpointing_options=ocp.CheckpointManagerOptions(create=False),
1176+
gradient_accumulation_steps=1,
1177+
)
1178+
1179+
# 4. Initialize Trainer
1180+
trainer = train_distill.MaxTextDistillationTrainer(
1181+
model=model_bundle,
1182+
strategy=strategy,
1183+
optimizer=optax.sgd(0.1),
1184+
training_config=train_config,
1185+
student_freeze_param_filter=freeze_filter,
1186+
)
1187+
trainer._lora_enabled = False
1188+
trainer.is_managed_externally = True
1189+
1190+
trainer = trainer.with_gen_model_input_fn(
1191+
lambda batch: {
1192+
"input_tokens": batch["input_tokens"],
1193+
"positions": None,
1194+
"attention_mask": None,
1195+
"decoder_segment_ids": None,
1196+
"targets": None,
1197+
"teacher_output": distillation_utils.DistillationForwardOutput(logits=jnp.ones((1, 2))),
1198+
}
1199+
)
1200+
1201+
dummy_batch = {"input_tokens": jnp.ones((1, 2))}
1202+
1203+
# 5. Execute Pass
1204+
trainer._train_step(model_bundle, trainer.optimizer, dummy_batch)
1205+
1206+
# 6. Verify layer1 is unchanged (frozen)
1207+
np.testing.assert_allclose(
1208+
student.layer1.kernel.get_value(),
1209+
initial_layer1_weights,
1210+
err_msg="layer1 weights should be frozen and remain unchanged.",
1211+
)
1212+
1213+
# Verify layer2 has changed (trained)
1214+
is_layer2_unchanged = np.allclose(student.layer2.kernel.get_value(), initial_layer2_weights)
1215+
self.assertFalse(is_layer2_unchanged, msg="layer2 weights should have updated.")
1216+
11241217

11251218
if __name__ == "__main__":
11261219
absltest.main()

0 commit comments

Comments
 (0)