Skip to content

Commit 843dd45

Browse files
feat(autojac): Make jac_to_grad return optional weights (#586)
--------- Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
1 parent f18ab5a commit 843dd45

3 files changed

Lines changed: 135 additions & 15 deletions

File tree

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,53 @@
11
from collections.abc import Iterable
2+
from typing import overload
23

34
import torch
45
from torch import Tensor
56

6-
from torchjd.aggregation import Aggregator
7+
from torchjd._linalg import Matrix
8+
from torchjd.aggregation import Aggregator, Weighting
9+
from torchjd.aggregation._aggregator_bases import WeightedAggregator
710

811
from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac
912
from ._utils import check_consistent_first_dimension
1013

1114

15+
@overload
16+
def jac_to_grad(
17+
tensors: Iterable[Tensor],
18+
/,
19+
aggregator: WeightedAggregator,
20+
*,
21+
retain_jac: bool = False,
22+
) -> Tensor: ...
23+
24+
25+
@overload
26+
def jac_to_grad(
27+
tensors: Iterable[Tensor],
28+
/,
29+
aggregator: Aggregator, # Not a WeightedAggregator, because overloads are checked in order
30+
*,
31+
retain_jac: bool = False,
32+
) -> None: ...
33+
34+
1235
def jac_to_grad(
1336
tensors: Iterable[Tensor],
1437
/,
1538
aggregator: Aggregator,
1639
*,
1740
retain_jac: bool = False,
18-
) -> None:
41+
) -> Tensor | None:
1942
r"""
2043
Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result
2144
into their ``.grad`` fields.
2245
2346
:param tensors: The tensors whose ``.jac`` fields should be aggregated. All Jacobians must
2447
have the same first dimension (e.g. number of losses).
25-
:param aggregator: The aggregator used to reduce the Jacobians into gradients.
48+
:param aggregator: The aggregator used to reduce the Jacobians into gradients. If it uses a
49+
:class:`Weighting <torchjd.aggregation._weighting_bases.Weighting>` to combine the rows of
50+
the Jacobians, ``jac_to_grad`` will also return the computed weights.
2651
:param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been
2752
used. Defaults to ``False``.
2853
@@ -48,12 +73,15 @@ def jac_to_grad(
4873
>>> y2 = (param ** 2).sum()
4974
>>>
5075
>>> backward([y1, y2]) # param now has a .jac field
51-
>>> jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field
76+
>>> weights = jac_to_grad([param], UPGrad()) # param now has a .grad field
5277
>>> param.grad
53-
tensor([-1., 1.])
78+
tensor([0.5000, 2.5000])
79+
>>> weights
80+
tensor([0.5, 0.5])
5481
5582
The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of
56-
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.
83+
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. In this case, the
84+
weights used to combine the Jacobian are equal because there was no conflict.
5785
"""
5886

