Skip to content

Commit 926ad3f

Browse files
authored
Add _pref_vector_to_str_suffix (#230)
1 parent 419ccf6 commit 926ad3f

4 files changed

Lines changed: 28 additions & 21 deletions

File tree

src/torchjd/aggregation/_pref_vector_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from torch import Tensor
22

3+
from ._str_utils import _vector_to_str
34
from .bases import _Weighting
45
from .constant import _ConstantWeighting
56

@@ -25,3 +26,12 @@ def _pref_vector_to_weighting(pref_vector: Tensor | None, default: _Weighting) -
2526
return default
2627
else:
2728
return _ConstantWeighting(pref_vector)
29+
30+
31+
def _pref_vector_to_str_suffix(pref_vector: Tensor | None) -> str:
32+
"""Returns a suffix string containing the representation of the optional preference vector."""
33+
34+
if pref_vector is None:
35+
return ""
36+
else:
37+
return f"([{_vector_to_str(pref_vector)}])"

src/torchjd/aggregation/aligned_mtl.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@
2929
from torch import Tensor
3030
from torch.linalg import LinAlgError
3131

32-
from ._pref_vector_utils import _check_pref_vector, _pref_vector_to_weighting
33-
from ._str_utils import _vector_to_str
32+
from ._pref_vector_utils import (
33+
_check_pref_vector,
34+
_pref_vector_to_str_suffix,
35+
_pref_vector_to_weighting,
36+
)
3437
from .bases import _WeightedAggregator, _Weighting
3538
from .mean import _MeanWeighting
3639

@@ -73,11 +76,7 @@ def __repr__(self) -> str:
7376
return f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)})"
7477

7578
def __str__(self) -> str:
76-
if self._pref_vector is None:
77-
suffix = ""
78-
else:
79-
suffix = f"([{_vector_to_str(self._pref_vector)}])"
80-
return f"AlignedMTL{suffix}"
79+
return f"AlignedMTL{_pref_vector_to_str_suffix(self._pref_vector)}"
8180

8281

8382
class _AlignedMTLWrapper(_Weighting):

src/torchjd/aggregation/dualproj.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
from torch import Tensor
77

88
from ._gramian_utils import _compute_normalized_gramian
9-
from ._pref_vector_utils import _check_pref_vector, _pref_vector_to_weighting
10-
from ._str_utils import _vector_to_str
9+
from ._pref_vector_utils import (
10+
_check_pref_vector,
11+
_pref_vector_to_str_suffix,
12+
_pref_vector_to_weighting,
13+
)
1114
from .bases import _WeightedAggregator, _Weighting
1215
from .mean import _MeanWeighting
1316

@@ -68,11 +71,7 @@ def __repr__(self) -> str:
6871
)
6972

7073
def __str__(self) -> str:
71-
if self._pref_vector is None:
72-
suffix = ""
73-
else:
74-
suffix = f"([{_vector_to_str(self._pref_vector)}])"
75-
return f"DualProj{suffix}"
74+
return f"DualProj{_pref_vector_to_str_suffix(self._pref_vector)}"
7675

7776

7877
class _DualProjWrapper(_Weighting):

src/torchjd/aggregation/upgrad.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
from torch import Tensor
77

88
from ._gramian_utils import _compute_normalized_gramian
9-
from ._pref_vector_utils import _check_pref_vector, _pref_vector_to_weighting
10-
from ._str_utils import _vector_to_str
9+
from ._pref_vector_utils import (
10+
_check_pref_vector,
11+
_pref_vector_to_str_suffix,
12+
_pref_vector_to_weighting,
13+
)
1114
from .bases import _WeightedAggregator, _Weighting
1215
from .mean import _MeanWeighting
1316

@@ -67,11 +70,7 @@ def __repr__(self) -> str:
6770
)
6871

6972
def __str__(self) -> str:
70-
if self._pref_vector is None:
71-
suffix = ""
72-
else:
73-
suffix = f"([{_vector_to_str(self._pref_vector)}])"
74-
return f"UPGrad{suffix}"
73+
return f"UPGrad{_pref_vector_to_str_suffix(self._pref_vector)}"
7574

7675

7776
class _UPGradWrapper(_Weighting):

0 commit comments

Comments
 (0)