Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ae6be7d
WIP: add gramian-based jac_to_grad
ValerianRey Jan 21, 2026
8bdf512
Update changelog
ValerianRey Jan 21, 2026
aaf2544
Use deque to free memory asap
ValerianRey Jan 23, 2026
64b06ad
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 23, 2026
745f707
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
5eb77f9
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
8f65caa
Use gramian_weighting in jac_to_grad
ValerianRey Jan 28, 2026
6fe15a4
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
d5cb5c2
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 29, 2026
f986950
Only optimize when no forward hooks
ValerianRey Jan 29, 2026
4cf5cbb
Make _gramian_based take aggregator instead of weighting
ValerianRey Jan 29, 2026
add549c
Add _can_skip_jacobian_combination helper function
ValerianRey Jan 29, 2026
453971a
Add test_can_skip_jacobian_combination
ValerianRey Jan 29, 2026
9d4c41c
Optimize compute_gramian for when contracted_dims=-1
ValerianRey Jan 29, 2026
48cd70b
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 30, 2026
8f2660d
Use TypeGuard in _can_skip_jacobian_combination
ValerianRey Jan 30, 2026
fc9bbcf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2026
3f9a6d1
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 1, 2026
9d9cbf0
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 4, 2026
b5ca226
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
0baa914
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 5, 2026
2ed1d7c
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 13, 2026
86be778
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 13, 2026
4ace19e
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 13, 2026
2a84bef
Add ruff if-else squeezing
ValerianRey Feb 13, 2026
4b6209c
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 23, 2026
1b1c660
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 23, 2026
b714253
Many fixes of problems coming from the merge
ValerianRey Feb 23, 2026
2bb8ab1
Fix _can_skip_jacobian_combination
ValerianRey Feb 23, 2026
63c9dde
Make check_consistent_first_dimension work with Deque
ValerianRey Feb 23, 2026
0f85811
Improve test_can_skip_jacobian_combination
ValerianRey Feb 23, 2026
9d55215
Add optimize_gramian_computation param and add error when not compatible
ValerianRey Feb 23, 2026
456510b
Fix overloads (partly) and add missing code coverage
ValerianRey Feb 23, 2026
55c69d1
Fix overloads
ValerianRey Feb 23, 2026
8a401a3
Fix docstring
ValerianRey Feb 23, 2026
b4bf7c4
fixup what @ValerianRey did wrong
PierreQuinton Feb 23, 2026
24a991a
Improve error message
ValerianRey Feb 23, 2026
2ea44a4
Improve docstring
ValerianRey Feb 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ changelog does not include internal changes that do not affect the user.
jac_to_grad(shared_module.parameters(), aggregator)
```

- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency
of `autojac`.
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
efficiency of `autojac`.
- Removed several unnecessary memory duplications. This should significantly improve the memory
efficiency and speed of `autojac`.

## [0.8.1] - 2026-01-07

Expand Down
21 changes: 16 additions & 5 deletions src/torchjd/_linalg/_gramian.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,22 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
first dimension).
"""

contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim
indices_source = list(range(t.ndim - contracted_dims))
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
transposed = t.movedim(indices_source, indices_dest)
gramian = torch.tensordot(t, transposed, dims=contracted_dims)
# Optimization: it's faster to do that than moving dims and using tensordot, and this case
# happens very often, sometimes hundreds of times for a single jac_to_grad.
if contracted_dims == -1:
if t.ndim == 1:
matrix = t.unsqueeze(1)
else:
matrix = t.flatten(start_dim=1)

gramian = matrix @ matrix.T

else:
contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim
indices_source = list(range(t.ndim - contracted_dims))
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
transposed = t.movedim(indices_source, indices_dest)
gramian = torch.tensordot(t, transposed, dims=contracted_dims)
return cast(PSDTensor, gramian)


Expand Down
64 changes: 54 additions & 10 deletions src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from collections import deque
from collections.abc import Iterable
from typing import cast
Comment thread
ValerianRey marked this conversation as resolved.
Outdated

import torch
from torch import Tensor
from torch import Tensor, nn

from torchjd._linalg import PSDMatrix, compute_gramian
from torchjd.aggregation import Aggregator
from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator

from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac

Expand Down Expand Up @@ -63,29 +67,69 @@ def jac_to_grad(
if len(tensors_) == 0:
return

jacobians = [t.jac for t in tensors_]
jacobians = deque(t.jac for t in tensors_)

if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]]):
if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians]):
raise ValueError("All Jacobians should have the same number of rows.")

if not retain_jac:
_free_jacs(tensors_)

if _can_skip_jacobian_combination(aggregator):
gradients = _gramian_based(cast(GramianWeightedAggregator, aggregator), jacobians, tensors_)
Comment thread
ValerianRey marked this conversation as resolved.
Outdated
else:
gradients = _jacobian_based(aggregator, jacobians, tensors_)
accumulate_grads(tensors_, gradients)


def _can_skip_jacobian_combination(aggregator: Aggregator) -> bool:
return isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator)
Comment thread
ValerianRey marked this conversation as resolved.
Outdated


def _has_forward_hook(module: nn.Module) -> bool:
"""Return whether the module has any forward hook registered."""
return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0


def _jacobian_based(
aggregator: Aggregator, jacobians: deque[Tensor], tensors: list[TensorWithJac]
) -> list[Tensor]:
jacobian_matrix = _unite_jacobians(jacobians)
gradient_vector = aggregator(jacobian_matrix)
gradients = _disunite_gradient(gradient_vector, jacobians, tensors_)
accumulate_grads(tensors_, gradients)
gradients = _disunite_gradient(gradient_vector, tensors)
return gradients


def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
jacobian_matrices = [jacobian.reshape(jacobian.shape[0], -1) for jacobian in jacobians]
def _gramian_based(
aggregator: GramianWeightedAggregator, jacobians: deque[Tensor], tensors: list[TensorWithJac]
) -> list[Tensor]:
weighting = aggregator.gramian_weighting
gramian = _compute_gramian_sum(jacobians)
weights = weighting(gramian)

gradients = list[Tensor]()
while jacobians:
jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap
gradients.append(torch.tensordot(weights, jacobian, dims=1))

return gradients


def _compute_gramian_sum(jacobians: deque[Tensor]) -> PSDMatrix:
gramian = sum([compute_gramian(matrix) for matrix in jacobians])
return cast(PSDMatrix, gramian)


def _unite_jacobians(jacobians: deque[Tensor]) -> Tensor:
jacobian_matrices = list[Tensor]()
while jacobians:
jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap
jacobian_matrices.append(jacobian.reshape(jacobian.shape[0], -1))
jacobian_matrix = torch.concat(jacobian_matrices, dim=1)
return jacobian_matrix


def _disunite_gradient(
gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac]
) -> list[Tensor]:
def _disunite_gradient(gradient_vector: Tensor, tensors: list[TensorWithJac]) -> list[Tensor]:
gradient_vectors = gradient_vector.split([t.numel() for t in tensors])
gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors)]
return gradients
Expand Down
117 changes: 115 additions & 2 deletions tests/unit/autojac/test_jac_to_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,28 @@
from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac
from utils.tensors import tensor_

from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad
from torchjd.autojac._jac_to_grad import jac_to_grad
from torchjd.aggregation import (
IMTLG,
MGDA,
Aggregator,
AlignedMTL,
ConFIG,
Constant,
DualProj,
GradDrop,
Krum,
Mean,
PCGrad,
Random,
Sum,
TrimmedMean,
UPGrad,
)
from torchjd.autojac._jac_to_grad import (
_can_skip_jacobian_combination,
_has_forward_hook,
jac_to_grad,
)


@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()])
Expand Down Expand Up @@ -101,3 +121,96 @@ def test_jacs_are_freed(retain_jac: bool):
check = assert_has_jac if retain_jac else assert_has_no_jac
check(t1)
check(t2)


def test_has_forward_hook():
"""Tests that _has_forward_hook correctly detects the presence of forward hooks."""

module = UPGrad()

def dummy_forward_hook(_module, _input, _output):
return _output

def dummy_forward_pre_hook(_module, _input):
return _input

def dummy_backward_hook(_module, _grad_input, _grad_output):
return _grad_input

def dummy_backward_pre_hook(_module, _grad_output):
return _grad_output

# Module with no hooks or backward hooks only should return False
assert not _has_forward_hook(module)
module.register_full_backward_hook(dummy_backward_hook)
assert not _has_forward_hook(module)
module.register_full_backward_pre_hook(dummy_backward_pre_hook)
assert not _has_forward_hook(module)

# Module with forward hook should return True
handle1 = module.register_forward_hook(dummy_forward_hook)
assert _has_forward_hook(module)
handle2 = module.register_forward_hook(dummy_forward_hook)
assert _has_forward_hook(module)
handle1.remove()
assert _has_forward_hook(module)
handle2.remove()
assert not _has_forward_hook(module)

# Module with forward pre-hook should return True
handle3 = module.register_forward_pre_hook(dummy_forward_pre_hook)
assert _has_forward_hook(module)
handle4 = module.register_forward_pre_hook(dummy_forward_pre_hook)
assert _has_forward_hook(module)
handle3.remove()
assert _has_forward_hook(module)
handle4.remove()
assert not _has_forward_hook(module)


_PARAMETRIZATIONS = [
(AlignedMTL(), True),
(DualProj(), True),
(IMTLG(), True),
(Krum(n_byzantine=1), True),
(MGDA(), True),
(PCGrad(), True),
(UPGrad(), True),
(ConFIG(), False),
(Constant(tensor_([0.5, 0.5])), False),
(GradDrop(), False),
(Mean(), False),
(Random(), False),
(Sum(), False),
(TrimmedMean(trim_number=1), False),
]

try:
from torchjd.aggregation import CAGrad

_PARAMETRIZATIONS.append((CAGrad(c=0.5), True))
except ImportError:
pass

try:
from torchjd.aggregation import NashMTL

_PARAMETRIZATIONS.append((NashMTL(n_tasks=2), False))
except ImportError:
pass


@mark.parametrize("aggregator, expected", _PARAMETRIZATIONS)
def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool):
"""
Tests that _can_skip_jacobian_combination correctly identifies when optimization can be used.
"""

assert _can_skip_jacobian_combination(aggregator) == expected
handle = aggregator.register_forward_hook(lambda module, input, output: output)
assert not _can_skip_jacobian_combination(aggregator)
handle.remove()
handle = aggregator.register_forward_pre_hook(lambda module, input: input)
assert not _can_skip_jacobian_combination(aggregator)
handle.remove()
assert _can_skip_jacobian_combination(aggregator) == expected