Skip to content

Commit 419ccf6

Browse files
authored
Make _pref_vector_to_weighting take a default weighting as param (#229)
1 parent 7602801 commit 419ccf6

4 files changed

Lines changed: 14 additions & 10 deletions

File tree

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from torch import Tensor
22

3+
from .bases import _Weighting
34
from .constant import _ConstantWeighting
4-
from .mean import _MeanWeighting
55

66

77
def _check_pref_vector(pref_vector: Tensor | None) -> None:
@@ -15,12 +15,13 @@ def _check_pref_vector(pref_vector: Tensor | None) -> None:
1515
)
1616

1717

18-
def _pref_vector_to_weighting(pref_vector: Tensor | None) -> _ConstantWeighting | _MeanWeighting:
19-
"""Returns the weighting associated to a given preference vector."""
18+
def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -> _Weighting:
19+
"""
20+
Returns the weighting associated to a given preference vector, with a fallback to a default
21+
weighting if the preference vector is None.
22+
"""
2023

2124
if pref_vector is None:
22-
weighting = _MeanWeighting()
25+
return default
2326
else:
24-
weighting = _ConstantWeighting(pref_vector)
25-
26-
return weighting
27+
return _ConstantWeighting(pref_vector)

src/torchjd/aggregation/aligned_mtl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ._pref_vector_utils import _check_pref_vector, _pref_vector_to_weighting
3333
from ._str_utils import _vector_to_str
3434
from .bases import _WeightedAggregator, _Weighting
35+
from .mean import _MeanWeighting
3536

3637

3738
class AlignedMTL(_WeightedAggregator):
@@ -63,7 +64,7 @@ class AlignedMTL(_WeightedAggregator):
6364

6465
def __init__(self, pref_vector: Tensor | None = None):
6566
_check_pref_vector(pref_vector)
66-
weighting = _pref_vector_to_weighting(pref_vector)
67+
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
6768
self._pref_vector = pref_vector
6869

6970
super().__init__(weighting=_AlignedMTLWrapper(weighting))

src/torchjd/aggregation/dualproj.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ._pref_vector_utils import _check_pref_vector, _pref_vector_to_weighting
1010
from ._str_utils import _vector_to_str
1111
from .bases import _WeightedAggregator, _Weighting
12+
from .mean import _MeanWeighting
1213

1314

1415
class DualProj(_WeightedAggregator):
@@ -50,7 +51,7 @@ def __init__(
5051
solver: Literal["quadprog"] = "quadprog",
5152
):
5253
_check_pref_vector(pref_vector)
53-
weighting = _pref_vector_to_weighting(pref_vector)
54+
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
5455
self._pref_vector = pref_vector
5556

5657
super().__init__(

src/torchjd/aggregation/upgrad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ._pref_vector_utils import _check_pref_vector, _pref_vector_to_weighting
1010
from ._str_utils import _vector_to_str
1111
from .bases import _WeightedAggregator, _Weighting
12+
from .mean import _MeanWeighting
1213

1314

1415
class UPGrad(_WeightedAggregator):
@@ -49,7 +50,7 @@ def __init__(
4950
solver: Literal["quadprog"] = "quadprog",
5051
):
5152
_check_pref_vector(pref_vector)
52-
weighting = _pref_vector_to_weighting(pref_vector)
53+
weighting = _pref_vector_to_weighting(pref_vector, default=_MeanWeighting())
5354
self._pref_vector = pref_vector
5455

5556
super().__init__(

0 commit comments

Comments
 (0)