1+ from typing import Any
2+
13from pytest import mark , raises
4+ from torch import Tensor
5+ from torch .testing import assert_close
26from utils .asserts import assert_grad_close , assert_has_jac , assert_has_no_jac
37from utils .tensors import tensor_
48
5- from torchjd .aggregation import Aggregator , Mean , PCGrad , UPGrad
9+ from torchjd .aggregation import (
10+ Aggregator ,
11+ ConFIG ,
12+ Mean ,
13+ PCGrad ,
14+ UPGrad ,
15+ )
16+ from torchjd .aggregation ._aggregator_bases import WeightedAggregator
617from torchjd .autojac ._jac_to_grad import jac_to_grad
718
819
9- @mark .parametrize ("aggregator" , [Mean (), UPGrad (), PCGrad ()])
20+ @mark .parametrize ("aggregator" , [Mean (), UPGrad (), PCGrad (), ConFIG () ])
1021def test_various_aggregators (aggregator : Aggregator ) -> None :
11- """Tests that jac_to_grad works for various aggregators."""
22+ """
23+ Tests that jac_to_grad works for various aggregators. For those that are weighted, the weights
24+ should also be returned. For the others, None should be returned.
25+ """
1226
1327 t1 = tensor_ (1.0 , requires_grad = True )
1428 t2 = tensor_ ([2.0 , 3.0 ], requires_grad = True )
@@ -19,11 +33,18 @@ def test_various_aggregators(aggregator: Aggregator) -> None:
1933 g1 = expected_grad [0 ]
2034 g2 = expected_grad [1 :]
2135
22- jac_to_grad ([t1 , t2 ], aggregator )
36+ optional_weights = jac_to_grad ([t1 , t2 ], aggregator )
2337
2438 assert_grad_close (t1 , g1 )
2539 assert_grad_close (t2 , g2 )
2640
41+ if isinstance (aggregator , WeightedAggregator ):
42+ assert optional_weights is not None
43+ expected_weights = aggregator .weighting (jac )
44+ assert_close (optional_weights , expected_weights )
45+ else :
46+ assert optional_weights is None
47+
2748
2849def test_single_tensor () -> None :
2950 """Tests that jac_to_grad works when a single tensor is provided."""
@@ -80,9 +101,10 @@ def test_row_mismatch() -> None:
80101
81102
82103def test_no_tensors () -> None :
83- """Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided."""
104+ """Tests that jac_to_grad correctly raises when an empty list of tensors is provided."""
84105
85- jac_to_grad ([], aggregator = UPGrad ())
106+ with raises (ValueError ):
107+ jac_to_grad ([], UPGrad ())
86108
87109
88110@mark .parametrize ("retain_jac" , [True , False ])
@@ -115,3 +137,51 @@ def test_noncontiguous_jac() -> None:
115137
116138 jac_to_grad ([t ], aggregator )
117139 assert_grad_close (t , g )
140+
141+
142+ @mark .parametrize ("aggregator" , [UPGrad (), ConFIG ()])
143+ def test_aggregator_hook_is_run (aggregator : Aggregator ) -> None :
144+ """
145+ Tests that jac_to_grad runs forward hooks registered on the aggregator, for both
146+ WeightedAggregator (UPGrad) and plain Aggregator (ConFIG) paths.
147+ """
148+
149+ call_count = [0 ] # Pointer to int
150+
151+ def hook (_module : Any , _input : Any , _output : Any ) -> None :
152+ call_count [0 ] += 1
153+
154+ aggregator .register_forward_hook (hook )
155+
156+ t = tensor_ ([2.0 , 3.0 ], requires_grad = True )
157+ jac = tensor_ ([[- 4.0 , 1.0 ], [6.0 , 1.0 ]])
158+ t .__setattr__ ("jac" , jac )
159+
160+ jac_to_grad ([t ], aggregator )
161+
162+ assert call_count [0 ] == 1
163+
164+
165+ def test_with_hooks () -> None :
166+ """Tests that jac_to_grad correctly returns the weights modified by all applicable hooks."""
167+
168+ def hook_aggregator (_module : Any , _input : Any , aggregation : Tensor ) -> Tensor :
169+ return aggregation * 2 # should not affect the weights
170+
171+ def hook_outer (_module : Any , _input : Any , weights : Tensor ) -> Tensor :
172+ return weights * 3 # should affect the weights returned by jac_to_grad
173+
174+ def hook_inner (_module : Any , _input : Any , weights : Tensor ) -> Tensor :
175+ return weights * 5 # should affect the weights returned by jac_to_grad
176+
177+ aggregator = UPGrad ()
178+ aggregator .register_forward_hook (hook_aggregator )
179+ aggregator .weighting .register_forward_hook (hook_outer )
180+ aggregator .gramian_weighting .register_forward_hook (hook_inner )
181+
182+ t = tensor_ ([2.0 , 3.0 ], requires_grad = True )
183+ jac = tensor_ ([[- 4.0 , 1.0 ], [6.0 , 1.0 ]])
184+ t .__setattr__ ("jac" , jac )
185+
186+ weights = jac_to_grad ([t ], aggregator )
187+ assert_close (weights , aggregator .weighting (jac ))
0 commit comments