Skip to content

Commit f216197

Browse files
gagikaGoogle-ML-Automation
authored andcommitted
Add support for L2 loss in feature distillation.
PiperOrigin-RevId: 898201648
1 parent 3ac667a commit f216197

5 files changed

Lines changed: 29 additions & 7 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,6 +1194,8 @@ distill_temperature: 1.0
11941194
# distill_beta is used for cosine similarity loss between intermediate activataitions of out_proj in teacher/student models.
11951195
# 0.0 value disables this feature.
11961196
distill_beta: 0.0
1197+
# distill_feature_loss_type is the type of loss to use for feature distillation ("cosine" or "l2").
1198+
distill_feature_loss_type: "cosine"
11971199
distill_layer_indices: None
11981200

11991201
##### Elastic training parameters

src/maxtext/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,9 @@ class Distillation(BaseModel):
11551155
distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.")
11561156
distill_temperature: float = Field(1.0, description="Temperature for distillation softening.")
11571157
distill_beta: float = Field(0.0, description="Weight for the feature loss component. Use 0.0 to disable")
1158+
distill_feature_loss_type: Literal["cosine", "l2"] = Field(
1159+
"cosine", description="The type of loss to use for feature distillation ('cosine' or 'l2')."
1160+
)
11581161
distill_layer_indices: None | list = Field(None, description="Feature indices for feature loss.")
11591162

11601163
# --- Distillation freezing filter --

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from array_record.python import array_record_module
2424

2525
import abc
26-
from typing import Any, Iterator, Optional, List, Callable
26+
from typing import Any, Iterator, Optional, List, Callable, Literal
2727

2828
import flax
2929
from flax import nnx
@@ -262,6 +262,7 @@ def __init__(
262262
beta_feature: float = 0.0,
263263
layer_indices: Optional[List[int]] = None,
264264
feature_loss_fn: Callable[[jax.Array, jax.Array], jax.Array] | None = None,
265+
feature_loss_type: Literal["cosine", "l2"] = "cosine",
265266
cosine_distance_axis: int | tuple[int, ...] = -1,
266267
vocab_size: int = 0,
267268
):
@@ -275,6 +276,8 @@ def __init__(
275276
alpha: Weight to balance distillation loss and task loss (0.0 to 1.0).
276277
beta_feature: Weight to balance feature loss (0.0 to 1.0). 0.0 disables feature loss.
277278
layer_indices: Layer indices to apply feature loss.
279+
feature_loss_type: The type of feature loss to use if `feature_loss_fn` is None.
280+
Can be "cosine" (default) or "l2".
278281
feature_loss_fn: A function that takes two jax. Arrays (student_map,
279282
teacher_map) and returns a scalar loss. Defaults to Cosine Distance.
280283
cosine_distance_axis: The axis to use for cosine distance computation if
@@ -295,9 +298,16 @@ def __init__(
295298

296299
self.feature_loss_fn = feature_loss_fn
297300
if feature_loss_fn is None:
298-
self.feature_loss_fn = lambda student_features, teacher_features: jnp.mean(
299-
optax.cosine_distance(student_features, teacher_features, axis=cosine_distance_axis)
300-
)
301+
if feature_loss_type == "cosine":
302+
self.feature_loss_fn = lambda student_features, teacher_features: jnp.mean(
303+
optax.cosine_distance(student_features, teacher_features, axis=cosine_distance_axis)
304+
)
305+
elif feature_loss_type == "l2":
306+
self.feature_loss_fn = lambda student_features, teacher_features: jnp.mean(
307+
optax.l2_loss(student_features, teacher_features)
308+
)
309+
else:
310+
raise ValueError(f"Unsupported feature_loss_type: {feature_loss_type!r}")
301311

302312
def compute_loss(
303313
self,

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ def build_training_components(
507507
alpha=student_config.distill_alpha,
508508
beta_feature=student_config.distill_beta,
509509
layer_indices=student_config.distill_layer_indices,
510+
feature_loss_type=student_config.distill_feature_loss_type,
510511
vocab_size=student_config.vocab_size,
511512
)
512513

tests/post_training/unit/train_distill_test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import shutil
2525
import tempfile
2626
import unittest
27+
from typing import Literal
2728
from unittest import mock
2829
import jax
2930
import jax.numpy as jnp
@@ -372,12 +373,14 @@ def test_optimizer_factory(self):
372373
train_distill.get_distillation_optimizer(config, max_train_steps=100)
373374

374375
def test_monitored_strategy(self):
375-
self._test_monitored_strategy(False)
376+
self._test_monitored_strategy(sft_mode=False, feature_loss_type="cosine")
377+
self._test_monitored_strategy(sft_mode=False, feature_loss_type="l2")
376378

377379
def test_monitored_strategy_sft(self):
378-
self._test_monitored_strategy(True)
380+
self._test_monitored_strategy(sft_mode=True, feature_loss_type="cosine")
381+
self._test_monitored_strategy(sft_mode=True, feature_loss_type="l2")
379382

380-
def _test_monitored_strategy(self, sft_mode: bool):
383+
def _test_monitored_strategy(self, *, sft_mode: bool, feature_loss_type: Literal["cosine", "l2"] = "cosine"):
381384
"""Verifies the strategy calculates metrics and returns the correct tuple."""
382385
strategy = distillation_utils.CombinedDistillationStrategy(
383386
student_forward_fn=lambda m, **k: None,
@@ -386,6 +389,7 @@ def _test_monitored_strategy(self, sft_mode: bool):
386389
temperature=1.0,
387390
alpha=0.5,
388391
beta_feature=1.0,
392+
feature_loss_type=feature_loss_type,
389393
layer_indices=None,
390394
)
391395

@@ -1012,6 +1016,7 @@ def test_main_offline_mode_skips_teacher_loading(
10121016
mock_student_cfg.distill_alpha = 0.5
10131017
mock_student_cfg.distill_beta = 0.0
10141018
mock_student_cfg.distill_layer_indices = None
1019+
mock_student_cfg.distill_feature_loss_type = "cosine"
10151020
mock_student_cfg.use_sft = False
10161021
mock_student_cfg.enable_dropout = False
10171022

@@ -1091,6 +1096,7 @@ def test_main_online_mode_loads_teacher(
10911096
mock_student_cfg.distill_alpha = 0.5
10921097
mock_student_cfg.distill_beta = 0.0
10931098
mock_student_cfg.distill_layer_indices = None
1099+
mock_student_cfg.distill_feature_loss_type = "cosine"
10941100
mock_student_cfg.use_sft = False
10951101
mock_student_cfg.enable_dropout = False
10961102

0 commit comments

Comments
 (0)