-
Notifications
You must be signed in to change notification settings - Fork 41
Expand file tree
/
Copy pathqwen_image_dit_nunchaku.py
More file actions
344 lines (287 loc) · 13.3 KB
/
qwen_image_dit_nunchaku.py
File metadata and controls
344 lines (287 loc) · 13.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
import torch
import torch.nn as nn
from typing import Any, Dict, List, Tuple, Optional
from einops import rearrange
from diffsynth_engine.models.basic import attention as attention_ops
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, RMSNorm
from diffsynth_engine.models.qwen_image.qwen_image_dit import (
QwenFeedForward,
apply_rotary_emb_qwen,
QwenDoubleStreamAttention,
QwenImageTransformerBlock,
QwenImageDiT,
QwenEmbedRope,
)
from nunchaku.models.utils import fuse_linears
from nunchaku.ops.fused import fused_gelu_mlp
from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
from diffsynth_engine.models.basic.lora_nunchaku import LoRASVDQW4A4Linear, LoRAAWQW4A16Linear
class QwenDoubleStreamAttentionNunchaku(QwenDoubleStreamAttention):
def __init__(
self,
dim_a,
dim_b,
num_heads,
head_dim,
device: str = "cuda:0",
dtype: torch.dtype = torch.bfloat16,
nunchaku_rank: int = 32,
):
super().__init__(dim_a, dim_b, num_heads, head_dim, device=device, dtype=dtype)
to_qkv = fuse_linears([self.to_q, self.to_k, self.to_v])
self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, rank=nunchaku_rank)
self.to_out = SVDQW4A4Linear.from_linear(self.to_out, rank=nunchaku_rank)
del self.to_q, self.to_k, self.to_v
add_qkv_proj = fuse_linears([self.add_q_proj, self.add_k_proj, self.add_v_proj])
self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, rank=nunchaku_rank)
self.to_add_out = SVDQW4A4Linear.from_linear(self.to_add_out, rank=nunchaku_rank)
del self.add_q_proj, self.add_k_proj, self.add_v_proj
def forward(
self,
image: torch.FloatTensor,
text: torch.FloatTensor,
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attn_mask: Optional[torch.Tensor] = None,
attn_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
img_q, img_k, img_v = self.to_qkv(image).chunk(3, dim=-1)
txt_q, txt_k, txt_v = self.add_qkv_proj(text).chunk(3, dim=-1)
img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads)
img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads)
img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads)
txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads)
txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads)
txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads)
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
if rotary_emb is not None:
img_freqs, txt_freqs = rotary_emb
img_q = apply_rotary_emb_qwen(img_q, img_freqs)
img_k = apply_rotary_emb_qwen(img_k, img_freqs)
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
joint_q = torch.cat([txt_q, img_q], dim=1)
joint_k = torch.cat([txt_k, img_k], dim=1)
joint_v = torch.cat([txt_v, img_v], dim=1)
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
txt_attn_output = joint_attn_out[:, : text.shape[1], :]
img_attn_output = joint_attn_out[:, text.shape[1] :, :]
img_attn_output = self.to_out(img_attn_output)
txt_attn_output = self.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
class QwenFeedForwardNunchaku(QwenFeedForward):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
dropout: float = 0.0,
device: str = "cuda:0",
dtype: torch.dtype = torch.bfloat16,
rank: int = 32,
):
super().__init__(dim, dim_out, dropout, device=device, dtype=dtype)
self.net[0].proj = SVDQW4A4Linear.from_linear(self.net[0].proj, rank=rank)
self.net[2] = SVDQW4A4Linear.from_linear(self.net[2], rank=rank)
self.net[2].act_unsigned = self.net[2].precision != "nvfp4"
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2])
class QwenImageTransformerBlockNunchaku(QwenImageTransformerBlock):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
device: str = "cuda:0",
dtype: torch.dtype = torch.bfloat16,
scale_shift: float = 1.0,
use_nunchaku_awq: bool = True,
use_nunchaku_attn: bool = True,
nunchaku_rank: int = 32,
):
super().__init__(dim, num_attention_heads, attention_head_dim, eps, device=device, dtype=dtype)
self.use_nunchaku_awq = use_nunchaku_awq
if use_nunchaku_awq:
self.img_mod[1] = AWQW4A16Linear.from_linear(self.img_mod[1], rank=nunchaku_rank)
if use_nunchaku_attn:
self.attn = QwenDoubleStreamAttentionNunchaku(
dim_a=dim,
dim_b=dim,
num_heads=num_attention_heads,
head_dim=attention_head_dim,
device=device,
dtype=dtype,
nunchaku_rank=nunchaku_rank,
)
else:
self.attn = QwenDoubleStreamAttention(
dim_a=dim,
dim_b=dim,
num_heads=num_attention_heads,
head_dim=attention_head_dim,
device=device,
dtype=dtype,
)
self.img_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank)
if use_nunchaku_awq:
self.txt_mod[1] = AWQW4A16Linear.from_linear(self.txt_mod[1], rank=nunchaku_rank)
self.txt_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank)
self.scale_shift = scale_shift
def _modulate(self, x, mod_params):
shift, scale, gate = mod_params.chunk(3, dim=-1)
if self.use_nunchaku_awq:
if self.scale_shift != 0:
scale.add_(self.scale_shift)
return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
image: torch.Tensor,
text: torch.Tensor,
temb: torch.Tensor,
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attn_mask: Optional[torch.Tensor] = None,
attn_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.use_nunchaku_awq:
img_mod_params = self.img_mod(temb) # [B, 6*dim]
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
# nunchaku's mod_params is [B, 6*dim] instead of [B, dim*6]
img_mod_params = (
img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
)
txt_mod_params = (
txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
)
img_mod_attn, img_mod_mlp = img_mod_params.chunk(2, dim=-1) # [B, 3*dim] each
txt_mod_attn, txt_mod_mlp = txt_mod_params.chunk(2, dim=-1) # [B, 3*dim] each
else:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
img_normed = self.img_norm1(image)
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
txt_normed = self.txt_norm1(text)
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
img_attn_out, txt_attn_out = self.attn(
image=img_modulated,
text=txt_modulated,
rotary_emb=rotary_emb,
attn_mask=attn_mask,
attn_kwargs=attn_kwargs,
)
image = image + img_gate * img_attn_out
text = text + txt_gate * txt_attn_out
img_normed_2 = self.img_norm2(image)
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
txt_normed_2 = self.txt_norm2(text)
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
img_mlp_out = self.img_mlp(img_modulated_2)
txt_mlp_out = self.txt_mlp(txt_modulated_2)
image = image + img_gate_2 * img_mlp_out
text = text + txt_gate_2 * txt_mlp_out
return text, image
class QwenImageDiTNunchaku(QwenImageDiT):
def __init__(
self,
num_layers: int = 60,
device: str = "cuda:0",
dtype: torch.dtype = torch.bfloat16,
use_nunchaku_awq: bool = True,
use_nunchaku_attn: bool = True,
nunchaku_rank: int = 32,
):
super().__init__(num_layers, device=device, dtype=dtype)
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True, device=device)
self.time_text_embed = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
self.txt_norm = RMSNorm(3584, eps=1e-6, device=device, dtype=dtype)
self.img_in = nn.Linear(64, 3072, device=device, dtype=dtype)
self.txt_in = nn.Linear(3584, 3072, device=device, dtype=dtype)
self.transformer_blocks = nn.ModuleList(
[
QwenImageTransformerBlockNunchaku(
dim=3072,
num_attention_heads=24,
attention_head_dim=128,
device=device,
dtype=dtype,
scale_shift=0,
use_nunchaku_awq=use_nunchaku_awq,
use_nunchaku_attn=use_nunchaku_attn,
nunchaku_rank=nunchaku_rank,
)
for _ in range(num_layers)
]
)
self.norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
self.proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, torch.Tensor],
device: str,
dtype: torch.dtype,
num_layers: int = 60,
use_nunchaku_awq: bool = True,
use_nunchaku_attn: bool = True,
nunchaku_rank: int = 32,
):
model = cls(
device="meta",
dtype=dtype,
num_layers=num_layers,
use_nunchaku_awq=use_nunchaku_awq,
use_nunchaku_attn=use_nunchaku_attn,
nunchaku_rank=nunchaku_rank,
)
model = model.requires_grad_(False)
model.load_state_dict(state_dict, assign=True)
model.to(device=device, non_blocking=True)
return model
def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = False):
fuse_dict = {}
for args in lora_args:
key = args["key"]
if any(suffix in key for suffix in {"add_q_proj", "add_k_proj", "add_v_proj"}):
fuse_key = f"{key.rsplit('.', 1)[0]}.add_qkv_proj"
type = key.rsplit(".", 1)[-1].split("_")[1]
fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {})
fuse_dict[fuse_key][type] = args
continue
if any(suffix in key for suffix in {"to_q", "to_k", "to_v"}):
fuse_key = f"{key.rsplit('.', 1)[0]}.to_qkv"
type = key.rsplit(".", 1)[-1].split("_")[1]
fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {})
fuse_dict[fuse_key][type] = args
continue
module = self.get_submodule(key)
if not isinstance(module, (LoRALinear, LoRAConv2d, LoRASVDQW4A4Linear, LoRAAWQW4A16Linear)):
raise ValueError(f"Unsupported lora key: {key}")
if fused and not isinstance(module, LoRAAWQW4A16Linear):
module.add_frozen_lora(**args)
else:
module.add_lora(**args)
for key in fuse_dict.keys():
module = self.get_submodule(key)
if not isinstance(module, LoRASVDQW4A4Linear):
raise ValueError(f"Unsupported lora key: {key}")
module.add_qkv_lora(
name=args["name"],
scale=fuse_dict[key]["q"]["scale"],
rank=fuse_dict[key]["q"]["rank"],
alpha=fuse_dict[key]["q"]["alpha"],
q_up=fuse_dict[key]["q"]["up"],
q_down=fuse_dict[key]["q"]["down"],
k_up=fuse_dict[key]["k"]["up"],
k_down=fuse_dict[key]["k"]["down"],
v_up=fuse_dict[key]["v"]["up"],
v_down=fuse_dict[key]["v"]["down"],
device=fuse_dict[key]["q"]["device"],
dtype=fuse_dict[key]["q"]["dtype"],
)
def enable_fp8_linear(self):
raise NotImplementedError(f"{self.__class__.__name__} does not support FP8 linear")