|
1 | 1 | from __future__ import annotations |
2 | | -from math import pi, log |
3 | | - |
4 | 2 | import torch |
5 | | -from torch.amp import autocast |
6 | | -from torch.nn import Module, ModuleList |
7 | | -from torch import nn, einsum, broadcast_tensors, Tensor |
8 | | - |
| 3 | +from torch import nn, einsum, Tensor |
| 4 | +from torch.nn import Module |
9 | 5 | from einops import rearrange, repeat |
10 | 6 |
|
11 | | -from typing import Literal |
12 | | - |
13 | | -# helper functions |
14 | | - |
15 | | -def exists(val): |
16 | | - return val is not None |
17 | | - |
18 | | -def default(val, d): |
19 | | - return val if exists(val) else d |
20 | | - |
21 | | -# broadcat, as tortoise-tts was using it |
22 | | - |
23 | | -def broadcat(tensors, dim = -1): |
24 | | - broadcasted_tensors = broadcast_tensors(*tensors) |
25 | | - return torch.cat(broadcasted_tensors, dim = dim) |
26 | | - |
27 | | -def slice_at_dim(t, dim_slice: slice, *, dim): |
28 | | - dim += (t.ndim if dim < 0 else 0) |
29 | | - colons = [slice(None)] * t.ndim |
30 | | - colons[dim] = dim_slice |
31 | | - return t[tuple(colons)] |
32 | | - |
33 | | -# rotary embedding helper functions |
34 | | - |
35 | | -def rotate_half(x): |
36 | | - x = rearrange(x, '... (d r) -> ... d r', r = 2) |
37 | | - x1, x2 = x.unbind(dim = -1) |
38 | | - x = torch.stack((-x2, x1), dim = -1) |
39 | | - return rearrange(x, '... d r -> ... (d r)') |
40 | | - |
41 | | -@autocast('cuda', enabled = False) |
42 | | -def apply_rotary_emb( |
43 | | - freqs, |
44 | | - t, |
45 | | - start_index = 0, |
46 | | - scale = 1., |
47 | | - seq_dim = -2, |
48 | | - freqs_seq_dim = None |
49 | | -): |
50 | | - dtype = t.dtype |
51 | 7 |
|
52 | | - if not exists(freqs_seq_dim): |
53 | | - if freqs.ndim == 2 or t.ndim == 3: |
54 | | - freqs_seq_dim = 0 |
| 8 | +def rotate_half(x: Tensor, interleaved=True) -> Tensor: |
| 9 | + if not interleaved: |
| 10 | + # x_half1, x_half2 = x.chunk(2, dim=-1) |
| 11 | + # Using torch.split instead of chunk for ONNX export compatibility. |
| 12 | + x1, x2 = torch.split(x, x.size(-1) // 2, dim=-1) |
| 13 | + return torch.cat((-x2, x1), dim=-1) |
| 14 | + else: |
| 15 | + x = rearrange(x, '... (d r) -> ... d r', r=2) |
| 16 | + x1, x2 = x.unbind(dim=-1) |
| 17 | + x = torch.stack((-x2, x1), dim=-1) |
| 18 | + return rearrange(x, '... d r -> ... (d r)') |
55 | 19 |
|
56 | | - if t.ndim == 3 or exists(freqs_seq_dim): |
57 | | - seq_len = t.shape[seq_dim] |
58 | | - freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) |
59 | 20 |
|
| 21 | +def apply_rotary_emb(freqs: Tensor, t: Tensor, interleaved=True) -> Tensor: |
60 | 22 | rot_dim = freqs.shape[-1] |
61 | | - end_index = start_index + rot_dim |
| 23 | + t_to_rotate = t[..., :rot_dim] |
| 24 | + t_pass_through = t[..., rot_dim:] |
| 25 | + |
| 26 | + t_rotated = (t_to_rotate * freqs.cos()) + (rotate_half(t_to_rotate, interleaved) * freqs.sin()) |
| 27 | + |
| 28 | + return torch.cat((t_rotated, t_pass_through), dim=-1) |
62 | 29 |
|
63 | | - assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' |
64 | | - |
65 | | - # Split t into three parts: left, middle (to be transformed), and right |
66 | | - t_left = t[..., :start_index] |
67 | | - t_middle = t[..., start_index:end_index] |
68 | | - t_right = t[..., end_index:] |
69 | | - |
70 | | - # Apply rotary embeddings without modifying t in place |
71 | | - t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) |
72 | | - |
73 | | - out = torch.cat((t_left, t_transformed, t_right), dim=-1) |
74 | | - |
75 | | - return out.type(dtype) |
76 | | - |
77 | | -# learned rotation helpers |
78 | | - |
79 | | -def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): |
80 | | - if exists(freq_ranges): |
81 | | - rotations = einsum('..., f -> ... f', rotations, freq_ranges) |
82 | | - rotations = rearrange(rotations, '... r f -> ... (r f)') |
83 | | - |
84 | | - rotations = repeat(rotations, '... n -> ... (n r)', r = 2) |
85 | | - return apply_rotary_emb(rotations, t, start_index = start_index) |
86 | | - |
87 | | -# classes |
88 | 30 |
|
89 | 31 | class RotaryEmbedding(Module): |
90 | 32 | def __init__( |
91 | 33 | self, |
92 | 34 | dim, |
93 | | - custom_freqs: Tensor | None = None, |
94 | | - freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang', |
95 | | - theta = 10000, |
96 | | - max_freq = 10, |
97 | | - num_freqs = 1, |
98 | | - learned_freq = False, |
99 | | - use_xpos = False, |
100 | | - xpos_scale_base = 512, |
101 | | - interpolate_factor = 1., |
102 | | - theta_rescale_factor = 1., |
103 | | - seq_before_head_dim = False, |
104 | | - cache_if_possible = True, |
105 | | - cache_max_seq_len = 8192 |
| 35 | + theta=10000, |
| 36 | + precompute_len=8192, |
| 37 | + cache_max_seq_len=8192, |
| 38 | + interleaved: bool = True |
106 | 39 | ): |
107 | 40 | super().__init__() |
108 | | - # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning |
109 | | - # has some connection to NTK literature |
110 | | - # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ |
111 | | - |
112 | | - theta *= theta_rescale_factor ** (dim / (dim - 2)) |
113 | | - |
114 | | - self.freqs_for = freqs_for |
| 41 | + self.interleaved = interleaved |
115 | 42 |
|
116 | | - if exists(custom_freqs): |
117 | | - freqs = custom_freqs |
118 | | - elif freqs_for == 'lang': |
119 | | - freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) |
120 | | - elif freqs_for == 'pixel': |
121 | | - freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi |
122 | | - elif freqs_for == 'constant': |
123 | | - freqs = torch.ones(num_freqs).float() |
| 43 | + inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| 44 | + self.register_buffer('inv_freq', inv_freq) |
124 | 45 |
|
125 | | - self.cache_if_possible = cache_if_possible |
126 | | - self.cache_max_seq_len = cache_max_seq_len |
| 46 | + self._cache_max_seq_len = max(precompute_len, cache_max_seq_len) |
| 47 | + self._precomputed_len = precompute_len |
127 | 48 |
|
128 | | - self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) |
| 49 | + self.register_buffer('cached_freqs', None, persistent=True) |
129 | 50 | self.cached_freqs_seq_len = 0 |
| 51 | + |
| 52 | + if self._precomputed_len > 0: |
| 53 | + self._precompute_cache(self._precomputed_len) |
130 | 54 |
|
131 | | - self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) |
132 | | - |
133 | | - self.learned_freq = learned_freq |
134 | | - |
135 | | - # dummy for device |
136 | | - |
137 | | - self.register_buffer('dummy', torch.tensor(0), persistent = False) |
138 | | - |
139 | | - # default sequence dimension |
140 | | - |
141 | | - self.seq_before_head_dim = seq_before_head_dim |
142 | | - self.default_seq_dim = -3 if seq_before_head_dim else -2 |
143 | | - |
144 | | - # interpolation factors |
145 | | - |
146 | | - assert interpolate_factor >= 1. |
147 | | - self.interpolate_factor = interpolate_factor |
148 | | - |
149 | | - # xpos |
150 | | - |
151 | | - self.use_xpos = use_xpos |
152 | | - |
153 | | - if not use_xpos: |
154 | | - return |
155 | | - |
156 | | - scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) |
157 | | - self.scale_base = xpos_scale_base |
158 | | - |
159 | | - self.register_buffer('scale', scale, persistent = False) |
160 | | - self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) |
161 | | - self.cached_scales_seq_len = 0 |
162 | | - |
163 | | - # add apply_rotary_emb as static method |
164 | | - |
165 | | - self.apply_rotary_emb = staticmethod(apply_rotary_emb) |
166 | | - |
167 | | - @property |
168 | | - def device(self): |
169 | | - return self.dummy.device |
170 | | - |
171 | | - def get_seq_pos(self, seq_len, device, dtype, offset = 0): |
172 | | - return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor |
173 | | - |
174 | | - def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, scale = None): |
175 | | - seq_dim = default(seq_dim, self.default_seq_dim) |
176 | | - |
177 | | - assert not self.use_xpos or exists(scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' |
178 | | - |
179 | | - device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] |
180 | | - |
181 | | - seq = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset) |
182 | | - |
183 | | - freqs = self.forward(seq, seq_len = seq_len, offset = offset) |
184 | | - |
185 | | - if seq_dim == -3: |
186 | | - freqs = rearrange(freqs, 'n d -> n 1 d') |
187 | | - |
188 | | - return apply_rotary_emb(freqs, t, scale = default(scale, 1.), seq_dim = seq_dim) |
189 | | - |
190 | | - def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0): |
191 | | - dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim) |
192 | | - |
193 | | - q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] |
194 | | - assert q_len <= k_len |
195 | | - |
196 | | - q_scale = k_scale = 1. |
197 | | - |
198 | | - if self.use_xpos: |
199 | | - seq = self.get_seq_pos(k_len, dtype = dtype, device = device) |
200 | | - |
201 | | - q_scale = self.get_scale(seq[-q_len:]).type(dtype) |
202 | | - k_scale = self.get_scale(seq).type(dtype) |
203 | | - |
204 | | - rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, scale = q_scale, offset = k_len - q_len + offset) |
205 | | - rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, scale = k_scale ** -1) |
206 | | - |
207 | | - rotated_q = rotated_q.type(q.dtype) |
208 | | - rotated_k = rotated_k.type(k.dtype) |
209 | | - |
210 | | - return rotated_q, rotated_k |
211 | | - |
212 | | - def rotate_queries_and_keys(self, q, k, seq_dim = None): |
213 | | - seq_dim = default(seq_dim, self.default_seq_dim) |
214 | | - |
215 | | - assert self.use_xpos |
216 | | - device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] |
217 | | - |
218 | | - seq = self.get_seq_pos(seq_len, dtype = dtype, device = device) |
219 | | - |
220 | | - freqs = self.forward(seq, seq_len = seq_len) |
221 | | - scale = self.get_scale(seq, seq_len = seq_len).to(dtype) |
222 | | - |
223 | | - if seq_dim == -3: |
224 | | - freqs = rearrange(freqs, 'n d -> n 1 d') |
225 | | - scale = rearrange(scale, 'n d -> n 1 d') |
226 | | - |
227 | | - rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) |
228 | | - rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) |
229 | | - |
230 | | - rotated_q = rotated_q.type(q.dtype) |
231 | | - rotated_k = rotated_k.type(k.dtype) |
232 | | - |
233 | | - return rotated_q, rotated_k |
234 | | - |
235 | | - def get_scale( |
236 | | - self, |
237 | | - t: Tensor, |
238 | | - seq_len: int | None = None, |
239 | | - offset = 0 |
240 | | - ): |
241 | | - assert self.use_xpos |
242 | | - |
243 | | - should_cache = ( |
244 | | - self.cache_if_possible and |
245 | | - exists(seq_len) and |
246 | | - (offset + seq_len) <= self.cache_max_seq_len |
247 | | - ) |
248 | | - |
249 | | - if ( |
250 | | - should_cache and \ |
251 | | - exists(self.cached_scales) and \ |
252 | | - (seq_len + offset) <= self.cached_scales_seq_len |
253 | | - ): |
254 | | - return self.cached_scales[offset:(offset + seq_len)] |
255 | | - |
256 | | - scale = 1. |
257 | | - if self.use_xpos: |
258 | | - power = (t - len(t) // 2) / self.scale_base |
259 | | - scale = self.scale ** rearrange(power, 'n -> n 1') |
260 | | - scale = repeat(scale, 'n d -> n (d r)', r = 2) |
261 | | - |
262 | | - if should_cache and offset == 0: |
263 | | - self.cached_scales[:seq_len] = scale.detach() |
264 | | - self.cached_scales_seq_len = seq_len |
265 | | - |
266 | | - return scale |
267 | | - |
268 | | - def get_axial_freqs(self, *dims): |
269 | | - Colon = slice(None) |
270 | | - all_freqs = [] |
271 | | - |
272 | | - for ind, dim in enumerate(dims): |
273 | | - if self.freqs_for == 'pixel': |
274 | | - pos = torch.linspace(-1, 1, steps = dim, device = self.device) |
275 | | - else: |
276 | | - pos = torch.arange(dim, device = self.device) |
277 | | - |
278 | | - freqs = self.forward(pos, seq_len = dim) |
279 | | - |
280 | | - all_axis = [None] * len(dims) |
281 | | - all_axis[ind] = Colon |
282 | | - |
283 | | - new_axis_slice = (Ellipsis, *all_axis, Colon) |
284 | | - all_freqs.append(freqs[new_axis_slice]) |
285 | | - |
286 | | - all_freqs = broadcast_tensors(*all_freqs) |
287 | | - return torch.cat(all_freqs, dim = -1) |
288 | | - |
289 | | - @autocast('cuda', enabled = False) |
290 | | - def forward( |
291 | | - self, |
292 | | - t: Tensor, |
293 | | - seq_len: int | None = None, |
294 | | - offset = 0 |
295 | | - ): |
296 | | - should_cache = ( |
297 | | - self.cache_if_possible and |
298 | | - not self.learned_freq and |
299 | | - exists(seq_len) and |
300 | | - self.freqs_for != 'pixel' and |
301 | | - (offset + seq_len) <= self.cache_max_seq_len |
302 | | - ) |
303 | | - |
304 | | - if ( |
305 | | - should_cache and \ |
306 | | - exists(self.cached_freqs) and \ |
307 | | - (offset + seq_len) <= self.cached_freqs_seq_len |
308 | | - ): |
309 | | - freqs = self.cached_freqs[offset:(offset + seq_len)].detach() |
310 | | - # Fix issue about 'find_unused_parameters' when DDP training.(#244) |
311 | | - freqs = freqs + 0. * self.freqs.sum() |
312 | | - return freqs |
313 | | - |
314 | | - freqs = self.freqs |
| 55 | + def _precompute_cache(self, seq_len: int): |
| 56 | + seq = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
| 57 | + freqs = einsum('i, j -> i j', seq, self.inv_freq) |
| 58 | + |
| 59 | + if self.interleaved: |
| 60 | + freqs = repeat(freqs, '... n -> ... (n r)', r=2) |
| 61 | + else: |
| 62 | + freqs = torch.cat((freqs, freqs), dim=-1) |
315 | 63 |
|
316 | | - freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) |
317 | | - freqs = repeat(freqs, '... n -> ... (n r)', r = 2) |
| 64 | + self.cached_freqs = freqs |
| 65 | + self.cached_freqs_seq_len = seq_len |
318 | 66 |
|
319 | | - if should_cache and offset == 0: |
320 | | - self.cached_freqs[:seq_len] = freqs.detach() |
321 | | - self.cached_freqs_seq_len = seq_len |
| 67 | + def forward(self, t: Tensor, seq_len: int) -> Tensor: |
| 68 | + if self.cached_freqs is None or seq_len > self.cached_freqs_seq_len: |
| 69 | + self._precompute_cache(seq_len) |
| 70 | + |
| 71 | + return self.cached_freqs[0: seq_len].detach() |
322 | 72 |
|
323 | | - return freqs |
| 73 | + def rotate_queries_or_keys(self, t: Tensor) -> Tensor: |
| 74 | + device, dtype, seq_len = t.device, t.dtype, t.shape[-2] |
| 75 | + freqs = self.forward(t, seq_len=seq_len) |
| 76 | + |
| 77 | + return apply_rotary_emb(freqs.to(device=device, dtype=dtype), t, self.interleaved) |
0 commit comments