Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ae6be7d
WIP: add gramian-based jac_to_grad
ValerianRey Jan 21, 2026
8bdf512
Update changelog
ValerianRey Jan 21, 2026
aaf2544
Use deque to free memory asap
ValerianRey Jan 23, 2026
64b06ad
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 23, 2026
745f707
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
5eb77f9
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
8f65caa
Use gramian_weighting in jac_to_grad
ValerianRey Jan 28, 2026
6fe15a4
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
d5cb5c2
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 29, 2026
f986950
Only optimize when no forward hooks
ValerianRey Jan 29, 2026
4cf5cbb
Make _gramian_based take aggregator instead of weighting
ValerianRey Jan 29, 2026
add549c
Add _can_skip_jacobian_combination helper function
ValerianRey Jan 29, 2026
453971a
Add test_can_skip_jacobian_combination
ValerianRey Jan 29, 2026
9d4c41c
Optimize compute_gramian for when contracted_dims=-1
ValerianRey Jan 29, 2026
48cd70b
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 30, 2026
8f2660d
Use TypeGuard in _can_skip_jacobian_combination
ValerianRey Jan 30, 2026
fc9bbcf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2026
3f9a6d1
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 1, 2026
9d9cbf0
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 4, 2026
b5ca226
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
0baa914
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 5, 2026
2ed1d7c
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 13, 2026
86be778
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 13, 2026
4ace19e
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 13, 2026
2a84bef
Add ruff if-else squeezing
ValerianRey Feb 13, 2026
4b6209c
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 23, 2026
1b1c660
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 23, 2026
b714253
Many fixes of problems coming from the merge
ValerianRey Feb 23, 2026
2bb8ab1
Fix _can_skip_jacobian_combination
ValerianRey Feb 23, 2026
63c9dde
Make check_consistent_first_dimension work with Deque
ValerianRey Feb 23, 2026
0f85811
Improve test_can_skip_jacobian_combination
ValerianRey Feb 23, 2026
9d55215
Add optimize_gramian_computation param and add error when not compatible
ValerianRey Feb 23, 2026
456510b
Fix overloads (partly) and add missing code coverage
ValerianRey Feb 23, 2026
55c69d1
Fix overloads
ValerianRey Feb 23, 2026
8a401a3
Fix docstring
ValerianRey Feb 23, 2026
b4bf7c4
fixup what @ValerianRey did wrong
PierreQuinton Feb 23, 2026
24a991a
Improve error message
ValerianRey Feb 23, 2026
2ea44a4
Improve docstring
ValerianRey Feb 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,8 @@ changelog does not include internal changes that do not affect the user.
- `GeneralizedWeighting.__call__`: The `generalized_gramian` parameter is now positional-only.
Suggested change: `generalized_weighting(generalized_gramian=generalized_gramian)` =>
`generalized_weighting(generalized_gramian)`.
- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency
of `autojac`.
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
efficiency of `autojac`.
- Removed several unnecessary memory duplications. This should significantly improve the memory
efficiency and speed of `autojac`.
- Increased the lower bounds of the torch (from 2.0.0 to 2.3.0) and numpy (from 1.21.0
to 1.21.2) dependencies to reflect what really works with torchjd. We now also run torchjd's tests
with the dependency lower-bounds specified in `pyproject.toml`, so we should now always accurately
Expand Down
19 changes: 14 additions & 5 deletions src/torchjd/_linalg/_gramian.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,20 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
first dimension).
"""

contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim
indices_source = list(range(t.ndim - contracted_dims))
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
transposed = t.movedim(indices_source, indices_dest)
gramian = torch.tensordot(t, transposed, dims=contracted_dims)
# Optimization: it's faster to do that than moving dims and using tensordot, and this case
# happens very often, sometimes hundreds of times for a single jac_to_grad.
if contracted_dims == -1:
matrix = t.unsqueeze(1) if t.ndim == 1 else t.flatten(start_dim=1)

gramian = matrix @ matrix.T

else:
contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim
indices_source = list(range(t.ndim - contracted_dims))
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
transposed = t.movedim(indices_source, indices_dest)
gramian = torch.tensordot(t, transposed, dims=contracted_dims)

return cast(PSDTensor, gramian)


Expand Down
87 changes: 76 additions & 11 deletions src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from collections import deque
from collections.abc import Iterable
from typing import overload
from typing import TypeGuard, cast, overload

import torch
from torch import Tensor
from torch import Tensor, nn

from torchjd._linalg import Matrix
from torchjd._linalg import Matrix, PSDMatrix, compute_gramian
from torchjd.aggregation import Aggregator, Weighting
from torchjd.aggregation._aggregator_bases import WeightedAggregator
from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator, WeightedAggregator

from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac
from ._utils import check_consistent_first_dimension
Expand Down Expand Up @@ -38,6 +39,7 @@
aggregator: Aggregator,
*,
retain_jac: bool = False,
optimize_gramian_computation: bool = False,
) -> Tensor | None:
r"""
Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result
Expand All @@ -50,6 +52,13 @@
the Jacobians, ``jac_to_grad`` will also return the computed weights.
:param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been
used. Defaults to ``False``.
:param optimize_gramian_computation: When the ``aggregator`` is a
:class:`GramianWeightedAggregator <torchjd.aggregation._aggregator_bases.GramianWeightedAggregator>`
Comment thread
ValerianRey marked this conversation as resolved.
Outdated
(e.g. :class:`UPGrad <torchjd.aggregation._upgrad.UPGrad>`), it's possible to skip the
concatenation of the Jacobians and to instead compute the Gramian as the sum of the Gramians
of the individual Jacobians. This saves memory (up to 50% memory saving) but can be slightly
slower (up to 15%) on CUDA. We advise to try this optimization if memory is an issue for
you. Defaults to ``False``.

