Skip to content

Commit 2085690

Browse files
authored
Use torch.set_default_device in conftest.py (#232)
* Force device to be cpu for some representation tests * Use set_default_device in conftest.py instead of always specifying the device
1 parent 177e0f3 commit 2085690

24 files changed

Lines changed: 351 additions & 386 deletions

tests/doc/test_backward.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55

66
from torch.testing import assert_close
7-
from unit.conftest import DEVICE
87

98

109
def test_backward():
@@ -13,11 +12,11 @@ def test_backward():
1312
from torchjd import backward
1413
from torchjd.aggregation import UPGrad
1514

16-
param = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
15+
param = torch.tensor([1.0, 2.0], requires_grad=True)
1716
# Compute arbitrary quantities that are function of param
18-
y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ param
17+
y1 = torch.tensor([-1.0, 1.0]) @ param
1918
y2 = (param**2).sum()
2019

2120
backward([y1, y2], UPGrad())
2221

23-
assert_close(param.grad, torch.tensor([0.5000, 2.5000], device=DEVICE), rtol=0.0, atol=1e-04)
22+
assert_close(param.grad, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04)

tests/unit/aggregation/_inputs.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
from torch import Tensor
3-
from unit.conftest import DEVICE
43

54

65
def _check_valid_dimensions(n_rows: int, n_cols: int) -> None:
@@ -37,9 +36,9 @@ def _augment_orthogonal_matrix(orthogonal_matrix: Tensor) -> Tensor:
3736

3837
n_rows = orthogonal_matrix.shape[0]
3938
projection = orthogonal_matrix @ orthogonal_matrix.T
40-
zero = torch.zeros([n_rows], device=DEVICE)
39+
zero = torch.zeros([n_rows])
4140
while True:
42-
random_vector = torch.randn([n_rows], device=DEVICE)
41+
random_vector = torch.randn([n_rows])
4342
projected_vector = random_vector - projection @ random_vector
4443
if not torch.allclose(projected_vector, zero):
4544
break
@@ -70,7 +69,7 @@ def _generate_unitary_matrix(n_rows: int, n_cols: int) -> Tensor:
7069
"""Generates a unitary matrix of shape [n_rows, n_cols]."""
7170

7271
_check_valid_dimensions(n_rows, n_cols)
73-
partial_matrix = torch.randn([n_rows, 1], device=DEVICE)
72+
partial_matrix = torch.randn([n_rows, 1])
7473
partial_matrix = torch.nn.functional.normalize(partial_matrix, dim=0)
7574

7675
unitary_matrix = _complete_orthogonal_matrix(partial_matrix, n_cols)
@@ -83,7 +82,7 @@ def _generate_unitary_matrix_with_positive_column(n_rows: int, n_cols: int) -> T
8382
positive vector.
8483
"""
8584
_check_valid_dimensions(n_rows, n_cols)
86-
partial_matrix = torch.abs(torch.randn([n_rows, 1], device=DEVICE))
85+
partial_matrix = torch.abs(torch.randn([n_rows, 1]))
8786
partial_matrix = torch.nn.functional.normalize(partial_matrix, dim=0)
8887

8988
unitary_matrix_with_positive_column = _complete_orthogonal_matrix(partial_matrix, n_cols)
@@ -94,7 +93,7 @@ def _generate_diagonal_singular_values(rank: int) -> Tensor:
9493
"""
9594
generates a diagonal matrix of positive values sorted in descending order.
9695
"""
97-
singular_values = torch.abs(torch.randn([rank], device=DEVICE))
96+
singular_values = torch.abs(torch.randn([rank]))
9897
singular_values = torch.sort(singular_values, descending=True)[0]
9998
S = torch.diag(singular_values)
10099
return S
@@ -108,7 +107,7 @@ def generate_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
108107
_check_valid_rank(n_rows, n_cols, rank)
109108

110109
if rank == 0:
111-
matrix = torch.zeros([n_rows, n_cols], device=DEVICE)
110+
matrix = torch.zeros([n_rows, n_cols])
112111
else:
113112
U = _generate_unitary_matrix(n_rows, rank)
114113
V = _generate_unitary_matrix(n_cols, rank)
@@ -126,7 +125,7 @@ def generate_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
126125

127126
_check_valid_rank(n_rows, n_cols, rank)
128127
if rank == 0:
129-
matrix = torch.zeros([n_rows, n_cols], device=DEVICE)
128+
matrix = torch.zeros([n_rows, n_cols])
130129
else:
131130
U = _generate_unitary_matrix_with_positive_column(n_rows, rank)
132131
V = _generate_unitary_matrix(n_cols, rank)
@@ -161,9 +160,7 @@ def generate_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
161160
generate_matrix(n_rows, n_cols, rank) for n_rows, n_cols, rank in _matrix_dimension_triples
162161
]
163162
scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices]
164-
zero_rank_matrices = [
165-
torch.zeros([n_rows, n_cols], device=DEVICE) for n_rows, n_cols in _zero_rank_matrix_shapes
166-
]
163+
zero_rank_matrices = [torch.zeros([n_rows, n_cols]) for n_rows, n_cols in _zero_rank_matrix_shapes]
167164
matrices_2_plus_rows = [matrix for matrix in matrices + zero_rank_matrices if matrix.shape[0] >= 2]
168165
scaled_matrices_2_plus_rows = [
169166
matrix for matrix in scaled_matrices + zero_rank_matrices if matrix.shape[0] >= 2

tests/unit/aggregation/test_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55
from pytest import mark, raises
66
from unit._utils import ExceptionContext
7-
from unit.conftest import DEVICE
87

98
from torchjd.aggregation import Aggregator
109

@@ -21,4 +20,4 @@
2120
)
2221
def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext):
2322
with expectation:
24-
Aggregator._check_is_matrix(torch.randn(shape, device=DEVICE))
23+
Aggregator._check_is_matrix(torch.randn(shape))

tests/unit/aggregation/test_constant.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from pytest import mark
33
from torch import Tensor
4-
from unit.conftest import DEVICE
54

65
from torchjd.aggregation import Constant
76

@@ -16,7 +15,7 @@
1615

1716
def _make_aggregator(matrix: Tensor) -> Constant:
1817
n_rows = matrix.shape[0]
19-
weights = torch.tensor([1.0 / n_rows] * n_rows, device=DEVICE)
18+
weights = torch.tensor([1.0 / n_rows] * n_rows)
2019
return Constant(weights)
2120

2221

@@ -38,6 +37,6 @@ def test_expected_structure_property(cls, aggregator: Constant, matrix: Tensor):
3837

3938

4039
def test_representations():
41-
A = Constant(weights=torch.tensor([1.0, 2.0]))
40+
A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu"))
4241
assert repr(A) == "Constant(weights=tensor([1., 2.]))"
4342
assert str(A) == "Constant([1., 2.])"

tests/unit/aggregation/test_graddrop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class TestGradDrop(ExpectedStructureProperty):
1212

1313

1414
def test_representations():
15-
A = GradDrop(leak=torch.tensor([0.0, 1.0]))
15+
A = GradDrop(leak=torch.tensor([0.0, 1.0], device="cpu"))
1616
assert repr(A) == "GradDrop(leak=tensor([0., 1.]))"
1717
assert str(A) == "GradDrop([0., 1.])"
1818

tests/unit/aggregation/test_imtl_g.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from pytest import mark
33
from torch.testing import assert_close
4-
from unit.conftest import DEVICE
54

65
from torchjd.aggregation import IMTLG
76

@@ -20,8 +19,8 @@ def test_imtlg_zero():
2019
"""
2120

2221
A = IMTLG()
23-
J = torch.zeros(2, 3, device=DEVICE)
24-
assert_close(A(J), torch.zeros(3, device=DEVICE))
22+
J = torch.zeros(2, 3)
23+
assert_close(A(J), torch.zeros(3))
2524

2625

2726
def test_representations():

tests/unit/aggregation/test_mgda.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from pytest import mark
33
from torch.testing import assert_close
4-
from unit.conftest import DEVICE
54

65
from torchjd.aggregation import MGDA
76
from torchjd.aggregation.mgda import _MGDAWeighting
@@ -29,7 +28,7 @@ class TestMGDA(ExpectedStructureProperty, NonConflictingProperty, PermutationInv
2928
],
3029
)
3130
def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]):
32-
matrix = torch.randn(shape, device=DEVICE)
31+
matrix = torch.randn(shape)
3332
weighting = _MGDAWeighting(epsilon=1e-05, max_iters=1000)
3433

