-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathlora.py
More file actions
355 lines (325 loc) · 11.6 KB
/
lora.py
File metadata and controls
355 lines (325 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import torch
import torch.nn as nn
from torch.nn.common_types import _size_2_t
from typing import Union
from collections import OrderedDict
from contextlib import contextmanager
from diffsynth_engine.models.basic.linear import fp8_linear
from diffsynth_engine.utils.platform import DTYPE_FP8
class LoRA(nn.Module):
def __init__(
self,
scale: float,
rank: int,
alpha: int,
up: Union[nn.Linear, nn.Conv2d, torch.Tensor],
down: Union[nn.Linear, nn.Conv2d, torch.Tensor],
device: str,
dtype: torch.dtype,
):
super().__init__()
self.device = device
self.dtype = dtype
self.scale = scale
self.rank = rank
self.alpha = alpha.item() if isinstance(alpha, torch.Tensor) else alpha
self.up = up.to(device=device, dtype=dtype)
self.down = down.to(device=device, dtype=dtype)
def forward(self, x):
if isinstance(self.up, torch.Tensor) and isinstance(self.down, torch.Tensor):
return self.scale * (self.alpha / self.rank) * (x @ self.down.T @ self.up.T)
return self.scale * (self.alpha / self.rank) * (self.up(self.down(x)))
def apply_to(self, w: Union[nn.Linear, nn.Conv2d, nn.Parameter, torch.Tensor]):
if isinstance(self.up, torch.Tensor) and isinstance(self.down, torch.Tensor):
delta_w = self.scale * (self.alpha / self.rank) * (self.up @ self.down)
else:
delta_w = self.scale * (self.alpha / self.rank) * (self.up.weight @ self.down.weight)
if isinstance(w, (nn.Linear, nn.Conv2d)):
delta_w = delta_w.to(device=w.weight.data.device, dtype=self.dtype)
w_dtype = w.weight.data.dtype
w.weight.data = w.weight.data.to(self.dtype)
w.weight.data.add_(delta_w)
w.weight.data = w.weight.data.to(w_dtype)
elif isinstance(w, nn.Parameter):
delta_w = delta_w.to(device=w.data.device, dtype=self.dtype)
w_dtype = w.data.dtype
w.data = w.data.to(self.dtype)
w.data.add_(delta_w)
w.data = w.data.to(w_dtype)
elif isinstance(w, torch.Tensor):
delta_w = delta_w.to(device=w.device, dtype=self.dtype)
w_dtype = w.dtype
w = w.to(self.dtype)
w.add_(delta_w)
w = w.to(w_dtype)
class LoRALinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
# LoRA
self._lora_dict = OrderedDict()
# Frozen LoRA
self.patched_frozen_lora = False
self._original_weight = None
@staticmethod
def from_linear(linear: nn.Linear):
lora_linear = LoRALinear(
linear.in_features,
linear.out_features,
linear.bias is not None,
device="meta",
dtype=linear.weight.dtype,
).to_empty(device=linear.weight.device)
lora_linear.weight = linear.weight
lora_linear.bias = linear.bias
return lora_linear
def add_lora(
self,
name: str,
scale: float,
rank: int,
alpha: int,
up: torch.Tensor,
down: torch.Tensor,
device: str,
dtype: torch.dtype,
**kwargs,
):
up_linear = nn.Linear(
up.shape[1],
up.shape[0],
bias=False,
device="meta",
dtype=dtype,
).to_empty(device=device)
down_linear = nn.Linear(
down.shape[0],
down.shape[1],
bias=False,
device="meta",
dtype=dtype,
).to_empty(device=device)
up_linear.weight.data = up
down_linear.weight.data = down
lora = LoRA(scale, rank, alpha, up_linear, down_linear, device, dtype)
self._lora_dict[name] = lora
def modify_scale(self, name: str, scale: float):
if name not in self._lora_dict:
raise ValueError(f"LoRA name {name} not found in LoRALinear {self.__class__.__name__}")
self._lora_dict[name].scale = scale
def add_frozen_lora(
self,
name: str,
scale: float,
rank: int,
alpha: int,
up: torch.Tensor,
down: torch.Tensor,
device: str,
dtype: torch.dtype,
save_original_weight: bool = True,
**kwargs,
):
if save_original_weight and self._original_weight is None:
if self.weight.dtype == torch.float8_e4m3fn:
self._original_weight = self.weight.to(dtype=torch.bfloat16, device="cpu", copy=True).pin_memory()
else:
self._original_weight = self.weight.to(device="cpu", copy=True).pin_memory()
lora = LoRA(scale, rank, alpha, up, down, device, dtype)
lora.apply_to(self)
self.patched_frozen_lora = True
def clear(self, release_all_cpu_memory: bool = False):
if self.patched_frozen_lora and self._original_weight is None:
raise RuntimeError(
"Current LoRALinear has patched by frozen LoRA, but original weight is not saved, so you cannot clear LoRA."
)
self._lora_dict.clear()
if self._original_weight is not None:
self.weight.data.copy_(
self._original_weight.to(device=self.weight.data.device, dtype=self.weight.data.dtype)
)
if release_all_cpu_memory:
del self._original_weight
self.patched_frozen_lora = False
def forward(self, x):
w_x = super().forward(x)
for name, lora in self._lora_dict.items():
w_x += lora(x)
return w_x
class LoRAFP8Linear(LoRALinear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
scaling: bool = True,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
self.weight.data = self.weight.data.to(DTYPE_FP8)
self.scaling = scaling
@staticmethod
def from_linear(linear: nn.Linear, scaling: bool = True):
lora_linear = LoRAFP8Linear(
linear.in_features,
linear.out_features,
linear.bias is not None,
device="meta",
dtype=linear.weight.dtype,
scaling=scaling,
).to_empty(device=linear.weight.device)
lora_linear.weight.data = linear.weight.data.to(DTYPE_FP8)
lora_linear.bias = linear.bias
return lora_linear
def forward(self, x):
w_x = fp8_linear(x, self.weight, self.bias, self.scaling) # only use fp8 linear for base layer
for name, lora in self._lora_dict.items():
w_x += lora(x)
return w_x
class LoRAConv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros", # TODO: refine this type
device=None,
dtype=None,
) -> None:
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype
)
# LoRA
self._lora_dict = OrderedDict()
# Frozen LoRA
self._original_weight = None
self.patched_frozen_lora = False
@staticmethod
def from_conv2d(conv2d: nn.Conv2d):
lora_conv2d = LoRAConv2d(
conv2d.in_channels,
conv2d.out_channels,
conv2d.kernel_size,
conv2d.stride,
conv2d.padding,
conv2d.dilation,
conv2d.groups,
conv2d.bias is not None,
conv2d.padding_mode,
device="meta",
dtype=conv2d.weight.dtype,
).to_empty(device=conv2d.weight.device)
lora_conv2d.weight = conv2d.weight
lora_conv2d.bias = conv2d.bias
return lora_conv2d
def _construct_lora(
self,
name: str,
scale: float,
rank: int,
alpha: int,
up: torch.Tensor,
down: torch.Tensor,
device: str,
dtype: torch.dtype,
):
down_conv = nn.Conv2d(
self.in_channels,
rank,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=False,
device="meta",
dtype=dtype,
).to_empty(device=device)
down_conv.weight.data = down
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
# see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
# refer from diffusers
up_conv = nn.Conv2d(
rank,
self.out_channels,
kernel_size=(1, 1),
stride=(1, 1),
bias=False,
device="meta",
dtype=dtype,
).to_empty(device=device)
up_conv.weight.data = up
lora = LoRA(scale, rank, alpha, up_conv, down_conv, device, dtype)
return lora
def add_lora(
self,
name: str,
scale: float,
rank: int,
alpha: int,
up: torch.Tensor,
down: torch.Tensor,
device: str,
dtype: torch.dtype,
**kwargs,
):
self._lora_dict[name] = self._construct_lora(name, scale, rank, alpha, up, down, device, dtype)
def modify_scale(self, name: str, scale: float):
if name not in self._lora_dict:
raise ValueError(f"LoRA name {name} not found in LoRAConv2d {self.__class__.__name__}")
self._lora_dict[name].scale = scale
def add_frozen_lora(
self,
name: str,
scale: float,
rank: int,
alpha: int,
up: torch.Tensor,
down: torch.Tensor,
device: str,
dtype: torch.dtype,
save_original_weight: bool = True,
):
if save_original_weight and self._original_weight is None:
if self.weight.dtype == torch.float8_e4m3fn:
self._original_weight = self.weight.to(dtype=torch.bfloat16, device="cpu", copy=True).pin_memory()
else:
self._original_weight = self.weight.to(device="cpu", copy=True).pin_memory()
lora = self._construct_lora(name, scale, rank, alpha, up, down, device, dtype)
lora.apply_to(self)
self.patched_frozen_lora = True
def clear(self, release_all_cpu_memory: bool = False):
if self.patched_frozen_lora and self._original_weight is None:
raise RuntimeError(
"Current LoRALinear has patched by frozen LoRA, but original weight is not saved, so you cannot clear LoRA."
)
self._lora_dict.clear()
if self._original_weight is not None:
self.weight.copy_(self._original_weight.to(device=self.weight.device, dtype=self.weight.dtype))
if release_all_cpu_memory:
del self._original_weight
self.patched_frozen_lora = False
def forward(self, x):
w_x = super().forward(x)
for name, lora in self._lora_dict.items():
w_x += lora(x)
return w_x
@contextmanager
def LoRAContext():
origin_linear = torch.nn.Linear
origin_conv2d = torch.nn.Conv2d
torch.nn.Linear = LoRALinear
torch.nn.Conv2d = LoRAConv2d
yield
torch.nn.Linear = origin_linear
torch.nn.Conv2d = origin_conv2d