@@ -97,7 +97,8 @@ def __call__(
9797 hidden_states : torch .Tensor ,
9898 encoder_hidden_states : Optional [torch .Tensor ] = None ,
9999 attention_mask : Optional [torch .Tensor ] = None ,
100- freqs_cis : Optional [torch .Tensor ] = None ,
100+ freqs_cos : Optional [torch .Tensor ] = None ,
101+ freqs_sin : Optional [torch .Tensor ] = None ,
101102 ) -> torch .Tensor :
102103 query = attn .to_q (hidden_states )
103104 key = attn .to_k (hidden_states )
@@ -113,17 +114,26 @@ def __call__(
113114 if attn .norm_k is not None :
114115 key = attn .norm_k (key )
115116
116- # Apply RoPE
117- def apply_rotary_emb (x_in : torch .Tensor , freqs_cis : torch .Tensor ) -> torch .Tensor :
118- with torch .amp .autocast ("cuda" , enabled = False ):
119- x = torch .view_as_complex (x_in .float ().reshape (* x_in .shape [:- 1 ], - 1 , 2 ))
120- freqs_cis = freqs_cis .unsqueeze (2 )
121- x_out = torch .view_as_real (x * freqs_cis ).flatten (3 )
122- return x_out .type_as (x_in ) # todo
123-
124- if freqs_cis is not None :
125- query = apply_rotary_emb (query , freqs_cis )
126- key = apply_rotary_emb (key , freqs_cis )
117+ # # Apply RoPE
118+ # def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
119+ # with torch.amp.autocast("cuda", enabled=False):
120+ # x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
121+ # freqs_cis = freqs_cis.unsqueeze(2)
122+ # x_out = torch.view_as_real(x * freqs_cis).flatten(3)
123+ # return x_out.type_as(x_in) # todo
124+
125+ def apply_rotary_emb (x_in : torch .Tensor , freqs_cos : torch .Tensor , freqs_sin : torch .Tensor ) -> torch .Tensor :
126+ freqs_cos = freqs_cos .unsqueeze (2 ) # [batch, seq, 1, head_dim//2]
127+ freqs_sin = freqs_sin .unsqueeze (2 )
128+ x = x_in .reshape (* x_in .shape [:- 1 ], - 1 , 2 )
129+ x0 , x1 = x [..., 0 ], x [..., 1 ]
130+ out0 = x0 * freqs_cos - x1 * freqs_sin
131+ out1 = x0 * freqs_sin + x1 * freqs_cos
132+ return torch .stack ([out0 , out1 ], dim = - 1 ).flatten (- 2 ).type_as (x_in )
133+
134+ if freqs_cos is not None and freqs_sin is not None :
135+ query = apply_rotary_emb (query , freqs_cos , freqs_sin )
136+ key = apply_rotary_emb (key , freqs_cos , freqs_sin )
127137
128138 # Cast to correct dtype
129139 dtype = query .dtype
@@ -219,7 +229,8 @@ def forward(
219229 self ,
220230 x : torch .Tensor ,
221231 attn_mask : torch .Tensor ,
222- freqs_cis : torch .Tensor ,
232+ freqs_cos : torch .Tensor ,
233+ freqs_sin : torch .Tensor ,
223234 adaln_input : Optional [torch .Tensor ] = None ,
224235 ):
225236 if self .modulation :
@@ -232,7 +243,8 @@ def forward(
232243 attn_out = self .attention (
233244 self .attention_norm1 (x ) * scale_msa ,
234245 attention_mask = attn_mask ,
235- freqs_cis = freqs_cis ,
246+ freqs_cos = freqs_cos ,
247+ freqs_sin = freqs_sin ,
236248 )
237249 x = x + gate_msa * self .attention_norm2 (attn_out )
238250
@@ -247,7 +259,8 @@ def forward(
247259 attn_out = self .attention (
248260 self .attention_norm1 (x ),
249261 attention_mask = attn_mask ,
250- freqs_cis = freqs_cis ,
262+ freqs_cos = freqs_cos ,
263+ freqs_sin = freqs_sin ,
251264 )
252265 x = x + self .attention_norm2 (attn_out )
253266
@@ -290,39 +303,48 @@ def __init__(
290303 self .axes_dims = axes_dims
291304 self .axes_lens = axes_lens
292305 assert len (axes_dims ) == len (axes_lens ), "axes_dims and axes_lens must have the same length"
293- self .freqs_cis = None
306+ self .freqs_cos = None
307+ self .freqs_sin = None
294308
295309 @staticmethod
296310 def precompute_freqs_cis (dim : List [int ], end : List [int ], theta : float = 256.0 ):
297311 with torch .device ("cpu" ):
298- freqs_cis = []
312+ freqs_cos = []
313+ freqs_sin = []
299314 for i , (d , e ) in enumerate (zip (dim , end )):
300315 freqs = 1.0 / (theta ** (torch .arange (0 , d , 2 , dtype = torch .float64 , device = "cpu" ) / d ))
301316 timestep = torch .arange (e , device = freqs .device , dtype = torch .float64 )
302317 freqs = torch .outer (timestep , freqs ).float ()
303- freqs_cis_i = torch .polar (torch .ones_like (freqs ), freqs ).to (torch .complex64 ) # complex64
304- freqs_cis .append (freqs_cis_i )
318+ freqs_cos .append (freqs .cos ())
319+ freqs_sin .append (freqs .sin ())
320+ # freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
321+ # freqs_cis.append(freqs_cis_i)
305322
306- return freqs_cis
323+ return freqs_cos , freqs_sin
307324
308325 def __call__ (self , ids : torch .Tensor ):
309326 assert ids .ndim == 2
310327 assert ids .shape [- 1 ] == len (self .axes_dims )
311328 device = ids .device
312329
313- if self .freqs_cis is None :
314- self .freqs_cis = self .precompute_freqs_cis (self .axes_dims , self .axes_lens , theta = self .theta )
315- self .freqs_cis = [freqs_cis .to (device ) for freqs_cis in self .freqs_cis ]
330+ if self .freqs_cos is None or self .freqs_sin is None :
331+ self .freqs_cos , self .freqs_sin = self .precompute_freqs_cis (self .axes_dims , self .axes_lens , theta = self .theta )
332+ self .freqs_cos = [f .to (device ) for f in self .freqs_cos ]
333+ self .freqs_sin = [f .to (device ) for f in self .freqs_sin ]
316334 else :
317335 # Ensure freqs_cis are on the same device as ids
318- if self .freqs_cis [0 ].device != device :
319- self .freqs_cis = [freqs_cis .to (device ) for freqs_cis in self .freqs_cis ]
336+ if self .freqs_cos [0 ].device != device :
337+ self .freqs_cos = [f .to (device ) for f in self .freqs_cos ]
338+ if self .freqs_sin [0 ].device != device :
339+ self .freqs_sin = [f .to (device ) for f in self .freqs_sin ]
320340
321- result = []
341+ cos_result = []
342+ sin_result = []
322343 for i in range (len (self .axes_dims )):
323344 index = ids [:, i ]
324- result .append (self .freqs_cis [i ][index ])
325- return torch .cat (result , dim = - 1 )
345+ cos_result .append (self .freqs_cos [i ][index ])
346+ sin_result .append (self .freqs_sin [i ][index ])
347+ return torch .cat (cos_result , dim = - 1 ), torch .cat (sin_result , dim = - 1 )
326348
327349
328350class ZImageTransformer2DModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin ):
@@ -587,20 +609,23 @@ def forward(
587609 adaln_input = t .type_as (x )
588610 x [torch .cat (x_inner_pad_mask )] = self .x_pad_token
589611 x = list (x .split (x_item_seqlens , dim = 0 ))
590- x_freqs_cis = list (self .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split (x_item_seqlens , dim = 0 ))
612+ x_freqs_cos , x_freqs_sin = self .rope_embedder (torch .cat (x_pos_ids , dim = 0 ))
613+ x_freqs_cos = list (x_freqs_cos .split (x_item_seqlens , dim = 0 ))
614+ x_freqs_sin = list (x_freqs_sin .split (x_item_seqlens , dim = 0 ))
591615
592616 x = pad_sequence (x , batch_first = True , padding_value = 0.0 )
593- x_freqs_cis = pad_sequence (x_freqs_cis , batch_first = True , padding_value = 0.0 )
617+ x_freqs_cos = pad_sequence (x_freqs_cos , batch_first = True , padding_value = 0.0 )
618+ x_freqs_sin = pad_sequence (x_freqs_sin , batch_first = True , padding_value = 0.0 )
594619 x_attn_mask = torch .zeros ((bsz , x_max_item_seqlen ), dtype = torch .bool , device = device )
595620 for i , seq_len in enumerate (x_item_seqlens ):
596621 x_attn_mask [i , :seq_len ] = 1
597622
598623 if torch .is_grad_enabled () and self .gradient_checkpointing :
599624 for layer in self .noise_refiner :
600- x = self ._gradient_checkpointing_func (layer , x , x_attn_mask , x_freqs_cis , adaln_input )
625+ x = self ._gradient_checkpointing_func (layer , x , x_attn_mask , x_freqs_cos , x_freqs_sin , adaln_input )
601626 else :
602627 for layer in self .noise_refiner :
603- x = layer (x , x_attn_mask , x_freqs_cis , adaln_input )
628+ x = layer (x , x_attn_mask , x_freqs_cos , x_freqs_sin , adaln_input )
604629
605630 # cap embed & refine
606631 cap_item_seqlens = [len (_ ) for _ in cap_feats ]
@@ -611,47 +636,53 @@ def forward(
611636 cap_feats = self .cap_embedder (cap_feats )
612637 cap_feats [torch .cat (cap_inner_pad_mask )] = self .cap_pad_token
613638 cap_feats = list (cap_feats .split (cap_item_seqlens , dim = 0 ))
614- cap_freqs_cis = list (self .rope_embedder (torch .cat (cap_pos_ids , dim = 0 )).split (cap_item_seqlens , dim = 0 ))
639+ cap_freqs_cos , cap_freqs_sin = self .rope_embedder (torch .cat (cap_pos_ids , dim = 0 ))
640+ cap_freqs_cos = list (cap_freqs_cos .split (cap_item_seqlens , dim = 0 ))
641+ cap_freqs_sin = list (cap_freqs_sin .split (cap_item_seqlens , dim = 0 ))
615642
616643 cap_feats = pad_sequence (cap_feats , batch_first = True , padding_value = 0.0 )
617- cap_freqs_cis = pad_sequence (cap_freqs_cis , batch_first = True , padding_value = 0.0 )
644+ cap_freqs_cos = pad_sequence (cap_freqs_cos , batch_first = True , padding_value = 0.0 )
645+ cap_freqs_sin = pad_sequence (cap_freqs_sin , batch_first = True , padding_value = 0.0 )
618646 cap_attn_mask = torch .zeros ((bsz , cap_max_item_seqlen ), dtype = torch .bool , device = device )
619647 for i , seq_len in enumerate (cap_item_seqlens ):
620648 cap_attn_mask [i , :seq_len ] = 1
621649
622650 if torch .is_grad_enabled () and self .gradient_checkpointing :
623651 for layer in self .context_refiner :
624- cap_feats = self ._gradient_checkpointing_func (layer , cap_feats , cap_attn_mask , cap_freqs_cis )
652+ cap_feats = self ._gradient_checkpointing_func (layer , cap_feats , cap_attn_mask , cap_freqs_cos , cap_freqs_sin )
625653 else :
626654 for layer in self .context_refiner :
627- cap_feats = layer (cap_feats , cap_attn_mask , cap_freqs_cis )
655+ cap_feats = layer (cap_feats , cap_attn_mask , cap_freqs_cos , cap_freqs_sin )
628656
629657 # unified
630658 unified = []
631- unified_freqs_cis = []
659+ unified_freqs_cos = []
660+ unified_freqs_sin = []
632661 for i in range (bsz ):
633662 x_len = x_item_seqlens [i ]
634663 cap_len = cap_item_seqlens [i ]
635664 unified .append (torch .cat ([x [i ][:x_len ], cap_feats [i ][:cap_len ]]))
636- unified_freqs_cis .append (torch .cat ([x_freqs_cis [i ][:x_len ], cap_freqs_cis [i ][:cap_len ]]))
665+ unified_freqs_cos .append (torch .cat ([x_freqs_cos [i ][:x_len ], cap_freqs_cos [i ][:cap_len ]]))
666+ unified_freqs_sin .append (torch .cat ([x_freqs_sin [i ][:x_len ], cap_freqs_sin [i ][:cap_len ]]))
637667 unified_item_seqlens = [a + b for a , b in zip (cap_item_seqlens , x_item_seqlens )]
638668 assert unified_item_seqlens == [len (_ ) for _ in unified ]
639669 unified_max_item_seqlen = max (unified_item_seqlens )
640670
641671 unified = pad_sequence (unified , batch_first = True , padding_value = 0.0 )
642- unified_freqs_cis = pad_sequence (unified_freqs_cis , batch_first = True , padding_value = 0.0 )
672+ unified_freqs_cos = pad_sequence (unified_freqs_cos , batch_first = True , padding_value = 0.0 )
673+ unified_freqs_sin = pad_sequence (unified_freqs_sin , batch_first = True , padding_value = 0.0 )
643674 unified_attn_mask = torch .zeros ((bsz , unified_max_item_seqlen ), dtype = torch .bool , device = device )
644675 for i , seq_len in enumerate (unified_item_seqlens ):
645676 unified_attn_mask [i , :seq_len ] = 1
646677
647678 if torch .is_grad_enabled () and self .gradient_checkpointing :
648679 for layer in self .layers :
649680 unified = self ._gradient_checkpointing_func (
650- layer , unified , unified_attn_mask , unified_freqs_cis , adaln_input
681+ layer , unified , unified_attn_mask , unified_freqs_cos , unified_freqs_sin , adaln_input
651682 )
652683 else :
653684 for layer in self .layers :
654- unified = layer (unified , unified_attn_mask , unified_freqs_cis , adaln_input )
685+ unified = layer (unified , unified_attn_mask , unified_freqs_cos , unified_freqs_sin , adaln_input )
655686
656687 unified = self .all_final_layer [f"{ patch_size } -{ f_patch_size } " ](unified , adaln_input )
657688 unified = list (unified .unbind (dim = 0 ))
0 commit comments