Skip to content

Commit 085c990

Browse files
authored
Add note about params disjointness in mtl_backward (#207)
* Add test_mtl_backward_default_shared_params_overlap_with_default_tasks_params * Add note asking for shared_params and tasks_params to be disjoint
1 parent 59579bc commit 085c990

2 files changed

Lines changed: 26 additions & 0 deletions

File tree

src/torchjd/autojac/mtl_backward.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def mtl_backward(
6969
A usage example of ``mtl_backward`` is provided in
7070
:doc:`Multi-Task Learning (MTL) <../../examples/mtl>`.
7171
72+
.. note::
73+
`shared_params` and `tasks_params` must be disjoint.
74+
7275
.. warning::
7376
``mtl_backward`` relies on a usage of ``torch.vmap`` that is not compatible with compiled
7477
functions. The arguments of ``mtl_backward`` should thus not come from a compiled model.

tests/unit/autojac/test_mtl_backward.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,3 +566,26 @@ def test_mtl_backward_shared_params_overlap_with_tasks_params():
566566
shared_params=[p0],
567567
retain_graph=True,
568568
)
569+
570+
571+
def test_mtl_backward_default_shared_params_overlap_with_default_tasks_params():
572+
"""
573+
Tests that mtl_backward raises an error when the set of shared params obtained by default
574+
overlaps with the set of task-specific params obtained by default.
575+
"""
576+
577+
p0 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
578+
p1 = torch.tensor(2.0, requires_grad=True, device=DEVICE)
579+
p2 = torch.tensor(3.0, requires_grad=True, device=DEVICE)
580+
581+
r = torch.tensor([-1.0, 1.0], device=DEVICE) @ p0
582+
y1 = r * p1
583+
y2 = p0.sum() * r * p2
584+
585+
with raises(ValueError):
586+
mtl_backward(
587+
losses=[y1, y2],
588+
features=[r],
589+
A=UPGrad(),
590+
retain_graph=True,
591+
)

0 commit comments

Comments
 (0)