-
Notifications
You must be signed in to change notification settings - Fork 41
Expand file tree
/
Copy pathqwen_image_dit_fbcache.py
More file actions
135 lines (117 loc) · 4.54 KB
/
qwen_image_dit_fbcache.py
File metadata and controls
135 lines (117 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
from typing import Any, Dict, Optional
from diffsynth_engine.models.qwen_image import QwenImageDiT
from diffsynth_engine.utils.gguf import gguf_inference
from diffsynth_engine.utils.parallel import cfg_parallel, cfg_parallel_unshard
class QwenImageDiTFBCache(QwenImageDiT):
def __init__(
self,
num_layers: int = 60,
device: str = "cuda:0",
dtype: torch.dtype = torch.bfloat16,
relative_l1_threshold: float = 0.05,
):
super().__init__(num_layers=num_layers, device=device, dtype=dtype)
self.relative_l1_threshold = relative_l1_threshold
self.step_count = 0
self.num_inference_steps = 0
def is_relative_l1_below_threshold(self, prev_residual, residual, threshold):
if threshold <= 0.0:
return False
if prev_residual.shape != residual.shape:
return False
mean_diff = (prev_residual - residual).abs().mean()
mean_prev_residual = prev_residual.abs().mean()
diff = mean_diff / mean_prev_residual
return diff.item() < threshold
def refresh_cache_status(self, num_inference_steps):
self.step_count = 0
self.num_inference_steps = num_inference_steps
def forward(
self,
image: torch.Tensor,
text: torch.Tensor = None,
timestep: torch.LongTensor = None,
txt_seq_lens: torch.LongTensor = None,
attn_kwargs: Optional[Dict[str, Any]] = None,
):
h, w = image.shape[-2:]
use_cfg = image.shape[0] > 1
with (
gguf_inference(),
cfg_parallel(
(
image,
text,
timestep,
txt_seq_lens,
),
use_cfg=use_cfg,
),
):
conditioning = self.time_text_embed(timestep, image.dtype)
video_fhw = (1, h // 2, w // 2) # frame, height, width
max_length = txt_seq_lens.max().item()
image_rotary_emb = self.pos_embed(video_fhw, max_length, image.device)
image = self.patchify(image)
image = self.img_in(image)
text = self.txt_in(self.txt_norm(text[:, :max_length]))
# first block
original_hidden_states = image
text, image = self.transformer_blocks[0](
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
attn_kwargs=attn_kwargs,
)
first_hidden_states_residual = image - original_hidden_states
if self.step_count == 0 or self.step_count == (self.num_inference_steps - 1):
should_calc = True
else:
skip = self.is_relative_l1_below_threshold(
first_hidden_states_residual,
self.prev_first_hidden_states_residual,
threshold=self.relative_l1_threshold,
)
should_calc = not skip
self.step_count += 1
if not should_calc:
image += self.previous_residual
else:
self.prev_first_hidden_states_residual = first_hidden_states_residual
first_hidden_states = image.clone()
for block in self.transformer_blocks[1:]:
text, image = block(
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
attn_kwargs=attn_kwargs,
)
previous_residual = image - first_hidden_states
self.previous_residual = previous_residual
image = self.norm_out(image, conditioning)
image = self.proj_out(image)
image = self.unpatchify(image, h, w)
(image,) = cfg_parallel_unshard((image,), use_cfg=use_cfg)
return image
@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, torch.Tensor],
device: str,
dtype: torch.dtype,
num_layers: int = 60,
relative_l1_threshold: float = 0.05,
):
model = cls(
device="meta",
dtype=dtype,
num_layers=num_layers,
relative_l1_threshold=relative_l1_threshold,
)
model = model.requires_grad_(False)
model.load_state_dict(state_dict, assign=True)
model.to(device=device, dtype=dtype, non_blocking=True)
return model