Skip to content

Commit e2c62d6

Browse files
committed
up
1 parent e6d4612 commit e2c62d6

1 file changed

Lines changed: 72 additions & 41 deletions

File tree

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 72 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def __call__(
9797
hidden_states: torch.Tensor,
9898
encoder_hidden_states: Optional[torch.Tensor] = None,
9999
attention_mask: Optional[torch.Tensor] = None,
100-
freqs_cis: Optional[torch.Tensor] = None,
100+
freqs_cos: Optional[torch.Tensor] = None,
101+
freqs_sin: Optional[torch.Tensor] = None,
101102
) -> torch.Tensor:
102103
query = attn.to_q(hidden_states)
103104
key = attn.to_k(hidden_states)
@@ -113,17 +114,26 @@ def __call__(
113114
if attn.norm_k is not None:
114115
key = attn.norm_k(key)
115116

116-
# Apply RoPE
117-
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
118-
with torch.amp.autocast("cuda", enabled=False):
119-
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
120-
freqs_cis = freqs_cis.unsqueeze(2)
121-
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
122-
return x_out.type_as(x_in) # todo
123-
124-
if freqs_cis is not None:
125-
query = apply_rotary_emb(query, freqs_cis)
126-
key = apply_rotary_emb(key, freqs_cis)
117+
# # Apply RoPE
118+
# def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
119+
# with torch.amp.autocast("cuda", enabled=False):
120+
# x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
121+
# freqs_cis = freqs_cis.unsqueeze(2)
122+
# x_out = torch.view_as_real(x * freqs_cis).flatten(3)
123+
# return x_out.type_as(x_in) # todo
124+
125+
def apply_rotary_emb(x_in: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
126+
freqs_cos = freqs_cos.unsqueeze(2) # [batch, seq, 1, head_dim//2]
127+
freqs_sin = freqs_sin.unsqueeze(2)
128+
x = x_in.reshape(*x_in.shape[:-1], -1, 2)
129+
x0, x1 = x[..., 0], x[..., 1]
130+
out0 = x0 * freqs_cos - x1 * freqs_sin
131+
out1 = x0 * freqs_sin + x1 * freqs_cos
132+
return torch.stack([out0, out1], dim=-1).flatten(-2).type_as(x_in)
133+
134+
if freqs_cos is not None and freqs_sin is not None:
135+
query = apply_rotary_emb(query, freqs_cos, freqs_sin)
136+
key = apply_rotary_emb(key, freqs_cos, freqs_sin)
127137

128138
# Cast to correct dtype
129139
dtype = query.dtype
@@ -219,7 +229,8 @@ def forward(
219229
self,
220230
x: torch.Tensor,
221231
attn_mask: torch.Tensor,
222-
freqs_cis: torch.Tensor,
232+
freqs_cos: torch.Tensor,
233+
freqs_sin: torch.Tensor,
223234
adaln_input: Optional[torch.Tensor] = None,
224235
):
225236
if self.modulation:
@@ -232,7 +243,8 @@ def forward(
232243
attn_out = self.attention(
233244
self.attention_norm1(x) * scale_msa,
234245
attention_mask=attn_mask,
235-
freqs_cis=freqs_cis,
246+
freqs_cos=freqs_cos,
247+
freqs_sin=freqs_sin,
236248
)
237249
x = x + gate_msa * self.attention_norm2(attn_out)
238250

@@ -247,7 +259,8 @@ def forward(
247259
attn_out = self.attention(
248260
self.attention_norm1(x),
249261
attention_mask=attn_mask,
250-
freqs_cis=freqs_cis,
262+
freqs_cos=freqs_cos,
263+
freqs_sin=freqs_sin,
251264
)
252265
x = x + self.attention_norm2(attn_out)
253266

@@ -290,39 +303,48 @@ def __init__(
290303
self.axes_dims = axes_dims
291304
self.axes_lens = axes_lens
292305
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
293-
self.freqs_cis = None
306+
self.freqs_cos = None
307+
self.freqs_sin = None
294308

295309
@staticmethod
296310
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
297311
with torch.device("cpu"):
298-
freqs_cis = []
312+
freqs_cos = []
313+
freqs_sin = []
299314
for i, (d, e) in enumerate(zip(dim, end)):
300315
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
301316
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
302317
freqs = torch.outer(timestep, freqs).float()
303-
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
304-
freqs_cis.append(freqs_cis_i)
318+
freqs_cos.append(freqs.cos())
319+
freqs_sin.append(freqs.sin())
320+
# freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
321+
# freqs_cis.append(freqs_cis_i)
305322

306-
return freqs_cis
323+
return freqs_cos, freqs_sin
307324

308325
def __call__(self, ids: torch.Tensor):
309326
assert ids.ndim == 2
310327
assert ids.shape[-1] == len(self.axes_dims)
311328
device = ids.device
312329

313-
if self.freqs_cis is None:
314-
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
315-
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
330+
if self.freqs_cos is None or self.freqs_sin is None:
331+
self.freqs_cos, self.freqs_sin = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
332+
self.freqs_cos = [f.to(device) for f in self.freqs_cos]
333+
self.freqs_sin = [f.to(device) for f in self.freqs_sin]
316334
else:
317335
# Ensure freqs_cis are on the same device as ids
318-
if self.freqs_cis[0].device != device:
319-
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
336+
if self.freqs_cos[0].device != device:
337+
self.freqs_cos = [f.to(device) for f in self.freqs_cos]
338+
if self.freqs_sin[0].device != device:
339+
self.freqs_sin = [f.to(device) for f in self.freqs_sin]
320340

321-
result = []
341+
cos_result = []
342+
sin_result = []
322343
for i in range(len(self.axes_dims)):
323344
index = ids[:, i]
324-
result.append(self.freqs_cis[i][index])
325-
return torch.cat(result, dim=-1)
345+
cos_result.append(self.freqs_cos[i][index])
346+
sin_result.append(self.freqs_sin[i][index])
347+
return torch.cat(cos_result, dim=-1), torch.cat(sin_result, dim=-1)
326348

327349

328350
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
@@ -587,20 +609,23 @@ def forward(
587609
adaln_input = t.type_as(x)
588610
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
589611
x = list(x.split(x_item_seqlens, dim=0))
590-
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
612+
x_freqs_cos, x_freqs_sin =self.rope_embedder(torch.cat(x_pos_ids, dim=0))
613+
x_freqs_cos = list(x_freqs_cos.split(x_item_seqlens, dim=0))
614+
x_freqs_sin = list(x_freqs_sin.split(x_item_seqlens, dim=0))
591615

592616
x = pad_sequence(x, batch_first=True, padding_value=0.0)
593-
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
617+
x_freqs_cos = pad_sequence(x_freqs_cos, batch_first=True, padding_value=0.0)
618+
x_freqs_sin = pad_sequence(x_freqs_sin, batch_first=True, padding_value=0.0)
594619
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
595620
for i, seq_len in enumerate(x_item_seqlens):
596621
x_attn_mask[i, :seq_len] = 1
597622

598623
if torch.is_grad_enabled() and self.gradient_checkpointing:
599624
for layer in self.noise_refiner:
600-
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
625+
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cos, x_freqs_sin, adaln_input)
601626
else:
602627
for layer in self.noise_refiner:
603-
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
628+
x = layer(x, x_attn_mask, x_freqs_cos, x_freqs_sin, adaln_input)
604629

605630
# cap embed & refine
606631
cap_item_seqlens = [len(_) for _ in cap_feats]
@@ -611,47 +636,53 @@ def forward(
611636
cap_feats = self.cap_embedder(cap_feats)
612637
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
613638
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
614-
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
639+
cap_freqs_cos, cap_freqs_sin = self.rope_embedder(torch.cat(cap_pos_ids, dim=0))
640+
cap_freqs_cos = list(cap_freqs_cos.split(cap_item_seqlens, dim=0))
641+
cap_freqs_sin = list(cap_freqs_sin.split(cap_item_seqlens, dim=0))
615642

616643
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
617-
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
644+
cap_freqs_cos = pad_sequence(cap_freqs_cos, batch_first=True, padding_value=0.0)
645+
cap_freqs_sin = pad_sequence(cap_freqs_sin, batch_first=True, padding_value=0.0)
618646
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
619647
for i, seq_len in enumerate(cap_item_seqlens):
620648
cap_attn_mask[i, :seq_len] = 1
621649

622650
if torch.is_grad_enabled() and self.gradient_checkpointing:
623651
for layer in self.context_refiner:
624-
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
652+
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cos, cap_freqs_sin)
625653
else:
626654
for layer in self.context_refiner:
627-
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
655+
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cos, cap_freqs_sin)
628656

629657
# unified
630658
unified = []
631-
unified_freqs_cis = []
659+
unified_freqs_cos = []
660+
unified_freqs_sin = []
632661
for i in range(bsz):
633662
x_len = x_item_seqlens[i]
634663
cap_len = cap_item_seqlens[i]
635664
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
636-
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
665+
unified_freqs_cos.append(torch.cat([x_freqs_cos[i][:x_len], cap_freqs_cos[i][:cap_len]]))
666+
unified_freqs_sin.append(torch.cat([x_freqs_sin[i][:x_len], cap_freqs_sin[i][:cap_len]]))
637667
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
638668
assert unified_item_seqlens == [len(_) for _ in unified]
639669
unified_max_item_seqlen = max(unified_item_seqlens)
640670

641671
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
642-
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
672+
unified_freqs_cos = pad_sequence(unified_freqs_cos, batch_first=True, padding_value=0.0)
673+
unified_freqs_sin = pad_sequence(unified_freqs_sin, batch_first=True, padding_value=0.0)
643674
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
644675
for i, seq_len in enumerate(unified_item_seqlens):
645676
unified_attn_mask[i, :seq_len] = 1
646677

647678
if torch.is_grad_enabled() and self.gradient_checkpointing:
648679
for layer in self.layers:
649680
unified = self._gradient_checkpointing_func(
650-
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
681+
layer, unified, unified_attn_mask, unified_freqs_cos, unified_freqs_sin, adaln_input
651682
)
652683
else:
653684
for layer in self.layers:
654-
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
685+
unified = layer(unified, unified_attn_mask, unified_freqs_cos, unified_freqs_sin, adaln_input)
655686

656687
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
657688
unified = list(unified.unbind(dim=0))

0 commit comments

Comments
 (0)