Skip to content

Commit 56f319f

Browse files
authored
Maintain aggregation unit tests (#214)
* Remove dead code in tests/unit/aggregation/_inputs.py * Remove useless warning filtering in test_cagrad * Shorten test_imtlg_zero * Update docstring of test_equivalence_upgrad_sum_two_rows
1 parent a81e64d commit 56f319f

4 files changed

Lines changed: 2 additions & 44 deletions

File tree

tests/unit/aggregation/_inputs.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -118,37 +118,6 @@ def generate_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
118118
return matrix
119119

120120

121-
def generate_positively_oriented_matrix(
122-
n_rows: int, n_cols: int, rank: int
123-
) -> tuple[Tensor, tuple[Tensor, Tensor, Tensor]]:
124-
"""
125-
Generates a random matrix of shape [n_rows, n_cols] with an SVD such that the largest singular
126-
value corresponds to a left singular vector that is all positive. Also returns the singular
127-
triple corresponding the largest singular value.
128-
If ``M, (u, s, v)`` is the output, we guarantee that:
129-
- ``s * u ~= M @ v``
130-
- ``u`` is a positive vector.
131-
- ``s`` is the largest singular value of ``M``.
132-
"""
133-
134-
_check_valid_rank(n_rows, n_cols, rank)
135-
if rank == 0:
136-
matrix = torch.zeros([n_rows, n_cols], device=DEVICE)
137-
largest_singular_value_triple = (
138-
torch.zeros([n_rows], device=DEVICE),
139-
torch.zeros([], device=DEVICE),
140-
torch.zeros([n_cols], device=DEVICE),
141-
)
142-
else:
143-
U = _generate_unitary_matrix_with_positive_column(n_rows, rank)
144-
V = _generate_unitary_matrix(n_cols, rank)
145-
S = _generate_diagonal_singular_values(rank)
146-
matrix = U @ S @ V.T
147-
largest_singular_value_triple = (U[:, 0], S[0, 0], V[:, 0])
148-
149-
return matrix, largest_singular_value_triple
150-
151-
152121
def generate_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
153122
"""
154123
Generates a random matrix of shape [``n_rows``, ``n_cols``] with provided ``rank``. The matrix
@@ -199,10 +168,6 @@ def generate_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
199168
scaled_matrices_2_plus_rows = [
200169
matrix for matrix in scaled_matrices + zero_rank_matrices if matrix.shape[0] >= 2
201170
]
202-
matrices_and_triples = [
203-
generate_positively_oriented_matrix(n_rows, n_cols, rank)
204-
for n_rows, n_cols, rank in _matrix_dimension_triples
205-
]
206171
stationary_matrices = [
207172
generate_stationary_matrix(n_rows, n_cols, rank)
208173
for n_rows, n_cols, rank in _matrix_dimension_triples

tests/unit/aggregation/test_cagrad.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,18 @@
88
from ._property_testers import ExpectedStructureProperty, NonConflictingProperty
99

1010

11-
@mark.filterwarnings("ignore:np.find_common_type is deprecated:")
1211
@mark.parametrize("aggregator", [CAGrad(c=0.5)])
1312
class TestCAGrad(ExpectedStructureProperty):
1413
pass
1514

1615

17-
@mark.filterwarnings("ignore:np.find_common_type is deprecated:")
1816
@mark.parametrize("aggregator", [CAGrad(c=1.0), CAGrad(c=2.0)])
1917
class TestCAGradNonConflicting(NonConflictingProperty):
2018
"""Tests that CAGrad is non-conflicting when c >= 1 (it should not hold when c < 1)"""
2119

2220
pass
2321

2422

25-
@mark.filterwarnings("ignore:np.find_common_type is deprecated:")
2623
@mark.parametrize("matrix", stationary_matrices + matrices)
2724
def test_equivalence_mean(matrix: Tensor):
2825
"""Tests that CAGrad is equivalent to Mean when c=0."""

tests/unit/aggregation/test_imtl_g.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,7 @@ def test_imtlg_zero():
2121

2222
A = IMTLG()
2323
J = torch.zeros(2, 3, device=DEVICE)
24-
25-
aggregation = A(J)
26-
expected = torch.zeros(3, device=DEVICE)
27-
28-
assert_close(aggregation, expected)
24+
assert_close(A(J), torch.zeros(3, device=DEVICE))
2925

3026

3127
def test_representations():

tests/unit/aggregation/test_pcgrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class TestPCGrad(ExpectedStructureProperty):
3333
)
3434
def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]):
3535
"""
36-
Tests that UPGradWrapper of a SumWeighting is equivalent to PCGradWeighting for matrices of 2
36+
Tests that _UPGradWrapper of a _SumWeighting is equivalent to _PCGradWeighting for matrices of 2
3737
rows.
3838
"""
3939

0 commit comments

Comments
 (0)