Skip to content

Commit 94a9d17

Browse files
test(aggregation): Add test coverage for Flattening (#628)
1 parent 049c2d0 commit 94a9d17

1 file changed

Lines changed: 35 additions & 0 deletions

File tree

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from pytest import mark
2+
from torch.testing import assert_close
3+
from utils.tensors import randn_
4+
5+
from torchjd._linalg import PSDMatrix, compute_gramian, flatten
6+
from torchjd.aggregation import Flattening, MeanWeighting, SumWeighting, UPGradWeighting, Weighting
7+
8+
9+
@mark.parametrize(
10+
"half_shape",
11+
[
12+
[1],
13+
[12],
14+
[4, 3],
15+
[2, 3, 2],
16+
],
17+
)
18+
@mark.parametrize(
19+
"weighting",
20+
[
21+
SumWeighting(),
22+
MeanWeighting(),
23+
UPGradWeighting(),
24+
],
25+
)
26+
def test_flattening(half_shape: list[int], weighting: Weighting[PSDMatrix]) -> None:
27+
matrix = randn_([*half_shape, 2])
28+
generalized_gramian = compute_gramian(matrix, 1)
29+
gramian = flatten(generalized_gramian)
30+
31+
flattening = Flattening(weighting)
32+
weights = flattening(generalized_gramian)
33+
34+
expected_weights = weighting(gramian).reshape(half_shape)
35+
assert_close(weights, expected_weights)

0 commit comments

Comments
 (0)