Skip to content

Commit 291ca1a

Browse files
committed
Fix correctness bugs and improve robustness
Fixes: - Sync __version__ to 0.2.1 (was 0.1.0, pyproject.toml says 0.2.1) - Wire up check_tensor_health after SVD to catch NaN/Inf in components - Fix TIES merge normalization: use contributor_count divisor instead of over-scaling by total adapter count n - Warn on task ID collision in absorb/absorb_incremental - Sanitize task IDs for filesystem-safe filenames in save/load (stores tid→filename mapping in metadata, backwards-compatible) - Add task_ids length validation to from_adapters_streaming Performance: - Cache module handles in VLoRAModel at init (avoids O(M) scan of all named_modules on every task switch) - Replace NF4 distance broadcast with torch.bucketize (binary search), reducing memory from O(N*16) to O(N) for large weight matrices
1 parent aa4c46d commit 291ca1a

5 files changed

Lines changed: 59 additions & 21 deletions

File tree

src/vlora/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
maintain one shared basis and per-task coefficient vectors.
66
"""
77

8-
__version__ = "0.1.0"
8+
__version__ = "0.2.1"
99

1010
from vlora.io import LoRAWeights, load_adapter, load_adapter_from_hub, save_adapter
1111
from vlora.ops import (

src/vlora/merge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def ties_merge(
141141
weighted = stacked * w
142142
# Zero out values with wrong sign
143143
weighted = weighted * mask.float()
144-
# Sum and normalize by number of contributors (avoid division by zero)
144+
# Average over contributors that match elected sign
145145
contributor_count = mask.float().sum(dim=0).clamp(min=1)
146-
merged = weighted.sum(dim=0) * (n / contributor_count)
146+
merged = weighted.sum(dim=0) / contributor_count
147147

148148
out_dict[layer] = merged
149149

src/vlora/model.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ def __init__(
8080
self._active_task: str | None = None
8181
self._cached_deltas: dict[str, Tensor] | None = None
8282
self._hooks: list[torch.utils.hooks.RemovableHook] = []
83+
# Cache module handles once to avoid O(M) scan on every task switch
84+
self._target_modules: dict[str, nn.Module] = {
85+
name: module
86+
for name, module in self.base_model.named_modules()
87+
if name in self.subspace.layer_names and _is_linear_layer(module)
88+
}
8389
self._qlora_info = self._detect_quantization()
8490

8591
def set_task(self, task_id: str) -> None:
@@ -113,8 +119,8 @@ def _apply_hooks(self) -> None:
113119
if self._cached_deltas is None:
114120
return
115121

116-
for name, module in self.base_model.named_modules():
117-
if name in self._cached_deltas and _is_linear_layer(module):
122+
for name, module in self._target_modules.items():
123+
if name in self._cached_deltas:
118124
delta = self._cached_deltas[name]
119125
hook = module.register_forward_hook(
120126
self._make_lora_hook(delta)
@@ -174,12 +180,12 @@ def _detect_quantization(self) -> dict:
174180
"quantized": False,
175181
"method": None,
176182
"num_quantized_layers": 0,
177-
"num_target_layers": 0,
183+
"num_target_layers": len(self._target_modules),
178184
}
179185
try:
180186
import bitsandbytes as bnb
181187

182-
for name, module in self.base_model.named_modules():
188+
for module in self._target_modules.values():
183189
if isinstance(module, bnb.nn.Linear4bit):
184190
info["quantized"] = True
185191
info["method"] = info["method"] or "nf4"
@@ -191,11 +197,6 @@ def _detect_quantization(self) -> dict:
191197
except ImportError:
192198
pass
193199

194-
# Count how many subspace layers match modules in the base model
195-
for name, module in self.base_model.named_modules():
196-
if name in self.subspace.layer_names and _is_linear_layer(module):
197-
info["num_target_layers"] += 1
198-
199200
return info
200201

201202
@property

src/vlora/ops.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ def nf4_quantize_dequantize(tensor: Tensor, block_size: int = 64) -> Tensor:
246246
The returned tensor is the same dtype as input but only contains
247247
values representable in NF4 format.
248248
249+
Uses ``torch.bucketize`` (binary search) for O(N log 16) lookup
250+
instead of broadcasting all 16 distances, keeping memory O(N).
251+
249252
Based on QLoRA (Dettmers et al., 2023, arXiv:2305.14314).
250253
251254
Args:
@@ -274,11 +277,12 @@ def nf4_quantize_dequantize(tensor: Tensor, block_size: int = 64) -> Tensor:
274277
absmax = blocks.abs().amax(dim=1, keepdim=True).clamp(min=1e-10)
275278
normalized = blocks / absmax
276279

277-
# Snap each value to nearest NF4 level
280+
# Snap to nearest NF4 level via binary search (memory-efficient).
281+
# Midpoints between adjacent NF4 levels serve as bucket boundaries.
278282
table = NF4_QUANT_TABLE.to(device=normalized.device, dtype=normalized.dtype)
279-
# (num_blocks, block_size, 1) vs (1, 1, 16) → distances (num_blocks, block_size, 16)
280-
distances = (normalized.unsqueeze(-1) - table).abs()
281-
indices = distances.argmin(dim=-1)
283+
midpoints = (table[:-1] + table[1:]) / 2 # 15 boundaries
284+
# bucketize returns the index of the bucket each value falls into
285+
indices = torch.bucketize(normalized, midpoints)
282286
quantized_normalized = table[indices]
283287

284288
# Dequantize: scale back by absmax

src/vlora/subspace.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def from_adapters(
149149
]:
150150
data = stacked[layer]
151151
comps, svals, mean = compute_svd(data, num_components=None, center=True)
152+
check_tensor_health(comps, f"{layer}.components_{side.lower()}")
153+
check_tensor_health(mean, f"{layer}.mean_{side.lower()}")
152154

153155
if adaptive_k:
154156
# Per-layer: each layer/side gets its own k
@@ -265,6 +267,12 @@ def absorb(self, new_adapter: LoRAWeights, new_task_id: str) -> None:
265267
reruns SVD to produce an updated basis.
266268
"""
267269
check_adapter_matches_subspace(new_adapter, self, "absorb")
270+
if new_task_id in self.tasks:
271+
import warnings
272+
warnings.warn(
273+
f"Task '{new_task_id}' already exists and will be overwritten by absorb.",
274+
stacklevel=2,
275+
)
268276
logger.info("Absorbing adapter '%s' (full SVD recompute, %d existing tasks)", new_task_id, len(self.tasks))
269277
# Reconstruct all existing tasks as full adapters
270278
all_adapters = []
@@ -308,6 +316,12 @@ def absorb_incremental(self, new_adapter: LoRAWeights, new_task_id: str) -> None
308316
approximation trade-off.
309317
"""
310318
check_adapter_matches_subspace(new_adapter, self, "absorb_incremental")
319+
if new_task_id in self.tasks:
320+
import warnings
321+
warnings.warn(
322+
f"Task '{new_task_id}' already exists and will be overwritten by absorb_incremental.",
323+
stacklevel=2,
324+
)
311325
logger.debug("Absorbing adapter '%s' incrementally", new_task_id)
312326
loadings_a: dict[str, Tensor] = {}
313327
loadings_b: dict[str, Tensor] = {}
@@ -390,6 +404,11 @@ def from_adapters_streaming(
390404
paths = [Path(p) for p in adapter_paths]
391405
if task_ids is None:
392406
task_ids = [p.name for p in paths]
407+
if len(task_ids) != len(paths):
408+
raise ValueError(
409+
f"task_ids length ({len(task_ids)}) must match "
410+
f"adapter_paths length ({len(paths)})"
411+
)
393412

394413
# Initialize from first adapter(s) — use first two if available
395414
# so SVD has enough samples to find >1 component
@@ -633,8 +652,16 @@ def get_trainable_params(
633652

634653
return params
635654

655+
@staticmethod
656+
def _safe_filename(task_id: str) -> str:
657+
"""Convert a task ID to a filesystem-safe filename component."""
658+
import re
659+
return re.sub(r'[^\w\-.]', '_', task_id)
660+
636661
def save(self, path: str | Path) -> None:
637662
"""Serialize the subspace to disk."""
663+
import json
664+
638665
path = Path(path)
639666
path.mkdir(parents=True, exist_ok=True)
640667

@@ -650,19 +677,22 @@ def save(self, path: str | Path) -> None:
650677

651678
save_file(tensors, str(path / "subspace.safetensors"))
652679

653-
# Save per-task loadings
680+
# Save per-task loadings (with sanitized filenames)
681+
tid_to_filename: dict[str, str] = {}
654682
for tid, proj in self.tasks.items():
683+
safe_name = self._safe_filename(tid)
684+
tid_to_filename[tid] = safe_name
655685
task_tensors = {}
656686
for layer in self.layer_names:
657687
task_tensors[f"{layer}.loadings_a"] = proj.loadings_a[layer].contiguous()
658688
task_tensors[f"{layer}.loadings_b"] = proj.loadings_b[layer].contiguous()
659-
save_file(task_tensors, str(path / f"task_{tid}.safetensors"))
689+
save_file(task_tensors, str(path / f"task_{safe_name}.safetensors"))
660690

661-
# Save metadata
662-
import json
691+
# Save metadata (includes filename mapping for safe round-trip)
663692
meta = {
664693
"layer_names": self.layer_names,
665694
"task_ids": list(self.tasks.keys()),
695+
"task_filenames": tid_to_filename,
666696
"rank": self.rank,
667697
"num_components": self.num_components,
668698
}
@@ -683,6 +713,8 @@ def load(cls, path: str | Path) -> SharedSubspace:
683713
task_ids = meta["task_ids"]
684714
rank = meta["rank"]
685715
num_components = meta["num_components"]
716+
# Support both old format (no mapping) and new format
717+
tid_to_filename = meta.get("task_filenames", {})
686718

687719
tensors = load_file(str(path / "subspace.safetensors"))
688720
components_a = {l: tensors[f"{l}.components_a"] for l in layer_names}
@@ -694,7 +726,8 @@ def load(cls, path: str | Path) -> SharedSubspace:
694726

695727
tasks = {}
696728
for tid in task_ids:
697-
task_tensors = load_file(str(path / f"task_{tid}.safetensors"))
729+
safe_name = tid_to_filename.get(tid, tid)
730+
task_tensors = load_file(str(path / f"task_{safe_name}.safetensors"))
698731
loadings_a = {l: task_tensors[f"{l}.loadings_a"] for l in layer_names}
699732
loadings_b = {l: task_tensors[f"{l}.loadings_b"] for l in layer_names}
700733
tasks[tid] = TaskProjection(

0 commit comments

Comments
 (0)