Skip to content

Commit 9c29977

Browse files
authored
refactor(autojac): Add OrderedSet.__add__ (#313)
* Use it in mtl_backward rather than creating a new OrderedSet and updating it
1 parent 2b26100 commit 9c29977

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

src/torchjd/autojac/_transform/ordered_set.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,8 @@ def add(self, element: _KeyType) -> None:
2222
"""Adds the specified element to the OrderedSet."""
2323

2424
self[element] = None
25+
26+
def __add__(self, other: "OrderedSet[_KeyType]") -> "OrderedSet[_KeyType]":
27+
"""Creates a new OrderedSet with the elements of self followed by the elements of other."""
28+
29+
return OrderedSet([*self, *other])

src/torchjd/autojac/mtl_backward.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,7 @@ def _create_task_transform(
168168
retain_graph: bool,
169169
) -> Transform[EmptyTensorDict, Gradients]:
170170
# Tensors with respect to which we compute the gradients.
171-
to_differentiate = OrderedSet(task_params) # Re-instantiate set to avoid modifying input
172-
to_differentiate.update(features)
171+
to_differentiate = task_params + features
173172

174173
# Transform that initializes the gradient output to 1.
175174
init = Init(loss)

0 commit comments

Comments
 (0)