.. note::
This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all
Expand Down Expand Up @@ -96,13 +105,46 @@
if len(tensors_) == 0:
raise ValueError("The `tensors` parameter cannot be empty.")

jacobians = [t.jac for t in tensors_]

jacobians = deque(t.jac for t in tensors_)
check_consistent_first_dimension(jacobians, "tensors.jac")

if not retain_jac:
_free_jacs(tensors_)

if optimize_gramian_computation:
if not _can_skip_jacobian_combination(aggregator):
raise ValueError(

Check warning on line 116 in src/torchjd/autojac/_jac_to_grad.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/autojac/_jac_to_grad.py#L115-L116

Added lines #L115 - L116 were not covered by tests
"In order to use `jac_to_grad` with `optimize_gramian_computation=True`, you must "
"provide a `GramianWeightedAggregator` that doesn't have any forward hooks attached"
" to it."
)

gradients, weights = _gramian_based(aggregator, jacobians)

Check warning on line 122 in src/torchjd/autojac/_jac_to_grad.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/autojac/_jac_to_grad.py#L122

Added line #L122 was not covered by tests
else:
gradients, weights = _jacobian_based(aggregator, jacobians, tensors_)
accumulate_grads(tensors_, gradients)

return weights


def _can_skip_jacobian_combination(aggregator: Aggregator) -> TypeGuard[GramianWeightedAggregator]:
return (
isinstance(aggregator, GramianWeightedAggregator)
and not _has_forward_hook(aggregator)
and not _has_forward_hook(aggregator.weighting)
)


def _has_forward_hook(module: nn.Module) -> bool:
"""Return whether the module has any forward hook registered."""
return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0


def _jacobian_based(
aggregator: Aggregator,
jacobians: deque[Tensor],
tensors: list[TensorWithJac],
) -> tuple[list[Tensor], Tensor | None]:
jacobian_matrix = _unite_jacobians(jacobians)
weights: Tensor | None = None

Expand All @@ -124,13 +166,36 @@
handle.remove()
else:
gradient_vector = aggregator(jacobian_matrix)
gradients = _disunite_gradient(gradient_vector, tensors_)
accumulate_grads(tensors_, gradients)
return weights
gradients = _disunite_gradient(gradient_vector, tensors)
return gradients, weights


def _gramian_based(
aggregator: GramianWeightedAggregator,
jacobians: deque[Tensor],
) -> tuple[list[Tensor], Tensor]:
weighting = aggregator.gramian_weighting
gramian = _compute_gramian_sum(jacobians)
weights = weighting(gramian)

Check warning on line 179 in src/torchjd/autojac/_jac_to_grad.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/autojac/_jac_to_grad.py#L177-L179

Added lines #L177 - L179 were not covered by tests

gradients = list[Tensor]()
while jacobians:
jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap
gradients.append(torch.tensordot(weights, jacobian, dims=1))

Check warning on line 184 in src/torchjd/autojac/_jac_to_grad.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/autojac/_jac_to_grad.py#L181-L184

Added lines #L181 - L184 were not covered by tests

return gradients, weights

Check warning on line 186 in src/torchjd/autojac/_jac_to_grad.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/autojac/_jac_to_grad.py#L186

Added line #L186 was not covered by tests


def _compute_gramian_sum(jacobians: deque[Tensor]) -> PSDMatrix:
gramian = sum([compute_gramian(matrix) for matrix in jacobians])
return cast(PSDMatrix, gramian)

Check warning on line 191 in src/torchjd/autojac/_jac_to_grad.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/autojac/_jac_to_grad.py#L190-L191

Added lines #L190 - L191 were not covered by tests


