|
109 | 109 | for library in LOADABLE_CLASSES: |
110 | 110 | LIBRARIES.append(library) |
111 | 111 |
|
112 | | -SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()] |
| 112 | +SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"] |
113 | 113 |
|
114 | 114 | logger = logging.get_logger(__name__) |
115 | 115 |
|
@@ -462,8 +462,7 @@ def module_is_offloaded(module): |
462 | 462 | pipeline_is_sequentially_offloaded = any( |
463 | 463 | module_is_sequentially_offloaded(module) for _, module in self.components.items() |
464 | 464 | ) |
465 | | - |
466 | | - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 |
| 465 | + is_pipeline_device_mapped = self._is_pipeline_device_mapped() |
467 | 466 | if is_pipeline_device_mapped: |
468 | 467 | raise ValueError( |
469 | 468 | "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline." |
@@ -1164,7 +1163,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t |
1164 | 1163 | """ |
1165 | 1164 | self._maybe_raise_error_if_group_offload_active(raise_error=True) |
1166 | 1165 |
|
1167 | | - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 |
| 1166 | + is_pipeline_device_mapped = self._is_pipeline_device_mapped() |
1168 | 1167 | if is_pipeline_device_mapped: |
1169 | 1168 | raise ValueError( |
1170 | 1169 | "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`." |
@@ -1286,7 +1285,7 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un |
1286 | 1285 | raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") |
1287 | 1286 | self.remove_all_hooks() |
1288 | 1287 |
|
1289 | | - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 |
| 1288 | + is_pipeline_device_mapped = self._is_pipeline_device_mapped() |
1290 | 1289 | if is_pipeline_device_mapped: |
1291 | 1290 | raise ValueError( |
1292 | 1291 | "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`." |
@@ -2171,6 +2170,21 @@ def _maybe_raise_error_if_group_offload_active( |
2171 | 2170 | return True |
2172 | 2171 | return False |
2173 | 2172 |
|
| 2173 | + def _is_pipeline_device_mapped(self): |
| 2174 | + # We support passing `device_map="cuda"`, for example. This is helpful, in case |
| 2175 | + # users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable |
| 2176 | + # in limited VRAM environments because quantized models often initialize directly on the accelerator. |
| 2177 | + device_map = self.hf_device_map |
| 2178 | + is_device_type_map = False |
| 2179 | + if isinstance(device_map, str): |
| 2180 | + try: |
| 2181 | + torch.device(device_map) |
| 2182 | + is_device_type_map = True |
| 2183 | + except RuntimeError: |
| 2184 | + pass |
| 2185 | + |
| 2186 | + return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1 |
| 2187 | + |
2174 | 2188 |
|
2175 | 2189 | class StableDiffusionMixin: |
2176 | 2190 | r""" |
|
0 commit comments