Skip to content

Commit fe49475

Browse files
Merge pull request #9 from Lexsi-Labs/dlb_mem_optimize
Add Relevance Trace Memory Controls
2 parents 24629c6 + 34d886a commit fe49475

5 files changed

Lines changed: 335 additions & 25 deletions

File tree

dl_backtrace/pytorch_backtrace/dlbacktrace/core/dlb_auto_sampler.py

Lines changed: 297 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,22 @@
55
from __future__ import annotations
66

77
import time
8-
from typing import Optional, List, Tuple, cast
8+
import gzip
9+
import lzma
10+
import numpy as np
11+
from pathlib import Path
12+
from typing import Optional, List, Tuple, cast, Any
913

1014
import torch
1115
import torch.nn.functional as F
1216

17+
# Optional: 7z support (requires py7zr)
18+
try:
19+
import py7zr
20+
HAS_7Z = True
21+
except ImportError:
22+
HAS_7Z = False
23+
1324
from transformers.generation.logits_process import (
1425
LogitsProcessorList,
1526
TemperatureLogitsWarper,
@@ -68,6 +79,56 @@ def __init__(self, dlb, tokenizer):
6879
self.dlb = dlb
6980
self.tokenizer = tokenizer
7081

82+
# ---------- compression helpers ----------
83+
84+
@staticmethod
85+
def load_compressed_relevance(file_path: str):
86+
"""
87+
Load relevance data from compressed file with automatic format detection.
88+
89+
Args:
90+
file_path: Path to compressed relevance file (.pt.gz, .pt.xz, .pt.7z, or .pt)
91+
92+
Returns:
93+
Loaded relevance dictionary
94+
95+
Example:
96+
>>> relevance = DLBAutoSampler.load_compressed_relevance("step_00000.pt.gz")
97+
"""
98+
path = Path(file_path)
99+
100+
if path.suffix == '.gz':
101+
# Gzip compressed
102+
with gzip.open(path, 'rb') as f:
103+
return torch.load(f, weights_only=False)
104+
105+
elif path.suffix == '.xz':
106+
# LZMA compressed
107+
with lzma.open(path, 'rb') as f:
108+
return torch.load(f, weights_only=False)
109+
110+
elif path.suffix == '.7z':
111+
# 7z compressed
112+
if not HAS_7Z:
113+
raise ImportError(
114+
"py7zr library required to load 7z files. "
115+
"Install with: pip install py7zr"
116+
)
117+
import tempfile
118+
with tempfile.TemporaryDirectory() as tmpdir:
119+
tmpdir_path = Path(tmpdir)
120+
with py7zr.SevenZipFile(path, 'r') as archive:
121+
archive.extractall(tmpdir_path)
122+
# Find the extracted .pt file
123+
pt_files = list(tmpdir_path.glob('*.pt'))
124+
if not pt_files:
125+
raise ValueError(f"No .pt file found in 7z archive: {path}")
126+
return torch.load(pt_files[0], weights_only=False)
127+
128+
else:
129+
# Uncompressed or unknown format
130+
return torch.load(path, weights_only=False)
131+
71132
# ---------- small dtype helpers ----------
72133

73134
@staticmethod
@@ -252,6 +313,173 @@ def add_val(x):
252313
add_val(rel_dict)
253314
return total
254315

316+
@staticmethod
317+
def _resolve_torch_dtype(dtype_hint):
318+
if dtype_hint is None:
319+
return None
320+
if isinstance(dtype_hint, torch.dtype):
321+
return dtype_hint
322+
if isinstance(dtype_hint, str):
323+
key = dtype_hint.strip().lower()
324+
mapping = {
325+
"float32": torch.float32,
326+
"fp32": torch.float32,
327+
"float": torch.float32,
328+
"float16": torch.float16,
329+
"fp16": torch.float16,
330+
"half": torch.float16,
331+
"bfloat16": torch.bfloat16,
332+
"bf16": torch.bfloat16,
333+
"float64": torch.float64,
334+
"fp64": torch.float64,
335+
}
336+
if key in mapping:
337+
return mapping[key]
338+
raise ValueError(f"Unsupported relevance dtype hint: {dtype_hint}")
339+
340+
def _compress_relevance_tree(self, data, *, target_dtype=None, move_to_cpu=True):
341+
if torch.is_tensor(data):
342+
tensor = data.detach()
343+
if move_to_cpu:
344+
tensor = tensor.to("cpu")
345+
if target_dtype is not None:
346+
tensor = tensor.to(dtype=target_dtype)
347+
return tensor.clone()
348+
# Handle numpy arrays by converting to torch tensor with target dtype
349+
if isinstance(data, np.ndarray):
350+
tensor = torch.from_numpy(data)
351+
if move_to_cpu:
352+
tensor = tensor.to("cpu")
353+
if target_dtype is not None:
354+
tensor = tensor.to(dtype=target_dtype)
355+
return tensor
356+
if isinstance(data, dict):
357+
return {k: self._compress_relevance_tree(v, target_dtype=target_dtype, move_to_cpu=move_to_cpu) for k, v in data.items()}
358+
if isinstance(data, list):
359+
return [self._compress_relevance_tree(v, target_dtype=target_dtype, move_to_cpu=move_to_cpu) for v in data]
360+
if isinstance(data, tuple):
361+
return tuple(self._compress_relevance_tree(v, target_dtype=target_dtype, move_to_cpu=move_to_cpu) for v in data)
362+
return data
363+
364+
def _prepare_cache_dir(self, base_dir: Optional[str], policy: str):
365+
if policy != "disk":
366+
return None
367+
if not base_dir:
368+
raise ValueError("relevance_cache_dir is required when relevance_cache_policy='disk'")
369+
root = Path(base_dir).expanduser()
370+
timestamp = int(time.time() * 1000)
371+
run_dir = root / f"relevance_cache_run_{timestamp}"
372+
run_dir.mkdir(parents=True, exist_ok=True)
373+
return run_dir
374+
375+
def _store_relevance_entry(
376+
self,
377+
rel_dict,
378+
*,
379+
policy: str,
380+
step_idx: int,
381+
cache_dir: Optional[Path],
382+
target_dtype,
383+
move_to_cpu: bool,
384+
use_compression: bool = True,
385+
compression_method: str = "gzip",
386+
pickle_protocol: int = 4,
387+
):
388+
"""
389+
Store relevance entry according to specified policy.
390+
391+
Args:
392+
rel_dict: Relevance dictionary to store
393+
policy: Cache policy ("full", "summary", "disk", "none")
394+
step_idx: Generation step index
395+
cache_dir: Directory for disk caching
396+
target_dtype: Target dtype for compression
397+
move_to_cpu: Whether to move tensors to CPU
398+
use_compression: If True, use compression (default: True)
399+
compression_method: Compression method - "gzip", "lzma", "7z", or "none"
400+
- "gzip": Fast, good compression (default)
401+
- "lzma": Better compression, slower
402+
- "7z": Best compression, slowest (requires py7zr)
403+
- "none": No compression
404+
pickle_protocol: Pickle protocol version (2-5). Higher = better compression.
405+
Protocol 4 (default): Python 3.4+, good compression
406+
Protocol 5: Best compression, Python 3.8+
407+
"""
408+
normalized_policy = (policy or "full").lower()
409+
if normalized_policy == "none":
410+
return None
411+
412+
processed = self._compress_relevance_tree(rel_dict, target_dtype=target_dtype, move_to_cpu=move_to_cpu)
413+
if processed is None:
414+
return None
415+
416+
if normalized_policy == "summary":
417+
return {"summary": self._summarize_relevance(processed)}
418+
419+
if normalized_policy == "disk":
420+
if cache_dir is None:
421+
raise ValueError("relevance_cache_dir must be provided when relevance_cache_policy='disk'")
422+
423+
base_file_path = cache_dir / f"step_{step_idx:05d}.pt"
424+
425+
# Determine compression method and file extension
426+
if not use_compression or compression_method == "none":
427+
# No compression
428+
file_path = base_file_path
429+
torch.save(processed, file_path, pickle_protocol=pickle_protocol)
430+
431+
elif compression_method == "gzip":
432+
# Gzip compression (fast, good ratio)
433+
file_path = Path(str(base_file_path) + '.gz')
434+
with gzip.open(file_path, 'wb', compresslevel=6) as f:
435+
torch.save(processed, f, pickle_protocol=pickle_protocol)
436+
437+
elif compression_method == "lzma":
438+
# LZMA/xz compression (better ratio, slower)
439+
file_path = Path(str(base_file_path) + '.xz')
440+
with lzma.open(file_path, 'wb', preset=6) as f:
441+
torch.save(processed, f, pickle_protocol=pickle_protocol)
442+
443+
elif compression_method == "7z":
444+
# 7z compression (best ratio, slowest)
445+
if not HAS_7Z:
446+
raise ImportError(
447+
"py7zr library required for 7z compression. "
448+
"Install with: pip install py7zr"
449+
)
450+
file_path = Path(str(base_file_path) + '.7z')
451+
# Save to temporary .pt file first
452+
import tempfile
453+
with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as tmp:
454+
tmp_path = Path(tmp.name)
455+
torch.save(processed, tmp_path, pickle_protocol=pickle_protocol)
456+
457+
# Compress with 7z
458+
with py7zr.SevenZipFile(file_path, 'w') as archive:
459+
archive.write(tmp_path, arcname=f'step_{step_idx:05d}.pt')
460+
461+
# Clean up temp file
462+
tmp_path.unlink()
463+
464+
else:
465+
raise ValueError(
466+
f"Unknown compression_method: {compression_method}. "
467+
f"Must be one of: 'gzip', 'lzma', '7z', 'none'"
468+
)
469+
470+
return {
471+
"summary": self._summarize_relevance(processed),
472+
"path": str(file_path),
473+
"compression": compression_method,
474+
}
475+
476+
477+
if normalized_policy != "full":
478+
raise ValueError(
479+
"relevance_cache_policy must be one of {'full', 'summary', 'disk', 'none'}"
480+
)
481+
return processed
482+
255483
# ---------- public API ----------
256484

257485
@torch.no_grad()
@@ -286,17 +514,52 @@ def generate(
286514
return_layerwise_output: bool = False,
287515
return_relevance: bool = False,
288516
debug: bool = False,
517+
relevance_cache_policy: str = "full",
518+
relevance_cache_dir: Optional[str] = None,
519+
relevance_compress_dtype: Optional[Any] = "float16",
520+
relevance_move_to_cpu: bool = True,
521+
relevance_use_compression: bool = True,
522+
relevance_compression_method: str = "gzip",
523+
relevance_pickle_protocol: int = 4,
289524
):
290525
"""
291526
Always returns:
292527
- [1, T_total] (top-1 sequence)
293528
- Or (sequence, scores_trace) for sampling when return_scores=True
294-
"""
529+
530+
Relevance caching knobs:
531+
relevance_cache_policy: "full" (default), "summary", "disk", or "none".
532+
relevance_cache_dir: base directory for on-disk caching (policy="disk").
533+
relevance_compress_dtype: dtype hint (str or torch.dtype) for stored tensors.
534+
relevance_use_compression: If True, use compression for disk storage (default: True).
535+
relevance_compression_method: Compression method - "gzip" (default), "lzma", "7z", or "none".
536+
- "gzip": Fast, good compression (~75% reduction)
537+
- "lzma": Better compression (~80% reduction), slower
538+
- "7z": Best compression (~82% reduction), slowest
539+
- "none": No compression (only dtype compression)
540+
relevance_pickle_protocol: Pickle protocol (2-5). Higher = better compression. Default=4.
541+
relevance_move_to_cpu: move tensors to CPU before caching to reduce VRAM.
542+
"""
295543
model = self._get_causallm(self.dlb.model)
296544
device = input_ids.device
297545
B = input_ids.size(0)
298546
assert B == 1, "Current implementation assumes batch size = 1."
299547

548+
cache_policy = (relevance_cache_policy or "full").lower()
549+
allowed_policies = {"full", "summary", "disk", "none"}
550+
if cache_policy not in allowed_policies:
551+
raise ValueError(
552+
"relevance_cache_policy must be one of {'full', 'summary', 'disk', 'none'}"
553+
)
554+
cache_dtype = (
555+
self._resolve_torch_dtype(relevance_compress_dtype)
556+
if relevance_compress_dtype is not None
557+
else None
558+
)
559+
cache_dir_path = None
560+
if return_relevance and cache_policy == "disk":
561+
cache_dir_path = self._prepare_cache_dir(relevance_cache_dir, cache_policy)
562+
300563
# Dtypes up-front
301564
input_ids = self._as_long(input_ids)
302565

@@ -437,8 +700,19 @@ def generate(
437700
task="generation",
438701
debug=False,
439702
)
440-
# rel_scalar = self._summarize_relevance(rel_dict)
441-
relevance_trace.append(rel_dict)
703+
step_idx = len(relevance_trace)
704+
entry = self._store_relevance_entry(
705+
rel_dict,
706+
policy=cache_policy,
707+
step_idx=step_idx,
708+
cache_dir=cache_dir_path,
709+
target_dtype=cache_dtype,
710+
move_to_cpu=relevance_move_to_cpu,
711+
use_compression=relevance_use_compression,
712+
compression_method=relevance_compression_method,
713+
pickle_protocol=relevance_pickle_protocol,
714+
)
715+
relevance_trace.append(entry)
442716

443717
generated = torch.cat([generated, next_tokens], dim=1)
444718
attn = torch.cat(
@@ -473,6 +747,9 @@ def generate(
473747
info["scores_trace"] = scores_trace
474748
if return_relevance:
475749
info["relevance_trace"] = relevance_trace
750+
info["relevance_cache_policy"] = cache_policy
751+
if cache_dir_path is not None:
752+
info["relevance_cache_dir"] = str(cache_dir_path)
476753
if return_layerwise_output:
477754
info["layerwise_output_trace"] = io_data_trace
478755
return generated, info # ([1, T], dict)
@@ -621,6 +898,7 @@ def generate(
621898

622899
if return_relevance:
623900
step_rel_scores = []
901+
step_offset = len(relevance_trace_beam)
624902
for b in range(beams):
625903
# Use the OLD beam state (before new token) for relevance computation
626904
self.dlb.predict(
@@ -640,7 +918,18 @@ def generate(
640918
task="generation",
641919
debug=False,
642920
)
643-
step_rel_scores.append(rel_dict_b)
921+
entry = self._store_relevance_entry(
922+
rel_dict_b,
923+
policy=cache_policy,
924+
step_idx=step_offset * beams + b,
925+
cache_dir=cache_dir_path,
926+
target_dtype=cache_dtype,
927+
move_to_cpu=relevance_move_to_cpu,
928+
use_compression=relevance_use_compression,
929+
compression_method=relevance_compression_method,
930+
pickle_protocol=relevance_pickle_protocol,
931+
)
932+
step_rel_scores.append(entry)
644933

645934
relevance_trace_beam.append(step_rel_scores)
646935

@@ -700,6 +989,9 @@ def generate(
700989
for step_rels in relevance_trace_beam
701990
]
702991
info_beam["relevance_trace"] = flat_relevance
992+
info_beam["relevance_cache_policy"] = cache_policy
993+
if cache_dir_path is not None:
994+
info_beam["relevance_cache_dir"] = str(cache_dir_path)
703995
if return_layerwise_output:
704996
# collapse to top-1 beam (final winner)
705997
flat_io_trace = [

0 commit comments

Comments
 (0)