File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -49,11 +49,11 @@ def mtl_backward(
4949 :param tasks_params: The parameters of each task-specific head. Their ``requires_grad`` flags
5050 must be set to ``True``. If not provided, the parameters considered for each task will
5151 default to the leaf tensors that are in the computation graph of its loss, but that were not
52- used to compute the `features`.
52+ used to compute the `` features` `.
5353 :param shared_params: The parameters of the shared feature extractor. The Jacobian matrix will
5454 have one column for each value in these tensors. Their ``requires_grad`` flags must be set
5555 to ``True``. If not provided, defaults to the leaf tensors that are in the computation graph
56- of the `features`.
56+ of the `` features` `.
5757 :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to
5858 ``False``.
5959 :param parallel_chunk_size: The number of scalars to differentiate simultaneously in the
@@ -70,7 +70,9 @@ def mtl_backward(
7070 :doc:`Multi-Task Learning (MTL) <../../examples/mtl>`.
7171
7272 .. note::
73- `shared_params` and `tasks_params` must be disjoint.
73+ ``shared_params`` should contain no parameter in common with ``tasks_params``. The different
74+ tasks may have some parameters in common. In this case, the sum of the gradients with
75+ respect to those parameters will be accumulated into their ``.grad`` fields.
7476
7577 .. warning::
7678 ``mtl_backward`` relies on a usage of ``torch.vmap`` that is not compatible with compiled
You can’t perform that action at this time.
0 commit comments