Skip to content

Commit c08a855

Browse files
authored
typing: Add missing annotations (#584)
* Add ANN ruff rule * Add noqa comment in test_lightning_integration * Add annotations in static_plotter and rename a variable * Add annotations in conftest.py * Add annotations in run_profiler.py * Add -> None to __init__, __init_subclass__, main, and to test functions * Add -> str to __str__ * Add -> "MemoryFrame" to MemoryFrame.from_event * Add annotations for args and kwargs (Any) * Add annotations for ctx * Add annotations for obj * Add annotations for chunk_size * Add annotations to _make_tensors * Add annotations to CloneParams * Add annotations to InterModuleParamReuse.forward * Add annotations to time_call * Add annotations to update_gradient_coordinate
1 parent 448cca1 commit c08a855

93 files changed

Lines changed: 506 additions & 471 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

docs/source/conf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None:
100100
return link
101101

102102

103-
def _get_obj(_info: dict[str, str]):
103+
def _get_obj(_info: dict[str, str]) -> object:
104104
module_name = _info["module"]
105105
full_name = _info["fullname"]
106106
sub_module = sys.modules.get(module_name)
@@ -112,7 +112,7 @@ def _get_obj(_info: dict[str, str]):
112112
return obj
113113

114114

115-
def _get_file_name(obj) -> str | None:
115+
def _get_file_name(obj: object) -> str | None:
116116
try:
117117
file_name = inspect.getsourcefile(obj)
118118
file_name = os.path.relpath(file_name, start=_PATH_ROOT)
@@ -121,7 +121,7 @@ def _get_file_name(obj) -> str | None:
121121
return file_name
122122

123123

124-
def _get_line_str(obj) -> str:
124+
def _get_line_str(obj: object) -> str:
125125
source, start = inspect.getsourcelines(obj)
126126
end = start + len(source) - 1
127127
line_str = f"#L{start}-L{end}"

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ select = [
134134
"I", # isort
135135
"UP", # pyupgrade
136136
"ARG", # flake8-unused-arguments
137+
"ANN", # flake-8-annotations
137138
"B", # flake8-bugbear
138139
"C4", # flake8-comprehensions
139140
"FIX", # flake8-fixme
@@ -156,12 +157,16 @@ ignore = [
156157
"RUF012", # Mutable default value for class attribute (a bit tedious to fix)
157158
"RET504", # Unnecessary assignment return statement
158159
"COM812", # Trailing comma missing (conflicts with formatter, see https://github.com/astral-sh/ruff/issues/9216)
160+
"ANN401", # Prevent annotating as Any (we rarely do that, and when we do it's hard to find an alternative)
159161
]
160162

161163
[tool.ruff.lint.per-file-ignores]
162164
"**/conftest.py" = ["ARG"] # Can't change argument names in the functions pytest expects
163165
"tests/doc/test_rst.py" = ["ARG"] # For the lightning example
164166

167+
[tool.ruff.lint.flake8-annotations]
168+
suppress-dummy-args = true
169+
165170
[tool.ruff.lint.isort]
166171
combine-as-imports = true
167172

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class Aggregator(nn.Module, ABC):
1313
:math:`m \times n` into row vectors of dimension :math:`n`.
1414
"""
1515

16-
def __init__(self):
16+
def __init__(self) -> None:
1717
super().__init__()
1818

1919
@staticmethod
@@ -48,7 +48,7 @@ class WeightedAggregator(Aggregator):
4848
:param weighting: The object responsible for extracting the vector of weights from the matrix.
4949
"""
5050

51-
def __init__(self, weighting: Weighting[Matrix]):
51+
def __init__(self, weighting: Weighting[Matrix]) -> None:
5252
super().__init__()
5353
self.weighting = weighting
5454

@@ -77,6 +77,6 @@ class GramianWeightedAggregator(WeightedAggregator):
7777
gramian.
7878
"""
7979

80-
def __init__(self, gramian_weighting: Weighting[PSDMatrix]):
80+
def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None:
8181
super().__init__(gramian_weighting << compute_gramian)
8282
self.gramian_weighting = gramian_weighting

src/torchjd/aggregation/_aligned_mtl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(
6161
self,
6262
pref_vector: Tensor | None = None,
6363
scale_mode: SUPPORTED_SCALE_MODE = "min",
64-
):
64+
) -> None:
6565
self._pref_vector = pref_vector
6666
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
6767
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))
@@ -92,7 +92,7 @@ def __init__(
9292
self,
9393
pref_vector: Tensor | None = None,
9494
scale_mode: SUPPORTED_SCALE_MODE = "min",
95-
):
95+
) -> None:
9696
super().__init__()
9797
self._pref_vector = pref_vector
9898
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode

src/torchjd/aggregation/_cagrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class CAGrad(GramianWeightedAggregator):
3434
To install it, use ``pip install "torchjd[cagrad]"``.
3535
"""
3636

37-
def __init__(self, c: float, norm_eps: float = 0.0001):
37+
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
3838
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
3939
self._c = c
4040
self._norm_eps = norm_eps
@@ -67,7 +67,7 @@ class CAGradWeighting(Weighting[PSDMatrix]):
6767
function.
6868
"""
6969

70-
def __init__(self, c: float, norm_eps: float = 0.0001):
70+
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
7171
super().__init__()
7272

7373
if c < 0.0:

src/torchjd/aggregation/_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class ConFIG(Aggregator):
5050
<https://github.com/tum-pbs/ConFIG/tree/main/conflictfree>`_.
5151
"""
5252

53-
def __init__(self, pref_vector: Tensor | None = None):
53+
def __init__(self, pref_vector: Tensor | None = None) -> None:
5454
super().__init__()
5555
self.weighting = pref_vector_to_weighting(pref_vector, default=SumWeighting())
5656
self._pref_vector = pref_vector

src/torchjd/aggregation/_constant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class Constant(WeightedAggregator):
1515
:param weights: The weights associated to the rows of the input matrices.
1616
"""
1717

18-
def __init__(self, weights: Tensor):
18+
def __init__(self, weights: Tensor) -> None:
1919
super().__init__(weighting=ConstantWeighting(weights=weights))
2020
self._weights = weights
2121

@@ -35,7 +35,7 @@ class ConstantWeighting(Weighting[Matrix]):
3535
:param weights: The weights to return at each call.
3636
"""
3737

38-
def __init__(self, weights: Tensor):
38+
def __init__(self, weights: Tensor) -> None:
3939
if weights.dim() != 1:
4040
raise ValueError(
4141
"Parameter `weights` should be a 1-dimensional tensor. Found `weights.shape = "

src/torchjd/aggregation/_dualproj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
norm_eps: float = 0.0001,
3434
reg_eps: float = 0.0001,
3535
solver: SUPPORTED_SOLVER = "quadprog",
36-
):
36+
) -> None:
3737
self._pref_vector = pref_vector
3838
self._norm_eps = norm_eps
3939
self._reg_eps = reg_eps
@@ -77,7 +77,7 @@ def __init__(
7777
norm_eps: float = 0.0001,
7878
reg_eps: float = 0.0001,
7979
solver: SUPPORTED_SOLVER = "quadprog",
80-
):
80+
) -> None:
8181
super().__init__()
8282
self._pref_vector = pref_vector
8383
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())

src/torchjd/aggregation/_flattening.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Flattening(GeneralizedWeighting):
2020
:param weighting: The weighting to apply to the Gramian matrix.
2121
"""
2222

23-
def __init__(self, weighting: Weighting):
23+
def __init__(self, weighting: Weighting) -> None:
2424
super().__init__()
2525
self.weighting = weighting
2626

src/torchjd/aggregation/_graddrop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class GradDrop(Aggregator):
2626
through. Defaults to None, which means no leak.
2727
"""
2828

29-
def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
29+
def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None:
3030
if leak is not None and leak.dim() != 1:
3131
raise ValueError(
3232
"Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = "

0 commit comments

Comments
 (0)