Skip to content

Commit add549c

Browse files
committed
Add _can_skip_jacobian_combination helper function
Tbh I don't like it very much (because it's an extra function + some cast is required) but it's the only way to easily test that the correct aggregators use the optimized _gramian_based method. I also tried using return type hint of TypeGuard[GramianWeightedAggergator] instead of bool for _can_skip_jacobian_combination, but it's not really correct since we also check that the aggregator has no forward hook, so that TypeGuard would be really weird. So in the end we have to use this cast.
1 parent 4cf5cbb commit add549c

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,17 @@ def jac_to_grad(
7575
if not retain_jac:
7676
_free_jacs(tensors_)
7777

78-
if isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator):
79-
# When it's possible, avoid the concatenation of the jacobians that can be very costly in
80-
# memory.
81-
gradients = _gramian_based(aggregator, jacobians, tensors_)
78+
if _can_skip_jacobian_combination(aggregator):
79+
gradients = _gramian_based(cast(GramianWeightedAggregator, aggregator), jacobians, tensors_)
8280
else:
8381
gradients = _jacobian_based(aggregator, jacobians, tensors_)
8482
accumulate_grads(tensors_, gradients)
8583

8684

85+
def _can_skip_jacobian_combination(aggregator: Aggregator) -> bool:
86+
return isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator)
87+
88+
8789
def _has_forward_hook(module: nn.Module) -> bool:
8890
"""Return whether the module has any forward hook registered."""
8991
return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0

0 commit comments

Comments
 (0)