Skip to content

Commit 161f6de

Browse files
committed
remove dependence on rotary embeddings
1 parent febf165 commit 161f6de

1 file changed

Lines changed: 2 additions & 9 deletions

File tree

mambular/arch_utils/layer_utils/attention_utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch.nn as nn
66
import torch.nn.functional as F
77
from einops import rearrange
8-
from rotary_embedding_torch import RotaryEmbedding
98

109

1110
class GEGLU(nn.Module):
@@ -25,7 +24,7 @@ def FeedForward(dim, mult=4, dropout=0.0):
2524

2625

2726
class Attention(nn.Module):
28-
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
27+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
2928
super().__init__()
3029
inner_dim = dim_head * heads
3130
self.heads = heads
@@ -34,18 +33,13 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
3433
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
3534
self.to_out = nn.Linear(inner_dim, dim, bias=False)
3635
self.dropout = nn.Dropout(dropout)
37-
self.rotary = rotary
3836
dim = np.int64(dim / 2)
39-
self.rotary_embedding = RotaryEmbedding(dim=dim)
4037

4138
def forward(self, x):
4239
h = self.heads
4340
x = self.norm(x)
4441
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
4542
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) # type: ignore
46-
if self.rotary:
47-
q = self.rotary_embedding.rotate_queries_or_keys(q)
48-
k = self.rotary_embedding.rotate_queries_or_keys(k)
4943
q = q * self.scale
5044

5145
sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)
@@ -61,7 +55,7 @@ def forward(self, x):
6155

6256

6357
class Transformer(nn.Module):
64-
def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False):
58+
def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout):
6559
super().__init__()
6660
self.layers = nn.ModuleList([])
6761

@@ -74,7 +68,6 @@ def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary
7468
heads=heads,
7569
dim_head=dim_head,
7670
dropout=attn_dropout,
77-
rotary=rotary,
7871
),
7972
FeedForward(dim, dropout=ff_dropout),
8073
]

0 commit comments

Comments
 (0)