Skip to content

Commit 755dc49

Browse files
committed
stable diffusion og.
1 parent 4bd7dd5 commit 755dc49

12 files changed

Lines changed: 553 additions & 1940 deletions

check_duplicate.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import ast
2+
import argparse
3+
from pathlib import Path
4+
from collections import defaultdict
5+
6+
# This script requires Python 3.9+ for the ast.unparse() function.
7+
8+
class ClassMethodVisitor(ast.NodeVisitor):
9+
"""
10+
An AST visitor that collects method names and the source code of their bodies.
11+
"""
12+
def __init__(self):
13+
self.methods = defaultdict(list)
14+
15+
def visit_ClassDef(self, node: ast.ClassDef):
16+
"""
17+
Visits a class definition, then inspects its methods.
18+
"""
19+
class_name = node.name
20+
for item in node.body:
21+
if isinstance(item, ast.FunctionDef):
22+
method_name = item.name
23+
body_source = ast.unparse(item.body).strip()
24+
self.methods[method_name].append((class_name, body_source))
25+
self.generic_visit(node)
26+
27+
def find_duplicate_method_content(directory: str, show_code: bool = True):
28+
"""
29+
Parses all Python files in a directory to find methods with duplicate content.
30+
31+
Args:
32+
directory: The path to the directory to inspect.
33+
show_code: If True, prints the shared code block for each duplicate.
34+
"""
35+
target_dir = Path(directory)
36+
if not target_dir.is_dir():
37+
print(f"❌ Error: '{directory}' is not a valid directory.")
38+
return
39+
40+
visitor = ClassMethodVisitor()
41+
42+
for py_file in target_dir.rglob('*.py'):
43+
try:
44+
with open(py_file, 'r', encoding='utf-8') as f:
45+
source_code = f.read()
46+
tree = ast.parse(source_code, filename=py_file)
47+
visitor.visit(tree)
48+
except Exception as e:
49+
print(f"⚠️ Warning: Could not process {py_file}. Error: {e}")
50+
51+
print("\n--- Duplicate Method Content Report ---")
52+
duplicates_found = False
53+
54+
for method_name, implementations in sorted(visitor.methods.items()):
55+
body_groups = defaultdict(list)
56+
for class_name, body_source in implementations:
57+
body_groups[body_source].append(class_name)
58+
59+
for body_source, class_list in body_groups.items():
60+
if len(class_list) > 1:
61+
duplicates_found = True
62+
unique_classes = sorted(list(set(class_list)))
63+
print(f"\n[+] Method `def {method_name}(...)` has identical content in {len(unique_classes)} classes:")
64+
for class_name in unique_classes:
65+
print(f" - {class_name}")
66+
67+
# Conditionally print the shared code block based on the flag
68+
if show_code:
69+
print("\n Shared Code Block:")
70+
indented_code = "\n".join([f" {line}" for line in body_source.splitlines()])
71+
print(indented_code)
72+
print(" " + "-" * 30)
73+
74+
if not duplicates_found:
75+
print("\n✅ No methods with identical content were found across classes.")
76+
77+
def main():
78+
"""Main function to set up argument parsing."""
79+
parser = argparse.ArgumentParser(
80+
description="Find methods with identical content across Python classes in a directory."
81+
)
82+
parser.add_argument(
83+
"directory",
84+
type=str,
85+
help="The path to the directory to inspect."
86+
)
87+
# New argument to control output verbosity
88+
parser.add_argument(
89+
"--hide-code",
90+
action="store_true",
91+
help="Do not print the shared code block for each duplicate found."
92+
)
93+
args = parser.parse_args()
94+
find_duplicate_method_content(args.directory, show_code=not args.hide_code)
95+
96+
if __name__ == "__main__":
97+
main()

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 6 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@
4444

4545
from .. import __version__
4646
from ..configuration_utils import ConfigMixin
47-
from ..models import AutoencoderKL
48-
from ..models.attention_processor import FusedAttnProcessor2_0
4947
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
5048
from ..quantizers import PipelineQuantizationConfig
5149
from ..quantizers.bitsandbytes.utils import _check_bnb_status
@@ -69,6 +67,7 @@
6967
)
7068
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
7169
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
70+
from .stable_diffusion.pipeline_stable_diffusion_utils import StableDiffusionMixin as ActualStableDiffusionMixin
7271

7372

