Skip to content

Commit 3c7dd61

Browse files
committed
Production hardening: progress bars, device awareness, API boundaries
Five targeted improvements for production readiness: 1. CLI progress bars — compress, analyze, merge show loading progress 2. Device-aware orthogonal_init — tensors created on subspace device/dtype 3. VLoRAModel/SubspaceTrainer __repr__ + thread-safety documentation 4. __all__ in all submodules — proper API boundaries for each module 5. Memory estimation logging in from_adapters() before SVD
1 parent 5da852b commit 3c7dd61

10 files changed

Lines changed: 98 additions & 23 deletions

File tree

src/vlora/analysis.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22

33
from __future__ import annotations
44

5+
__all__ = [
6+
"adapter_diff",
7+
"compute_similarity_matrix",
8+
"find_clusters",
9+
"find_outliers",
10+
"subspace_coverage",
11+
]
12+
513
from typing import TYPE_CHECKING
614

715
import torch

src/vlora/cli.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,13 @@ def info(subspace_path: str, as_json: bool):
9292
@click.option("--adaptive-k", is_flag=True, help="Use per-layer adaptive k selection.")
9393
def compress(adapter_dirs: tuple[str, ...], output: str, num_components: int | None, variance_threshold: float, adaptive_k: bool):
9494
"""Build shared subspace from adapter directories."""
95-
click.echo(f"\n Loading {len(adapter_dirs)} adapters...")
96-
9795
adapters = []
9896
task_ids = []
99-
for d in adapter_dirs:
100-
path = Path(d)
101-
adapters.append(load_adapter(path))
102-
task_ids.append(path.name)
103-
click.echo(f" Loaded: {path.name}")
97+
with click.progressbar(adapter_dirs, label=" Loading adapters") as bar:
98+
for d in bar:
99+
path = Path(d)
100+
adapters.append(load_adapter(path))
101+
task_ids.append(path.name)
104102

