We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2a10f27 commit 277c082Copy full SHA for 277c082
1 file changed
modules/optimizer/chained_optimizer.py
@@ -1,3 +1,4 @@
1
+import torch
2
from torch import Tensor
3
from torch.optim import Optimizer
4
from torch.optim.optimizer import ParamsT
@@ -87,9 +88,14 @@ def _copy_lr_to_optimizers(self) -> None:
87
88
self.optimizers[optimizer_idx].param_groups[param_group_idx]["lr"] = param_group["lr"]
89
90
def step(self, closure=None) -> None:
91
+ loss = None
92
+ if closure is not None:
93
+ with torch.enable_grad():
94
+ loss = closure()
95
self._copy_lr_to_optimizers()
96
for opt in self.optimizers:
- opt.step(closure)
97
+ opt.step(closure=None)
98
+ return loss
99
100
def add_param_group(self, param_group: Dict[str, Any]) -> None:
101
super().add_param_group(param_group)
0 commit comments