Skip to content

Commit a81e64d

Browse files
authored
Maintain usage examples and their tests (#215)
* Use only floating point values in the definition of J in TrimmedMean's usage example * Reorder imports in rst examples * Remove filterwarnings in cagrad and nash_mtl usage examples * Add module docstrings for tests/doc modules * Wrap extra code of doc tests into Extra blocks
1 parent 1eaafee commit a81e64d

8 files changed

Lines changed: 41 additions & 19 deletions

File tree

docs/source/examples/basic_usage.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Import several classes from ``torch`` and ``torchjd``:
1515
.. code-block:: python
1616
1717
import torch
18-
from torch.nn import MSELoss, Sequential, Linear, ReLU
18+
from torch.nn import Linear, MSELoss, ReLU, Sequential
1919
from torch.optim import SGD
2020
2121
import torchjd

docs/source/examples/iwrm.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ each Jacobian matrix consists of one gradient per loss. In this example, we use
2929
3030
import torch
3131
from torch.nn import (
32-
MSELoss,
33-
Sequential,
3432
Linear,
35-
ReLU
33+
MSELoss,
34+
ReLU,
35+
Sequential
3636
)
3737
from torch.optim import SGD
3838
@@ -70,10 +70,10 @@ each Jacobian matrix consists of one gradient per loss. In this example, we use
7070
7171
import torch
7272
from torch.nn import (
73-
MSELoss,
74-
Sequential,
7573
Linear,
76-
ReLU
74+
MSELoss,
75+
ReLU,
76+
Sequential
7777
)
7878
from torch.optim import SGD
7979

src/torchjd/aggregation/cagrad.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ class CAGrad(_WeightedAggregator):
2121
2222
Use CAGrad to aggregate a matrix.
2323
24-
>>> import warnings
25-
>>> warnings.filterwarnings("ignore")
26-
>>>
2724
>>> from torch import tensor
2825
>>> from torchjd.aggregation import CAGrad
2926
>>>

src/torchjd/aggregation/nash_mtl.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ class NashMTL(_WeightedAggregator):
5151
5252
Use NashMTL to aggregate a matrix.
5353
54-
>>> import warnings
55-
>>> warnings.filterwarnings("ignore")
56-
>>>
5754
>>> from torch import tensor
5855
>>> from torchjd.aggregation import NashMTL
5956
>>>

src/torchjd/aggregation/trimmed_mean.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ class TrimmedMean(Aggregator):
2525
>>>
2626
>>> A = TrimmedMean(trim_number=1)
2727
>>> J = tensor([
28-
... [ 1e11, 3],
29-
... [ 1, -1e11],
28+
... [ 1e11, 3.],
29+
... [ 1., -1e11],
3030
... [-1e10, 1e10],
31-
... [ 2, 2],
31+
... [ 2., 2.],
3232
... ])
3333
>>>
3434
>>> A(J)

tests/doc/test_aggregation.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
"""
2+
This file contains tests for the usage examples provided in the Aggregator subclasses. Each
3+
Aggregator's usage example, including its imports, should be copied in a function here, with the
4+
only difference that the advertised output should be replaced by a call to `assert_close`. The
5+
functions should be in alphabetical order.
6+
"""
7+
18
import torch
29
from torch.testing import assert_close
310

@@ -14,9 +21,11 @@ def test_aligned_mtl():
1421

1522

1623
def test_cagrad():
24+
# Extra ----------------------------------------------------------------------------------------
1725
import warnings
1826

1927
warnings.filterwarnings("ignore")
28+
# ----------------------------------------------------------------------------------------------
2029

2130
from torch import tensor
2231

@@ -116,9 +125,11 @@ def test_mgda():
116125

117126

118127
def test_nash_mtl():
128+
# Extra ----------------------------------------------------------------------------------------
119129
import warnings
120130

121131
warnings.filterwarnings("ignore")
132+
# ----------------------------------------------------------------------------------------------
122133

123134
from torch import tensor
124135

@@ -149,7 +160,10 @@ def test_random():
149160
A = Random()
150161
J = tensor([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
151162

163+
# Extra ----------------------------------------------------------------------------------------
152164
_ = torch.manual_seed(0)
165+
# ----------------------------------------------------------------------------------------------
166+
153167
assert_close(A(J), tensor([-2.6229, 1.0000, 1.0000]), rtol=0, atol=1e-4)
154168
assert_close(A(J), tensor([5.3976, 1.0000, 1.0000]), rtol=0, atol=1e-4)
155169

@@ -173,10 +187,10 @@ def test_trimmed_mean():
173187
A = TrimmedMean(trim_number=1)
174188
J = tensor(
175189
[
176-
[1e11, 3],
177-
[1, -1e11],
190+
[1e11, 3.0],
191+
[1.0, -1e11],
178192
[-1e10, 1e10],
179-
[2, 2],
193+
[2.0, 2.0],
180194
]
181195
)
182196

tests/doc/test_backward.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""
2+
This file contains the test of the backward usage example, with a verification of the value of the
3+
obtained `.grad` field.
4+
"""
5+
16
from torch.testing import assert_close
27
from unit.conftest import DEVICE
38

tests/doc/test_rst.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
"""
2+
This file contains the tests corresponding to the extra usage examples contained in the `.rst` files
3+
of the documentation. When there are multiple examples within a single `.rst` file, we use nested
4+
functions here to test them.
5+
"""
6+
7+
18
def test_basic_usage():
29
import torch
310
from torch.nn import Linear, MSELoss, ReLU, Sequential
@@ -113,11 +120,13 @@ def test_mtl():
113120

114121

115122
def test_lightning_integration():
123+
# Extra ----------------------------------------------------------------------------------------
116124
import logging
117125
import warnings
118126

119127
warnings.filterwarnings("ignore")
120128
logging.disable(logging.INFO)
129+
# ----------------------------------------------------------------------------------------------
121130

122131
import torch
123132
from lightning import LightningModule, Trainer

0 commit comments

Comments
 (0)