105103
click.echo(" Building subspace...")
106104
sub = SharedSubspace.from_adapters(
@@ -182,14 +180,16 @@ def analyze(adapter_dirs: tuple[str, ...], threshold: float, as_json: bool):
182180

183181
adapters = []
184182
names = []
185-
for d in adapter_dirs:
186-
path = Path(d)
187-
adapters.append(load_adapter(path))
188-
names.append(path.name)
183+
with click.progressbar(adapter_dirs, label=" Loading adapters") as bar:
184+
for d in bar:
185+
path = Path(d)
186+
adapters.append(load_adapter(path))
187+
names.append(path.name)
189188

190189
if len(adapters) < 2:
191190
raise click.ClickException("Need at least 2 adapters for analysis.")
192191

192+
click.echo(" Computing similarity matrix...")
193193
sim = compute_similarity_matrix(adapters)
194194
clusters = find_clusters(sim, threshold=threshold)
195195

@@ -206,10 +206,6 @@ def analyze(adapter_dirs: tuple[str, ...], threshold: float, as_json: bool):
206206
click.echo(json_mod.dumps(output, indent=2))
207207
return
208208

209-
click.echo(f"\n Loading {len(adapter_dirs)} adapters...")
210-
for n in names:
211-
click.echo(f" Loaded: {n}")
212-
213209
click.echo("\n Pairwise Cosine Similarity:")
214210
header = " " + " " * 20 + " ".join(f"{n[:8]:>8}" for n in names)
215211
click.echo(header)
@@ -399,12 +395,11 @@ def merge(adapter_dirs: tuple[str, ...], output: str, method: str, weights: str
399395
"""Merge multiple adapters into one using task arithmetic, TIES, or DARE."""
400396
from vlora.merge import MERGE_METHODS
401397

402-
click.echo(f"\n Loading {len(adapter_dirs)} adapters...")
403398
adapters = []
404-
for d in adapter_dirs:
405-
path = Path(d)
406-
adapters.append(load_adapter(path))
407-
click.echo(f" Loaded: {path.name}")
399+
with click.progressbar(adapter_dirs, label=" Loading adapters") as bar:
400+
for d in bar:
401+
path = Path(d)
402+
adapters.append(load_adapter(path))
408403

409404
if len(adapters) < 2:
410405
raise click.ClickException("Need at least 2 adapters to merge.")

src/vlora/io.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66

77
from __future__ import annotations
88

9+
__all__ = [
10+
"LoRAWeights",
11+
"load_adapter",
12+
"load_adapter_from_hub",
13+
"save_adapter",
14+
"parse_state_dict",
15+
"stack_lora_weights",
16+
]
17+
918
import json
1019
import logging
1120
import re

src/vlora/merge.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from __future__ import annotations
1414

15+
__all__ = ["task_arithmetic", "ties_merge", "dare_merge", "MERGE_METHODS"]
16+
1517
import logging
1618

1719
import torch

src/vlora/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
__all__ = ["VLoRAModel"]
6+
57
from typing import Any
68

79
import torch
@@ -38,6 +40,11 @@ class VLoRAModel(nn.Module):
3840
``compute_dtype`` to match the model's compute precision (typically
3941
``torch.bfloat16``).
4042
43+
Note: This class is **not thread-safe**. Concurrent calls to
44+
``set_task()``, ``merge()``, or ``forward()`` from multiple threads
45+
may produce incorrect results. Use a lock or separate model instances
46+
for multi-threaded serving.
47+
4148
Usage:
4249
subspace = SharedSubspace.load("shared_subspace/")
4350
base_model = AutoModelForCausalLM.from_pretrained("model-name")
@@ -312,6 +319,15 @@ def is_merged(self) -> bool:
312319
"""Whether LoRA deltas are currently baked into base weights."""
313320
return self._merged
314321

322+
def __repr__(self) -> str:
323+
task = self._active_task or "none"
324+
merged = " merged" if self._merged else ""
325+
return (
326+
f"VLoRAModel(tasks={len(self.subspace.tasks)}, "
327+
f"active={task!r}{merged}, "
328+
f"layers={len(self._target_modules)})"
329+
)
330+
315331
def compile(self, **kwargs) -> VLoRAModel:
316332
"""Compile the base model with torch.compile for faster inference.
317333

src/vlora/ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,20 @@
66

77
from __future__ import annotations
88

9+
__all__ = [
10+
"NF4_QUANT_TABLE",
11+
"compute_svd",
12+
"explained_variance_ratio",
13+
"gram_schmidt",
14+
"incremental_svd_update",
15+
"nf4_pack",
16+
"nf4_quantize_dequantize",
17+
"nf4_unpack",
18+
"project_onto_subspace",
19+
"reconstruct_from_subspace",
20+
"select_num_components",
21+
]
22+
923
import torch
1024
from torch import Tensor
1125

src/vlora/pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
__all__ = ["init_subspace", "absorb_task", "extract_adapter"]
6+
57
from pathlib import Path
68

79
from vlora.io import LoRAWeights, load_adapter, save_adapter

src/vlora/router.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
shared subspace, blending is a cheap linear combination rather than
66
reconstructing and merging full LoRA matrices.
77
8+
89
Usage:
910
subspace = SharedSubspace.load("shared_subspace/")
1011
router = TaskRouter.from_subspace(subspace, hidden_dim=64)
@@ -17,6 +18,8 @@
1718

1819
from __future__ import annotations
1920

21+
__all__ = ["TaskRouter"]
22+
2023
import torch
2124
import torch.nn as nn
2225
import torch.nn.functional as F

src/vlora/subspace.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from __future__ import annotations
99

10+
__all__ = ["SharedSubspace", "TaskProjection"]
11+
1012
import logging
1113
from dataclasses import dataclass
1214
from pathlib import Path
@@ -133,7 +135,20 @@ def from_adapters(
133135
the threshold. Overrides num_components.
134136
"""
135137
check_adapters_compatible(adapters)
136-
logger.info("Building subspace from %d adapters", len(adapters))
138+
139+
# Log memory estimate to help users anticipate resource needs
140+
n_adapters = len(adapters)
141+
sample_layer = adapters[0].layer_names[0]
142+
dim_a = adapters[0].lora_a[sample_layer].numel()
143+
dim_b = adapters[0].lora_b[sample_layer].numel()
144+
n_layers = len(adapters[0].layer_names)
145+
# SVD working memory: ~2 * (N * D * 4 bytes) per layer for A and B
146+
svd_bytes = 2 * n_adapters * (dim_a + dim_b) * 4 * n_layers
147+
svd_mb = svd_bytes / (1024 * 1024)
148+
logger.info(
149+
"Building subspace from %d adapters (%d layers, ~%.0f MB estimated)",
150+
n_adapters, n_layers, svd_mb,
151+
)
137152

138153
if task_ids is None:
139154
task_ids = [f"task_{i}" for i in range(len(adapters))]

src/vlora/training.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from __future__ import annotations
2121

22+
__all__ = ["SubspaceTrainer", "orthogonal_init"]
23+
2224
import torch
2325
from torch import Tensor
2426

@@ -49,9 +51,11 @@ def orthogonal_init(
4951

5052
for layer in subspace.layer_names:
5153
actual_k = subspace.components_a[layer].shape[0]
52-
loadings_a[layer] = torch.randn(actual_k) * scale
54+
device = subspace.components_a[layer].device
55+
dtype = subspace.components_a[layer].dtype
56+
loadings_a[layer] = torch.randn(actual_k, device=device, dtype=dtype) * scale
5357
# Initialize B-side to zero (like standard LoRA) so initial delta is zero
54-
loadings_b[layer] = torch.zeros(actual_k)
58+
loadings_b[layer] = torch.zeros(actual_k, device=device, dtype=dtype)
5559

5660
proj = TaskProjection(task_id=task_id, loadings_a=loadings_a, loadings_b=loadings_b)
5761
subspace.tasks[task_id] = proj
@@ -145,3 +149,10 @@ def num_trainable_params(self) -> int:
145149
def step_count(self) -> int:
146150
"""Number of optimizer steps taken."""
147151
return self._step_count
152+
153+
def __repr__(self) -> str:
154+
return (
155+
f"SubspaceTrainer(task={self.task_id!r}, "
156+
f"params={self.num_trainable_params}, "
157+
f"steps={self._step_count})"
158+
)

0 commit comments

Comments
 (0)