def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
jacobian_matrices = [jacobian.reshape(jacobian.shape[0], -1) for jacobian in jacobians]
def _unite_jacobians(jacobians: deque[Tensor]) -> Tensor:
jacobian_matrices = list[Tensor]()
while jacobians:
jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap
jacobian_matrices.append(jacobian.reshape(jacobian.shape[0], -1))
jacobian_matrix = torch.concat(jacobian_matrices, dim=1)
return jacobian_matrix

Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/autojac/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ def check_consistent_first_dimension(
:param jacobians: Sequence of Jacobian tensors to validate.
:param variable_name: Name of the variable to include in the error message.
"""

if len(jacobians) > 0 and not all(
jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]
jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians
):
raise ValueError(f"All Jacobians in `{variable_name}` should have the same number of rows.")

Expand Down
122 changes: 120 additions & 2 deletions tests/unit/autojac/test_jac_to_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,28 @@
from utils.tensors import tensor_

from torchjd.aggregation import (
IMTLG,
MGDA,
Aggregator,
AlignedMTL,
ConFIG,
Constant,
DualProj,
GradDrop,
Krum,
Mean,
PCGrad,
Random,
Sum,
TrimmedMean,
UPGrad,
)
from torchjd.aggregation._aggregator_bases import WeightedAggregator
from torchjd.autojac._jac_to_grad import jac_to_grad
from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator, WeightedAggregator
from torchjd.autojac._jac_to_grad import (
_can_skip_jacobian_combination,
_has_forward_hook,
jac_to_grad,
)


@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad(), ConFIG()])
Expand Down Expand Up @@ -125,6 +139,110 @@ def test_jacs_are_freed(retain_jac: bool) -> None:
check(t2)


def test_has_forward_hook() -> None:
"""Tests that _has_forward_hook correctly detects the presence of forward hooks."""

module = UPGrad()

def dummy_forward_hook(_module, _input, _output) -> Tensor:
return _output

def dummy_forward_pre_hook(_module, _input) -> Tensor:
return _input

def dummy_backward_hook(_module, _grad_input, _grad_output) -> Tensor:
return _grad_input

def dummy_backward_pre_hook(_module, _grad_output) -> Tensor:
return _grad_output

# Module with no hooks or backward hooks only should return False
assert not _has_forward_hook(module)
module.register_full_backward_hook(dummy_backward_hook)
assert not _has_forward_hook(module)
module.register_full_backward_pre_hook(dummy_backward_pre_hook)
assert not _has_forward_hook(module)

# Module with forward hook should return True
handle1 = module.register_forward_hook(dummy_forward_hook)
assert _has_forward_hook(module)
handle2 = module.register_forward_hook(dummy_forward_hook)
assert _has_forward_hook(module)
handle1.remove()
assert _has_forward_hook(module)
handle2.remove()
assert not _has_forward_hook(module)

# Module with forward pre-hook should return True
handle3 = module.register_forward_pre_hook(dummy_forward_pre_hook)
assert _has_forward_hook(module)
handle4 = module.register_forward_pre_hook(dummy_forward_pre_hook)
assert _has_forward_hook(module)
handle3.remove()
assert _has_forward_hook(module)
handle4.remove()
assert not _has_forward_hook(module)


_PARAMETRIZATIONS = [
(AlignedMTL(), True),
(DualProj(), True),
(IMTLG(), True),
(Krum(n_byzantine=1), True),
(MGDA(), True),
(PCGrad(), True),
(UPGrad(), True),
(ConFIG(), False),
(Constant(tensor_([0.5, 0.5])), False),
(GradDrop(), False),
(Mean(), False),
(Random(), False),
(Sum(), False),
(TrimmedMean(trim_number=1), False),
]

try:
from torchjd.aggregation import CAGrad

_PARAMETRIZATIONS.append((CAGrad(c=0.5), True))
except ImportError:
pass

try:
from torchjd.aggregation import NashMTL

_PARAMETRIZATIONS.append((NashMTL(n_tasks=2), False))
except ImportError:
pass


@mark.parametrize("aggregator, expected", _PARAMETRIZATIONS)
def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool) -> None:
"""
Tests that _can_skip_jacobian_combination correctly identifies when optimization can be used.
"""

assert _can_skip_jacobian_combination(aggregator) == expected
handle = aggregator.register_forward_hook(lambda _module, _input, output: output)
assert not _can_skip_jacobian_combination(aggregator)
handle.remove()
assert _can_skip_jacobian_combination(aggregator) == expected
handle = aggregator.register_forward_pre_hook(lambda _module, input: input)
assert not _can_skip_jacobian_combination(aggregator)
handle.remove()
assert _can_skip_jacobian_combination(aggregator) == expected

if isinstance(aggregator, GramianWeightedAggregator):
handle = aggregator.weighting.register_forward_hook(lambda _module, _input, output: output)
assert not _can_skip_jacobian_combination(aggregator)
handle.remove()
assert _can_skip_jacobian_combination(aggregator) == expected
handle = aggregator.weighting.register_forward_pre_hook(lambda _module, input: input)
assert not _can_skip_jacobian_combination(aggregator)
handle.remove()
assert _can_skip_jacobian_combination(aggregator) == expected


def test_noncontiguous_jac() -> None:
"""Tests that jac_to_grad works when the .jac field is non-contiguous."""

Expand Down
Loading