|
5 | 5 | import torch |
6 | 6 | from torch import Tensor, nn |
7 | 7 | from torch.nn import Parameter |
| 8 | +from torch.overrides import is_tensor_like |
8 | 9 | from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only |
9 | 10 |
|
10 | 11 | from torchjd._linalg import Matrix |
@@ -83,8 +84,8 @@ def _compute_jacobian( |
83 | 84 | /, |
84 | 85 | ) -> Matrix: |
85 | 86 | grad_outputs_in_dims = (0,) * len(grad_outputs) |
86 | | - args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) |
87 | | - kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs) |
| 87 | + args_in_dims = tree_map(lambda t: 0 if is_tensor_like(t) else None, args) |
| 88 | + kwargs_in_dims = tree_map(lambda t: 0 if is_tensor_like(t) else None, kwargs) |
88 | 89 | in_dims = (grad_outputs_in_dims, args_in_dims, kwargs_in_dims) |
89 | 90 | vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) |
90 | 91 |
|
@@ -114,7 +115,7 @@ def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...] |
114 | 115 | ] |
115 | 116 | output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j) |
116 | 117 | flat_outputs = tree_flatten(output)[0] |
117 | | - rg_outputs = tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad) |
| 118 | + rg_outputs = tuple(t for t in flat_outputs if is_tensor_like(t) and t.requires_grad) |
118 | 119 | return rg_outputs |
119 | 120 |
|
120 | 121 | vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1] |
|
0 commit comments