Skip to content

Commit 3852fc2

Browse files
authored
Fix docstring of mtl_backward (#210)
* Clarify note in `mtl_backward` * Fix formatting
1 parent 085c990 commit 3852fc2

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

src/torchjd/autojac/mtl_backward.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def mtl_backward(
4949
:param tasks_params: The parameters of each task-specific head. Their ``requires_grad`` flags
5050
must be set to ``True``. If not provided, the parameters considered for each task will
5151
default to the leaf tensors that are in the computation graph of its loss, but that were not
52-
used to compute the `features`.
52+
used to compute the ``features``.
5353
:param shared_params: The parameters of the shared feature extractor. The Jacobian matrix will
5454
have one column for each value in these tensors. Their ``requires_grad`` flags must be set
5555
to ``True``. If not provided, defaults to the leaf tensors that are in the computation graph
56-
of the `features`.
56+
of the ``features``.
5757
:param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to
5858
``False``.
5959
:param parallel_chunk_size: The number of scalars to differentiate simultaneously in the
@@ -70,7 +70,9 @@ def mtl_backward(
7070
:doc:`Multi-Task Learning (MTL) <../../examples/mtl>`.
7171
7272
.. note::
73-
`shared_params` and `tasks_params` must be disjoint.
73+
``shared_params`` should contain no parameter in common with ``tasks_params``. The different
74+
tasks may have some parameters in common. In this case, the sum of the gradients with
75+
respect to those parameters will be accumulated into their ``.grad`` fields.
7476
7577
.. warning::
7678
``mtl_backward`` relies on a usage of ``torch.vmap`` that is not compatible with compiled

0 commit comments

Comments
 (0)