Skip to content

Commit 448cca1

Browse files
fix: Use is_tensor_like instead of instance(Tensor, · ) (#583)
* Fix in `_jacobian_computer`. * Fix in `module_hook_manager`. * Fix in `backward`. * Fix in `jac`. * Fix in `autojac._utils`. * For the casts, opened an issue: pytorch/pytorch#175324
1 parent aeef2f4 commit 448cca1

5 files changed

Lines changed: 20 additions & 9 deletions

File tree

src/torchjd/autogram/_jacobian_computer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from torch import Tensor, nn
77
from torch.nn import Parameter
8+
from torch.overrides import is_tensor_like
89
from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only
910

1011
from torchjd._linalg import Matrix
@@ -83,8 +84,8 @@ def _compute_jacobian(
8384
/,
8485
) -> Matrix:
8586
grad_outputs_in_dims = (0,) * len(grad_outputs)
86-
args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args)
87-
kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs)
87+
args_in_dims = tree_map(lambda t: 0 if is_tensor_like(t) else None, args)
88+
kwargs_in_dims = tree_map(lambda t: 0 if is_tensor_like(t) else None, kwargs)
8889
in_dims = (grad_outputs_in_dims, args_in_dims, kwargs_in_dims)
8990
vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims)
9091

@@ -114,7 +115,7 @@ def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...]
114115
]
115116
output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j)
116117
flat_outputs = tree_flatten(output)[0]
117-
rg_outputs = tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad)
118+
rg_outputs = tuple(t for t in flat_outputs if is_tensor_like(t) and t.requires_grad)
118119
return rg_outputs
119120

120121
vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1]

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch import Tensor, nn
66
from torch.autograd.graph import get_gradient_edge
7+
from torch.overrides import is_tensor_like
78
from torch.utils._pytree import PyTree, tree_flatten, tree_unflatten
89
from torch.utils.hooks import RemovableHandle as TorchRemovableHandle
910

@@ -114,7 +115,7 @@ def __call__(
114115
rg_outputs = list[Tensor]()
115116
rg_output_indices = list[int]()
116117
for idx, output in enumerate(flat_outputs):
117-
if isinstance(output, Tensor) and output.requires_grad:
118+
if is_tensor_like(output) and output.requires_grad:
118119
rg_outputs.append(output)
119120
rg_output_indices.append(idx)
120121

src/torchjd/autojac/_backward.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from collections.abc import Iterable, Sequence
2+
from typing import cast
23

34
from torch import Tensor
5+
from torch.overrides import is_tensor_like
46

57
from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform
68
from ._utils import (
@@ -140,7 +142,9 @@ def _create_jac_tensors_dict(
140142
# Transform that turns the gradients into Jacobians.
141143
diag = Diagonalize(tensors)
142144
return (diag << init)({})
143-
jac_tensors = [opt_jac_tensors] if isinstance(opt_jac_tensors, Tensor) else opt_jac_tensors
145+
jac_tensors = cast(
146+
Sequence[Tensor], (opt_jac_tensors,) if is_tensor_like(opt_jac_tensors) else opt_jac_tensors
147+
)
144148
check_matching_length(jac_tensors, tensors, "jac_tensors", "tensors")
145149
check_matching_jac_shapes(jac_tensors, tensors, "jac_tensors", "tensors")
146150
check_consistent_first_dimension(jac_tensors, "jac_tensors")

src/torchjd/autojac/_jac.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from collections.abc import Sequence
2+
from typing import cast
23

34
from torch import Tensor
5+
from torch.overrides import is_tensor_like
46

57
from torchjd.autojac._transform._base import Transform
68
from torchjd.autojac._transform._diagonalize import Diagonalize
@@ -154,7 +156,7 @@ def jac(
154156
raise ValueError("`outputs` cannot be empty.")
155157

156158
# Preserve repetitions to duplicate jacobians at the return statement
157-
inputs_with_repetition = (inputs,) if isinstance(inputs, Tensor) else inputs
159+
inputs_with_repetition = cast(Sequence[Tensor], (inputs,) if is_tensor_like(inputs) else inputs)
158160
inputs_ = OrderedSet(inputs_with_repetition)
159161

160162
jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs)
@@ -180,7 +182,9 @@ def _create_jac_outputs_dict(
180182
# Transform that turns the gradients into Jacobians.
181183
diag = Diagonalize(outputs)
182184
return (diag << init)({})
183-
jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs
185+
jac_outputs = cast(
186+
Sequence[Tensor], (opt_jac_outputs,) if is_tensor_like(opt_jac_outputs) else opt_jac_outputs
187+
)
184188
check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs")
185189
check_matching_jac_shapes(jac_outputs, outputs, "jac_outputs", "outputs")
186190
check_consistent_first_dimension(jac_outputs, "jac_outputs")

src/torchjd/autojac/_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from torch import Tensor
66
from torch.autograd.graph import Node
7+
from torch.overrides import is_tensor_like
78

89
from ._transform import OrderedSet
910

@@ -20,8 +21,8 @@ def as_checked_ordered_set(
2021
tensors: Sequence[Tensor] | Tensor,
2122
variable_name: str,
2223
) -> OrderedSet[Tensor]:
23-
if isinstance(tensors, Tensor):
24-
tensors = [tensors]
24+
if is_tensor_like(tensors):
25+
tensors = (cast(Tensor, tensors),)
2526

2627
original_length = len(tensors)
2728
output = OrderedSet(tensors)

0 commit comments

Comments
 (0)