Skip to content

Commit 453971a

Browse files
committed
Add test_can_skip_jacobian_combination
1 parent add549c commit 453971a

1 file changed

Lines changed: 70 additions & 2 deletions

File tree

tests/unit/autojac/test_jac_to_grad.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,28 @@
22
from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac
33
from utils.tensors import tensor_
44

5-
from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad
6-
from torchjd.autojac._jac_to_grad import _has_forward_hook, jac_to_grad
5+
from torchjd.aggregation import (
6+
IMTLG,
7+
MGDA,
8+
Aggregator,
9+
AlignedMTL,
10+
ConFIG,
11+
Constant,
12+
DualProj,
13+
GradDrop,
14+
Krum,
15+
Mean,
16+
PCGrad,
17+
Random,
18+
Sum,
19+
TrimmedMean,
20+
UPGrad,
21+
)
22+
from torchjd.autojac._jac_to_grad import (
23+
_can_skip_jacobian_combination,
24+
_has_forward_hook,
25+
jac_to_grad,
26+
)
727

828

929
@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()])
@@ -146,3 +166,51 @@ def dummy_backward_pre_hook(_module, _grad_output):
146166
assert _has_forward_hook(module)
147167
handle4.remove()
148168
assert not _has_forward_hook(module)
169+
170+
171+
_PARAMETRIZATIONS = [
172+
(AlignedMTL(), True),
173+
(DualProj(), True),
174+
(IMTLG(), True),
175+
(Krum(n_byzantine=1), True),
176+
(MGDA(), True),
177+
(PCGrad(), True),
178+
(UPGrad(), True),
179+
(ConFIG(), False),
180+
(Constant(tensor_([0.5, 0.5])), False),
181+
(GradDrop(), False),
182+
(Mean(), False),
183+
(Random(), False),
184+
(Sum(), False),
185+
(TrimmedMean(trim_number=1), False),
186+
]
187+
188+
try:
189+
from torchjd.aggregation import CAGrad
190+
191+
_PARAMETRIZATIONS.append((CAGrad(c=0.5), True))
192+
except ImportError:
193+
pass
194+
195+
try:
196+
from torchjd.aggregation import NashMTL
197+
198+
_PARAMETRIZATIONS.append((NashMTL(n_tasks=2), False))
199+
except ImportError:
200+
pass
201+
202+
203+
@mark.parametrize("aggregator, expected", _PARAMETRIZATIONS)
204+
def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool):
205+
"""
206+
Tests that _can_skip_jacobian_combination correctly identifies when optimization can be used.
207+
"""
208+
209+
assert _can_skip_jacobian_combination(aggregator) == expected
210+
handle = aggregator.register_forward_hook(lambda module, input, output: output)
211+
assert not _can_skip_jacobian_combination(aggregator)
212+
handle.remove()
213+
handle = aggregator.register_forward_pre_hook(lambda module, input: input)
214+
assert not _can_skip_jacobian_combination(aggregator)
215+
handle.remove()
216+
assert _can_skip_jacobian_combination(aggregator) == expected

0 commit comments

Comments
 (0)