|
9 | 9 |
|
10 | 10 | import numpy as np |
11 | 11 | import PIL.Image |
| 12 | +import pytest |
12 | 13 | import torch |
13 | 14 | import torch.nn as nn |
14 | 15 | from huggingface_hub import ModelCard, delete_repo |
@@ -2362,6 +2363,73 @@ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4 |
2362 | 2363 | max_diff = np.abs(to_np(out) - to_np(loaded_out)).max() |
2363 | 2364 | self.assertLess(max_diff, expected_max_difference) |
2364 | 2365 |
|
| 2366 | + @require_torch_accelerator |
| 2367 | + def test_pipeline_level_group_offloading_sanity_checks(self): |
| 2368 | + components = self.get_dummy_components() |
| 2369 | + pipe: DiffusionPipeline = self.pipeline_class(**components) |
| 2370 | + |
| 2371 | + for name, component in pipe.components.items(): |
| 2372 | + if hasattr(component, "_supports_group_offloading"): |
| 2373 | + if not component._supports_group_offloading: |
| 2374 | + pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.") |
| 2375 | + |
| 2376 | + module_names = sorted( |
| 2377 | + [name for name, component in pipe.components.items() if isinstance(component, torch.nn.Module)] |
| 2378 | + ) |
| 2379 | + exclude_module_name = module_names[0] |
| 2380 | + offload_device = "cpu" |
| 2381 | + pipe.enable_group_offload( |
| 2382 | + onload_device=torch_device, |
| 2383 | + offload_device=offload_device, |
| 2384 | + offload_type="leaf_level", |
| 2385 | + exclude_modules=exclude_module_name, |
| 2386 | + ) |
| 2387 | + excluded_module = getattr(pipe, exclude_module_name) |
| 2388 | + self.assertTrue(torch.device(excluded_module.device).type == torch.device(torch_device).type) |
| 2389 | + |
| 2390 | + for name, component in pipe.components.items(): |
| 2391 | + if name not in [exclude_module_name] and isinstance(component, torch.nn.Module): |
| 2392 | + # `component.device` prints the `onload_device` type. We should probably override the |
| 2393 | + # `device` property in `ModelMixin`. |
| 2394 | + component_device = next(component.parameters())[0].device |
| 2395 | + self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type) |
| 2396 | + |
| 2397 | + @require_torch_accelerator |
| 2398 | + def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4): |
| 2399 | + components = self.get_dummy_components() |
| 2400 | + pipe: DiffusionPipeline = self.pipeline_class(**components) |
| 2401 | + |
| 2402 | + for name, component in pipe.components.items(): |
| 2403 | + if hasattr(component, "_supports_group_offloading"): |
| 2404 | + if not component._supports_group_offloading: |
| 2405 | + pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.") |
| 2406 | + |
| 2407 | + # Regular inference. |
| 2408 | + pipe = pipe.to(torch_device) |
| 2409 | + pipe.set_progress_bar_config(disable=None) |
| 2410 | + torch.manual_seed(0) |
| 2411 | + inputs = self.get_dummy_inputs(torch_device) |
| 2412 | + inputs["generator"] = torch.manual_seed(0) |
| 2413 | + out = pipe(**inputs)[0] |
| 2414 | + |
| 2415 | + pipe.to("cpu") |
| 2416 | + del pipe |
| 2417 | + |
| 2418 | + # Inference with offloading |
| 2419 | + pipe: DiffusionPipeline = self.pipeline_class(**components) |
| 2420 | + offload_device = "cpu" |
| 2421 | + pipe.enable_group_offload( |
| 2422 | + onload_device=torch_device, |
| 2423 | + offload_device=offload_device, |
| 2424 | + offload_type="leaf_level", |
| 2425 | + ) |
| 2426 | + pipe.set_progress_bar_config(disable=None) |
| 2427 | + inputs["generator"] = torch.manual_seed(0) |
| 2428 | + out_offload = pipe(**inputs)[0] |
| 2429 | + |
| 2430 | + max_diff = np.abs(to_np(out) - to_np(out_offload)).max() |
| 2431 | + self.assertLess(max_diff, expected_max_difference) |
| 2432 | + |
2365 | 2433 |
|
2366 | 2434 | @is_staging_test |
2367 | 2435 | class PipelinePushToHubTester(unittest.TestCase): |
|
0 commit comments