5987
tensors_ = list[TensorWithJac]()
@@ -66,7 +94,7 @@ def jac_to_grad(
6694
tensors_.append(t)
6795

6896
if len(tensors_) == 0:
69-
return
97+
raise ValueError("The `tensors` parameter cannot be empty.")
7098

7199
jacobians = [t.jac for t in tensors_]
72100

@@ -76,9 +104,29 @@ def jac_to_grad(
76104
_free_jacs(tensors_)
77105

78106
jacobian_matrix = _unite_jacobians(jacobians)
79-
gradient_vector = aggregator(jacobian_matrix)
107+
weights: Tensor | None = None
108+
109+
if isinstance(aggregator, WeightedAggregator):
110+
111+
def capture_hook(_m: Weighting[Matrix], _i: tuple[Tensor], output: Tensor) -> None:
112+
nonlocal weights
113+
weights = output
114+
115+
# Append the weight-capturing post-hook to the outer weighting to ensure that all other
116+
# post-hooks of the outer and inner weighting are run (potentially with effect on the
117+
# weights) prior to capturing the weights.
118+
handle = aggregator.weighting.register_forward_hook(capture_hook)
119+
120+
# Using a try-finally here in case an exception is raised by the aggregator.
121+
try:
122+
gradient_vector = aggregator(jacobian_matrix)
123+
finally:
124+
handle.remove()
125+
else:
126+
gradient_vector = aggregator(jacobian_matrix)
80127
gradients = _disunite_gradient(gradient_vector, tensors_)
81128
accumulate_grads(tensors_, gradients)
129+
return weights
82130

83131

84132
def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:

tests/doc/test_jac_to_grad.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
the obtained `.grad` field.
44
"""
55

6+
from torch.testing import assert_close
67
from utils.asserts import assert_grad_close
78

89

@@ -17,6 +18,7 @@ def test_jac_to_grad() -> None:
1718
y1 = torch.tensor([-1.0, 1.0]) @ param
1819
y2 = (param**2).sum()
1920
backward([y1, y2]) # param now has a .jac field
20-
jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field
21+
weights = jac_to_grad([param], UPGrad()) # param now has a .grad field
2122

2223
assert_grad_close(param, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04)
24+
assert_close(weights, torch.tensor([0.5, 0.5]), rtol=0.0, atol=0.0)

tests/unit/autojac/test_jac_to_grad.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,28 @@
1+
from typing import Any
2+
13
from pytest import mark, raises
4+
from torch import Tensor
5+
from torch.testing import assert_close
26
from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac
37
from utils.tensors import tensor_
48

5-
from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad
9+
from torchjd.aggregation import (
10+
Aggregator,
11+
ConFIG,
12+
Mean,
13+
PCGrad,
14+
UPGrad,
15+
)
16+
from torchjd.aggregation._aggregator_bases import WeightedAggregator
617
from torchjd.autojac._jac_to_grad import jac_to_grad
718

819

9-
@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()])
20+
@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad(), ConFIG()])
1021
def test_various_aggregators(aggregator: Aggregator) -> None:
11-
"""Tests that jac_to_grad works for various aggregators."""
22+
"""
23+
Tests that jac_to_grad works for various aggregators. For those that are weighted, the weights
24+
should also be returned. For the others, None should be returned.
25+
"""
1226

1327
t1 = tensor_(1.0, requires_grad=True)
1428
t2 = tensor_([2.0, 3.0], requires_grad=True)
@@ -19,11 +33,18 @@ def test_various_aggregators(aggregator: Aggregator) -> None:
1933
g1 = expected_grad[0]
2034
g2 = expected_grad[1:]
2135

22-
jac_to_grad([t1, t2], aggregator)
36+
optional_weights = jac_to_grad([t1, t2], aggregator)
2337

2438
assert_grad_close(t1, g1)
2539
assert_grad_close(t2, g2)
2640

41+
if isinstance(aggregator, WeightedAggregator):
42+
assert optional_weights is not None
43+
expected_weights = aggregator.weighting(jac)
44+
assert_close(optional_weights, expected_weights)
45+
else:
46+
assert optional_weights is None
47+
2748

2849
def test_single_tensor() -> None:
2950
"""Tests that jac_to_grad works when a single tensor is provided."""
@@ -80,9 +101,10 @@ def test_row_mismatch() -> None:
80101

81102

82103
def test_no_tensors() -> None:
83-
"""Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided."""
104+
"""Tests that jac_to_grad correctly raises when an empty list of tensors is provided."""
84105

85-
jac_to_grad([], aggregator=UPGrad())
106+
with raises(ValueError):
107+
jac_to_grad([], UPGrad())
86108

87109

88110
@mark.parametrize("retain_jac", [True, False])
@@ -115,3 +137,51 @@ def test_noncontiguous_jac() -> None:
115137

116138
jac_to_grad([t], aggregator)
117139
assert_grad_close(t, g)
140+
141+
142+
@mark.parametrize("aggregator", [UPGrad(), ConFIG()])
143+
def test_aggregator_hook_is_run(aggregator: Aggregator) -> None:
144+
"""
145+
Tests that jac_to_grad runs forward hooks registered on the aggregator, for both
146+
WeightedAggregator (UPGrad) and plain Aggregator (ConFIG) paths.
147+
"""
148+
149+
call_count = [0] # Pointer to int
150+
151+
def hook(_module: Any, _input: Any, _output: Any) -> None:
152+
call_count[0] += 1
153+
154+
aggregator.register_forward_hook(hook)
155+
156+
t = tensor_([2.0, 3.0], requires_grad=True)
157+
jac = tensor_([[-4.0, 1.0], [6.0, 1.0]])
158+
t.__setattr__("jac", jac)
159+
160+
jac_to_grad([t], aggregator)
161+
162+
assert call_count[0] == 1
163+
164+
165+
def test_with_hooks() -> None:
166+
"""Tests that jac_to_grad correctly returns the weights modified by all applicable hooks."""
167+
168+
def hook_aggregator(_module: Any, _input: Any, aggregation: Tensor) -> Tensor:
169+
return aggregation * 2 # should not affect the weights
170+
171+
def hook_outer(_module: Any, _input: Any, weights: Tensor) -> Tensor:
172+
return weights * 3 # should affect the weights returned by jac_to_grad
173+
174+
def hook_inner(_module: Any, _input: Any, weights: Tensor) -> Tensor:
175+
return weights * 5 # should affect the weights returned by jac_to_grad
176+
177+
aggregator = UPGrad()
178+
aggregator.register_forward_hook(hook_aggregator)
179+
aggregator.weighting.register_forward_hook(hook_outer)
180+
aggregator.gramian_weighting.register_forward_hook(hook_inner)
181+
182+
t = tensor_([2.0, 3.0], requires_grad=True)
183+
jac = tensor_([[-4.0, 1.0], [6.0, 1.0]])
184+
t.__setattr__("jac", jac)
185+
186+
weights = jac_to_grad([t], aggregator)
187+
assert_close(weights, aggregator.weighting(jac))

0 commit comments

Comments
 (0)