55from __future__ import annotations
66
77import time
8- from typing import Optional , List , Tuple , cast
8+ import gzip
9+ import lzma
10+ import numpy as np
11+ from pathlib import Path
12+ from typing import Optional , List , Tuple , cast , Any
913
1014import torch
1115import torch .nn .functional as F
1216
17+ # Optional: 7z support (requires py7zr)
18+ try :
19+ import py7zr
20+ HAS_7Z = True
21+ except ImportError :
22+ HAS_7Z = False
23+
1324from transformers .generation .logits_process import (
1425 LogitsProcessorList ,
1526 TemperatureLogitsWarper ,
@@ -68,6 +79,56 @@ def __init__(self, dlb, tokenizer):
6879 self .dlb = dlb
6980 self .tokenizer = tokenizer
7081
82+ # ---------- compression helpers ----------
83+
84+ @staticmethod
85+ def load_compressed_relevance (file_path : str ):
86+ """
87+ Load relevance data from compressed file with automatic format detection.
88+
89+ Args:
90+ file_path: Path to compressed relevance file (.pt.gz, .pt.xz, .pt.7z, or .pt)
91+
92+ Returns:
93+ Loaded relevance dictionary
94+
95+ Example:
96+ >>> relevance = DLBAutoSampler.load_compressed_relevance("step_00000.pt.gz")
97+ """
98+ path = Path (file_path )
99+
100+ if path .suffix == '.gz' :
101+ # Gzip compressed
102+ with gzip .open (path , 'rb' ) as f :
103+ return torch .load (f , weights_only = False )
104+
105+ elif path .suffix == '.xz' :
106+ # LZMA compressed
107+ with lzma .open (path , 'rb' ) as f :
108+ return torch .load (f , weights_only = False )
109+
110+ elif path .suffix == '.7z' :
111+ # 7z compressed
112+ if not HAS_7Z :
113+ raise ImportError (
114+ "py7zr library required to load 7z files. "
115+ "Install with: pip install py7zr"
116+ )
117+ import tempfile
118+ with tempfile .TemporaryDirectory () as tmpdir :
119+ tmpdir_path = Path (tmpdir )
120+ with py7zr .SevenZipFile (path , 'r' ) as archive :
121+ archive .extractall (tmpdir_path )
122+ # Find the extracted .pt file
123+ pt_files = list (tmpdir_path .glob ('*.pt' ))
124+ if not pt_files :
125+ raise ValueError (f"No .pt file found in 7z archive: { path } " )
126+ return torch .load (pt_files [0 ], weights_only = False )
127+
128+ else :
129+ # Uncompressed or unknown format
130+ return torch .load (path , weights_only = False )
131+
71132 # ---------- small dtype helpers ----------
72133
73134 @staticmethod
@@ -252,6 +313,173 @@ def add_val(x):
252313 add_val (rel_dict )
253314 return total
254315
316+ @staticmethod
317+ def _resolve_torch_dtype (dtype_hint ):
318+ if dtype_hint is None :
319+ return None
320+ if isinstance (dtype_hint , torch .dtype ):
321+ return dtype_hint
322+ if isinstance (dtype_hint , str ):
323+ key = dtype_hint .strip ().lower ()
324+ mapping = {
325+ "float32" : torch .float32 ,
326+ "fp32" : torch .float32 ,
327+ "float" : torch .float32 ,
328+ "float16" : torch .float16 ,
329+ "fp16" : torch .float16 ,
330+ "half" : torch .float16 ,
331+ "bfloat16" : torch .bfloat16 ,
332+ "bf16" : torch .bfloat16 ,
333+ "float64" : torch .float64 ,
334+ "fp64" : torch .float64 ,
335+ }
336+ if key in mapping :
337+ return mapping [key ]
338+ raise ValueError (f"Unsupported relevance dtype hint: { dtype_hint } " )
339+
340+ def _compress_relevance_tree (self , data , * , target_dtype = None , move_to_cpu = True ):
341+ if torch .is_tensor (data ):
342+ tensor = data .detach ()
343+ if move_to_cpu :
344+ tensor = tensor .to ("cpu" )
345+ if target_dtype is not None :
346+ tensor = tensor .to (dtype = target_dtype )
347+ return tensor .clone ()
348+ # Handle numpy arrays by converting to torch tensor with target dtype
349+ if isinstance (data , np .ndarray ):
350+ tensor = torch .from_numpy (data )
351+ if move_to_cpu :
352+ tensor = tensor .to ("cpu" )
353+ if target_dtype is not None :
354+ tensor = tensor .to (dtype = target_dtype )
355+ return tensor
356+ if isinstance (data , dict ):
357+ return {k : self ._compress_relevance_tree (v , target_dtype = target_dtype , move_to_cpu = move_to_cpu ) for k , v in data .items ()}
358+ if isinstance (data , list ):
359+ return [self ._compress_relevance_tree (v , target_dtype = target_dtype , move_to_cpu = move_to_cpu ) for v in data ]
360+ if isinstance (data , tuple ):
361+ return tuple (self ._compress_relevance_tree (v , target_dtype = target_dtype , move_to_cpu = move_to_cpu ) for v in data )
362+ return data
363+
364+ def _prepare_cache_dir (self , base_dir : Optional [str ], policy : str ):
365+ if policy != "disk" :
366+ return None
367+ if not base_dir :
368+ raise ValueError ("relevance_cache_dir is required when relevance_cache_policy='disk'" )
369+ root = Path (base_dir ).expanduser ()
370+ timestamp = int (time .time () * 1000 )
371+ run_dir = root / f"relevance_cache_run_{ timestamp } "
372+ run_dir .mkdir (parents = True , exist_ok = True )
373+ return run_dir
374+
375+ def _store_relevance_entry (
376+ self ,
377+ rel_dict ,
378+ * ,
379+ policy : str ,
380+ step_idx : int ,
381+ cache_dir : Optional [Path ],
382+ target_dtype ,
383+ move_to_cpu : bool ,
384+ use_compression : bool = True ,
385+ compression_method : str = "gzip" ,
386+ pickle_protocol : int = 4 ,
387+ ):
388+ """
389+ Store relevance entry according to specified policy.
390+
391+ Args:
392+ rel_dict: Relevance dictionary to store
393+ policy: Cache policy ("full", "summary", "disk", "none")
394+ step_idx: Generation step index
395+ cache_dir: Directory for disk caching
396+ target_dtype: Target dtype for compression
397+ move_to_cpu: Whether to move tensors to CPU
398+ use_compression: If True, use compression (default: True)
399+ compression_method: Compression method - "gzip", "lzma", "7z", or "none"
400+ - "gzip": Fast, good compression (default)
401+ - "lzma": Better compression, slower
402+ - "7z": Best compression, slowest (requires py7zr)
403+ - "none": No compression
404+ pickle_protocol: Pickle protocol version (2-5). Higher = better compression.
405+ Protocol 4 (default): Python 3.4+, good compression
406+ Protocol 5: Best compression, Python 3.8+
407+ """
408+ normalized_policy = (policy or "full" ).lower ()
409+ if normalized_policy == "none" :
410+ return None
411+
412+ processed = self ._compress_relevance_tree (rel_dict , target_dtype = target_dtype , move_to_cpu = move_to_cpu )
413+ if processed is None :
414+ return None
415+
416+ if normalized_policy == "summary" :
417+ return {"summary" : self ._summarize_relevance (processed )}
418+
419+ if normalized_policy == "disk" :
420+ if cache_dir is None :
421+ raise ValueError ("relevance_cache_dir must be provided when relevance_cache_policy='disk'" )
422+
423+ base_file_path = cache_dir / f"step_{ step_idx :05d} .pt"
424+
425+ # Determine compression method and file extension
426+ if not use_compression or compression_method == "none" :
427+ # No compression
428+ file_path = base_file_path
429+ torch .save (processed , file_path , pickle_protocol = pickle_protocol )
430+
431+ elif compression_method == "gzip" :
432+ # Gzip compression (fast, good ratio)
433+ file_path = Path (str (base_file_path ) + '.gz' )
434+ with gzip .open (file_path , 'wb' , compresslevel = 6 ) as f :
435+ torch .save (processed , f , pickle_protocol = pickle_protocol )
436+
437+ elif compression_method == "lzma" :
438+ # LZMA/xz compression (better ratio, slower)
439+ file_path = Path (str (base_file_path ) + '.xz' )
440+ with lzma .open (file_path , 'wb' , preset = 6 ) as f :
441+ torch .save (processed , f , pickle_protocol = pickle_protocol )
442+
443+ elif compression_method == "7z" :
444+ # 7z compression (best ratio, slowest)
445+ if not HAS_7Z :
446+ raise ImportError (
447+ "py7zr library required for 7z compression. "
448+ "Install with: pip install py7zr"
449+ )
450+ file_path = Path (str (base_file_path ) + '.7z' )
451+ # Save to temporary .pt file first
452+ import tempfile
453+ with tempfile .NamedTemporaryFile (suffix = '.pt' , delete = False ) as tmp :
454+ tmp_path = Path (tmp .name )
455+ torch .save (processed , tmp_path , pickle_protocol = pickle_protocol )
456+
457+ # Compress with 7z
458+ with py7zr .SevenZipFile (file_path , 'w' ) as archive :
459+ archive .write (tmp_path , arcname = f'step_{ step_idx :05d} .pt' )
460+
461+ # Clean up temp file
462+ tmp_path .unlink ()
463+
464+ else :
465+ raise ValueError (
466+ f"Unknown compression_method: { compression_method } . "
467+ f"Must be one of: 'gzip', 'lzma', '7z', 'none'"
468+ )
469+
470+ return {
471+ "summary" : self ._summarize_relevance (processed ),
472+ "path" : str (file_path ),
473+ "compression" : compression_method ,
474+ }
475+
476+
477+ if normalized_policy != "full" :
478+ raise ValueError (
479+ "relevance_cache_policy must be one of {'full', 'summary', 'disk', 'none'}"
480+ )
481+ return processed
482+
255483 # ---------- public API ----------
256484
257485 @torch .no_grad ()
@@ -286,17 +514,52 @@ def generate(
286514 return_layerwise_output : bool = False ,
287515 return_relevance : bool = False ,
288516 debug : bool = False ,
517+ relevance_cache_policy : str = "full" ,
518+ relevance_cache_dir : Optional [str ] = None ,
519+ relevance_compress_dtype : Optional [Any ] = "float16" ,
520+ relevance_move_to_cpu : bool = True ,
521+ relevance_use_compression : bool = True ,
522+ relevance_compression_method : str = "gzip" ,
523+ relevance_pickle_protocol : int = 4 ,
289524 ):
290525 """
291526 Always returns:
292527 - [1, T_total] (top-1 sequence)
293528 - Or (sequence, scores_trace) for sampling when return_scores=True
294- """
529+
530+ Relevance caching knobs:
531+ relevance_cache_policy: "full" (default), "summary", "disk", or "none".
532+ relevance_cache_dir: base directory for on-disk caching (policy="disk").
533+ relevance_compress_dtype: dtype hint (str or torch.dtype) for stored tensors.
534+ relevance_use_compression: If True, use compression for disk storage (default: True).
535+ relevance_compression_method: Compression method - "gzip" (default), "lzma", "7z", or "none".
536+ - "gzip": Fast, good compression (~75% reduction)
537+ - "lzma": Better compression (~80% reduction), slower
538+ - "7z": Best compression (~82% reduction), slowest
539+ - "none": No compression (only dtype compression)
540+ relevance_pickle_protocol: Pickle protocol (2-5). Higher = better compression. Default=4.
541+ relevance_move_to_cpu: move tensors to CPU before caching to reduce VRAM.
542+ """
295543 model = self ._get_causallm (self .dlb .model )
296544 device = input_ids .device
297545 B = input_ids .size (0 )
298546 assert B == 1 , "Current implementation assumes batch size = 1."
299547
548+ cache_policy = (relevance_cache_policy or "full" ).lower ()
549+ allowed_policies = {"full" , "summary" , "disk" , "none" }
550+ if cache_policy not in allowed_policies :
551+ raise ValueError (
552+ "relevance_cache_policy must be one of {'full', 'summary', 'disk', 'none'}"
553+ )
554+ cache_dtype = (
555+ self ._resolve_torch_dtype (relevance_compress_dtype )
556+ if relevance_compress_dtype is not None
557+ else None
558+ )
559+ cache_dir_path = None
560+ if return_relevance and cache_policy == "disk" :
561+ cache_dir_path = self ._prepare_cache_dir (relevance_cache_dir , cache_policy )
562+
300563 # Dtypes up-front
301564 input_ids = self ._as_long (input_ids )
302565
@@ -437,8 +700,19 @@ def generate(
437700 task = "generation" ,
438701 debug = False ,
439702 )
440- # rel_scalar = self._summarize_relevance(rel_dict)
441- relevance_trace .append (rel_dict )
703+ step_idx = len (relevance_trace )
704+ entry = self ._store_relevance_entry (
705+ rel_dict ,
706+ policy = cache_policy ,
707+ step_idx = step_idx ,
708+ cache_dir = cache_dir_path ,
709+ target_dtype = cache_dtype ,
710+ move_to_cpu = relevance_move_to_cpu ,
711+ use_compression = relevance_use_compression ,
712+ compression_method = relevance_compression_method ,
713+ pickle_protocol = relevance_pickle_protocol ,
714+ )
715+ relevance_trace .append (entry )
442716
443717 generated = torch .cat ([generated , next_tokens ], dim = 1 )
444718 attn = torch .cat (
@@ -473,6 +747,9 @@ def generate(
473747 info ["scores_trace" ] = scores_trace
474748 if return_relevance :
475749 info ["relevance_trace" ] = relevance_trace
750+ info ["relevance_cache_policy" ] = cache_policy
751+ if cache_dir_path is not None :
752+ info ["relevance_cache_dir" ] = str (cache_dir_path )
476753 if return_layerwise_output :
477754 info ["layerwise_output_trace" ] = io_data_trace
478755 return generated , info # ([1, T], dict)
@@ -621,6 +898,7 @@ def generate(
621898
622899 if return_relevance :
623900 step_rel_scores = []
901+ step_offset = len (relevance_trace_beam )
624902 for b in range (beams ):
625903 # Use the OLD beam state (before new token) for relevance computation
626904 self .dlb .predict (
@@ -640,7 +918,18 @@ def generate(
640918 task = "generation" ,
641919 debug = False ,
642920 )
643- step_rel_scores .append (rel_dict_b )
921+ entry = self ._store_relevance_entry (
922+ rel_dict_b ,
923+ policy = cache_policy ,
924+ step_idx = step_offset * beams + b ,
925+ cache_dir = cache_dir_path ,
926+ target_dtype = cache_dtype ,
927+ move_to_cpu = relevance_move_to_cpu ,
928+ use_compression = relevance_use_compression ,
929+ compression_method = relevance_compression_method ,
930+ pickle_protocol = relevance_pickle_protocol ,
931+ )
932+ step_rel_scores .append (entry )
644933
645934 relevance_trace_beam .append (step_rel_scores )
646935
@@ -700,6 +989,9 @@ def generate(
700989 for step_rels in relevance_trace_beam
701990 ]
702991 info_beam ["relevance_trace" ] = flat_relevance
992+ info_beam ["relevance_cache_policy" ] = cache_policy
993+ if cache_dir_path is not None :
994+ info_beam ["relevance_cache_dir" ] = str (cache_dir_path )
703995 if return_layerwise_output :
704996 # collapse to top-1 beam (final winner)
705997 flat_io_trace = [
0 commit comments