Skip to content

Commit 1638ccd

Browse files
authored
refactor RoPE (#276)
refactor RoPE
1 parent 47ad6f6 commit 1638ccd

1 file changed

Lines changed: 53 additions & 299 deletions

File tree

Lines changed: 53 additions & 299 deletions
Original file line numberDiff line numberDiff line change
@@ -1,323 +1,77 @@
11
from __future__ import annotations
2-
from math import pi, log
3-
42
import torch
5-
from torch.amp import autocast
6-
from torch.nn import Module, ModuleList
7-
from torch import nn, einsum, broadcast_tensors, Tensor
8-
3+
from torch import nn, einsum, Tensor
4+
from torch.nn import Module
95
from einops import rearrange, repeat
106

11-
from typing import Literal
12-
13-
# helper functions
14-
15-
def exists(val):
16-
return val is not None
17-
18-
def default(val, d):
19-
return val if exists(val) else d
20-
21-
# broadcat, as tortoise-tts was using it
22-
23-
def broadcat(tensors, dim = -1):
24-
broadcasted_tensors = broadcast_tensors(*tensors)
25-
return torch.cat(broadcasted_tensors, dim = dim)
26-
27-
def slice_at_dim(t, dim_slice: slice, *, dim):
28-
dim += (t.ndim if dim < 0 else 0)
29-
colons = [slice(None)] * t.ndim
30-
colons[dim] = dim_slice
31-
return t[tuple(colons)]
32-
33-
# rotary embedding helper functions
34-
35-
def rotate_half(x):
36-
x = rearrange(x, '... (d r) -> ... d r', r = 2)
37-
x1, x2 = x.unbind(dim = -1)
38-
x = torch.stack((-x2, x1), dim = -1)
39-
return rearrange(x, '... d r -> ... (d r)')
40-
41-
@autocast('cuda', enabled = False)
42-
def apply_rotary_emb(
43-
freqs,
44-
t,
45-
start_index = 0,
46-
scale = 1.,
47-
seq_dim = -2,
48-
freqs_seq_dim = None
49-
):
50-
dtype = t.dtype
517

52-
if not exists(freqs_seq_dim):
53-
if freqs.ndim == 2 or t.ndim == 3:
54-
freqs_seq_dim = 0
8+
def rotate_half(x: Tensor, interleaved=True) -> Tensor:
9+
if not interleaved:
10+
# x_half1, x_half2 = x.chunk(2, dim=-1)
11+
# Using torch.split instead of chunk for ONNX export compatibility.
12+
x1, x2 = torch.split(x, x.size(-1) // 2, dim=-1)
13+
return torch.cat((-x2, x1), dim=-1)
14+
else:
15+
x = rearrange(x, '... (d r) -> ... d r', r=2)
16+
x1, x2 = x.unbind(dim=-1)
17+
x = torch.stack((-x2, x1), dim=-1)
18+
return rearrange(x, '... d r -> ... (d r)')
5519

56-
if t.ndim == 3 or exists(freqs_seq_dim):
57-
seq_len = t.shape[seq_dim]
58-
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim)
5920

21+
def apply_rotary_emb(freqs: Tensor, t: Tensor, interleaved=True) -> Tensor:
6022
rot_dim = freqs.shape[-1]
61-
end_index = start_index + rot_dim
23+
t_to_rotate = t[..., :rot_dim]
24+
t_pass_through = t[..., rot_dim:]
25+
26+
t_rotated = (t_to_rotate * freqs.cos()) + (rotate_half(t_to_rotate, interleaved) * freqs.sin())
27+
28+
return torch.cat((t_rotated, t_pass_through), dim=-1)
6229

63-
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
64-
65-
# Split t into three parts: left, middle (to be transformed), and right
66-
t_left = t[..., :start_index]
67-
t_middle = t[..., start_index:end_index]
68-
t_right = t[..., end_index:]
69-
70-
# Apply rotary embeddings without modifying t in place
71-
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
72-
73-
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
74-
75-
return out.type(dtype)
76-
77-
# learned rotation helpers
78-
79-
def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
80-
if exists(freq_ranges):
81-
rotations = einsum('..., f -> ... f', rotations, freq_ranges)
82-
rotations = rearrange(rotations, '... r f -> ... (r f)')
83-
84-
rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
85-
return apply_rotary_emb(rotations, t, start_index = start_index)
86-
87-
# classes
8830

8931
class RotaryEmbedding(Module):
9032
def __init__(
9133
self,
9234
dim,
93-
custom_freqs: Tensor | None = None,
94-
freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang',
95-
theta = 10000,
96-
max_freq = 10,
97-
num_freqs = 1,
98-
learned_freq = False,
99-
use_xpos = False,
100-
xpos_scale_base = 512,
101-
interpolate_factor = 1.,
102-
theta_rescale_factor = 1.,
103-
seq_before_head_dim = False,
104-
cache_if_possible = True,
105-
cache_max_seq_len = 8192
35+
theta=10000,
36+
precompute_len=8192,
37+
cache_max_seq_len=8192,
38+
interleaved: bool = True
10639
):
10740
super().__init__()
108-
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
109-
# has some connection to NTK literature
110-
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
111-
112-
theta *= theta_rescale_factor ** (dim / (dim - 2))
113-
114-
self.freqs_for = freqs_for
41+
self.interleaved = interleaved
11542

116-
if exists(custom_freqs):
117-
freqs = custom_freqs
118-
elif freqs_for == 'lang':
119-
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
120-
elif freqs_for == 'pixel':
121-
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
122-
elif freqs_for == 'constant':
123-
freqs = torch.ones(num_freqs).float()
43+
inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
44+
self.register_buffer('inv_freq', inv_freq)
12445

125-
self.cache_if_possible = cache_if_possible
126-
self.cache_max_seq_len = cache_max_seq_len
46+
self._cache_max_seq_len = max(precompute_len, cache_max_seq_len)
47+
self._precomputed_len = precompute_len
12748

128-
self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False)
49+
self.register_buffer('cached_freqs', None, persistent=True)
12950
self.cached_freqs_seq_len = 0
51+
52+
if self._precomputed_len > 0:
53+
self._precompute_cache(self._precomputed_len)
13054

