-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path_gramian_computer.py
More file actions
46 lines (35 loc) · 1.44 KB
/
_gramian_computer.py
File metadata and controls
46 lines (35 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils._pytree import PyTree
from torchjd.autogram._jacobian_computer import JacobianComputer
class GramianComputer(ABC):
@abstractmethod
def __call__(
self,
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Tensor:
"""Compute what we can for a module and optionally return the gramian if it's ready."""
class JacobianBasedGramianComputer(GramianComputer, ABC):
def __init__(self, jacobian_computer: JacobianComputer):
self.jacobian_computer = jacobian_computer
@staticmethod
def _to_gramian(jacobian: Tensor) -> Tensor:
return jacobian @ jacobian.T
class JacobianBasedGramianComputerWithoutCrossTerms(JacobianBasedGramianComputer):
"""
Stateful JacobianBasedGramianComputer that directly returning the gramian without considering
cross-terms (except intra-module cross-terms).
"""
def __call__(
self,
rg_outputs: tuple[Tensor, ...],
grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Tensor:
"""Compute what we can for a module and optionally return the gramian if it's ready."""
jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)
return self._to_gramian(jacobian_matrix)