-
Notifications
You must be signed in to change notification settings - Fork 95
fix: Allow loading CUDA-saved models on CPU-only machines #296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
acd63c7
97d5b90
55f7589
c02a95f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -69,18 +69,79 @@ def check_version(estimator): | |||
| sklearn.__version__) < packaging.version.parse("1.6.dev") | ||||
|
|
||||
|
|
||||
| def _safe_torch_load(filename, weights_only=False, **kwargs): | ||||
| checkpoint = None | ||||
| def _safe_torch_load(filename, weights_only=False, _is_retry=False, **kwargs): | ||||
| """Load a checkpoint with automatic CUDA/MPS to CPU fallback. | ||||
|
|
||||
| If loading fails due to a CUDA or MPS device error (e.g. checkpoint was | ||||
| saved on a GPU but no GPU is available), this function automatically retries | ||||
| with ``map_location="cpu"`` and issues a warning. | ||||
|
|
||||
| Args: | ||||
| filename: Path to the checkpoint file. | ||||
| weights_only: Passed through to :func:`torch.load`. | ||||
| _is_retry: Internal flag to prevent infinite recursion. Do not set | ||||
| this manually. | ||||
| **kwargs: Additional keyword arguments forwarded to :func:`torch.load`. | ||||
|
|
||||
| Returns: | ||||
| The loaded checkpoint dictionary. | ||||
|
|
||||
| Raises: | ||||
| RuntimeError: If loading fails on the retry attempt or for non-device | ||||
| related reasons. | ||||
| """ | ||||
| legacy_mode = packaging.version.parse( | ||||
| torch.__version__) < packaging.version.parse("2.6.0") | ||||
|
|
||||
| if legacy_mode: | ||||
| checkpoint = torch.load(filename, weights_only=False, **kwargs) | ||||
| else: | ||||
| with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS): | ||||
| checkpoint = torch.load(filename, | ||||
| weights_only=weights_only, | ||||
| **kwargs) | ||||
| try: | ||||
| if legacy_mode: | ||||
| checkpoint = torch.load(filename, weights_only=False, **kwargs) | ||||
| else: | ||||
| with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS): | ||||
| checkpoint = torch.load(filename, | ||||
| weights_only=weights_only, | ||||
| **kwargs) | ||||
| except RuntimeError as e: | ||||
| error_msg = str(e) | ||||
| is_device_error = ("CUDA" in error_msg | ||||
| or "cuda" in error_msg.lower() | ||||
| or "MPS" in error_msg | ||||
| or "mps" in error_msg.lower()) | ||||
| if is_device_error: | ||||
| if _is_retry: | ||||
| raise RuntimeError( | ||||
| f"Failed to load checkpoint even with map_location='cpu'. " | ||||
| f"The checkpoint appears to require a device (CUDA/MPS) " | ||||
| f"that is not available in the current environment. " | ||||
| f"Please verify your PyTorch installation or load on a " | ||||
| f"machine with the required hardware. " | ||||
| f"Original error: {e}" | ||||
| ) from e | ||||
| if "map_location" in kwargs: | ||||
| raise RuntimeError( | ||||
| f"Loading the checkpoint failed with a device error even " | ||||
| f"though map_location={kwargs['map_location']!r} was " | ||||
| f"explicitly specified. The checkpoint was likely saved on " | ||||
| f"a CUDA/MPS device that is not available. Please check " | ||||
| f"your PyTorch installation or use a machine with the " | ||||
| f"required hardware. Original error: {e}" | ||||
| ) from e | ||||
| warnings.warn( | ||||
| f"Checkpoint was saved on a device that is not available " | ||||
| f"(error: {error_msg}). Automatically falling back to CPU. " | ||||
| f"To suppress this warning, pass map_location='cpu' " | ||||
| f"explicitly.", | ||||
| UserWarning, | ||||
| stacklevel=2, | ||||
| ) | ||||
| kwargs["map_location"] = torch.device("cpu") | ||||
| return _safe_torch_load( | ||||
| filename, | ||||
| weights_only=weights_only, | ||||
| _is_retry=True, | ||||
| **kwargs, | ||||
| ) | ||||
| raise | ||||
|
|
||||
| if not isinstance(checkpoint, dict): | ||||
| _check_type_checkpoint(checkpoint) | ||||
|
|
@@ -334,6 +395,47 @@ def _check_type_checkpoint(checkpoint): | |||
| return checkpoint | ||||
|
|
||||
|
|
||||
| def _resolve_checkpoint_device(device): | ||||
| """Resolve the device stored in a checkpoint for the current runtime. | ||||
|
|
||||
| If a checkpoint was saved on a device (CUDA, MPS, ...) that is unavailable | ||||
| at load time, this falls back to CPU and issues a warning. | ||||
|
|
||||
| Args: | ||||
| device: The device from the checkpoint (str or torch.device). | ||||
|
|
||||
| Returns: | ||||
| str: The resolved device string ('cpu' or validated device). | ||||
| """ | ||||
| if isinstance(device, torch.device): | ||||
| device = str(device) | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not robust. use torch.device type instead of string parsing |
||||
|
|
||||
| if not isinstance(device, str): | ||||
| raise TypeError( | ||||
| "Expected checkpoint device to be a string or torch.device, " | ||||
| f"got {type(device)}.") | ||||
|
|
||||
| fallback_to_cpu = False | ||||
|
|
||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| if device.startswith("cuda") and not torch.cuda.is_available(): | ||||
| fallback_to_cpu = True | ||||
| elif device.startswith("mps") and ( | ||||
| not hasattr(torch.backends, "mps") | ||||
| or not torch.backends.mps.is_available()): | ||||
| fallback_to_cpu = True | ||||
|
|
||||
| if fallback_to_cpu: | ||||
| warnings.warn( | ||||
| f"Checkpoint was saved on '{device}' which is not available in " | ||||
| f"the current environment. Automatically falling back to CPU.", | ||||
| UserWarning, | ||||
| stacklevel=2, | ||||
| ) | ||||
| return "cpu" | ||||
|
|
||||
| return sklearn_utils.check_device(device) | ||||
|
|
||||
|
|
||||
| def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": | ||||
| """Loads a CEBRA model with a Sklearn backend. | ||||
|
|
||||
|
|
@@ -357,11 +459,24 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": | |||
|
|
||||
| args, state, state_dict = cebra_info['args'], cebra_info[ | ||||
| 'state'], cebra_info['state_dict'] | ||||
|
|
||||
| # Resolve device: use CPU when checkpoint was saved on CUDA but CUDA is not available | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
remove comments that are obvious from context |
||||
| saved_device = state["device_"] | ||||
| load_device = _resolve_checkpoint_device(saved_device) | ||||
|
|
||||
|
||||
| cebra_ = cebra.CEBRA(**args) | ||||
|
|
||||
| for key, value in state.items(): | ||||
| setattr(cebra_, key, value) | ||||
|
|
||||
| # Update device attributes to the resolved device for the current runtime | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above
Suggested change
|
||||
| cebra_.device_ = load_device | ||||
| saved_device_str = str(saved_device) if isinstance(saved_device, | ||||
| torch.device) else saved_device | ||||
| if isinstance(saved_device_str, | ||||
| str) and saved_device_str.startswith("cuda") and load_device == "cpu": | ||||
| cebra_.device = "cpu" | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above; lets use torch.device instead of string operations. e.g. instead of startswith you can check the |
||||
|
|
||||
| #TODO(stes): unused right now | ||||
| #state_and_args = {**args, **state} | ||||
|
|
||||
|
|
@@ -375,7 +490,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": | |||
| num_neurons=state["n_features_in_"], | ||||
| num_units=args["num_hidden_units"], | ||||
| num_output=args["output_dimension"], | ||||
| ).to(state['device_']) | ||||
| ).to(load_device) | ||||
|
|
||||
| elif isinstance(cebra_.num_sessions_, int): | ||||
| model = nn.ModuleList([ | ||||
|
|
@@ -385,10 +500,10 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": | |||
| num_units=args["num_hidden_units"], | ||||
| num_output=args["output_dimension"], | ||||
| ) for n_features in state["n_features_in_"] | ||||
| ]).to(state['device_']) | ||||
| ]).to(load_device) | ||||
|
|
||||
| criterion = cebra_._prepare_criterion() | ||||
| criterion.to(state['device_']) | ||||
| criterion.to(load_device) | ||||
|
|
||||
| optimizer = torch.optim.Adam( | ||||
| itertools.chain(model.parameters(), criterion.parameters()), | ||||
|
|
@@ -404,7 +519,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": | |||
| tqdm_on=args['verbose'], | ||||
| ) | ||||
| solver.load_state_dict(state_dict) | ||||
| solver.to(state['device_']) | ||||
| solver.to(load_device) | ||||
|
|
||||
| cebra_.model_ = model | ||||
| cebra_.solver_ = solver | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| #!/usr/bin/env python | ||
| """Generate a CUDA-saved checkpoint for integration testing. | ||
|
|
||
| Run this script on a machine with a CUDA GPU to produce a checkpoint file | ||
| that can be used to verify the CUDA-to-CPU loading fallback in a CI | ||
| environment (which typically has no GPU). | ||
|
|
||
| Usage:: | ||
|
|
||
| # Default output path | ||
| python tests/generate_cuda_checkpoint.py | ||
|
|
||
| # Custom output path | ||
| python tests/generate_cuda_checkpoint.py --output /tmp/cuda_checkpoint.pt | ||
|
|
||
| # Verify an existing checkpoint | ||
| python tests/generate_cuda_checkpoint.py --verify tests/test_data/cuda_checkpoint.pt | ||
|
|
||
| Requirements: | ||
| - PyTorch with CUDA support (``torch.cuda.is_available()`` must be True) | ||
| - CEBRA installed (``pip install -e .`` from the repo root) | ||
|
|
||
| The generated file is a standard ``torch.save`` checkpoint in the CEBRA | ||
| sklearn format. It contains CUDA tensors, so loading it on a CPU-only | ||
| machine *without* the fallback logic will fail with:: | ||
|
|
||
| RuntimeError: Attempting to deserialize object on a CUDA device but | ||
| torch.cuda.is_available() is False. | ||
| """ | ||
|
|
||
| import argparse | ||
| import os | ||
| import sys | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
|
|
||
| def generate(output_path: str) -> None: | ||
| """Train a minimal CEBRA model on CUDA and save the checkpoint.""" | ||
| if not torch.cuda.is_available(): | ||
| print("ERROR: CUDA is not available. Run this on a GPU machine.", | ||
| file=sys.stderr) | ||
| sys.exit(1) | ||
|
|
||
| import cebra | ||
|
|
||
| print(f"PyTorch {torch.__version__}, CUDA {torch.version.cuda}") | ||
| print(f"Device: {torch.cuda.get_device_name(0)}") | ||
|
|
||
| # Train a tiny model on GPU | ||
| X = np.random.uniform(0, 1, (200, 10)).astype(np.float32) | ||
| model = cebra.CEBRA( | ||
| model_architecture="offset1-model", | ||
| max_iterations=10, | ||
| batch_size=64, | ||
| output_dimension=4, | ||
| device="cuda", | ||
| verbose=False, | ||
| ) | ||
| model.fit(X) | ||
|
|
||
| # Sanity-check: model params should live on CUDA | ||
| param_device = next(model.solver_.model.parameters()).device | ||
| assert param_device.type == "cuda", f"Expected cuda, got {param_device}" | ||
|
|
||
| os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) | ||
| model.save(output_path) | ||
| print(f"Saved CUDA checkpoint to {output_path}") | ||
|
|
||
| # Verify round-trip on GPU | ||
| loaded = cebra.CEBRA.load(output_path) | ||
| emb = loaded.transform(X) | ||
| assert emb.shape == (200, 4), f"Unexpected shape: {emb.shape}" | ||
| print("Round-trip verification on GPU: OK") | ||
|
|
||
|
|
||
| def verify(path: str) -> None: | ||
| """Load a checkpoint on CPU and confirm the fallback works.""" | ||
| import cebra | ||
|
|
||
| if not os.path.exists(path): | ||
| print(f"ERROR: {path} does not exist.", file=sys.stderr) | ||
| sys.exit(1) | ||
|
|
||
| print(f"Loading checkpoint from {path} ...") | ||
| model = cebra.CEBRA.load(path) | ||
| print(f" device_: {model.device_}") | ||
| print(f" device: {model.device}") | ||
|
|
||
| X = np.random.uniform(0, 1, (50, model.n_features_)).astype(np.float32) | ||
| emb = model.transform(X) | ||
| print(f" transform shape: {emb.shape}") | ||
| print("Verification: OK") | ||
|
|
||
|
|
||
| def main() -> None: | ||
| parser = argparse.ArgumentParser(description=__doc__, | ||
| formatter_class=argparse.RawDescriptionHelpFormatter) | ||
| parser.add_argument( | ||
| "--output", | ||
| default="tests/test_data/cuda_checkpoint.pt", | ||
| help="Output path for the generated checkpoint (default: tests/test_data/cuda_checkpoint.pt)", | ||
| ) | ||
| parser.add_argument( | ||
| "--verify", | ||
| metavar="PATH", | ||
| help="Instead of generating, verify an existing checkpoint can be loaded.", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| if args.verify: | ||
| verify(args.verify) | ||
| else: | ||
| generate(args.output) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls dont mention types in args/returns. type annotate instead