131-
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
132-
133-
self.learned_freq = learned_freq
134-
135-
# dummy for device
136-
137-
self.register_buffer('dummy', torch.tensor(0), persistent = False)
138-
139-
# default sequence dimension
140-
141-
self.seq_before_head_dim = seq_before_head_dim
142-
self.default_seq_dim = -3 if seq_before_head_dim else -2
143-
144-
# interpolation factors
145-
146-
assert interpolate_factor >= 1.
147-
self.interpolate_factor = interpolate_factor
148-
149-
# xpos
150-
151-
self.use_xpos = use_xpos
152-
153-
if not use_xpos:
154-
return
155-
156-
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
157-
self.scale_base = xpos_scale_base
158-
159-
self.register_buffer('scale', scale, persistent = False)
160-
self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False)
161-
self.cached_scales_seq_len = 0
162-
163-
# add apply_rotary_emb as static method
164-
165-
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
166-
167-
@property
168-
def device(self):
169-
return self.dummy.device
170-
171-
def get_seq_pos(self, seq_len, device, dtype, offset = 0):
172-
return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
173-
174-
def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, scale = None):
175-
seq_dim = default(seq_dim, self.default_seq_dim)
176-
177-
assert not self.use_xpos or exists(scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
178-
179-
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
180-
181-
seq = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset)
182-
183-
freqs = self.forward(seq, seq_len = seq_len, offset = offset)
184-
185-
if seq_dim == -3:
186-
freqs = rearrange(freqs, 'n d -> n 1 d')
187-
188-
return apply_rotary_emb(freqs, t, scale = default(scale, 1.), seq_dim = seq_dim)
189-
190-
def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
191-
dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim)
192-
193-
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
194-
assert q_len <= k_len
195-
196-
q_scale = k_scale = 1.
197-
198-
if self.use_xpos:
199-
seq = self.get_seq_pos(k_len, dtype = dtype, device = device)
200-
201-
q_scale = self.get_scale(seq[-q_len:]).type(dtype)
202-
k_scale = self.get_scale(seq).type(dtype)
203-
204-
rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, scale = q_scale, offset = k_len - q_len + offset)
205-
rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, scale = k_scale ** -1)
206-
207-
rotated_q = rotated_q.type(q.dtype)
208-
rotated_k = rotated_k.type(k.dtype)
209-
210-
return rotated_q, rotated_k
211-
212-
def rotate_queries_and_keys(self, q, k, seq_dim = None):
213-
seq_dim = default(seq_dim, self.default_seq_dim)
214-
215-
assert self.use_xpos
216-
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
217-
218-
seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
219-
220-
freqs = self.forward(seq, seq_len = seq_len)
221-
scale = self.get_scale(seq, seq_len = seq_len).to(dtype)
222-
223-
if seq_dim == -3:
224-
freqs = rearrange(freqs, 'n d -> n 1 d')
225-
scale = rearrange(scale, 'n d -> n 1 d')
226-
227-
rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim)
228-
rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim)
229-
230-
rotated_q = rotated_q.type(q.dtype)
231-
rotated_k = rotated_k.type(k.dtype)
232-
233-
return rotated_q, rotated_k
234-
235-
def get_scale(
236-
self,
237-
t: Tensor,
238-
seq_len: int | None = None,
239-
offset = 0
240-
):
241-
assert self.use_xpos
242-
243-
should_cache = (
244-
self.cache_if_possible and
245-
exists(seq_len) and
246-
(offset + seq_len) <= self.cache_max_seq_len
247-
)
248-
249-
if (
250-
should_cache and \
251-
exists(self.cached_scales) and \
252-
(seq_len + offset) <= self.cached_scales_seq_len
253-
):
254-
return self.cached_scales[offset:(offset + seq_len)]
255-
256-
scale = 1.
257-
if self.use_xpos:
258-
power = (t - len(t) // 2) / self.scale_base
259-
scale = self.scale ** rearrange(power, 'n -> n 1')
260-
scale = repeat(scale, 'n d -> n (d r)', r = 2)
261-
262-
if should_cache and offset == 0:
263-
self.cached_scales[:seq_len] = scale.detach()
264-
self.cached_scales_seq_len = seq_len
265-
266-
return scale
267-
268-
def get_axial_freqs(self, *dims):
269-
Colon = slice(None)
270-
all_freqs = []
271-
272-
for ind, dim in enumerate(dims):
273-
if self.freqs_for == 'pixel':
274-
pos = torch.linspace(-1, 1, steps = dim, device = self.device)
275-
else:
276-
pos = torch.arange(dim, device = self.device)
277-
278-
freqs = self.forward(pos, seq_len = dim)
279-
280-
all_axis = [None] * len(dims)
281-
all_axis[ind] = Colon
282-
283-
new_axis_slice = (Ellipsis, *all_axis, Colon)
284-
all_freqs.append(freqs[new_axis_slice])
285-
286-
all_freqs = broadcast_tensors(*all_freqs)
287-
return torch.cat(all_freqs, dim = -1)
288-
289-
@autocast('cuda', enabled = False)
290-
def forward(
291-
self,
292-
t: Tensor,
293-
seq_len: int | None = None,
294-
offset = 0
295-
):
296-
should_cache = (
297-
self.cache_if_possible and
298-
not self.learned_freq and
299-
exists(seq_len) and
300-
self.freqs_for != 'pixel' and
301-
(offset + seq_len) <= self.cache_max_seq_len
302-
)
303-
304-
if (
305-
should_cache and \
306-
exists(self.cached_freqs) and \
307-
(offset + seq_len) <= self.cached_freqs_seq_len
308-
):
309-
freqs = self.cached_freqs[offset:(offset + seq_len)].detach()
310-
# Fix issue about 'find_unused_parameters' when DDP training.(#244)
311-
freqs = freqs + 0. * self.freqs.sum()
312-
return freqs
313-
314-
freqs = self.freqs
55+
def _precompute_cache(self, seq_len: int):
56+
seq = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
57+
freqs = einsum('i, j -> i j', seq, self.inv_freq)
58+
59+
if self.interleaved:
60+
freqs = repeat(freqs, '... n -> ... (n r)', r=2)
61+
else:
62+
freqs = torch.cat((freqs, freqs), dim=-1)
31563

316-
freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
317-
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
64+
self.cached_freqs = freqs
65+
self.cached_freqs_seq_len = seq_len
31866

319-
if should_cache and offset == 0:
320-
self.cached_freqs[:seq_len] = freqs.detach()
321-
self.cached_freqs_seq_len = seq_len
67+
def forward(self, t: Tensor, seq_len: int) -> Tensor:
68+
if self.cached_freqs is None or seq_len > self.cached_freqs_seq_len:
69+
self._precompute_cache(seq_len)
70+
71+
return self.cached_freqs[0: seq_len].detach()
32272

323-
return freqs
73+
def rotate_queries_or_keys(self, t: Tensor) -> Tensor:
74+
device, dtype, seq_len = t.device, t.dtype, t.shape[-2]
75+
freqs = self.forward(t, seq_len=seq_len)
76+
77+
return apply_rotary_emb(freqs.to(device=device, dtype=dtype), t, self.interleaved)

0 commit comments

Comments
 (0)