Skip to content

Commit a698f68

Browse files
authored
Rename A to aggregator in backward and mtl_backward (#203)
* Rename A to aggregator in backward, mtl_backward and their usages * Add changelog entry
1 parent 69f56f6 commit a698f68

11 files changed

Lines changed: 90 additions & 82 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ changes that do not affect the user.
2323

2424
### Changed
2525

26+
- **BREAKING**: Changed the name of the parameter `A` to `aggregator` in `backward` and
27+
`mtl_backward`.
2628
- **BREAKING**: Changed the order of the parameters of `backward` and `mtl_backward` to make it
2729
possible to have a default value for `inputs` and for `shared_params` and `tasks_params`,
2830
respectively. Usages of `backward` and `mtl_backward` that rely on the order between arguments

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ params = [
7878

7979
loss_fn = MSELoss()
8080
optimizer = SGD(params, lr=0.1)
81-
A = UPGrad()
81+
aggregator = UPGrad()
8282

8383
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
8484
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
@@ -92,7 +92,7 @@ for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
9292
loss2 = loss_fn(output2, target2)
9393

9494
optimizer.zero_grad()
95-
mtl_backward(losses=[loss1, loss2], features=features, A=A)
95+
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
9696
optimizer.step()
9797
```
9898

docs/source/examples/basic_usage.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Define the aggregator that will be used to combine the Jacobian matrix:
3232

3333
.. code-block:: python
3434
35-
A = UPGrad()
35+
aggregator = UPGrad()
3636
3737
In essence, :doc:`UPGrad <../docs/aggregation/upgrad>` projects each gradient onto the dual cone of
3838
the rows of the Jacobian and averages the results. This ensures that locally, no loss will be
@@ -69,7 +69,7 @@ Perform the Jacobian descent backward pass:
6969

7070
.. code-block:: python
7171
72-
torchjd.backward([loss1, loss2], A)
72+
torchjd.backward([loss1, loss2], aggregator)
7373
7474
This will populate the ``.grad`` field of each model parameter with the corresponding aggregated
7575
Jacobian matrix.

docs/source/examples/iwrm.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ each Jacobian matrix consists of one gradient per loss. In this example, we use
9292
9393
params = model.parameters()
9494
optimizer = SGD(params, lr=0.1)
95-
A = UPGrad()
95+
aggregator = UPGrad()
9696
9797
for x, y in zip(X, Y):
9898
y_hat = model(x)
9999
losses = loss_fn(y_hat, y)
100100
optimizer.zero_grad()
101-
backward(losses, A)
101+
backward(losses, aggregator)
102102
optimizer.step()
103103
104104
Note that in both cases, we use the `torch.optim.SGD

docs/source/examples/lightning_integration.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ The following code example demonstrates a basic multi-task learning setup using
4444
4545
opt = self.optimizers()
4646
opt.zero_grad()
47-
mtl_backward(losses=[loss1, loss2], features=features, A=UPGrad())
47+
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
4848
opt.step()
4949
5050
def configure_optimizers(self) -> OptimizerLRScheduler:

docs/source/examples/mtl.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
3939
4040
loss_fn = MSELoss()
4141
optimizer = SGD(params, lr=0.1)
42-
A = UPGrad()
42+
aggregator = UPGrad()
4343
4444
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
4545
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
@@ -53,7 +53,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
5353
loss2 = loss_fn(output2, target2)
5454
5555
optimizer.zero_grad()
56-
mtl_backward(losses=[loss1, loss2], features=features, A=A)
56+
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
5757
optimizer.step()
5858
5959
.. note::

src/torchjd/autojac/backward.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,19 @@
1515

1616
def backward(
1717
tensors: Sequence[Tensor] | Tensor,
18-
A: Aggregator,
18+
aggregator: Aggregator,
1919
inputs: Iterable[Tensor] | None = None,
2020
retain_graph: bool = False,
2121
parallel_chunk_size: int | None = None,
2222
) -> None:
2323
r"""
2424
Computes the Jacobian of all values in ``tensors`` with respect to all ``inputs``. Computes its
25-
aggregation by ``A`` and accumulates it in the ``.grad`` fields of the ``inputs``.
25+
aggregation by the provided ``aggregator`` and accumulates it in the ``.grad`` fields of the
26+
``inputs``.
2627
2728
:param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobian
2829
matrices will have one row for each value of each of these tensors.
29-
:param A: Aggregator used to reduce the Jacobian into a vector.
30+
:param aggregator: Aggregator used to reduce the Jacobian into a vector.
3031
:param inputs: The tensors with respect to which the Jacobian must be computed. These must have
3132
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
3233
that were used to compute the ``tensors`` parameter.
@@ -95,7 +96,7 @@ def backward(
9596
jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph)
9697

9798
# Transform that aggregates the Jacobians.
98-
aggregate = Aggregate(A, inputs)
99+
aggregate = Aggregate(aggregator, inputs)
99100

100101
# Transform that accumulates the result in the .grad field of the inputs.
101102
accumulate = Accumulate(inputs)

src/torchjd/autojac/mtl_backward.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
def mtl_backward(
2828
losses: Sequence[Tensor],
2929
features: Sequence[Tensor] | Tensor,
30-
A: Aggregator,
30+
aggregator: Aggregator,
3131
tasks_params: Sequence[Iterable[Tensor]] | None = None,
3232
shared_params: Iterable[Tensor] | None = None,
3333
retain_graph: bool = False,
@@ -45,7 +45,7 @@ def mtl_backward(
4545
:param losses: The task losses. The Jacobian matrix will have one row per loss.
4646
:param features: The last shared representation used for all tasks, as given by the feature
4747
extractor. Should be non-empty.
48-
:param A: Aggregator used to reduce the Jacobian into a vector.
48+
:param aggregator: Aggregator used to reduce the Jacobian into a vector.
4949
:param tasks_params: The parameters of each task-specific head. Their ``requires_grad`` flags
5050
must be set to ``True``. If not provided, the parameters considered for each task will
5151
default to the leaf tensors that are in the computation graph of its loss, but that were not
@@ -129,7 +129,7 @@ def mtl_backward(
129129
jac = Jac(features, shared_params, parallel_chunk_size, retain_graph)
130130

131131
# Transform that aggregates the Jacobians.
132-
aggregate = Aggregate(A, shared_params)
132+
aggregate = Aggregate(aggregator, shared_params)
133133

134134
# Transform that accumulates the result in the .grad field of the shared parameters.
135135
accumulate = Accumulate(shared_params)

tests/doc/test_rst.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def test_basic_usage():
99
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2))
1010
optimizer = SGD(model.parameters(), lr=0.1)
1111

12-
A = UPGrad()
12+
aggregator = UPGrad()
1313
input = torch.randn(16, 10) # Batch of 16 random input vectors of length 10
1414
target1 = torch.randn(16) # First batch of 16 targets
1515
target2 = torch.randn(16) # Second batch of 16 targets
@@ -20,7 +20,7 @@ def test_basic_usage():
2020
loss2 = loss_fn(output[:, 1], target2)
2121

2222
optimizer.zero_grad()
23-
torchjd.backward([loss1, loss2], A)
23+
torchjd.backward([loss1, loss2], aggregator)
2424
optimizer.step()
2525

2626

@@ -62,13 +62,13 @@ def test_iwrm_with_ssjd():
6262

6363
params = model.parameters()
6464
optimizer = SGD(params, lr=0.1)
65-
A = UPGrad()
65+
aggregator = UPGrad()
6666

6767
for x, y in zip(X, Y):
6868
y_hat = model(x)
6969
losses = loss_fn(y_hat, y)
7070
optimizer.zero_grad()
71-
backward(losses, A)
71+
backward(losses, aggregator)
7272
optimizer.step()
7373

7474
test_erm_with_sgd()
@@ -94,7 +94,7 @@ def test_mtl():
9494

9595
loss_fn = MSELoss()
9696
optimizer = SGD(params, lr=0.1)
97-
A = UPGrad()
97+
aggregator = UPGrad()
9898

9999
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
100100
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
@@ -108,7 +108,7 @@ def test_mtl():
108108
loss2 = loss_fn(output2, target2)
109109

110110
optimizer.zero_grad()
111-
mtl_backward(losses=[loss1, loss2], features=features, A=A)
111+
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
112112
optimizer.step()
113113

114114

@@ -150,7 +150,7 @@ def training_step(self, batch, batch_idx) -> None:
150150

151151
opt = self.optimizers()
152152
opt.zero_grad()
153-
mtl_backward(losses=[loss1, loss2], features=features, A=UPGrad())
153+
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
154154
opt.step()
155155

156156
def configure_optimizers(self) -> OptimizerLRScheduler:

tests/unit/autojac/test_backward.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, UPGrad
1111

1212

13-
@mark.parametrize("A", [Mean(), UPGrad(), MGDA(), Random()])
14-
def test_backward_various_aggregators(A: Aggregator):
13+
@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()])
14+
def test_backward_various_aggregators(aggregator: Aggregator):
1515
"""Tests that backward works for various aggregators."""
1616

1717
p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
@@ -21,17 +21,17 @@ def test_backward_various_aggregators(A: Aggregator):
2121
y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum()
2222
y2 = (p1**2).sum() + p2.norm()
2323

24-
backward([y1, y2], A)
24+
backward([y1, y2], aggregator)
2525

2626
for p in params:
2727
assert (p.grad is not None) and (p.shape == p.grad.shape)
2828

2929

30-
@mark.parametrize("A", [Mean(), UPGrad(), MGDA()])
30+
@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA()])
3131
@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (60, 55), (120, 143)])
3232
@mark.parametrize("manually_specify_inputs", [True, False])
3333
def test_backward_value_is_correct(
34-
A: Aggregator, shape: tuple[int, int], manually_specify_inputs: bool
34+
aggregator: Aggregator, shape: tuple[int, int], manually_specify_inputs: bool
3535
):
3636
"""
3737
Tests that the .grad value filled by backward is correct in a simple example of matrix-vector
@@ -47,15 +47,15 @@ def test_backward_value_is_correct(
4747
else:
4848
inputs = None
4949

50-
backward([output], A, inputs=inputs)
50+
backward([output], aggregator, inputs=inputs)
5151

52-
assert_close(input.grad, A(J))
52+
assert_close(input.grad, aggregator(J))
5353

5454

5555
def test_backward_empty_inputs():
5656
"""Tests that backward does not fill the .grad values if no input is specified."""
5757

58-
A = Mean()
58+
aggregator = Mean()
5959

6060
p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
6161
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
@@ -64,7 +64,7 @@ def test_backward_empty_inputs():
6464
y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum()
6565
y2 = (p1**2).sum() + p2.norm()
6666

67-
backward([y1, y2], A, inputs=[])
67+
backward([y1, y2], aggregator, inputs=[])
6868

6969
for p in params:
7070
assert p.grad is None
@@ -76,15 +76,15 @@ def test_backward_partial_inputs():
7676
specified as inputs.
7777
"""
7878

79-
A = Mean()
79+
aggregator = Mean()
8080

8181
p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
8282
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
8383

8484
y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum()
8585
y2 = (p1**2).sum() + p2.norm()
8686

87-
backward([y1, y2], A, inputs=[p1])
87+
backward([y1, y2], aggregator, inputs=[p1])
8888

8989
assert (p1.grad is not None) and (p1.shape == p1.grad.shape)
9090
assert p2.grad is None
@@ -93,13 +93,13 @@ def test_backward_partial_inputs():
9393
def test_backward_empty_tensors():
9494
"""Tests that backward raises an error when called with an empty list of tensors."""
9595

96-
A = UPGrad()
96+
aggregator = UPGrad()
9797

9898
p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
9999
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
100100

101101
with raises(ValueError):
102-
backward([], A, inputs=[p1, p2])
102+
backward([], aggregator, inputs=[p1, p2])
103103

104104

105105
def test_backward_multiple_tensors():
@@ -108,7 +108,7 @@ def test_backward_multiple_tensors():
108108
containing the all the values of the original tensors.
109109
"""
110110

111-
A = UPGrad()
111+
aggregator = UPGrad()
112112

113113
p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
114114
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
@@ -117,13 +117,13 @@ def test_backward_multiple_tensors():
117117
y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum()
118118
y2 = (p1**2).sum() + p2.norm()
119119

120-
backward([y1, y2], A, retain_graph=True)
120+
backward([y1, y2], aggregator, retain_graph=True)
121121

122122
param_to_grad = {p: p.grad for p in params}
123123
for p in params:
124124
p.grad = None
125125

126-
backward(torch.cat([y1.reshape(-1), y2.reshape(-1)]), A)
126+
backward(torch.cat([y1.reshape(-1), y2.reshape(-1)]), aggregator)
127127

128128
for p in params:
129129
assert (p.grad == param_to_grad[p]).all()
@@ -133,7 +133,7 @@ def test_backward_multiple_tensors():
133133
def test_backward_valid_chunk_size(chunk_size):
134134
"""Tests that backward works for various valid values of parallel_chunk_size."""
135135

136-
A = UPGrad()
136+
aggregator = UPGrad()
137137

138138
p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
139139
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
@@ -142,7 +142,7 @@ def test_backward_valid_chunk_size(chunk_size):
142142
y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum()
143143
y2 = (p1**2).sum() + p2.norm()
144144

145-
backward([y1, y2], A, parallel_chunk_size=chunk_size, retain_graph=True)
145+
backward([y1, y2], aggregator, parallel_chunk_size=chunk_size, retain_graph=True)
146146

147147
for p in params:
148148
assert (p.grad is not None) and (p.shape == p.grad.shape)
@@ -152,7 +152,7 @@ def test_backward_valid_chunk_size(chunk_size):
152152
def test_backward_non_positive_chunk_size(chunk_size: int):
153153
"""Tests that backward raises an error when using invalid chunk sizes."""
154154

155-
A = UPGrad()
155+
aggregator = UPGrad()
156156

157157
p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
158158
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
@@ -161,7 +161,7 @@ def test_backward_non_positive_chunk_size(chunk_size: int):
161161
y2 = (p1**2).sum() + p2.norm()
162162

163163
with raises(ValueError):
164-
backward([y1, y2], A, parallel_chunk_size=chunk_size)
164+
backward([y1, y2], aggregator, parallel_chunk_size=chunk_size)
165165

166166

167167
@mark.parametrize(
@@ -174,7 +174,7 @@ def test_backward_no_retain_graph_small_chunk_size(chunk_size: int, expectation:
174174
large enough to allow differentiation of all tensors at once.
175175
"""
176176

177-
A = UPGrad()
177+
aggregator = UPGrad()
178178

179179
p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
180180
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
@@ -183,7 +183,7 @@ def test_backward_no_retain_graph_small_chunk_size(chunk_size: int, expectation:
183183
y2 = (p1**2).sum() + p2.norm()
184184

185185
with expectation:
186-
backward([y1, y2], A, retain_graph=False, parallel_chunk_size=chunk_size)
186+
backward([y1, y2], aggregator, retain_graph=False, parallel_chunk_size=chunk_size)
187187

188188

189189
def test_backward_fails_with_input_retaining_grad():
@@ -198,7 +198,7 @@ def test_backward_fails_with_input_retaining_grad():
198198
c = 3 * b
199199

200200
with raises(RuntimeError):
201-
backward(tensors=c, A=UPGrad(), inputs=[b])
201+
backward(tensors=c, aggregator=UPGrad(), inputs=[b])
202202

203203

204204
def test_backward_fails_with_non_input_retaining_grad():
@@ -213,7 +213,7 @@ def test_backward_fails_with_non_input_retaining_grad():
213213
c = 3 * b
214214

215215
# backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor
216-
backward(tensors=c, A=UPGrad(), inputs=[a])
216+
backward(tensors=c, aggregator=UPGrad(), inputs=[a])
217217

218218
with raises(RuntimeError):
219219
# Using such a BatchedTensor should result in an error

0 commit comments

Comments
 (0)