Skip to content

Commit 69f56f6

Browse files
authored
Simplify _get_leaf_tensors (#209)
* Change it to only work with tensors and excluded that have a grad_fn * Stop making it exclude the excluded tensors themselves (only exclude their grad_fn from the graph traversal) * Remove tests that expected the leaves to be excluded * Add tests to verify that errors are raised when necessary
1 parent 3852fc2 commit 69f56f6

2 files changed

Lines changed: 46 additions & 48 deletions

File tree

src/torchjd/autojac/_utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,26 @@ def _get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) ->
3737
"""
3838
Gets the leaves of the autograd graph of all specified ``tensors``.
3939
40-
:param tensors: Tensors from which the graph traversal should start.
41-
:param excluded: Tensors that should be excluded from the results and whose grad_fn should be
42-
excluded from the graph traversal.
40+
:param tensors: Tensors from which the graph traversal should start. They should all require
41+
grad and not be leaves.
42+
:param excluded: Tensors whose grad_fn should be excluded from the graph traversal. They should
43+
all require grad and not be leaves.
44+
4345
"""
4446

47+
if any([tensor.grad_fn is None for tensor in tensors]):
48+
raise ValueError("All `tensors` should have a `grad_fn`.")
49+
50+
if any([tensor.grad_fn is None for tensor in excluded]):
51+
raise ValueError("All `excluded` tensors should have a `grad_fn`.")
52+
4553
accumulate_grads = _get_descendant_accumulate_grads(
46-
roots={tensor.grad_fn for tensor in tensors if tensor.grad_fn is not None},
54+
roots={tensor.grad_fn for tensor in tensors},
4755
excluded_nodes={tensor.grad_fn for tensor in excluded},
4856
)
4957
leaves = {g.variable for g in accumulate_grads}
5058

51-
return leaves - set(excluded)
59+
return leaves
5260

5361

5462
def _get_descendant_accumulate_grads(roots: set[Node], excluded_nodes: set[Node]) -> set[Node]:

tests/unit/autojac/test_utils.py

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from pytest import mark
2+
from pytest import mark, raises
33
from torch.nn import Linear, MSELoss, ReLU, Sequential
44
from unit.conftest import DEVICE
55

@@ -61,24 +61,6 @@ def test_get_leaf_tensors_excluded_2():
6161
assert leaves == {p1, p2}
6262

6363

64-
def test_get_leaf_tensors_excluded_3():
65-
"""
66-
Tests that _get_leaf_tensors works correctly when some tensors are excluded from the search.
67-
68-
In this case, one of the leaves itself is excluded.
69-
"""
70-
71-
p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
72-
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
73-
p3 = torch.tensor([5.0, 6.0], requires_grad=True, device=DEVICE)
74-
75-
y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum()
76-
y2 = (p1**2).sum() + p2.norm() + p3.sum()
77-
78-
leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={p3})
79-
assert leaves == {p1, p2}
80-
81-
8264
def test_get_leaf_tensors_leaf_not_requiring_grad():
8365
"""
8466
Tests that _get_leaf_tensors does not include tensors that do not require grad in its results.
@@ -113,25 +95,6 @@ def test_get_leaf_tensors_model():
11395
assert leaves == set(model.parameters())
11496

11597

116-
def test_get_leaf_tensors_model_excluded_1():
117-
"""
118-
Tests that _get_leaf_tensors works correctly when the autograd graph is generated by a simple
119-
sequential model, and some of the model's parameters are excluded.
120-
"""
121-
122-
x = torch.randn(16, 10)
123-
y = torch.randn(16, 1)
124-
125-
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 1))
126-
loss_fn = MSELoss(reduction="none")
127-
128-
y_hat = model(x)
129-
losses = loss_fn(y_hat, y)
130-
131-
leaves = _get_leaf_tensors(tensors=[losses], excluded=set(model[0].parameters()))
132-
assert leaves == set(model[2].parameters())
133-
134-
13598
def test_get_leaf_tensors_model_excluded_2():
13699
"""
137100
Tests that _get_leaf_tensors works correctly when the autograd graph is generated by a simple
@@ -197,11 +160,38 @@ def test_get_leaf_tensors_deep(depth: int):
197160

198161

199162
def test_get_leaf_tensors_leaf():
163+
"""Tests that _get_leaf_tensors raises an error some of the provided tensors are leaves."""
164+
165+
a = torch.tensor(1.0, requires_grad=True, device=DEVICE)
166+
with raises(ValueError):
167+
_ = _get_leaf_tensors(tensors=[a], excluded=set())
168+
169+
170+
def test_get_leaf_tensors_tensor_not_requiring_grad():
200171
"""
201-
Tests that _get_leaf_tensors correctly returns an empty set when the provided tensors are
202-
leaves.
172+
Tests that _get_leaf_tensors raises an error some of the provided tensors do not require grad.
203173
"""
204174

205-
a = torch.tensor(1.0, requires_grad=True, device=DEVICE)
206-
leaves = _get_leaf_tensors(tensors=[a], excluded=set())
207-
assert leaves == set()
175+
a = torch.tensor(1.0, requires_grad=False, device=DEVICE) * 2
176+
with raises(ValueError):
177+
_ = _get_leaf_tensors(tensors=[a], excluded=set())
178+
179+
180+
def test_get_leaf_tensors_excluded_leaf():
181+
"""Tests that _get_leaf_tensors raises an error some of the excluded tensors are leaves."""
182+
183+
a = torch.tensor(1.0, requires_grad=True, device=DEVICE) * 2
184+
b = torch.tensor(2.0, requires_grad=True, device=DEVICE)
185+
with raises(ValueError):
186+
_ = _get_leaf_tensors(tensors=[a], excluded={b})
187+
188+
189+
def test_get_leaf_tensors_excluded_not_requiring_grad():
190+
"""
191+
Tests that _get_leaf_tensors raises an error some of the excluded tensors do not require grad.
192+
"""
193+
194+
a = torch.tensor(1.0, requires_grad=True, device=DEVICE) * 2
195+
b = torch.tensor(2.0, requires_grad=False, device=DEVICE) * 2
196+
with raises(ValueError):
197+
_ = _get_leaf_tensors(tensors=[a], excluded={b})

0 commit comments

Comments
 (0)