From a0fceca6d52e324dcdbc9bd65fd7b52e613a8f40 Mon Sep 17 00:00:00 2001 From: Jack Moffitt Date: Tue, 10 Mar 2026 13:36:38 -0500 Subject: [PATCH 1/3] RFC for quantization bootstrapping --- rfcs/00000-quantizer-bootstrap.md | 158 ++++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 rfcs/00000-quantizer-bootstrap.md diff --git a/rfcs/00000-quantizer-bootstrap.md b/rfcs/00000-quantizer-bootstrap.md new file mode 100644 index 000000000..18f45ba3d --- /dev/null +++ b/rfcs/00000-quantizer-bootstrap.md @@ -0,0 +1,158 @@ +# Quantization Bootstrapping + +| | | +|---|---| +| **Authors** | Jack Moffitt | +| **Contributors** | | +| **Created** | 2026-03-10 | +| **Updated** | 2026-03-10 | + +## Summary + +Indexes that use quantization must have a minimum number of vectors before they +can build quantization tables and start inserting vectors. Bootstrapping is the +process of incrementally building an index that starts empty, operates on +non-quantized vectors until enough vectors are present to build quantization +tables, and then transitions to normal operation. + +## Motivation + +### Background + +DiskANN's quantizers require some statistical information in order to build +quantization tables. For PQ, 10,000 vectors are generally required to build good +tables; for spherical, 100 are needed. In order to create an index, these +vectors must be provided at creation time in order to build the quantization +tables, at which point each vector is quantized as it is inserted. + +This requirement is easy to fulfill when building indexes from existing +datasets, but when starting from scratch, there is no ability for DiskANN to +build a quantized index since the quantization tables are a required part of the +constructor. + +Current deployments of DiskANN work around this issue by not allowing index +creation until a dataset is sufficient large (pg_diskann), or operating a +separate flat index until sufficient vectors are collected at which point the +quantization tables are calculated and a graph index is built with DiskANN. + +### Problem Statement + +This RFC proposes changing DiskANN to operate in a quantization bootstrap mode +where it operates on full precision vectors until sufficient vectors exist to +create quantization tables, and then seamlessly transitions to a quantized +index. + +This means the index will operate in three different phases. In Phase 1, the +index operates in full precision mode only until sufficient vectors exist to +build quantization tables. During Phase 2, quantization tables will be built and +vectors will be quantized on insert; pre-existing vectors be quantized in the +background. Once all vectors are quantized, Phase Three begins the normal +operation of the quantized index. + +### Goals + +1. Allow quantized indices to start empty and use full precision data only to + operate until sufficient vectors are inserted. +2. Aside from allowing construction of `DiskANNIndex` without providing + quantization tables, there should user visible changes to using the index. +3. Performance should remain as high as possible during the three phases. +4. The quantization of previously inserted full vectors during Phase Two should + be controllable by the data provider. + +## Proposal + +Bootstrapping needs two changes to DiskANN. + +1. **Switching Strategies**: DiskANN needs to start by using full precision only + strategies during Phase One, and switching to a hybrid strategy for Phase + Two, and if the user's intent is to use quantized-only strategies, switching + to quantized-only for Phase Three. +2. **Quantization Backfill**: During Phase Two, previously inserted vectors will + need to be quantized. As background jobs are a performance concern, DiskANN + will need hooks for customizing this behavior. + +### Switching Strategies + +DiskANN already has the ability to run multiple strategies including hybrid full +precision and quantized ones. These should be sufficient for purposes of +bootstrapping, but we will need to orchestrate seamless transitions between +them. + +As the caller designates a strategy to use, we can implement new +`BootstrappedQuantized` and `BootstrappedHybrid` strategies that layer over +existing `FullPrecision`, `Quantized`, and `Hybrid` strategies. These new +bootstrapped strategies will delegate operation to the existing strategies +depending on the current phase. + +*Open question*: How exactly do we do this? + +### Quantization Backfill + +After quantization tables are built, newly inserted vectors will be quantized +before insertion, but previously inserted vectors won't have quantized +representations yet. During Phase Two, these previously inserted full precision +vectors will need to be quantized before the index enters Phase Three. + +Since integrators of DiskANN are sensitive to background jobs, how the index +manages backfilling quantized vectors should be controllable. + +The simplest way is to backfill all missing quantized vectors immediately during +the insert that starts Phase Two. This will cause a latency spike on that single +insert, but doesn't require any background processing. + +A more complicated solution would be to launch a background job that iterates +over full-precision only vectors and quantizes them. DiskANN should provide such +a job that integrators can use, but should also provide some callback that the +hosting database can pump to make incremental progress under its own control. + +Both of these methods can be realized by having a new trait `QuantBackfill`: + +```rust + +pub enum QuantBackfillStatus { + Incomplete, + Complete, +} + +pub trait QuantBackfill { + type BackfillError: AsyncFriendly; + + /// Backfill quantization vectors for up to approximately `duration` amount of time. + fn backfill(duration: Duration) -> impl Future> + AsyncFriendly; +} +``` + +This trait would be implemented on the type that implements the +`BootstrappedQuantized` and `BootstrappedHybrid` strategies. + +*Open question*: How to implement the background task and make it overridable? + +## Trade-offs + +Currently the workarounds in use are either no index at all until sufficient +vectors exist or operating a side index until sufficient vectors exist and then +building a quantized graph. + +pg_diskann uses the former method, which means users are confused when they try +to create indexes on empty tables or insufficiently populated tables and get an +error. Cosmos DB uses the latter strategy and operates a flat index until an +asynchronous graph build is complete enough to use the graph index. This +requires the Cosmos DB team to maintain all their own infrastructure for the +flat index and the code around transitioning to the graph index. + +This proposal mitigates the downsides while still allowing the integrator to +retain control over key performance details. + +## Benchmark Results + +Since there is no way to build an index currently until quantization tables are +built, there is no way to benchmark the first two phases. There should be no +impact during Phase Three to performance. + +## Future Work + +None. + +## References + +None. \ No newline at end of file From 94050f0f83c559f18d67783cbfe0d8bee741124c Mon Sep 17 00:00:00 2001 From: Jack Moffitt Date: Thu, 19 Mar 2026 17:19:11 -0500 Subject: [PATCH 2/3] changes based on brainstorming with Mark --- rfcs/00000-quantizer-bootstrap.md | 140 +++++++++++++++++++----------- 1 file changed, 90 insertions(+), 50 deletions(-) diff --git a/rfcs/00000-quantizer-bootstrap.md b/rfcs/00000-quantizer-bootstrap.md index 18f45ba3d..a74c60100 100644 --- a/rfcs/00000-quantizer-bootstrap.md +++ b/rfcs/00000-quantizer-bootstrap.md @@ -3,7 +3,7 @@ | | | |---|---| | **Authors** | Jack Moffitt | -| **Contributors** | | +| **Contributors** | Mark Hildebrand | | **Created** | 2026-03-10 | | **Updated** | 2026-03-10 | @@ -31,7 +31,7 @@ build a quantized index since the quantization tables are a required part of the constructor. Current deployments of DiskANN work around this issue by not allowing index -creation until a dataset is sufficient large (pg_diskann), or operating a +creation until a dataset is sufficiently large (pg_diskann), or operating a separate flat index until sufficient vectors are collected at which point the quantization tables are calculated and a graph index is built with DiskANN. @@ -45,8 +45,8 @@ index. This means the index will operate in three different phases. In Phase 1, the index operates in full precision mode only until sufficient vectors exist to build quantization tables. During Phase 2, quantization tables will be built and -vectors will be quantized on insert; pre-existing vectors be quantized in the -background. Once all vectors are quantized, Phase Three begins the normal +vectors will be quantized on insert; pre-existing vectors will be quantized in +the background. Once all vectors are quantized, Phase Three begins the normal operation of the quantized index. ### Goals @@ -54,78 +54,113 @@ operation of the quantized index. 1. Allow quantized indices to start empty and use full precision data only to operate until sufficient vectors are inserted. 2. Aside from allowing construction of `DiskANNIndex` without providing - quantization tables, there should user visible changes to using the index. + quantization tables, there should be no user-visible changes to using the index. 3. Performance should remain as high as possible during the three phases. 4. The quantization of previously inserted full vectors during Phase Two should be controllable by the data provider. ## Proposal -Bootstrapping needs two changes to DiskANN. +Bootstrapping needs two changes to a DiskANN data provider implementation. 1. **Switching Strategies**: DiskANN needs to start by using full precision only - strategies during Phase One, and switching to a hybrid strategy for Phase - Two, and if the user's intent is to use quantized-only strategies, switching - to quantized-only for Phase Three. + strategies during Phase One, and switching to a quantized-only or hybrid + strategy for Phase Two. 2. **Quantization Backfill**: During Phase Two, previously inserted vectors will - need to be quantized. As background jobs are a performance concern, DiskANN - will need hooks for customizing this behavior. + need to be quantized. As background jobs are a performance concern, how + exactly this is accomplished must be customizable by the data provider. ### Switching Strategies DiskANN already has the ability to run multiple strategies including hybrid full -precision and quantized ones. These should be sufficient for purposes of -bootstrapping, but we will need to orchestrate seamless transitions between -them. - -As the caller designates a strategy to use, we can implement new -`BootstrappedQuantized` and `BootstrappedHybrid` strategies that layer over -existing `FullPrecision`, `Quantized`, and `Hybrid` strategies. These new -bootstrapped strategies will delegate operation to the existing strategies +precision and quantized ones. These strategies represent the high level intent, +but the data provider can choose alternate implementions depending on the data +available during the current phase. + +#### Insertion and Deletion + +Insertion and deletion can remain largely the same in the data provider +implementation. Inserts will need to write vectors, mappings, attributes, et al +into storage, and can track the current phase to gate writes to quantized +vectors. The search portion of these operations will return different objects depending on the current phase. -*Open question*: How exactly do we do this? +For example, consider a `DataProvder::set_element()` implementation: -### Quantization Backfill +```rust +struct ExampleProvider { + // other fields omitted + quantizer: Option, +} -After quantization tables are built, newly inserted vectors will be quantized -before insertion, but previously inserted vectors won't have quantized -representations yet. During Phase Two, these previously inserted full precision -vectors will need to be quantized before the index enters Phase Three. +impl SetElement<[f32]> for ExampleProvider { + // associated types ommitted + + async fn set_element( + &self, + context: &Self::Context, + id: &Self::ExternalId, + element: &[T], + ) -> Result { + let internal_id = self.new_id()?; + self.write_vector(context, internal_id, element)?; + self.set_internal_map(internal_id, id)?; + self.set_external_map(id, internal_id)?; + + // Quantize and storage quant vector if we have a quantizer. + if let Some(quantizer) = self.quantizer { + let qv = quantizer.quantize(element)?; + self.write_quant_vector(context, internal_id, element)?; + } else { + // This function will check if we are ready for Phase Two, and if so, do or schedule the quantizer intialization. + self.maybe_initialize_quantizer()?; + } + + Ok(NoopGuard::new(internal_id)) + } +} +``` -Since integrators of DiskANN are sensitive to background jobs, how the index -manages backfilling quantized vectors should be controllable. +Delete can similarly check the status of the quantizer, and delete quantized +vectors if they exist. -The simplest way is to backfill all missing quantized vectors immediately during -the insert that starts Phase Two. This will cause a latency spike on that single -insert, but doesn't require any background processing. +#### Searching -A more complicated solution would be to launch a background job that iterates -over full-precision only vectors and quantizes them. DiskANN should provide such -a job that integrators can use, but should also provide some callback that the -hosting database can pump to make incremental progress under its own control. +To avoid complexity of hybrid distance calculations, either full precision +distances will be used (Phase One and Two) or quantized distances will be used +(Phase Three). If a hybrid strategy is in use, then the hybrid distances will +not be used until Phase Three. -Both of these methods can be realized by having a new trait `QuantBackfill`: +Since vector data may be in one of two representations, the `Accessor::Element` +type should be `Poly` (this should be over-aligned to the correct alignment +for the primitive element type), and the data provider should interpret based on +data size. The distance and query computers will also need modifications to +accept both vector representations, and in the case of query computer the +representation much match that of the query. -```rust -pub enum QuantBackfillStatus { - Incomplete, - Complete, -} +### Quantization Backfill + +After quantization tables are built, newly inserted vectors will be quantized +before insertion, but previously inserted vectors won't have quantized +representations yet. During Phase Two, these previously inserted full precision +vectors will need to be quantized before the index enters Phase Three. -pub trait QuantBackfill { - type BackfillError: AsyncFriendly; +Since integrators of DiskANN are sensitive to background jobs, how the index +manages backfilling quantized vectors is controlled by the data provider +implementation. The data provider must have some way to track which vectors have +missing quantized representations so that it generate them. - /// Backfill quantization vectors for up to approximately `duration` amount of time. - fn backfill(duration: Duration) -> impl Future> + AsyncFriendly; -} -``` +Once Phase Two is reached, the data provider can either pause during insertion +of the phase changing vector, or schedule the work to happen asynchronously +however it likes. -This trait would be implemented on the type that implements the -`BootstrappedQuantized` and `BootstrappedHybrid` strategies. +One possibility is to piggy-back on deletion tracking to track quantization +status of vectors. For example, in diskann-garnet, a free space map is kept that +tracks deletes. This could be expanded from 1-bit to 2-bits, and the second bit +used to track whether the vector is quantized. Alternatively, metadata about the +allocated range can be kept and used to iterate over the unquantized set. -*Open question*: How to implement the background task and make it overridable? ## Trade-offs @@ -134,7 +169,7 @@ vectors exist or operating a side index until sufficient vectors exist and then building a quantized graph. pg_diskann uses the former method, which means users are confused when they try -to create indexes on empty tables or insufficiently populated tables and get an +to create indexes on empty or insufficiently populated tables and get an error. Cosmos DB uses the latter strategy and operates a flat index until an asynchronous graph build is complete enough to use the graph index. This requires the Cosmos DB team to maintain all their own infrastructure for the @@ -143,6 +178,11 @@ flat index and the code around transitioning to the graph index. This proposal mitigates the downsides while still allowing the integrator to retain control over key performance details. +This proposal also entirely encapsulates this inside the `DataProvider` +implementation. Alternatively, one could attempt to solve this with some kind of +index or strategy layering, but the complexity this would introduce seems not +worth the cost. + ## Benchmark Results Since there is no way to build an index currently until quantization tables are From b5a3759e463e0356f38f5b98094bda85dc6ebaeb Mon Sep 17 00:00:00 2001 From: Jack Moffitt Date: Mon, 13 Apr 2026 15:08:04 -0500 Subject: [PATCH 3/3] Implement BIN and Q8 quantizers --- Cargo.lock | 1 + diskann-garnet/Cargo.toml | 5 +- diskann-garnet/src/alloc.rs | 2 +- diskann-garnet/src/dyn_index.rs | 36 +- diskann-garnet/src/fsm.rs | 101 ++- diskann-garnet/src/garnet.rs | 8 + diskann-garnet/src/lib.rs | 73 +- diskann-garnet/src/provider.rs | 799 +++++++++++++++++--- diskann-garnet/src/quantization.rs | 317 ++++++++ diskann-garnet/src/test_utils.rs | 19 +- diskann-providers/src/common/minmax_repr.rs | 6 + diskann-providers/src/common/mod.rs | 2 +- vectorset/src/loader.rs | 52 +- vectorset/src/main.rs | 61 +- 14 files changed, 1310 insertions(+), 172 deletions(-) create mode 100644 diskann-garnet/src/quantization.rs diff --git a/Cargo.lock b/Cargo.lock index 1713f4b87..81fcc6f78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -778,6 +778,7 @@ dependencies = [ "diskann-utils", "diskann-vector", "foldhash 0.2.0", + "rand 0.9.4", "thiserror 2.0.17", "tokio", ] diff --git a/diskann-garnet/Cargo.toml b/diskann-garnet/Cargo.toml index c16857e05..b81049947 100644 --- a/diskann-garnet/Cargo.toml +++ b/diskann-garnet/Cargo.toml @@ -16,8 +16,9 @@ dashmap = { workspace = true, features = ["inline"] } diskann.workspace = true diskann-quantization.workspace = true diskann-providers.workspace = true +diskann-utils.workspace = true diskann-vector.workspace = true foldhash = "0.2.0" +rand.workspace = true thiserror.workspace = true -tokio.workspace = true -diskann-utils.workspace = true +tokio = { workspace = true, features = ["sync"] } diff --git a/diskann-garnet/src/alloc.rs b/diskann-garnet/src/alloc.rs index 17db0775e..16e3cfb3c 100644 --- a/diskann-garnet/src/alloc.rs +++ b/diskann-garnet/src/alloc.rs @@ -9,7 +9,7 @@ use std::ptr::NonNull; /// Custom allocator that over-aligns to 8 bytes. This is needed since Garnet will hand us byte slices for f32 data /// that may be unaligned, so we need an allocator to make owned, aligned byte containers. #[derive(Debug, Clone, Copy)] -pub(crate) struct AlignToEight; +pub struct AlignToEight; unsafe impl AllocatorCore for AlignToEight { #[inline] diff --git a/diskann-garnet/src/dyn_index.rs b/diskann-garnet/src/dyn_index.rs index badedba51..77153dabe 100644 --- a/diskann-garnet/src/dyn_index.rs +++ b/diskann-garnet/src/dyn_index.rs @@ -7,7 +7,7 @@ use crate::{ SearchResults, garnet::{Context, GarnetId}, labels::GarnetQueryLabelProvider, - provider::{self, GarnetProvider}, + provider::{self, DynamicQuantization, GarnetProvider}, }; use diskann::{ ANNError, ANNResult, @@ -16,8 +16,7 @@ use diskann::{ utils::VectorRepr, }; use diskann_providers::{ - index::wrapped_async::DiskANNIndex, - model::graph::provider::{async_::common::FullPrecision, layers::BetaFilter}, + index::wrapped_async::DiskANNIndex, model::graph::provider::layers::BetaFilter, }; use std::sync::Arc; @@ -55,6 +54,10 @@ pub trait DynIndex: Send + Sync { fn internal_id_exists(&self, context: &Context, id: u32) -> bool; fn external_id_exists(&self, context: &Context, id: &GarnetId) -> bool; + + fn train_quantizer(&self, context: &Context) -> bool; + + fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize); } impl DynIndex for DiskANNIndex> { @@ -63,7 +66,7 @@ impl DynIndex for DiskANNIndex> { /// The data slice here must be aligned to `T` or this will panic. fn insert(&self, context: &Context, id: &GarnetId, data: &[u8]) -> ANNResult<()> { self.insert( - FullPrecision, + DynamicQuantization, context, id, bytemuck::cast_slice::(data), @@ -87,10 +90,10 @@ impl DynIndex for DiskANNIndex> { ) -> ANNResult { let query = bytemuck::cast_slice::(data); if let Some((labels, beta)) = filter { - let beta_filter = BetaFilter::new(FullPrecision, Arc::new(labels.clone()), beta); + let beta_filter = BetaFilter::new(DynamicQuantization, Arc::new(labels.clone()), beta); self.search(*params, &beta_filter, context, query, output) } else { - self.search(*params, &FullPrecision, context, query, output) + self.search(*params, &DynamicQuantization, context, query, output) } } @@ -105,9 +108,9 @@ impl DynIndex for DiskANNIndex> { let rt = tokio::runtime::Builder::new_current_thread() .build() .map_err(|e| ANNError::new(diskann::ANNErrorKind::Opaque, e))?; - let mut accessor: provider::FullAccessor<'_, T> = - >::search_accessor( - &FullPrecision, + let mut accessor: provider::DynamicAccessor<'_, T> = + >::search_accessor( + &DynamicQuantization, self.inner.provider(), context, )?; @@ -115,13 +118,12 @@ impl DynIndex for DiskANNIndex> { // Look up internal ID let iid = self.inner.provider().to_internal_id(context, id)?; let data = rt.block_on(accessor.get_element(iid))?; - let data_bytes = bytemuck::cast_slice::(&data); - self.search_vector(context, data_bytes, params, filter, output) + self.search_vector(context, &data, params, filter, output) } fn remove(&self, context: &Context, id: &GarnetId) -> ANNResult<()> { self.inplace_delete( - FullPrecision, + DynamicQuantization, context, id, 3, @@ -147,4 +149,14 @@ impl DynIndex for DiskANNIndex> { fn external_id_exists(&self, context: &Context, id: &GarnetId) -> bool { self.inner.provider().vector_id_exists(context, id) } + + fn train_quantizer(&self, context: &Context) -> bool { + self.inner.provider().train_quantizer(context) + } + + fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize) { + self.inner + .provider() + .backfill_quant_vectors(context, task_idx, task_count); + } } diff --git a/diskann-garnet/src/fsm.rs b/diskann-garnet/src/fsm.rs index 37096278f..6b6e56130 100644 --- a/diskann-garnet/src/fsm.rs +++ b/diskann-garnet/src/fsm.rs @@ -26,7 +26,7 @@ use crossbeam::queue::ArrayQueue; use std::sync::{ RwLock, - atomic::{AtomicBool, AtomicU32, Ordering}, + atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}, }; use thiserror::Error; @@ -47,11 +47,23 @@ pub enum FsmError { } pub struct FreeSpaceMap { + /// Garnet callbacks for reading/writing FSM keys callbacks: Callbacks, + /// A flag to signal whether there are free IDs in the FSM. + /// This is set after a scan of the FSM, and is used to prevent extraneous reads + /// of FSM blocks. has_free_ids: AtomicBool, + /// A queue of previously deleted IDs to prevent excessive reads to the FSM fast_free_list: ArrayQueue, + /// The maximum block ID stored in the FSM max_block: RwLock, + /// The next ID that will be minted if reusing previously deleted IDs is unavailable next_id: AtomicU32, + /// The total number of IDs marked used in the FSM + total_used: AtomicUsize, + /// Lock that prevents reuse of prevously used IDs. + /// This is used to disable ID reuse during quantization backfill. + reuse_lock: AtomicUsize, } impl FreeSpaceMap { @@ -60,6 +72,7 @@ impl FreeSpaceMap { let fast_free_list = ArrayQueue::new(FAST_SIZE); let max_block = RwLock::new(u32::MAX); let next_id = AtomicU32::new(0); + let total_used = AtomicUsize::new(0); let mut this = Self { callbacks, @@ -67,6 +80,8 @@ impl FreeSpaceMap { fast_free_list, max_block, next_id, + total_used, + reuse_lock: AtomicUsize::new(0), }; // Attempt to load state from Garnet. @@ -97,6 +112,7 @@ impl FreeSpaceMap { let mut block = vec![0u8; BLOCK_SIZE_BYTES]; let mut last_used_id = -1i64; + let mut total_used = 0usize; for block_id in (0..max_block_id).rev() { let block_key = Self::block_key(block_id); @@ -115,8 +131,9 @@ impl FreeSpaceMap { let used = bit_used(byte, bidx); if used { last_used_id = last_used_id.max(id as i64); - } else if (id as i64) < last_used_id && self.fast_free_list.push(id).is_err() { - break; + total_used += 1; + } else if (id as i64) < last_used_id { + let _ = self.fast_free_list.push(id); } id = id.saturating_sub(1); @@ -130,6 +147,8 @@ impl FreeSpaceMap { self.next_id .store((last_used_id + 1) as u32, Ordering::Release); + self.total_used.store(total_used, Ordering::Release); + if !self.fast_free_list.is_empty() { self.has_free_ids.store(true, Ordering::Release); } @@ -174,6 +193,14 @@ impl FreeSpaceMap { return Err(FsmError::Garnet(GarnetError::Write)); } + if changed { + if used { + self.total_used.fetch_add(1, Ordering::AcqRel); + } else { + self.total_used.fetch_sub(1, Ordering::AcqRel); + } + } + // NOTE: We don't modify the free list if the id was already free. if !used && changed { // Push the id onto the fast free list. If the queue is full, ignore it. @@ -214,7 +241,7 @@ impl FreeSpaceMap { /// This may be a a fresh ID larger than all the others, or it may be a reused ID that /// previously belonged to a deleted element. The returned ID is marked as used. pub fn next_id(&self, ctx: Context) -> Result { - if self.has_free_ids.load(Ordering::Acquire) { + if self.can_reuse() && self.has_free_ids.load(Ordering::Acquire) { // We retry reusing a freed ID until there are none or we get one and marking it used // succeeds in changing the value. loop { @@ -270,6 +297,10 @@ impl FreeSpaceMap { self.next_id.load(Ordering::Acquire).saturating_sub(1) } + pub fn total_used(&self) -> usize { + self.total_used.load(Ordering::Acquire) + } + /// Return the FSM block number, byte index, and bit index for a given ID. /// The block number is the block which stores this ID, the byte index is byte offset /// within the block which contains the status bits, and the bit index is the bit index @@ -301,9 +332,9 @@ impl FreeSpaceMap { let mut has_free_ids = false; let mut id = 0u32; + let mut block = vec![0u8; BLOCK_SIZE_BYTES]; 'scan: for block_id in 0..*max_block { let block_key = Self::block_key(block_id); - let mut block = vec![0u8; BLOCK_SIZE_BYTES]; if !self .callbacks .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) @@ -363,6 +394,66 @@ impl FreeSpaceMap { Ok(()) } + + /// Visit each used id in the FSM, invoking f on each id. + pub fn visit_used(&self, ctx: Context, mut f: F) -> Result<(), FsmError> + where + F: FnMut(u32) -> bool, + { + let max_block = { *self.max_block.read().unwrap() }; + let mut block = vec![0u8; BLOCK_SIZE_BYTES]; + let mut id = 0u32; + + for block_id in 0..max_block + 1 { + let block_key = Self::block_key(block_id); + if !self + .callbacks + .read_single_wid(ctx.term(Term::Metadata), block_key, &mut block) + { + return Err(FsmError::Garnet(GarnetError::Read)); + } + + for &byte in &block { + if byte == 0x00 { + id += 8; + continue; + } + + for bidx in 0..8 { + if bit_used(byte, bidx) { + let keep_going = f(id); + if !keep_going { + return Ok(()); + } + } + id += 1; + } + } + } + + Ok(()) + } + + /// Returns whether previously deleted IDs may be reused. + fn can_reuse(&self) -> bool { + self.reuse_lock.load(Ordering::Acquire) == 0 + } + + /// Prevent the reuse of previously deleted IDs. + /// Each call to this increments a counter, and only once the counter is back to zero + /// will reuse be allowed again. + pub fn lock_reuse(&self) { + self.reuse_lock.fetch_add(1, Ordering::AcqRel); + } + + /// Resume reuse of previously deleted IDs. + /// Each call to this decrements a counter, and only once the counter is back to zero + /// will reuse be allowed. This returns whether reuse was actually enabled. + pub fn unlock_reuse(&self) -> bool { + let prev = self.reuse_lock.fetch_sub(1, Ordering::AcqRel); + debug_assert_ne!(prev, 0); + prev == 1 + } } /// Return whether the `bidx`th bit is set in byte, where bits are labeled from left to right. diff --git a/diskann-garnet/src/garnet.rs b/diskann-garnet/src/garnet.rs index 85ba0cce5..566460af1 100644 --- a/diskann-garnet/src/garnet.rs +++ b/diskann-garnet/src/garnet.rs @@ -86,6 +86,10 @@ impl Callbacks { self.rmw_callback } + #[expect( + dead_code, + reason = "currently unused, but may be needed in the future" + )] pub fn exists_iid(&self, ctx: Context, id: u32) -> bool { let key = [4, id]; // SAFETY: Key bytes are preceded by 4 bytes of space. @@ -100,6 +104,10 @@ impl Callbacks { unsafe { self.exists_raw(ctx, &key_bytes[4..]) } } + #[expect( + dead_code, + reason = "currently unused, but may be needed in the future" + )] pub fn exists_eid(&self, ctx: Context, id: &GarnetId) -> bool { // SAFETY: GarnetId ensures there are 4 bytes preceding the key bytes. unsafe { self.exists_raw(ctx, id) } diff --git a/diskann-garnet/src/lib.rs b/diskann-garnet/src/lib.rs index a66cd7ab0..c6b40e37f 100644 --- a/diskann-garnet/src/lib.rs +++ b/diskann-garnet/src/lib.rs @@ -51,6 +51,7 @@ mod fsm; mod garnet; mod labels; mod provider; +mod quantization; #[cfg(test)] mod test_utils; @@ -175,7 +176,8 @@ fn create_index_impl( callbacks: Callbacks, context: Context, ) -> Result, GarnetProviderError> { - let provider = GarnetProvider::::new(dim, metric_type, max_degree, callbacks, context)?; + let provider = + GarnetProvider::::new(dim, quant_type, metric_type, max_degree, callbacks, context)?; let state = if provider.start_points_exist() { AtomicUsize::new(IndexState::Ready as usize) } else { @@ -217,7 +219,7 @@ pub unsafe extern "C" fn create_index( target_degree, config::MaxDegree::Value(max_degree as usize), l_build as usize, - config::PruneKind::TriangleInequality, + metric_type.into(), ) .build() { @@ -230,6 +232,7 @@ pub unsafe extern "C" fn create_index( let callbacks = Callbacks::new(read_callback, write_callback, delete_callback, rmw_callback); match quant_type { + VectorQuantType::Invalid => ptr::null(), VectorQuantType::XPreQ8 => { if let Ok(index) = create_index_impl::( quant_type, @@ -245,7 +248,7 @@ pub unsafe extern "C" fn create_index( ptr::null() } } - VectorQuantType::NoQuant => { + VectorQuantType::NoQuant | VectorQuantType::Bin | VectorQuantType::Q8 => { if let Ok(index) = create_index_impl::( quant_type, config, @@ -260,7 +263,6 @@ pub unsafe extern "C" fn create_index( ptr::null() } } - _ => ptr::null(), } } @@ -318,6 +320,9 @@ fn interpret_vector<'a>( let v = match vector_value_type { VectorValueType::Invalid => return None, VectorValueType::FP32 => match quant_type { + VectorQuantType::Invalid => { + return None; + } VectorQuantType::XPreQ8 => { let mut bp = if let Ok(bp) = Poly::broadcast(0u8, vector_len, AlignToEight) { bp @@ -332,11 +337,11 @@ fn interpret_vector<'a>( } PolyCow::from(bp) } - VectorQuantType::NoQuant if v.as_ptr().align_offset(4) == 0 => { + _ if v.as_ptr().align_offset(4) == 0 => { // pointer is correctly aligned to interpret as f32 PolyCow::from(v) } - VectorQuantType::NoQuant => { + _ => { // need to copy f32 data as it is unaligned let mut fp = if let Ok(fp) = Poly::broadcast(0u8, vector_len_bytes, AlignToEight) { fp @@ -346,9 +351,6 @@ fn interpret_vector<'a>( fp.copy_from_slice(v); PolyCow::from(fp) } - _ => { - return None; - } }, VectorValueType::XB8 => match quant_type { VectorQuantType::XPreQ8 => PolyCow::from(v), @@ -375,6 +377,10 @@ fn interpret_vector<'a>( Some(v) } +const INSERT_FAIL: u8 = 0; +const INSERT_SUCCESS: u8 = 1; +const INSERT_SUCCESS_START_TRAINING: u8 = 2; + /// # Safety /// /// FFI @@ -389,7 +395,7 @@ pub unsafe extern "C" fn insert( vector_len: usize, attribute_data: *const u8, attribute_len: usize, -) -> bool { +) -> u8 { let index = unsafe { &*index_ptr.cast::() }; let ctx = Context(ctx); @@ -404,13 +410,13 @@ pub unsafe extern "C" fn insert( ) { v } else { - return false; + return INSERT_FAIL; }; if let Some(_err) = ensure_index_ready_or_init(index, || index.inner.maybe_set_start_point(&ctx, &v).err()) { - return false; + return INSERT_FAIL; }; // Write attributes to garnet @@ -420,11 +426,22 @@ pub unsafe extern "C" fn insert( &[] }; if index.inner.set_attributes(&ctx, &id, attr_data).is_err() { - return false; + return INSERT_FAIL; } + let old_ready = provider::QUANTIZER_READY.with(|v| v.load(Ordering::Acquire)); + // Insert the vector - index.inner.insert(&ctx, &id, &v).is_ok() + if index.inner.insert(&ctx, &id, &v).is_ok() { + let ready = provider::QUANTIZER_READY.with(|v| v.load(Ordering::Acquire)); + if !old_ready && ready { + INSERT_SUCCESS_START_TRAINING + } else { + INSERT_SUCCESS + } + } else { + INSERT_FAIL + } } fn ensure_index_ready_or_init(index: &Index, init: F) -> Option @@ -466,6 +483,34 @@ where None } +/// # Safety +/// +/// FFI +#[unsafe(no_mangle)] +pub unsafe extern "C" fn build_quant_table(context: u64, index_ptr: *const c_void) -> bool { + let index = unsafe { &*index_ptr.cast::() }; + let ctx = Context(context); + + index.inner.train_quantizer(&ctx) +} + +/// # Safety +/// +/// FFI +#[unsafe(no_mangle)] +pub unsafe extern "C" fn backfill_quant_vectors( + context: u64, + index_ptr: *const c_void, + task_index: usize, + task_count: usize, +) { + let index = unsafe { &*index_ptr.cast::() }; + let ctx = Context(context); + index + .inner + .backfill_quant_vectors(&ctx, task_index, task_count); +} + /// # Safety /// /// FFI diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 0be5be9a4..6f8750135 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -11,7 +11,7 @@ use diskann::{ config::defaults::MAX_OCCLUSION_SIZE, glue::{ self, DefaultPostProcessor, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, - PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + PruneStrategy, SearchExt, SearchPostProcess, SearchPostProcessStep, SearchStrategy, }, workingset::{self, map::Entry}, }, @@ -22,21 +22,40 @@ use diskann::{ }, utils::VectorRepr, }; -use diskann_providers::model::graph::provider::async_::common::FullPrecision; -use diskann_utils::Reborrow; +use diskann_quantization::alloc::{AllocatorError, Poly}; use diskann_utils::object_pool::{AsPooled, ObjectPool, PooledRef, Undef}; -use diskann_vector::{PreprocessedDistanceFunction, contains::ContainsSimd, distance::Metric}; +use diskann_utils::{Reborrow, views::Matrix}; +use diskann_vector::{ + DistanceFunction, PreprocessedDistanceFunction, contains::ContainsSimd, distance::Metric, +}; use std::{ - future, mem, + any::TypeId, + future, + marker::PhantomData, + mem, ops::{Deref, DerefMut}, + sync::atomic::{AtomicBool, AtomicU64, Ordering}, + time::SystemTime, }; use thiserror::Error; use crate::{ + VectorQuantType, + alloc::AlignToEight, fsm::{FreeSpaceMap, FsmError}, garnet::{Callbacks, Context, GarnetError, GarnetId, Term}, + quantization::{ + self, DynDistanceComputer, DynQueryComputer, GarnetQuantizer, GarnetQuantizerError, + }, }; +thread_local! { + /// Thread local flag to detect when we've reached the quantization threshold. This is needed + /// to return the correct status at the end of insert() since the provider code that becomes + /// aware of the state change cannot directly communicate it back to the caller. + pub static QUANTIZER_READY: AtomicBool = const { AtomicBool::new(false) }; +} + #[derive(Clone)] struct AdjList(AdjacencyList); @@ -64,6 +83,44 @@ impl AsPooled for AdjList { } } +/// A type erased vector, properly aligned so that it can hold any size element +/// (up to 8 bytes long). +pub struct DynVector { + inner: Poly<[u8], AlignToEight>, + ty: PhantomData, +} + +impl DynVector { + fn new(inner: Poly<[u8], AlignToEight>) -> Self { + Self { + inner, + ty: PhantomData, + } + } +} + +impl Deref for DynVector { + type Target = Poly<[u8], AlignToEight>; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for DynVector { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl<'a, T: VectorRepr> Reborrow<'a> for DynVector { + type Target = &'a [u8]; + + fn reborrow(&'a self) -> Self::Target { + &self.inner + } +} + #[derive(Debug, Error)] pub enum GarnetProviderError { #[error("Garnet operation failed")] @@ -72,6 +129,18 @@ pub enum GarnetProviderError { Fsm(#[from] FsmError), #[error("Start point invalid")] StartPoint, + #[error("Allocation failed")] + AllocFailed(#[from] AllocatorError), + #[error("Invalid quantizer for vector data")] + InvalidQuantizer, + #[error("Expected quantizer is missing")] + MissingQuantizer, + #[error("Failed to gather quantizer training data")] + MissingTrainingData, + #[error("Quantizer error: {0}")] + Quantizer(#[from] GarnetQuantizerError), + #[error("Post processing error: {0}")] + PostProcessing(Box), } impl From for ANNError { @@ -83,21 +152,32 @@ impl From for ANNError { diskann::always_escalate!(GarnetProviderError); +/// The Garnet DataProvider implementation. pub struct GarnetProvider { dim: usize, metric_type: Metric, max_degree: usize, callbacks: Callbacks, + /// The quantizer the index will use, or None if NOQUANT is used. + quantizer: Option>, + /// Tracks whether quantization backfill is complete. + all_quantized: AtomicBool, + /// Tracks whether training has already started. + training_started: AtomicU64, id_buffer_pool: ObjectPool, filtered_ids_pool: ObjectPool>, + quant_buffer_pool: ObjectPool>, neighbor_cache: DashMap, foldhash::fast::RandomState>, - start_point_cache: DashMap, foldhash::fast::RandomState>, + start_point_cache: DashMap, foldhash::fast::RandomState>, + start_point_quant_cache: DashMap, foldhash::fast::RandomState>, fsm: FreeSpaceMap, + _phantom: PhantomData, } impl GarnetProvider { pub fn new( dim: usize, + quant_type: VectorQuantType, metric_type: Metric, max_degree: usize, callbacks: Callbacks, @@ -114,11 +194,13 @@ impl GarnetProvider { let start_point_cache = DashMap::with_capacity_and_hasher(1, foldhash::fast::RandomState::default()); + let start_point_quant_cache = + DashMap::with_capacity_and_hasher(1, foldhash::fast::RandomState::default()); let neighbor_cache = DashMap::with_capacity_and_hasher(1, foldhash::fast::RandomState::default()); // Try to read the start point from Garnet - let mut v = vec![T::default(); dim]; + let mut v = Poly::broadcast(0u8, dim * mem::size_of::(), AlignToEight)?; if callbacks.read_single_iid(context.term(Term::Vector), 0, &mut v) { let mut neighbors = vec![0u32; max_degree + 1]; if !callbacks.read_single_iid(context.term(Term::Neighbors), 0, &mut neighbors) { @@ -134,16 +216,50 @@ impl GarnetProvider { let fsm = FreeSpaceMap::new(context, callbacks)?; + let (quantizer, canonical_bytes, all_quantized) = match quant_type { + VectorQuantType::NoQuant | VectorQuantType::XPreQ8 => (None, 0, false), + VectorQuantType::Invalid => return Err(GarnetProviderError::InvalidQuantizer), + VectorQuantType::Q8 => { + if TypeId::of::() != TypeId::of::() { + return Err(GarnetProviderError::InvalidQuantizer); + } + + let quantizer = Box::new(quantization::MinMax8Bit::new(dim, metric_type)?) + as Box; + let canonical_bytes = quantizer.canonical_bytes(); + // NOTE: Q8 needs no training, so it always starts with backfill complete. + (Some(quantizer), canonical_bytes, true) + } + VectorQuantType::Bin => { + if TypeId::of::() != TypeId::of::() { + return Err(GarnetProviderError::InvalidQuantizer); + } + + let quantizer = + Box::new(quantization::Spherical1Bit::new(dim)) as Box; + let canonical_bytes = quantizer.canonical_bytes(); + (Some(quantizer), canonical_bytes, false) + } + }; + let quant_buffer_pool = + ObjectPool::new(Undef::new(canonical_bytes), parallelism, Some(parallelism)); + Ok(Self { dim, metric_type, max_degree, callbacks, + quantizer, + all_quantized: AtomicBool::new(all_quantized), + training_started: AtomicU64::new(0), id_buffer_pool, filtered_ids_pool, + quant_buffer_pool, start_point_cache, + start_point_quant_cache, neighbor_cache, fsm, + _phantom: PhantomData, }) } @@ -152,7 +268,7 @@ impl GarnetProvider { context: &Context, point: &[T], ) -> Result<(), GarnetProviderError> { - let mut v = vec![T::default(); self.dim]; + let mut v = Poly::broadcast(0u8, self.dim * mem::size_of::(), AlignToEight)?; if self .callbacks .read_single_iid(context.term(Term::Vector), 0, &mut v) @@ -187,6 +303,25 @@ impl GarnetProvider { return Err(GarnetError::Write.into()); } + if self.is_quantized() + && let Some(quantizer) = self.quantizer() + { + // We are already able to quantize, so store the quantized start point + + let mut qpoint = vec![0u8; quantizer.canonical_bytes()]; + quantizer.compress(bytemuck::cast_slice::(point), &mut qpoint)?; + + if !self + .callbacks + .write_iid(context.term(Term::Quantized), 0, &qpoint) + { + return Err(GarnetError::Write.into()); + } + + self.start_point_quant_cache + .insert(0, Poly::from_iter(qpoint.iter().copied(), AlignToEight)?); + } + if !self .callbacks .write_iid(context.term(Term::Neighbors), 0, &neighbors) @@ -194,7 +329,13 @@ impl GarnetProvider { return Err(GarnetError::Write.into()); } - self.start_point_cache.insert(0, point.to_vec()); + self.start_point_cache.insert( + 0, + Poly::from_iter( + bytemuck::cast_slice::(point).iter().copied(), + AlignToEight, + )?, + ); self.neighbor_cache .insert(0, Vec::with_capacity(self.max_degree + 1)); } @@ -206,10 +347,6 @@ impl GarnetProvider { self.start_point_cache.get(&0).is_some() && self.neighbor_cache.get(&0).is_some() } - pub fn callbacks(&self) -> &Callbacks { - &self.callbacks - } - pub fn set_attributes( &self, context: &Context, @@ -240,6 +377,189 @@ impl GarnetProvider { pub fn max_internal_id(&self) -> u32 { self.fsm.max_id() } + + /// Train the quantizer. + /// + /// This should only be called when at least `quantizer.required_vectors()` vectors exist in + /// the provider. This will build quantization tables, but does not quantize any vectors. + pub fn train_quantizer(&self, context: &Context) -> bool { + // Collect up to `self.quantizer.required_vectors()` vectors and use them for quantizer training. + + let current_time: u64 = + match SystemTime::now().duration_since(std::time::SystemTime::UNIX_EPOCH) { + Ok(t) => { + let t = t.as_millis().try_into().unwrap_or(0); + if t == 0 { + return false; + } + t + } + Err(_) => return false, + }; + + // Ensure we don't kick off training twice. + match self.training_started.compare_exchange( + 0, + current_time, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => (), + Err(_) => return false, + } + + let quantizer = match &self.quantizer { + Some(q) => q, + None => { + self.training_started.store(0, Ordering::Release); + return false; + } + }; + + debug_assert_eq!(std::any::TypeId::of::(), std::any::TypeId::of::()); + + let rows = quantizer.required_vectors(); + let mut data = Matrix::new(0f32, rows, self.dim); + let mut row_idx = 0usize; + + if self + .fsm + .visit_used(*context, |id| { + // Skip the start point. + if id == 0 { + return true; + } + + if row_idx >= rows { + return false; + } + + // Read the vector into the data matrix. + // Note that it's ok to read f32 instead of T here, because this can only get called when T == f32. + let row = data.row_mut(row_idx); + if !self + .callbacks + .read_single_iid(context.term(Term::Vector), id, row) + { + return false; + } + + row_idx += 1; + + true + }) + .is_err() + { + // Training failed + self.training_started.store(0, Ordering::Release); + return false; + } + + if row_idx < quantizer.required_vectors() { + self.training_started.store(0, Ordering::Release); + return false; + } + + let view = if let Some(view) = data.subview(0..row_idx) { + view + } else { + return false; + }; + + // Train the quantizer. + match quantizer.train(self.metric_type, view) { + Ok(()) => true, + Err(_e) => { + self.training_started.store(0, Ordering::Release); + false + } + } + } + + /// Bulk quantize previously inserted vectors. + /// + /// This function will be invoked on multiple threads. The total number of tasks and the ID of + /// the current task are given as inputs. + pub fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize) { + let quantizer = match &self.quantizer { + Some(q) => q, + None => return, + }; + + let max_id = self.fsm.max_id() as usize; + + let task_count = task_count.min(max_id + 1); + if task_idx >= task_count { + return; + } + + let work_count = (max_id + 1) / task_count; // will be >= 1 + let start_id = (work_count * task_idx) as u32; + let end_id = (work_count * (task_idx + 1)).min(max_id + 1) as u32; + + self.fsm.lock_reuse(); + + let mut v = vec![0f32; self.dim]; + let mut q = vec![0u8; quantizer.canonical_bytes()]; + for id in start_id..end_id { + if !self + .callbacks + .read_single_iid(context.term(Term::Vector), id, &mut v) + { + continue; + } + + if quantizer.compress(&v, &mut q).is_err() { + continue; + }; + + if !self + .callbacks + .write_iid(context.term(Term::Quantized), id, &q) + { + continue; + } + } + + let backfill_finished = self.fsm.unlock_reuse(); + + if backfill_finished { + // Finish by quantizing the start points + if let Some(v) = self.start_point_cache.get(&0) + && quantizer + .compress(bytemuck::cast_slice::(&v), &mut q) + .is_ok() + { + let _ = self + .callbacks + .write_iid(context.term(Term::Quantized), 0, &q); + + // set the cache + let point = if let Ok(p) = Poly::from_iter(q.iter().copied(), AlignToEight) { + p + } else { + return; + }; + self.start_point_quant_cache.insert(0, point); + } + + self.all_quantized.store(true, Ordering::Release); + } + } + + /// Returns the quantizer associated with the index. + fn quantizer(&self) -> Option<&dyn GarnetQuantizer> { + if let Some(quantizer) = &self.quantizer { + return Some(&**quantizer as &dyn GarnetQuantizer); + } + + None + } + + /// Returns quantization status. If this is true, the index is operating fully quantized. + pub fn is_quantized(&self) -> bool { + self.quantizer.is_some() && self.all_quantized.load(Ordering::Acquire) + } } impl DataProvider for GarnetProvider { @@ -287,11 +607,31 @@ impl SetElement<&[T]> for GarnetProvider { ) -> Result { let internal_id = self.fsm.next_id(*context)?; + // Set quantization readiness + if let Some(quantizer) = &self.quantizer + && !quantizer.is_prepared() + && self.fsm.total_used() > quantizer.required_vectors() + { + QUANTIZER_READY.with(|v| v.store(true, Ordering::Release)); + } + let insert = || -> Result<(), Self::SetError> { self.callbacks .write_iid(context.term(Term::Vector), internal_id, element) .then_some(()) .ok_or(GarnetError::Write)?; + if let Some(quantizer) = &self.quantizer + && quantizer.is_prepared() + { + let mut quant = self + .quant_buffer_pool + .get_ref(Undef::new(quantizer.canonical_bytes())); + quantizer.compress(bytemuck::cast_slice::(element), &mut quant)?; + self.callbacks + .write_iid(context.term(Term::Quantized), internal_id, &quant) + .then_some(()) + .ok_or(GarnetError::Write)?; + } self.callbacks .write_iid(context.term(Term::ExtMap), internal_id, id) .then_some(()) @@ -345,6 +685,7 @@ impl Delete for GarnetProvider { // NOTE: Commented out until DiskANN fixes accessing neighbor data post-delete. //ok &= self.callbacks.delete_iid(context.term(Term::Neighbors), id); ok &= self.callbacks.delete_iid(context.term(Term::Vector), id); + ok &= self.callbacks.delete_iid(context.term(Term::Quantized), id); if !ok { return future::ready(Err(GarnetError::Delete.into())); @@ -386,20 +727,25 @@ impl Delete for GarnetProvider { } } -#[allow(dead_code)] -pub struct FullAccessor<'a, T: VectorRepr> { +/// Dynamic accessor that seamlessly transitions from full precision vector based operation to +/// quantized-only operation. +#[derive(Copy, Clone, Debug)] +pub struct DynamicQuantization; + +pub struct DynamicAccessor<'a, T: VectorRepr> { provider: &'a GarnetProvider, context: &'a Context, - is_search: bool, + /// Whether this accessor should use quantized vectors + quantized: bool, id_buffer: PooledRef<'a, AdjList>, filtered_ids: PooledRef<'a, Vec>, } -impl<'a, T: VectorRepr> FullAccessor<'a, T> { +impl<'a, T: VectorRepr> DynamicAccessor<'a, T> { pub(crate) fn new( provider: &'a GarnetProvider, context: &'a Context, - is_search: bool, + quantized: bool, ) -> Self { let id_buffer = provider .id_buffer_pool @@ -407,10 +753,10 @@ impl<'a, T: VectorRepr> FullAccessor<'a, T> { let filtered_ids = provider .filtered_ids_pool .get_ref(Undef::new(MAX_OCCLUSION_SIZE.get() as usize * 2)); // x2 to allow for the length prefixes for garnet - FullAccessor { + DynamicAccessor { provider, context, - is_search, + quantized, id_buffer, filtered_ids, } @@ -434,6 +780,7 @@ impl<'a, T: VectorRepr> FullAccessor<'a, T> { id, &mut guard, ) { + guard.finish(0); return false; } @@ -444,11 +791,11 @@ impl<'a, T: VectorRepr> FullAccessor<'a, T> { } } -impl HasId for FullAccessor<'_, T> { +impl HasId for DynamicAccessor<'_, T> { type Id = u32; } -impl SearchExt for FullAccessor<'_, T> { +impl SearchExt for DynamicAccessor<'_, T> { fn starting_points(&self) -> impl Future>> + Send { let points = if self.provider.start_points_exist() { vec![0] @@ -466,7 +813,7 @@ impl SearchExt for FullAccessor<'_, T> { } } -impl ExpandBeam<&[T]> for FullAccessor<'_, T> { +impl ExpandBeam<&[T]> for DynamicAccessor<'_, T> { fn expand_beam( &mut self, ids: Itr, @@ -490,14 +837,31 @@ impl ExpandBeam<&[T]> for FullAccessor<'_, T> { .filter(|id| pred.eval_mut(id)) { if id == 0 { - let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { - r + let dist = if self.quantized + && let Some(_quantizer) = self.provider.quantizer() + { + let guard = if let Some(r) = self.provider.start_point_quant_cache.get(&id) + { + r + } else { + return future::ready(Err(GarnetProviderError::Garnet( + GarnetError::Read, + ) + .into())); + }; + computer.evaluate_similarity(&*guard) } else { - return future::ready(Err( - GarnetProviderError::Garnet(GarnetError::Read).into() - )); + let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { + r + } else { + return future::ready(Err(GarnetProviderError::Garnet( + GarnetError::Read, + ) + .into())); + }; + computer.evaluate_similarity(&*guard) }; - let dist = computer.evaluate_similarity(&*guard); + on_neighbors(dist, id); } else { self.filtered_ids.push(4); @@ -505,49 +869,78 @@ impl ExpandBeam<&[T]> for FullAccessor<'_, T> { } } - self.provider.callbacks.read_multi_lpiid( - self.context.term(Term::Vector), - &self.filtered_ids, - |i, v| { - let dist = computer.evaluate_similarity(v); - on_neighbors(dist, self.filtered_ids[i as usize * 2 + 1]); - }, - ); + let ctx = if self.quantized { + self.context.term(Term::Quantized) + } else { + self.context.term(Term::Vector) + }; + + if !self.filtered_ids.is_empty() { + self.provider + .callbacks + .read_multi_lpiid(ctx, &self.filtered_ids, |i, v| { + let dist = computer.evaluate_similarity(v); + on_neighbors(dist, self.filtered_ids[i as usize * 2 + 1]); + }); + } } future::ready(Ok(())) } } -impl Accessor for FullAccessor<'_, T> { +impl Accessor for DynamicAccessor<'_, T> { type Element<'a> - = Vec + = DynVector where Self: 'a; - type ElementRef<'a> = &'a [T]; + type ElementRef<'a> = &'a [u8]; type GetError = GarnetProviderError; fn get_element( &mut self, id: Self::Id, ) -> impl Future, Self::GetError>> + Send { - let mut v = vec![T::default(); self.provider.dim]; + let v_len = if self.quantized + && let Some(quantizer) = self.provider.quantizer() + { + quantizer.canonical_bytes() + } else { + self.provider.dim * mem::size_of::() + }; + + let mut v = match Poly::broadcast(0u8, v_len, AlignToEight) { + Ok(v) => DynVector::new(v), + Err(e) => return future::ready(Err(GarnetProviderError::AllocFailed(e))), + }; if id == 0 { - let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { - r + if self.quantized { + let guard = if let Some(r) = self.provider.start_point_quant_cache.get(&id) { + r + } else { + return future::ready(Err(GarnetError::Read.into())); + }; + v.copy_from_slice(&guard); + return future::ready(Ok(v)); } else { - return future::ready(Err(GarnetError::Read.into())); - }; - v.copy_from_slice(&guard); - return future::ready(Ok(v)); + let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { + r + } else { + return future::ready(Err(GarnetError::Read.into())); + }; + v.copy_from_slice(&guard); + return future::ready(Ok(v)); + } } - if !self - .provider - .callbacks - .read_single_iid(self.context.term(Term::Vector), id, &mut v) - { + let ctx = if self.quantized { + self.context.term(Term::Quantized) + } else { + self.context.term(Term::Vector) + }; + + if !self.provider.callbacks.read_single_iid(ctx, id, &mut v) { return future::ready(Err(GarnetError::Read.into())); } @@ -555,29 +948,102 @@ impl Accessor for FullAccessor<'_, T> { } } -impl BuildDistanceComputer for FullAccessor<'_, T> { - type DistanceComputer = T::Distance; +/// Wrapper for full precision distance computer. +pub struct FullPrecisionDistance(T::Distance); + +impl DynDistanceComputer for FullPrecisionDistance { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { + self.0.evaluate_similarity( + bytemuck::cast_slice::(a), + bytemuck::cast_slice::(b), + ) + } +} + +/// Wrapper for full precision query computer. +pub struct FullPrecisionQueryDistance(T::QueryDistance); + +impl DynQueryComputer for FullPrecisionQueryDistance { + fn evaluate_similarity(&self, a: &[u8]) -> f32 { + self.0.evaluate_similarity(bytemuck::cast_slice::(a)) + } +} + +/// Type-erased distance computer. +pub struct GarnetDistanceComputer { + inner: Box, +} + +impl GarnetDistanceComputer { + pub fn new(computer: T) -> Self { + Self { + inner: Box::new(computer), + } + } +} +impl DistanceFunction<&[u8], &[u8]> for GarnetDistanceComputer { + fn evaluate_similarity(&self, x: &[u8], y: &[u8]) -> f32 { + self.inner.evaluate_similarity(x, y) + } +} + +/// Type-erased query computer. +pub struct GarnetQueryComputer { + inner: Box, +} + +impl GarnetQueryComputer { + pub fn new(computer: T) -> Self { + Self { + inner: Box::new(computer), + } + } +} + +impl PreprocessedDistanceFunction<&[u8]> for GarnetQueryComputer { + fn evaluate_similarity(&self, changing: &[u8]) -> f32 { + self.inner.evaluate_similarity(changing) + } +} + +impl BuildDistanceComputer for DynamicAccessor<'_, T> { + type DistanceComputer = GarnetDistanceComputer; type DistanceComputerError = GarnetProviderError; fn build_distance_computer( &self, ) -> Result { - Ok(T::distance( - self.provider.metric_type, - Some(self.provider.dim), - )) + if self.quantized + && let Some(quantizer) = self.provider.quantizer() + { + Ok(quantizer.distance_computer()?) + } else { + Ok(GarnetDistanceComputer::new(FullPrecisionDistance::( + T::distance(self.provider.metric_type, Some(self.provider.dim)), + ))) + } } } -impl BuildQueryComputer<&[T]> for FullAccessor<'_, T> { - type QueryComputer = T::QueryDistance; +impl BuildQueryComputer<&[T]> for DynamicAccessor<'_, T> { + type QueryComputer = GarnetQueryComputer; type QueryComputerError = GarnetProviderError; fn build_query_computer( &self, from: &[T], ) -> Result { - Ok(T::query_distance(from, self.provider.metric_type)) + if self.quantized + && let Some(quantizer) = self.provider.quantizer() + { + Ok(quantizer + .query_computer(bytemuck::cast_slice::(from)) + .map_err(|e| GarnetQuantizerError::QueryComputer(Box::new(e)))?) + } else { + Ok(GarnetQueryComputer::new(FullPrecisionQueryDistance::( + T::query_distance(from, self.provider.metric_type), + ))) + } } } @@ -594,61 +1060,106 @@ impl<'a, T> Reborrow<'a> for Escape { } } -type WorkingSet = workingset::Map>; -type WorkingSetView<'a, T> = workingset::map::View<'a, u32, Escape>; +pub struct WorkingSet { + map: workingset::Map>, + contains_unquantized: bool, +} + +impl WorkingSet { + pub fn new(capacity_type: workingset::map::Capacity, capacity: usize) -> Self { + Self { + map: workingset::map::Builder::new(capacity_type).build(capacity), + contains_unquantized: false, + } + } +} + +impl Deref for WorkingSet { + type Target = workingset::Map>; + + fn deref(&self) -> &Self::Target { + &self.map + } +} + +impl DerefMut for WorkingSet { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.map + } +} + +type WorkingSetView<'a> = workingset::map::View<'a, u32, Escape>; -impl workingset::Fill> for FullAccessor<'_, T> { +impl workingset::Fill for DynamicAccessor<'_, T> { type Error = GarnetProviderError; type View<'a> - = WorkingSetView<'a, T> + = WorkingSetView<'a> where Self: 'a; async fn fill<'a, Itr>( &'a mut self, - set: &'a mut WorkingSet, + set: &'a mut WorkingSet, itr: Itr, ) -> Result, Self::Error> where Itr: ExactSizeIterator + Clone + Send + Sync, Self: 'a, { + if !self.quantized { + // Mark this working set as having full vectors if we're not quantizing yet + set.contains_unquantized = true; + } else if set.contains_unquantized { + // Working set is polluted by full vectors, it must be cleared + set.clear(); + set.contains_unquantized = false; + } + // Evict items from the working set to make room if needed. set.prepare(itr.clone()); self.filtered_ids.clear(); for id in itr { if id == 0 { - if let Entry::Vacant(e) = set.entry(id) { - let guard = if let Some(r) = self.provider.start_point_cache.get(&id) { - r - } else { - return Err(GarnetError::Read.into()); - }; - e.insert(Escape((&**guard).into())); - } + if self.quantized + && let Entry::Vacant(e) = set.entry(id) + { + if let Some(guard) = self.provider.start_point_quant_cache.get(&id) { + e.insert(Escape((&**guard).into())); + } + } else if let Entry::Vacant(e) = set.entry(id) { + if let Some(guard) = self.provider.start_point_cache.get(&id) { + e.insert(Escape((&**guard).into())); + } + } else { + continue; + }; } else if !set.contains_key(&id) { self.filtered_ids.push(4); self.filtered_ids.push(id); } } + let ctx = if self.quantized { + self.context.term(Term::Quantized) + } else { + self.context.term(Term::Vector) + }; + if !self.filtered_ids.is_empty() { - self.provider.callbacks.read_multi_lpiid( - self.context.term(Term::Vector), - &self.filtered_ids, - |id, v| { + self.provider + .callbacks + .read_multi_lpiid(ctx, &self.filtered_ids, |id, v| { set.insert(self.filtered_ids[id as usize * 2 + 1], Escape(v.into())); - }, - ); + }); } Ok(set.view()) } } -pub struct DelegateNeighborAccessor<'p, 'a, T: VectorRepr>(&'a mut FullAccessor<'p, T>); +pub struct DelegateNeighborAccessor<'p, 'a, T: VectorRepr>(&'a mut DynamicAccessor<'p, T>); impl HasId for DelegateNeighborAccessor<'_, '_, T> { type Id = u32; @@ -668,7 +1179,7 @@ impl NeighborAccessor for DelegateNeighborAccessor<'_, '_, T> { } } -impl<'p, 'a, T: VectorRepr> DelegateNeighbor<'a> for FullAccessor<'p, T> { +impl<'p, 'a, T: VectorRepr> DelegateNeighbor<'a> for DynamicAccessor<'p, T> { type Delegate = DelegateNeighborAccessor<'p, 'a, T>; fn delegate_neighbor(&'a mut self) -> Self::Delegate { @@ -754,21 +1265,21 @@ impl NeighborAccessorMut for DelegateNeighborAccessor<'_, '_, T> #[derive(Debug, Default, Clone, Copy)] pub struct CopyExternalIds; -impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], GarnetId> +impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], GarnetId> for CopyExternalIds { type Error = GarnetProviderError; fn post_process( &self, - accessor: &mut FullAccessor<'a, T>, + accessor: &mut DynamicAccessor<'a, T>, _query: &[T], - _computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, + _computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, candidates: I, output: &mut B, ) -> impl Future> + Send where - I: Iterator as HasId>::Id>> + Send, + I: Iterator as HasId>::Id>> + Send, B: SearchOutputBuffer + Send + ?Sized, { let initial = output.current_len(); @@ -782,41 +1293,124 @@ impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], Garn break; } } + let count = output.current_len() - initial; future::ready(Ok(count)) } } -impl SearchStrategy, &[T]> for FullPrecision { - type SearchAccessor<'a> = FullAccessor<'a, T>; +/// A [`SearchPostProcess`] base object that reranks quantized vectors by full precision distance. +#[derive(Debug, Default, Clone, Copy)] +pub struct Rerank; + +impl<'a, 'b, T: VectorRepr> SearchPostProcessStep, &'b [T], GarnetId> + for Rerank +{ + type Error + = GarnetProviderError + where + NextError: diskann::error::StandardError; + + type NextAccessor = DynamicAccessor<'a, T>; + + async fn post_process_step( + &self, + next: &Next, + accessor: &mut DynamicAccessor<'a, T>, + query: &'b [T], + computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, + candidates: I, + output: &mut B, + ) -> Result> + where + I: Iterator as HasId>::Id>> + Send, + B: SearchOutputBuffer + Send + ?Sized, + Next: SearchPostProcess + Sync, + { + if !accessor.quantized { + // Skip reranking if the accessor if working with full precision + return next + .post_process(accessor, query, computer, candidates, output) + .await + .map_err(|e| GarnetProviderError::PostProcessing(Box::new(e))); + } + + let provider = &accessor.provider; + let f = T::distance(provider.metric_type, Some(provider.dim)); + let mut v = Poly::broadcast(0u8, provider.dim * mem::size_of::(), AlignToEight)?; + + // Filter before computing the full precision distances. + let mut reranked: Vec<(u32, f32)> = candidates + .filter_map(|n| { + if !provider.vector_iid_exists(accessor.context, n.id) { + None + } else if provider.callbacks.read_single_iid( + accessor.context.term(Term::Vector), + n.id, + &mut v, + ) { + Some(( + n.id, + f.evaluate_similarity(query, bytemuck::cast_slice::(&v)), + )) + } else { + None + } + }) + .collect(); + + // Sort the full precision distances. + reranked + .sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + + next.post_process( + accessor, + query, + computer, + reranked.into_iter().map(|(id, d)| Neighbor::new(id, d)), + output, + ) + .await + .map_err(|e| GarnetProviderError::PostProcessing(Box::new(e))) + } +} + +impl SearchStrategy, &[T]> for DynamicQuantization { + type SearchAccessor<'a> = DynamicAccessor<'a, T>; type SearchAccessorError = GarnetProviderError; - type QueryComputer = T::QueryDistance; + type QueryComputer = GarnetQueryComputer; fn search_accessor<'a>( &'a self, provider: &'a GarnetProvider, context: &'a as DataProvider>::Context, ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider, context, true)) + let quantized = provider.is_quantized(); + Ok(DynamicAccessor::new(provider, context, quantized)) } } -impl DefaultPostProcessor, &[T], GarnetId> for FullPrecision { - default_post_processor!(glue::Pipeline); +impl DefaultPostProcessor, &[T], GarnetId> + for DynamicQuantization +{ + default_post_processor!( + glue::Pipeline> + ); } -impl PruneStrategy> for FullPrecision { - type PruneAccessor<'a> = FullAccessor<'a, T>; +impl PruneStrategy> for DynamicQuantization { + type PruneAccessor<'a> = DynamicAccessor<'a, T>; type PruneAccessorError = GarnetProviderError; - type DistanceComputer<'a> = T::Distance; - type WorkingSet = WorkingSet; + type DistanceComputer<'a> = GarnetDistanceComputer; + type WorkingSet = WorkingSet; fn prune_accessor<'a>( &'a self, provider: &'a GarnetProvider, context: &'a as DataProvider>::Context, ) -> Result, Self::PruneAccessorError> { - Ok(FullAccessor::new(provider, context, false)) + let quantized = provider.is_quantized(); + Ok(DynamicAccessor::new(provider, context, quantized)) } fn create_working_set(&self, capacity: usize) -> Self::WorkingSet { @@ -824,11 +1418,11 @@ impl PruneStrategy> for FullPrecision { // cache and persist up to `capacity` items across uses of the working set. // // This reuse is limited to a single collection of backedges for an insert or multi-insert. - workingset::map::Builder::new(workingset::map::Capacity::Default).build(capacity) + WorkingSet::new(workingset::map::Capacity::Default, capacity) } } -impl InsertStrategy, &[T]> for FullPrecision { +impl InsertStrategy, &[T]> for DynamicQuantization { type PruneStrategy = Self; fn insert_search_accessor<'a>( @@ -836,7 +1430,8 @@ impl InsertStrategy, &[T]> for FullPrecision { provider: &'a GarnetProvider, context: &'a as DataProvider>::Context, ) -> Result, Self::SearchAccessorError> { - Ok(FullAccessor::new(provider, context, false)) + let quantized = provider.is_quantized(); + Ok(DynamicAccessor::new(provider, context, quantized)) } fn prune_strategy(&self) -> Self::PruneStrategy { @@ -844,13 +1439,13 @@ impl InsertStrategy, &[T]> for FullPrecision { } } -impl InplaceDeleteStrategy> for FullPrecision { +impl InplaceDeleteStrategy> for DynamicQuantization { type DeleteElement<'a> = &'a [T]; type DeleteElementGuard = Box<[T]>; type DeleteElementError = GarnetProviderError; type PruneStrategy = Self; - type DeleteSearchAccessor<'a> = FullAccessor<'a, T>; + type DeleteSearchAccessor<'a> = DynamicAccessor<'a, T>; type SearchPostProcessor = glue::CopyIds; type SearchStrategy = Self; diff --git a/diskann-garnet/src/quantization.rs b/diskann-garnet/src/quantization.rs new file mode 100644 index 000000000..265bda056 --- /dev/null +++ b/diskann-garnet/src/quantization.rs @@ -0,0 +1,317 @@ +use std::{num::NonZero, sync::RwLock}; + +use diskann::utils::VectorRepr; +use diskann_quantization::{ + CompressInto, + algorithms::{Transform, TransformKind, transforms::NewTransformError}, + alloc::{GlobalAllocator, ScopedAllocator}, + minmax, + num::Positive, + spherical::{ + self, Data, PreScale, SphericalQuantizer, SupportedMetric, + iface::{self, Opaque, OpaqueMut, Quantizer}, + }, +}; +use diskann_utils::views::MatrixView; +use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric}; +use thiserror::Error; + +use crate::provider::{GarnetDistanceComputer, GarnetQueryComputer}; + +#[derive(Debug, Error)] +pub enum GarnetQuantizerError { + #[error("Quantization training error: {0}")] + Training(Box), + #[error("Quantization alloc error: {0}")] + Alloc(Box), + #[error("Query computer error: {0}")] + QueryComputer(Box), + #[error("Binary quantization error: {0}")] + Compression(Box), + #[error("No quantizer found")] + NoQuantizer, + #[error("Got zero dimension")] + ZeroDim, + #[error("Transform error: {0}")] + BadTransform(#[from] NewTransformError), +} + +/// Quantizer trait that all diskann-garnet quantizers must implement +pub trait GarnetQuantizer: Send + Sync { + /// Check whether the quantizer is ready to be used + fn is_prepared(&self) -> bool; + /// Returns the number of vectors needed before the quantizer can be trained + fn required_vectors(&self) -> usize; + /// Returns the size of a quantized vector + fn canonical_bytes(&self) -> usize; + /// Train the quantizer. + /// Each row of the matrix will be a vector + fn train(&self, metric: Metric, data: MatrixView) -> Result<(), GarnetQuantizerError>; + /// Quantize a vector + fn compress(&self, v: &[f32], into: &mut [u8]) -> Result<(), GarnetQuantizerError>; + /// Returns a distance computer for comparing quantized vectors + fn distance_computer(&self) -> Result; + /// Returns a query computer for comparing distances to a particular query + fn query_computer(&self, query: &[f32]) -> Result; +} + +/// Type-erased distance computer +pub trait DynDistanceComputer: Send + Sync { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32; +} + +/// Type-erased query computer +pub trait DynQueryComputer: Send + Sync { + fn evaluate_similarity(&self, a: &[u8]) -> f32; +} + +/// Spherical 1-bit quantization. +/// +/// This quantizer corresponds to `BIN` quantizer in the Redis protocol. It requires hundreds of +/// vectors (but not thousands) for training. Quantized vectors have 1 bit per dimension plus up +/// to 6 bytes of overhead. +pub struct Spherical1Bit { + dim: usize, + inner: RwLock>>, +} + +impl Spherical1Bit { + pub fn new(dim: usize) -> Self { + Self { + dim, + inner: RwLock::new(None), + } + } +} + +impl GarnetQuantizer for Spherical1Bit { + fn is_prepared(&self) -> bool { + self.inner.read().unwrap().is_some() + } + + fn required_vectors(&self) -> usize { + 1000 + } + + fn canonical_bytes(&self) -> usize { + Data::<1, GlobalAllocator>::canonical_bytes(self.dim) + } + + fn train( + &self, + metric_type: Metric, + data: MatrixView, + ) -> Result<(), GarnetQuantizerError> { + let mut rng = rand::rng(); + let quantizer = SphericalQuantizer::train( + data.as_view(), + TransformKind::DoubleHadamard { + target_dim: diskann_quantization::algorithms::transforms::TargetDim::Same, + }, + SupportedMetric::try_from(metric_type) + .map_err(|e| GarnetQuantizerError::Training(Box::new(e)))?, + PreScale::ReciprocalMeanNorm, + &mut rng, + GlobalAllocator, + ) + .map_err(|e| GarnetQuantizerError::Training(Box::new(e)))?; + + let mut inner = self.inner.write().unwrap(); + *inner = Some( + spherical::iface::Impl::<1>::new(quantizer) + .map_err(|e| GarnetQuantizerError::Alloc(Box::new(e)))?, + ); + + Ok(()) + } + + fn compress(&self, v: &[f32], into: &mut [u8]) -> Result<(), GarnetQuantizerError> { + let guard = self.inner.read().unwrap(); + if let Some(quantizer) = &*guard { + spherical::iface::Quantizer::::compress( + quantizer, + v, + OpaqueMut::new(into), + ScopedAllocator::global(), + ) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + Ok(()) + } else { + Err(GarnetQuantizerError::NoQuantizer) + } + } + + fn distance_computer(&self) -> Result { + let guard = self.inner.read().unwrap(); + if let Some(quantizer) = &*guard { + let computer = quantizer + .distance_computer(GlobalAllocator) + .map_err(|e| GarnetQuantizerError::Alloc(Box::new(e)))?; + Ok(GarnetDistanceComputer::new(computer)) + } else { + Err(GarnetQuantizerError::NoQuantizer) + } + } + + fn query_computer(&self, query: &[f32]) -> Result { + let guard = self.inner.read().unwrap(); + if let Some(quantizer) = &*guard { + let computer = quantizer + .fused_query_computer( + query, + iface::QueryLayout::FullPrecision, + true, + GlobalAllocator, + ScopedAllocator::global(), + ) + .map_err(|e| GarnetQuantizerError::QueryComputer(Box::new(e)))?; + Ok(GarnetQueryComputer::new(computer)) + } else { + Err(GarnetQuantizerError::NoQuantizer) + } + } +} + +impl DynDistanceComputer for iface::DistanceComputer { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { + , Opaque<'_>, _>>::evaluate_similarity( + self, + Opaque::new(a), + Opaque::new(b), + ) + .unwrap() + } +} + +impl DynQueryComputer for iface::QueryComputer { + fn evaluate_similarity(&self, a: &[u8]) -> f32 { + , _>>::evaluate_similarity( + self, + Opaque::new(a), + ) + .unwrap() + } +} + +/// 8-bit scalar quantizer using MinMax +/// +/// This quantizer requires no training at all and is usable immediately on the first first. Each +/// quantized vector has 8 bits per dimension and 20 bytes of overhead. +pub struct MinMax8Bit { + dim: usize, + metric: Metric, + inner: minmax::MinMaxQuantizer, +} + +impl MinMax8Bit { + pub fn new(dim: usize, metric: Metric) -> Result { + let dim = match NonZero::new(dim) { + Some(d) => d, + None => return Err(GarnetQuantizerError::ZeroDim), + }; + let mut rng = rand::rng(); + let transform = Transform::new( + TransformKind::DoubleHadamard { + target_dim: diskann_quantization::algorithms::transforms::TargetDim::Same, + }, + dim, + Some(&mut rng), + GlobalAllocator, + )?; + let grid_scale = Positive::new(1.0).unwrap(); + + Ok(Self { + dim: dim.get(), + metric, + inner: minmax::MinMaxQuantizer::new(transform, grid_scale), + }) + } +} + +impl GarnetQuantizer for MinMax8Bit { + fn is_prepared(&self) -> bool { + true + } + + fn required_vectors(&self) -> usize { + 0 + } + + fn canonical_bytes(&self) -> usize { + minmax::Data::<8>::canonical_bytes(self.dim) + } + + fn train(&self, _metric: Metric, _data: MatrixView) -> Result<(), GarnetQuantizerError> { + Ok(()) + } + + fn compress(&self, v: &[f32], into: &mut [u8]) -> Result<(), GarnetQuantizerError> { + let into = minmax::DataMutRef::<8>::from_canonical_front_mut(into, self.dim) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + self.inner + .compress_into(v, into) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + Ok(()) + } + + fn distance_computer(&self) -> Result { + let computer = GarnetDistanceComputer::new( + ::distance( + self.metric, + Some(self.dim), + ), + ); + Ok(computer) + } + + fn query_computer(&self, query: &[f32]) -> Result { + let computer = GarnetQueryComputer::new(MinMax8BitQueryComputer::new( + &self.inner, + query, + self.dim, + self.metric, + )?); + Ok(computer) + } +} + +impl DynDistanceComputer for diskann_providers::common::FnPtr { + fn evaluate_similarity(&self, a: &[u8], b: &[u8]) -> f32 { + let a = diskann_providers::common::MinMax8::from_bytes(a); + let b = diskann_providers::common::MinMax8::from_bytes(b); + >::evaluate_similarity(self, a, b) + } +} +struct MinMax8BitQueryComputer( + diskann_providers::common::BufferedFnPtr, +); + +impl MinMax8BitQueryComputer { + fn new( + quantizer: &minmax::MinMaxQuantizer, + query: &[f32], + dim: usize, + metric: Metric, + ) -> Result { + let mut v = vec![Default::default(); minmax::Data::<8>::canonical_bytes(dim)]; + quantizer + .compress_into( + query, + minmax::DataMutRef::<8>::from_canonical_front_mut(&mut v, dim) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?, + ) + .map_err(|e| GarnetQuantizerError::Compression(Box::new(e)))?; + let inner = diskann_providers::common::MinMax8::query_distance( + diskann_providers::common::MinMax8::from_bytes(&v), + metric, + ); + Ok(Self(inner)) + } +} + +impl DynQueryComputer for MinMax8BitQueryComputer { + fn evaluate_similarity(&self, a: &[u8]) -> f32 { + let a = diskann_providers::common::MinMax8::from_bytes(a); + self.0.evaluate_similarity(a) + } +} diff --git a/diskann-garnet/src/test_utils.rs b/diskann-garnet/src/test_utils.rs index 26b72f774..ba9caed10 100644 --- a/diskann-garnet/src/test_utils.rs +++ b/diskann-garnet/src/test_utils.rs @@ -131,6 +131,7 @@ mod tests { use std::collections::HashMap; use crate::{ + VectorQuantType, garnet::{Context, Term}, test_utils::Store, }; @@ -203,7 +204,14 @@ mod tests { // Create a u8 GarnetProvider with the test Store callbacks let dim = 8; let max_degree = 32; - let provider = GarnetProvider::::new(dim, Metric::L2, max_degree, callbacks, ctx); + let provider = GarnetProvider::::new( + dim, + VectorQuantType::NoQuant, + Metric::L2, + max_degree, + callbacks, + ctx, + ); // Provider should be created successfully assert!(provider.is_ok()); @@ -215,7 +223,14 @@ mod tests { // Create a u8 GarnetProvider with the test Store callbacks let dim = 8; let max_degree = 32; - let provider = GarnetProvider::::new(dim, Metric::L2, max_degree, callbacks, ctx); + let provider = GarnetProvider::::new( + dim, + VectorQuantType::NoQuant, + Metric::L2, + max_degree, + callbacks, + ctx, + ); // Provider should be created successfully assert!(provider.is_ok()); diff --git a/diskann-providers/src/common/minmax_repr.rs b/diskann-providers/src/common/minmax_repr.rs index 73c19b10d..19dc8ec96 100644 --- a/diskann-providers/src/common/minmax_repr.rs +++ b/diskann-providers/src/common/minmax_repr.rs @@ -117,6 +117,12 @@ impl num_traits::FromPrimitive for MinMaxElement { impl_from_primitive!(NBITS, from_u32, u32); } +impl MinMaxElement { + pub fn from_bytes(x: &[u8]) -> &[Self] { + bytemuck::must_cast_slice(x) + } +} + //////////////////////// /// DistanceProvider /// //////////////////////// diff --git a/diskann-providers/src/common/mod.rs b/diskann-providers/src/common/mod.rs index b5c4a42a4..ec39946f7 100644 --- a/diskann-providers/src/common/mod.rs +++ b/diskann-providers/src/common/mod.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ mod minmax_repr; -pub use minmax_repr::{MinMax4, MinMax8, MinMaxElement}; +pub use minmax_repr::{BufferedFnPtr, FnPtr, MinMax4, MinMax8, MinMaxElement}; mod ignore_lock_poison; pub use ignore_lock_poison::IgnoreLockPoison; diff --git a/vectorset/src/loader.rs b/vectorset/src/loader.rs index f1e62444e..9e0f7a856 100644 --- a/vectorset/src/loader.rs +++ b/vectorset/src/loader.rs @@ -4,22 +4,14 @@ */ use anyhow::{Result, anyhow}; -use std::{ - collections::HashMap, - marker::PhantomData, - mem, - path::{Path, PathBuf}, - sync::Arc, -}; +use std::{collections::HashMap, marker::PhantomData, mem, path::Path, sync::Arc}; use tokio::{fs::File, io::AsyncReadExt, sync::Mutex}; const BATCH_SIZE: usize = 1024; /// Loads base or query vectors from a given path and allow iteration over them. -#[allow(dead_code)] pub struct DatasetLoader { file: Mutex<(File, usize)>, - path: PathBuf, num_vectors: usize, dim: usize, type_: PhantomData, @@ -36,7 +28,6 @@ impl DatasetLoader { Ok(Arc::new(Self { file: Mutex::new((file, 0)), - path, num_vectors, dim, type_: PhantomData, @@ -60,35 +51,33 @@ impl DatasetLoader { /// Returns (count, first_id) where `count` is the number of vectors loaded /// and `first_id` is the id of the first vector. pub async fn next(&self, buffer: &mut Vec) -> Result<(usize, usize)> { + let mut f = self.file.lock().await; + let mut count; let mut first_id; loop { - { - let mut f = self.file.lock().await; - - first_id = f.1; - if f.1 >= self.num_vectors { - buffer.clear(); - return Ok((0, first_id)); - } - - buffer.resize(BATCH_SIZE * self.dim, T::zeroed()); + first_id = f.1; + if f.1 >= self.num_vectors { + buffer.clear(); + return Ok((0, first_id)); + } - let mut buf: &mut [u8] = bytemuck::cast_slice_mut::(&mut *buffer); - while let bytes_read = f.0.read(buf).await? - && bytes_read > 0 - { - buf = &mut buf[bytes_read..]; - } + buffer.resize(BATCH_SIZE * self.dim, T::zeroed()); - let elements_left = buf.len() / mem::size_of::(); - if !buf.is_empty() && !elements_left.is_multiple_of(self.dim) { - return Err(anyhow!("unexpected EOF")); - } + let mut buf: &mut [u8] = bytemuck::cast_slice_mut::(&mut *buffer); + while let bytes_read = f.0.read(buf).await? + && bytes_read > 0 + { + buf = &mut buf[bytes_read..]; + } - count = BATCH_SIZE - elements_left / self.dim; + let elements_left = buf.len() / mem::size_of::(); + if !buf.is_empty() && !elements_left.is_multiple_of(self.dim) { + return Err(anyhow!("unexpected EOF")); } + count = BATCH_SIZE - elements_left / self.dim; + if count == 0 { continue; } @@ -96,7 +85,6 @@ impl DatasetLoader { break; } - let mut f = self.file.lock().await; f.1 += count; Ok((count, first_id)) diff --git a/vectorset/src/main.rs b/vectorset/src/main.rs index 4ba864d56..8333e4f86 100644 --- a/vectorset/src/main.rs +++ b/vectorset/src/main.rs @@ -107,6 +107,14 @@ struct IngestArgs { #[arg(long)] no_header_with_dim: Option, + /// Quantizer + #[arg(long)] + quantizer: Option, + + /// Metric + #[arg(long)] + metric: Option, + /// Paths to base vectors base_path: PathBuf, } @@ -161,6 +169,50 @@ enum DataType { Float32, } +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +enum Quantizer { + None, + Bin, + Q8, +} + +impl ToRedisArgs for Quantizer { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + redis::RedisWrite, + { + let q = match self { + Quantizer::None => b"NOQUANT".as_slice(), + Quantizer::Bin => b"BIN".as_slice(), + Quantizer::Q8 => b"Q8".as_slice(), + }; + out.write_arg(q); + } +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +enum DistanceMetric { + L2, + Cosine, + CosineNormalized, + InnerProduct, +} + +impl ToRedisArgs for DistanceMetric { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + redis::RedisWrite, + { + let q = match self { + DistanceMetric::L2 => b"L2".as_slice(), + DistanceMetric::Cosine => b"COSINE".as_slice(), + DistanceMetric::CosineNormalized => b"COSINE_NORMALIZED".as_slice(), + DistanceMetric::InnerProduct => b"INNER_PRODUCT".as_slice(), + }; + out.write_arg(q); + } +} + struct VectorId(u32); impl ToRedisArgs for VectorId { @@ -362,6 +414,8 @@ async fn ingest( let degree = args.degree; let mut cred = cred.clone(); let data_type = opts.data_type; + let quantizer = args.quantizer; + let metric = args.metric; tasks.spawn(async move { let mut buf = vec![T::zeroed(); ds.batch_size() * ds.dim()]; @@ -387,6 +441,7 @@ async fn ingest( let element = VectorId((first_id + i) as u32); let buf_start = i * ds.dim(); let buf_end = buf_start + ds.dim(); + pipeline.cmd("VADD").arg(&vset); match data_type { @@ -407,10 +462,14 @@ async fn ingest( pipeline.arg(b"XPREQ8"); } DataType::Float32 => { - pipeline.arg(b"NOQUANT"); + pipeline.arg(quantizer); } } + if let Some(metric) = metric { + pipeline.arg(b"XDISTANCE_METRIC").arg(metric); + } + pipeline .arg(b"EF") .arg(l_build.to_string().as_bytes())