|
2 | 2 | from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac |
3 | 3 | from utils.tensors import tensor_ |
4 | 4 |
|
5 | | -from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad |
6 | | -from torchjd.autojac._jac_to_grad import _has_forward_hook, jac_to_grad |
| 5 | +from torchjd.aggregation import ( |
| 6 | + IMTLG, |
| 7 | + MGDA, |
| 8 | + Aggregator, |
| 9 | + AlignedMTL, |
| 10 | + ConFIG, |
| 11 | + Constant, |
| 12 | + DualProj, |
| 13 | + GradDrop, |
| 14 | + Krum, |
| 15 | + Mean, |
| 16 | + PCGrad, |
| 17 | + Random, |
| 18 | + Sum, |
| 19 | + TrimmedMean, |
| 20 | + UPGrad, |
| 21 | +) |
| 22 | +from torchjd.autojac._jac_to_grad import ( |
| 23 | + _can_skip_jacobian_combination, |
| 24 | + _has_forward_hook, |
| 25 | + jac_to_grad, |
| 26 | +) |
7 | 27 |
|
8 | 28 |
|
9 | 29 | @mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()]) |
@@ -146,3 +166,51 @@ def dummy_backward_pre_hook(_module, _grad_output): |
146 | 166 | assert _has_forward_hook(module) |
147 | 167 | handle4.remove() |
148 | 168 | assert not _has_forward_hook(module) |
| 169 | + |
| 170 | + |
| 171 | +_PARAMETRIZATIONS = [ |
| 172 | + (AlignedMTL(), True), |
| 173 | + (DualProj(), True), |
| 174 | + (IMTLG(), True), |
| 175 | + (Krum(n_byzantine=1), True), |
| 176 | + (MGDA(), True), |
| 177 | + (PCGrad(), True), |
| 178 | + (UPGrad(), True), |
| 179 | + (ConFIG(), False), |
| 180 | + (Constant(tensor_([0.5, 0.5])), False), |
| 181 | + (GradDrop(), False), |
| 182 | + (Mean(), False), |
| 183 | + (Random(), False), |
| 184 | + (Sum(), False), |
| 185 | + (TrimmedMean(trim_number=1), False), |
| 186 | +] |
| 187 | + |
| 188 | +try: |
| 189 | + from torchjd.aggregation import CAGrad |
| 190 | + |
| 191 | + _PARAMETRIZATIONS.append((CAGrad(c=0.5), True)) |
| 192 | +except ImportError: |
| 193 | + pass |
| 194 | + |
| 195 | +try: |
| 196 | + from torchjd.aggregation import NashMTL |
| 197 | + |
| 198 | + _PARAMETRIZATIONS.append((NashMTL(n_tasks=2), False)) |
| 199 | +except ImportError: |
| 200 | + pass |
| 201 | + |
| 202 | + |
| 203 | +@mark.parametrize("aggregator, expected", _PARAMETRIZATIONS) |
| 204 | +def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool): |
| 205 | + """ |
| 206 | + Tests that _can_skip_jacobian_combination correctly identifies when optimization can be used. |
| 207 | + """ |
| 208 | + |
| 209 | + assert _can_skip_jacobian_combination(aggregator) == expected |
| 210 | + handle = aggregator.register_forward_hook(lambda module, input, output: output) |
| 211 | + assert not _can_skip_jacobian_combination(aggregator) |
| 212 | + handle.remove() |
| 213 | + handle = aggregator.register_forward_pre_hook(lambda module, input: input) |
| 214 | + assert not _can_skip_jacobian_combination(aggregator) |
| 215 | + handle.remove() |
| 216 | + assert _can_skip_jacobian_combination(aggregator) == expected |
0 commit comments