55import torch .nn as nn
66import torch .nn .functional as F
77from einops import rearrange
8- from rotary_embedding_torch import RotaryEmbedding
98
109
1110class GEGLU (nn .Module ):
@@ -25,7 +24,7 @@ def FeedForward(dim, mult=4, dropout=0.0):
2524
2625
2726class Attention (nn .Module ):
28- def __init__ (self , dim , heads = 8 , dim_head = 64 , dropout = 0.0 , rotary = False ):
27+ def __init__ (self , dim , heads = 8 , dim_head = 64 , dropout = 0.0 ):
2928 super ().__init__ ()
3029 inner_dim = dim_head * heads
3130 self .heads = heads
@@ -34,18 +33,13 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
3433 self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
3534 self .to_out = nn .Linear (inner_dim , dim , bias = False )
3635 self .dropout = nn .Dropout (dropout )
37- self .rotary = rotary
3836 dim = np .int64 (dim / 2 )
39- self .rotary_embedding = RotaryEmbedding (dim = dim )
4037
4138 def forward (self , x ):
4239 h = self .heads
4340 x = self .norm (x )
4441 q , k , v = self .to_qkv (x ).chunk (3 , dim = - 1 )
4542 q , k , v = map (lambda t : rearrange (t , "b n (h d) -> b h n d" , h = h ), (q , k , v )) # type: ignore
46- if self .rotary :
47- q = self .rotary_embedding .rotate_queries_or_keys (q )
48- k = self .rotary_embedding .rotate_queries_or_keys (k )
4943 q = q * self .scale
5044
5145 sim = torch .einsum ("b h i d, b h j d -> b h i j" , q , k )
@@ -61,7 +55,7 @@ def forward(self, x):
6155
6256
6357class Transformer (nn .Module ):
64- def __init__ (self , dim , depth , heads , dim_head , attn_dropout , ff_dropout , rotary = False ):
58+ def __init__ (self , dim , depth , heads , dim_head , attn_dropout , ff_dropout ):
6559 super ().__init__ ()
6660 self .layers = nn .ModuleList ([])
6761
@@ -74,7 +68,6 @@ def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary
7468 heads = heads ,
7569 dim_head = dim_head ,
7670 dropout = attn_dropout ,
77- rotary = rotary ,
7871 ),
7972 FeedForward (dim , dropout = ff_dropout ),
8073 ]
0 commit comments