File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11from torch import Tensor
22
3+ from .bases import _Weighting
34from .constant import _ConstantWeighting
4- from .mean import _MeanWeighting
55
66
77def _check_pref_vector (pref_vector : Tensor | None ) -> None :
@@ -15,12 +15,13 @@ def _check_pref_vector(pref_vector: Tensor | None) -> None:
1515 )
1616
1717
18- def _pref_vector_to_weighting (pref_vector : Tensor | None ) -> _ConstantWeighting | _MeanWeighting :
19- """Returns the weighting associated to a given preference vector."""
18+ def _pref_vector_to_weighting (pref_vector : Tensor | None , default : _Weighting ) -> _Weighting :
19+ """
20+ Returns the weighting associated to a given preference vector, with a fallback to a default
21+ weighting if the preference vector is None.
22+ """
2023
2124 if pref_vector is None :
22- weighting = _MeanWeighting ()
25+ return default
2326 else :
24- weighting = _ConstantWeighting (pref_vector )
25-
26- return weighting
27+ return _ConstantWeighting (pref_vector )
Original file line number Diff line number Diff line change 3232from ._pref_vector_utils import _check_pref_vector , _pref_vector_to_weighting
3333from ._str_utils import _vector_to_str
3434from .bases import _WeightedAggregator , _Weighting
35+ from .mean import _MeanWeighting
3536
3637
3738class AlignedMTL (_WeightedAggregator ):
@@ -63,7 +64,7 @@ class AlignedMTL(_WeightedAggregator):
6364
6465 def __init__ (self , pref_vector : Tensor | None = None ):
6566 _check_pref_vector (pref_vector )
66- weighting = _pref_vector_to_weighting (pref_vector )
67+ weighting = _pref_vector_to_weighting (pref_vector , default = _MeanWeighting () )
6768 self ._pref_vector = pref_vector
6869
6970 super ().__init__ (weighting = _AlignedMTLWrapper (weighting ))
Original file line number Diff line number Diff line change 99from ._pref_vector_utils import _check_pref_vector , _pref_vector_to_weighting
1010from ._str_utils import _vector_to_str
1111from .bases import _WeightedAggregator , _Weighting
12+ from .mean import _MeanWeighting
1213
1314
1415class DualProj (_WeightedAggregator ):
@@ -50,7 +51,7 @@ def __init__(
5051 solver : Literal ["quadprog" ] = "quadprog" ,
5152 ):
5253 _check_pref_vector (pref_vector )
53- weighting = _pref_vector_to_weighting (pref_vector )
54+ weighting = _pref_vector_to_weighting (pref_vector , default = _MeanWeighting () )
5455 self ._pref_vector = pref_vector
5556
5657 super ().__init__ (
Original file line number Diff line number Diff line change 99from ._pref_vector_utils import _check_pref_vector , _pref_vector_to_weighting
1010from ._str_utils import _vector_to_str
1111from .bases import _WeightedAggregator , _Weighting
12+ from .mean import _MeanWeighting
1213
1314
1415class UPGrad (_WeightedAggregator ):
@@ -49,7 +50,7 @@ def __init__(
4950 solver : Literal ["quadprog" ] = "quadprog" ,
5051 ):
5152 _check_pref_vector (pref_vector )
52- weighting = _pref_vector_to_weighting (pref_vector )
53+ weighting = _pref_vector_to_weighting (pref_vector , default = _MeanWeighting () )
5354 self ._pref_vector = pref_vector
5455
5556 super ().__init__ (
You can’t perform that action at this time.
0 commit comments