7473
if is_torch_npu_available():
@@ -2172,137 +2171,8 @@ def _maybe_raise_error_if_group_offload_active(
21722171
return False
21732172

21742173

2175-
class StableDiffusionMixin:
2176-
r"""
2177-
Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion)
2178-
"""
2179-
2180-
def enable_vae_slicing(self):
2181-
r"""
2182-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
2183-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
2184-
"""
2185-
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
2186-
deprecate(
2187-
"enable_vae_slicing",
2188-
"0.40.0",
2189-
depr_message,
2190-
)
2191-
self.vae.enable_slicing()
2192-
2193-
def disable_vae_slicing(self):
2194-
r"""
2195-
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
2196-
computing decoding in one step.
2197-
"""
2198-
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
2199-
deprecate(
2200-
"disable_vae_slicing",
2201-
"0.40.0",
2202-
depr_message,
2203-
)
2204-
self.vae.disable_slicing()
2205-
2206-
def enable_vae_tiling(self):
2207-
r"""
2208-
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
2209-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
2210-
processing larger images.
2211-
"""
2212-
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
2213-
deprecate(
2214-
"enable_vae_tiling",
2215-
"0.40.0",
2216-
depr_message,
2217-
)
2218-
self.vae.enable_tiling()
2219-
2220-
def disable_vae_tiling(self):
2221-
r"""
2222-
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
2223-
computing decoding in one step.
2224-
"""
2225-
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
2226-
deprecate(
2227-
"disable_vae_tiling",
2228-
"0.40.0",
2229-
depr_message,
2230-
)
2231-
self.vae.disable_tiling()
2232-
2233-
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
2234-
r"""Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497.
2235-
2236-
The suffixes after the scaling factors represent the stages where they are being applied.
2237-
2238-
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
2239-
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
2240-
2241-
Args:
2242-
s1 (`float`):
2243-
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
2244-
mitigate "oversmoothing effect" in the enhanced denoising process.
2245-
s2 (`float`):
2246-
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
2247-
mitigate "oversmoothing effect" in the enhanced denoising process.
2248-
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
2249-
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
2250-
"""
2251-
if not hasattr(self, "unet"):
2252-
raise ValueError("The pipeline must have `unet` for using FreeU.")
2253-
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
2254-
2255-
def disable_freeu(self):
2256-
"""Disables the FreeU mechanism if enabled."""
2257-
self.unet.disable_freeu()
2258-
2259-
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
2260-
"""
2261-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
2262-
are fused. For cross-attention modules, key and value projection matrices are fused.
2263-
2264-
> [!WARNING] > This API is 🧪 experimental.
2265-
2266-
Args:
2267-
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
2268-
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
2269-
"""
2270-
self.fusing_unet = False
2271-
self.fusing_vae = False
2272-
2273-
if unet:
2274-
self.fusing_unet = True
2275-
self.unet.fuse_qkv_projections()
2276-
self.unet.set_attn_processor(FusedAttnProcessor2_0())
2277-
2278-
if vae:
2279-
if not isinstance(self.vae, AutoencoderKL):
2280-
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
2281-
2282-
self.fusing_vae = True
2283-
self.vae.fuse_qkv_projections()
2284-
self.vae.set_attn_processor(FusedAttnProcessor2_0())
2285-
2286-
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
2287-
"""Disable QKV projection fusion if enabled.
2288-
2289-
> [!WARNING] > This API is 🧪 experimental.
2290-
2291-
Args:
2292-
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
2293-
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
2294-
2295-
"""
2296-
if unet:
2297-
if not self.fusing_unet:
2298-
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
2299-
else:
2300-
self.unet.unfuse_qkv_projections()
2301-
self.fusing_unet = False
2302-
2303-
if vae:
2304-
if not self.fusing_vae:
2305-
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
2306-
else:
2307-
self.vae.unfuse_qkv_projections()
2308-
self.fusing_vae = False
2174+
class StableDiffusionMixin(ActualStableDiffusionMixin):
2175+
def __init__(self, *args, **kwargs):
2176+
deprecation_message = "`StableDiffusionMixin` from `diffusers.pipelines.pipeline_utils` is deprecated and this will be removed in a future version. Please use `StableDiffusionMixin` from `diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils`, instead."
2177+
deprecate("StableDiffusionMixin", "1.0.0", deprecation_message)
2178+
super().__init__(*args, **kwargs)

0 commit comments

Comments
 (0)