Skip to content

Commit 95bb00d

Browse files
authored
Add representation tests when pref_vector is not None (#233)
1 parent 2085690 commit 95bb00d

3 files changed

Lines changed: 30 additions & 0 deletions

File tree

tests/unit/aggregation/test_aligned_mtl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
from pytest import mark
23

34
from torchjd.aggregation import AlignedMTL
@@ -14,3 +15,7 @@ def test_representations():
1415
A = AlignedMTL(pref_vector=None)
1516
assert repr(A) == "AlignedMTL(pref_vector=None)"
1617
assert str(A) == "AlignedMTL"
18+
19+
A = AlignedMTL(pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"))
20+
assert repr(A) == "AlignedMTL(pref_vector=tensor([1., 2., 3.]))"
21+
assert str(A) == "AlignedMTL([1., 2., 3.])"

tests/unit/aggregation/test_dualproj.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
from pytest import mark
23

34
from torchjd.aggregation import DualProj
@@ -22,3 +23,15 @@ def test_representations():
2223
repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')"
2324
)
2425
assert str(A) == "DualProj"
26+
27+
A = DualProj(
28+
pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"),
29+
norm_eps=0.0001,
30+
reg_eps=0.0001,
31+
solver="quadprog",
32+
)
33+
assert (
34+
repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, "
35+
"solver='quadprog')"
36+
)
37+
assert str(A) == "DualProj([1., 2., 3.])"

tests/unit/aggregation/test_upgrad.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,15 @@ def test_representations():
4747
A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog")
4848
assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')"
4949
assert str(A) == "UPGrad"
50+
51+
A = UPGrad(
52+
pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"),
53+
norm_eps=0.0001,
54+
reg_eps=0.0001,
55+
solver="quadprog",
56+
)
57+
assert (
58+
repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, "
59+
"solver='quadprog')"
60+
)
61+
assert str(A) == "UPGrad([1., 2., 3.])"

0 commit comments

Comments
 (0)