Skip to content

Commit 55f7589

Browse files
committed
refactor: Address Steffen's review comments on PR #296
- Refactored _safe_torch_load() to use recursion instead of duplicate logic - Added meaningful error messages when CPU fallback fails - Added UserWarning when auto-remapping CUDA/MPS to CPU - Extended _resolve_checkpoint_device() to handle MPS fallback - Added test for MPS checkpoint fallback - Added test for meaningful error on retry failure - Added test for error with explicit map_location - Created tests/generate_cuda_checkpoint.py utility for GPU test data - Removed binary checkpoint files from repo - Updated .gitignore to exclude test checkpoint binaries All 53 tests pass (14 CUDA/MPS tests + 39 regression tests)
1 parent 97d5b90 commit 55f7589

51 files changed

Lines changed: 330 additions & 96 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ saved_models/
3535
*.svg
3636
*.mp4
3737

38+
# Test checkpoint binaries (generate with tests/generate_cuda_checkpoint.py)
39+
tests/test_data/cuda_saved_checkpoint/
40+
tests/test_data/*.pt
41+
3842
## gitignore for Python
3943
## Source: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
4044
# Byte-compiled / optimized / DLL files

cebra/integrations/sklearn/cebra.py

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,27 @@ def check_version(estimator):
6969
sklearn.__version__) < packaging.version.parse("1.6.dev")
7070

7171

72-
def _safe_torch_load(filename, weights_only=False, **kwargs):
73-
checkpoint = None
72+
def _safe_torch_load(filename, weights_only=False, _is_retry=False, **kwargs):
73+
"""Load a checkpoint with automatic CUDA/MPS to CPU fallback.
74+
75+
If loading fails due to a CUDA or MPS device error (e.g. checkpoint was
76+
saved on a GPU but no GPU is available), this function automatically retries
77+
with ``map_location="cpu"`` and issues a warning.
78+
79+
Args:
80+
filename: Path to the checkpoint file.
81+
weights_only: Passed through to :func:`torch.load`.
82+
_is_retry: Internal flag to prevent infinite recursion. Do not set
83+
this manually.
84+
**kwargs: Additional keyword arguments forwarded to :func:`torch.load`.
85+
86+
Returns:
87+
The loaded checkpoint dictionary.
88+
89+
Raises:
90+
RuntimeError: If loading fails on the retry attempt or for non-device
91+
related reasons.
92+
"""
7493
legacy_mode = packaging.version.parse(
7594
torch.__version__) < packaging.version.parse("2.6.0")
7695

@@ -83,23 +102,46 @@ def _safe_torch_load(filename, weights_only=False, **kwargs):
83102
weights_only=weights_only,
84103
**kwargs)
85104
except RuntimeError as e:
86-
# Handle CUDA deserialization errors by retrying with map_location='cpu'
87-
if "CUDA" in str(e) or "cuda" in str(e).lower():
88-
if "map_location" not in kwargs:
89-
kwargs["map_location"] = torch.device("cpu")
90-
if legacy_mode:
91-
checkpoint = torch.load(filename, weights_only=False,
92-
**kwargs)
93-
else:
94-
with torch.serialization.safe_globals(
95-
CEBRA_LOAD_SAFE_GLOBALS):
96-
checkpoint = torch.load(filename,
97-
weights_only=weights_only,
98-
**kwargs)
99-
else:
100-
raise
101-
else:
102-
raise
105+
error_msg = str(e)
106+
is_device_error = ("CUDA" in error_msg
107+
or "cuda" in error_msg.lower()
108+
or "MPS" in error_msg
109+
or "mps" in error_msg.lower())
110+
if is_device_error:
111+
if _is_retry:
112+
raise RuntimeError(
113+
f"Failed to load checkpoint even with map_location='cpu'. "
114+
f"The checkpoint appears to require a device (CUDA/MPS) "
115+
f"that is not available in the current environment. "
116+
f"Please verify your PyTorch installation or load on a "
117+
f"machine with the required hardware. "
118+
f"Original error: {e}"
119+
) from e
120+
if "map_location" in kwargs:
121+
raise RuntimeError(
122+
f"Loading the checkpoint failed with a device error even "
123+
f"though map_location={kwargs['map_location']!r} was "
124+
f"explicitly specified. The checkpoint was likely saved on "
125+
f"a CUDA/MPS device that is not available. Please check "
126+
f"your PyTorch installation or use a machine with the "
127+
f"required hardware. Original error: {e}"
128+
) from e
129+
warnings.warn(
130+
f"Checkpoint was saved on a device that is not available "
131+
f"(error: {error_msg}). Automatically falling back to CPU. "
132+
f"To suppress this warning, pass map_location='cpu' "
133+
f"explicitly.",
134+
UserWarning,
135+
stacklevel=2,
136+
)
137+
kwargs["map_location"] = torch.device("cpu")
138+
return _safe_torch_load(
139+
filename,
140+
weights_only=weights_only,
141+
_is_retry=True,
142+
**kwargs,
143+
)
144+
raise
103145

104146
if not isinstance(checkpoint, dict):
105147
_check_type_checkpoint(checkpoint)
@@ -356,8 +398,8 @@ def _check_type_checkpoint(checkpoint):
356398
def _resolve_checkpoint_device(device):
357399
"""Resolve the device stored in a checkpoint for the current runtime.
358400
359-
If a checkpoint was saved on CUDA and CUDA is unavailable at load time, this
360-
falls back to CPU.
401+
If a checkpoint was saved on a device (CUDA, MPS, ...) that is unavailable
402+
at load time, this falls back to CPU and issues a warning.
361403
362404
Args:
363405
device: The device from the checkpoint (str or torch.device).
@@ -373,7 +415,22 @@ def _resolve_checkpoint_device(device):
373415
"Expected checkpoint device to be a string or torch.device, "
374416
f"got {type(device)}.")
375417

418+
fallback_to_cpu = False
419+
376420
if device.startswith("cuda") and not torch.cuda.is_available():
421+
fallback_to_cpu = True
422+
elif device.startswith("mps") and (
423+
not hasattr(torch.backends, "mps")
424+
or not torch.backends.mps.is_available()):
425+
fallback_to_cpu = True
426+
427+
if fallback_to_cpu:
428+
warnings.warn(
429+
f"Checkpoint was saved on '{device}' which is not available in "
430+
f"the current environment. Automatically falling back to CPU.",
431+
UserWarning,
432+
stacklevel=2,
433+
)
377434
return "cpu"
378435

379436
return sklearn_utils.check_device(device)

tests/generate_cuda_checkpoint.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#!/usr/bin/env python
2+
"""Generate a CUDA-saved checkpoint for integration testing.
3+
4+
Run this script on a machine with a CUDA GPU to produce a checkpoint file
5+
that can be used to verify the CUDA-to-CPU loading fallback in a CI
6+
environment (which typically has no GPU).
7+
8+
Usage::
9+
10+
# Default output path
11+
python tests/generate_cuda_checkpoint.py
12+
13+
# Custom output path
14+
python tests/generate_cuda_checkpoint.py --output /tmp/cuda_checkpoint.pt
15+
16+
# Verify an existing checkpoint
17+
python tests/generate_cuda_checkpoint.py --verify tests/test_data/cuda_checkpoint.pt
18+
19+
Requirements:
20+
- PyTorch with CUDA support (``torch.cuda.is_available()`` must be True)
21+
- CEBRA installed (``pip install -e .`` from the repo root)
22+
23+
The generated file is a standard ``torch.save`` checkpoint in the CEBRA
24+
sklearn format. It contains CUDA tensors, so loading it on a CPU-only
25+
machine *without* the fallback logic will fail with::
26+
27+
RuntimeError: Attempting to deserialize object on a CUDA device but
28+
torch.cuda.is_available() is False.
29+
"""
30+
31+
import argparse
32+
import os
33+
import sys
34+
35+
import numpy as np
36+
import torch
37+
38+
39+
def generate(output_path: str) -> None:
40+
"""Train a minimal CEBRA model on CUDA and save the checkpoint."""
41+
if not torch.cuda.is_available():
42+
print("ERROR: CUDA is not available. Run this on a GPU machine.",
43+
file=sys.stderr)
44+
sys.exit(1)
45+
46+
import cebra
47+
48+
print(f"PyTorch {torch.__version__}, CUDA {torch.version.cuda}")
49+
print(f"Device: {torch.cuda.get_device_name(0)}")
50+
51+
# Train a tiny model on GPU
52+
X = np.random.uniform(0, 1, (200, 10)).astype(np.float32)
53+
model = cebra.CEBRA(
54+
model_architecture="offset1-model",
55+
max_iterations=10,
56+
batch_size=64,
57+
output_dimension=4,
58+
device="cuda",
59+
verbose=False,
60+
)
61+
model.fit(X)
62+
63+
# Sanity-check: model params should live on CUDA
64+
param_device = next(model.solver_.model.parameters()).device
65+
assert param_device.type == "cuda", f"Expected cuda, got {param_device}"
66+
67+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
68+
model.save(output_path)
69+
print(f"Saved CUDA checkpoint to {output_path}")
70+
71+
# Verify round-trip on GPU
72+
loaded = cebra.CEBRA.load(output_path)
73+
emb = loaded.transform(X)
74+
assert emb.shape == (200, 4), f"Unexpected shape: {emb.shape}"
75+
print("Round-trip verification on GPU: OK")
76+
77+
78+
def verify(path: str) -> None:
79+
"""Load a checkpoint on CPU and confirm the fallback works."""
80+
import cebra
81+
82+
if not os.path.exists(path):
83+
print(f"ERROR: {path} does not exist.", file=sys.stderr)
84+
sys.exit(1)
85+
86+
print(f"Loading checkpoint from {path} ...")
87+
model = cebra.CEBRA.load(path)
88+
print(f" device_: {model.device_}")
89+
print(f" device: {model.device}")
90+
91+
X = np.random.uniform(0, 1, (50, model.n_features_)).astype(np.float32)
92+
emb = model.transform(X)
93+
print(f" transform shape: {emb.shape}")
94+
print("Verification: OK")
95+
96+
97+
def main() -> None:
98+
parser = argparse.ArgumentParser(description=__doc__,
99+
formatter_class=argparse.RawDescriptionHelpFormatter)
100+
parser.add_argument(
101+
"--output",
102+
default="tests/test_data/cuda_checkpoint.pt",
103+
help="Output path for the generated checkpoint (default: tests/test_data/cuda_checkpoint.pt)",
104+
)
105+
parser.add_argument(
106+
"--verify",
107+
metavar="PATH",
108+
help="Instead of generating, verify an existing checkpoint can be loaded.",
109+
)
110+
args = parser.parse_args()
111+
112+
if args.verify:
113+
verify(args.verify)
114+
else:
115+
generate(args.output)
116+
117+
118+
if __name__ == "__main__":
119+
main()

tests/test_data/cuda_saved_checkpoint/.data/serialization_id

Lines changed: 0 additions & 1 deletion
This file was deleted.

tests/test_data/cuda_saved_checkpoint/.format_version

Lines changed: 0 additions & 1 deletion
This file was deleted.

tests/test_data/cuda_saved_checkpoint/.storage_alignment

Lines changed: 0 additions & 1 deletion
This file was deleted.

tests/test_data/cuda_saved_checkpoint/byteorder

Lines changed: 0 additions & 1 deletion
This file was deleted.
-182 KB
Binary file not shown.
-23.3 KB
Binary file not shown.

tests/test_data/cuda_saved_checkpoint/data/1

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)