Skip to content

Commit 948fc63

Browse files
committed
fix bugs in model restore.
1 parent 19f1077 commit 948fc63

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

codebase/torchutils/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ def save(self, minitor: str, metrics: dict, states: dict):
331331
def restore(self, metrics: dict, states: dict, device="cuda:0"):
332332
checkpoint_path = self.output_directory / "checkpoint.pt"
333333
if checkpoint_path.exists():
334-
checkpoint: dict = torch.load(checkpoint_path, map_location=device)
334+
map_location= f"cuda:{device}" if isinstance(device, int) else device
335+
checkpoint: dict = torch.load(checkpoint_path, map_location=map_location)
335336
metrics.update(checkpoint.pop("metrics", dict()))
336337
for name, module in states.items():
337338
module.load_state_dict(checkpoint[name])

0 commit comments

Comments
 (0)