Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ saved_models/
*.svg
*.mp4

# Test checkpoint binaries (generate with tests/generate_cuda_checkpoint.py)
tests/test_data/cuda_saved_checkpoint/
tests/test_data/*.pt

## gitignore for Python
## Source: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
# Byte-compiled / optimized / DLL files
Expand Down
141 changes: 128 additions & 13 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,79 @@ def check_version(estimator):
sklearn.__version__) < packaging.version.parse("1.6.dev")


def _safe_torch_load(filename, weights_only=False, **kwargs):
checkpoint = None
def _safe_torch_load(filename, weights_only=False, _is_retry=False, **kwargs):
"""Load a checkpoint with automatic CUDA/MPS to CPU fallback.

If loading fails due to a CUDA or MPS device error (e.g. checkpoint was
saved on a GPU but no GPU is available), this function automatically retries
with ``map_location="cpu"`` and issues a warning.

Args:
filename: Path to the checkpoint file.
weights_only: Passed through to :func:`torch.load`.
_is_retry: Internal flag to prevent infinite recursion. Do not set
this manually.
**kwargs: Additional keyword arguments forwarded to :func:`torch.load`.

Returns:
The loaded checkpoint dictionary.

Raises:
RuntimeError: If loading fails on the retry attempt or for non-device
related reasons.
"""
legacy_mode = packaging.version.parse(
torch.__version__) < packaging.version.parse("2.6.0")

if legacy_mode:
checkpoint = torch.load(filename, weights_only=False, **kwargs)
else:
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
checkpoint = torch.load(filename,
weights_only=weights_only,
**kwargs)
try:
if legacy_mode:
checkpoint = torch.load(filename, weights_only=False, **kwargs)
else:
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
checkpoint = torch.load(filename,
weights_only=weights_only,
**kwargs)
except RuntimeError as e:
error_msg = str(e)
is_device_error = ("CUDA" in error_msg
or "cuda" in error_msg.lower()
or "MPS" in error_msg
or "mps" in error_msg.lower())
if is_device_error:
if _is_retry:
raise RuntimeError(
f"Failed to load checkpoint even with map_location='cpu'. "
f"The checkpoint appears to require a device (CUDA/MPS) "
f"that is not available in the current environment. "
f"Please verify your PyTorch installation or load on a "
f"machine with the required hardware. "
f"Original error: {e}"
) from e
if "map_location" in kwargs:
raise RuntimeError(
f"Loading the checkpoint failed with a device error even "
f"though map_location={kwargs['map_location']!r} was "
f"explicitly specified. The checkpoint was likely saved on "
f"a CUDA/MPS device that is not available. Please check "
f"your PyTorch installation or use a machine with the "
f"required hardware. Original error: {e}"
) from e
warnings.warn(
f"Checkpoint was saved on a device that is not available "
f"(error: {error_msg}). Automatically falling back to CPU. "
f"To suppress this warning, pass map_location='cpu' "
f"explicitly.",
UserWarning,
stacklevel=2,
)
kwargs["map_location"] = torch.device("cpu")
return _safe_torch_load(
filename,
weights_only=weights_only,
_is_retry=True,
**kwargs,
)
raise

if not isinstance(checkpoint, dict):
_check_type_checkpoint(checkpoint)
Expand Down Expand Up @@ -334,6 +395,47 @@ def _check_type_checkpoint(checkpoint):
return checkpoint


def _resolve_checkpoint_device(device):
"""Resolve the device stored in a checkpoint for the current runtime.

If a checkpoint was saved on a device (CUDA, MPS, ...) that is unavailable
at load time, this falls back to CPU and issues a warning.

Args:
device: The device from the checkpoint (str or torch.device).

Returns:
str: The resolved device string ('cpu' or validated device).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls dont mention types in args/returns. type annotate instead

"""
if isinstance(device, torch.device):
device = str(device)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not robust. use torch.device type instead of string parsing
https://docs.pytorch.org/docs/stable/tensor_attributes.html#torch.device


if not isinstance(device, str):
raise TypeError(
"Expected checkpoint device to be a string or torch.device, "
f"got {type(device)}.")

fallback_to_cpu = False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

if device.startswith("cuda") and not torch.cuda.is_available():
fallback_to_cpu = True
elif device.startswith("mps") and (
not hasattr(torch.backends, "mps")
or not torch.backends.mps.is_available()):
fallback_to_cpu = True

if fallback_to_cpu:
warnings.warn(
f"Checkpoint was saved on '{device}' which is not available in "
f"the current environment. Automatically falling back to CPU.",
UserWarning,
stacklevel=2,
)
return "cpu"

return sklearn_utils.check_device(device)


def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
"""Loads a CEBRA model with a Sklearn backend.

Expand All @@ -357,11 +459,24 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":

args, state, state_dict = cebra_info['args'], cebra_info[
'state'], cebra_info['state_dict']

# Resolve device: use CPU when checkpoint was saved on CUDA but CUDA is not available
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Resolve device: use CPU when checkpoint was saved on CUDA but CUDA is not available

remove comments that are obvious from context

saved_device = state["device_"]
load_device = _resolve_checkpoint_device(saved_device)

Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new CPU-fallback logic only changes subsequent .to(load_device) calls, but loading a truly CUDA-saved checkpoint can still fail earlier in torch.load when the checkpoint contains CUDA tensors and CUDA isn’t available. Consider adding a retry/automatic fallback in CEBRA.load / _safe_torch_load that catches the CUDA deserialization RuntimeError and re-loads with map_location='cpu' (when the caller didn’t already pass map_location).

Copilot uses AI. Check for mistakes.
cebra_ = cebra.CEBRA(**args)

for key, value in state.items():
setattr(cebra_, key, value)

# Update device attributes to the resolved device for the current runtime
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

Suggested change
# Update device attributes to the resolved device for the current runtime

cebra_.device_ = load_device
saved_device_str = str(saved_device) if isinstance(saved_device,
torch.device) else saved_device
if isinstance(saved_device_str,
str) and saved_device_str.startswith("cuda") and load_device == "cpu":
cebra_.device = "cpu"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above; lets use torch.device instead of string operations. e.g. instead of startswith you can check the .type of the device


#TODO(stes): unused right now
#state_and_args = {**args, **state}

Expand All @@ -375,7 +490,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
num_neurons=state["n_features_in_"],
num_units=args["num_hidden_units"],
num_output=args["output_dimension"],
).to(state['device_'])
).to(load_device)

elif isinstance(cebra_.num_sessions_, int):
model = nn.ModuleList([
Expand All @@ -385,10 +500,10 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
num_units=args["num_hidden_units"],
num_output=args["output_dimension"],
) for n_features in state["n_features_in_"]
]).to(state['device_'])
]).to(load_device)

criterion = cebra_._prepare_criterion()
criterion.to(state['device_'])
criterion.to(load_device)

optimizer = torch.optim.Adam(
itertools.chain(model.parameters(), criterion.parameters()),
Expand All @@ -404,7 +519,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
tqdm_on=args['verbose'],
)
solver.load_state_dict(state_dict)
solver.to(state['device_'])
solver.to(load_device)

cebra_.model_ = model
cebra_.solver_ = solver
Expand Down
119 changes: 119 additions & 0 deletions tests/generate_cuda_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#!/usr/bin/env python
"""Generate a CUDA-saved checkpoint for integration testing.

Run this script on a machine with a CUDA GPU to produce a checkpoint file
that can be used to verify the CUDA-to-CPU loading fallback in a CI
environment (which typically has no GPU).

Usage::

# Default output path
python tests/generate_cuda_checkpoint.py

# Custom output path
python tests/generate_cuda_checkpoint.py --output /tmp/cuda_checkpoint.pt

# Verify an existing checkpoint
python tests/generate_cuda_checkpoint.py --verify tests/test_data/cuda_checkpoint.pt

Requirements:
- PyTorch with CUDA support (``torch.cuda.is_available()`` must be True)
- CEBRA installed (``pip install -e .`` from the repo root)

The generated file is a standard ``torch.save`` checkpoint in the CEBRA
sklearn format. It contains CUDA tensors, so loading it on a CPU-only
machine *without* the fallback logic will fail with::

RuntimeError: Attempting to deserialize object on a CUDA device but
torch.cuda.is_available() is False.
"""

import argparse
import os
import sys

import numpy as np
import torch


def generate(output_path: str) -> None:
"""Train a minimal CEBRA model on CUDA and save the checkpoint."""
if not torch.cuda.is_available():
print("ERROR: CUDA is not available. Run this on a GPU machine.",
file=sys.stderr)
sys.exit(1)

import cebra

print(f"PyTorch {torch.__version__}, CUDA {torch.version.cuda}")
print(f"Device: {torch.cuda.get_device_name(0)}")

# Train a tiny model on GPU
X = np.random.uniform(0, 1, (200, 10)).astype(np.float32)
model = cebra.CEBRA(
model_architecture="offset1-model",
max_iterations=10,
batch_size=64,
output_dimension=4,
device="cuda",
verbose=False,
)
model.fit(X)

# Sanity-check: model params should live on CUDA
param_device = next(model.solver_.model.parameters()).device
assert param_device.type == "cuda", f"Expected cuda, got {param_device}"

os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
model.save(output_path)
print(f"Saved CUDA checkpoint to {output_path}")

# Verify round-trip on GPU
loaded = cebra.CEBRA.load(output_path)
emb = loaded.transform(X)
assert emb.shape == (200, 4), f"Unexpected shape: {emb.shape}"
print("Round-trip verification on GPU: OK")


def verify(path: str) -> None:
"""Load a checkpoint on CPU and confirm the fallback works."""
import cebra

if not os.path.exists(path):
print(f"ERROR: {path} does not exist.", file=sys.stderr)
sys.exit(1)

print(f"Loading checkpoint from {path} ...")
model = cebra.CEBRA.load(path)
print(f" device_: {model.device_}")
print(f" device: {model.device}")

X = np.random.uniform(0, 1, (50, model.n_features_)).astype(np.float32)
emb = model.transform(X)
print(f" transform shape: {emb.shape}")
print("Verification: OK")


def main() -> None:
parser = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument(
"--output",
default="tests/test_data/cuda_checkpoint.pt",
help="Output path for the generated checkpoint (default: tests/test_data/cuda_checkpoint.pt)",
)
parser.add_argument(
"--verify",
metavar="PATH",
help="Instead of generating, verify an existing checkpoint can be loaded.",
)
args = parser.parse_args()

if args.verify:
verify(args.verify)
else:
generate(args.output)


if __name__ == "__main__":
main()
Loading