diff --git a/Cargo.lock b/Cargo.lock index 1713f4b87..81fcc6f78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -778,6 +778,7 @@ dependencies = [ "diskann-utils", "diskann-vector", "foldhash 0.2.0", + "rand 0.9.4", "thiserror 2.0.17", "tokio", ] diff --git a/diskann-garnet/Cargo.toml b/diskann-garnet/Cargo.toml index c16857e05..b81049947 100644 --- a/diskann-garnet/Cargo.toml +++ b/diskann-garnet/Cargo.toml @@ -16,8 +16,9 @@ dashmap = { workspace = true, features = ["inline"] } diskann.workspace = true diskann-quantization.workspace = true diskann-providers.workspace = true +diskann-utils.workspace = true diskann-vector.workspace = true foldhash = "0.2.0" +rand.workspace = true thiserror.workspace = true -tokio.workspace = true -diskann-utils.workspace = true +tokio = { workspace = true, features = ["sync"] } diff --git a/diskann-garnet/src/alloc.rs b/diskann-garnet/src/alloc.rs index 17db0775e..16e3cfb3c 100644 --- a/diskann-garnet/src/alloc.rs +++ b/diskann-garnet/src/alloc.rs @@ -9,7 +9,7 @@ use std::ptr::NonNull; /// Custom allocator that over-aligns to 8 bytes. This is needed since Garnet will hand us byte slices for f32 data /// that may be unaligned, so we need an allocator to make owned, aligned byte containers. #[derive(Debug, Clone, Copy)] -pub(crate) struct AlignToEight; +pub struct AlignToEight; unsafe impl AllocatorCore for AlignToEight { #[inline] diff --git a/diskann-garnet/src/dyn_index.rs b/diskann-garnet/src/dyn_index.rs index badedba51..77153dabe 100644 --- a/diskann-garnet/src/dyn_index.rs +++ b/diskann-garnet/src/dyn_index.rs @@ -7,7 +7,7 @@ use crate::{ SearchResults, garnet::{Context, GarnetId}, labels::GarnetQueryLabelProvider, - provider::{self, GarnetProvider}, + provider::{self, DynamicQuantization, GarnetProvider}, }; use diskann::{ ANNError, ANNResult, @@ -16,8 +16,7 @@ use diskann::{ utils::VectorRepr, }; use diskann_providers::{ - index::wrapped_async::DiskANNIndex, - model::graph::provider::{async_::common::FullPrecision, layers::BetaFilter}, + index::wrapped_async::DiskANNIndex, model::graph::provider::layers::BetaFilter, }; use std::sync::Arc; @@ -55,6 +54,10 @@ pub trait DynIndex: Send + Sync { fn internal_id_exists(&self, context: &Context, id: u32) -> bool; fn external_id_exists(&self, context: &Context, id: &GarnetId) -> bool; + + fn train_quantizer(&self, context: &Context) -> bool; + + fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize); } impl DynIndex for DiskANNIndex> { @@ -63,7 +66,7 @@ impl DynIndex for DiskANNIndex> { /// The data slice here must be aligned to `T` or this will panic. fn insert(&self, context: &Context, id: &GarnetId, data: &[u8]) -> ANNResult<()> { self.insert( - FullPrecision, + DynamicQuantization, context, id, bytemuck::cast_slice::(data), @@ -87,10 +90,10 @@ impl DynIndex for DiskANNIndex> { ) -> ANNResult { let query = bytemuck::cast_slice::(data); if let Some((labels, beta)) = filter { - let beta_filter = BetaFilter::new(FullPrecision, Arc::new(labels.clone()), beta); + let beta_filter = BetaFilter::new(DynamicQuantization, Arc::new(labels.clone()), beta); self.search(*params, &beta_filter, context, query, output) } else { - self.search(*params, &FullPrecision, context, query, output) + self.search(*params, &DynamicQuantization, context, query, output) } } @@ -105,9 +108,9 @@ impl DynIndex for DiskANNIndex> { let rt = tokio::runtime::Builder::new_current_thread() .build() .map_err(|e| ANNError::new(diskann::ANNErrorKind::Opaque, e))?; - let mut accessor: provider::FullAccessor<'_, T> = - >::search_accessor( - &FullPrecision, + let mut accessor: provider::DynamicAccessor<'_, T> = + >::search_accessor( + &DynamicQuantization, self.inner.provider(), context, )?; @@ -115,13 +118,12 @@ impl DynIndex for DiskANNIndex> { // Look up internal ID let iid = self.inner.provider().to_internal_id(context, id)?; let data = rt.block_on(accessor.get_element(iid))?; - let data_bytes = bytemuck::cast_slice::(&data); - self.search_vector(context, data_bytes, params, filter, output) + self.search_vector(context, &data, params, filter, output) } fn remove(&self, context: &Context, id: &GarnetId) -> ANNResult<()> { self.inplace_delete( - FullPrecision, + DynamicQuantization, context, id, 3, @@ -147,4 +149,14 @@ impl DynIndex for DiskANNIndex> { fn external_id_exists(&self, context: &Context, id: &GarnetId) -> bool { self.inner.provider().vector_id_exists(context, id) } + + fn train_quantizer(&self, context: &Context) -> bool { + self.inner.provider().train_quantizer(context) + } + + fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize) { + self.inner + .provider() + .backfill_quant_vectors(context, task_idx, task_count); + } } diff --git a/diskann-garnet/src/fsm.rs b/diskann-garnet/src/fsm.rs index 37096278f..6b6e56130 100644 --- a/diskann-garnet/src/fsm.rs +++ b/diskann-garnet/src/fsm.rs @@ -26,7 +26,7 @@ use crossbeam::queue::ArrayQueue; use std::sync::{ RwLock, - atomic::{AtomicBool, AtomicU32, Ordering}, + atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}, }; use thiserror::Error; @@ -47,11 +47,23 @@ pub enum FsmError { } pub struct FreeSpaceMap { + /// Garnet callbacks for reading/writing FSM keys callbacks: Callbacks, + /// A flag to signal whether there are free IDs in the FSM. + /// This is set after a scan of the FSM, and is used to prevent extraneous reads + /// of FSM blocks. has_free_ids: AtomicBool, + /// A queue of previously deleted IDs to prevent excessive reads to the FSM fast_free_list: ArrayQueue, + /// The maximum block ID stored in the FSM max_block: RwLock, + /// The next ID that will be minted if reusing previously deleted IDs is unavailable next_id: AtomicU32, + /// The total number of IDs marked used in the FSM + total_used: AtomicUsize, + /// Lock that prevents reuse of prevously used IDs. + /// This is used to disable ID reuse during quantization backfill. + reuse_lock: AtomicUsize, } impl FreeSpaceMap { @@ -60,6 +72,7 @@ impl FreeSpaceMap { let fast_free_list = ArrayQueue::new(FAST_SIZE); let max_block = RwLock::new(u32::MAX); let next_id = AtomicU32::new(0); + let total_used = AtomicUsize::new(0); let mut this = Self { callbacks, @@ -67,6 +80,8 @@ impl FreeSpaceMap { fast_free_list, max_block, next_id, + total_used, + reuse_lock: AtomicUsize::new(0), }; // Attempt to load state from Garnet. @@ -97,6 +112,7 @@ impl FreeSpaceMap { let mut block = vec![0u8; BLOCK_SIZE_BYTES]; let mut last_used_id = -1i64; + let mut total_used = 0usize; for block_id in (0..max_block_id).rev() { let block_key = Self::block_key(block_id); @@ -115,8 +131,9 @@ impl FreeSpaceMap { let used = bit_used(byte, bidx); if used { last_used_id = last_used_id.max(id as i64); - } else if (id as i64) < last_used_id && self.fast_free_list.push(id).is_err() { - break; + total_used += 1; + } else if (id as i64) < last_used_id { + let _ = self.fast_free_list.push(id); } id = id.saturating_sub(1); @@ -130,6 +147,8 @@ impl FreeSpaceMap { self.next_id .store((last_used_id + 1) as u32, Ordering::Release); + self.total_used.store(total_used, Ordering::Release); + if !self.fast_free_list.is_empty() { self.has_free_ids.store(true, Ordering::Release); } @@ -174,6 +193,14 @@ impl FreeSpaceMap { return Err(FsmError::Garnet(GarnetError::Write)); } + if changed { + if used { + self.total_used.fetch_add(1, Ordering::AcqRel); + } else { + self.total_used.fetch_sub(1, Ordering::AcqRel); + } + } + // NOTE: We don't modify the free list if the id was already free. if !used && changed { // Push the id onto the fast free list. If the queue is full, ignore it. @@ -214,7 +241,7 @@ impl FreeSpaceMap { /// This may be a a fresh ID larger than all the others, or it may be a reused ID that /// previously belonged to a deleted element. The returned ID is marked as used. pub fn next_id(&self, ctx: Context) -> Result { - if self.has_free_ids.load(Ordering::Acquire) { + if self.can_reuse() && self.has_free_ids.load(Ordering::Acquire) { // We retry reusing a freed ID until there are none or we get one and marking it used // succeeds in changing the value. loop { @@ -270,6 +297,10 @@ impl FreeSpaceMap { self.next_id.load(Ordering::Acquire).saturating_sub(1) } + pub fn total_used(&self) -> usize { + self.total_used.load(Ordering::Acquire) + } + /// Return the FSM block number, byte index, and bit index for a given ID. /// The block number is the block which stores this ID, the byte index is byte offset /// within the block which contains the status bits, and the bit index is the bit index @@ -301,9 +332,9 @@ impl FreeSpaceMap { let mut has_free_ids = false; let mut id = 0u32; + let mut block = vec![0u8; BLOCK_SIZE_BYTES]; 'scan: for block_id in 0..*max_block { let block_key = Self::block_key(block_id); - let mut block = vec![0u8; BLOCK_SIZE_BYTES]; if !self .callbacks .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) @@ -363,6 +394,66 @@ impl FreeSpaceMap { Ok(()) } + + /// Visit each used id in the FSM, invoking f on each id. + pub fn visit_used(&self, ctx: Context, mut f: F) -> Result<(), FsmError> + where + F: FnMut(u32) -> bool, + { + let max_block = { *self.max_block.read().unwrap() }; + let mut block = vec![0u8; BLOCK_SIZE_BYTES]; + let mut id = 0u32; + + for block_id in 0..max_block + 1 { + let block_key = Self::block_key(block_id); + if !self + .callbacks + .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) + { + return Err(FsmError::Garnet(GarnetError::Read)); + } + + for &byte in &block { + if byte == 0x00 { + id += 8; + continue; + } + + for bidx in 0..8 { + if bit_used(byte, bidx) { + let keep_going = f(id); + if !keep_going { + return Ok(()); + } + } + id += 1; + } + } + } + + Ok(()) + } + + /// Returns whether previously deleted IDs may be reused. + fn can_reuse(&self) -> bool { + self.reuse_lock.load(Ordering::Acquire) == 0 + } + + /// Prevent the reuse of previously deleted IDs. + /// Each call to this increments a counter, and only once the counter is back to zero + /// will reuse be allowed again. + pub fn lock_reuse(&self) { + self.reuse_lock.fetch_add(1, Ordering::AcqRel); + } + + /// Resume reuse of previously deleted IDs. + /// Each call to this decrements a counter, and only once the counter is back to zero + /// will reuse be allowed. This returns whether reuse was actually enabled. + pub fn unlock_reuse(&self) -> bool { + let prev = self.reuse_lock.fetch_sub(1, Ordering::AcqRel); + debug_assert_ne!(prev, 0); + prev == 1 + } } /// Return whether the `bidx`th bit is set in byte, where bits are labeled from left to right. diff --git a/diskann-garnet/src/garnet.rs b/diskann-garnet/src/garnet.rs index 85ba0cce5..566460af1 100644 --- a/diskann-garnet/src/garnet.rs +++ b/diskann-garnet/src/garnet.rs @@ -86,6 +86,10 @@ impl Callbacks { self.rmw_callback } + #[expect( + dead_code, + reason = "currently unused, but may be needed in the future" + )] pub fn exists_iid(&self, ctx: Context, id: u32) -> bool { let key = [4, id]; // SAFETY: Key bytes are preceded by 4 bytes of space. @@ -100,6 +104,10 @@ impl Callbacks { unsafe { self.exists_raw(ctx, &key_bytes[4..]) } } + #[expect( + dead_code, + reason = "currently unused, but may be needed in the future" + )] pub fn exists_eid(&self, ctx: Context, id: &GarnetId) -> bool { // SAFETY: GarnetId ensures there are 4 bytes preceding the key bytes. unsafe { self.exists_raw(ctx, id) } diff --git a/diskann-garnet/src/lib.rs b/diskann-garnet/src/lib.rs index a66cd7ab0..c6b40e37f 100644 --- a/diskann-garnet/src/lib.rs +++ b/diskann-garnet/src/lib.rs @@ -51,6 +51,7 @@ mod fsm; mod garnet; mod labels; mod provider; +mod quantization; #[cfg(test)] mod test_utils; @@ -175,7 +176,8 @@ fn create_index_impl( callbacks: Callbacks, context: Context, ) -> Result, GarnetProviderError> { - let provider = GarnetProvider::::new(dim, metric_type, max_degree, callbacks, context)?; + let provider = + GarnetProvider::::new(dim, quant_type, metric_type, max_degree, callbacks, context)?; let state = if provider.start_points_exist() { AtomicUsize::new(IndexState::Ready as usize) } else { @@ -217,7 +219,7 @@ pub unsafe extern "C" fn create_index( target_degree, config::MaxDegree::Value(max_degree as usize), l_build as usize, - config::PruneKind::TriangleInequality, + metric_type.into(), ) .build() { @@ -230,6 +232,7 @@ pub unsafe extern "C" fn create_index( let callbacks = Callbacks::new(read_callback, write_callback, delete_callback, rmw_callback); match quant_type { + VectorQuantType::Invalid => ptr::null(), VectorQuantType::XPreQ8 => { if let Ok(index) = create_index_impl::( quant_type, @@ -245,7 +248,7 @@ pub unsafe extern "C" fn create_index( ptr::null() } } - VectorQuantType::NoQuant => { + VectorQuantType::NoQuant | VectorQuantType::Bin | VectorQuantType::Q8 => { if let Ok(index) = create_index_impl::( quant_type, config, @@ -260,7 +263,6 @@ pub unsafe extern "C" fn create_index( ptr::null() } } - _ => ptr::null(), } } @@ -318,6 +320,9 @@ fn interpret_vector<'a>( let v = match vector_value_type { VectorValueType::Invalid => return None, VectorValueType::FP32 => match quant_type { + VectorQuantType::Invalid => { + return None; + } VectorQuantType::XPreQ8 => { let mut bp = if let Ok(bp) = Poly::broadcast(0u8, vector_len, AlignToEight) { bp @@ -332,11 +337,11 @@ fn interpret_vector<'a>( } PolyCow::from(bp) } - VectorQuantType::NoQuant if v.as_ptr().align_offset(4) == 0 => { + _ if v.as_ptr().align_offset(4) == 0 => { // pointer is correctly aligned to interpret as f32 PolyCow::from(v) } - VectorQuantType::NoQuant => { + _ => { // need to copy f32 data as it is unaligned let mut fp = if let Ok(fp) = Poly::broadcast(0u8, vector_len_bytes, AlignToEight) { fp @@ -346,9 +351,6 @@ fn interpret_vector<'a>( fp.copy_from_slice(v); PolyCow::from(fp) } - _ => { - return None; - } }, VectorValueType::XB8 => match quant_type { VectorQuantType::XPreQ8 => PolyCow::from(v), @@ -375,6 +377,10 @@ fn interpret_vector<'a>( Some(v) } +const INSERT_FAIL: u8 = 0; +const INSERT_SUCCESS: u8 = 1; +const INSERT_SUCCESS_START_TRAINING: u8 = 2; + /// # Safety /// /// FFI @@ -389,7 +395,7 @@ pub unsafe extern "C" fn insert( vector_len: usize, attribute_data: *const u8, attribute_len: usize, -) -> bool { +) -> u8 { let index = unsafe { &*index_ptr.cast::() }; let ctx = Context(ctx); @@ -404,13 +410,13 @@ pub unsafe extern "C" fn insert( ) { v } else { - return false; + return INSERT_FAIL; }; if let Some(_err) = ensure_index_ready_or_init(index, || index.inner.maybe_set_start_point(&ctx, &v).err()) { - return false; + return INSERT_FAIL; }; // Write attributes to garnet @@ -420,11 +426,22 @@ pub unsafe extern "C" fn insert( &[] }; if index.inner.set_attributes(&ctx, &id, attr_data).is_err() { - return false; + return INSERT_FAIL; } + let old_ready = provider::QUANTIZER_READY.with(|v| v.load(Ordering::Acquire)); + // Insert the vector - index.inner.insert(&ctx, &id, &v).is_ok() + if index.inner.insert(&ctx, &id, &v).is_ok() { + let ready = provider::QUANTIZER_READY.with(|v| v.load(Ordering::Acquire)); + if !old_ready && ready { + INSERT_SUCCESS_START_TRAINING + } else { + INSERT_SUCCESS + } + } else { + INSERT_FAIL + } } fn ensure_index_ready_or_init(index: &Index, init: F) -> Option @@ -466,6 +483,34 @@ where None } +/// # Safety +/// +/// FFI +#[unsafe(no_mangle)] +pub unsafe extern "C" fn build_quant_table(context: u64, index_ptr: *const c_void) -> bool { + let index = unsafe { &*index_ptr.cast::() }; + let ctx = Context(context); + + index.inner.train_quantizer(&ctx) +} + +/// # Safety +/// +/// FFI +#[unsafe(no_mangle)] +pub unsafe extern "C" fn backfill_quant_vectors( + context: u64, + index_ptr: *const c_void, + task_index: usize, + task_count: usize, +) { + let index = unsafe { &*index_ptr.cast::() }; + let ctx = Context(context); + index + .inner + .backfill_quant_vectors(&ctx, task_index, task_count); +} + /// # Safety /// /// FFI diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 0be5be9a4..6f8750135 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -11,7 +11,7 @@ use diskann::{ config::defaults::MAX_OCCLUSION_SIZE, glue::{ self, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + PruneStrategy, SearchExt, SearchPostProcess, SearchPostProcessStep, SearchStrategy, }, workingset::{self, map::Entry}, }, @@ -22,21 +22,40 @@ use diskann::{ }, utils::VectorRepr, }; -use diskann_providers::model::graph::provider::async_::common::FullPrecision; -use diskann_utils::Reborrow; +use diskann_quantization::alloc::{AllocatorError, Poly}; use diskann_utils::object_pool::{AsPooled, ObjectPool, PooledRef, Undef}; -use diskann_vector::{PreprocessedDistanceFunction, contains::ContainsSimd, distance::Metric}; +use diskann_utils::{Reborrow, views::Matrix}; +use diskann_vector::{ + DistanceFunction, PreprocessedDistanceFunction, contains::ContainsSimd, distance::Metric, +}; use std::{ - future, mem, + any::TypeId, + future, + marker::PhantomData, + mem, ops::{Deref, DerefMut}, + sync::atomic::{AtomicBool, AtomicU64, Ordering}, + time::SystemTime, }; use thiserror::Error; use crate::{ + VectorQuantType, + alloc::AlignToEight, fsm::{FreeSpaceMap, FsmError}, garnet::{Callbacks, Context, GarnetError, GarnetId, Term}, + quantization::{ + self, DynDistanceComputer, DynQueryComputer, GarnetQuantizer, GarnetQuantizerError, + }, }; +thread_local! { + /// Thread local flag to detect when we've reached the quantization threshold. This is needed + /// to return the correct status at the end of insert() since the provider code that becomes + /// aware of the state change cannot directly communicate it back to the caller. + pub static QUANTIZER_READY: AtomicBool = const { AtomicBool::new(false) }; +} + #[derive(Clone)] struct AdjList(AdjacencyList); @@ -64,6 +83,44 @@ impl AsPooled for AdjList { } } +/// A type erased vector, properly aligned so that it can hold any size element +/// (up to 8 bytes long). +pub struct DynVector { + inner: Poly<[u8], AlignToEight>, + ty: PhantomData, +} + +impl DynVector { + fn new(inner: Poly<[u8], AlignToEight>) -> Self { + Self { + inner, + ty: PhantomData, + } + } +} + +impl Deref for DynVector { + type Target = Poly<[u8], AlignToEight>; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for DynVector { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl<'a, T: VectorRepr> Reborrow<'a> for DynVector { + type Target = &'a [u8]; + + fn reborrow(&'a self) -> Self::Target { + &self.inner + } +} + #[derive(Debug, Error)] pub enum GarnetProviderError { #[error("Garnet operation failed")] @@ -72,6 +129,18 @@ pub enum GarnetProviderError { Fsm(#[from] FsmError), #[error("Start point invalid")] StartPoint, + #[error("Allocation failed")] + AllocFailed(#[from] AllocatorError), + #[error("Invalid quantizer for vector data")] + InvalidQuantizer, + #[error("Expected quantizer is missing")] + MissingQuantizer, + #[error("Failed to gather quantizer training data")] + MissingTrainingData, + #[error("Quantizer error: {0}")] + Quantizer(#[from] GarnetQuantizerError), + #[error("Post processing error: {0}")] + PostProcessing(Box), } impl From for ANNError { @@ -83,21 +152,32 @@ impl From for ANNError { diskann::always_escalate!(GarnetProviderError); +/// The Garnet DataProvider implementation. pub struct GarnetProvider { dim: usize, metric_type: Metric, max_degree: usize, callbacks: Callbacks, + /// The quantizer the index will use, or None if NOQUANT is used. + quantizer: Option>, + /// Tracks whether quantization backfill is complete. + all_quantized: AtomicBool, + /// Tracks whether training has already started. + training_started: AtomicU64, id_buffer_pool: ObjectPool, filtered_ids_pool: ObjectPool>, + quant_buffer_pool: ObjectPool>, neighbor_cache: DashMap, foldhash::fast::RandomState>, - start_point_cache: DashMap, foldhash::fast::RandomState>, + start_point_cache: DashMap, foldhash::fast::RandomState>, + start_point_quant_cache: DashMap, foldhash::fast::RandomState>, fsm: FreeSpaceMap, + _phantom: PhantomData, } impl GarnetProvider { pub fn new( dim: usize, + quant_type: VectorQuantType, metric_type: Metric, max_degree: usize, callbacks: Callbacks, @@ -114,11 +194,13 @@ impl GarnetProvider { let start_point_cache = DashMap::with_capacity_and_hasher(1, foldhash::fast::RandomState::default()); + let start_point_quant_cache = + DashMap::with_capacity_and_hasher(1, foldhash::fast::RandomState::default()); let neighbor_cache = DashMap::with_capacity_and_hasher(1, foldhash::fast::RandomState::default()); // Try to read the start point from Garnet - let mut v = vec![T::default(); dim]; + let mut v = Poly::broadcast(0u8, dim * mem::size_of::(), AlignToEight)?; if callbacks.read_single_iid(context.term(Term::Vector), 0, &mut v) { let mut neighbors = vec![0u32; max_degree + 1]; if !callbacks.read_single_iid(context.term(Term::Neighbors), 0, &mut neighbors) { @@ -134,16 +216,50 @@ impl GarnetProvider { let fsm = FreeSpaceMap::new(context, callbacks)?; + let (quantizer, canonical_bytes, all_quantized) = match quant_type { + VectorQuantType::NoQuant | VectorQuantType::XPreQ8 => (None, 0, false), + VectorQuantType::Invalid => return Err(GarnetProviderError::InvalidQuantizer), + VectorQuantType::Q8 => { + if TypeId::of::() != TypeId::of::() { + return Err(GarnetProviderError::InvalidQuantizer); + } + + let quantizer = Box::new(quantization::MinMax8Bit::new(dim, metric_type)?) + as Box; + let canonical_bytes = quantizer.canonical_bytes(); + // NOTE: Q8 needs no training, so it always starts with backfill complete. + (Some(quantizer), canonical_bytes, true) + } + VectorQuantType::Bin => { + if TypeId::of::() != TypeId::of::() { + return Err(GarnetProviderError::InvalidQuantizer); + } + + let quantizer = + Box::new(quantization::Spherical1Bit::new(dim)) as Box; + let canonical_bytes = quantizer.canonical_bytes(); + (Some(quantizer), canonical_bytes, false) + } + }; + let quant_buffer_pool = + ObjectPool::new(Undef::new(canonical_bytes), parallelism, Some(parallelism)); + Ok(Self { dim, metric_type, max_degree, callbacks, + quantizer, + all_quantized: AtomicBool::new(all_quantized), + training_started: AtomicU64::new(0), id_buffer_pool, filtered_ids_pool, + quant_buffer_pool, start_point_cache, + start_point_quant_cache, neighbor_cache, fsm, + _phantom: PhantomData, }) } @@ -152,7 +268,7 @@ impl GarnetProvider { context: &Context, point: &[T], ) -> Result<(), GarnetProviderError> { - let mut v = vec![T::default(); self.dim]; + let mut v = Poly::broadcast(0u8, self.dim * mem::size_of::(), AlignToEight)?; if self .callbacks .read_single_iid(context.term(Term::Vector), 0, &mut v) @@ -187,6 +303,25 @@ impl GarnetProvider { return Err(GarnetError::Write.into()); } + if self.is_quantized() + && let Some(quantizer) = self.quantizer() + { + // We are already able to quantize, so store the quantized start point + + let mut qpoint = vec![0u8; quantizer.canonical_bytes()]; + quantizer.compress(bytemuck::cast_slice::(point), &mut qpoint)?; + + if !self + .callbacks + .write_iid(context.term(Term::Quantized), 0, &qpoint) + { + return Err(GarnetError::Write.into()); + } + + self.start_point_quant_cache + .insert(0, Poly::from_iter(qpoint.iter().copied(), AlignToEight)?); + } + if !self .callbacks .write_iid(context.term(Term::Neighbors), 0, &neighbors) @@ -194,7 +329,13 @@ impl GarnetProvider { return Err(GarnetError::Write.into()); } - self.start_point_cache.insert(0, point.to_vec()); + self.start_point_cache.insert( + 0, + Poly::from_iter( + bytemuck::cast_slice::(point).iter().copied(), + AlignToEight, + )?, + ); self.neighbor_cache .insert(0, Vec::with_capacity(self.max_degree + 1)); } @@ -206,10 +347,6 @@ impl GarnetProvider { self.start_point_cache.get(&0).is_some() && self.neighbor_cache.get(&0).is_some() } - pub fn callbacks(&self) -> &Callbacks { - &self.callbacks - } - pub fn set_attributes( &self, context: &Context, @@ -240,6 +377,189 @@ impl GarnetProvider { pub fn max_internal_id(&self) -> u32 { self.fsm.max_id() } + + /// Train the quantizer. + /// + /// This should only be called when at least `quantizer.required_vectors()` vectors exist in + /// the provider. This will build quantization tables, but does not quantize any vectors. + pub fn train_quantizer(&self, context: &Context) -> bool { + // Collect up to `self.quantizer.required_vectors()` vectors and use them for quantizer training. + + let current_time: u64 = + match SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH) { + Ok(t) => { + let t = t.as_millis().try_into().unwrap_or(0); + if t == 0 { + return false; + } + t + } + Err(_) => return false, + }; + + // Ensure we don't kick off training twice. + match self.training_started.compare_exchange( + 0, + current_time, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => (), + Err(_) => return false, + } + + let quantizer = match &self.quantizer { + Some(q) => q, + None => { + self.training_started.store(0, Ordering::Release); + return false; + } + }; + + debug_assert_eq!(std::any::TypeId::of::(), std::any::TypeId::of::()); + + let rows = quantizer.required_vectors(); + let mut data = Matrix::new(0f32, rows, self.dim); + let mut row_idx = 0usize; + + if self + .fsm + .visit_used(*context, |id| { + // Skip the start point. + if id == 0 { + return true; + } + + if row_idx >= rows { + return false; + } + + // Read the vector into the data matrix. + // Note that it's ok to read f32 instead of T here, because this can only get called when T == f32. + let row = data.row_mut(row_idx); + if !self + .callbacks + .read_single_iid(context.term(Term::Vector), id, row) + { + return false; + } + + row_idx += 1; + + true + }) + .is_err() + { + // Training failed + self.training_started.store(0, Ordering::Release); + return false; + } + + if row_idx < quantizer.required_vectors() { + self.training_started.store(0, Ordering::Release); + return false; + } + + let view = if let Some(view) = data.subview(0..row_idx) { + view + } else { + return false; + }; + + // Train the quantizer. + match quantizer.train(self.metric_type, view) { + Ok(()) => true, + Err(_e) => { + self.training_started.store(0, Ordering::Release); + false + } + } + } + + /// Bulk quantize previously inserted vectors. + /// + /// This function will be invoked on multiple threads. The total number of tasks and the ID of + /// the current task are given as inputs. + pub fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize) { + let quantizer = match &self.quantizer { + Some(q) => q, + None => return, + }; + + let max_id = self.fsm.max_id() as usize; + + let task_count = task_count.min(max_id + 1); + if task_idx >= task_count { + return; + } + + let work_count = (max_id + 1) / task_count; // will be >= 1 + let start_id = (work_count * task_idx) as u32; + let end_id = (work_count * (task_idx + 1)).min(max_id + 1) as u32; + + self.fsm.lock_reuse(); + + let mut v = vec![0f32; self.dim]; + let mut q = vec![0u8; quantizer.canonical_bytes()]; + for id in start_id..end_id { + if !self + .callbacks + .read_single_iid(context.term(Term::Vector), id, &mut v) + { + continue; + } + + if quantizer.compress(&v, &mut q).is_err() { + continue; + }; + + if !self + .callbacks + .write_iid(context.term(Term::Quantized), id, &q) + { + continue; + } + } + + let backfill_finished = self.fsm.unlock_reuse(); + + if backfill_finished { + // Finish by quantizing the start points + if let Some(v) = self.start_point_cache.get(&0) + && quantizer + .compress(bytemuck::cast_slice::(&v), &mut q) + .is_ok() + { + let _ = self + .callbacks + .write_iid(context.term(Term::Quantized), 0, &q); + + // set the cache + let point = if let Ok(p) = Poly::from_iter(q.iter().copied(), AlignToEight) { + p + } else { + return; + }; + self.start_point_quant_cache.insert(0, point); + } + + self.all_quantized.store(true, Ordering::Release); + } + } + + /// Returns the quantizer associated with the index. + fn quantizer(&self) -> Option<&dyn GarnetQuantizer> { + if let Some(quantizer) = &self.quantizer { + return Some(&**quantizer as &dyn GarnetQuantizer); + } + + None + } + + /// Returns quantization status. If this is true, the index is operating fully quantized. + pub fn is_quantized(&self) -> bool { + self.quantizer.is_some() && self.all_quantized.load(Ordering::Acquire) + } } impl DataProvider for GarnetProvider { @@ -287,11 +607,31 @@ impl SetElement<&[T]> for GarnetProvider { ) -> Result { let internal_id = self.fsm.next_id(*context)?; + // Set quantization readiness + if let Some(quantizer) = &self.quantizer + && !quantizer.is_prepared() + && self.fsm.total_used() > quantizer.required_vectors() + { + QUANTIZER_READY.with(|v| v.store(true, Ordering::Release)); + } + let insert = || -> Result<(), Self::SetError> { self.callbacks .write_iid(context.term(Term::Vector), internal_id, element) .then_some(()) .ok_or(GarnetError::Write)?; + if let Some(quantizer) = &self.quantizer + && quantizer.is_prepared() + { + let mut quant = self + .quant_buffer_pool + .get_ref(Undef::new(quantizer.canonical_bytes())); + quantizer.compress(bytemuck::cast_slice::(element), &mut quant)?; + self.callbacks + .write_iid(context.term(Term::Quantized), internal_id, &quant) + .then_some(()) + .ok_or(GarnetError::Write)?; + } self.callbacks .write_iid(context.term(Term::ExtMap), internal_id, id) .then_some(()) @@ -345,6 +685,7 @@ impl Delete for GarnetProvider { // NOTE: Commented out until DiskANN fixes accessing neighbor data post-delete. //ok &= self.callbacks.delete_iid(context.term(Term::Neighbors), id); ok &= self.callbacks.delete_iid(context.term(Term::Vector), id); + ok &= self.callbacks.delete_iid(context.term(Term::Quantized), id); if !ok { return future::ready(Err(GarnetError::Delete.into())); @@ -386,20 +727,25 @@ impl Delete for GarnetProvider { } } -#[allow(dead_code)] -pub struct FullAccessor<'a, T: VectorRepr> { +/// Dynamic accessor that seamlessly transitions from full precision vector based operation to +/// quantized-only operation. +#[derive(Copy, Clone, Debug)] +pub struct DynamicQuantization; + +pub struct DynamicAccessor<'a, T: VectorRepr> { provider: &'a GarnetProvider, context: &'a Context, - is_search: bool, + /// Whether this accessor should use quantized vectors + quantized: bool, id_buffer: PooledRef<'a, AdjList>, filtered_ids: PooledRef<'a, Vec>, } -impl<'a, T: VectorRepr> FullAccessor<'a, T> { +impl<'a, T: VectorRepr> DynamicAccessor<'a, T> { pub(crate) fn new( provider: &'a GarnetProvider, context: &'a Context, - is_search: bool, + quantized: bool, ) -> Self { let id_buffer = provider .id_buffer_pool @@ -407,10 +753,10 @@ impl<'a, T: VectorRepr> FullAccessor<'a, T> { let filtered_ids = provider .filtered_ids_pool .get_ref(Undef::new(MAX_OCCLUSION_SIZE.get() as usize * 2)); // x2 to allow for the length prefixes for garnet - FullAccessor { + DynamicAccessor { provider, context, - is_search, + quantized, id_buffer, filtered_ids, } @@ -434,6 +780,7 @@ impl<'a, T: VectorRepr> FullAccessor<'a, T> { id, &mut guard, ) { + guard.finish(0); return false; } @@ -444,11 +791,11 @@ impl<'a, T: VectorRepr> FullAccessor<'a, T> { } } -impl HasId for FullAccessor<'_, T> { +impl HasId for DynamicAccessor<'_, T> { type Id = u32; } -impl SearchExt for FullAccessor<'_, T> { +impl SearchExt for DynamicAccessor<'_, T> { fn starting_points(&self) -> impl Future>> + Send { let points = if self.provider.start_points_exist() { vec![0] @@ -466,7 +813,7 @@ impl SearchExt for FullAccessor<'_, T> { } } -impl ExpandBeam<&[T]> for FullAccessor<'_, T> { +impl ExpandBeam<&[T]> for DynamicAccessor<'_, T> { fn expand_beam( &mut self, ids: Itr, @@ -490,14 +837,31 @@ impl ExpandBeam<&[T]> for FullAccessor<'_, T> { .filter(|id| pred.eval_mut(id)) { if id == 0 { - let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { - r + let dist = if self.quantized + && let Some(_quantizer) = self.provider.quantizer() + { + let guard = if let Some(r) = self.provider.start_point_quant_cache.get(&id) + { + r + } else { + return future::ready(Err(GarnetProviderError::Garnet( + GarnetError::Read, + ) + .into())); + }; + computer.evaluate_similarity(&*guard) } else { - return future::ready(Err( - GarnetProviderError::Garnet(GarnetError::Read).into() - )); + let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { + r + } else { + return future::ready(Err(GarnetProviderError::Garnet( + GarnetError::Read, + ) + .into())); + }; + computer.evaluate_similarity(&*guard) }; - let dist = computer.evaluate_similarity(&*guard); + on_neighbors(dist, id); } else { self.filtered_ids.push(4); @@ -505,49 +869,78 @@ impl ExpandBeam<&[T]> for FullAccessor<'_, T> { } } - self.provider.callbacks.read_multi_lpiid( - self.context.term(Term::Vector), - &self.filtered_ids, - |i, v| { - let dist = computer.evaluate_similarity(v); - on_neighbors(dist, self.filtered_ids[i as usize * 2 + 1]); - }, - ); + let ctx = if self.quantized { + self.context.term(Term::Quantized) + } else { + self.context.term(Term::Vector) + }; + + if !self.filtered_ids.is_empty() { + self.provider + .callbacks + .read_multi_lpiid(ctx, &self.filtered_ids, |i, v| { + let dist = computer.evaluate_similarity(v); + on_neighbors(dist, self.filtered_ids[i as usize * 2 + 1]); + }); + } } future::ready(Ok(())) } } -impl Accessor for FullAccessor<'_, T> { +impl Accessor for DynamicAccessor<'_, T> { type Element<'a> - = Vec + = DynVector where Self: 'a; - type ElementRef<'a> = &'a [T]; + type ElementRef<'a> = &'a [u8]; type GetError = GarnetProviderError; fn get_element( &mut self, id: Self::Id, ) -> impl Future, Self::GetError>> + Send { - let mut v = vec![T::default(); self.provider.dim]; + let v_len = if self.quantized + && let Some(quantizer) = self.provider.quantizer() + { + quantizer.canonical_bytes() + } else { + self.provider.dim * mem::size_of::() + }; + + let mut v = match Poly::broadcast(0u8, v_len, AlignToEight) { + Ok(v) => DynVector::new(v), + Err(e) => return future::ready(Err(GarnetProviderError::AllocFailed(e))), + }; if id == 0 { - let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { - r + if self.quantized { + let guard = if let Some(r) = self.provider.start_point_quant_cache.get(&id) { + r + } else { + return future::ready(Err(GarnetError::Read.into())); + }; + v.copy_from_slice(&guard); + return future::ready(Ok(v)); } else { - return future::ready(Err(GarnetError::Read.into())); - }; - v.copy_from_slice(&guard); - return future::ready(Ok(v)); + let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { + r + } else { + return future::ready(Err(GarnetError::Read.into())); + }; + v.copy_from_slice(&guard); + return future::ready(Ok(v)); + } } - if !self - .provider - .callbacks - .read_single_iid(self.context.term(Term::Vector), id, &mut v) - { + let ctx = if self.quantized { + self.context.term(Term::Quantized) + } else { + self.context.term(Term::Vector) + }; + + if !self.provider.callbacks.read_single_iid(ctx, id, &mut v) { return future::ready(Err(GarnetError::Read.into())); } @@ -555,29 +948,102 @@ impl Accessor for FullAccessor<'_, T> { } } -impl BuildDistanceComputer for FullAccessor<'_, T> { - type DistanceComputer = T::Distance; +/// Wrapper for full precision distance computer. +pub struct FullPrecisionDistance(T::Distance); + +impl DynDistanceComputer for FullPrecisionDistance { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { + self.0.evaluate_similarity( + bytemuck::cast_slice::(a), + bytemuck::cast_slice::(b), + ) + } +} + +/// Wrapper for full precision query computer. +pub struct FullPrecisionQueryDistance(T::QueryDistance); + +impl DynQueryComputer for FullPrecisionQueryDistance { + fn evaluate_similarity(&self, a: &[u8]) -> f32 { + self.0.evaluate_similarity(bytemuck::cast_slice::(a)) + } +} + +/// Type-erased distance computer. +pub struct GarnetDistanceComputer { + inner: Box, +} + +impl GarnetDistanceComputer { + pub fn new(computer: T) -> Self { + Self { + inner: Box::new(computer), + } + } +} +impl DistanceFunction<&[u8], &[u8]> for GarnetDistanceComputer { + fn evaluate_similarity(&self, x: &[u8], y: &[u8]) -> f32 { + self.inner.evaluate_similarity(x, y) + } +} + +/// Type-erased query computer. +pub struct GarnetQueryComputer { + inner: Box, +} + +impl GarnetQueryComputer { + pub fn new(computer: T) -> Self { + Self { + inner: Box::new(computer), + } + } +} + +impl PreprocessedDistanceFunction<&[u8]> for GarnetQueryComputer { + fn evaluate_similarity(&self, changing: &[u8]) -> f32 { + self.inner.evaluate_similarity(changing) + } +} + +impl BuildDistanceComputer for DynamicAccessor<'_, T> { + type DistanceComputer = GarnetDistanceComputer; type DistanceComputerError = GarnetProviderError; fn build_distance_computer( &self, ) -> Result { - Ok(T::distance( - self.provider.metric_type, - Some(self.provider.dim), - )) + if self.quantized + && let Some(quantizer) = self.provider.quantizer() + { + Ok(quantizer.distance_computer()?) + } else { + Ok(GarnetDistanceComputer::new(FullPrecisionDistance::( + T::distance(self.provider.metric_type, Some(self.provider.dim)), + ))) + } } } -impl BuildQueryComputer<&[T]> for FullAccessor<'_, T> { - type QueryComputer = T::QueryDistance; +impl BuildQueryComputer<&[T]> for DynamicAccessor<'_, T> { + type QueryComputer = GarnetQueryComputer; type QueryComputerError = GarnetProviderError; fn build_query_computer( &self, from: &[T], ) -> Result { - Ok(T::query_distance(from, self.provider.metric_type)) + if self.quantized + && let Some(quantizer) = self.provider.quantizer() + { + Ok(quantizer + .query_computer(bytemuck::cast_slice::(from)) + .map_err(|e| GarnetQuantizerError::QueryComputer(Box::new(e)))?) + } else { + Ok(GarnetQueryComputer::new(FullPrecisionQueryDistance::( + T::query_distance(from, self.provider.metric_type), + ))) + } } } @@ -594,61 +1060,106 @@ impl<'a, T> Reborrow<'a> for Escape { } } -type WorkingSet = workingset::Map>; -type WorkingSetView<'a, T> = workingset::map::View<'a, u32, Escape>; +pub struct WorkingSet { + map: workingset::Map>, + contains_unquantized: bool, +} + +impl WorkingSet { + pub fn new(capacity_type: workingset::map::Capacity, capacity: usize) -> Self { + Self { + map: workingset::map::Builder::new(capacity_type).build(capacity), + contains_unquantized: false, + } + } +} + +impl Deref for WorkingSet { + type Target = workingset::Map>; + + fn deref(&self) -> &Self::Target { + &self.map + } +} + +impl DerefMut for WorkingSet { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.map + } +} + +type WorkingSetView<'a> = workingset::map::View<'a, u32, Escape>; -impl workingset::Fill> for FullAccessor<'_, T> { +impl workingset::Fill for DynamicAccessor<'_, T> { type Error = GarnetProviderError; type View<'a> - = WorkingSetView<'a, T> + = WorkingSetView<'a> where Self: 'a; async fn fill<'a, Itr>( &'a mut self, - set: &'a mut WorkingSet, + set: &'a mut WorkingSet, itr: Itr, ) -> Result, Self::Error> where Itr: ExactSizeIterator + Clone + Send + Sync, Self: 'a, { + if !self.quantized { + // Mark this working set as having full vectors if we're not quantizing yet + set.contains_unquantized = true; + } else if set.contains_unquantized { + // Working set is polluted by full vectors, it must be cleared + set.clear(); + set.contains_unquantized = false; + } + // Evict items from the working set to make room if needed. set.prepare(itr.clone()); self.filtered_ids.clear(); for id in itr { if id == 0 { - if let Entry::Vacant(e) = set.entry(id) { - let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { - r - } else { - return Err(GarnetError::Read.into()); - }; - e.insert(Escape((&**guard).into())); - } + if self.quantized + && let Entry::Vacant(e) = set.entry(id) + { + if let Some(guard) = self.provider.start_point_quant_cache.get(&id) { + e.insert(Escape((&**guard).into())); + } + } else if let Entry::Vacant(e) = set.entry(id) { + if let Some(guard) = self.provider.start_point_cache.get(&id) { + e.insert(Escape((&**guard).into())); + } + } else { + continue; + }; } else if !set.contains_key(&id) { self.filtered_ids.push(4); self.filtered_ids.push(id); } } + let ctx = if self.quantized { + self.context.term(Term::Quantized) + } else { + self.context.term(Term::Vector) + }; + if !self.filtered_ids.is_empty() { - self.provider.callbacks.read_multi_lpiid( - self.context.term(Term::Vector), - &self.filtered_ids, - |id, v| { + self.provider + .callbacks + .read_multi_lpiid(ctx, &self.filtered_ids, |id, v| { set.insert(self.filtered_ids[id as usize * 2 + 1], Escape(v.into())); - }, - ); + }); } Ok(set.view()) } } -pub struct DelegateNeighborAccessor<'p, 'a, T: VectorRepr>(&'a mut FullAccessor<'p, T>); +pub struct DelegateNeighborAccessor<'p, 'a, T: VectorRepr>(&'a mut DynamicAccessor<'p, T>); impl HasId for DelegateNeighborAccessor<'_, '_, T> { type Id = u32; @@ -668,7 +1179,7 @@ impl NeighborAccessor for DelegateNeighborAccessor<'_, '_, T> { } } -impl<'p, 'a, T: VectorRepr> DelegateNeighbor<'a> for FullAccessor<'p, T> { +impl<'p, 'a, T: VectorRepr> DelegateNeighbor<'a> for DynamicAccessor<'p, T> { type Delegate = DelegateNeighborAccessor<'p, 'a, T>; fn delegate_neighbor(&'a mut self) -> Self::Delegate { @@ -754,21 +1265,21 @@ impl NeighborAccessorMut for DelegateNeighborAccessor<'_, '_, T> #[derive(Debug, Default, Clone, Copy)] pub struct CopyExternalIds; -impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], GarnetId> +impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], GarnetId> for CopyExternalIds { type Error = GarnetProviderError; fn post_process( &self, - accessor: &mut FullAccessor<'a, T>, + accessor: &mut DynamicAccessor<'a, T>, _query: &[T], - _computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, + _computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, candidates: I, output: &mut B, ) -> impl Future> + Send where - I: Iterator as HasId>::Id>> + Send, + I: Iterator as HasId>::Id>> + Send, B: SearchOutputBuffer + Send + ?Sized, { let initial = output.current_len(); @@ -782,41 +1293,124 @@ impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], Garn break; } } + let count = output.current_len() - initial; future::ready(Ok(count)) } } -impl SearchStrategy, &[T]> for FullPrecision { - type SearchAccessor<'a> = FullAccessor<'a, T>; +/// A [`SearchPostProcess`] base object that reranks quantized vectors by full precision distance. +#[derive(Debug, Default, Clone, Copy)] +pub struct Rerank; + +impl<'a, 'b, T: VectorRepr> SearchPostProcessStep, &'b [T], GarnetId> + for Rerank +{ + type Error + = GarnetProviderError + where + NextError: diskann::error::StandardError; + + type NextAccessor = DynamicAccessor<'a, T>; + + async fn post_process_step( + &self, + next: &Next, + accessor: &mut DynamicAccessor<'a, T>, + query: &'b [T], + computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, + candidates: I, + output: &mut B, + ) -> Result> + where + I: Iterator as HasId>::Id>> + Send, + B: SearchOutputBuffer + Send + ?Sized, + Next: SearchPostProcess + Sync, + { + if !accessor.quantized { + // Skip reranking if the accessor if working with full precision + return next + .post_process(accessor, query, computer, candidates, output) + .await + .map_err(|e| GarnetProviderError::PostProcessing(Box::new(e))); + } + + let provider = &accessor.provider; + let f = T::distance(provider.metric_type, Some(provider.dim)); + let mut v = Poly::broadcast(0u8, provider.dim * mem::size_of::(), AlignToEight)?; + + // Filter before computing the full precision distances. + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if !provider.vector_iid_exists(accessor.context, n.id) { + None + } else if provider.callbacks.read_single_iid( + accessor.context.term(Term::Vector), + n.id, + &mut v, + ) { + Some(( + n.id, + f.evaluate_similarity(query, bytemuck::cast_slice::(&v)), + )) + } else { + None + } + }) + .collect(); + + // Sort the full precision distances. + reranked + .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + + next.post_process( + accessor, + query, + computer, + reranked.into_iter().map(|(id, d)| Neighbor::new(id, d)), + output, + ) + .await + .map_err(|e| GarnetProviderError::PostProcessing(Box::new(e))) + } +} + +impl SearchStrategy, &[T]> for DynamicQuantization { + type SearchAccessor<'a> = DynamicAccessor<'a, T>; type SearchAccessorError = GarnetProviderError; - type QueryComputer = T::QueryDistance; + type QueryComputer = GarnetQueryComputer; fn search_accessor<'a>( &'a self, provider: &'a GarnetProvider, context: &'a as DataProvider>::Context, ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider, context, true)) + let quantized = provider.is_quantized(); + Ok(DynamicAccessor::new(provider, context, quantized)) } } -impl DefaultPostProcessor, &[T], GarnetId> for FullPrecision { - default_post_processor!(glue::Pipeline); +impl DefaultPostProcessor, &[T], GarnetId> + for DynamicQuantization +{ + default_post_processor!( + glue::Pipeline> + ); } -impl PruneStrategy> for FullPrecision { - type PruneAccessor<'a> = FullAccessor<'a, T>; +impl PruneStrategy> for DynamicQuantization { + type PruneAccessor<'a> = DynamicAccessor<'a, T>; type PruneAccessorError = GarnetProviderError; - type DistanceComputer<'a> = T::Distance; - type WorkingSet = WorkingSet; + type DistanceComputer<'a> = GarnetDistanceComputer; + type WorkingSet = WorkingSet; fn prune_accessor<'a>( &'a self, provider: &'a GarnetProvider, context: &'a as DataProvider>::Context, ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider, context, false)) + let quantized = provider.is_quantized(); + Ok(DynamicAccessor::new(provider, context, quantized)) } fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { @@ -824,11 +1418,11 @@ impl PruneStrategy> for FullPrecision { // cache and persist up to `capacity` items across uses of the working set. // // This reuse is limited to a single collection of backedges for an insert or multi-insert. - workingset::map::Builder::new(workingset::map::Capacity::Default).build(capacity) + WorkingSet::new(workingset::map::Capacity::Default, capacity) } } -impl InsertStrategy, &[T]> for FullPrecision { +impl InsertStrategy, &[T]> for DynamicQuantization { type PruneStrategy = Self; fn insert_search_accessor<'a>( @@ -836,7 +1430,8 @@ impl InsertStrategy, &[T]> for FullPrecision { provider: &'a GarnetProvider, context: &'a as DataProvider>::Context, ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider, context, false)) + let quantized = provider.is_quantized(); + Ok(DynamicAccessor::new(provider, context, quantized)) } fn prune_strategy(&self) -> Self::PruneStrategy { @@ -844,13 +1439,13 @@ impl InsertStrategy, &[T]> for FullPrecision { } } -impl InplaceDeleteStrategy> for FullPrecision { +impl InplaceDeleteStrategy> for DynamicQuantization { type DeleteElement<'a> = &'a [T]; type DeleteElementGuard = Box<[T]>; type DeleteElementError = GarnetProviderError; type PruneStrategy = Self; - type DeleteSearchAccessor<'a> = FullAccessor<'a, T>; + type DeleteSearchAccessor<'a> = DynamicAccessor<'a, T>; type SearchPostProcessor = glue::CopyIds; type SearchStrategy = Self; diff --git a/diskann-garnet/src/quantization.rs b/diskann-garnet/src/quantization.rs new file mode 100644 index 000000000..265bda056 --- /dev/null +++ b/diskann-garnet/src/quantization.rs @@ -0,0 +1,317 @@ +use std::{num::NonZero, sync::RwLock}; + +use diskann::utils::VectorRepr; +use diskann_quantization::{ + CompressInto, + algorithms::{Transform, TransformKind, transforms::NewTransformError}, + alloc::{GlobalAllocator, ScopedAllocator}, + minmax, + num::Positive, + spherical::{ + self, Data, PreScale, SphericalQuantizer, SupportedMetric, + iface::{self, Opaque, OpaqueMut, Quantizer}, + }, +}; +use diskann_utils::views::MatrixView; +use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric}; +use thiserror::Error; + +use crate::provider::{GarnetDistanceComputer, GarnetQueryComputer}; + +#[derive(Debug, Error)] +pub enum GarnetQuantizerError { + #[error("Quantization training error: {0}")] + Training(Box), + #[error("Quantization alloc error: {0}")] + Alloc(Box), + #[error("Query computer error: {0}")] + QueryComputer(Box), + #[error("Binary quantization error: {0}")] + Compression(Box), + #[error("No quantizer found")] + NoQuantizer, + #[error("Got zero dimension")] + ZeroDim, + #[error("Transform error: {0}")] + BadTransform(#[from] NewTransformError), +} + +/// Quantizer trait that all diskann-garnet quantizers must implement +pub trait GarnetQuantizer: Send + Sync { + /// Check whether the quantizer is ready to be used + fn is_prepared(&self) -> bool; + /// Returns the number of vectors needed before the quantizer can be trained + fn required_vectors(&self) -> usize; + /// Returns the size of a quantized vector + fn canonical_bytes(&self) -> usize; + /// Train the quantizer. + /// Each row of the matrix will be a vector + fn train(&self, metric: Metric, data: MatrixView) -> Result<(), GarnetQuantizerError>; + /// Quantize a vector + fn compress(&self, v: &[f32], into: &mut [u8]) -> Result<(), GarnetQuantizerError>; + /// Returns a distance computer for comparing quantized vectors + fn distance_computer(&self) -> Result; + /// Returns a query computer for comparing distances to a particular query + fn query_computer(&self, query: &[f32]) -> Result; +} + +/// Type-erased distance computer +pub trait DynDistanceComputer: Send + Sync { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32; +} + +/// Type-erased query computer +pub trait DynQueryComputer: Send + Sync { + fn evaluate_similarity(&self, a: &[u8]) -> f32; +} + +/// Spherical 1-bit quantization. +/// +/// This quantizer corresponds to `BIN` quantizer in the Redis protocol. It requires hundreds of +/// vectors (but not thousands) for training. Quantized vectors have 1 bit per dimension plus up +/// to 6 bytes of overhead. +pub struct Spherical1Bit { + dim: usize, + inner: RwLock>>, +} + +impl Spherical1Bit { + pub fn new(dim: usize) -> Self { + Self { + dim, + inner: RwLock::new(None), + } + } +} + +impl GarnetQuantizer for Spherical1Bit { + fn is_prepared(&self) -> bool { + self.inner.read().unwrap().is_some() + } + + fn required_vectors(&self) -> usize { + 1000 + } + + fn canonical_bytes(&self) -> usize { + Data::<1, GlobalAllocator>::canonical_bytes(self.dim) + } + + fn train( + &self, + metric_type: Metric, + data: MatrixView, + ) -> Result<(), GarnetQuantizerError> { + let mut rng = rand::rng(); + let quantizer = SphericalQuantizer::train( + data.as_view(), + TransformKind::DoubleHadamard { + target_dim: diskann_quantization::algorithms::transforms::TargetDim::Same, + }, + SupportedMetric::try_from(metric_type) + .map_err(|e| GarnetQuantizerError::Training(Box::new(e)))?, + PreScale::ReciprocalMeanNorm, + &mut rng, + GlobalAllocator, + ) + .map_err(|e| GarnetQuantizerError::Training(Box::new(e)))?; + + let mut inner = self.inner.write().unwrap(); + *inner = Some( + spherical::iface::Impl::<1>::new(quantizer) + .map_err(|e| GarnetQuantizerError::Alloc(Box::new(e)))?, + ); + + Ok(()) + } + + fn compress(&self, v: &[f32], into: &mut [u8]) -> Result<(), GarnetQuantizerError> { + let guard = self.inner.read().unwrap(); + if let Some(quantizer) = &*guard { + spherical::iface::Quantizer::::compress( + quantizer, + v, + OpaqueMut::new(into), + ScopedAllocator::global(), + ) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + Ok(()) + } else { + Err(GarnetQuantizerError::NoQuantizer) + } + } + + fn distance_computer(&self) -> Result { + let guard = self.inner.read().unwrap(); + if let Some(quantizer) = &*guard { + let computer = quantizer + .distance_computer(GlobalAllocator) + .map_err(|e| GarnetQuantizerError::Alloc(Box::new(e)))?; + Ok(GarnetDistanceComputer::new(computer)) + } else { + Err(GarnetQuantizerError::NoQuantizer) + } + } + + fn query_computer(&self, query: &[f32]) -> Result { + let guard = self.inner.read().unwrap(); + if let Some(quantizer) = &*guard { + let computer = quantizer + .fused_query_computer( + query, + iface::QueryLayout::FullPrecision, + true, + GlobalAllocator, + ScopedAllocator::global(), + ) + .map_err(|e| GarnetQuantizerError::QueryComputer(Box::new(e)))?; + Ok(GarnetQueryComputer::new(computer)) + } else { + Err(GarnetQuantizerError::NoQuantizer) + } + } +} + +impl DynDistanceComputer for iface::DistanceComputer { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { + , Opaque<'_>, _>>::evaluate_similarity( + self, + Opaque::new(a), + Opaque::new(b), + ) + .unwrap() + } +} + +impl DynQueryComputer for iface::QueryComputer { + fn evaluate_similarity(&self, a: &[u8]) -> f32 { + , _>>::evaluate_similarity( + self, + Opaque::new(a), + ) + .unwrap() + } +} + +/// 8-bit scalar quantizer using MinMax +/// +/// This quantizer requires no training at all and is usable immediately on the first first. Each +/// quantized vector has 8 bits per dimension and 20 bytes of overhead. +pub struct MinMax8Bit { + dim: usize, + metric: Metric, + inner: minmax::MinMaxQuantizer, +} + +impl MinMax8Bit { + pub fn new(dim: usize, metric: Metric) -> Result { + let dim = match NonZero::new(dim) { + Some(d) => d, + None => return Err(GarnetQuantizerError::ZeroDim), + }; + let mut rng = rand::rng(); + let transform = Transform::new( + TransformKind::DoubleHadamard { + target_dim: diskann_quantization::algorithms::transforms::TargetDim::Same, + }, + dim, + Some(&mut rng), + GlobalAllocator, + )?; + let grid_scale = Positive::new(1.0).unwrap(); + + Ok(Self { + dim: dim.get(), + metric, + inner: minmax::MinMaxQuantizer::new(transform, grid_scale), + }) + } +} + +impl GarnetQuantizer for MinMax8Bit { + fn is_prepared(&self) -> bool { + true + } + + fn required_vectors(&self) -> usize { + 0 + } + + fn canonical_bytes(&self) -> usize { + minmax::Data::<8>::canonical_bytes(self.dim) + } + + fn train(&self, _metric: Metric, _data: MatrixView) -> Result<(), GarnetQuantizerError> { + Ok(()) + } + + fn compress(&self, v: &[f32], into: &mut [u8]) -> Result<(), GarnetQuantizerError> { + let into = minmax::DataMutRef::<8>::from_canonical_front_mut(into, self.dim) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + self.inner + .compress_into(v, into) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + Ok(()) + } + + fn distance_computer(&self) -> Result { + let computer = GarnetDistanceComputer::new( + ::distance( + self.metric, + Some(self.dim), + ), + ); + Ok(computer) + } + + fn query_computer(&self, query: &[f32]) -> Result { + let computer = GarnetQueryComputer::new(MinMax8BitQueryComputer::new( + &self.inner, + query, + self.dim, + self.metric, + )?); + Ok(computer) + } +} + +impl DynDistanceComputer for diskann_providers::common::FnPtr { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { + let a = diskann_providers::common::MinMax8::from_bytes(a); + let b = diskann_providers::common::MinMax8::from_bytes(b); + >::evaluate_similarity(self, a, b) + } +} +struct MinMax8BitQueryComputer( + diskann_providers::common::BufferedFnPtr, +); + +impl MinMax8BitQueryComputer { + fn new( + quantizer: &minmax::MinMaxQuantizer, + query: &[f32], + dim: usize, + metric: Metric, + ) -> Result { + let mut v = vec![Default::default(); minmax::Data::<8>::canonical_bytes(dim)]; + quantizer + .compress_into( + query, + minmax::DataMutRef::<8>::from_canonical_front_mut(&mut v, dim) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?, + ) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + let inner = diskann_providers::common::MinMax8::query_distance( + diskann_providers::common::MinMax8::from_bytes(&v), + metric, + ); + Ok(Self(inner)) + } +} + +impl DynQueryComputer for MinMax8BitQueryComputer { + fn evaluate_similarity(&self, a: &[u8]) -> f32 { + let a = diskann_providers::common::MinMax8::from_bytes(a); + self.0.evaluate_similarity(a) + } +} diff --git a/diskann-garnet/src/test_utils.rs b/diskann-garnet/src/test_utils.rs index 26b72f774..ba9caed10 100644 --- a/diskann-garnet/src/test_utils.rs +++ b/diskann-garnet/src/test_utils.rs @@ -131,6 +131,7 @@ mod tests { use std::collections::HashMap; use crate::{ + VectorQuantType, garnet::{Context, Term}, test_utils::Store, }; @@ -203,7 +204,14 @@ mod tests { // Create a u8 GarnetProvider with the test Store callbacks let dim = 8; let max_degree = 32; - let provider = GarnetProvider::::new(dim, Metric::L2, max_degree, callbacks, ctx); + let provider = GarnetProvider::::new( + dim, + VectorQuantType::NoQuant, + Metric::L2, + max_degree, + callbacks, + ctx, + ); // Provider should be created successfully assert!(provider.is_ok()); @@ -215,7 +223,14 @@ mod tests { // Create a u8 GarnetProvider with the test Store callbacks let dim = 8; let max_degree = 32; - let provider = GarnetProvider::::new(dim, Metric::L2, max_degree, callbacks, ctx); + let provider = GarnetProvider::::new( + dim, + VectorQuantType::NoQuant, + Metric::L2, + max_degree, + callbacks, + ctx, + ); // Provider should be created successfully assert!(provider.is_ok()); diff --git a/diskann-providers/src/common/minmax_repr.rs b/diskann-providers/src/common/minmax_repr.rs index 73c19b10d..19dc8ec96 100644 --- a/diskann-providers/src/common/minmax_repr.rs +++ b/diskann-providers/src/common/minmax_repr.rs @@ -117,6 +117,12 @@ impl num_traits::FromPrimitive for MinMaxElement { impl_from_primitive!(NBITS, from_u32, u32); } +impl MinMaxElement { + pub fn from_bytes(x: &[u8]) -> &[Self] { + bytemuck::must_cast_slice(x) + } +} + //////////////////////// /// DistanceProvider /// //////////////////////// diff --git a/diskann-providers/src/common/mod.rs b/diskann-providers/src/common/mod.rs index b5c4a42a4..ec39946f7 100644 --- a/diskann-providers/src/common/mod.rs +++ b/diskann-providers/src/common/mod.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ mod minmax_repr; -pub use minmax_repr::{MinMax4, MinMax8, MinMaxElement}; +pub use minmax_repr::{BufferedFnPtr, FnPtr, MinMax4, MinMax8, MinMaxElement}; mod ignore_lock_poison; pub use ignore_lock_poison::IgnoreLockPoison; diff --git a/rfcs/00000-quantizer-bootstrap.md b/rfcs/00000-quantizer-bootstrap.md new file mode 100644 index 000000000..a74c60100 --- /dev/null +++ b/rfcs/00000-quantizer-bootstrap.md @@ -0,0 +1,198 @@ +# Quantization Bootstrapping + +| | | +|---|---| +| **Authors** | Jack Moffitt | +| **Contributors** | Mark Hildebrand | +| **Created** | 2026-03-10 | +| **Updated** | 2026-03-10 | + +## Summary + +Indexes that use quantization must have a minimum number of vectors before they +can build quantization tables and start inserting vectors. Bootstrapping is the +process of incrementally building an index that starts empty, operates on +non-quantized vectors until enough vectors are present to build quantization +tables, and then transitions to normal operation. + +## Motivation + +### Background + +DiskANN's quantizers require some statistical information in order to build +quantization tables. For PQ, 10,000 vectors are generally required to build good +tables; for spherical, 100 are needed. In order to create an index, these +vectors must be provided at creation time in order to build the quantization +tables, at which point each vector is quantized as it is inserted. + +This requirement is easy to fulfill when building indexes from existing +datasets, but when starting from scratch, there is no ability for DiskANN to +build a quantized index since the quantization tables are a required part of the +constructor. + +Current deployments of DiskANN work around this issue by not allowing index +creation until a dataset is sufficiently large (pg_diskann), or operating a +separate flat index until sufficient vectors are collected at which point the +quantization tables are calculated and a graph index is built with DiskANN. + +### Problem Statement + +This RFC proposes changing DiskANN to operate in a quantization bootstrap mode +where it operates on full precision vectors until sufficient vectors exist to +create quantization tables, and then seamlessly transitions to a quantized +index. + +This means the index will operate in three different phases. In Phase 1, the +index operates in full precision mode only until sufficient vectors exist to +build quantization tables. During Phase 2, quantization tables will be built and +vectors will be quantized on insert; pre-existing vectors will be quantized in +the background. Once all vectors are quantized, Phase Three begins the normal +operation of the quantized index. + +### Goals + +1. Allow quantized indices to start empty and use full precision data only to + operate until sufficient vectors are inserted. +2. Aside from allowing construction of `DiskANNIndex` without providing + quantization tables, there should be no user-visible changes to using the index. +3. Performance should remain as high as possible during the three phases. +4. The quantization of previously inserted full vectors during Phase Two should + be controllable by the data provider. + +## Proposal + +Bootstrapping needs two changes to a DiskANN data provider implementation. + +1. **Switching Strategies**: DiskANN needs to start by using full precision only + strategies during Phase One, and switching to a quantized-only or hybrid + strategy for Phase Two. +2. **Quantization Backfill**: During Phase Two, previously inserted vectors will + need to be quantized. As background jobs are a performance concern, how + exactly this is accomplished must be customizable by the data provider. + +### Switching Strategies + +DiskANN already has the ability to run multiple strategies including hybrid full +precision and quantized ones. These strategies represent the high level intent, +but the data provider can choose alternate implementions depending on the data +available during the current phase. + +#### Insertion and Deletion + +Insertion and deletion can remain largely the same in the data provider +implementation. Inserts will need to write vectors, mappings, attributes, et al +into storage, and can track the current phase to gate writes to quantized +vectors. The search portion of these operations will return different objects +depending on the current phase. + +For example, consider a `DataProvder::set_element()` implementation: + +```rust +struct ExampleProvider { + // other fields omitted + quantizer: Option, +} + +impl SetElement<[f32]> for ExampleProvider { + // associated types ommitted + + async fn set_element( + &self, + context: &Self::Context, + id: &Self::ExternalId, + element: &[T], + ) -> Result { + let internal_id = self.new_id()?; + self.write_vector(context, internal_id, element)?; + self.set_internal_map(internal_id, id)?; + self.set_external_map(id, internal_id)?; + + // Quantize and storage quant vector if we have a quantizer. + if let Some(quantizer) = self.quantizer { + let qv = quantizer.quantize(element)?; + self.write_quant_vector(context, internal_id, element)?; + } else { + // This function will check if we are ready for Phase Two, and if so, do or schedule the quantizer intialization. + self.maybe_initialize_quantizer()?; + } + + Ok(NoopGuard::new(internal_id)) + } +} +``` + +Delete can similarly check the status of the quantizer, and delete quantized +vectors if they exist. + +#### Searching + +To avoid complexity of hybrid distance calculations, either full precision +distances will be used (Phase One and Two) or quantized distances will be used +(Phase Three). If a hybrid strategy is in use, then the hybrid distances will +not be used until Phase Three. + +Since vector data may be in one of two representations, the `Accessor::Element` +type should be `Poly` (this should be over-aligned to the correct alignment +for the primitive element type), and the data provider should interpret based on +data size. The distance and query computers will also need modifications to +accept both vector representations, and in the case of query computer the +representation much match that of the query. + + +### Quantization Backfill + +After quantization tables are built, newly inserted vectors will be quantized +before insertion, but previously inserted vectors won't have quantized +representations yet. During Phase Two, these previously inserted full precision +vectors will need to be quantized before the index enters Phase Three. + +Since integrators of DiskANN are sensitive to background jobs, how the index +manages backfilling quantized vectors is controlled by the data provider +implementation. The data provider must have some way to track which vectors have +missing quantized representations so that it generate them. + +Once Phase Two is reached, the data provider can either pause during insertion +of the phase changing vector, or schedule the work to happen asynchronously +however it likes. + +One possibility is to piggy-back on deletion tracking to track quantization +status of vectors. For example, in diskann-garnet, a free space map is kept that +tracks deletes. This could be expanded from 1-bit to 2-bits, and the second bit +used to track whether the vector is quantized. Alternatively, metadata about the +allocated range can be kept and used to iterate over the unquantized set. + + +## Trade-offs + +Currently the workarounds in use are either no index at all until sufficient +vectors exist or operating a side index until sufficient vectors exist and then +building a quantized graph. + +pg_diskann uses the former method, which means users are confused when they try +to create indexes on empty or insufficiently populated tables and get an +error. Cosmos DB uses the latter strategy and operates a flat index until an +asynchronous graph build is complete enough to use the graph index. This +requires the Cosmos DB team to maintain all their own infrastructure for the +flat index and the code around transitioning to the graph index. + +This proposal mitigates the downsides while still allowing the integrator to +retain control over key performance details. + +This proposal also entirely encapsulates this inside the `DataProvider` +implementation. Alternatively, one could attempt to solve this with some kind of +index or strategy layering, but the complexity this would introduce seems not +worth the cost. + +## Benchmark Results + +Since there is no way to build an index currently until quantization tables are +built, there is no way to benchmark the first two phases. There should be no +impact during Phase Three to performance. + +## Future Work + +None. + +## References + +None. \ No newline at end of file diff --git a/vectorset/src/loader.rs b/vectorset/src/loader.rs index f1e62444e..9e0f7a856 100644 --- a/vectorset/src/loader.rs +++ b/vectorset/src/loader.rs @@ -4,22 +4,14 @@ */ use anyhow::{Result, anyhow}; -use std::{ - collections::HashMap, - marker::PhantomData, - mem, - path::{Path, PathBuf}, - sync::Arc, -}; +use std::{collections::HashMap, marker::PhantomData, mem, path::Path, sync::Arc}; use tokio::{fs::File, io::AsyncReadExt, sync::Mutex}; const BATCH_SIZE: usize = 1024; /// Loads base or query vectors from a given path and allow iteration over them. -#[allow(dead_code)] pub struct DatasetLoader { file: Mutex<(File, usize)>, - path: PathBuf, num_vectors: usize, dim: usize, type_: PhantomData, @@ -36,7 +28,6 @@ impl DatasetLoader { Ok(Arc::new(Self { file: Mutex::new((file, 0)), - path, num_vectors, dim, type_: PhantomData, @@ -60,35 +51,33 @@ impl DatasetLoader { /// Returns (count, first_id) where `count` is the number of vectors loaded /// and `first_id` is the id of the first vector. pub async fn next(&self, buffer: &mut Vec) -> Result<(usize, usize)> { + let mut f = self.file.lock().await; + let mut count; let mut first_id; loop { - { - let mut f = self.file.lock().await; - - first_id = f.1; - if f.1 >= self.num_vectors { - buffer.clear(); - return Ok((0, first_id)); - } - - buffer.resize(BATCH_SIZE * self.dim, T::zeroed()); + first_id = f.1; + if f.1 >= self.num_vectors { + buffer.clear(); + return Ok((0, first_id)); + } - let mut buf: &mut [u8] = bytemuck::cast_slice_mut::(&mut *buffer); - while let bytes_read = f.0.read(buf).await? - && bytes_read > 0 - { - buf = &mut buf[bytes_read..]; - } + buffer.resize(BATCH_SIZE * self.dim, T::zeroed()); - let elements_left = buf.len() / mem::size_of::(); - if !buf.is_empty() && !elements_left.is_multiple_of(self.dim) { - return Err(anyhow!("unexpected EOF")); - } + let mut buf: &mut [u8] = bytemuck::cast_slice_mut::(&mut *buffer); + while let bytes_read = f.0.read(buf).await? + && bytes_read > 0 + { + buf = &mut buf[bytes_read..]; + } - count = BATCH_SIZE - elements_left / self.dim; + let elements_left = buf.len() / mem::size_of::(); + if !buf.is_empty() && !elements_left.is_multiple_of(self.dim) { + return Err(anyhow!("unexpected EOF")); } + count = BATCH_SIZE - elements_left / self.dim; + if count == 0 { continue; } @@ -96,7 +85,6 @@ impl DatasetLoader { break; } - let mut f = self.file.lock().await; f.1 += count; Ok((count, first_id)) diff --git a/vectorset/src/main.rs b/vectorset/src/main.rs index 4ba864d56..8333e4f86 100644 --- a/vectorset/src/main.rs +++ b/vectorset/src/main.rs @@ -107,6 +107,14 @@ struct IngestArgs { #[arg(long)] no_header_with_dim: Option, + /// Quantizer + #[arg(long)] + quantizer: Option, + + /// Metric + #[arg(long)] + metric: Option, + /// Paths to base vectors base_path: PathBuf, } @@ -161,6 +169,50 @@ enum DataType { Float32, } +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +enum Quantizer { + None, + Bin, + Q8, +} + +impl ToRedisArgs for Quantizer { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + redis::RedisWrite, + { + let q = match self { + Quantizer::None => b"NOQUANT".as_slice(), + Quantizer::Bin => b"BIN".as_slice(), + Quantizer::Q8 => b"Q8".as_slice(), + }; + out.write_arg(q); + } +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +enum DistanceMetric { + L2, + Cosine, + CosineNormalized, + InnerProduct, +} + +impl ToRedisArgs for DistanceMetric { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + redis::RedisWrite, + { + let q = match self { + DistanceMetric::L2 => b"L2".as_slice(), + DistanceMetric::Cosine => b"COSINE".as_slice(), + DistanceMetric::CosineNormalized => b"COSINE_NORMALIZED".as_slice(), + DistanceMetric::InnerProduct => b"INNER_PRODUCT".as_slice(), + }; + out.write_arg(q); + } +} + struct VectorId(u32); impl ToRedisArgs for VectorId { @@ -362,6 +414,8 @@ async fn ingest( let degree = args.degree; let mut cred = cred.clone(); let data_type = opts.data_type; + let quantizer = args.quantizer; + let metric = args.metric; tasks.spawn(async move { let mut buf = vec![T::zeroed(); ds.batch_size() * ds.dim()]; @@ -387,6 +441,7 @@ async fn ingest( let element = VectorId((first_id + i) as u32); let buf_start = i * ds.dim(); let buf_end = buf_start + ds.dim(); + pipeline.cmd("VADD").arg(&vset); match data_type { @@ -407,10 +462,14 @@ async fn ingest( pipeline.arg(b"XPREQ8"); } DataType::Float32 => { - pipeline.arg(b"NOQUANT"); + pipeline.arg(quantizer); } } + if let Some(metric) = metric { + pipeline.arg(b"XDISTANCE_METRIC").arg(metric); + } + pipeline .arg(b"EF") .arg(l_build.to_string().as_bytes())