-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathbase.py
More file actions
90 lines (73 loc) · 3.16 KB
/
base.py
File metadata and controls
90 lines (73 loc) · 3.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Optional
import torch
import torch.nn as nn
from accelerate import init_empty_weights
from diffusers.configuration_utils import ConfigMixin
from diffsynth_engine.utils import logging
from diffsynth_engine.utils.constants import CONFIG_NAME
from diffsynth_engine.utils.load_utils import load_model_weights
logger = logging.get_logger(__name__)
class DiffusionModel(nn.Module, ConfigMixin):
config_name = CONFIG_NAME
# This is identical to diffusers' ModelMixin._keep_in_fp32_modules.
_keep_in_fp32_modules: list[str] | None = None
# ModelMixin._keys_to_ignore_on_load_unexpected.
_keys_to_ignore_on_load_unexpected: list[str] | None = None
@classmethod
def from_pretrained(
cls,
model_path: str,
subfolder: Optional[str] = None,
device: Optional[str | torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
# load config
config_dict = cls.load_config(model_path, subfolder=subfolder, local_files_only=True)
# initialize model
with init_empty_weights():
model = cls.from_config(config_dict)
# avoids precision loss
if dtype is not None and dtype != torch.float32 and cls._keep_in_fp32_modules:
state_dict = load_model_weights(model_path, subfolder, device, dtype=None)
for key in state_dict:
if any(m in key.split(".") for m in cls._keep_in_fp32_modules):
state_dict[key] = state_dict[key].to(device=device, dtype=torch.float32)
else:
state_dict[key] = state_dict[key].to(device=device, dtype=dtype)
else:
state_dict = load_model_weights(model_path, subfolder, device, dtype)
# Filter out unexpected keys that the model explicitly ignores
if cls._keys_to_ignore_on_load_unexpected:
keys_to_remove = [
key for key in state_dict if any(pattern in key for pattern in cls._keys_to_ignore_on_load_unexpected)
]
for key in keys_to_remove:
del state_dict[key]
if keys_to_remove:
logger.info(
f"Dropped {len(keys_to_remove)} unexpected key(s) matching "
f"{cls._keys_to_ignore_on_load_unexpected} from state_dict."
)
model.load_state_dict(state_dict, strict=True, assign=True)
model.to(device=device)
return model
class AutoregressiveModel(nn.Module):
config_name = CONFIG_NAME
@classmethod
def from_pretrained(
cls,
model_path: str,
subfolder: Optional[str] = None,
device: Optional[str | torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
# load config
config = cls.config_class.from_pretrained(model_path, subfolder=subfolder, local_files_only=True)
# initialize model
with init_empty_weights():
model = cls(config)
# load model weights
state_dict = load_model_weights(model_path, subfolder, device, dtype)
model.load_state_dict(state_dict, strict=True, assign=True)
model.to(device=device)
return model