3534
gramian = matrix @ matrix.T
@@ -45,7 +44,7 @@ def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]):
4544
assert_close(positive_weights.norm(), weights.norm())
4645

4746
weights_sum = weights.sum()
48-
assert_close(weights_sum, torch.ones([], device=DEVICE))
47+
assert_close(weights_sum, torch.ones([]))
4948

5049
# Dual feasibility
5150
positive_mu = mu[mu >= 0]

tests/unit/aggregation/test_pcgrad.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from pytest import mark
33
from torch.testing import assert_close
4-
from unit.conftest import DEVICE
54

65
from torchjd.aggregation import PCGrad
76
from torchjd.aggregation.pcgrad import _PCGradWeighting
@@ -37,7 +36,7 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]):
3736
rows.
3837
"""
3938

40-
matrix = torch.randn(shape, device=DEVICE)
39+
matrix = torch.randn(shape)
4140

4241
pc_grad_weighting = _PCGradWeighting()
4342
upgrad_sum_weighting = _UPGradWrapper(

tests/unit/aggregation/test_upgrad.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from pytest import mark
33
from torch.testing import assert_close
4-
from unit.conftest import DEVICE
54

65
from torchjd.aggregation import UPGrad
76
from torchjd.aggregation.mean import _MeanWeighting
@@ -21,8 +20,8 @@ class TestUPGrad(ExpectedStructureProperty, NonConflictingProperty, PermutationI
2120

2221
@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
2322
def test_upgrad_lagrangian_satisfies_kkt_conditions(shape: tuple[int, int]):
24-
matrix = torch.randn(shape, device=DEVICE)
25-
weights = torch.rand(shape[0], device=DEVICE)
23+
matrix = torch.randn(shape)
24+
weights = torch.rand(shape[0])
2625

2726
gramian = matrix @ matrix.T
2827

@@ -41,7 +40,7 @@ def test_upgrad_lagrangian_satisfies_kkt_conditions(shape: tuple[int, int]):
4140
assert_close(positive_constraint.norm(), constraint.norm(), atol=1e-04, rtol=0)
4241

4342
slackness = torch.trace(lagrange_multiplier @ constraint)
44-
assert_close(slackness, torch.zeros_like(slackness, device=DEVICE), atol=3e-03, rtol=0)
43+
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)
4544

4645

4746
def test_representations():

tests/unit/autojac/_transform/test_accumulate.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
from pytest import mark, raises
3-
from unit.conftest import DEVICE
43

54
from torchjd.autojac._transform import Accumulate, Gradients
65

@@ -13,12 +12,12 @@ def test_single_accumulation():
1312
once.
1413
"""
1514

