Skip to content

Commit 277c082

Browse files
committed
fix
1 parent 2a10f27 commit 277c082

1 file changed

Lines changed: 7 additions & 1 deletion

File tree

modules/optimizer/chained_optimizer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
from torch import Tensor
23
from torch.optim import Optimizer
34
from torch.optim.optimizer import ParamsT
@@ -87,9 +88,14 @@ def _copy_lr_to_optimizers(self) -> None:
8788
self.optimizers[optimizer_idx].param_groups[param_group_idx]["lr"] = param_group["lr"]
8889

8990
def step(self, closure=None) -> None:
91+
loss = None
92+
if closure is not None:
93+
with torch.enable_grad():
94+
loss = closure()
9095
self._copy_lr_to_optimizers()
9196
for opt in self.optimizers:
92-
opt.step(closure)
97+
opt.step(closure=None)
98+
return loss
9399

94100
def add_param_group(self, param_group: Dict[str, Any]) -> None:
95101
super().add_param_group(param_group)

0 commit comments

Comments
 (0)