Skip to content

Commit 83ec2fb

Browse files
committed
support device type device_maps to work with offloading.
1 parent 54fa074 commit 83ec2fb

1 file changed

Lines changed: 19 additions & 5 deletions

File tree

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
for library in LOADABLE_CLASSES:
110110
LIBRARIES.append(library)
111111

112-
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
112+
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"]
113113

114114
logger = logging.get_logger(__name__)
115115

@@ -462,8 +462,7 @@ def module_is_offloaded(module):
462462
pipeline_is_sequentially_offloaded = any(
463463
module_is_sequentially_offloaded(module) for _, module in self.components.items()
464464
)
465-
466-
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
465+
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
467466
if is_pipeline_device_mapped:
468467
raise ValueError(
469468
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
@@ -1164,7 +1163,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
11641163
"""
11651164
self._maybe_raise_error_if_group_offload_active(raise_error=True)
11661165

1167-
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
1166+
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
11681167
if is_pipeline_device_mapped:
11691168
raise ValueError(
11701169
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
@@ -1286,7 +1285,7 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
12861285
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
12871286
self.remove_all_hooks()
12881287

1289-
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
1288+
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
12901289
if is_pipeline_device_mapped:
12911290
raise ValueError(
12921291
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
@@ -2171,6 +2170,21 @@ def _maybe_raise_error_if_group_offload_active(
21712170
return True
21722171
return False
21732172

2173+
def _is_pipeline_device_mapped(self):
2174+
# We support passing `device_map="cuda"`, for example. This is helpful, in case
2175+
# users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable
2176+
# in limited VRAM environments because quantized models often initialize directly on the accelerator.
2177+
device_map = self.hf_device_map
2178+
is_device_type_map = False
2179+
if isinstance(device_map, str):
2180+
try:
2181+
torch.device(device_map)
2182+
is_device_type_map = True
2183+
except RuntimeError:
2184+
pass
2185+
2186+
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
2187+
21742188

21752189
class StableDiffusionMixin:
21762190
r"""

0 commit comments

Comments
 (0)