diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 33938caea..3d9f8ce09 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -29,8 +29,8 @@ use diskann::{ }, neighbor::Neighbor, provider::{ - Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId, - NeighborAccessor, NoopGuard, + Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, + DistancesUnordered, HasElementRef, HasId, NeighborAccessor, NoopGuard, }, utils::{IntoUsize, VectorRepr}, ANNError, ANNResult, @@ -428,7 +428,13 @@ where .to_vec(), }) } +} +impl DistancesUnordered<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> +where + Data: GraphDataType, + VP: VertexProvider, +{ async fn distances_unordered( &mut self, vec_id_itr: Itr, @@ -689,6 +695,15 @@ where type Id = u32; } +impl HasElementRef for DiskAccessor<'_, Data, VP> +where + Data: GraphDataType, + VP: VertexProvider, +{ + /// `ElementRef` can have arbitrary lifetimes. + type ElementRef<'a> = &'a [u8]; +} + impl Accessor for DiskAccessor<'_, Data, VP> where Data: GraphDataType, @@ -701,9 +716,6 @@ where where Self: 'a; - /// `ElementRef` can have arbitrary lifetimes. - type ElementRef<'a> = &'a [u8]; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = ANNError; diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 0be5be9a4..8350c6cb8 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -18,7 +18,8 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DelegateNeighbor, - Delete, ElementStatus, HasId, NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, + Delete, DistancesUnordered, ElementStatus, HasElementRef, HasId, NeighborAccessor, + NeighborAccessorMut, NoopGuard, SetElement, }, utils::VectorRepr, }; @@ -519,12 +520,15 @@ impl ExpandBeam<&[T]> for FullAccessor<'_, T> { } } +impl HasElementRef for FullAccessor<'_, T> { + type ElementRef<'a> = &'a [T]; +} + impl Accessor for FullAccessor<'_, T> { type Element<'a> = Vec where Self: 'a; - type ElementRef<'a> = &'a [T]; type GetError = GarnetProviderError; fn get_element( @@ -581,6 +585,8 @@ impl BuildQueryComputer<&[T]> for FullAccessor<'_, T> { } } +impl DistancesUnordered<&[T]> for FullAccessor<'_, T> {} + /// An escape hatch for the blanket implementation of [`workingset::Fill`]. /// /// Without an `&[T]: Into>`, the blanket implementation for `workingset::Map` diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 0d16248dd..15a2b5ce9 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -8,7 +8,10 @@ use std::sync::{Arc, RwLock}; use diskann::{ error::{ErrorExt, IntoANNResult}, graph::glue::{ExpandBeam, SearchExt}, - provider::{Accessor, AsNeighbor, BuildQueryComputer, DelegateNeighbor, HasId}, + provider::{ + Accessor, AsNeighbor, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, + HasElementRef, HasId, + }, ANNError, ANNErrorKind, }; use diskann_utils::Reborrow; @@ -68,6 +71,13 @@ where type Id = ::Id; } +impl HasElementRef for EncodedDocumentAccessor +where + IA: Accessor, +{ + type ElementRef<'a> = EncodedDocument, &'a RoaringTreemap>; +} + impl Accessor for EncodedDocumentAccessor where IA: Accessor, @@ -76,7 +86,6 @@ where = EncodedDocument, RoaringTreemap> where Self: 'a; - type ElementRef<'a> = EncodedDocument, &'a RoaringTreemap>; type GetError = ANNError; async fn get_element(&mut self, id: Self::Id) -> Result, Self::GetError> { @@ -164,7 +173,7 @@ where impl<'q, IA, Q> BuildQueryComputer<&'q FilteredQuery> for EncodedDocumentAccessor where - IA: BuildQueryComputer<&'q Q>, + IA: BuildQueryComputer<&'q Q> + Accessor, { type QueryComputerError = ANNError; type QueryComputer = InlineBetaComputer; @@ -195,6 +204,14 @@ where { } +impl DistancesUnordered for EncodedDocumentAccessor +where + IA: Accessor, + EncodedDocumentAccessor: BuildQueryComputer, + Q: Clone, +{ +} + impl<'a, IA> DelegateNeighbor<'a> for EncodedDocumentAccessor where IA: DelegateNeighbor<'a>, diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 26b07f0f6..10aef0109 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -143,7 +143,7 @@ pub struct FilterResults { impl<'q, Q, IA, IPP> SearchPostProcess, &'q FilteredQuery> for FilterResults where - IA: BuildQueryComputer<&'q Q>, + IA: BuildQueryComputer<&'q Q> + Accessor, Q: Send + Sync, IPP: SearchPostProcess + Send + Sync, { diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index 8fde6f465..e67c50d00 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -28,8 +28,8 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultContext, - DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, NeighborAccessorMut, - NoopGuard, SetElement, + DelegateNeighbor, Delete, DistancesUnordered, ElementStatus, HasElementRef, HasId, + NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, }, utils::{IntoUsize, VectorRepr}, }; @@ -974,6 +974,15 @@ where } } +impl HasElementRef for FullAccessor<'_, T, Q, D> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, +{ + type ElementRef<'a> = &'a [T]; +} + impl Accessor for FullAccessor<'_, T, Q, D> where T: VectorRepr, @@ -986,9 +995,6 @@ where where Self: 'a; - /// The reference version of `Element` is the same as `Element`. - type ElementRef<'a> = &'a [T]; - // Choose to panic on an out-of-bounds access rather than propagate an error. // type GetError = Panics; @@ -1058,6 +1064,14 @@ where { } +impl DistancesUnordered<&[T]> for FullAccessor<'_, T, Q, D> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, +{ +} + impl<'a, T, Q, D> AsDeletionCheck for FullAccessor<'a, T, Q, D> where T: VectorRepr, @@ -1134,6 +1148,14 @@ where } } +impl HasElementRef for QuantAccessor<'_, T, D> +where + T: VectorRepr, + D: AsyncFriendly, +{ + type ElementRef<'a> = &'a [u8]; +} + impl Accessor for QuantAccessor<'_, T, D> where T: VectorRepr, @@ -1145,9 +1167,6 @@ where where Self: 'a; - /// The reference version of `Element` is simply `Element`. - type ElementRef<'a> = &'a [u8]; - // ANNError on access failures in bf-tree // type GetError = ANNError; @@ -1220,6 +1239,13 @@ where { } +impl DistancesUnordered<&[T]> for QuantAccessor<'_, T, D> +where + T: VectorRepr, + D: AsyncFriendly, +{ +} + impl<'a, T, D> AsDeletionCheck for QuantAccessor<'a, T, D> where T: VectorRepr, @@ -1282,6 +1308,14 @@ where } } +impl HasElementRef for HybridAccessor<'_, T, D> +where + T: VectorRepr, + D: AsyncFriendly, +{ + type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; +} + impl Accessor for HybridAccessor<'_, T, D> where T: VectorRepr, @@ -1296,9 +1330,6 @@ where where Self: 'a; - /// The generalized reference form of `Element`. - type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; - // Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = Panics; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index e72f323ef..0b5eb1b72 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -20,7 +20,7 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - ExecutionContext, HasId, + DistancesUnordered, ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -185,6 +185,16 @@ where } } +impl HasElementRef for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type ElementRef<'a> = &'a [T]; +} + impl Accessor for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, @@ -199,9 +209,6 @@ where where Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = &'a [T]; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = Panics; @@ -316,6 +323,15 @@ where { } +impl DistancesUnordered<&[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + //-------------------// // In-mem Extensions // //-------------------// @@ -353,7 +369,7 @@ pub struct Rerank; impl<'a, A, T> glue::SearchPostProcess for Rerank where T: VectorRepr, - A: BuildQueryComputer<&'a [T], Id = u32> + GetFullPrecision + AsDeletionCheck, + A: BuildQueryComputer<&'a [T]> + HasId + GetFullPrecision + AsDeletionCheck, { type Error = Panics; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index b3867ad85..477dc517d 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -16,8 +16,8 @@ use diskann::{ workingset, }, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, - HasId, + Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, + ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -143,6 +143,15 @@ where } } +impl HasElementRef for QuantAccessor<'_, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type ElementRef<'a> = &'a [u8]; +} + impl Accessor for QuantAccessor<'_, V, D, Ctx> where V: AsyncFriendly, @@ -156,9 +165,6 @@ where where Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = &'a [u8]; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = Panics; @@ -242,6 +248,15 @@ where { } +impl DistancesUnordered<&[T]> for QuantAccessor<'_, V, D, Ctx> +where + T: VectorRepr, + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + //-------------------// // In-mem Extensions // //-------------------// @@ -316,6 +331,15 @@ where } } +impl HasElementRef for HybridAccessor<'_, T, D, Ctx> +where + T: VectorRepr, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; +} + impl Accessor for HybridAccessor<'_, T, D, Ctx> where T: VectorRepr, @@ -331,9 +355,6 @@ where where Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = Panics; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index a69344691..c9c781fed 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -16,8 +16,8 @@ use diskann::{ workingset, }, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, - HasId, + Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, + ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -420,6 +420,16 @@ where } } +impl HasElementRef for QuantAccessor<'_, NBITS, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, + Unsigned: Representation, +{ + type ElementRef<'a> = CVRef<'a, NBITS>; +} + impl Accessor for QuantAccessor<'_, NBITS, V, D, Ctx> where V: AsyncFriendly, @@ -434,9 +444,6 @@ where where Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = CVRef<'a, NBITS>; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = ANNError; @@ -556,6 +563,18 @@ where { } +impl DistancesUnordered<&[T]> + for QuantAccessor<'_, NBITS, V, D, Ctx> +where + T: VectorRepr, + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, + Unsigned: Representation, + QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, +{ +} + impl BuildDistanceComputer for QuantAccessor<'_, NBITS, V, D, Ctx> where V: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 11f489aa4..0edaf4a77 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -18,8 +18,8 @@ use diskann::{ workingset, }, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, - HasId, + Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, + ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -340,6 +340,15 @@ where } } +impl HasElementRef for QuantAccessor<'_, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type ElementRef<'a> = spherical::iface::Opaque<'a>; +} + impl Accessor for QuantAccessor<'_, V, D, Ctx> where V: AsyncFriendly, @@ -353,9 +362,6 @@ where where Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = spherical::iface::Opaque<'a>; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = ANNError; @@ -467,6 +473,15 @@ where { } +impl DistancesUnordered<&[T]> for QuantAccessor<'_, V, D, Ctx> +where + T: VectorRepr, + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + #[derive(Debug, Error)] #[error("unconstructible")] pub enum Infallible {} diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs index 2493a5ce2..61a99498f 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs @@ -21,7 +21,7 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - HasId, + DistancesUnordered, HasElementRef, HasId, }, utils::IntoUsize, }; @@ -164,6 +164,10 @@ impl HasId for FlakyAccessor<'_> { type Id = u32; } +impl HasElementRef for FlakyAccessor<'_> { + type ElementRef<'a> = &'a [f32]; +} + impl Accessor for FlakyAccessor<'_> { /// This accessor returns raw slices. There *is* a chance of racing when the fast /// providers are used. We just have to live with it. @@ -172,8 +176,6 @@ impl Accessor for FlakyAccessor<'_> { where Self: 'a; - type ElementRef<'a> = &'a [f32]; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = TestError; @@ -231,6 +233,8 @@ impl<'a, 'b> BuildQueryComputer<&'a [f32]> for FlakyAccessor<'b> { impl ExpandBeam<&[f32]> for FlakyAccessor<'_> {} +impl DistancesUnordered<&[f32]> for FlakyAccessor<'_> {} + impl<'a> DelegateNeighbor<'a> for FlakyAccessor<'_> { type Delegate = &'a SimpleNeighborProviderAsync; fn delegate_neighbor(&'a mut self) -> Self::Delegate { diff --git a/diskann-providers/src/model/graph/provider/async_/postprocess.rs b/diskann-providers/src/model/graph/provider/async_/postprocess.rs index 3e1849bb0..dbbf08fa4 100644 --- a/diskann-providers/src/model/graph/provider/async_/postprocess.rs +++ b/diskann-providers/src/model/graph/provider/async_/postprocess.rs @@ -8,7 +8,7 @@ use diskann::{ graph::{SearchOutputBuffer, glue}, neighbor::Neighbor, - provider::BuildQueryComputer, + provider::{BuildQueryComputer, HasId}, }; /// A bridge allowing `Accessors` to opt-in to [`RemoveDeletedIdsAndCopy`] by delegating to @@ -39,7 +39,7 @@ pub struct RemoveDeletedIdsAndCopy; impl glue::SearchPostProcess for RemoveDeletedIdsAndCopy where - A: BuildQueryComputer + AsDeletionCheck, + A: BuildQueryComputer + HasId + AsDeletionCheck, { type Error = std::convert::Infallible; diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index f0ffc0451..ff822dd9d 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -24,7 +24,10 @@ use diskann::{ index::QueryLabelProvider, }, neighbor::Neighbor, - provider::{Accessor, AsNeighbor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, + provider::{ + Accessor, AsNeighbor, BuildQueryComputer, DataProvider, DelegateNeighbor, + DistancesUnordered, HasElementRef, HasId, + }, utils::VectorId, }; use diskann_utils::Reborrow; @@ -70,7 +73,7 @@ pub struct Unwrap; /// Delegate post-processing to the inner strategy's post-processing routine. impl SearchPostProcessStep, T, O> for Unwrap where - A: BuildQueryComputer, + A: BuildQueryComputer + Accessor, { type Error = NextError @@ -227,6 +230,13 @@ where type Id = Inner::Id; } +impl HasElementRef for BetaAccessor +where + Inner: Accessor, +{ + type ElementRef<'a> = Pair>; +} + impl Accessor for BetaAccessor where Inner: Accessor, @@ -236,7 +246,6 @@ where = Pair> where Self: 'a; - type ElementRef<'a> = Pair>; /// Use the same error type as `Inner`. type GetError = Inner::GetError; @@ -279,10 +288,10 @@ where impl BuildQueryComputer for BetaAccessor where - Inner: BuildQueryComputer, + Inner: BuildQueryComputer + Accessor, { /// Use a [`BetaComputer`] to apply filtering. - type QueryComputer = BetaComputer; + type QueryComputer = BetaComputer; /// Use the same error as `Inner`. type QueryComputerError = Inner::QueryComputerError; @@ -296,7 +305,15 @@ where } } -impl ExpandBeam for BetaAccessor where Inner: BuildQueryComputer + AsNeighbor {} +impl ExpandBeam for BetaAccessor where + Inner: BuildQueryComputer + AsNeighbor + Accessor +{ +} + +impl DistancesUnordered for BetaAccessor where + Inner: BuildQueryComputer + Accessor +{ +} /// A [`PreprocessedDistanceFunction`] that applied `beta` filtering to the inner computer. pub struct BetaComputer { @@ -432,12 +449,15 @@ mod tests { always_escalate!(NotAllowed); + impl HasElementRef for Doubler { + type ElementRef<'a> = u64; + } + impl Accessor for Doubler { type Element<'a> = u64 where Self: 'a; - type ElementRef<'a> = u64; type GetError = NotAllowed; @@ -498,6 +518,8 @@ mod tests { impl ExpandBeam for Doubler {} + impl DistancesUnordered for Doubler {} + #[derive(Debug)] struct SimpleStrategy; diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs new file mode 100644 index 000000000..8ebec490d --- /dev/null +++ b/diskann/src/flat/index.rs @@ -0,0 +1,204 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! [`FlatIndex`] — the index wrapper for a [`DataProvider`](crate::provider::DataProvider) +//! over which we do flat search. +use std::num::NonZeroUsize; + +use diskann_utils::future::SendFuture; + +use crate::{ + ANNResult, + error::{ErrorExt, IntoANNResult}, + flat::{DistancesUnordered, SearchStrategy}, + graph::{SearchOutputBuffer, glue::SearchPostProcess}, + neighbor::{Neighbor, NeighborPriorityQueue}, + provider::{BuildQueryComputer, DataProvider}, +}; + +/// Statistics collected during a flat search. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct SearchStats { + /// The total number of distance computations performed during the scan. + pub cmps: u32, + + /// The total number of results written to the output buffer. + pub result_count: u32, +} + +/// A `'static` thin wrapper around a [`DataProvider`] used for flat search. +/// +/// The provider is owned by the index. The index is constructed once at process startup and +/// shared across requests; per-query state lives in the [`crate::flat::DistancesUnordered`] +/// implementation that the [`SearchStrategy`] produces. +#[derive(Debug)] +pub struct FlatIndex { + /// The backing provider. + provider: P, +} + +impl FlatIndex

{ + /// Construct a new [`FlatIndex`] around `provider`. + pub fn new(provider: P) -> Self { + Self { provider } + } + + /// Borrow the underlying provider. + pub fn provider(&self) -> &P { + &self.provider + } + + /// Brute-force k-nearest-neighbor flat search. + /// + /// Streams every element produced by the strategy's visitor through the query + /// computer, keeps the best `k` candidates in a [`NeighborPriorityQueue`], and hands + /// the survivors to the post-processor. + /// + /// # Arguments + /// - `k`: number of nearest neighbors to return. + /// - `strategy`: produces the per-query iterator and the query computer. See [`SearchStrategy`]. + /// - `processor`: post-processes the survivor candidates into the output type. + /// - `context`: per-request context threaded through to the provider. + /// - `query`: the query. + /// - `output`: caller-owned output buffer. + pub fn knn_search( + &self, + k: NonZeroUsize, + strategy: &S, + processor: &PP, + context: &P::Context, + query: T, + output: &mut OB, + ) -> impl SendFuture> + where + S: SearchStrategy, + T: Copy + Send + Sync, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, + PP: for<'a> SearchPostProcess, T, O> + Send + Sync, + { + async move { + let mut visitor = strategy + .create_visitor(&self.provider, context) + .into_ann_result()?; + + let computer = visitor.build_query_computer(query).into_ann_result()?; + + let k = k.get(); + let mut queue = NeighborPriorityQueue::new(k); + let mut cmps: u32 = 0; + + visitor + .distances_unordered(&computer, |id, dist| { + cmps += 1; + queue.insert(Neighbor::new(id, dist)); + }) + .await + .escalate("flat scan must complete to produce correct k-NN results")?; + + let result_count = processor + .post_process(&mut visitor, query, &computer, queue.iter().take(k), output) + .await + .into_ann_result()? as u32; + + Ok(SearchStats { cmps, result_count }) + } + } +} + +///////////// +// Tests /// +///////////// + +#[cfg(test)] +mod tests { + use crate::flat::{ + FlatIndex, + test::{ + harness::KnnOracleRun, + provider::{self as flat_provider, Strategy}, + }, + }; + use crate::graph::test::synthetic::Grid; + + fn fixture(grid: Grid, size: usize) -> (FlatIndex, usize) { + let provider = flat_provider::Provider::grid(grid, size); + let len = provider.len(); + (FlatIndex::new(provider), len) + } + + /// `knn_search` returns a `Send` future, and a shared `&FlatIndex` can serve + /// many concurrent searches on a multi-threaded runtime, each producing the + /// correct top-k independently. + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn knn_search() { + use std::sync::Arc; + + let (index, len) = fixture(Grid::Two, 4); + let index = Arc::new(index); + + // Mix of corner, axis-aligned, and off-grid queries; k spans 1..=len. + let cases: &[(&[f32], usize)] = &[ + (&[-1.0, -1.0], 1), + (&[1.0, 1.0], len), + (&[-1.0, 1.0], len / 2), + (&[1.0, -1.0], len - 1), + (&[0.0, 0.0], 3), + (&[3.0, 3.0], len), + (&[-2.0, 0.5], 2), + (&[0.5, -0.5], len), + ]; + + let mut set = tokio::task::JoinSet::new(); + for (query, k) in cases { + let index = Arc::clone(&index); + let query: Vec = query.to_vec(); + let k = *k; + set.spawn(async move { + let outcome = KnnOracleRun::run(&index, &Strategy::new(), &query, k) + .await + .expect("knn_search failed"); + (query, k, outcome) + }); + } + + while let Some(joined) = set.join_next().await { + let (query, k, outcome) = joined.expect("task panicked"); + assert_eq!( + outcome.top_k, outcome.ground_truth, + "query = {query:?}, k = {k}: top-k must match brute force", + ); + assert_eq!(outcome.stats.cmps as usize, len); + assert_eq!(outcome.stats.result_count as usize, k.min(len)); + } + } + + /// A transient error from the visitor's scan must escalate up through `knn_search`. + #[test] + fn transient_scan_error() { + let (index, _len) = fixture(Grid::Two, 3); + + // The flat scan must touch every id, so any transient id is guaranteed to be + // hit. + for transient_ids in [&[0u32][..], &[3][..], &[1, 2, 5][..]] { + let err = KnnOracleRun::run_sync( + &index, + &Strategy::with_transient(transient_ids.iter().copied()), + &[1.0, 0.0], + 4, + ) + .expect_err("transient error during full scan must escalate"); + + let msg = format!("{err}"); + assert!( + transient_ids + .iter() + .any(|id| msg.contains(&format!("id {id}"))), + "transients = {transient_ids:?}: expected error to name one of the \ + transient ids, got: {msg}", + ); + } + } +} diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs new file mode 100644 index 000000000..7bad2d131 --- /dev/null +++ b/diskann/src/flat/iterator.rs @@ -0,0 +1,615 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Sequential ("flat") access primitives. +//! +//! This module defines the traits that flat-search algorithms use to walk every element +//! of a [`DataProvider`](crate::provider::DataProvider) once. +//! +//! * [`DistancesUnordered`]: the single trait flat search consumes. It takes a +//! pre-built query computer and a callback, applies the callback to every +//! `(id, distance)` pair in the provider, and is the only trait an in-memory +//! visitor (such as [`crate::flat::test::provider::Visitor`]) needs to implement. +//! The super-traits [`HasId`] and [`BuildQueryComputer`] define the id and +//! query-computer types. +//! +//! * [`FlatIterator`]: a convenient entry point for backends whose natural shape is +//! element-at-a-time iteration. The trait exposes a single `next` method and an +//! associated `Element<'_>` type that must be [`Reborrow`]able to the `ElementRef<'_>` +//! exposed via the [`HasElementRef`] super-trait. +//! +//! * [`Iterated`]: bridges any [`FlatIterator`] implementation into a +//! [`DistancesUnordered`] by looping over [`FlatIterator::next`] and scoring each +//! element with the supplied computer. + +use std::fmt::Debug; + +use diskann_utils::{Reborrow, future::SendFuture}; +use diskann_vector::PreprocessedDistanceFunction; + +use crate::{ + error::ToRanked, + provider::{BuildQueryComputer, HasElementRef, HasId}, +}; + +/// Fused iterate-and-score primitive over the elements of a flat index. +/// +/// Implementations drive an entire scan over the underlying data, scoring each +/// element with the supplied [`BuildQueryComputer::QueryComputer`] and invoking +/// `f` with the resulting `(id, distance)` pair. The super-trait +/// [`BuildQueryComputer`] supplies the computer type. +pub trait DistancesUnordered: HasId + BuildQueryComputer + Send + Sync { + /// The error type yielded by [`Self::distances_unordered`]. + type Error: ToRanked + Debug + Send + Sync + 'static; + + /// Drive the entire scan, scoring each element with `computer` and invoking `f` + /// with the resulting `(id, distance)` pair. + fn distances_unordered( + &mut self, + computer: &>::QueryComputer, + f: F, + ) -> impl SendFuture> + where + F: Send + FnMut(::Id, f32); +} + +////////////// +// Iterator // +////////////// + +/// A lending, asynchronous iterator over the elements of a flat index. +/// +/// Implementations provide element-at-a-time access via [`Self::next`]. Providers that +/// only implement `FlatIterator` can be wrapped in [`Iterated`] to obtain a +/// default [`DistancesUnordered`] implementation. +pub trait FlatIterator: HasId + HasElementRef + Send + Sync { + /// The concrete element returned by [`Self::next`]. Reborrows to [`Self::ElementRef`]. + type Element<'a>: for<'b> Reborrow<'b, Target = ::ElementRef<'b>> + + Send + + Sync + where + Self: 'a; + + /// The error type yielded by [`Self::next`]. + type Error: ToRanked + Debug + Send + Sync + 'static; + + /// Advance the iterator and asynchronously yield the next `(id, element)` pair. + /// + /// Returns `Ok(None)` when the scan is exhausted. The yielded element borrows from + /// the iterator and is invalidated by the next call to `next`. + #[allow(clippy::type_complexity)] + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>>; +} + +///////////// +// Default // +///////////// + +/// Bridges a [`FlatIterator`] into a [`DistancesUnordered`] by looping over +/// [`FlatIterator::next`], reborrowing each element, and scoring it with the +/// supplied computer. +/// +/// This is the default adapter for providers that implement element-at-a-time +/// iteration. Providers that can do better (prefetching, SIMD batching, bulk +/// I/O) should implement [`DistancesUnordered`] directly. +pub struct Iterated { + inner: I, +} + +impl Iterated { + /// Wrap an iterator to produce a [`DistancesUnordered`] implementation. + pub fn new(inner: I) -> Self { + Self { inner } + } + + /// Unwrap, returning the inner iterator. + pub fn into_inner(self) -> I { + self.inner + } +} + +impl HasId for Iterated { + type Id = I::Id; +} + +impl HasElementRef for Iterated { + type ElementRef<'a> = I::ElementRef<'a>; +} + +/// Forwards the inner iterator's [`BuildQueryComputer`] impl through the wrapper +/// so that callers (and the [`DistancesUnordered`] blanket below) can obtain the +/// query computer from the [`Iterated`] adapter directly. +impl BuildQueryComputer for Iterated +where + I: BuildQueryComputer, +{ + type QueryComputerError = I::QueryComputerError; + type QueryComputer = I::QueryComputer; + + fn build_query_computer( + &self, + from: T, + ) -> Result { + self.inner.build_query_computer(from) + } +} + +/// The blanket implementation of [`DistancesUnordered`] for any +/// [`FlatIterator`] that also exposes a [`BuildQueryComputer`]. +impl DistancesUnordered for Iterated +where + I: FlatIterator + BuildQueryComputer + Send + Sync, +{ + type Error = I::Error; + + fn distances_unordered( + &mut self, + computer: &Self::QueryComputer, + mut f: F, + ) -> impl SendFuture> + where + F: Send + FnMut(::Id, f32), + { + async move { + while let Some((id, element)) = self.inner.next().await? { + let dist = computer.evaluate_similarity(element.reborrow()); + f(id, dist); + } + Ok(()) + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + fmt::Debug, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + }; + + use diskann_utils::Reborrow; + use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; + + use super::*; + use crate::{ + ANNError, always_escalate, + error::Infallible, + provider::{BuildQueryComputer, HasElementRef, HasId}, + utils::VectorRepr, + }; + + /////////////////////////// + // Shared sample dataset // + /////////////////////////// + + /// Canonical sample dataset shared by every contract test below. + fn sample_items() -> Vec<(u32, Vec)> { + vec![ + (10, vec![0.0, 0.0]), + (11, vec![1.0, 0.0]), + (12, vec![0.0, 2.0]), + ] + } + + /// Backing store of `[f32]` vectors, used by every element-shape fixture + /// below to cover [`FlatIterator::Element`] variants without re-implementing + /// the data layout each time. + struct Store { + items: Vec<(u32, Vec)>, + } + + impl Store { + fn sample() -> Self { + Self { + items: sample_items(), + } + } + } + + /////////////////////// + // Common impl macro // + /////////////////////// + + /// Implement [`HasId`], [`HasElementRef`], and [`BuildQueryComputer`] for an + /// iterator type. Every fixture in this module shares these impls — only + /// [`FlatIterator::Element`] varies. The [`DistancesUnordered`] impl on + /// `Iterated<$T>` comes from the blanket impl in the parent module, so does + /// not need to be repeated here. + macro_rules! common_iterator_impls { + ($T:ty) => { + impl HasId for $T { + type Id = u32; + } + + impl HasElementRef for $T { + type ElementRef<'a> = &'a [f32]; + } + + impl BuildQueryComputer<&[f32]> for $T { + type QueryComputerError = Infallible; + type QueryComputer = ::QueryDistance; + + fn build_query_computer( + &self, + from: &[f32], + ) -> Result { + Ok(f32::query_distance(from, Metric::L2)) + } + } + }; + } + + ///////////////////////////////// + // Allocating: Element = Vec // + ///////////////////////////////// + + /// `Element<'a> = Vec` — owns its data, reborrows to `&'a [f32]`. + /// Mirrors the `Allocating` accessor in [`crate::provider`]. + struct Allocating<'a> { + store: &'a Store, + cursor: usize, + } + + impl<'a> Allocating<'a> { + fn new(store: &'a Store) -> Self { + Self { store, cursor: 0 } + } + } + + common_iterator_impls!(Allocating<'_>); + + impl FlatIterator for Allocating<'_> { + type Element<'a> + = Vec + where + Self: 'a; + type Error = Infallible; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>> { + async move { + let i = self.cursor; + if i >= self.store.items.len() { + return Ok(None); + } + self.cursor += 1; + let (id, ref v) = self.store.items[i]; + Ok(Some((id, v.clone()))) + } + } + } + + //////////////////////////////////////////////// + // Forwarding: Element = &'store [f32] // + //////////////////////////////////////////////// + + /// `Element<'a> = &'store [f32]` — borrows directly out of the underlying + /// store. The element lifetime is tied to the *store* (not the iterator), + /// proving the trait supports forwarding accessors. Mirrors the + /// `Forwarding` accessor in [`crate::provider`]. + struct Forwarding<'store> { + store: &'store Store, + cursor: usize, + } + + impl<'store> Forwarding<'store> { + fn new(store: &'store Store) -> Self { + Self { store, cursor: 0 } + } + } + + common_iterator_impls!(Forwarding<'_>); + + impl<'store> FlatIterator for Forwarding<'store> { + type Element<'a> + = &'store [f32] + where + Self: 'a; + type Error = Infallible; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>> { + async move { + let i = self.cursor; + if i >= self.store.items.len() { + return Ok(None); + } + self.cursor += 1; + let (id, ref v) = self.store.items[i]; + Ok(Some((id, v.as_slice()))) + } + } + } + + ///////////////////////////////////////////////////////// + // Wrapping: Element = guard-shaped non-ref `Wrapped` // + ///////////////////////////////////////////////////////// + + /// A guard-shaped element that reborrows to `&'b [f32]` and counts its own + /// drops. Mirrors the `Wrapping` accessor's `Wrapped<'a>` in + /// [`crate::provider`], plus a [`Drop`] hook to verify the [`Iterated`] + /// adapter does not leak guards. + struct Wrapped<'g> { + data: &'g [f32], + drop_count: Arc, + } + + impl<'b> Reborrow<'b> for Wrapped<'_> { + type Target = &'b [f32]; + fn reborrow(&'b self) -> Self::Target { + self.data + } + } + + impl Drop for Wrapped<'_> { + fn drop(&mut self) { + self.drop_count.fetch_add(1, Ordering::SeqCst); + } + } + + /// `Element<'a> = Wrapped<'a>`. + struct Wrapping<'a> { + store: &'a Store, + cursor: usize, + drop_count: Arc, + } + + impl<'a> Wrapping<'a> { + fn new(store: &'a Store) -> Self { + Self { + store, + cursor: 0, + drop_count: Arc::new(AtomicUsize::new(0)), + } + } + } + + common_iterator_impls!(Wrapping<'_>); + + impl FlatIterator for Wrapping<'_> { + type Element<'a> + = Wrapped<'a> + where + Self: 'a; + type Error = Infallible; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>> { + async move { + let i = self.cursor; + if i >= self.store.items.len() { + return Ok(None); + } + self.cursor += 1; + let (id, ref v) = self.store.items[i]; + Ok(Some(( + id, + Wrapped { + data: v.as_slice(), + drop_count: self.drop_count.clone(), + }, + ))) + } + } + } + + ///////////////////////////////////////////////// + // Sharing: Element = &'a [f32] via local buf // + ///////////////////////////////////////////////// + + /// `Element<'a> = &'a [f32]` — copies into an internal buffer per `next()` + /// to avoid per-call allocation. Mirrors the `Sharing` accessor in + /// [`crate::provider`]. + struct Sharing<'a> { + store: &'a Store, + cursor: usize, + buf: Vec, + } + + impl<'a> Sharing<'a> { + fn new(store: &'a Store) -> Self { + Self { + store, + cursor: 0, + buf: Vec::new(), + } + } + } + + common_iterator_impls!(Sharing<'_>); + + impl FlatIterator for Sharing<'_> { + type Element<'a> + = &'a [f32] + where + Self: 'a; + type Error = Infallible; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>> { + async move { + let i = self.cursor; + if i >= self.store.items.len() { + return Ok(None); + } + self.cursor += 1; + let (id, ref v) = self.store.items[i]; + self.buf.clear(); + self.buf.extend_from_slice(v); + Ok(Some((id, self.buf.as_slice()))) + } + } + } + + //////////////////////// + // Failing iterator // + //////////////////////// + + /// A critical (non-recoverable) error type the [`Failing`] iterator yields. + #[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] + #[error("synthetic iterator failure at id {0}")] + struct Boom(u32); + + always_escalate!(Boom); + + impl From for ANNError { + #[track_caller] + fn from(boom: Boom) -> ANNError { + ANNError::opaque(boom) + } + } + + /// `Element<'a> = &'a [f32]`, but `next()` returns `Err(Boom(id))` exactly + /// once after `fail_after` successful yields. Used to verify error + /// propagation through [`Iterated`]'s [`DistancesUnordered`] impl. + struct Failing<'a> { + store: &'a Store, + cursor: usize, + fail_after: usize, + } + + impl HasId for Failing<'_> { + type Id = u32; + } + + impl HasElementRef for Failing<'_> { + type ElementRef<'a> = &'a [f32]; + } + + impl BuildQueryComputer<&[f32]> for Failing<'_> { + type QueryComputerError = Infallible; + type QueryComputer = ::QueryDistance; + + fn build_query_computer( + &self, + from: &[f32], + ) -> Result { + Ok(f32::query_distance(from, Metric::L2)) + } + } + + impl FlatIterator for Failing<'_> { + type Element<'a> + = &'a [f32] + where + Self: 'a; + type Error = Boom; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>> { + async move { + let i = self.cursor; + if i >= self.store.items.len() { + return Ok(None); + } + self.cursor += 1; + let (id, ref v) = self.store.items[i]; + if i == self.fail_after { + return Err(Boom(id)); + } + Ok(Some((id, v.as_slice()))) + } + } + } + + ///////////// + // Helpers // + ///////////// + + /// Build the canonical `(id, distance)` ground-truth list for a query under + /// L2, against [`sample_items`]. + fn expected_distances(query: &[f32]) -> Vec<(u32, f32)> { + let computer = f32::query_distance(query, Metric::L2); + sample_items() + .into_iter() + .map(|(id, v)| (id, computer.evaluate_similarity(v.as_slice()))) + .collect() + } + + /////////// + // Tests // + /////////// + + /// The blanket [`DistancesUnordered`] impl on [`Iterated`] produces the + /// correct `(id, distance)` pairs for every supported + /// [`FlatIterator::Element`] shape: owning, forwarding, guard-wrapped, and + /// shared-buffer. + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn distances_unordered() { + let store = Store::sample(); + let query = vec![0.5_f32, 0.9]; + let expected = expected_distances(&query); + + async fn run(mut visitor: Iterated, query: &[f32], expected: &[(u32, f32)]) + where + I: FlatIterator + Send + Sync, + I: for<'a> HasElementRef = &'a [f32]>, + Iterated: HasId + for<'q> DistancesUnordered<&'q [f32]>, + { + let computer = visitor.build_query_computer(query).unwrap(); + let mut seen: Vec<(u32, f32)> = Vec::new(); + visitor + .distances_unordered(&computer, |id, d| seen.push((id, d))) + .await + .unwrap(); + assert_eq!(seen, expected); + } + + // Allocating: Element = Vec (owns). + run(Iterated::new(Allocating::new(&store)), &query, &expected).await; + + // Forwarding: Element = &'store [f32] (borrows from store). + run(Iterated::new(Forwarding::new(&store)), &query, &expected).await; + + // Round-trip through `Iterated::into_inner` to exercise the unwrap path. + let recovered = Iterated::new(Forwarding::new(&store)).into_inner(); + run(Iterated::new(recovered), &query, &expected).await; + + // Wrapping: Element = Wrapped<'a> (guard-shaped non-ref). + run(Iterated::new(Wrapping::new(&store)), &query, &expected).await; + + // Sharing: Element = &'a [f32] (per-call internal buffer). + run(Iterated::new(Sharing::new(&store)), &query, &expected).await; + } + + /// An error returned mid-iteration by [`FlatIterator::next`] propagates up + /// through the [`Iterated`] adapter's [`DistancesUnordered`] impl, and the + /// closure stops being invoked at the failure point. + #[tokio::test] + async fn failures_midstream() { + let store = Store::sample(); + let mut visitor = Iterated::new(Failing { + store: &store, + cursor: 0, + fail_after: 1, // Yield item 0 successfully, fail on item 1. + }); + + let query = vec![0.0_f32, 0.0]; + let computer = visitor.build_query_computer(query.as_slice()).unwrap(); + + let mut seen: Vec = Vec::new(); + let err = visitor + .distances_unordered(&computer, |id, _d| seen.push(id)) + .await + .expect_err("Failing iterator must surface its error"); + + assert_eq!(err, Boom(11)); + assert_eq!( + seen, + vec![10], + "the closure must only see items yielded before the failure", + ); + } +} diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs new file mode 100644 index 000000000..8f3d66306 --- /dev/null +++ b/diskann/src/flat/mod.rs @@ -0,0 +1,38 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Sequential ("flat") search infrastructure. +//! +//! This module is the streaming counterpart to the random-access [`crate::provider::Accessor`] +//! family. It is designed for backends whose natural access pattern is a one-pass scan over +//! their data — for example append-only buffered stores, on-disk shards streamed via I/O, +//! or any provider where random access is significantly more expensive than sequential. +//! +//! # Architecture +//! +//! The module mirrors the layering used by graph search: +//! +//! | Graph (random access) | Flat (sequential) | Shared? | +//! | :------------------------------------ | :----------------------------------------- |:--------- | +//! | [`crate::provider::DataProvider`] | [`crate::provider::DataProvider`] | Yes | +//! | [`crate::graph::DiskANNIndex`] | [`FlatIndex`] | No | +//! | [`crate::provider::DistancesUnordered`] | [`DistancesUnordered`] | No | +//! | [`crate::graph::glue::SearchStrategy`] | [`SearchStrategy`] | No | +//! | [`crate::graph::glue::SearchPostProcess`] | [`crate::graph::glue::SearchPostProcess`] | Yes | +//! | [`crate::graph::Search`] | [`FlatIndex::knn_search`] | No | +//! +//! See [`FlatIndex::knn_search`] for the canonical brute-force k-NN algorithm built on these +//! primitives. + +pub mod index; +pub mod iterator; +pub mod strategy; + +pub use index::{FlatIndex, SearchStats}; +pub use iterator::{DistancesUnordered, FlatIterator, Iterated}; +pub use strategy::SearchStrategy; + +#[cfg(test)] +mod test; diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs new file mode 100644 index 000000000..a3ff980b1 --- /dev/null +++ b/diskann/src/flat/strategy.rs @@ -0,0 +1,78 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! [`SearchStrategy`] — glue between [`DataProvider`] and per-query +//! [`DistancesUnordered`] visitors. + +use crate::{error::StandardError, flat::DistancesUnordered, provider::DataProvider}; + +/// Per-call configuration that knows how to construct a per-query +/// [`DistancesUnordered`] visitor for a provider. +/// +/// `SearchStrategy` is the flat counterpart to [`crate::graph::glue::SearchStrategy`] +/// (disambiguated by module path). A strategy instance carries the per-query setup +/// recipe; the per-query mutable state lives in the visitor it produces, so a single +/// strategy may be reused across many searches. +/// +/// The strategy itself is a pure factory; the visitor it produces carries the +/// query-preprocessing capability via [`crate::provider::BuildQueryComputer`] (a +/// super-trait of [`DistancesUnordered`]). +pub trait SearchStrategy: Send + Sync +where + P: DataProvider, +{ + /// The visitor type produced by [`Self::create_visitor`]. Borrows from `self` and the + /// provider. + /// + /// The visitor implements both the streaming [`DistancesUnordered`] primitive and + /// the query preprocessor [`crate::provider::BuildQueryComputer`]. + type Visitor<'a>: DistancesUnordered + where + Self: 'a, + P: 'a; + + /// The error type for [`Self::create_visitor`]. + type Error: StandardError; + + /// Construct a fresh visitor over `provider` for the given request `context`. + /// + /// This is where lock acquisition, snapshot pinning, and any other per-query setup + /// should happen. The returned visitor owns whatever borrows / guards it needs to + /// remain valid until it is dropped. + fn create_visitor<'a>( + &'a self, + provider: &'a P, + context: &'a P::Context, + ) -> Result, Self::Error>; +} + +#[cfg(test)] +mod tests { + use crate::{ + flat::test::provider::{self as flat_provider, Strategy}, + graph::test::synthetic::Grid, + }; + + use super::SearchStrategy; + + /// `create_visitor` produces independent visitors on successive calls. + /// + /// The strategy is a stateless factory; calling it twice should yield two + /// distinct visitors that may be used in parallel without interfering with + /// each other. + #[test] + fn exercise_create_visitor() { + let provider = flat_provider::Provider::grid(Grid::Two, 3); + let context = flat_provider::Context::new(); + let strategy = Strategy::new(); + + let v1 = strategy.create_visitor(&provider, &context).unwrap(); + let v2 = strategy.create_visitor(&provider, &context).unwrap(); + + // The two visitors must occupy distinct stack slots — i.e. holding `v1` + // does not preclude constructing `v2`. + let _ = (&v1, &v2); + } +} diff --git a/diskann/src/flat/test/cases/flat_knn_search.rs b/diskann/src/flat/test/cases/flat_knn_search.rs new file mode 100644 index 000000000..76a00c33c --- /dev/null +++ b/diskann/src/flat/test/cases/flat_knn_search.rs @@ -0,0 +1,204 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Baseline-cached regression sweep for [`crate::flat::FlatIndex::knn_search`]. +//! +//! Bbuilds a fresh index per parameter combination, runs `knn_search` through the +//! [`crate::flat::test::harness`], snapshots the result + statistics into +//! [`FlatKnnBaseline`], and compares the entire batch against the JSON committed under +//! `diskann/test/generated/flat/test/cases/flat_knn_search/`. + +use crate::{ + flat::{ + FlatIndex, + test::{ + harness, + provider::{self as flat_provider, Metrics, Strategy}, + }, + }, + graph::test::synthetic::Grid, + test::{ + TestPath, TestRoot, + cmp::{assert_eq_verbose, verbose_eq}, + get_or_save_test_results, + }, +}; + +fn root() -> TestRoot { + TestRoot::new("flat/test/cases/flat_knn_search") +} + +/// `k` values exercised for every `(grid, query)` combination. +const KS: [usize; 3] = [1, 4, 10]; + +/// One row of the baseline JSON: a single `(grid, size, query, k)` execution of +/// `FlatIndex::knn_search` plus the brute-force ground truth, search stats, and +/// per-row provider metrics. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct FlatKnnBaseline { + /// Free-form description of what this row exercises. + description: String, + + /// The query vector. + query: Vec, + + /// The dimensionality of the underlying grid. + grid_dims: usize, + + /// The side length of the underlying grid. + grid_size: usize, + + /// The requested `k`. + k: usize, + + /// Sorted distance multiset of the top-`k` returned by `knn_search`. + /// We store the distance multiset rather than `(id, distance)` pairs because + /// the priority queue may evict different *ids* on a boundary distance tie + /// (the queue's tie-breaking is heap-internal, not id-based) — but the + /// multiset of distances is invariant. + top_k_distances: Vec, + + /// Brute-force ground-truth top-`k` `(id, distance)` (sorted by `(distance asc, + /// id asc)`). The brute-force pass enumerates ids in ascending order, so on a + /// tie this prefers the smaller id and gives a canonical answer for the JSON. + ground_truth: Vec<(u32, f32)>, + + /// `cmps` reported by `knn_search`. Must equal `provider.len()`. + comparisons: usize, + + /// `result_count` reported by `knn_search`. Must equal `min(k, provider.len())`. + result_count: usize, + + /// Per-provider metrics observed for this row (see [`Metrics`]). + metrics: Metrics, +} + +verbose_eq!(FlatKnnBaseline { + description, + query, + grid_dims, + grid_size, + k, + top_k_distances, + ground_truth, + comparisons, + result_count, + metrics, +}); + +/// Run `knn_search` + brute-force oracle against a *shared* `index`, assert the +/// cross-row invariants, and produce the baseline row. The per-row provider metrics +/// captured into the baseline are the *delta* observed during this row, which keeps +/// the snapshot independent of how many rows preceded it. +fn run_row( + index: &FlatIndex, + grid_dim: usize, + grid_size: usize, + query: &[f32], + k: usize, + desc: &str, +) -> FlatKnnBaseline { + let len = index.provider().len(); + let metrics_before = index.provider().metrics(); + + let outcome = harness::KnnOracleRun::run_sync(index, &Strategy::new(), query, k).unwrap(); + let stats = outcome.stats; + + assert_eq!( + stats.cmps as usize, len, + "flat scan must touch every element exactly once", + ); + assert_eq!( + stats.result_count as usize, + k.min(len), + "result_count must equal min(k, provider.len())", + ); + + let gt_distances: Vec = outcome.ground_truth.iter().map(|(_, d)| *d).collect(); + assert_eq!( + outcome.top_k_distances, gt_distances, + "flat scan top-k distance multiset must agree with brute force", + ); + + let metrics_after = index.provider().metrics(); + let metrics = Metrics { + get_element: metrics_after.get_element - metrics_before.get_element, + }; + // `get_element` is incremented only by the [`Visitor`] used during `knn_search`; + // the brute-force oracle iterates `Provider::items()` directly and does not touch + // the visitor, so we expect exactly one scan's worth of increments per row. + assert_eq!( + metrics.get_element, len, + "expected exactly one scan (from knn_search) to increment get_element", + ); + + FlatKnnBaseline { + description: desc.to_string(), + query: query.to_vec(), + grid_dims: grid_dim, + grid_size, + k, + top_k_distances: outcome.top_k_distances, + ground_truth: outcome.ground_truth, + comparisons: stats.cmps as usize, + result_count: stats.result_count as usize, + metrics, + } +} + +/// Sweep [`KS`] × `queries` for the given `(grid, size)` and snapshot the results. +fn _flat_knn_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { + let dim: usize = grid.dim().into(); + + // Build the provider and index once, mirroring the production pattern where a + // single index serves many queries. + let provider = flat_provider::Provider::grid(grid, size); + let len = provider.len(); + assert_eq!( + len, + size.pow(dim as u32), + "flat::test::Provider::grid should produce size^dim rows", + ); + let index = FlatIndex::new(provider); + + let queries: [(Vec, &str); 2] = [ + ( + vec![-1.0; dim], + "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + ), + ( + vec![(size - 1) as f32; dim], + "All `size-1`: query coincides with the last grid corner.", + ), + ]; + + let index_ref = &index; + let results: Vec = queries + .iter() + .flat_map(|(q, desc)| { + KS.iter() + .map(move |&k| run_row(index_ref, dim, size, q, k, desc)) + }) + .collect(); + + let name = parent.push(format!("search_{dim}_{size}")); + let expected = get_or_save_test_results(&name, &results); + assert_eq_verbose!(expected, results); +} + +#[test] +fn flat_knn_search_1_100() { + _flat_knn_search(Grid::One, 100, root().path()); +} + +#[test] +fn flat_knn_search_2_5() { + _flat_knn_search(Grid::Two, 5, root().path()); +} + +#[test] +fn flat_knn_search_3_4() { + _flat_knn_search(Grid::Three, 4, root().path()); +} diff --git a/diskann/src/flat/test/cases/mod.rs b/diskann/src/flat/test/cases/mod.rs new file mode 100644 index 000000000..ffaf5b91c --- /dev/null +++ b/diskann/src/flat/test/cases/mod.rs @@ -0,0 +1,6 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +mod flat_knn_search; diff --git a/diskann/src/flat/test/harness.rs b/diskann/src/flat/test/harness.rs new file mode 100644 index 000000000..af8500511 --- /dev/null +++ b/diskann/src/flat/test/harness.rs @@ -0,0 +1,147 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Reusable execution harness for [`crate::flat::FlatIndex`] tests. +//! +//! Centralizes every piece of repeated boilerplate that surrounds a `knn_search` +//! call so individual tests stay close to a one-line statement of intent: +//! +//! * runtime construction (`current_thread_runtime`), +//! * scratch buffer + [`BackInserter`] plumbing, +//! * [`NonZeroUsize`] wrapping and [`CopyIds`] selection, +//! * brute-force ground-truth oracle, +//! * canonical `(distance asc, id asc)` re-sort that matches the oracle's +//! tie-breaking (the priority queue's heap-pop order is *not* id-stable on +//! ties, but the oracle's is). +//! +//! Use [`KnnOracleRun::new`] to drive `knn_search` and pair the result with the +//! brute-force ground truth. + +use std::{cmp::Ordering, num::NonZeroUsize}; + +use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; + +use crate::{ + ANNResult, + flat::{ + FlatIndex, SearchStats, + test::provider::{Provider, Strategy}, + }, + graph::glue::CopyIds, + neighbor::{BackInserter, Neighbor}, + test::tokio::current_thread_runtime, + utils::VectorRepr, +}; + +/// Result of running [`FlatIndex::knn_search`] under the harness alongside a +/// brute-force ground-truth oracle. +#[derive(Debug, Clone)] +pub(crate) struct KnnOracleRun { + /// Top-`k` `(id, distance)` pairs in canonical `(distance asc, id asc)` order. + /// Re-sorted from the heap output so equality checks are deterministic on ties. + pub top_k: Vec<(u32, f32)>, + /// `top_k.iter().map(|(_, d)| d).collect()`. + pub top_k_distances: Vec, + /// Statistics returned by `knn_search` (cmps, result_count). + pub stats: SearchStats, + /// Brute-force ground-truth top-`k` `(id, distance)` pairs in `(distance asc, + /// id asc)` order. + pub ground_truth: Vec<(u32, f32)>, +} + +impl KnnOracleRun { + /// Run [`FlatIndex::knn_search`] once, blocking on a fresh single-threaded + /// runtime, and pair the result with the brute-force ground truth. + pub fn run_sync( + index: &FlatIndex, + strategy: &Strategy, + query: &[f32], + k: usize, + ) -> ANNResult { + current_thread_runtime().block_on(Self::run(index, strategy, query, k)) + } + + /// Async variant of [`KnnOracleRun::new`]. Use this from tests that already + /// have a Tokio runtime (e.g. `#[tokio::test]`) or that need to drive + /// `knn_search` concurrently across tasks. + pub async fn run( + index: &FlatIndex, + strategy: &Strategy, + query: &[f32], + k: usize, + ) -> ANNResult { + let context = crate::flat::test::provider::Context::new(); + let mut buf = vec![Neighbor::::default(); k]; + + let stats = index + .knn_search( + NonZeroUsize::new(k).expect("flat::test::harness requires k > 0"), + strategy, + &CopyIds, + &context, + query, + &mut BackInserter::new(buf.as_mut_slice()), + ) + .await?; + + let top_k = top_k_sorted(&buf, stats.result_count as usize); + let top_k_distances = top_k.iter().map(|(_, d)| *d).collect(); + let ground_truth = brute_force_topk(index.provider(), Metric::L2, query, k); + + Ok(Self { + top_k, + top_k_distances, + stats, + ground_truth, + }) + } +} + +/// Compute the brute-force top-`k` `(id, distance)` pairs over every element of +/// `provider` under `metric`. Iterates [`Provider::items`] directly and scores with +/// a fresh [`f32::query_distance`] computer, so the oracle is independent of the +/// [`crate::flat::test::provider::Visitor`] under test. Ties are broken by ascending +/// id for determinism. +pub(crate) fn brute_force_topk( + provider: &Provider, + metric: Metric, + query: &[f32], + k: usize, +) -> Vec<(u32, f32)> { + let computer = f32::query_distance(query, metric); + + let mut neighbors: Vec> = provider + .items() + .iter() + .enumerate() + .map(|(id, element)| Neighbor::new(id as u32, computer.evaluate_similarity(element))) + .collect(); + + sort_neighbors(&mut neighbors); + neighbors + .into_iter() + .take(k) + .map(|n| n.as_tuple()) + .collect() +} + +/// Take the first `result_count` neighbors and return them in `(distance asc, id asc)` +/// order. +fn top_k_sorted(buf: &[Neighbor], result_count: usize) -> Vec<(u32, f32)> { + let mut neighbors: Vec> = buf.iter().copied().take(result_count).collect(); + sort_neighbors(&mut neighbors); + neighbors.into_iter().map(|n| n.as_tuple()).collect() +} + +/// Sort a slice of [`Neighbor`] by `(distance asc, id asc)`. NaN distances are +/// treated as equal (test data should not produce NaN). +fn sort_neighbors(neighbors: &mut [Neighbor]) { + neighbors.sort_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(Ordering::Equal) + .then(a.id.cmp(&b.id)) + }); +} diff --git a/diskann/src/flat/test/mod.rs b/diskann/src/flat/test/mod.rs new file mode 100644 index 000000000..8cbd01622 --- /dev/null +++ b/diskann/src/flat/test/mod.rs @@ -0,0 +1,10 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Test fixtures and helpers for the flat module. +pub(crate) mod harness; +pub(crate) mod provider; + +mod cases; diff --git a/diskann/src/flat/test/provider.rs b/diskann/src/flat/test/provider.rs new file mode 100644 index 000000000..c6464b020 --- /dev/null +++ b/diskann/src/flat/test/provider.rs @@ -0,0 +1,406 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +#![allow(dead_code)] + +//! Self-contained test provider for the flat-search module. + +use std::{ + borrow::Cow, + collections::HashSet, + fmt::{self, Debug}, + future::Future, + sync::Arc, +}; + +use diskann_utils::future::SendFuture; +use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; +use thiserror::Error; + +use crate::{ + ANNError, always_escalate, + error::{Infallible, RankedError, ToRanked, TransientError}, + flat::{DistancesUnordered, SearchStrategy}, + graph::test::synthetic::Grid, + internal::counter::{Counter, LocalCounter}, + provider::{self, BuildQueryComputer, ExecutionContext, HasElementRef, HasId, NoopGuard}, + utils::VectorRepr, +}; + +////////////// +// Provider // +////////////// + +/// In-memory test provider for flat search. +#[derive(Debug)] +pub struct Provider { + items: Vec>, + dim: usize, + get_element: Counter, +} + +impl Provider { + /// Construct a provider that owns `items`. Every vector must have the same + /// (non-zero) length. + pub fn new(items: impl IntoIterator>) -> Self { + let items: Vec> = items.into_iter().collect(); + assert!( + !items.is_empty(), + "flat::test::Provider needs at least one item" + ); + let dim = items[0].len(); + assert!( + dim > 0, + "flat::test::Provider items must have non-zero dimension" + ); + for (i, v) in items.iter().enumerate() { + assert_eq!( + v.len(), + dim, + "flat::test::Provider item {i} has dim {} but expected {dim}", + v.len(), + ); + } + Self { + items, + dim, + get_element: Counter::new(), + } + } + + /// Build a provider over the row vectors of [`Grid::data`]. IDs are `0..n` in + /// row-major order (last coordinate varies fastest). + /// + /// Unlike the graph-side `Provider::grid`, this does *not* add a separate + /// start-point row — flat search has no notion of one. + pub fn grid(grid: Grid, size: usize) -> Self { + let data = grid.data(size); + let items: Vec> = data.row_iter().map(|row| row.to_vec()).collect(); + Self::new(items) + } + + /// Number of vectors in the provider. + pub fn len(&self) -> usize { + self.items.len() + } + + /// Snapshot of the per-provider counters. + pub fn metrics(&self) -> Metrics { + Metrics { + get_element: self.get_element.value(), + } + } + + /// Expose the items for brute force. + pub fn items(&self) -> &[Vec] { + self.items.as_slice() + } +} + +/// Counters tracked by [`Provider`]. +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(serde::Serialize, serde::Deserialize))] +pub struct Metrics { + /// The number of times any [`Visitor`] yielded an element. + pub get_element: usize, +} + +#[cfg(test)] +crate::test::cmp::verbose_eq!(Metrics { get_element }); + +///////////// +// Context // +///////////// + +/// Per-search execution context. No spawn/clone tracking — flat search runs on +/// the calling task and never spawns. +#[derive(Debug, Clone, Default)] +pub struct Context; + +impl Context { + pub fn new() -> Self { + Self + } +} + +impl ExecutionContext for Context { + fn wrap_spawn(&self, f: F) -> impl Future + Send + 'static + where + F: Future + Send + 'static, + { + f + } +} + +///////////////////// +// Errors / Guards // +///////////////////// + +/// Critical id-validation error: the requested id is out of range. +#[derive(Debug, Clone, Copy, Error, PartialEq, Eq)] +#[error("flat::test::Provider has no id {0}")] +pub struct InvalidId(pub u32); + +always_escalate!(InvalidId); + +impl From for ANNError { + #[track_caller] + fn from(err: InvalidId) -> ANNError { + ANNError::opaque(err) + } +} + +/// Transient access error injected by [`Visitor::flaky`]. +/// +/// Matches the shape of `graph::test::TransientAccessError`: panics in `Drop` if it +/// is dropped without being acknowledged or escalated. This guards against accidental +/// silent suppression of the error in the test code itself. +#[must_use] +#[derive(Debug)] +pub struct TransientGetError { + id: u32, + handled: bool, +} + +impl TransientGetError { + fn new(id: u32) -> Self { + Self { id, handled: false } + } +} + +impl fmt::Display for TransientGetError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "transient failure retrieving id {}", self.id) + } +} + +impl std::error::Error for TransientGetError {} + +impl Drop for TransientGetError { + fn drop(&mut self) { + assert!( + self.handled, + "dropped an unhandled TransientGetError for id {}", + self.id, + ); + } +} + +impl TransientError for TransientGetError { + fn acknowledge(mut self, _why: D) + where + D: fmt::Display, + { + self.handled = true; + } + + fn escalate(mut self, _why: D) -> InvalidId + where + D: fmt::Display, + { + self.handled = true; + InvalidId(self.id) + } +} + +/// Two-tier error for [`Visitor::distances_unordered`]: a critical [`InvalidId`] +/// or a recoverable [`TransientGetError`]. +#[derive(Debug)] +pub enum AccessError { + InvalidId(InvalidId), + Transient(TransientGetError), +} + +impl fmt::Display for AccessError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidId(e) => fmt::Display::fmt(e, f), + Self::Transient(e) => fmt::Display::fmt(e, f), + } + } +} + +impl std::error::Error for AccessError {} + +impl ToRanked for AccessError { + type Transient = TransientGetError; + type Error = InvalidId; + + fn to_ranked(self) -> RankedError { + match self { + Self::InvalidId(e) => RankedError::Error(e), + Self::Transient(e) => RankedError::Transient(e), + } + } + + fn from_transient(transient: TransientGetError) -> Self { + Self::Transient(transient) + } + + fn from_error(error: InvalidId) -> Self { + Self::InvalidId(error) + } +} + +////////////////// +// DataProvider // +////////////////// + +impl provider::DataProvider for Provider { + type Context = Context; + type InternalId = u32; + type ExternalId = u32; + type Error = InvalidId; + type Guard = NoopGuard; + + fn to_internal_id(&self, _ctx: &Context, gid: &u32) -> Result { + if (*gid as usize) < self.items.len() { + Ok(*gid) + } else { + Err(InvalidId(*gid)) + } + } + + fn to_external_id(&self, _ctx: &Context, id: u32) -> Result { + if (id as usize) < self.items.len() { + Ok(id) + } else { + Err(InvalidId(id)) + } + } +} + +///////////// +// Visitor // +///////////// + +/// Per-search visitor over a [`Provider`]. Analog of `graph::test::Accessor`: holds +/// the `'a` borrow of the provider, accumulates a local `get_element` counter that +/// flushes back on drop, and optionally injects transient errors for a configurable +/// set of ids. +pub struct Visitor<'a> { + provider: &'a Provider, + transient_ids: Option>>, + get_element: LocalCounter<'a>, +} + +impl<'a> Visitor<'a> { + /// Construct a visitor with no fault injection. + pub fn new(provider: &'a Provider) -> Self { + Self { + provider, + transient_ids: None, + get_element: provider.get_element.local(), + } + } + + /// Construct a visitor that returns a [`TransientGetError`] for any id in + /// `transient_ids`. Other ids behave normally. + pub fn flaky(provider: &'a Provider, transient_ids: Cow<'a, HashSet>) -> Self { + Self { + provider, + transient_ids: Some(transient_ids), + get_element: provider.get_element.local(), + } + } +} + +impl Debug for Visitor<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Visitor") + .field("provider", &self.provider) + .field("transient_ids", &self.transient_ids) + .finish_non_exhaustive() + } +} + +impl HasId for Visitor<'_> { + type Id = u32; +} + +impl HasElementRef for Visitor<'_> { + type ElementRef<'a> = &'a [f32]; +} + +impl BuildQueryComputer<&[f32]> for Visitor<'_> { + type QueryComputerError = Infallible; + type QueryComputer = ::QueryDistance; + + fn build_query_computer( + &self, + from: &[f32], + ) -> Result { + Ok(f32::query_distance(from, Metric::L2)) + } +} + +impl DistancesUnordered<&[f32]> for Visitor<'_> { + type Error = AccessError; + + fn distances_unordered( + &mut self, + computer: &Self::QueryComputer, + mut f: F, + ) -> impl SendFuture> + where + F: Send + FnMut(Self::Id, f32), + { + async move { + for (i, vector) in self.provider.items.iter().enumerate() { + let id = i as u32; + if let Some(ids) = &self.transient_ids + && ids.contains(&id) + { + return Err(AccessError::Transient(TransientGetError::new(id))); + } + self.get_element.increment(); + let dist = computer.evaluate_similarity(vector.as_slice()); + f(id, dist); + } + Ok(()) + } + } +} + +////////////// +// Strategy // +////////////// + +/// Stateless factory of [`Visitor`]s. +#[derive(Clone, Debug, Default)] +pub struct Strategy { + transient_ids: Option>>, +} + +impl Strategy { + pub fn new() -> Self { + Self::default() + } + + /// Construct a strategy whose visitors return a transient error on `get_element` + /// for every id in `transient_ids`. + pub fn with_transient(transient_ids: impl IntoIterator) -> Self { + Self { + transient_ids: Some(Arc::new(transient_ids.into_iter().collect())), + } + } +} + +impl SearchStrategy for Strategy { + type Visitor<'a> = Visitor<'a>; + type Error = Infallible; + + fn create_visitor<'a>( + &'a self, + provider: &'a Provider, + _context: &'a Context, + ) -> Result, Self::Error> { + let visitor = match &self.transient_ids { + Some(ids) => Visitor::flaky(provider, Cow::Borrowed(ids)), + None => Visitor::new(provider), + }; + Ok(visitor) + } +} diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index cb098b85c..2f5bf5f0f 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -89,7 +89,7 @@ use crate::{ neighbor::Neighbor, provider::{ Accessor, AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, - DataProvider, HasId, NeighborAccessor, + DataProvider, DistancesUnordered, HasElementRef, HasId, NeighborAccessor, }, utils::VectorId, }; @@ -242,7 +242,7 @@ impl HybridPredicate for NotInMut<'_, T> where T: Clone + Eq + std::hash:: /// /// The provided implementation works on each element of `ids` sequentially, pre-filters /// the resulting candidate list using `pred.eval()` before invoking -/// [`BuildQueryComputer::distances_unordered`]. +/// [`DistancesUnordered::distances_unordered`]. /// /// The callback `on_neighbors` is decorated to the uses `pred.eval_mut()`. /// @@ -252,7 +252,7 @@ impl HybridPredicate for NotInMut<'_, T> where T: Clone + Eq + std::hash:: /// ## Error Handling /// /// Transient errors yielded by `distances_unordered` are acknowledged and not escalated. -pub trait ExpandBeam: BuildQueryComputer + AsNeighbor + Sized { +pub trait ExpandBeam: DistancesUnordered + AsNeighbor + Sized { fn expand_beam( &mut self, ids: Itr, @@ -299,7 +299,7 @@ where /// We could grab this type from the `SearchAccessor` associated type, but it's /// useful enough that we move it up here. type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< - as Accessor>::ElementRef<'b>, + as HasElementRef>::ElementRef<'b>, f32, > + Send + Sync @@ -386,7 +386,7 @@ macro_rules! default_post_processor { /// directly into the output buffer. pub trait SearchPostProcess::Id> where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, { type Error: StandardError; @@ -412,7 +412,7 @@ pub struct CopyIds; impl SearchPostProcess for CopyIds where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, { type Error = std::convert::Infallible; fn post_process( @@ -437,7 +437,7 @@ where /// using a [`Pipeline`]. pub trait SearchPostProcessStep::Id> where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, { /// A potentially modified version of the error yielded by the next state in the /// processing pipeline. @@ -446,7 +446,7 @@ where NextError: StandardError; /// The accessor that will be passed to the next processing stage. - type NextAccessor: BuildQueryComputer; + type NextAccessor: BuildQueryComputer + HasId; /// Perform any modification the `input`, `output`, `accessor`, or `computer` objects /// and invoke the [`SearchPostProcess`] routine `next` on stage. @@ -471,7 +471,7 @@ pub struct FilterStartPoints; impl SearchPostProcessStep for FilterStartPoints where - A: BuildQueryComputer + SearchExt, + A: BuildQueryComputer + SearchExt + HasId, T: Copy + Send + Sync, { /// A this level, sub-errors are converted into [`ANNError`] to provide additional @@ -540,7 +540,7 @@ impl Pipeline { impl SearchPostProcess for Pipeline where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, Head: SearchPostProcessStep, Tail: SearchPostProcess + Sync, { @@ -614,8 +614,8 @@ where /// We could grab this type from the `PruneAccessor` associated type, but it's /// useful enough that we move it up here. type DistanceComputer<'computer>: for<'a, 'b, 'c, 'd> DistanceFunction< - as Accessor>::ElementRef<'b>, - as Accessor>::ElementRef<'d>, + as HasElementRef>::ElementRef<'b>, + as HasElementRef>::ElementRef<'d>, f32, > + Send + Sync; @@ -854,7 +854,7 @@ mod tests { use super::*; use crate::{ ANNResult, neighbor, - provider::{DelegateNeighbor, ExecutionContext, HasId, NeighborAccessor}, + provider::{DelegateNeighbor, ExecutionContext, HasElementRef, HasId, NeighborAccessor}, }; // A really simple provider that just holds floats and uses the absolute value for its @@ -927,12 +927,15 @@ mod tests { type Id = u32; } + impl HasElementRef for Retriever<'_> { + type ElementRef<'a> = f32; + } + impl Accessor for Retriever<'_> { type Element<'a> = f32 where Self: 'a; - type ElementRef<'a> = f32; type GetError = ANNError; fn get_element( @@ -979,6 +982,8 @@ mod tests { impl ExpandBeam for Retriever<'_> {} + impl DistancesUnordered for Retriever<'_> {} + // This strategy explicitly does not define `post_process` so we can test the provided // implementation. struct Strategy; diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 051ec6fc7..e0660006e 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -1040,12 +1040,15 @@ impl provider::HasId for Accessor<'_> { type Id = u32; } +impl provider::HasElementRef for Accessor<'_> { + type ElementRef<'a> = &'a [f32]; +} + impl provider::Accessor for Accessor<'_> { type Element<'a> = &'a [f32] where Self: 'a; - type ElementRef<'a> = &'a [f32]; type GetError = AccessError; async fn get_element(&mut self, id: u32) -> Result<&[f32], AccessError> { @@ -1085,6 +1088,8 @@ impl provider::BuildQueryComputer<&[f32]> for Accessor<'_> { } } +impl provider::DistancesUnordered<&[f32]> for Accessor<'_> {} + impl provider::BuildDistanceComputer for Accessor<'_> { type DistanceComputerError = Infallible; type DistanceComputer = ::Distance; diff --git a/diskann/src/lib.rs b/diskann/src/lib.rs index 71cb3ed41..9c1f6ac76 100644 --- a/diskann/src/lib.rs +++ b/diskann/src/lib.rs @@ -13,6 +13,7 @@ pub mod utils; pub(crate) mod internal; // Index Implementations +pub mod flat; pub mod graph; // Top level exports. diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index f55af356c..82c1d64db 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -71,10 +71,14 @@ //! * [`BuildDistanceComputer`]: A sub-trait of [`Accessor`] that allows for random-access //! distance computations on the retrieved elements. //! -//! * [`BuildQueryComputer`]: A sub-trait of [`Accessor`] that allows for specialized query -//! based computations. This allows a query to be pre-processed in a way that allows +//! * [`BuildQueryComputer`]: A sub-trait of [`HasElementRef`] that allows for specialized +//! query based computations. This allows a query to be pre-processed in a way that allows //! faster computations. //! +//! * [`DistancesUnordered`]: A sub-trait of [`Accessor`] and [`BuildQueryComputer`] that +//! provides a fused iterate-and-compute primitive over a set of element ids using a +//! pre-built query computer. +//! //! # Neighbor Delegation //! //! Index search requires that accessor types implement both the data-centric [`Accessor`] @@ -381,6 +385,16 @@ where } } +/////////////////// +// HasElementRef // +/////////////////// + +/// A catch-all super trait for traits that have an associated `ElementRef<'_>` +/// type. Traits like [`Accessor`] are subtraits. +pub trait HasElementRef { + type ElementRef<'a>; +} + ////////////// // Accessor // ////////////// @@ -419,18 +433,15 @@ where /// The need for `ElementRef` arises to allow HRTB bounds to distance computers without /// inducing a `'static` bound on `Self`. In traits like [`BuildQueryComputer`], attempting /// to use `Element` directly will result in such a requirement on the implementing Accessor. -pub trait Accessor: HasId + Send + Sync { - /// A generalized reference type used for distance computations. - /// - /// Note that the lifetime of `ElementRef` is unconstrained and thus using it in a - /// [HRTB](https://doc.rust-lang.org/nomicon/hrtb.html) will not induce a `'static` - /// requirement on `Self`. - type ElementRef<'a>; - +pub trait Accessor: HasId + HasElementRef + Send + Sync { /// The concrete type of the data element associated with this accessor. /// /// For distance computations, this should be cheaply convertible via [`Reborrow`] to /// `Self::ElementRef`. + /// + /// Note that the lifetime of `ElementRef` is unconstrained and thus using it in a + /// [HRTB](https://doc.rust-lang.org/nomicon/hrtb.html) will not induce a `'static` + /// requirement on `Self`. type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync where Self: 'a; @@ -496,11 +507,18 @@ pub trait BuildDistanceComputer: Accessor { ) -> Result; } -/// A specialized [`Accessor`] that provides query computations for a query type `T`. +/// A trait that provides query computations for a query type `T`. /// /// Query computers are allowed to preprocess the query to enable more efficient distance /// computations. -pub trait BuildQueryComputer: Accessor { +/// +/// This trait only requires [`HasElementRef`] (so the query computer's element type can be +/// named) so that it can be used with multiple access patterns - like [`Accessor`] and +/// [`crate::flat::FlastSearchStrategy`]. +/// +/// A fused iterate-and-compute primitive can be created as a sub-trait - +/// e.g. [`DistancesUnordered`], which requires both [`Accessor`] and `BuildQueryComputer`. +pub trait BuildQueryComputer: HasElementRef { /// The error type (if any) associated with distance computer construction. type QueryComputerError: std::error::Error + Into + Send + Sync + 'static; @@ -519,11 +537,16 @@ pub trait BuildQueryComputer: Accessor { &self, from: T, ) -> Result; +} - /// Compute the distances for the elements in the iterator `itr` using the - /// `computer` and apply the closure `f` to each distance and ID. The default - /// implementation uses on_elements_unordered to iterate over the elements - /// and compute the distances using `computer` parameter. +/// A sub-trait of [`Accessor`] and [`BuildQueryComputer`] that exposes the fused +/// iterate-and-compute primitive `distances_unordered`. +/// +/// The default implementation uses [`Accessor::on_elements_unordered`] to iterate over the +/// elements and computes their distances using the provided `computer`. +pub trait DistancesUnordered: Accessor + BuildQueryComputer { + /// Compute the distances for the elements in the iterator `vec_id_itr` using the + /// `computer` and apply the closure `f` to each distance and ID. fn distances_unordered( &mut self, vec_id_itr: Itr, @@ -1046,12 +1069,14 @@ mod tests { impl HasId for FloatAccessor<'_> { type Id = u32; } + impl HasElementRef for FloatAccessor<'_> { + type ElementRef<'a> = f32; + } impl Accessor for FloatAccessor<'_> { type Element<'a> = f32 where Self: 'a; - type ElementRef<'a> = f32; type GetError = Missing; @@ -1112,12 +1137,14 @@ mod tests { impl HasId for StringAccessor<'_> { type Id = u32; } + impl HasElementRef for StringAccessor<'_> { + type ElementRef<'a> = &'a str; + } impl Accessor for StringAccessor<'_> { type Element<'a> = &'a str where Self: 'a; - type ElementRef<'a> = &'a str; type GetError = Missing; @@ -1271,12 +1298,15 @@ mod tests { common_test_accessor!(Allocating<'_>); + impl HasElementRef for Allocating<'_> { + type ElementRef<'a> = &'a [u8]; + } + impl Accessor for Allocating<'_> { type Element<'a> = Box<[u8]> where Self: 'a; - type ElementRef<'a> = &'a [u8]; type GetError = Infallible; async fn get_element(&mut self, _: u32) -> Result, Infallible> { @@ -1298,6 +1328,10 @@ mod tests { common_test_accessor!(Forwarding<'_>); + impl HasElementRef for Forwarding<'_> { + type ElementRef<'a> = &'a [u8]; + } + impl<'provider> Accessor for Forwarding<'provider> { // NOTE: The lifetime of `Element` is `'provider` - not `'a`. This is what makes // it a forwarding accessor. @@ -1305,7 +1339,6 @@ mod tests { = &'provider [u8] where Self: 'a; - type ElementRef<'a> = &'a [u8]; type GetError = Infallible; async fn get_element(&mut self, _: u32) -> Result<&'provider [u8], Infallible> { @@ -1342,12 +1375,15 @@ mod tests { common_test_accessor!(Wrapping<'_>); + impl HasElementRef for Wrapping<'_> { + type ElementRef<'a> = &'a [u8]; + } + impl Accessor for Wrapping<'_> { type Element<'a> = Wrapped<'a> where Self: 'a; - type ElementRef<'a> = &'a [u8]; type GetError = Infallible; async fn get_element(&mut self, _: u32) -> Result, Infallible> { @@ -1373,12 +1409,15 @@ mod tests { common_test_accessor!(Sharing<'_>); + impl HasElementRef for Sharing<'_> { + type ElementRef<'a> = &'a [u8]; + } + impl Accessor for Sharing<'_> { type Element<'a> = &'a [u8] where Self: 'a; - type ElementRef<'a> = &'a [u8]; type GetError = Infallible; async fn get_element(&mut self, _: u32) -> Result<&[u8], Infallible> { diff --git a/diskann/test/generated/flat/test/cases/flat_knn_search/search_1_100.json b/diskann/test/generated/flat/test/cases/flat_knn_search/search_1_100.json new file mode 100644 index 000000000..3dda3b7c8 --- /dev/null +++ b/diskann/test/generated/flat/test/cases/flat_knn_search/search_1_100.json @@ -0,0 +1,264 @@ +{ + "file": "diskann/src/flat/test/cases/flat_knn_search.rs", + "test": "flat/test/cases/flat_knn_search/search_1_100", + "payload": [ + { + "comparisons": 100, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 0, + 1.0 + ] + ], + "k": 1, + "metrics": { + "get_element": 100 + }, + "query": [ + -1.0 + ], + "result_count": 1, + "top_k_distances": [ + 1.0 + ] + }, + { + "comparisons": 100, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 0, + 1.0 + ], + [ + 1, + 4.0 + ], + [ + 2, + 9.0 + ], + [ + 3, + 16.0 + ] + ], + "k": 4, + "metrics": { + "get_element": 100 + }, + "query": [ + -1.0 + ], + "result_count": 4, + "top_k_distances": [ + 1.0, + 4.0, + 9.0, + 16.0 + ] + }, + { + "comparisons": 100, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 0, + 1.0 + ], + [ + 1, + 4.0 + ], + [ + 2, + 9.0 + ], + [ + 3, + 16.0 + ], + [ + 4, + 25.0 + ], + [ + 5, + 36.0 + ], + [ + 6, + 49.0 + ], + [ + 7, + 64.0 + ], + [ + 8, + 81.0 + ], + [ + 9, + 100.0 + ] + ], + "k": 10, + "metrics": { + "get_element": 100 + }, + "query": [ + -1.0 + ], + "result_count": 10, + "top_k_distances": [ + 1.0, + 4.0, + 9.0, + 16.0, + 25.0, + 36.0, + 49.0, + 64.0, + 81.0, + 100.0 + ] + }, + { + "comparisons": 100, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 99, + 0.0 + ] + ], + "k": 1, + "metrics": { + "get_element": 100 + }, + "query": [ + 99.0 + ], + "result_count": 1, + "top_k_distances": [ + 0.0 + ] + }, + { + "comparisons": 100, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 99, + 0.0 + ], + [ + 98, + 1.0 + ], + [ + 97, + 4.0 + ], + [ + 96, + 9.0 + ] + ], + "k": 4, + "metrics": { + "get_element": 100 + }, + "query": [ + 99.0 + ], + "result_count": 4, + "top_k_distances": [ + 0.0, + 1.0, + 4.0, + 9.0 + ] + }, + { + "comparisons": 100, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 99, + 0.0 + ], + [ + 98, + 1.0 + ], + [ + 97, + 4.0 + ], + [ + 96, + 9.0 + ], + [ + 95, + 16.0 + ], + [ + 94, + 25.0 + ], + [ + 93, + 36.0 + ], + [ + 92, + 49.0 + ], + [ + 91, + 64.0 + ], + [ + 90, + 81.0 + ] + ], + "k": 10, + "metrics": { + "get_element": 100 + }, + "query": [ + 99.0 + ], + "result_count": 10, + "top_k_distances": [ + 0.0, + 1.0, + 4.0, + 9.0, + 16.0, + 25.0, + 36.0, + 49.0, + 64.0, + 81.0 + ] + } + ] +} \ No newline at end of file diff --git a/diskann/test/generated/flat/test/cases/flat_knn_search/search_2_5.json b/diskann/test/generated/flat/test/cases/flat_knn_search/search_2_5.json new file mode 100644 index 000000000..3a978a520 --- /dev/null +++ b/diskann/test/generated/flat/test/cases/flat_knn_search/search_2_5.json @@ -0,0 +1,270 @@ +{ + "file": "diskann/src/flat/test/cases/flat_knn_search.rs", + "test": "flat/test/cases/flat_knn_search/search_2_5", + "payload": [ + { + "comparisons": 25, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 0, + 2.0 + ] + ], + "k": 1, + "metrics": { + "get_element": 25 + }, + "query": [ + -1.0, + -1.0 + ], + "result_count": 1, + "top_k_distances": [ + 2.0 + ] + }, + { + "comparisons": 25, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 0, + 2.0 + ], + [ + 1, + 5.0 + ], + [ + 5, + 5.0 + ], + [ + 6, + 8.0 + ] + ], + "k": 4, + "metrics": { + "get_element": 25 + }, + "query": [ + -1.0, + -1.0 + ], + "result_count": 4, + "top_k_distances": [ + 2.0, + 5.0, + 5.0, + 8.0 + ] + }, + { + "comparisons": 25, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 0, + 2.0 + ], + [ + 1, + 5.0 + ], + [ + 5, + 5.0 + ], + [ + 6, + 8.0 + ], + [ + 2, + 10.0 + ], + [ + 10, + 10.0 + ], + [ + 7, + 13.0 + ], + [ + 11, + 13.0 + ], + [ + 3, + 17.0 + ], + [ + 15, + 17.0 + ] + ], + "k": 10, + "metrics": { + "get_element": 25 + }, + "query": [ + -1.0, + -1.0 + ], + "result_count": 10, + "top_k_distances": [ + 2.0, + 5.0, + 5.0, + 8.0, + 10.0, + 10.0, + 13.0, + 13.0, + 17.0, + 17.0 + ] + }, + { + "comparisons": 25, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 24, + 0.0 + ] + ], + "k": 1, + "metrics": { + "get_element": 25 + }, + "query": [ + 4.0, + 4.0 + ], + "result_count": 1, + "top_k_distances": [ + 0.0 + ] + }, + { + "comparisons": 25, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 24, + 0.0 + ], + [ + 19, + 1.0 + ], + [ + 23, + 1.0 + ], + [ + 18, + 2.0 + ] + ], + "k": 4, + "metrics": { + "get_element": 25 + }, + "query": [ + 4.0, + 4.0 + ], + "result_count": 4, + "top_k_distances": [ + 0.0, + 1.0, + 1.0, + 2.0 + ] + }, + { + "comparisons": 25, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 24, + 0.0 + ], + [ + 19, + 1.0 + ], + [ + 23, + 1.0 + ], + [ + 18, + 2.0 + ], + [ + 14, + 4.0 + ], + [ + 22, + 4.0 + ], + [ + 13, + 5.0 + ], + [ + 17, + 5.0 + ], + [ + 12, + 8.0 + ], + [ + 9, + 9.0 + ] + ], + "k": 10, + "metrics": { + "get_element": 25 + }, + "query": [ + 4.0, + 4.0 + ], + "result_count": 10, + "top_k_distances": [ + 0.0, + 1.0, + 1.0, + 2.0, + 4.0, + 4.0, + 5.0, + 5.0, + 8.0, + 9.0 + ] + } + ] +} \ No newline at end of file diff --git a/diskann/test/generated/flat/test/cases/flat_knn_search/search_3_4.json b/diskann/test/generated/flat/test/cases/flat_knn_search/search_3_4.json new file mode 100644 index 000000000..74e067761 --- /dev/null +++ b/diskann/test/generated/flat/test/cases/flat_knn_search/search_3_4.json @@ -0,0 +1,276 @@ +{ + "file": "diskann/src/flat/test/cases/flat_knn_search.rs", + "test": "flat/test/cases/flat_knn_search/search_3_4", + "payload": [ + { + "comparisons": 64, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 0, + 3.0 + ] + ], + "k": 1, + "metrics": { + "get_element": 64 + }, + "query": [ + -1.0, + -1.0, + -1.0 + ], + "result_count": 1, + "top_k_distances": [ + 3.0 + ] + }, + { + "comparisons": 64, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 0, + 3.0 + ], + [ + 1, + 6.0 + ], + [ + 4, + 6.0 + ], + [ + 16, + 6.0 + ] + ], + "k": 4, + "metrics": { + "get_element": 64 + }, + "query": [ + -1.0, + -1.0, + -1.0 + ], + "result_count": 4, + "top_k_distances": [ + 3.0, + 6.0, + 6.0, + 6.0 + ] + }, + { + "comparisons": 64, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 0, + 3.0 + ], + [ + 1, + 6.0 + ], + [ + 4, + 6.0 + ], + [ + 16, + 6.0 + ], + [ + 5, + 9.0 + ], + [ + 17, + 9.0 + ], + [ + 20, + 9.0 + ], + [ + 2, + 11.0 + ], + [ + 8, + 11.0 + ], + [ + 32, + 11.0 + ] + ], + "k": 10, + "metrics": { + "get_element": 64 + }, + "query": [ + -1.0, + -1.0, + -1.0 + ], + "result_count": 10, + "top_k_distances": [ + 3.0, + 6.0, + 6.0, + 6.0, + 9.0, + 9.0, + 9.0, + 11.0, + 11.0, + 11.0 + ] + }, + { + "comparisons": 64, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 63, + 0.0 + ] + ], + "k": 1, + "metrics": { + "get_element": 64 + }, + "query": [ + 3.0, + 3.0, + 3.0 + ], + "result_count": 1, + "top_k_distances": [ + 0.0 + ] + }, + { + "comparisons": 64, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 63, + 0.0 + ], + [ + 47, + 1.0 + ], + [ + 59, + 1.0 + ], + [ + 62, + 1.0 + ] + ], + "k": 4, + "metrics": { + "get_element": 64 + }, + "query": [ + 3.0, + 3.0, + 3.0 + ], + "result_count": 4, + "top_k_distances": [ + 0.0, + 1.0, + 1.0, + 1.0 + ] + }, + { + "comparisons": 64, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 63, + 0.0 + ], + [ + 47, + 1.0 + ], + [ + 59, + 1.0 + ], + [ + 62, + 1.0 + ], + [ + 43, + 2.0 + ], + [ + 46, + 2.0 + ], + [ + 58, + 2.0 + ], + [ + 42, + 3.0 + ], + [ + 31, + 4.0 + ], + [ + 55, + 4.0 + ] + ], + "k": 10, + "metrics": { + "get_element": 64 + }, + "query": [ + 3.0, + 3.0, + 3.0 + ], + "result_count": 10, + "top_k_distances": [ + 0.0, + 1.0, + 1.0, + 1.0, + 2.0, + 2.0, + 2.0, + 3.0, + 4.0, + 4.0 + ] + } + ] +} \ No newline at end of file diff --git a/rfcs/00983-flat-search.md b/rfcs/00983-flat-search.md new file mode 100644 index 000000000..a4a4312e2 --- /dev/null +++ b/rfcs/00983-flat-search.md @@ -0,0 +1,265 @@ +# Flat Search + +| | | +|------------------|--------------------------------| +| **Authors** | Aditya Krishnan, Alex Razumov, Dongliang Wu | +| **Created** | 2026-04-24 | +| **Updated** | 2026-05-14 | + +## 1. Motivation + +### 1.1 Background + +DiskANN today exposes a single abstraction family centered on the +[`crate::provider::Accessor`] trait. Accessors are random access by design since the graph greedy search algorithm needs to decide which ids to fetch and the accessor materializes the corresponding elements (vectors, quantized vectors and neighbor lists) on demand. This is the right contract for graph search, where neighborhood expansion is inherently random-access against the [`crate::provider::DataProvider`]. + +A growing class of consumers diverge from our current pattern of use by accesssing their index **sequentially**. Some consumers build their index in an "append-only" fashion and require that they walk the index in a sequential, fixed order, relying on iteration position to enforce versioning / deduplication invariants. + +### 1.2 Problem Statement + +The problem-statement here is simple: provide first-class support for sequential, one-pass scans over a data backend without +stuffing the algorithm or the backend through the `Accessor` trait surface. + +### 1.3 Goals + +1. Define a fused iterate-and-score primitive — `flat::DistancesUnordered` — that + mirrors the role `Accessor` plays for graph search but exposes a sequential + scan-and-score operation instead of random access. +2. Provide flat-search algorithm implementations built on the new primitives, so consumers can use this against their own providers / backends. +3. Expose support for diferent distance computers and post-processing like re-ranking _out-of-the-box_ without having to reimplement these for the flat search path. + +## 2. Proposal + +The flat-search infrastructure is built on a small sequence of traits. The only trait a +backend *must* implement is `flat::DistancesUnordered`. A `flat::SearchStrategy` +then instantiates per-query visitors that implement it. + +An opt-in `FlatIterator` trait plus the `Iterated` adapter exist for +convenience for backends that naturally expose element-at-a-time iteration. + +### 2.1 Refactor `Accessor` and `BuildQueryComputer ` +We start by a small refactor we introduce to the traits in `diskann::providers` and `diskann::graph::glue` that +will enable us to cleanly separate the query preprocessing, result post-processing from the search pattern so that +both graph and flat search can share common components as much as possible: + +1. **Extract `HasElementRef` out of `Accessor`.** The `ElementRef<'a>` GAT moves to + its own zero-method trait so that streaming visitors (which are not `Accessor`s) + can still expose an element type. `Accessor` is now `Accessor: HasId + + HasElementRef + Send + Sync`. `HasElementRef` is simply: + + ```rust + pub trait HasElementRef { type ElementRef<'a> } + ``` + +2. **Decouple `BuildQueryComputer` from `Accessor`.** Previously a + sub-trait of `Accessor`, `BuildQueryComputer` is lifted to depend only on + `HasElementRef`. Secondly, it now contains only a constructor `build_query_computer` + as an associated method and nothing else. This is the + change that lets `BuildQueryComputer` and `graph::glue::SearchPostProcess` be + used unchanged by both the flat index and the graph. + +3. **Split distance scoring into a new `DistancesUnordered` trait family.** + Previously, the unordered iterate-and-score loop was a default method tucked + inside `Accessor` (and shadowed by overrides on a few providers). It is now its + own subtrait of `BuildQueryComputer`, with two flavors that share a name and a + default-body shape but differ in their access super-trait: + + - **`provider::DistancesUnordered: Accessor + BuildQueryComputer`** — drives + the scan via the random-access `Accessor` machinery. Used by graph search. + - **`flat::DistancesUnordered: HasId + BuildQueryComputer`** — a + self-contained fused iterate-and-score trait used by flat search. Backends + implement the entire scan-and-score loop in a single method. More on it below. + +### 2.2 Core traits for flat search + +The single required trait for flat search is `flat::DistancesUnordered`. It fuses +iteration and scoring into one method: implementations drive an entire scan over their +underlying data, scoring each element with the supplied query computer and invoking a +callback with `(id, distance)` pairs. Implementations choose iteration order, +prefetching, and any bulk reads; algorithms see only `(Id, f32)` pairs. + +```rust +pub trait DistancesUnordered: HasId + BuildQueryComputer + Send + Sync { + type Error: ToRanked + Debug + Send + Sync + 'static; + + fn distances_unordered( + &mut self, + computer: &>::QueryComputer, + f: F, + ) -> impl SendFuture> + where + F: Send + FnMut(::Id, f32); +} +``` + +`DistancesUnordered` is scoped to a single query. We introduce a strategy that is the per-call +constructor that hands the algorithm a freshly-bound visitor. It is meant to be stateless, +cheap to construct, and lives only for the duration of one search. + +```rust +pub trait SearchStrategy: Send + Sync +where + P: DataProvider, +{ + /// The per-query visitor type produced by [`Self::create_visitor`]. Borrows from + /// `self` and the provider. The visitor implements both the streaming + /// [`DistancesUnordered`] primitive and the query preprocessor + /// [`BuildQueryComputer`]. + type Visitor<'a>: DistancesUnordered + where + Self: 'a, + P: 'a; + + type Error: StandardError; + + /// Construct a fresh visitor over `provider` for the given request `context`. + fn create_visitor<'a>( + &'a self, + provider: &'a P, + context: &'a P::Context, + ) -> Result, Self::Error>; +} +``` +This shape mirrors the random-access `graph::glue::SearchStrategy` and lets `FlatIndex::knn_search` accept the same +`graph::glue::SearchPostProcess` that graph search uses (see below). + +### 2.3 `FlatIndex` — the top-level handle + +`FlatIndex` is a thin `'static` wrapper around a `DataProvider`. The same +`DataProvider` trait used by graph search is reused — flat and graph share one +provider surface and the same `Context` / id-mapping / error machinery. + +```rust +pub struct FlatIndex { + provider: P, +} + +impl FlatIndex

{ + pub fn new(provider: P) -> Self; + pub fn provider(&self) -> &P; + + pub fn knn_search( + &self, + k: NonZeroUsize, + strategy: &S, + processor: &PP, + context: &P::Context, + query: T, + output: &mut OB, + ) -> impl SendFuture> + where + S: flat::SearchStrategy, + T: Copy + Send + Sync, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, + PP: for<'a> graph::glue::SearchPostProcess, T, O> + Send + Sync, +} +``` + +**Note:** The `PP` bound uses the same `graph::glue::SearchPostProcess` trait as graph search; +there is no flat-specific post-process trait. Reuse is enabled by the trait splits +described above (the visitor implements `BuildQueryComputer + HasId`, which is all +`SearchPostProcess` requires). + +The `knn_search` method is the canonical brute-force search algorithm: + +1. Construct the per-query visitor via `strategy.create_visitor`. +2. Build the query computer from the visitor via `BuildQueryComputer::build_query_computer`. +3. Drive the scan via `visitor.distances_unordered(&computer, ...)`, inserting each + `(id, distance)` pair into a `NeighborPriorityQueue` of capacity `k`. +4. Hand the survivors (in distance order) to `processor.post_process`. +5. Return search stats. + +Other algorithms (filtered, range, diverse) can be added later as additional methods on +`FlatIndex`. + +#### Search call chain (AI Generated) + +The diagram below traces the trait dispatch sequence inside one `search` call for +each of graph and flat search. The centre lane shows the shared traits that both +columns dip into. + +```text + Graph Shared Flat + ───── ────── ──── + + DiskANNIndex::search FlatIndex::knn_search + │ │ + ▼ ▼ + graph::glue::SearchStrategy flat::SearchStrategy + ::search_accessor ::create_visitor + │ │ + ▼ ▼ + ExpandBeam visitor DistancesUnordered visitor + (Accessor + BuildQueryComputer) (HasId + BuildQueryComputer) + │ │ + │ BuildQueryComputer │ + ├─────────────────►::build_query_computer ◄──────────────────────┤ + │ (visitor → QueryComputer) │ + │ │ + ▼ ▼ + ExpandBeam::expand_beam DistancesUnordered + (greedy beam loop: ::distances_unordered + for each frontier id, (one pass over every + get_neighbors, element; computer scores + distances_unordered) each one) + │ │ + ▼ ▼ + NeighborPriorityQueue NeighborPriorityQueue + │ │ + │ graph::glue::SearchPostProcess │ + └─────────────►::post_process ◄──────────────────────────────────┘ + │ + ▼ + SearchOutputBuffer +``` + +### 2.4 `FlatIterator` and `Iterated` — convenience for element-at-a-time backends + +For backends that naturally expose element-at-a-time iteration, `FlatIterator` is a +lending async iterator: + +```rust +pub trait FlatIterator: HasId + HasElementRef + Send + Sync { + type Element<'a>: for<'b> Reborrow<'b, Target = ::ElementRef<'b>> + + Send + Sync + where Self: 'a; + type Error: StandardError; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>>; +} +``` + +`Iterated` wraps any `FlatIterator` and implements `DistancesUnordered` (when +the inner type also implements `BuildQueryComputer`) by looping over `next()`, +reborrowing each element, and scoring it with the supplied query computer. + +## Trade-offs + +### Reusing `DataProvider` + +This design leans into using the `DataProvider` trait which requires implementations to implement `InternalId` and `ExternalId` conversions (via the context). Arguably, this requirement is too restrictive for some consumers of a flat-index. Reasons for sticking with `DataProvider`: + +- Every concrete provider already implements `DataProvider`, so a separate trait adds + an abstraction that existing consumers will have to implement if they want to opt-in to the flat-index path. +- Sharing `DataProvider` means the `Context`, id-mapping (`to_internal_id` / + `to_external_id`), and error machinery are identical across graph and flat search, + reducing the learning surface for new contributors. + +### Expand `Element` to support batched distance computation? + +The current optional iterator `FlatIterator` yields one element per `next()` call, and the query computer scores +elements one at a time via `PreprocessedDistanceFunction::evaluate_similarity`. This could leave some optimization and performance on the table; especially with the upcoming effort around batched distance kernels. Of course, a consumer can choose to implement their own optimized implementation of `distances_unordered` that uses batching. + +An alternative is to make `next()` yield a *batch* instead of a single vector representation like `Element<'_>`. Some work will need to be done to define the right interaction between the batch type, the element type in the batch, the interaction with `QueryComputer`'s types and way IDs and distances are collected in the queue. + +### Intra-query parallelism + +The current design of `DistancesUnordered` does not allow an implementation to exploit parallelism within a query; since the trait requires a `&mut self`. Especially for a flat index, some implementations might want to parallelize within the scan for a query. Arguably we will need a more complex extension of this architecture to support this. + +## Future Work +- Support for other flat-search algorithms like - filtered, range and diverse flat algorithms as additional methods on `FlatIndex`. +- Index build -- this is just one part of the picture; more work needs to be done around how this fits in with any traits / interface we need for index build. +