|
1 | 1 | import torch |
2 | | -from pytest import mark |
| 2 | +from pytest import mark, raises |
3 | 3 | from torch.nn import Linear, MSELoss, ReLU, Sequential |
4 | 4 | from unit.conftest import DEVICE |
5 | 5 |
|
@@ -61,24 +61,6 @@ def test_get_leaf_tensors_excluded_2(): |
61 | 61 | assert leaves == {p1, p2} |
62 | 62 |
|
63 | 63 |
|
64 | | -def test_get_leaf_tensors_excluded_3(): |
65 | | - """ |
66 | | - Tests that _get_leaf_tensors works correctly when some tensors are excluded from the search. |
67 | | -
|
68 | | - In this case, one of the leaves itself is excluded. |
69 | | - """ |
70 | | - |
71 | | - p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE) |
72 | | - p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE) |
73 | | - p3 = torch.tensor([5.0, 6.0], requires_grad=True, device=DEVICE) |
74 | | - |
75 | | - y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum() |
76 | | - y2 = (p1**2).sum() + p2.norm() + p3.sum() |
77 | | - |
78 | | - leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={p3}) |
79 | | - assert leaves == {p1, p2} |
80 | | - |
81 | | - |
82 | 64 | def test_get_leaf_tensors_leaf_not_requiring_grad(): |
83 | 65 | """ |
84 | 66 | Tests that _get_leaf_tensors does not include tensors that do not require grad in its results. |
@@ -113,25 +95,6 @@ def test_get_leaf_tensors_model(): |
113 | 95 | assert leaves == set(model.parameters()) |
114 | 96 |
|
115 | 97 |
|
116 | | -def test_get_leaf_tensors_model_excluded_1(): |
117 | | - """ |
118 | | - Tests that _get_leaf_tensors works correctly when the autograd graph is generated by a simple |
119 | | - sequential model, and some of the model's parameters are excluded. |
120 | | - """ |
121 | | - |
122 | | - x = torch.randn(16, 10) |
123 | | - y = torch.randn(16, 1) |
124 | | - |
125 | | - model = Sequential(Linear(10, 5), ReLU(), Linear(5, 1)) |
126 | | - loss_fn = MSELoss(reduction="none") |
127 | | - |
128 | | - y_hat = model(x) |
129 | | - losses = loss_fn(y_hat, y) |
130 | | - |
131 | | - leaves = _get_leaf_tensors(tensors=[losses], excluded=set(model[0].parameters())) |
132 | | - assert leaves == set(model[2].parameters()) |
133 | | - |
134 | | - |
135 | 98 | def test_get_leaf_tensors_model_excluded_2(): |
136 | 99 | """ |
137 | 100 | Tests that _get_leaf_tensors works correctly when the autograd graph is generated by a simple |
@@ -197,11 +160,38 @@ def test_get_leaf_tensors_deep(depth: int): |
197 | 160 |
|
198 | 161 |
|
199 | 162 | def test_get_leaf_tensors_leaf(): |
| 163 | + """Tests that _get_leaf_tensors raises an error some of the provided tensors are leaves.""" |
| 164 | + |
| 165 | + a = torch.tensor(1.0, requires_grad=True, device=DEVICE) |
| 166 | + with raises(ValueError): |
| 167 | + _ = _get_leaf_tensors(tensors=[a], excluded=set()) |
| 168 | + |
| 169 | + |
| 170 | +def test_get_leaf_tensors_tensor_not_requiring_grad(): |
200 | 171 | """ |
201 | | - Tests that _get_leaf_tensors correctly returns an empty set when the provided tensors are |
202 | | - leaves. |
| 172 | + Tests that _get_leaf_tensors raises an error some of the provided tensors do not require grad. |
203 | 173 | """ |
204 | 174 |
|
205 | | - a = torch.tensor(1.0, requires_grad=True, device=DEVICE) |
206 | | - leaves = _get_leaf_tensors(tensors=[a], excluded=set()) |
207 | | - assert leaves == set() |
| 175 | + a = torch.tensor(1.0, requires_grad=False, device=DEVICE) * 2 |
| 176 | + with raises(ValueError): |
| 177 | + _ = _get_leaf_tensors(tensors=[a], excluded=set()) |
| 178 | + |
| 179 | + |
| 180 | +def test_get_leaf_tensors_excluded_leaf(): |
| 181 | + """Tests that _get_leaf_tensors raises an error some of the excluded tensors are leaves.""" |
| 182 | + |
| 183 | + a = torch.tensor(1.0, requires_grad=True, device=DEVICE) * 2 |
| 184 | + b = torch.tensor(2.0, requires_grad=True, device=DEVICE) |
| 185 | + with raises(ValueError): |
| 186 | + _ = _get_leaf_tensors(tensors=[a], excluded={b}) |
| 187 | + |
| 188 | + |
| 189 | +def test_get_leaf_tensors_excluded_not_requiring_grad(): |
| 190 | + """ |
| 191 | + Tests that _get_leaf_tensors raises an error some of the excluded tensors do not require grad. |
| 192 | + """ |
| 193 | + |
| 194 | + a = torch.tensor(1.0, requires_grad=True, device=DEVICE) * 2 |
| 195 | + b = torch.tensor(2.0, requires_grad=False, device=DEVICE) * 2 |
| 196 | + with raises(ValueError): |
| 197 | + _ = _get_leaf_tensors(tensors=[a], excluded={b}) |
0 commit comments