16-
key1 = torch.zeros([], requires_grad=True, device=DEVICE)
17-
key2 = torch.zeros([1], requires_grad=True, device=DEVICE)
18-
key3 = torch.zeros([2, 3], requires_grad=True, device=DEVICE)
19-
value1 = torch.ones([], device=DEVICE)
20-
value2 = torch.ones([1], device=DEVICE)
21-
value3 = torch.ones([2, 3], device=DEVICE)
15+
key1 = torch.zeros([], requires_grad=True)
16+
key2 = torch.zeros([1], requires_grad=True)
17+
key3 = torch.zeros([2, 3], requires_grad=True)
18+
value1 = torch.ones([])
19+
value2 = torch.ones([1])
20+
value3 = torch.ones([2, 3])
2221
input = Gradients({key1: value1, key2: value2, key3: value3})
2322

2423
accumulate = Accumulate([key1, key2, key3])
@@ -41,12 +40,12 @@ def test_multiple_accumulation(iterations: int):
4140
`iterations` times.
4241
"""
4342

44-
key1 = torch.zeros([], requires_grad=True, device=DEVICE)
45-
key2 = torch.zeros([1], requires_grad=True, device=DEVICE)
46-
key3 = torch.zeros([2, 3], requires_grad=True, device=DEVICE)
47-
value1 = torch.ones([], device=DEVICE)
48-
value2 = torch.ones([1], device=DEVICE)
49-
value3 = torch.ones([2, 3], device=DEVICE)
43+
key1 = torch.zeros([], requires_grad=True)
44+
key2 = torch.zeros([1], requires_grad=True)
45+
key3 = torch.zeros([2, 3], requires_grad=True)
46+
value1 = torch.ones([])
47+
value2 = torch.ones([1])
48+
value3 = torch.ones([2, 3])
5049
input = Gradients({key1: value1, key2: value2, key3: value3})
5150

5251
accumulate = Accumulate([key1, key2, key3])
@@ -70,8 +69,8 @@ def test_no_requires_grad_fails():
7069
tensor that does not require grad.
7170
"""
7271

73-
key = torch.zeros([1], requires_grad=False, device=DEVICE)
74-
value = torch.ones([1], device=DEVICE)
72+
key = torch.zeros([1], requires_grad=False)
73+
value = torch.ones([1])
7574
input = Gradients({key: value})
7675

7776
accumulate = Accumulate([key])
@@ -86,8 +85,8 @@ def test_no_leaf_and_no_retains_grad_fails():
8685
tensor that is not a leaf and that does not retain grad.
8786
"""
8887

89-
key = torch.tensor([1.0], requires_grad=True, device=DEVICE) * 2
90-
value = torch.ones([1], device=DEVICE)
88+
key = torch.tensor([1.0], requires_grad=True) * 2
89+
value = torch.ones([1])
9190
input = Gradients({key: value})
9291

9392
accumulate = Accumulate([key])

0 commit comments

Comments
 (0)