-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathwan_text_encoder.py
More file actions
332 lines (287 loc) · 13 KB
/
wan_text_encoder.py
File metadata and controls
332 lines (287 loc) · 13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
from diffsynth_engine.utils.gguf import gguf_inference
from diffsynth_engine.utils import logging
logger = logging.get_logger(__name__)
def fp16_clamp(x):
if x.dtype == torch.float16 and torch.isinf(x).any():
clamp = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp, max=clamp)
return x
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super(T5LayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.type_as(self.weight)
return self.weight * x
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.0, device="cuda:0"):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
self.device = device
# layers
self.q = nn.Linear(dim, dim_attn, bias=False, device=device)
self.k = nn.Linear(dim, dim_attn, bias=False, device=device)
self.v = nn.Linear(dim, dim_attn, bias=False, device=device)
self.o = nn.Linear(dim_attn, dim, bias=False, device=device)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context=None, mask=None, pos_bias=None):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
mask: [B, L2] or [B, L1, L2] or None.
"""
# check inputs
context = x if context is None else context
b, n, c = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).view(b, -1, n, c)
k = self.k(context).view(b, -1, n, c)
v = self.v(context).view(b, -1, n, c)
# attention bias
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
if pos_bias is not None:
attn_bias += pos_bias
if mask is not None:
assert mask.ndim in [2, 3]
mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
# compute attention (T5 does not use scaling)
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum("bnij,bjnc->binc", attn, v)
# output
x = x.reshape(b, -1, n * c)
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.0, device="cuda:0"):
super(T5FeedForward, self).__init__()
self.dim = dim
self.dim_ffn = dim_ffn
self.device = device
# layers
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False, device=device), GELU())
self.fc1 = nn.Linear(dim, dim_ffn, bias=False, device=device)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False, device=device)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x) * self.gate(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class T5SelfAttention(nn.Module):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.0, device="cuda:0"):
super(T5SelfAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
self.device = device
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout, device)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout, device)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, device=device)
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128, device="cuda:0"):
super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
self.device = device
# layers
self.embedding = nn.Embedding(num_buckets, num_heads, device=device)
def forward(self, lq, lk):
device = self.embedding.weight.device
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
rel_pos = self._relative_position_bucket(rel_pos)
rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
return rel_pos_embeds.contiguous()
def _relative_position_bucket(self, rel_pos):
# preprocess
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).long() * num_buckets
rel_pos = torch.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = 0
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
# embeddings for small and large positions
max_exact = num_buckets // 2
rel_pos_large = (
max_exact
+ (
torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)
).long()
)
rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
return rel_buckets
def init_weights(m):
if isinstance(m, T5LayerNorm):
nn.init.ones_(m.weight)
elif isinstance(m, T5FeedForward):
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
elif isinstance(m, T5Attention):
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
elif isinstance(m, T5RelativeEmbedding):
nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
class WanTextEncoderStateDictConverter(StateDictConverter):
def __init__(self, num_encoder_layers: int = 24):
self.num_encoder_layers = num_encoder_layers
def _from_diffusers(self, state_dict):
rename_dict = {
"shared.weight": "token_embedding.weight",
"encoder.final_layer_norm.weight": "norm.weight",
}
for i in range(self.num_encoder_layers):
rename_dict.update(
{
f"encoder.block.{i}.layer.0.SelfAttention.q.weight": f"blocks.{i}.attn.q.weight",
f"encoder.block.{i}.layer.0.SelfAttention.k.weight": f"blocks.{i}.attn.k.weight",
f"encoder.block.{i}.layer.0.SelfAttention.v.weight": f"blocks.{i}.attn.v.weight",
f"encoder.block.{i}.layer.0.SelfAttention.o.weight": f"blocks.{i}.attn.o.weight",
f"encoder.block.{i}.layer.0.SelfAttention.relative_attention_bias.weight": f"blocks.{i}.pos_embedding.embedding.weight",
f"encoder.block.{i}.layer.0.layer_norm.weight": f"blocks.{i}.norm1.weight",
f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight": f"blocks.{i}.ffn.gate.0.weight",
f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight": f"blocks.{i}.ffn.fc1.weight",
f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight": f"blocks.{i}.ffn.fc2.weight",
f"encoder.block.{i}.layer.1.layer_norm.weight": f"blocks.{i}.norm2.weight",
}
)
new_state_dict = {}
for key, param in state_dict.items():
if key in rename_dict:
new_state_dict[rename_dict[key]] = param
return new_state_dict
def _from_enc_blk(self, state_dict):
# Map enc.blk.* keys to blocks.* keys
rename_dict = {
"token_embd.weight": "token_embedding.weight",
"enc.output_norm.weight": "norm.weight",
}
for i in range(self.num_encoder_layers):
rename_dict.update({
f"enc.blk.{i}.attn_norm.weight": f"blocks.{i}.norm1.weight",
f"enc.blk.{i}.attn_q.weight": f"blocks.{i}.attn.q.weight",
f"enc.blk.{i}.attn_k.weight": f"blocks.{i}.attn.k.weight",
f"enc.blk.{i}.attn_v.weight": f"blocks.{i}.attn.v.weight",
f"enc.blk.{i}.attn_o.weight": f"blocks.{i}.attn.o.weight",
f"enc.blk.{i}.attn_rel_b.weight": f"blocks.{i}.pos_embedding.embedding.weight",
f"enc.blk.{i}.ffn_norm.weight": f"blocks.{i}.norm2.weight",
f"enc.blk.{i}.ffn_gate.weight": f"blocks.{i}.ffn.gate.0.weight",
f"enc.blk.{i}.ffn_up.weight": f"blocks.{i}.ffn.fc1.weight",
f"enc.blk.{i}.ffn_down.weight": f"blocks.{i}.ffn.fc2.weight",
})
new_state_dict = {}
for key, param in state_dict.items():
if key in rename_dict:
new_state_dict[rename_dict[key]] = param
return new_state_dict
def convert(self, state_dict):
if "encoder.final_layer_norm.weight" in state_dict:
logger.info("use diffusers format state dict")
return self._from_diffusers(state_dict)
if any(k.startswith("enc.blk.") for k in state_dict):
logger.info("use enc.blk.* format state dict")
return self._from_enc_blk(state_dict)
# Try to detect if already in model format (blocks.0.attn.q.weight, etc.)
if any(k.startswith("blocks.") for k in state_dict):
logger.info("use native model format state dict")
return state_dict
logger.warning("Unknown state dict format, passing through unchanged")
return state_dict
class WanTextEncoder(PreTrainedModel):
converter = WanTextEncoderStateDictConverter()
def __init__(
self,
vocab=256384,
dim=4096,
dim_attn=4096,
dim_ffn=10240,
num_heads=64,
num_layers=24,
num_buckets=32,
shared_pos=False,
dropout=0.0,
device: str = "cuda:0",
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim, device=device)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, device=device) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, device)
for _ in range(num_layers)
]
)
self.norm = T5LayerNorm(dim)
def forward(self, ids, mask=None):
with gguf_inference():
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, torch.Tensor],
device: str,
dtype: torch.dtype,
):
model = cls(device="meta", dtype=dtype)
model = model.requires_grad_(False)
model.load_state_dict(state_dict, assign=True)
# Allow partial loading for missing/unexpected keys
#model.load_state_dict(state_dict, strict=False, assign=True)
model.to(device=device, dtype=dtype, non_blocking=True)
return model