From 6b6c63c2c97773ae7634c3306ed724a41a3eb7b2 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Mon, 27 Apr 2026 22:13:05 -0700 Subject: [PATCH 01/24] first draft --- diskann/src/flat/index.rs | 106 +++++++++++ diskann/src/flat/iterator.rs | 91 ++++++++++ diskann/src/flat/mod.rs | 50 ++++++ diskann/src/flat/post_process.rs | 73 ++++++++ diskann/src/flat/search.rs | 35 ++++ diskann/src/flat/stats.rs | 16 ++ diskann/src/flat/strategy.rs | 77 ++++++++ diskann/src/flat/test/mod.rs | 293 +++++++++++++++++++++++++++++++ diskann/src/lib.rs | 1 + rfcs/00000-flat-search.md | 217 +++++++++++++++++++++++ 10 files changed, 959 insertions(+) create mode 100644 diskann/src/flat/index.rs create mode 100644 diskann/src/flat/iterator.rs create mode 100644 diskann/src/flat/mod.rs create mode 100644 diskann/src/flat/post_process.rs create mode 100644 diskann/src/flat/search.rs create mode 100644 diskann/src/flat/stats.rs create mode 100644 diskann/src/flat/strategy.rs create mode 100644 diskann/src/flat/test/mod.rs create mode 100644 rfcs/00000-flat-search.md diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs new file mode 100644 index 000000000..bebc1ea6a --- /dev/null +++ b/diskann/src/flat/index.rs @@ -0,0 +1,106 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! [`FlatIndex`] — the index wrapper for flat search. + +use std::marker::PhantomData; +use std::num::NonZeroUsize; + +use diskann_utils::future::SendFuture; +use diskann_vector::PreprocessedDistanceFunction; + +use crate::{ + ANNResult, + error::IntoANNResult, + flat::{ + FlatIterator, FlatPostProcess, FlatSearchStats, FlatSearchStrategy, + }, + graph::SearchOutputBuffer, + neighbor::{Neighbor, NeighborPriorityQueue}, + provider::DataProvider, +}; + +/// A `'static` thin wrapper around a [`DataProvider`] used for flat search. +/// +/// The provider is owned by the index. The index is constructed once at process startup and +/// shared across requests; per-query state lives in the [`crate::flat::FlatIterator`] that +/// the [`crate::flat::FlatSearchStrategy`] produces. +#[derive(Debug)] +pub struct FlatIndex { + /// The backing provider. + pub provider: P, + _marker: PhantomData P>, +} + +impl FlatIndex

{ + /// Construct a new [`FlatIndex`] around `provider`. + pub fn new(provider: P) -> Self { + Self { + provider, + _marker: PhantomData, + } + } + + /// Borrow the underlying provider. + pub fn provider(&self) -> &P { + &self.provider + } + + /// Brute-force k-nearest-neighbor flat search. + /// + /// Streams every element produced by the strategy's iterator through the query + /// computer, keeps the best `k` candidates in a [`NeighborPriorityQueue`], and hands + /// the survivors to the post-processor. + /// + /// # Arguments + /// - `k`: number of nearest neighbors to return. + /// - `strategy`: produces the per-query iterator and the query computer. + /// - `processor`: post-processes the survivor candidates into the output type. + /// - `context`: per-request context threaded through to the provider. + /// - `query`: the query. + /// - `output`: caller-owned output buffer. + pub fn knn_search( + &self, + k: NonZeroUsize, + strategy: &S, + processor: &PP, + context: &P::Context, + query: &T, + output: &mut OB, + ) -> impl SendFuture> + where + S: FlatSearchStrategy, + T: ?Sized + Sync, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, + PP: for<'a> FlatPostProcess, T, O> + Send + Sync, + { + async move { + let mut iter = strategy + .create_iter(&self.provider, context) + .into_ann_result()?; + let computer = strategy.build_query_computer(query).into_ann_result()?; + + let k = k.get(); + let mut queue = NeighborPriorityQueue::new(k); + let mut cmps: u32 = 0; + + iter.on_elements_unordered(|id, element| { + let dist = computer.evaluate_similarity(element); + cmps += 1; + queue.insert(Neighbor::new(id, dist)); + }) + .await + .into_ann_result()?; + + let result_count = processor + .post_process(&mut iter, query, queue.iter().take(k), output) + .await + .into_ann_result()? as u32; + + Ok(FlatSearchStats { cmps, result_count }) + } + } +} diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs new file mode 100644 index 000000000..02ef66f99 --- /dev/null +++ b/diskann/src/flat/iterator.rs @@ -0,0 +1,91 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! [`FlatIterator`] — the sequential access primitive for flat search. + +use diskann_utils::{Reborrow, future::SendFuture}; + +use crate::{error::StandardError, provider::HasId}; + +/// A lending, asynchronous iterator over the elements of a flat index. +/// +/// `FlatIterator` is the streaming counterpart to [`crate::provider::Accessor`]. Where an +/// accessor exposes random retrieval by id, a flat iterator exposes a *sequential* walk — +/// each call to [`Self::next`] advances an internal cursor and yields the next element. +/// +/// Like [`crate::provider::Accessor::get_element`], advancing the cursor is **async**: it +/// may need to await an I/O fetch (e.g., reading the next disk page, awaiting a network +/// response, etc.). Iterators backed by purely in-memory data should return a ready +/// future. +/// +/// The iterator is responsible for: +/// - Choosing the iteration order (buffer-sequential, hash-walked, partitioned, …). +/// - Skipping items that should not be visible to the algorithm (deleted, obsolete, …). +/// - Holding any borrows / locks needed to keep the underlying storage alive. +/// +/// Algorithms see only `(Id, ElementRef)` pairs and treat the stream as opaque. +/// +/// # `Element` vs `ElementRef` +/// +/// Same pattern as [`crate::provider::Accessor`]: +/// +/// - `Element<'a>` is the type returned by `next`. Its lifetime is bound to the iterator +/// borrow at the call site, so only one element is live at a time. +/// - `ElementRef<'a>` is an unconstrained-lifetime reborrow used in distance-function +/// bounds. Required to keep [HRTB](https://doc.rust-lang.org/nomicon/hrtb.html) bounds +/// on query computers from forcing `Self: 'static`. +/// +/// # Hot path +/// +/// Algorithms drive the scan via [`Self::on_elements_unordered`]. The provided +/// implementation simply loops over [`Self::next`]; iterators that can amortize +/// per-element cost (prefetching the next chunk, batching distance computation, +/// performing SIMD-friendly bulk reads) should override it. +pub trait FlatIterator: HasId + Send + Sync { + /// A reference to a yielded element with an unconstrained lifetime, suitable for + /// distance-function HRTB bounds. + type ElementRef<'a>; + + /// The concrete element returned by [`Self::next`]. Reborrows to [`Self::ElementRef`]. + type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync + where + Self: 'a; + + /// The error type yielded by [`Self::next`] and [`Self::on_elements_unordered`]. + type Error: StandardError; + + /// Advance the iterator and asynchronously yield the next `(id, element)` pair. + /// + /// Returns `Ok(None)` when the scan is exhausted. The yielded element borrows from + /// the iterator and is invalidated by the next call to `next`. + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>>; + + /// Drive the entire scan, invoking `f` for each yielded element. + /// + /// The default implementation loops over [`Self::next`]. Implementations that benefit + /// from bulk dispatch (prefetching, batched SIMD distance computation, etc.) should + /// override this method. + /// + /// The order of invocation is unspecified and may differ between calls. The closure + /// `f` is **synchronous**; if you need to await inside the per-element handler, drive + /// the iterator manually with [`Self::next`]. + fn on_elements_unordered( + &mut self, + mut f: F, + ) -> impl SendFuture> + where + F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>), + { + async move { + while let Some((id, element)) = self.next().await? { + f(id, element.reborrow()); + } + Ok(()) + } + } +} + diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs new file mode 100644 index 000000000..0cd9dcc33 --- /dev/null +++ b/diskann/src/flat/mod.rs @@ -0,0 +1,50 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Sequential ("flat") search infrastructure. +//! +//! This module is the streaming counterpart to the random-access [`crate::provider::Accessor`] +//! family. It is designed for backends whose natural access pattern is a one-pass scan over +//! their data — for example append-only buffered stores, on-disk shards streamed via I/O, +//! or any provider where random access is significantly more expensive than sequential. +//! +//! # Architecture +//! +//! The module mirrors the layering used by graph search: +//! +//! | Graph (random access) | Flat (sequential) | +//! | :------------------------------------ | :-------------------------------- | +//! | [`crate::provider::DataProvider`] | [`crate::provider::DataProvider`] | +//! | [`crate::graph::DiskANNIndex`] | [`FlatIndex`] | +//! | [`crate::provider::Accessor`] | [`FlatIterator`] | +//! | [`crate::graph::glue::SearchStrategy`] | [`FlatSearchStrategy`] | +//! | [`crate::graph::glue::SearchPostProcess`] | [`FlatPostProcess`] | +//! | [`crate::graph::Search`] | [`FlatIndex::knn_search`] | +//! +//! # Hot loop +//! +//! Algorithms drive the scan via [`FlatIterator::next`] (lending iterator) or override +//! [`FlatIterator::on_elements_unordered`] when batching/prefetching wins. The default +//! implementation of `on_elements_unordered` simply loops over `next`. +//! +//! See [`FlatIndex::knn_search`] for the canonical brute-force k-NN algorithm built on these +//! primitives. + +pub mod index; +pub mod iterator; +pub mod post_process; +pub mod search; +pub mod stats; +pub mod strategy; + +#[cfg(any(test, feature = "testing"))] +pub mod test; + +pub use index::FlatIndex; +pub use iterator::FlatIterator; +pub use post_process::{CopyFlatIds, FlatPostProcess}; +pub use search::{KnnFlatError, validate_k}; +pub use stats::FlatSearchStats; +pub use strategy::FlatSearchStrategy; diff --git a/diskann/src/flat/post_process.rs b/diskann/src/flat/post_process.rs new file mode 100644 index 000000000..38267fef6 --- /dev/null +++ b/diskann/src/flat/post_process.rs @@ -0,0 +1,73 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! [`FlatPostProcess`] — terminal stage of the flat search pipeline. + +use diskann_utils::future::SendFuture; + +use crate::{ + error::StandardError, flat::FlatIterator, graph::SearchOutputBuffer, neighbor::Neighbor, provider::HasId, +}; + +/// Hydrate / filter / transform the survivor candidates produced by a flat search and +/// write them into an output buffer. +/// +/// This is the flat counterpart to [`crate::graph::glue::SearchPostProcess`]. Processors +/// receive `&mut S` so they can consult any iterator-owned lookup state (e.g., an +/// `Id -> rich-record` table built up during the scan) when assembling outputs. As with +/// the graph counterpart, [`Self::post_process`] is **async** so that processors can +/// hydrate via I/O without blocking. +/// +/// The `O` type parameter lets callers pick the output element type (raw `(Id, f32)` +/// pairs, fully hydrated hits, etc.). +pub trait FlatPostProcess::Id> +where + S: FlatIterator, + T: ?Sized, +{ + /// Errors yielded by [`Self::post_process`]. + type Error: StandardError; + + /// Consume `candidates` (in distance order) and write at most `k` results into + /// `output`. Returns the number of results written. + fn post_process( + &self, + iter: &mut S, + query: &T, + candidates: I, + output: &mut B, + ) -> impl SendFuture> + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized; +} + +/// A trivial [`FlatPostProcess`] that copies each `(Id, distance)` pair straight into the +/// output buffer. +#[derive(Debug, Default, Clone, Copy)] +pub struct CopyFlatIds; + +impl FlatPostProcess for CopyFlatIds +where + S: FlatIterator, + T: ?Sized, +{ + type Error = crate::error::Infallible; + + fn post_process( + &self, + _iter: &mut S, + _query: &T, + candidates: I, + output: &mut B, + ) -> impl SendFuture> + where + I: Iterator::Id>> + Send, + B: SearchOutputBuffer<::Id> + Send + ?Sized, + { + let count = output.extend(candidates.map(|n| (n.id, n.distance))); + std::future::ready(Ok(count)) + } +} diff --git a/diskann/src/flat/search.rs b/diskann/src/flat/search.rs new file mode 100644 index 000000000..b51b561f1 --- /dev/null +++ b/diskann/src/flat/search.rs @@ -0,0 +1,35 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Error types for flat search parameter validation. + +use std::num::NonZeroUsize; + +use thiserror::Error; + +use crate::{ANNError, ANNErrorKind}; + +/// Errors raised when validating flat search parameters. +#[derive(Debug, Error)] +pub enum KnnFlatError { + /// `k` was zero. + #[error("k cannot be zero")] + KZero, +} + +impl From for ANNError { + #[track_caller] + fn from(err: KnnFlatError) -> Self { + Self::new(ANNErrorKind::IndexError, err) + } +} + +/// Validate and wrap a `k` value as [`NonZeroUsize`]. +/// +/// This is a convenience for callers that want to validate `k` before passing it to +/// [`FlatIndex::knn_search`](crate::flat::FlatIndex::knn_search). +pub fn validate_k(k: usize) -> Result { + NonZeroUsize::new(k).ok_or(KnnFlatError::KZero) +} diff --git a/diskann/src/flat/stats.rs b/diskann/src/flat/stats.rs new file mode 100644 index 000000000..faf14f888 --- /dev/null +++ b/diskann/src/flat/stats.rs @@ -0,0 +1,16 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Statistics returned by a flat search. + +/// Statistics collected during a single flat search invocation. +#[derive(Debug, Clone, Copy, Default)] +pub struct FlatSearchStats { + /// Number of distance computations performed (i.e., elements visited by the scanner). + pub cmps: u32, + + /// Number of results written into the caller-provided output buffer. + pub result_count: u32, +} diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs new file mode 100644 index 000000000..8b0f5e445 --- /dev/null +++ b/diskann/src/flat/strategy.rs @@ -0,0 +1,77 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! [`FlatSearchStrategy`] — glue between [`DataProvider`] and per-query [`FlatIterator`]s. + +use diskann_vector::PreprocessedDistanceFunction; + +use crate::{ + error::StandardError, + flat::FlatIterator, + provider::DataProvider, +}; + +/// Per-call configuration that knows how to construct a [`FlatIterator`] for a provider +/// and how to pre-process queries of type `T` into a distance computer. +/// +/// `FlatSearchStrategy` is the flat counterpart to [`crate::graph::glue::SearchStrategy`]. +/// A strategy instance is stateless config — typically constructed at the call site, used +/// for one search, and dropped. +/// +/// # Why two methods? +/// +/// - [`Self::create_iter`] is query-independent and may be called multiple times per +/// request (e.g., once per parallel query in a batched search). +/// - [`Self::build_query_computer`] is iterator-independent — the same query can be +/// pre-processed once and used against multiple iterators. +/// +/// Both methods may borrow from the strategy itself. +/// +/// # Type parameters +/// +/// - `Provider`: the [`DataProvider`] that backs the index. +/// - `T`: the query type. Often `[E]` for vector queries; can be any `?Sized` type. +pub trait FlatSearchStrategy: Send + Sync +where + P: DataProvider, + T: ?Sized, +{ + /// The iterator type produced by [`Self::create_iter`]. Borrows from `self` and the + /// provider. + type Iter<'a>: FlatIterator + where + Self: 'a, + P: 'a; + + /// The query computer produced by [`Self::build_query_computer`]. + /// + /// The HRTB on `ElementRef` ensures the same computer can score every element yielded + /// by every lifetime of `Iter`. Two lifetimes are needed: `'a` for the iterator + /// instance and `'b` for the reborrowed element. + type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< + as FlatIterator>::ElementRef<'b>, + f32, + > + Send + + Sync + + 'static; + + /// The error type for both factory methods. + type Error: StandardError; + + /// Construct a fresh iterator over `provider` for the given request `context`. + /// + /// This is where lock acquisition, snapshot pinning, and any other per-query setup + /// should happen. The returned iterator owns whatever borrows / guards it needs to + /// remain valid until it is dropped. + fn create_iter<'a>( + &'a self, + provider: &'a P, + context: &'a P::Context, + ) -> Result, Self::Error>; + + /// Pre-process a query into a [`Self::QueryComputer`] usable for distance computation + /// against any iterator produced by [`Self::create_iter`]. + fn build_query_computer(&self, query: &T) -> Result; +} diff --git a/diskann/src/flat/test/mod.rs b/diskann/src/flat/test/mod.rs new file mode 100644 index 000000000..61307ab61 --- /dev/null +++ b/diskann/src/flat/test/mod.rs @@ -0,0 +1,293 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Trivial in-memory provider, iterator, and strategy used for unit-testing the flat +//! search infrastructure. +//! +//! This is intentionally simple: vectors live in a `Vec>`, ids are `u32`, and +//! distance is squared Euclidean. It exists so the trait shapes in [`crate::flat`] can be +//! exercised end-to-end without dragging in any provider-side machinery. + +use diskann_utils::future::SendFuture; +use diskann_vector::PreprocessedDistanceFunction; +use thiserror::Error; + +use crate::{ + always_escalate, + ANNError, ANNErrorKind, + flat::{FlatIterator, FlatPostProcess, FlatSearchStrategy}, + graph::SearchOutputBuffer, + neighbor::Neighbor, + provider::{DataProvider, DefaultContext, HasId, NoopGuard}, +}; + +/// Trivial flat provider holding a list of fixed-dimension `f32` vectors. +#[derive(Debug)] +pub struct InMemoryFlatProvider { + pub dim: usize, + pub vectors: Vec>, +} + +impl InMemoryFlatProvider { + pub fn new(dim: usize, vectors: Vec>) -> Self { + Self { dim, vectors } + } +} + +#[derive(Debug, Error)] +#[error("invalid vector id {0}")] +pub struct InMemoryProviderError(u32); + +impl From for ANNError { + #[track_caller] + fn from(err: InMemoryProviderError) -> Self { + ANNError::new(ANNErrorKind::IndexError, err) + } +} + +always_escalate!(InMemoryProviderError); + +impl DataProvider for InMemoryFlatProvider { + type Context = DefaultContext; + type InternalId = u32; + type ExternalId = u32; + type Error = InMemoryProviderError; + type Guard = NoopGuard; + + fn to_internal_id( + &self, + _context: &Self::Context, + gid: &u32, + ) -> Result { + if (*gid as usize) < self.vectors.len() { + Ok(*gid) + } else { + Err(InMemoryProviderError(*gid)) + } + } + + fn to_external_id( + &self, + _context: &Self::Context, + id: u32, + ) -> Result { + if (id as usize) < self.vectors.len() { + Ok(id) + } else { + Err(InMemoryProviderError(id)) + } + } +} + +/// Sequential iterator over [`InMemoryFlatProvider`]. +pub struct InMemoryIterator<'a> { + vectors: &'a [Vec], + cursor: u32, +} + +#[derive(Debug, Error)] +#[error("in-memory iterator does not error")] +pub struct InMemoryIteratorError; + +impl From for ANNError { + #[track_caller] + fn from(err: InMemoryIteratorError) -> Self { + ANNError::new(ANNErrorKind::IndexError, err) + } +} + +always_escalate!(InMemoryIteratorError); + +impl<'a> HasId for InMemoryIterator<'a> { + type Id = u32; +} + +impl<'a> FlatIterator for InMemoryIterator<'a> { + type ElementRef<'b> = &'b [f32]; + type Element<'b> + = &'b [f32] + where + Self: 'b; + type Error = InMemoryIteratorError; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>> { + let idx = self.cursor as usize; + let result = self.vectors.get(idx).map(|v| { + self.cursor += 1; + (idx as u32, v.as_slice()) + }); + std::future::ready(Ok(result)) + } +} + +/// Squared Euclidean computer: holds a copy of the query and scores against `&[f32]`. +#[derive(Debug, Clone)] +pub struct L2QueryComputer { + query: Vec, +} + +impl<'a> PreprocessedDistanceFunction<&'a [f32], f32> for L2QueryComputer { + fn evaluate_similarity(&self, changing: &'a [f32]) -> f32 { + debug_assert_eq!(self.query.len(), changing.len()); + self.query + .iter() + .zip(changing.iter()) + .map(|(a, b)| { + let d = a - b; + d * d + }) + .sum() + } +} + +/// Strategy: produces an [`InMemoryIterator`] and an [`L2QueryComputer`]. +#[derive(Debug, Default, Clone, Copy)] +pub struct InMemoryStrategy; + +#[derive(Debug, Error)] +pub enum InMemoryStrategyError { + #[error("query length {query} does not match provider dimension {dim}")] + DimMismatch { query: usize, dim: usize }, +} + +impl From for ANNError { + #[track_caller] + fn from(err: InMemoryStrategyError) -> Self { + ANNError::new(ANNErrorKind::IndexError, err) + } +} + +impl FlatSearchStrategy for InMemoryStrategy { + type Iter<'a> = InMemoryIterator<'a>; + type QueryComputer = L2QueryComputer; + type Error = InMemoryStrategyError; + + fn create_iter<'a>( + &'a self, + provider: &'a InMemoryFlatProvider, + _context: &'a DefaultContext, + ) -> Result, Self::Error> { + Ok(InMemoryIterator { + vectors: &provider.vectors, + cursor: 0, + }) + } + + fn build_query_computer(&self, query: &[f32]) -> Result { + Ok(L2QueryComputer { + query: query.to_vec(), + }) + } +} + +/// Post-processor that copies the surviving `(id, distance)` pairs straight to output. +/// +/// Identical in behavior to [`crate::flat::CopyFlatIds`] but typed concretely against the +/// in-memory iterator, useful in tests where we want to assert against the exact output +/// shape. +#[derive(Debug, Default, Clone, Copy)] +pub struct CopyInMemoryHits; + +impl<'a> FlatPostProcess, [f32]> for CopyInMemoryHits { + type Error = crate::error::Infallible; + + fn post_process( + &self, + _iter: &mut InMemoryIterator<'a>, + _query: &[f32], + candidates: I, + output: &mut B, + ) -> impl SendFuture> + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + let count = output.extend(candidates.map(|n| (n.id, n.distance))); + std::future::ready(Ok(count)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + flat::{CopyFlatIds, FlatIndex, validate_k}, + neighbor::Neighbor, + }; + + fn build_provider() -> InMemoryFlatProvider { + // 5 two-dimensional points; the closest to (0.0, 0.0) is index 0. + InMemoryFlatProvider::new( + 2, + vec![ + vec![0.1, 0.0], // d^2 = 0.01 + vec![1.0, 0.0], // d^2 = 1.00 + vec![0.0, 0.5], // d^2 = 0.25 + vec![5.0, 5.0], // d^2 = 50.0 + vec![-0.2, 0.1], // d^2 = 0.05 + ], + ) + } + + #[tokio::test] + async fn knn_flat_returns_top_k_in_distance_order() { + let provider = build_provider(); + let index = FlatIndex::new(provider); + let strategy = InMemoryStrategy; + let processor = CopyInMemoryHits; + let query = vec![0.0_f32, 0.0]; + + let mut output: Vec> = Vec::new(); + let stats = index + .knn_search( + validate_k(3).unwrap(), + &strategy, + &processor, + &DefaultContext, + query.as_slice(), + &mut output, + ) + .await + .expect("search succeeds"); + + assert_eq!(stats.cmps, 5); + assert_eq!(stats.result_count, 3); + + let ids: Vec = output.iter().map(|n| n.id).collect(); + assert_eq!(ids, vec![0, 4, 2]); + } + + #[tokio::test] + async fn knn_flat_with_k_larger_than_n_returns_all() { + let provider = build_provider(); + let index = FlatIndex::new(provider); + let strategy = InMemoryStrategy; + let processor = CopyFlatIds; + let query = vec![0.0_f32, 0.0]; + + let mut output: Vec> = Vec::new(); + let stats = index + .knn_search( + validate_k(100).unwrap(), + &strategy, + &processor, + &DefaultContext, + query.as_slice(), + &mut output, + ) + .await + .expect("search succeeds"); + + assert_eq!(stats.cmps, 5); + assert_eq!(stats.result_count, 5); + } + + #[test] + fn knn_flat_rejects_zero_k() { + assert!(validate_k(0).is_err()); + } +} diff --git a/diskann/src/lib.rs b/diskann/src/lib.rs index 71cb3ed41..9c1f6ac76 100644 --- a/diskann/src/lib.rs +++ b/diskann/src/lib.rs @@ -13,6 +13,7 @@ pub mod utils; pub(crate) mod internal; // Index Implementations +pub mod flat; pub mod graph; // Top level exports. diff --git a/rfcs/00000-flat-search.md b/rfcs/00000-flat-search.md new file mode 100644 index 000000000..78606ddb1 --- /dev/null +++ b/rfcs/00000-flat-search.md @@ -0,0 +1,217 @@ +# Flat Search + +| | | +|------------------|--------------------------------| +| **Authors** | Aditya Krishnan, Alex Razumov, Dongliang Wu | +| **Created** | 2026-04-24 | +| **Updated** | 2026-04-27 | + +## Motivation + +### Background + +DiskANN today exposes a single abstraction family centered on the +[`crate::provider::Accessor`] trait. Accessors are random access by design since the graph greedy search algorithm needs to decide which ids to fetch and the accessor materializes the corresponding elements (vectors, quantized vectors and neighbor lists) on demand. This is the right contract for graph search, where neighborhood expansion is inherently random-access against the [`crate::provider::DataProvider`]. + +A growing class of consumers diverge from our current pattern of use by accesssing their index **sequentially**. Some consumers build their index in an "append-only" fashion and require that they walk the index in a sequential, fixed order, relying on iteration position to enforce versioning / deduplication invariants. + +### Problem Statement + +The problem-statement here is simple: provide first-class support for sequential, one-pass scans over a data backend without +stuffing the algorithm or the backend through the `Accessor` trait surface. + +### Goals + +1. Define a streaming access primitive — `FlatIterator` — that mirrors the role + `Accessor` plays for graph search but exposes a lending-iterator interface instead of + a random-access one. +2. Provide flat-search algorithm implementations (with `knn_search` as default and filtered and diverse variants to opt-into) built on the new + primitives, so consumers can use this against their own providers / backends. +3. Expose support for features and implementations native to the repo like quantized distance computers out-of-the-box. + +## Proposal + +Let's start with the main analog to the `Accessor` trait for the `FlatIndex` - `FlatIterator`. + + +### `FlatIterator` + +```rust +pub trait FlatIterator: HasId + Send + Sync { // Has Id support + // Element yielded by iterator + type ElementRef<'a>; + + // Mostly machinery to play nice with HRTB + type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync + where + Self: 'a; + + type Error: StandardError; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>>; + + // Default implementation for driving a closure on the items in the index. + fn on_elements_unordered( + &mut self, + mut f: F, + ) -> impl SendFuture> + where F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>), + { + async move { + while let Some((id, element)) = self.next().await? { + f(id, element.reborrow()); + } + + Ok(()) + } + } +} +``` + +The trait combines two access patterns: + +- A required lending-iterator `next()`. +- A defaulted bulk method `on_elements_unordered` that consumes the entire scan via a + callback. The default impl loops over `next`; iterators that benefit from prefetching, + SIMD batching, or amortized per-element cost could override it. + +Both methods are **async** (returning `impl SendFuture<...>`), matching +[`crate::provider::Accessor::get_element`]. Iterators backed by I/O — disk pages, +remote shards — return a real future; in-memory iterators wrap their result in +`std::future::ready`. + +The `Element` / `ElementRef` split is identical to `Accessor` and exists for the same +reason: to keep HRTB bounds on query computers from inducing `'static` requirements on +the iterator type. + + +### The glue: `FlatSearchStrategy` + +While the `FlatIterator` is the primary object that provides access to the elements in the index for the algorithm, it is scoped to each query. We intorduce a constructor - `FlatSearchStrategy` - similar to `SearchStrategy` for `Accessor` to instantiate this object. A strategy is per-call configuration: stateless, cheap to construct, scoped to one +search. It produces both a per-query iterator and a query computer. + +```rust +pub trait FlatSearchStrategy: Send + Sync +where + P: DataProvider, + T: ?Sized, +{ + /// The iterator type produced by [`Self::create_iter`]. Borrows from `self` and the + /// provider. + type Iter<'a>: FlatIterator + where + Self: 'a, + + /// The query computer produced by [`Self::build_query_computer`]. + type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< + as FlatIterator>::ElementRef<'b>, + f32, + > + Send + + Sync + + 'static; + + /// The error type for both factory methods. + type Error: StandardError; + + /// Construct a fresh iterator over `provider` for the given request `context`. + fn create_iter<'a>( + &'a self, + provider: &'a P, + context: &'a P::Context, + ) -> Result, Self::Error>; + + /// Pre-process a query into a [`Self::QueryComputer`] usable for distance computation + /// against any iterator produced by [`Self::create_iter`]. + fn build_query_computer(&self, query: &T) -> Result; +} +``` + +The `ElementRef<'b>` that the distance function `QueryComputer` acts on is tied to the (reborrowed) element yielded by the `FlatIterator::next()`. + +### `FlatIndex` + +`FlatIndex` is a thin `'static` wrapper around a `DataProvider`. The same `DataProvider` +trait used by graph search is reused here — flat and graph subsystems share a single +provider surface and the same `Context` / id-mapping / error machinery. + +```rust +pub struct FlatIndex { + provider: P, + /* private */ +} + +impl FlatIndex

{ + pub fn new(provider: P) -> Self; + pub fn provider(&self) -> &P; + + pub fn knn_search( + &self, + k: NonZeroUsize, + strategy: &S, + processor: &PP, + context: &P::Context, + query: &T, + output: &mut OB, + ) -> impl SendFuture> + where + S: FlatSearchStrategy, + T: ?Sized + Sync, + O: Send, + OB: SearchOutputBuffer + Send + ?Sized, +} +``` + +The `knn_search` method is the canonical brute-force search algorithm: + +1. Construct the iterator via `strategy.create_iter` to obtain a scoped iterator over the elements. +2. Build the query computer via `strategy.build_query_computer`. +3. Drive the scan via `iter.on_elements_unordered`, scoring each element and + inserting `Neighbor`s into a `NeighborPriorityQueue` of capacity `k`. +4. Hand the survivors (in distance order) to `processor.post_process`. +5. Return search stats. + +Other algorithms (filtered, range, diverse) can be added later as additional methods on +`FlatIndex`. + +## Trade-offs + +### Reusing `DataProvider` + +This design leans into using the `DataProvider` trait which requires implementations to implement `InternalId` and `ExternalId` conversions (via the context). Arguably, this requirement is too restrictive for some consumers of a flat-index. Reasons for sticking with `DataProvider`: + +- Every concrete provider already implements `DataProvider`, so a separate trait adds + an abstraction that existing consumers will have to implement if they want to opt-in to the flat-index path. +- Sharing `DataProvider` means the `Context`, id-mapping (`to_internal_id` / + `to_external_id`), and error machinery are identical across graph and flat search, + reducing the learning surface for new contributors. + +### Async vs sync API for `FlatIterator` + +`next()` and `on_elements_unordered` return a future, making the trait +async. This is the right default for disk-backed and network-backed iterators +where advancing the cursor involves real I/O. It also matches the `Accessor` surface, +keeping the two subsystems shaped the same way. + +The cost is paid by in-memory consumers: every call to `next()` goes through the future +machinery even when the result is immediately available via `std::future::ready`. In a +tight brute-force loop this overhead — poll scaffolding, pinning etc — could be measurable. + +We chose async because the wider audience of consumers (disk, network, mixed) benefits +more than in-memory consumers lose. + +### Expand `Element` to support batched distance computation? + +The current design yields one element per `next()` call, and the query computer scores +elements one at a time via `PreprocessedDistanceFunction::evaluate_similarity`. This could leave some optimization and performance on the table; especially with the upcoming effort around batched distance kernels. + +An alternative is to make `next()` yield a *batch* instead of a single vector representation like `Element<'_>`. Some work will need to be done to define the right interaction between the batch type, the element type in the batch, the interaction with `QueryComputer`'s types and way IDs and distances are collected in the queue. + +We opted for the scalar-per-element design for now because it is simpler to implement and +reason about. The hope is that batched distance computation can be layered on later as an opt-in sub-trait without breaking +existing iterators. + +## Future Work +- Support for other flat-search algorithms like - filtered, range and diverse flat algorithms as additional methods on `FlatIndex`. + From 9f718a13387745befc908688ded5f064066ded6f Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Mon, 27 Apr 2026 22:31:41 -0700 Subject: [PATCH 02/24] remove unnecessary parts --- diskann/src/flat/index.rs | 17 +- diskann/src/flat/iterator.rs | 37 +--- diskann/src/flat/mod.rs | 7 - diskann/src/flat/post_process.rs | 8 +- diskann/src/flat/search.rs | 35 ---- diskann/src/flat/stats.rs | 16 -- diskann/src/flat/test/mod.rs | 293 ------------------------------- 7 files changed, 16 insertions(+), 397 deletions(-) delete mode 100644 diskann/src/flat/search.rs delete mode 100644 diskann/src/flat/stats.rs delete mode 100644 diskann/src/flat/test/mod.rs diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs index bebc1ea6a..04227b1da 100644 --- a/diskann/src/flat/index.rs +++ b/diskann/src/flat/index.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -//! [`FlatIndex`] — the index wrapper for flat search. +//! [`FlatIndex`] — the index wrapper for an on which we do flat search. use std::marker::PhantomData; use std::num::NonZeroUsize; @@ -15,9 +15,9 @@ use crate::{ ANNResult, error::IntoANNResult, flat::{ - FlatIterator, FlatPostProcess, FlatSearchStats, FlatSearchStrategy, + FlatIterator, FlatPostProcess, FlatSearchStrategy, }, - graph::SearchOutputBuffer, + graph::{SearchOutputBuffer, index::SearchStats}, neighbor::{Neighbor, NeighborPriorityQueue}, provider::DataProvider, }; @@ -56,7 +56,7 @@ impl FlatIndex

{ /// /// # Arguments /// - `k`: number of nearest neighbors to return. - /// - `strategy`: produces the per-query iterator and the query computer. + /// - `strategy`: produces the per-query iterator and the query computer. See [`FlatSearchStrategy`] /// - `processor`: post-processes the survivor candidates into the output type. /// - `context`: per-request context threaded through to the provider. /// - `query`: the query. @@ -69,7 +69,7 @@ impl FlatIndex

{ context: &P::Context, query: &T, output: &mut OB, - ) -> impl SendFuture> + ) -> impl SendFuture> where S: FlatSearchStrategy, T: ?Sized + Sync, @@ -100,7 +100,12 @@ impl FlatIndex

{ .await .into_ann_result()? as u32; - Ok(FlatSearchStats { cmps, result_count }) + Ok(SearchStats { + cmps, + hops: 0, + result_count, + range_search_second_round: false, + }) } } } diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index 02ef66f99..895c909ad 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -//! [`FlatIterator`] — the sequential access primitive for flat search. +//! [`FlatIterator`] — the sequential access primitive for accessing a flat index. use diskann_utils::{Reborrow, future::SendFuture}; @@ -15,34 +15,7 @@ use crate::{error::StandardError, provider::HasId}; /// accessor exposes random retrieval by id, a flat iterator exposes a *sequential* walk — /// each call to [`Self::next`] advances an internal cursor and yields the next element. /// -/// Like [`crate::provider::Accessor::get_element`], advancing the cursor is **async**: it -/// may need to await an I/O fetch (e.g., reading the next disk page, awaiting a network -/// response, etc.). Iterators backed by purely in-memory data should return a ready -/// future. -/// -/// The iterator is responsible for: -/// - Choosing the iteration order (buffer-sequential, hash-walked, partitioned, …). -/// - Skipping items that should not be visible to the algorithm (deleted, obsolete, …). -/// - Holding any borrows / locks needed to keep the underlying storage alive. -/// /// Algorithms see only `(Id, ElementRef)` pairs and treat the stream as opaque. -/// -/// # `Element` vs `ElementRef` -/// -/// Same pattern as [`crate::provider::Accessor`]: -/// -/// - `Element<'a>` is the type returned by `next`. Its lifetime is bound to the iterator -/// borrow at the call site, so only one element is live at a time. -/// - `ElementRef<'a>` is an unconstrained-lifetime reborrow used in distance-function -/// bounds. Required to keep [HRTB](https://doc.rust-lang.org/nomicon/hrtb.html) bounds -/// on query computers from forcing `Self: 'static`. -/// -/// # Hot path -/// -/// Algorithms drive the scan via [`Self::on_elements_unordered`]. The provided -/// implementation simply loops over [`Self::next`]; iterators that can amortize -/// per-element cost (prefetching the next chunk, batching distance computation, -/// performing SIMD-friendly bulk reads) should override it. pub trait FlatIterator: HasId + Send + Sync { /// A reference to a yielded element with an unconstrained lifetime, suitable for /// distance-function HRTB bounds. @@ -66,13 +39,7 @@ pub trait FlatIterator: HasId + Send + Sync { /// Drive the entire scan, invoking `f` for each yielded element. /// - /// The default implementation loops over [`Self::next`]. Implementations that benefit - /// from bulk dispatch (prefetching, batched SIMD distance computation, etc.) should - /// override this method. - /// - /// The order of invocation is unspecified and may differ between calls. The closure - /// `f` is **synchronous**; if you need to await inside the per-element handler, drive - /// the iterator manually with [`Self::next`]. + /// The default implementation loops over [`Self::next`]. fn on_elements_unordered( &mut self, mut f: F, diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs index 0cd9dcc33..34fe62ac8 100644 --- a/diskann/src/flat/mod.rs +++ b/diskann/src/flat/mod.rs @@ -35,16 +35,9 @@ pub mod index; pub mod iterator; pub mod post_process; -pub mod search; -pub mod stats; pub mod strategy; -#[cfg(any(test, feature = "testing"))] -pub mod test; - pub use index::FlatIndex; pub use iterator::FlatIterator; pub use post_process::{CopyFlatIds, FlatPostProcess}; -pub use search::{KnnFlatError, validate_k}; -pub use stats::FlatSearchStats; pub use strategy::FlatSearchStrategy; diff --git a/diskann/src/flat/post_process.rs b/diskann/src/flat/post_process.rs index 38267fef6..2f95c2932 100644 --- a/diskann/src/flat/post_process.rs +++ b/diskann/src/flat/post_process.rs @@ -11,17 +11,15 @@ use crate::{ error::StandardError, flat::FlatIterator, graph::SearchOutputBuffer, neighbor::Neighbor, provider::HasId, }; -/// Hydrate / filter / transform the survivor candidates produced by a flat search and +/// Post-process the survivor candidates produced by a flat search and /// write them into an output buffer. /// /// This is the flat counterpart to [`crate::graph::glue::SearchPostProcess`]. Processors /// receive `&mut S` so they can consult any iterator-owned lookup state (e.g., an -/// `Id -> rich-record` table built up during the scan) when assembling outputs. As with -/// the graph counterpart, [`Self::post_process`] is **async** so that processors can -/// hydrate via I/O without blocking. +/// `Id -> rich-record` table built up during the scan) when assembling outputs. /// /// The `O` type parameter lets callers pick the output element type (raw `(Id, f32)` -/// pairs, fully hydrated hits, etc.). +/// pairs, fully hydrated hits etc.). pub trait FlatPostProcess::Id> where S: FlatIterator, diff --git a/diskann/src/flat/search.rs b/diskann/src/flat/search.rs deleted file mode 100644 index b51b561f1..000000000 --- a/diskann/src/flat/search.rs +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -//! Error types for flat search parameter validation. - -use std::num::NonZeroUsize; - -use thiserror::Error; - -use crate::{ANNError, ANNErrorKind}; - -/// Errors raised when validating flat search parameters. -#[derive(Debug, Error)] -pub enum KnnFlatError { - /// `k` was zero. - #[error("k cannot be zero")] - KZero, -} - -impl From for ANNError { - #[track_caller] - fn from(err: KnnFlatError) -> Self { - Self::new(ANNErrorKind::IndexError, err) - } -} - -/// Validate and wrap a `k` value as [`NonZeroUsize`]. -/// -/// This is a convenience for callers that want to validate `k` before passing it to -/// [`FlatIndex::knn_search`](crate::flat::FlatIndex::knn_search). -pub fn validate_k(k: usize) -> Result { - NonZeroUsize::new(k).ok_or(KnnFlatError::KZero) -} diff --git a/diskann/src/flat/stats.rs b/diskann/src/flat/stats.rs deleted file mode 100644 index faf14f888..000000000 --- a/diskann/src/flat/stats.rs +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -//! Statistics returned by a flat search. - -/// Statistics collected during a single flat search invocation. -#[derive(Debug, Clone, Copy, Default)] -pub struct FlatSearchStats { - /// Number of distance computations performed (i.e., elements visited by the scanner). - pub cmps: u32, - - /// Number of results written into the caller-provided output buffer. - pub result_count: u32, -} diff --git a/diskann/src/flat/test/mod.rs b/diskann/src/flat/test/mod.rs deleted file mode 100644 index 61307ab61..000000000 --- a/diskann/src/flat/test/mod.rs +++ /dev/null @@ -1,293 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -//! Trivial in-memory provider, iterator, and strategy used for unit-testing the flat -//! search infrastructure. -//! -//! This is intentionally simple: vectors live in a `Vec>`, ids are `u32`, and -//! distance is squared Euclidean. It exists so the trait shapes in [`crate::flat`] can be -//! exercised end-to-end without dragging in any provider-side machinery. - -use diskann_utils::future::SendFuture; -use diskann_vector::PreprocessedDistanceFunction; -use thiserror::Error; - -use crate::{ - always_escalate, - ANNError, ANNErrorKind, - flat::{FlatIterator, FlatPostProcess, FlatSearchStrategy}, - graph::SearchOutputBuffer, - neighbor::Neighbor, - provider::{DataProvider, DefaultContext, HasId, NoopGuard}, -}; - -/// Trivial flat provider holding a list of fixed-dimension `f32` vectors. -#[derive(Debug)] -pub struct InMemoryFlatProvider { - pub dim: usize, - pub vectors: Vec>, -} - -impl InMemoryFlatProvider { - pub fn new(dim: usize, vectors: Vec>) -> Self { - Self { dim, vectors } - } -} - -#[derive(Debug, Error)] -#[error("invalid vector id {0}")] -pub struct InMemoryProviderError(u32); - -impl From for ANNError { - #[track_caller] - fn from(err: InMemoryProviderError) -> Self { - ANNError::new(ANNErrorKind::IndexError, err) - } -} - -always_escalate!(InMemoryProviderError); - -impl DataProvider for InMemoryFlatProvider { - type Context = DefaultContext; - type InternalId = u32; - type ExternalId = u32; - type Error = InMemoryProviderError; - type Guard = NoopGuard; - - fn to_internal_id( - &self, - _context: &Self::Context, - gid: &u32, - ) -> Result { - if (*gid as usize) < self.vectors.len() { - Ok(*gid) - } else { - Err(InMemoryProviderError(*gid)) - } - } - - fn to_external_id( - &self, - _context: &Self::Context, - id: u32, - ) -> Result { - if (id as usize) < self.vectors.len() { - Ok(id) - } else { - Err(InMemoryProviderError(id)) - } - } -} - -/// Sequential iterator over [`InMemoryFlatProvider`]. -pub struct InMemoryIterator<'a> { - vectors: &'a [Vec], - cursor: u32, -} - -#[derive(Debug, Error)] -#[error("in-memory iterator does not error")] -pub struct InMemoryIteratorError; - -impl From for ANNError { - #[track_caller] - fn from(err: InMemoryIteratorError) -> Self { - ANNError::new(ANNErrorKind::IndexError, err) - } -} - -always_escalate!(InMemoryIteratorError); - -impl<'a> HasId for InMemoryIterator<'a> { - type Id = u32; -} - -impl<'a> FlatIterator for InMemoryIterator<'a> { - type ElementRef<'b> = &'b [f32]; - type Element<'b> - = &'b [f32] - where - Self: 'b; - type Error = InMemoryIteratorError; - - fn next( - &mut self, - ) -> impl SendFuture)>, Self::Error>> { - let idx = self.cursor as usize; - let result = self.vectors.get(idx).map(|v| { - self.cursor += 1; - (idx as u32, v.as_slice()) - }); - std::future::ready(Ok(result)) - } -} - -/// Squared Euclidean computer: holds a copy of the query and scores against `&[f32]`. -#[derive(Debug, Clone)] -pub struct L2QueryComputer { - query: Vec, -} - -impl<'a> PreprocessedDistanceFunction<&'a [f32], f32> for L2QueryComputer { - fn evaluate_similarity(&self, changing: &'a [f32]) -> f32 { - debug_assert_eq!(self.query.len(), changing.len()); - self.query - .iter() - .zip(changing.iter()) - .map(|(a, b)| { - let d = a - b; - d * d - }) - .sum() - } -} - -/// Strategy: produces an [`InMemoryIterator`] and an [`L2QueryComputer`]. -#[derive(Debug, Default, Clone, Copy)] -pub struct InMemoryStrategy; - -#[derive(Debug, Error)] -pub enum InMemoryStrategyError { - #[error("query length {query} does not match provider dimension {dim}")] - DimMismatch { query: usize, dim: usize }, -} - -impl From for ANNError { - #[track_caller] - fn from(err: InMemoryStrategyError) -> Self { - ANNError::new(ANNErrorKind::IndexError, err) - } -} - -impl FlatSearchStrategy for InMemoryStrategy { - type Iter<'a> = InMemoryIterator<'a>; - type QueryComputer = L2QueryComputer; - type Error = InMemoryStrategyError; - - fn create_iter<'a>( - &'a self, - provider: &'a InMemoryFlatProvider, - _context: &'a DefaultContext, - ) -> Result, Self::Error> { - Ok(InMemoryIterator { - vectors: &provider.vectors, - cursor: 0, - }) - } - - fn build_query_computer(&self, query: &[f32]) -> Result { - Ok(L2QueryComputer { - query: query.to_vec(), - }) - } -} - -/// Post-processor that copies the surviving `(id, distance)` pairs straight to output. -/// -/// Identical in behavior to [`crate::flat::CopyFlatIds`] but typed concretely against the -/// in-memory iterator, useful in tests where we want to assert against the exact output -/// shape. -#[derive(Debug, Default, Clone, Copy)] -pub struct CopyInMemoryHits; - -impl<'a> FlatPostProcess, [f32]> for CopyInMemoryHits { - type Error = crate::error::Infallible; - - fn post_process( - &self, - _iter: &mut InMemoryIterator<'a>, - _query: &[f32], - candidates: I, - output: &mut B, - ) -> impl SendFuture> - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized, - { - let count = output.extend(candidates.map(|n| (n.id, n.distance))); - std::future::ready(Ok(count)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - flat::{CopyFlatIds, FlatIndex, validate_k}, - neighbor::Neighbor, - }; - - fn build_provider() -> InMemoryFlatProvider { - // 5 two-dimensional points; the closest to (0.0, 0.0) is index 0. - InMemoryFlatProvider::new( - 2, - vec![ - vec![0.1, 0.0], // d^2 = 0.01 - vec![1.0, 0.0], // d^2 = 1.00 - vec![0.0, 0.5], // d^2 = 0.25 - vec![5.0, 5.0], // d^2 = 50.0 - vec![-0.2, 0.1], // d^2 = 0.05 - ], - ) - } - - #[tokio::test] - async fn knn_flat_returns_top_k_in_distance_order() { - let provider = build_provider(); - let index = FlatIndex::new(provider); - let strategy = InMemoryStrategy; - let processor = CopyInMemoryHits; - let query = vec![0.0_f32, 0.0]; - - let mut output: Vec> = Vec::new(); - let stats = index - .knn_search( - validate_k(3).unwrap(), - &strategy, - &processor, - &DefaultContext, - query.as_slice(), - &mut output, - ) - .await - .expect("search succeeds"); - - assert_eq!(stats.cmps, 5); - assert_eq!(stats.result_count, 3); - - let ids: Vec = output.iter().map(|n| n.id).collect(); - assert_eq!(ids, vec![0, 4, 2]); - } - - #[tokio::test] - async fn knn_flat_with_k_larger_than_n_returns_all() { - let provider = build_provider(); - let index = FlatIndex::new(provider); - let strategy = InMemoryStrategy; - let processor = CopyFlatIds; - let query = vec![0.0_f32, 0.0]; - - let mut output: Vec> = Vec::new(); - let stats = index - .knn_search( - validate_k(100).unwrap(), - &strategy, - &processor, - &DefaultContext, - query.as_slice(), - &mut output, - ) - .await - .expect("search succeeds"); - - assert_eq!(stats.cmps, 5); - assert_eq!(stats.result_count, 5); - } - - #[test] - fn knn_flat_rejects_zero_k() { - assert!(validate_k(0).is_err()); - } -} From f0a9dbd2ebb68ef89060ece0b0317edf1b3a247e Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Mon, 27 Apr 2026 22:37:48 -0700 Subject: [PATCH 03/24] rename file --- rfcs/{00000-flat-search.md => 00983-flat-search.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename rfcs/{00000-flat-search.md => 00983-flat-search.md} (100%) diff --git a/rfcs/00000-flat-search.md b/rfcs/00983-flat-search.md similarity index 100% rename from rfcs/00000-flat-search.md rename to rfcs/00983-flat-search.md From 0672f3d5ab35cc82a42309e48e1c66dbcddb5d9b Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Tue, 28 Apr 2026 08:19:33 -0700 Subject: [PATCH 04/24] fmt --- diskann/src/flat/index.rs | 4 +--- diskann/src/flat/iterator.rs | 9 +++------ diskann/src/flat/post_process.rs | 5 +++-- diskann/src/flat/strategy.rs | 6 +----- 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs index 04227b1da..6cf4c87c2 100644 --- a/diskann/src/flat/index.rs +++ b/diskann/src/flat/index.rs @@ -14,9 +14,7 @@ use diskann_vector::PreprocessedDistanceFunction; use crate::{ ANNResult, error::IntoANNResult, - flat::{ - FlatIterator, FlatPostProcess, FlatSearchStrategy, - }, + flat::{FlatIterator, FlatPostProcess, FlatSearchStrategy}, graph::{SearchOutputBuffer, index::SearchStats}, neighbor::{Neighbor, NeighborPriorityQueue}, provider::DataProvider, diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index 895c909ad..c822e61d8 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -33,17 +33,15 @@ pub trait FlatIterator: HasId + Send + Sync { /// /// Returns `Ok(None)` when the scan is exhausted. The yielded element borrows from /// the iterator and is invalidated by the next call to `next`. + #[allow(clippy::type_complexity)] fn next( &mut self, ) -> impl SendFuture)>, Self::Error>>; /// Drive the entire scan, invoking `f` for each yielded element. /// - /// The default implementation loops over [`Self::next`]. - fn on_elements_unordered( - &mut self, - mut f: F, - ) -> impl SendFuture> + /// The default implementation loops over [`Self::next`]. + fn on_elements_unordered(&mut self, mut f: F) -> impl SendFuture> where F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>), { @@ -55,4 +53,3 @@ pub trait FlatIterator: HasId + Send + Sync { } } } - diff --git a/diskann/src/flat/post_process.rs b/diskann/src/flat/post_process.rs index 2f95c2932..3e688e5bd 100644 --- a/diskann/src/flat/post_process.rs +++ b/diskann/src/flat/post_process.rs @@ -8,7 +8,8 @@ use diskann_utils::future::SendFuture; use crate::{ - error::StandardError, flat::FlatIterator, graph::SearchOutputBuffer, neighbor::Neighbor, provider::HasId, + error::StandardError, flat::FlatIterator, graph::SearchOutputBuffer, neighbor::Neighbor, + provider::HasId, }; /// Post-process the survivor candidates produced by a flat search and @@ -16,7 +17,7 @@ use crate::{ /// /// This is the flat counterpart to [`crate::graph::glue::SearchPostProcess`]. Processors /// receive `&mut S` so they can consult any iterator-owned lookup state (e.g., an -/// `Id -> rich-record` table built up during the scan) when assembling outputs. +/// `Id -> rich-record` table built up during the scan) when assembling outputs. /// /// The `O` type parameter lets callers pick the output element type (raw `(Id, f32)` /// pairs, fully hydrated hits etc.). diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs index 8b0f5e445..423b2817c 100644 --- a/diskann/src/flat/strategy.rs +++ b/diskann/src/flat/strategy.rs @@ -7,11 +7,7 @@ use diskann_vector::PreprocessedDistanceFunction; -use crate::{ - error::StandardError, - flat::FlatIterator, - provider::DataProvider, -}; +use crate::{error::StandardError, flat::FlatIterator, provider::DataProvider}; /// Per-call configuration that knows how to construct a [`FlatIterator`] for a provider /// and how to pre-process queries of type `T` into a distance computer. From 3ac0e1b2f4eac2fa11545f04a5908d016be400bb Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Tue, 28 Apr 2026 17:27:11 -0700 Subject: [PATCH 05/24] split iterator to callback --- diskann/src/flat/index.rs | 15 +++--- diskann/src/flat/iterator.rs | 81 +++++++++++++++++++++++++++----- diskann/src/flat/mod.rs | 2 +- diskann/src/flat/post_process.rs | 6 +-- diskann/src/flat/strategy.rs | 18 +++---- 5 files changed, 91 insertions(+), 31 deletions(-) diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs index 6cf4c87c2..9087e4b4b 100644 --- a/diskann/src/flat/index.rs +++ b/diskann/src/flat/index.rs @@ -14,7 +14,7 @@ use diskann_vector::PreprocessedDistanceFunction; use crate::{ ANNResult, error::IntoANNResult, - flat::{FlatIterator, FlatPostProcess, FlatSearchStrategy}, + flat::{OnElementsUnordered, FlatPostProcess, FlatSearchStrategy}, graph::{SearchOutputBuffer, index::SearchStats}, neighbor::{Neighbor, NeighborPriorityQueue}, provider::DataProvider, @@ -28,7 +28,7 @@ use crate::{ #[derive(Debug)] pub struct FlatIndex { /// The backing provider. - pub provider: P, + provider: P, _marker: PhantomData P>, } @@ -73,19 +73,20 @@ impl FlatIndex

{ T: ?Sized + Sync, O: Send, OB: SearchOutputBuffer + Send + ?Sized, - PP: for<'a> FlatPostProcess, T, O> + Send + Sync, + PP: for<'a> FlatPostProcess, T, O> + Send + Sync, { async move { - let mut iter = strategy - .create_iter(&self.provider, context) + let mut callback = strategy + .create_callback(&self.provider, context) .into_ann_result()?; + let computer = strategy.build_query_computer(query).into_ann_result()?; let k = k.get(); let mut queue = NeighborPriorityQueue::new(k); let mut cmps: u32 = 0; - iter.on_elements_unordered(|id, element| { + callback.on_elements_unordered(|id, element| { let dist = computer.evaluate_similarity(element); cmps += 1; queue.insert(Neighbor::new(id, dist)); @@ -94,7 +95,7 @@ impl FlatIndex

{ .into_ann_result()?; let result_count = processor - .post_process(&mut iter, query, queue.iter().take(k), output) + .post_process(&mut callback, query, queue.iter().take(k), output) .await .into_ann_result()? as u32; diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index c822e61d8..01a5d7cea 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -3,19 +3,41 @@ * Licensed under the MIT license. */ -//! [`FlatIterator`] — the sequential access primitive for accessing a flat index. +//! [`OnElementsUnordered`] — the sequential access primitive for accessing a flat index. +//! +//! [`FlatIterator`] — a lending async iterator that can be bridged into +//! [`OnElementsUnordered`] via [`DefaultIteratedOperator`]. use diskann_utils::{Reborrow, future::SendFuture}; use crate::{error::StandardError, provider::HasId}; -/// A lending, asynchronous iterator over the elements of a flat index. +/// Callback-driven sequential scan over the elements of a flat index. /// -/// `FlatIterator` is the streaming counterpart to [`crate::provider::Accessor`]. Where an -/// accessor exposes random retrieval by id, a flat iterator exposes a *sequential* walk — -/// each call to [`Self::next`] advances an internal cursor and yields the next element. +/// `OnElementsUnordered` is the streaming counterpart to [`crate::provider::Accessor`]. +/// Where an accessor exposes random retrieval by id, this trait exposes a *sequential* +/// walk that invokes a caller-supplied closure for every element. /// /// Algorithms see only `(Id, ElementRef)` pairs and treat the stream as opaque. +pub trait OnElementsUnordered: HasId + Send + Sync { + /// A reference to a yielded element with an unconstrained lifetime, suitable for + /// distance-function HRTB bounds. + type ElementRef<'a>; + + /// The error type yielded by [`Self::on_elements_unordered`]. + type Error: StandardError; + + /// Drive the entire scan, invoking `f` for each yielded element. + fn on_elements_unordered(&mut self, f: F) -> impl SendFuture> + where + F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>); +} + +/// A lending, asynchronous iterator over the elements of a flat index. +/// +/// Implementations provide element-at-a-time access via [`Self::next`]. Providers that +/// only implement `FlatIterator` can be wrapped in [`DefaultIteratedOperator`] to obtain +/// an [`OnElementsUnordered`] implementation automatically. pub trait FlatIterator: HasId + Send + Sync { /// A reference to a yielded element with an unconstrained lifetime, suitable for /// distance-function HRTB bounds. @@ -26,7 +48,7 @@ pub trait FlatIterator: HasId + Send + Sync { where Self: 'a; - /// The error type yielded by [`Self::next`] and [`Self::on_elements_unordered`]. + /// The error type yielded by [`Self::next`]. type Error: StandardError; /// Advance the iterator and asynchronously yield the next `(id, element)` pair. @@ -37,19 +59,56 @@ pub trait FlatIterator: HasId + Send + Sync { fn next( &mut self, ) -> impl SendFuture)>, Self::Error>>; +} + + +/////////////// +/// Default /// +/////////////// + + +/// Bridges a [`FlatIterator`] into an [`OnElementsUnordered`] by looping over +/// [`FlatIterator::next`] and reborrowing each element into the closure. +/// +/// This is the default adapter for providers that implement element-at-a-time iteration. +/// Providers that can do better (prefetching, SIMD batching, bulk I/O) should implement +/// [`OnElementsUnordered`] directly. +pub struct DefaultIteratedOperator { + inner: I, +} + +impl DefaultIteratedOperator { + /// Wrap an iterator to produce an [`OnElementsUnordered`] implementation. + pub fn new(inner: I) -> Self { + Self { inner } + } + + /// Unwrap, returning the inner iterator. + pub fn into_inner(self) -> I { + self.inner + } +} + +impl HasId for DefaultIteratedOperator { + type Id = I::Id; +} + +impl OnElementsUnordered for DefaultIteratedOperator +where + I: FlatIterator + HasId + Send + Sync, +{ + type ElementRef<'a> = I::ElementRef<'a>; + type Error = I::Error; - /// Drive the entire scan, invoking `f` for each yielded element. - /// - /// The default implementation loops over [`Self::next`]. fn on_elements_unordered(&mut self, mut f: F) -> impl SendFuture> where F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>), { async move { - while let Some((id, element)) = self.next().await? { + while let Some((id, element)) = self.inner.next().await? { f(id, element.reborrow()); } Ok(()) } } -} +} \ No newline at end of file diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs index 34fe62ac8..8754ae9d4 100644 --- a/diskann/src/flat/mod.rs +++ b/diskann/src/flat/mod.rs @@ -38,6 +38,6 @@ pub mod post_process; pub mod strategy; pub use index::FlatIndex; -pub use iterator::FlatIterator; +pub use iterator::{DefaultIteratedOperator, FlatIterator, OnElementsUnordered}; pub use post_process::{CopyFlatIds, FlatPostProcess}; pub use strategy::FlatSearchStrategy; diff --git a/diskann/src/flat/post_process.rs b/diskann/src/flat/post_process.rs index 3e688e5bd..71ffabf3c 100644 --- a/diskann/src/flat/post_process.rs +++ b/diskann/src/flat/post_process.rs @@ -8,7 +8,7 @@ use diskann_utils::future::SendFuture; use crate::{ - error::StandardError, flat::FlatIterator, graph::SearchOutputBuffer, neighbor::Neighbor, + error::StandardError, flat::OnElementsUnordered, graph::SearchOutputBuffer, neighbor::Neighbor, provider::HasId, }; @@ -23,7 +23,7 @@ use crate::{ /// pairs, fully hydrated hits etc.). pub trait FlatPostProcess::Id> where - S: FlatIterator, + S: OnElementsUnordered, T: ?Sized, { /// Errors yielded by [`Self::post_process`]. @@ -50,7 +50,7 @@ pub struct CopyFlatIds; impl FlatPostProcess for CopyFlatIds where - S: FlatIterator, + S: OnElementsUnordered, T: ?Sized, { type Error = crate::error::Infallible; diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs index 423b2817c..0e77df36b 100644 --- a/diskann/src/flat/strategy.rs +++ b/diskann/src/flat/strategy.rs @@ -7,7 +7,7 @@ use diskann_vector::PreprocessedDistanceFunction; -use crate::{error::StandardError, flat::FlatIterator, provider::DataProvider}; +use crate::{error::StandardError, flat::OnElementsUnordered, provider::DataProvider}; /// Per-call configuration that knows how to construct a [`FlatIterator`] for a provider /// and how to pre-process queries of type `T` into a distance computer. @@ -18,7 +18,7 @@ use crate::{error::StandardError, flat::FlatIterator, provider::DataProvider}; /// /// # Why two methods? /// -/// - [`Self::create_iter`] is query-independent and may be called multiple times per +/// - [`Self::create_callback`] is query-independent and may be called multiple times per /// request (e.g., once per parallel query in a batched search). /// - [`Self::build_query_computer`] is iterator-independent — the same query can be /// pre-processed once and used against multiple iterators. @@ -34,9 +34,9 @@ where P: DataProvider, T: ?Sized, { - /// The iterator type produced by [`Self::create_iter`]. Borrows from `self` and the + /// The iterator type produced by [`Self::create_callback`]. Borrows from `self` and the /// provider. - type Iter<'a>: FlatIterator + type Callback<'a>: OnElementsUnordered where Self: 'a, P: 'a; @@ -47,7 +47,7 @@ where /// by every lifetime of `Iter`. Two lifetimes are needed: `'a` for the iterator /// instance and `'b` for the reborrowed element. type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< - as FlatIterator>::ElementRef<'b>, + as OnElementsUnordered>::ElementRef<'b>, f32, > + Send + Sync @@ -59,15 +59,15 @@ where /// Construct a fresh iterator over `provider` for the given request `context`. /// /// This is where lock acquisition, snapshot pinning, and any other per-query setup - /// should happen. The returned iterator owns whatever borrows / guards it needs to + /// should happen. The returned callback object owns whatever borrows / guards it needs to /// remain valid until it is dropped. - fn create_iter<'a>( + fn create_callback<'a>( &'a self, provider: &'a P, context: &'a P::Context, - ) -> Result, Self::Error>; + ) -> Result, Self::Error>; /// Pre-process a query into a [`Self::QueryComputer`] usable for distance computation - /// against any iterator produced by [`Self::create_iter`]. + /// against any iterator produced by [`Self::create_callback`]. fn build_query_computer(&self, query: &T) -> Result; } From f887f2fad9112303bf58f128ce43765d47ddea45 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Tue, 28 Apr 2026 18:42:16 -0700 Subject: [PATCH 06/24] use distance unordered callback --- diskann/src/flat/index.rs | 10 ++++------ diskann/src/flat/iterator.rs | 33 +++++++++++++++++++++++++++++++- diskann/src/flat/mod.rs | 4 ++-- diskann/src/flat/post_process.rs | 4 ++-- diskann/src/flat/strategy.rs | 4 ++-- 5 files changed, 42 insertions(+), 13 deletions(-) diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs index 9087e4b4b..825567f8a 100644 --- a/diskann/src/flat/index.rs +++ b/diskann/src/flat/index.rs @@ -9,12 +9,11 @@ use std::marker::PhantomData; use std::num::NonZeroUsize; use diskann_utils::future::SendFuture; -use diskann_vector::PreprocessedDistanceFunction; use crate::{ ANNResult, error::IntoANNResult, - flat::{OnElementsUnordered, FlatPostProcess, FlatSearchStrategy}, + flat::{DistancesUnordered, FlatPostProcess, FlatSearchStrategy}, graph::{SearchOutputBuffer, index::SearchStats}, neighbor::{Neighbor, NeighborPriorityQueue}, provider::DataProvider, @@ -23,8 +22,8 @@ use crate::{ /// A `'static` thin wrapper around a [`DataProvider`] used for flat search. /// /// The provider is owned by the index. The index is constructed once at process startup and -/// shared across requests; per-query state lives in the [`crate::flat::FlatIterator`] that -/// the [`crate::flat::FlatSearchStrategy`] produces. +/// shared across requests; per-query state lives in the [`crate::flat::OnElementsUnordered`] +/// implementation that the [`crate::flat::FlatSearchStrategy`] produces. #[derive(Debug)] pub struct FlatIndex { /// The backing provider. @@ -86,8 +85,7 @@ impl FlatIndex

{ let mut queue = NeighborPriorityQueue::new(k); let mut cmps: u32 = 0; - callback.on_elements_unordered(|id, element| { - let dist = computer.evaluate_similarity(element); + callback.distances_unordered(&computer, |id, dist| { cmps += 1; queue.insert(Neighbor::new(id, dist)); }) diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index 01a5d7cea..fd69b312a 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -9,6 +9,7 @@ //! [`OnElementsUnordered`] via [`DefaultIteratedOperator`]. use diskann_utils::{Reborrow, future::SendFuture}; +use diskann_vector::PreprocessedDistanceFunction; use crate::{error::StandardError, provider::HasId}; @@ -33,6 +34,34 @@ pub trait OnElementsUnordered: HasId + Send + Sync { F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>); } + +/// Extension of [`OnElementsUnordered`] that drives the scan with a pre-built query +/// computer, invoking a callback with `(id, distance)` pairs instead of raw elements. +/// +/// The concrete computer is insantiated and supplied externally +/// by the [`FlatSearchStrategy`](crate::flat::FlatSearchStrategy). +/// +/// The default implementation delegates to [`OnElementsUnordered::on_elements_unordered`], +/// calling `computer.evaluate_similarity` on each element. +pub trait DistancesUnordered: OnElementsUnordered { + /// Drive the entire scan, scoring each element with `computer` and invoking `f` with + /// the resulting `(id, distance)` pair. + fn distances_unordered( + &mut self, + computer: &C, + mut f: F, + ) -> impl SendFuture> + where + C: for<'a> PreprocessedDistanceFunction, f32> + Send + Sync, + F: Send + FnMut(Self::Id, f32), + { + self.on_elements_unordered(move |id, element| { + let dist = computer.evaluate_similarity(element); + f(id, dist); + }) + } +} + /// A lending, asynchronous iterator over the elements of a flat index. /// /// Implementations provide element-at-a-time access via [`Self::next`]. Providers that @@ -111,4 +140,6 @@ where Ok(()) } } -} \ No newline at end of file +} + +impl DistancesUnordered for DefaultIteratedOperator where I: FlatIterator + HasId + Send + Sync {} \ No newline at end of file diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs index 8754ae9d4..2516509da 100644 --- a/diskann/src/flat/mod.rs +++ b/diskann/src/flat/mod.rs @@ -38,6 +38,6 @@ pub mod post_process; pub mod strategy; pub use index::FlatIndex; -pub use iterator::{DefaultIteratedOperator, FlatIterator, OnElementsUnordered}; -pub use post_process::{CopyFlatIds, FlatPostProcess}; +pub use iterator::{DefaultIteratedOperator, DistancesUnordered, FlatIterator, OnElementsUnordered}; +pub use post_process::{CopyIds, FlatPostProcess}; pub use strategy::FlatSearchStrategy; diff --git a/diskann/src/flat/post_process.rs b/diskann/src/flat/post_process.rs index 71ffabf3c..2cab763dd 100644 --- a/diskann/src/flat/post_process.rs +++ b/diskann/src/flat/post_process.rs @@ -46,9 +46,9 @@ where /// A trivial [`FlatPostProcess`] that copies each `(Id, distance)` pair straight into the /// output buffer. #[derive(Debug, Default, Clone, Copy)] -pub struct CopyFlatIds; +pub struct CopyIds; -impl FlatPostProcess for CopyFlatIds +impl FlatPostProcess for CopyIds where S: OnElementsUnordered, T: ?Sized, diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs index 0e77df36b..9fa2c6e00 100644 --- a/diskann/src/flat/strategy.rs +++ b/diskann/src/flat/strategy.rs @@ -7,7 +7,7 @@ use diskann_vector::PreprocessedDistanceFunction; -use crate::{error::StandardError, flat::OnElementsUnordered, provider::DataProvider}; +use crate::{error::StandardError, flat::{DistancesUnordered, OnElementsUnordered}, provider::DataProvider}; /// Per-call configuration that knows how to construct a [`FlatIterator`] for a provider /// and how to pre-process queries of type `T` into a distance computer. @@ -36,7 +36,7 @@ where { /// The iterator type produced by [`Self::create_callback`]. Borrows from `self` and the /// provider. - type Callback<'a>: OnElementsUnordered + type Callback<'a>: DistancesUnordered where Self: 'a, P: 'a; From dc2281cfb26607eab6a47f6f97eab446c6badd38 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Tue, 28 Apr 2026 21:26:24 -0700 Subject: [PATCH 07/24] rfc update --- rfcs/00983-flat-search.md | 155 ++++++++++++++++++++------------------ 1 file changed, 83 insertions(+), 72 deletions(-) diff --git a/rfcs/00983-flat-search.md b/rfcs/00983-flat-search.md index 78606ddb1..dfef8f516 100644 --- a/rfcs/00983-flat-search.md +++ b/rfcs/00983-flat-search.md @@ -22,75 +22,88 @@ stuffing the algorithm or the backend through the `Accessor` trait surface. ### Goals -1. Define a streaming access primitive — `FlatIterator` — that mirrors the role - `Accessor` plays for graph search but exposes a lending-iterator interface instead of - a random-access one. +1. Define a streaming access primitive — `OnElementsUnordered` — that mirrors the role + `Accessor` plays for graph search but exposes a callback-driven scan instead of + random access. 2. Provide flat-search algorithm implementations (with `knn_search` as default and filtered and diverse variants to opt-into) built on the new primitives, so consumers can use this against their own providers / backends. 3. Expose support for features and implementations native to the repo like quantized distance computers out-of-the-box. ## Proposal -Let's start with the main analog to the `Accessor` trait for the `FlatIndex` - `FlatIterator`. +The flat-search infrastructure is built on a small sequence of traits. The only required traits for the algorithm is `OnElementsUnordered` and its subtrait `DistancesUnordered`. A strategy - `FlatSearchStrategy` - instantiates these implementations for specific providers. An opt-in iterator trait `FlatIterator` and default implementations of the core traits - `DefaultIteratedOperator` - exist for convenience for backends that naturally expose element-at-a-time iteration. - -### `FlatIterator` +### `OnElementsUnordered` — the core scan ```rust -pub trait FlatIterator: HasId + Send + Sync { // Has Id support - // Element yielded by iterator - type ElementRef<'a>; +pub trait OnElementsUnordered: HasId + Send + Sync { + type ElementRef<'a>; + type Error: StandardError; - // Mostly machinery to play nice with HRTB - type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync - where - Self: 'a; + fn on_elements_unordered(&mut self, f: F) -> impl SendFuture> + where + F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>); +} +``` - type Error: StandardError; +A single required method: drive the entire scan via a callback. Async to match +[`crate::provider::Accessor`]. Implementations choose iteration order, prefetching, and +any SIMD-friendly bulk reads if they want; algorithms see only `(Id, ElementRef)` pairs. - fn next( - &mut self, - ) -> impl SendFuture)>, Self::Error>>; +### `DistancesUnordered` — the distance subtrait - // Default implementation for driving a closure on the items in the index. - fn on_elements_unordered( - &mut self, - mut f: F, +```rust +pub trait DistancesUnordered: OnElementsUnordered { + fn distances_unordered( + &mut self, computer: &C, mut f: F, ) -> impl SendFuture> - where F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>), + where + C: for<'a> PreprocessedDistanceFunction, f32> + Send + Sync, + F: Send + FnMut(Self::Id, f32), { - async move { - while let Some((id, element)) = self.next().await? { - f(id, element.reborrow()); - } - - Ok(()) - } + // default delegates to on_elements_unordered + evaluate_similarity } } ``` -The trait combines two access patterns: +A subtrait that fuses scanning with scoring. The default implementation loops +`on_elements_unordered` and calls `computer.evaluate_similarity` on each element. -- A required lending-iterator `next()`. -- A defaulted bulk method `on_elements_unordered` that consumes the entire scan via a - callback. The default impl loops over `next`; iterators that benefit from prefetching, - SIMD batching, or amortized per-element cost could override it. +The query computer is a generic parameter rather than an associated type, so the same +callback type can be driven by different computers. The `FlatSearchStrategy` is the +source of truth for which computer is used in any given search. -Both methods are **async** (returning `impl SendFuture<...>`), matching -[`crate::provider::Accessor::get_element`]. Iterators backed by I/O — disk pages, -remote shards — return a real future; in-memory iterators wrap their result in -`std::future::ready`. +### `FlatIterator` and `DefaultIteratedOperator` — convenience for element-at-a-time backends -The `Element` / `ElementRef` split is identical to `Accessor` and exists for the same -reason: to keep HRTB bounds on query computers from inducing `'static` requirements on -the iterator type. +For backends that naturally expose element-at-a-time iteration, `FlatIterator` is a +lending async iterator: + +```rust +pub trait FlatIterator: HasId + Send + Sync { + type ElementRef<'a>; + // lifetime gymnastics to make lifetime of `Element<'_>` to play nice with HRTB + type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync + where Self: 'a; + type Error: StandardError; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>>; +} +``` + +`DefaultIteratedOperator` wraps any `FlatIterator` and implements `OnElementsUnordered` +(and `DistancesUnordered` by inheritance) by looping over `next()` and reborrowing each +element. ### The glue: `FlatSearchStrategy` -While the `FlatIterator` is the primary object that provides access to the elements in the index for the algorithm, it is scoped to each query. We intorduce a constructor - `FlatSearchStrategy` - similar to `SearchStrategy` for `Accessor` to instantiate this object. A strategy is per-call configuration: stateless, cheap to construct, scoped to one -search. It produces both a per-query iterator and a query computer. +While `OnElementsUnordered` is the primary handle the algorithm uses to walk the index, +it is scoped to each query. We introduce a constructor — `FlatSearchStrategy` — similar +to `SearchStrategy` for `Accessor`, to instantiate the per-query callback object. +A strategy is per-call configuration that is stateless, cheap to construct and scoped to one +search. It produces both a per-query callback and a query computer. ```rust pub trait FlatSearchStrategy: Send + Sync @@ -98,55 +111,55 @@ where P: DataProvider, T: ?Sized, { - /// The iterator type produced by [`Self::create_iter`]. Borrows from `self` and the - /// provider. - type Iter<'a>: FlatIterator + /// The per-query callback type produced by [`Self::create_callback`]. Borrows from + /// `self` and the provider. + type Callback<'a>: DistancesUnordered where Self: 'a, /// The query computer produced by [`Self::build_query_computer`]. type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< - as FlatIterator>::ElementRef<'b>, + as OnElementsUnordered>::ElementRef<'b>, f32, > + Send + Sync + 'static; - /// The error type for both factory methods. + /// The error type type Error: StandardError; - /// Construct a fresh iterator over `provider` for the given request `context`. - fn create_iter<'a>( + /// Construct a fresh callback over `provider` for the given request `context`. + fn create_callback<'a>( &'a self, provider: &'a P, context: &'a P::Context, - ) -> Result, Self::Error>; + ) -> Result, Self::Error>; /// Pre-process a query into a [`Self::QueryComputer`] usable for distance computation - /// against any iterator produced by [`Self::create_iter`]. + /// against any callback produced by [`Self::create_callback`]. fn build_query_computer(&self, query: &T) -> Result; } ``` -The `ElementRef<'b>` that the distance function `QueryComputer` acts on is tied to the (reborrowed) element yielded by the `FlatIterator::next()`. +The `ElementRef<'b>` that the `QueryComputer` acts on is tied to the +`OnElementsUnordered::ElementRef` of the callback produced by `create_callback`. ### `FlatIndex` `FlatIndex` is a thin `'static` wrapper around a `DataProvider`. The same `DataProvider` -trait used by graph search is reused here — flat and graph subsystems share a single +trait used by graph search is reused here - flat and graph subsystems share a single provider surface and the same `Context` / id-mapping / error machinery. ```rust pub struct FlatIndex { provider: P, - /* private */ } impl FlatIndex

{ pub fn new(provider: P) -> Self; pub fn provider(&self) -> &P; - pub fn knn_search( + pub fn knn_search( &self, k: NonZeroUsize, strategy: &S, @@ -160,15 +173,16 @@ impl FlatIndex

{ T: ?Sized + Sync, O: Send, OB: SearchOutputBuffer + Send + ?Sized, + PP: for<'a> FlatPostProcess, T, O> + Send + Sync, } ``` The `knn_search` method is the canonical brute-force search algorithm: -1. Construct the iterator via `strategy.create_iter` to obtain a scoped iterator over the elements. +1. Construct the per-query callback via `strategy.create_callback`. 2. Build the query computer via `strategy.build_query_computer`. -3. Drive the scan via `iter.on_elements_unordered`, scoring each element and - inserting `Neighbor`s into a `NeighborPriorityQueue` of capacity `k`. +3. Drive the scan via `callback.distances_unordered(&computer, ...)`, inserting each + `(id, distance)` pair into a `NeighborPriorityQueue` of capacity `k`. 4. Hand the survivors (in distance order) to `processor.post_process`. 5. Return search stats. @@ -187,31 +201,28 @@ This design leans into using the `DataProvider` trait which requires implementat `to_external_id`), and error machinery are identical across graph and flat search, reducing the learning surface for new contributors. -### Async vs sync API for `FlatIterator` +### Async vs sync scan API -`next()` and `on_elements_unordered` return a future, making the trait -async. This is the right default for disk-backed and network-backed iterators -where advancing the cursor involves real I/O. It also matches the `Accessor` surface, +`on_elements_unordered` and `distances_unordered` return a future, making the scan +surface async. This is the right default for disk-backed and network-backed backends +where advancing the scan involves real I/O. It also matches the `Accessor` surface, keeping the two subsystems shaped the same way. -The cost is paid by in-memory consumers: every call to `next()` goes through the future -machinery even when the result is immediately available via `std::future::ready`. In a -tight brute-force loop this overhead — poll scaffolding, pinning etc — could be measurable. +The cost is paid by in-memory consumers: the scan goes through the future machinery +even when results are immediately available. In a tight brute-force loop this overhead — +poll scaffolding, pinning etc — could be measurable. We chose async because the wider audience of consumers (disk, network, mixed) benefits -more than in-memory consumers lose. +more than in-memory consumers lose. ### Expand `Element` to support batched distance computation? The current design yields one element per `next()` call, and the query computer scores -elements one at a time via `PreprocessedDistanceFunction::evaluate_similarity`. This could leave some optimization and performance on the table; especially with the upcoming effort around batched distance kernels. +elements one at a time via `PreprocessedDistanceFunction::evaluate_similarity`. This could leave some optimization and performance on the table; especially with the upcoming effort around batched distance kernels. Of course, a consumer can choose to implement their own optimized implementation of `distances_unordered` that uses batching. An alternative is to make `next()` yield a *batch* instead of a single vector representation like `Element<'_>`. Some work will need to be done to define the right interaction between the batch type, the element type in the batch, the interaction with `QueryComputer`'s types and way IDs and distances are collected in the queue. -We opted for the scalar-per-element design for now because it is simpler to implement and -reason about. The hope is that batched distance computation can be layered on later as an opt-in sub-trait without breaking -existing iterators. - ## Future Work - Support for other flat-search algorithms like - filtered, range and diverse flat algorithms as additional methods on `FlatIndex`. +- Index build -- this is just one part of the picture; more work needs to be done around how this fits in with any traits / interface we need for index build. From ee48f7dc752f6500b379c448ca42d29430ae49ff Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Tue, 28 Apr 2026 21:27:13 -0700 Subject: [PATCH 08/24] rustfmt flat module --- diskann/src/flat/index.rs | 15 ++++++++------- diskann/src/flat/iterator.rs | 8 +++----- diskann/src/flat/mod.rs | 4 +++- diskann/src/flat/strategy.rs | 6 +++++- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs index 825567f8a..f91a9ccac 100644 --- a/diskann/src/flat/index.rs +++ b/diskann/src/flat/index.rs @@ -22,7 +22,7 @@ use crate::{ /// A `'static` thin wrapper around a [`DataProvider`] used for flat search. /// /// The provider is owned by the index. The index is constructed once at process startup and -/// shared across requests; per-query state lives in the [`crate::flat::OnElementsUnordered`] +/// shared across requests; per-query state lives in the [`crate::flat::OnElementsUnordered`] /// implementation that the [`crate::flat::FlatSearchStrategy`] produces. #[derive(Debug)] pub struct FlatIndex { @@ -85,12 +85,13 @@ impl FlatIndex

{ let mut queue = NeighborPriorityQueue::new(k); let mut cmps: u32 = 0; - callback.distances_unordered(&computer, |id, dist| { - cmps += 1; - queue.insert(Neighbor::new(id, dist)); - }) - .await - .into_ann_result()?; + callback + .distances_unordered(&computer, |id, dist| { + cmps += 1; + queue.insert(Neighbor::new(id, dist)); + }) + .await + .into_ann_result()?; let result_count = processor .post_process(&mut callback, query, queue.iter().take(k), output) diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index fd69b312a..6ccebc214 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -34,7 +34,6 @@ pub trait OnElementsUnordered: HasId + Send + Sync { F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>); } - /// Extension of [`OnElementsUnordered`] that drives the scan with a pre-built query /// computer, invoking a callback with `(id, distance)` pairs instead of raw elements. /// @@ -42,7 +41,7 @@ pub trait OnElementsUnordered: HasId + Send + Sync { /// by the [`FlatSearchStrategy`](crate::flat::FlatSearchStrategy). /// /// The default implementation delegates to [`OnElementsUnordered::on_elements_unordered`], -/// calling `computer.evaluate_similarity` on each element. +/// calling `computer.evaluate_similarity` on each element. pub trait DistancesUnordered: OnElementsUnordered { /// Drive the entire scan, scoring each element with `computer` and invoking `f` with /// the resulting `(id, distance)` pair. @@ -90,12 +89,10 @@ pub trait FlatIterator: HasId + Send + Sync { ) -> impl SendFuture)>, Self::Error>>; } - /////////////// /// Default /// /////////////// - /// Bridges a [`FlatIterator`] into an [`OnElementsUnordered`] by looping over /// [`FlatIterator::next`] and reborrowing each element into the closure. /// @@ -142,4 +139,5 @@ where } } -impl DistancesUnordered for DefaultIteratedOperator where I: FlatIterator + HasId + Send + Sync {} \ No newline at end of file +impl DistancesUnordered for DefaultIteratedOperator where I: FlatIterator + HasId + Send + Sync +{} diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs index 2516509da..bf1290ed8 100644 --- a/diskann/src/flat/mod.rs +++ b/diskann/src/flat/mod.rs @@ -38,6 +38,8 @@ pub mod post_process; pub mod strategy; pub use index::FlatIndex; -pub use iterator::{DefaultIteratedOperator, DistancesUnordered, FlatIterator, OnElementsUnordered}; +pub use iterator::{ + DefaultIteratedOperator, DistancesUnordered, FlatIterator, OnElementsUnordered, +}; pub use post_process::{CopyIds, FlatPostProcess}; pub use strategy::FlatSearchStrategy; diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs index 9fa2c6e00..4b8ad8fe1 100644 --- a/diskann/src/flat/strategy.rs +++ b/diskann/src/flat/strategy.rs @@ -7,7 +7,11 @@ use diskann_vector::PreprocessedDistanceFunction; -use crate::{error::StandardError, flat::{DistancesUnordered, OnElementsUnordered}, provider::DataProvider}; +use crate::{ + error::StandardError, + flat::{DistancesUnordered, OnElementsUnordered}, + provider::DataProvider, +}; /// Per-call configuration that knows how to construct a [`FlatIterator`] for a provider /// and how to pre-process queries of type `T` into a distance computer. From 7fd903ecb31c63992340411b41bf919014383065 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Tue, 28 Apr 2026 21:31:55 -0700 Subject: [PATCH 09/24] fix clippy: replace doc-comment divider with regular comment --- diskann/src/flat/iterator.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index 6ccebc214..b725eeae9 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -89,9 +89,7 @@ pub trait FlatIterator: HasId + Send + Sync { ) -> impl SendFuture)>, Self::Error>>; } -/////////////// -/// Default /// -/////////////// +// ─── Default adapter ──────────────────────────────────────────────────────── /// Bridges a [`FlatIterator`] into an [`OnElementsUnordered`] by looping over /// [`FlatIterator::next`] and reborrowing each element into the closure. From 8af5e004255872b8306e8d78c7527fc7a6aa3820 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Wed, 29 Apr 2026 10:31:05 -0700 Subject: [PATCH 10/24] small edits --- diskann/src/flat/iterator.rs | 8 +++++++- diskann/src/flat/strategy.rs | 2 +- rfcs/00983-flat-search.md | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index b725eeae9..fa07e92b4 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -61,6 +61,10 @@ pub trait DistancesUnordered: OnElementsUnordered { } } +////////////// +// Iterator // +////////////// + /// A lending, asynchronous iterator over the elements of a flat index. /// /// Implementations provide element-at-a-time access via [`Self::next`]. Providers that @@ -89,7 +93,9 @@ pub trait FlatIterator: HasId + Send + Sync { ) -> impl SendFuture)>, Self::Error>>; } -// ─── Default adapter ──────────────────────────────────────────────────────── +///////////// +// Default // +///////////// /// Bridges a [`FlatIterator`] into an [`OnElementsUnordered`] by looping over /// [`FlatIterator::next`] and reborrowing each element into the closure. diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs index 4b8ad8fe1..5da3349ed 100644 --- a/diskann/src/flat/strategy.rs +++ b/diskann/src/flat/strategy.rs @@ -13,7 +13,7 @@ use crate::{ provider::DataProvider, }; -/// Per-call configuration that knows how to construct a [`FlatIterator`] for a provider +/// Per-call configuration that knows how to construct a [`DistancesUnordered`] for a provider /// and how to pre-process queries of type `T` into a distance computer. /// /// `FlatSearchStrategy` is the flat counterpart to [`crate::graph::glue::SearchStrategy`]. diff --git a/rfcs/00983-flat-search.md b/rfcs/00983-flat-search.md index dfef8f516..b79018a5d 100644 --- a/rfcs/00983-flat-search.md +++ b/rfcs/00983-flat-search.md @@ -4,7 +4,7 @@ |------------------|--------------------------------| | **Authors** | Aditya Krishnan, Alex Razumov, Dongliang Wu | | **Created** | 2026-04-24 | -| **Updated** | 2026-04-27 | +| **Updated** | 2026-04-28 | ## Motivation From 1dd1c727a0a32abb3fc20ac5181d8878c43fe7e3 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Mon, 4 May 2026 18:09:27 -0400 Subject: [PATCH 11/24] renames and uplevel query computer --- diskann/src/flat/index.rs | 18 ++++++------------ diskann/src/flat/iterator.rs | 23 +++++++++++++---------- diskann/src/flat/mod.rs | 4 +--- diskann/src/flat/strategy.rs | 24 ++++++++++++------------ rfcs/00983-flat-search.md | 30 +++++++++++++++--------------- 5 files changed, 47 insertions(+), 52 deletions(-) diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs index f91a9ccac..c5ee18824 100644 --- a/diskann/src/flat/index.rs +++ b/diskann/src/flat/index.rs @@ -4,8 +4,6 @@ */ //! [`FlatIndex`] — the index wrapper for an on which we do flat search. - -use std::marker::PhantomData; use std::num::NonZeroUsize; use diskann_utils::future::SendFuture; @@ -28,16 +26,12 @@ use crate::{ pub struct FlatIndex { /// The backing provider. provider: P, - _marker: PhantomData P>, } impl FlatIndex

{ /// Construct a new [`FlatIndex`] around `provider`. pub fn new(provider: P) -> Self { - Self { - provider, - _marker: PhantomData, - } + Self { provider } } /// Borrow the underlying provider. @@ -72,11 +66,11 @@ impl FlatIndex

{ T: ?Sized + Sync, O: Send, OB: SearchOutputBuffer + Send + ?Sized, - PP: for<'a> FlatPostProcess, T, O> + Send + Sync, + PP: for<'a> FlatPostProcess, T, O> + Send + Sync, { async move { - let mut callback = strategy - .create_callback(&self.provider, context) + let mut visitor = strategy + .create_visitor(&self.provider, context) .into_ann_result()?; let computer = strategy.build_query_computer(query).into_ann_result()?; @@ -85,7 +79,7 @@ impl FlatIndex

{ let mut queue = NeighborPriorityQueue::new(k); let mut cmps: u32 = 0; - callback + visitor .distances_unordered(&computer, |id, dist| { cmps += 1; queue.insert(Neighbor::new(id, dist)); @@ -94,7 +88,7 @@ impl FlatIndex

{ .into_ann_result()?; let result_count = processor - .post_process(&mut callback, query, queue.iter().take(k), output) + .post_process(&mut visitor, query, queue.iter().take(k), output) .await .into_ann_result()? as u32; diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index fa07e92b4..7defe0c80 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -43,15 +43,21 @@ pub trait OnElementsUnordered: HasId + Send + Sync { /// The default implementation delegates to [`OnElementsUnordered::on_elements_unordered`], /// calling `computer.evaluate_similarity` on each element. pub trait DistancesUnordered: OnElementsUnordered { + /// The concrete type of the distance computer for a query, which should be applicable for + /// all elements in the underlying driver. + type QueryComputer: for<'a> PreprocessedDistanceFunction, f32> + + Send + + Sync + + 'static; + /// Drive the entire scan, scoring each element with `computer` and invoking `f` with /// the resulting `(id, distance)` pair. - fn distances_unordered( + fn distances_unordered( &mut self, - computer: &C, + computer: &Self::QueryComputer, mut f: F, ) -> impl SendFuture> where - C: for<'a> PreprocessedDistanceFunction, f32> + Send + Sync, F: Send + FnMut(Self::Id, f32), { self.on_elements_unordered(move |id, element| { @@ -103,11 +109,11 @@ pub trait FlatIterator: HasId + Send + Sync { /// This is the default adapter for providers that implement element-at-a-time iteration. /// Providers that can do better (prefetching, SIMD batching, bulk I/O) should implement /// [`OnElementsUnordered`] directly. -pub struct DefaultIteratedOperator { +pub struct Iterated { inner: I, } -impl DefaultIteratedOperator { +impl Iterated { /// Wrap an iterator to produce an [`OnElementsUnordered`] implementation. pub fn new(inner: I) -> Self { Self { inner } @@ -119,11 +125,11 @@ impl DefaultIteratedOperator { } } -impl HasId for DefaultIteratedOperator { +impl HasId for Iterated { type Id = I::Id; } -impl OnElementsUnordered for DefaultIteratedOperator +impl OnElementsUnordered for Iterated where I: FlatIterator + HasId + Send + Sync, { @@ -142,6 +148,3 @@ where } } } - -impl DistancesUnordered for DefaultIteratedOperator where I: FlatIterator + HasId + Send + Sync -{} diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs index bf1290ed8..d3407310a 100644 --- a/diskann/src/flat/mod.rs +++ b/diskann/src/flat/mod.rs @@ -38,8 +38,6 @@ pub mod post_process; pub mod strategy; pub use index::FlatIndex; -pub use iterator::{ - DefaultIteratedOperator, DistancesUnordered, FlatIterator, OnElementsUnordered, -}; +pub use iterator::{DistancesUnordered, FlatIterator, Iterated, OnElementsUnordered}; pub use post_process::{CopyIds, FlatPostProcess}; pub use strategy::FlatSearchStrategy; diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs index 5da3349ed..4d7e474ad 100644 --- a/diskann/src/flat/strategy.rs +++ b/diskann/src/flat/strategy.rs @@ -22,10 +22,10 @@ use crate::{ /// /// # Why two methods? /// -/// - [`Self::create_callback`] is query-independent and may be called multiple times per +/// - [`Self::create_visitor`] is query-independent and may be called multiple times per /// request (e.g., once per parallel query in a batched search). -/// - [`Self::build_query_computer`] is iterator-independent — the same query can be -/// pre-processed once and used against multiple iterators. +/// - [`Self::build_query_computer`] is visitor-independent — the same query can be +/// pre-processed once and used against multiple visitors. /// /// Both methods may borrow from the strategy itself. /// @@ -38,9 +38,9 @@ where P: DataProvider, T: ?Sized, { - /// The iterator type produced by [`Self::create_callback`]. Borrows from `self` and the + /// The visitor type produced by [`Self::create_visitor`]. Borrows from `self` and the /// provider. - type Callback<'a>: DistancesUnordered + type Visitor<'a>: DistancesUnordered where Self: 'a, P: 'a; @@ -48,10 +48,10 @@ where /// The query computer produced by [`Self::build_query_computer`]. /// /// The HRTB on `ElementRef` ensures the same computer can score every element yielded - /// by every lifetime of `Iter`. Two lifetimes are needed: `'a` for the iterator + /// by every lifetime of `Visitor`. Two lifetimes are needed: `'a` for the visitor /// instance and `'b` for the reborrowed element. type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< - as OnElementsUnordered>::ElementRef<'b>, + as OnElementsUnordered>::ElementRef<'b>, f32, > + Send + Sync @@ -60,18 +60,18 @@ where /// The error type for both factory methods. type Error: StandardError; - /// Construct a fresh iterator over `provider` for the given request `context`. + /// Construct a fresh visitor over `provider` for the given request `context`. /// /// This is where lock acquisition, snapshot pinning, and any other per-query setup - /// should happen. The returned callback object owns whatever borrows / guards it needs to + /// should happen. The returned visitor owns whatever borrows / guards it needs to /// remain valid until it is dropped. - fn create_callback<'a>( + fn create_visitor<'a>( &'a self, provider: &'a P, context: &'a P::Context, - ) -> Result, Self::Error>; + ) -> Result, Self::Error>; /// Pre-process a query into a [`Self::QueryComputer`] usable for distance computation - /// against any iterator produced by [`Self::create_callback`]. + /// against any visitor produced by [`Self::create_visitor`]. fn build_query_computer(&self, query: &T) -> Result; } diff --git a/rfcs/00983-flat-search.md b/rfcs/00983-flat-search.md index b79018a5d..e552b6164 100644 --- a/rfcs/00983-flat-search.md +++ b/rfcs/00983-flat-search.md @@ -73,7 +73,7 @@ The query computer is a generic parameter rather than an associated type, so the callback type can be driven by different computers. The `FlatSearchStrategy` is the source of truth for which computer is used in any given search. -### `FlatIterator` and `DefaultIteratedOperator` — convenience for element-at-a-time backends +### `FlatIterator` and `Iterated` — convenience for element-at-a-time backends For backends that naturally expose element-at-a-time iteration, `FlatIterator` is a lending async iterator: @@ -92,7 +92,7 @@ pub trait FlatIterator: HasId + Send + Sync { } ``` -`DefaultIteratedOperator` wraps any `FlatIterator` and implements `OnElementsUnordered` +`Iterated` wraps any `FlatIterator` and implements `OnElementsUnordered` (and `DistancesUnordered` by inheritance) by looping over `next()` and reborrowing each element. @@ -101,9 +101,9 @@ element. While `OnElementsUnordered` is the primary handle the algorithm uses to walk the index, it is scoped to each query. We introduce a constructor — `FlatSearchStrategy` — similar -to `SearchStrategy` for `Accessor`, to instantiate the per-query callback object. +to `SearchStrategy` for `Accessor`, to instantiate the per-query visitor. A strategy is per-call configuration that is stateless, cheap to construct and scoped to one -search. It produces both a per-query callback and a query computer. +search. It produces both a per-query visitor and a query computer. ```rust pub trait FlatSearchStrategy: Send + Sync @@ -111,15 +111,15 @@ where P: DataProvider, T: ?Sized, { - /// The per-query callback type produced by [`Self::create_callback`]. Borrows from + /// The per-query visitor type produced by [`Self::create_visitor`]. Borrows from /// `self` and the provider. - type Callback<'a>: DistancesUnordered + type Visitor<'a>: DistancesUnordered where Self: 'a, /// The query computer produced by [`Self::build_query_computer`]. type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< - as OnElementsUnordered>::ElementRef<'b>, + as OnElementsUnordered>::ElementRef<'b>, f32, > + Send + Sync @@ -128,21 +128,21 @@ where /// The error type type Error: StandardError; - /// Construct a fresh callback over `provider` for the given request `context`. - fn create_callback<'a>( + /// Construct a fresh visitor over `provider` for the given request `context`. + fn create_visitor<'a>( &'a self, provider: &'a P, context: &'a P::Context, - ) -> Result, Self::Error>; + ) -> Result, Self::Error>; /// Pre-process a query into a [`Self::QueryComputer`] usable for distance computation - /// against any callback produced by [`Self::create_callback`]. + /// against any visitor produced by [`Self::create_visitor`]. fn build_query_computer(&self, query: &T) -> Result; } ``` The `ElementRef<'b>` that the `QueryComputer` acts on is tied to the -`OnElementsUnordered::ElementRef` of the callback produced by `create_callback`. +`OnElementsUnordered::ElementRef` of the visitor produced by `create_visitor`. ### `FlatIndex` @@ -173,15 +173,15 @@ impl FlatIndex

{ T: ?Sized + Sync, O: Send, OB: SearchOutputBuffer + Send + ?Sized, - PP: for<'a> FlatPostProcess, T, O> + Send + Sync, + PP: for<'a> FlatPostProcess, T, O> + Send + Sync, } ``` The `knn_search` method is the canonical brute-force search algorithm: -1. Construct the per-query callback via `strategy.create_callback`. +1. Construct the per-query visitor via `strategy.create_visitor`. 2. Build the query computer via `strategy.build_query_computer`. -3. Drive the scan via `callback.distances_unordered(&computer, ...)`, inserting each +3. Drive the scan via `visitor.distances_unordered(&computer, ...)`, inserting each `(id, distance)` pair into a `NeighborPriorityQueue` of capacity `k`. 4. Hand the survivors (in distance order) to `processor.post_process`. 5. Return search stats. From 3d24ef757b745f7726657770d3315875072e1557 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Tue, 5 May 2026 14:11:37 -0400 Subject: [PATCH 12/24] split buildquerycomputer, haselementref and distancesunordered --- .../src/search/provider/disk_provider.rs | 22 +++-- diskann-garnet/src/provider.rs | 10 ++- .../encoded_document_accessor.rs | 23 ++++- .../inline_beta_search/inline_beta_filter.rs | 2 +- .../graph/provider/async_/bf_tree/provider.rs | 53 +++++++++--- .../graph/provider/async_/caching/provider.rs | 20 ++++- .../provider/async_/inmem/full_precision.rs | 26 ++++-- .../graph/provider/async_/inmem/product.rs | 37 ++++++-- .../graph/provider/async_/inmem/scalar.rs | 28 +++++-- .../graph/provider/async_/inmem/spherical.rs | 25 ++++-- .../model/graph/provider/async_/inmem/test.rs | 10 ++- .../graph/provider/async_/postprocess.rs | 4 +- .../model/graph/provider/layers/betafilter.rs | 33 ++++++-- diskann/src/flat/iterator.rs | 25 +++--- diskann/src/flat/strategy.rs | 6 +- diskann/src/graph/glue.rs | 33 ++++---- diskann/src/graph/test/provider.rs | 7 +- diskann/src/provider.rs | 84 ++++++++++++++----- 18 files changed, 333 insertions(+), 115 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index b7b30e94a..4f05b63e0 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -28,8 +28,8 @@ use diskann::{ }, neighbor::Neighbor, provider::{ - Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId, - NeighborAccessor, NoopGuard, + Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, + DistancesUnordered, HasElementRef, HasId, NeighborAccessor, NoopGuard, }, utils::{IntoUsize, VectorRepr}, ANNError, ANNResult, @@ -427,7 +427,13 @@ where .to_vec(), }) } +} +impl DistancesUnordered<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> +where + Data: GraphDataType, + VP: VertexProvider, +{ async fn distances_unordered( &mut self, vec_id_itr: Itr, @@ -687,6 +693,15 @@ where type Id = u32; } +impl HasElementRef for DiskAccessor<'_, Data, VP> +where + Data: GraphDataType, + VP: VertexProvider, +{ + /// `ElementRef` can have arbitrary lifetimes. + type ElementRef<'a> = &'a [u8]; +} + impl Accessor for DiskAccessor<'_, Data, VP> where Data: GraphDataType, @@ -699,9 +714,6 @@ where where Self: 'a; - /// `ElementRef` can have arbitrary lifetimes. - type ElementRef<'a> = &'a [u8]; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = ANNError; diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index e2522b147..8efd37267 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -18,7 +18,8 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DelegateNeighbor, - Delete, ElementStatus, HasId, NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, + Delete, DistancesUnordered, ElementStatus, HasElementRef, HasId, NeighborAccessor, + NeighborAccessorMut, NoopGuard, SetElement, }, utils::VectorRepr, }; @@ -519,12 +520,15 @@ impl ExpandBeam<&[T]> for FullAccessor<'_, T> { } } +impl HasElementRef for FullAccessor<'_, T> { + type ElementRef<'a> = &'a [T]; +} + impl Accessor for FullAccessor<'_, T> { type Element<'a> = Vec where Self: 'a; - type ElementRef<'a> = &'a [T]; type GetError = GarnetProviderError; fn get_element( @@ -581,6 +585,8 @@ impl BuildQueryComputer<&[T]> for FullAccessor<'_, T> { } } +impl DistancesUnordered<&[T]> for FullAccessor<'_, T> {} + /// An escape hatch for the blanket implementation of [`workingset::Fill`]. /// /// Without an `&[T]: Into>`, the blanket implementation for `workingset::Map` diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 0d16248dd..15a2b5ce9 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -8,7 +8,10 @@ use std::sync::{Arc, RwLock}; use diskann::{ error::{ErrorExt, IntoANNResult}, graph::glue::{ExpandBeam, SearchExt}, - provider::{Accessor, AsNeighbor, BuildQueryComputer, DelegateNeighbor, HasId}, + provider::{ + Accessor, AsNeighbor, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, + HasElementRef, HasId, + }, ANNError, ANNErrorKind, }; use diskann_utils::Reborrow; @@ -68,6 +71,13 @@ where type Id = ::Id; } +impl HasElementRef for EncodedDocumentAccessor +where + IA: Accessor, +{ + type ElementRef<'a> = EncodedDocument, &'a RoaringTreemap>; +} + impl Accessor for EncodedDocumentAccessor where IA: Accessor, @@ -76,7 +86,6 @@ where = EncodedDocument, RoaringTreemap> where Self: 'a; - type ElementRef<'a> = EncodedDocument, &'a RoaringTreemap>; type GetError = ANNError; async fn get_element(&mut self, id: Self::Id) -> Result, Self::GetError> { @@ -164,7 +173,7 @@ where impl<'q, IA, Q> BuildQueryComputer<&'q FilteredQuery> for EncodedDocumentAccessor where - IA: BuildQueryComputer<&'q Q>, + IA: BuildQueryComputer<&'q Q> + Accessor, { type QueryComputerError = ANNError; type QueryComputer = InlineBetaComputer; @@ -195,6 +204,14 @@ where { } +impl DistancesUnordered for EncodedDocumentAccessor +where + IA: Accessor, + EncodedDocumentAccessor: BuildQueryComputer, + Q: Clone, +{ +} + impl<'a, IA> DelegateNeighbor<'a> for EncodedDocumentAccessor where IA: DelegateNeighbor<'a>, diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 26b07f0f6..10aef0109 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -143,7 +143,7 @@ pub struct FilterResults { impl<'q, Q, IA, IPP> SearchPostProcess, &'q FilteredQuery> for FilterResults where - IA: BuildQueryComputer<&'q Q>, + IA: BuildQueryComputer<&'q Q> + Accessor, Q: Send + Sync, IPP: SearchPostProcess + Send + Sync, { diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index f7289e146..4754fb6ce 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -28,8 +28,8 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultContext, - DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, NeighborAccessorMut, - NoopGuard, SetElement, + DelegateNeighbor, Delete, DistancesUnordered, ElementStatus, HasElementRef, HasId, + NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, }, utils::{IntoUsize, VectorRepr}, }; @@ -974,6 +974,15 @@ where } } +impl HasElementRef for FullAccessor<'_, T, Q, D> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, +{ + type ElementRef<'a> = &'a [T]; +} + impl Accessor for FullAccessor<'_, T, Q, D> where T: VectorRepr, @@ -986,9 +995,6 @@ where where Self: 'a; - /// The reference version of `Element` is the same as `Element`. - type ElementRef<'a> = &'a [T]; - // Choose to panic on an out-of-bounds access rather than propagate an error. // type GetError = Panics; @@ -1058,6 +1064,14 @@ where { } +impl DistancesUnordered<&[T]> for FullAccessor<'_, T, Q, D> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, +{ +} + impl<'a, T, Q, D> AsDeletionCheck for FullAccessor<'a, T, Q, D> where T: VectorRepr, @@ -1134,6 +1148,14 @@ where } } +impl HasElementRef for QuantAccessor<'_, T, D> +where + T: VectorRepr, + D: AsyncFriendly, +{ + type ElementRef<'a> = &'a [u8]; +} + impl Accessor for QuantAccessor<'_, T, D> where T: VectorRepr, @@ -1145,9 +1167,6 @@ where where Self: 'a; - /// The reference version of `Element` is simply `Element`. - type ElementRef<'a> = &'a [u8]; - // ANNError on access failures in bf-tree // type GetError = ANNError; @@ -1220,6 +1239,13 @@ where { } +impl DistancesUnordered<&[T]> for QuantAccessor<'_, T, D> +where + T: VectorRepr, + D: AsyncFriendly, +{ +} + impl<'a, T, D> AsDeletionCheck for QuantAccessor<'a, T, D> where T: VectorRepr, @@ -1282,6 +1308,14 @@ where } } +impl HasElementRef for HybridAccessor<'_, T, D> +where + T: VectorRepr, + D: AsyncFriendly, +{ + type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; +} + impl Accessor for HybridAccessor<'_, T, D> where T: VectorRepr, @@ -1296,9 +1330,6 @@ where where Self: 'a; - /// The generalized reference form of `Element`. - type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; - // Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = Panics; diff --git a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs index f60ea100e..f7f67c8ef 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs @@ -74,8 +74,8 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, AsNeighbor, BuildDistanceComputer, BuildQueryComputer, CacheableAccessor, - DataProvider, DelegateNeighbor, Delete, ElementStatus, HasId, NeighborAccessor, - NeighborAccessorMut, SetElement, + DataProvider, DelegateNeighbor, Delete, DistancesUnordered, ElementStatus, HasElementRef, + HasId, NeighborAccessor, NeighborAccessorMut, SetElement, }, }; use diskann_utils::{ @@ -739,6 +739,14 @@ where } } +impl HasElementRef for CachingAccessor +where + A: CacheableAccessor, + C: ElementCache, +{ + type ElementRef<'a> = A::ElementRef<'a>; +} + impl Accessor for CachingAccessor where A: CacheableAccessor, @@ -748,7 +756,6 @@ where = A::Element<'a> where Self: 'a; - type ElementRef<'a> = A::ElementRef<'a>; type GetError = CachingError; @@ -825,6 +832,13 @@ where { } +impl DistancesUnordered for CachingAccessor +where + A: BuildQueryComputer + CacheableAccessor + AsNeighbor, + C: ElementCache + NeighborCache, +{ +} + /// Post Process #[derive(Debug, Default, Clone, Copy)] pub struct Unwrap; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 4f60b9510..16b963bc9 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -20,7 +20,7 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - ExecutionContext, HasId, + DistancesUnordered, ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -185,6 +185,16 @@ where } } +impl HasElementRef for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type ElementRef<'a> = &'a [T]; +} + impl Accessor for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, @@ -199,9 +209,6 @@ where where Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = &'a [T]; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = Panics; @@ -316,6 +323,15 @@ where { } +impl DistancesUnordered<&[T]> for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + //-------------------// // In-mem Extensions // //-------------------// @@ -353,7 +369,7 @@ pub struct Rerank; impl<'a, A, T> glue::SearchPostProcess for Rerank where T: VectorRepr, - A: BuildQueryComputer<&'a [T], Id = u32> + GetFullPrecision + AsDeletionCheck, + A: BuildQueryComputer<&'a [T]> + HasId + GetFullPrecision + AsDeletionCheck, { type Error = Panics; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index 4f3f931ba..fa202d09f 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -16,8 +16,8 @@ use diskann::{ workingset, }, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, - HasId, + Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, ExecutionContext, + HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -143,6 +143,15 @@ where } } +impl HasElementRef for QuantAccessor<'_, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type ElementRef<'a> = &'a [u8]; +} + impl Accessor for QuantAccessor<'_, V, D, Ctx> where V: AsyncFriendly, @@ -156,9 +165,6 @@ where where Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = &'a [u8]; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = Panics; @@ -242,6 +248,15 @@ where { } +impl DistancesUnordered<&[T]> for QuantAccessor<'_, V, D, Ctx> +where + T: VectorRepr, + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + //-------------------// // In-mem Extensions // //-------------------// @@ -316,6 +331,15 @@ where } } +impl HasElementRef for HybridAccessor<'_, T, D, Ctx> +where + T: VectorRepr, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; +} + impl Accessor for HybridAccessor<'_, T, D, Ctx> where T: VectorRepr, @@ -331,9 +355,6 @@ where where Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = distances::pq::Hybrid<&'a [T], &'a [u8]>; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = Panics; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index 0e1f5954c..78f60779c 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -16,8 +16,8 @@ use diskann::{ workingset, }, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, - HasId, + Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, ExecutionContext, + HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -420,6 +420,16 @@ where } } +impl HasElementRef for QuantAccessor<'_, NBITS, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, + Unsigned: Representation, +{ + type ElementRef<'a> = CVRef<'a, NBITS>; +} + impl Accessor for QuantAccessor<'_, NBITS, V, D, Ctx> where V: AsyncFriendly, @@ -434,9 +444,6 @@ where where Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = CVRef<'a, NBITS>; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = ANNError; @@ -556,6 +563,17 @@ where { } +impl DistancesUnordered<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx> +where + T: VectorRepr, + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, + Unsigned: Representation, + QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, +{ +} + impl BuildDistanceComputer for QuantAccessor<'_, NBITS, V, D, Ctx> where V: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 5c6f7d85a..7b7d9c2b4 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -18,8 +18,8 @@ use diskann::{ workingset, }, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, ExecutionContext, - HasId, + Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, ExecutionContext, + HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -340,6 +340,15 @@ where } } +impl HasElementRef for QuantAccessor<'_, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type ElementRef<'a> = spherical::iface::Opaque<'a>; +} + impl Accessor for QuantAccessor<'_, V, D, Ctx> where V: AsyncFriendly, @@ -353,9 +362,6 @@ where where Self: 'a; - /// `ElementRef` has an arbitrarily short lifetime. - type ElementRef<'a> = spherical::iface::Opaque<'a>; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = ANNError; @@ -467,6 +473,15 @@ where { } +impl DistancesUnordered<&[T]> for QuantAccessor<'_, V, D, Ctx> +where + T: VectorRepr, + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ +} + #[derive(Debug, Error)] #[error("unconstructible")] pub enum Infallible {} diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs index bfb18c3c1..a7a80ca56 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs @@ -21,7 +21,7 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - HasId, + DistancesUnordered, HasElementRef, HasId, }, utils::IntoUsize, }; @@ -164,6 +164,10 @@ impl HasId for FlakyAccessor<'_> { type Id = u32; } +impl HasElementRef for FlakyAccessor<'_> { + type ElementRef<'a> = &'a [f32]; +} + impl Accessor for FlakyAccessor<'_> { /// This accessor returns raw slices. There *is* a chance of racing when the fast /// providers are used. We just have to live with it. @@ -172,8 +176,6 @@ impl Accessor for FlakyAccessor<'_> { where Self: 'a; - type ElementRef<'a> = &'a [f32]; - /// Choose to panic on an out-of-bounds access rather than propagate an error. type GetError = TestError; @@ -231,6 +233,8 @@ impl<'a, 'b> BuildQueryComputer<&'a [f32]> for FlakyAccessor<'b> { impl ExpandBeam<&[f32]> for FlakyAccessor<'_> {} +impl DistancesUnordered<&[f32]> for FlakyAccessor<'_> {} + impl<'a> DelegateNeighbor<'a> for FlakyAccessor<'_> { type Delegate = &'a SimpleNeighborProviderAsync; fn delegate_neighbor(&'a mut self) -> Self::Delegate { diff --git a/diskann-providers/src/model/graph/provider/async_/postprocess.rs b/diskann-providers/src/model/graph/provider/async_/postprocess.rs index 3e1849bb0..dbbf08fa4 100644 --- a/diskann-providers/src/model/graph/provider/async_/postprocess.rs +++ b/diskann-providers/src/model/graph/provider/async_/postprocess.rs @@ -8,7 +8,7 @@ use diskann::{ graph::{SearchOutputBuffer, glue}, neighbor::Neighbor, - provider::BuildQueryComputer, + provider::{BuildQueryComputer, HasId}, }; /// A bridge allowing `Accessors` to opt-in to [`RemoveDeletedIdsAndCopy`] by delegating to @@ -39,7 +39,7 @@ pub struct RemoveDeletedIdsAndCopy; impl glue::SearchPostProcess for RemoveDeletedIdsAndCopy where - A: BuildQueryComputer + AsDeletionCheck, + A: BuildQueryComputer + HasId + AsDeletionCheck, { type Error = std::convert::Infallible; diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index f0ffc0451..9a38379cf 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -24,7 +24,10 @@ use diskann::{ index::QueryLabelProvider, }, neighbor::Neighbor, - provider::{Accessor, AsNeighbor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, + provider::{ + Accessor, AsNeighbor, BuildQueryComputer, DataProvider, DelegateNeighbor, + DistancesUnordered, HasElementRef, HasId, + }, utils::VectorId, }; use diskann_utils::Reborrow; @@ -70,7 +73,7 @@ pub struct Unwrap; /// Delegate post-processing to the inner strategy's post-processing routine. impl SearchPostProcessStep, T, O> for Unwrap where - A: BuildQueryComputer, + A: BuildQueryComputer + Accessor, { type Error = NextError @@ -227,6 +230,13 @@ where type Id = Inner::Id; } +impl HasElementRef for BetaAccessor +where + Inner: Accessor, +{ + type ElementRef<'a> = Pair>; +} + impl Accessor for BetaAccessor where Inner: Accessor, @@ -236,7 +246,6 @@ where = Pair> where Self: 'a; - type ElementRef<'a> = Pair>; /// Use the same error type as `Inner`. type GetError = Inner::GetError; @@ -279,10 +288,10 @@ where impl BuildQueryComputer for BetaAccessor where - Inner: BuildQueryComputer, + Inner: BuildQueryComputer + Accessor, { /// Use a [`BetaComputer`] to apply filtering. - type QueryComputer = BetaComputer; + type QueryComputer = BetaComputer; /// Use the same error as `Inner`. type QueryComputerError = Inner::QueryComputerError; @@ -296,7 +305,12 @@ where } } -impl ExpandBeam for BetaAccessor where Inner: BuildQueryComputer + AsNeighbor {} +impl ExpandBeam for BetaAccessor where + Inner: BuildQueryComputer + AsNeighbor + Accessor +{ +} + +impl DistancesUnordered for BetaAccessor where Inner: BuildQueryComputer + Accessor {} /// A [`PreprocessedDistanceFunction`] that applied `beta` filtering to the inner computer. pub struct BetaComputer { @@ -432,12 +446,15 @@ mod tests { always_escalate!(NotAllowed); + impl HasElementRef for Doubler { + type ElementRef<'a> = u64; + } + impl Accessor for Doubler { type Element<'a> = u64 where Self: 'a; - type ElementRef<'a> = u64; type GetError = NotAllowed; @@ -498,6 +515,8 @@ mod tests { impl ExpandBeam for Doubler {} + impl DistancesUnordered for Doubler {} + #[derive(Debug)] struct SimpleStrategy; diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index 7defe0c80..459b02216 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -11,7 +11,7 @@ use diskann_utils::{Reborrow, future::SendFuture}; use diskann_vector::PreprocessedDistanceFunction; -use crate::{error::StandardError, provider::HasId}; +use crate::{error::StandardError, provider::{HasElementRef, HasId}}; /// Callback-driven sequential scan over the elements of a flat index. /// @@ -20,18 +20,14 @@ use crate::{error::StandardError, provider::HasId}; /// walk that invokes a caller-supplied closure for every element. /// /// Algorithms see only `(Id, ElementRef)` pairs and treat the stream as opaque. -pub trait OnElementsUnordered: HasId + Send + Sync { - /// A reference to a yielded element with an unconstrained lifetime, suitable for - /// distance-function HRTB bounds. - type ElementRef<'a>; - +pub trait OnElementsUnordered: HasId + HasElementRef + Send + Sync { /// The error type yielded by [`Self::on_elements_unordered`]. type Error: StandardError; /// Drive the entire scan, invoking `f` for each yielded element. fn on_elements_unordered(&mut self, f: F) -> impl SendFuture> where - F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>); + F: Send + for<'a> FnMut(Self::Id, ::ElementRef<'a>); } /// Extension of [`OnElementsUnordered`] that drives the scan with a pre-built query @@ -45,7 +41,7 @@ pub trait OnElementsUnordered: HasId + Send + Sync { pub trait DistancesUnordered: OnElementsUnordered { /// The concrete type of the distance computer for a query, which should be applicable for /// all elements in the underlying driver. - type QueryComputer: for<'a> PreprocessedDistanceFunction, f32> + type QueryComputer: for<'a> PreprocessedDistanceFunction<::ElementRef<'a>, f32> + Send + Sync + 'static; @@ -76,13 +72,9 @@ pub trait DistancesUnordered: OnElementsUnordered { /// Implementations provide element-at-a-time access via [`Self::next`]. Providers that /// only implement `FlatIterator` can be wrapped in [`DefaultIteratedOperator`] to obtain /// an [`OnElementsUnordered`] implementation automatically. -pub trait FlatIterator: HasId + Send + Sync { - /// A reference to a yielded element with an unconstrained lifetime, suitable for - /// distance-function HRTB bounds. - type ElementRef<'a>; - +pub trait FlatIterator: HasId + HasElementRef + Send + Sync { /// The concrete element returned by [`Self::next`]. Reborrows to [`Self::ElementRef`]. - type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync + type Element<'a>: for<'b> Reborrow<'b, Target = ::ElementRef<'b>> + Send + Sync where Self: 'a; @@ -129,11 +121,14 @@ impl HasId for Iterated { type Id = I::Id; } +impl HasElementRef for Iterated { + type ElementRef<'a> = I::ElementRef<'a>; +} + impl OnElementsUnordered for Iterated where I: FlatIterator + HasId + Send + Sync, { - type ElementRef<'a> = I::ElementRef<'a>; type Error = I::Error; fn on_elements_unordered(&mut self, mut f: F) -> impl SendFuture> diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs index 4d7e474ad..ee3fce96c 100644 --- a/diskann/src/flat/strategy.rs +++ b/diskann/src/flat/strategy.rs @@ -9,8 +9,8 @@ use diskann_vector::PreprocessedDistanceFunction; use crate::{ error::StandardError, - flat::{DistancesUnordered, OnElementsUnordered}, - provider::DataProvider, + flat::{DistancesUnordered}, + provider::{DataProvider, HasElementRef}, }; /// Per-call configuration that knows how to construct a [`DistancesUnordered`] for a provider @@ -51,7 +51,7 @@ where /// by every lifetime of `Visitor`. Two lifetimes are needed: `'a` for the visitor /// instance and `'b` for the reborrowed element. type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< - as OnElementsUnordered>::ElementRef<'b>, + as HasElementRef>::ElementRef<'b>, f32, > + Send + Sync diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index f796a5054..8fcbcc420 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -89,7 +89,7 @@ use crate::{ neighbor::Neighbor, provider::{ Accessor, AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, - DataProvider, HasId, NeighborAccessor, + DataProvider, DistancesUnordered, HasElementRef, HasId, NeighborAccessor, }, utils::VectorId, }; @@ -242,7 +242,7 @@ impl HybridPredicate for NotInMut<'_, T> where T: Clone + Eq + std::hash:: /// /// The provided implementation works on each element of `ids` sequentially, pre-filters /// the resulting candidate list using `pred.eval()` before invoking -/// [`BuildQueryComputer::distances_unordered`]. +/// [`DistancesUnordered::distances_unordered`]. /// /// The callback `on_neighbors` is decorated to the uses `pred.eval_mut()`. /// @@ -252,7 +252,7 @@ impl HybridPredicate for NotInMut<'_, T> where T: Clone + Eq + std::hash:: /// ## Error Handling /// /// Transient errors yielded by `distances_unordered` are acknowledged and not escalated. -pub trait ExpandBeam: BuildQueryComputer + AsNeighbor + Sized { +pub trait ExpandBeam: DistancesUnordered + AsNeighbor + Sized { fn expand_beam( &mut self, ids: Itr, @@ -299,7 +299,7 @@ where /// We could grab this type from the `SearchAccessor` associated type, but it's /// useful enough that we move it up here. type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< - as Accessor>::ElementRef<'b>, + as HasElementRef>::ElementRef<'b>, f32, > + Send + Sync @@ -386,7 +386,7 @@ macro_rules! default_post_processor { /// directly into the output buffer. pub trait SearchPostProcess::Id> where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, { type Error: StandardError; @@ -412,7 +412,7 @@ pub struct CopyIds; impl SearchPostProcess for CopyIds where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, { type Error = std::convert::Infallible; fn post_process( @@ -437,7 +437,7 @@ where /// using a [`Pipeline`]. pub trait SearchPostProcessStep::Id> where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, { /// A potentially modified version of the error yielded by the next state in the /// processing pipeline. @@ -446,7 +446,7 @@ where NextError: StandardError; /// The accessor that will be passed to the next processing stage. - type NextAccessor: BuildQueryComputer; + type NextAccessor: BuildQueryComputer + HasId; /// Perform any modification the `input`, `output`, `accessor`, or `computer` objects /// and invoke the [`SearchPostProcess`] routine `next` on stage. @@ -471,7 +471,7 @@ pub struct FilterStartPoints; impl SearchPostProcessStep for FilterStartPoints where - A: BuildQueryComputer + SearchExt, + A: BuildQueryComputer + SearchExt + HasId, T: Copy + Send + Sync, { /// A this level, sub-errors are converted into [`ANNError`] to provide additional @@ -540,7 +540,7 @@ impl Pipeline { impl SearchPostProcess for Pipeline where - A: BuildQueryComputer, + A: BuildQueryComputer + HasId, Head: SearchPostProcessStep, Tail: SearchPostProcess + Sync, { @@ -614,8 +614,8 @@ where /// We could grab this type from the `PruneAccessor` associated type, but it's /// useful enough that we move it up here. type DistanceComputer: for<'a, 'b, 'c, 'd> DistanceFunction< - as Accessor>::ElementRef<'b>, - as Accessor>::ElementRef<'d>, + as HasElementRef>::ElementRef<'b>, + as HasElementRef>::ElementRef<'d>, f32, > + Send + Sync @@ -855,7 +855,7 @@ mod tests { use super::*; use crate::{ ANNResult, neighbor, - provider::{DelegateNeighbor, ExecutionContext, HasId, NeighborAccessor}, + provider::{DelegateNeighbor, ExecutionContext, HasElementRef, HasId, NeighborAccessor}, }; // A really simple provider that just holds floats and uses the absolute value for its @@ -928,12 +928,15 @@ mod tests { type Id = u32; } + impl HasElementRef for Retriever<'_> { + type ElementRef<'a> = f32; + } + impl Accessor for Retriever<'_> { type Element<'a> = f32 where Self: 'a; - type ElementRef<'a> = f32; type GetError = ANNError; fn get_element( @@ -980,6 +983,8 @@ mod tests { impl ExpandBeam for Retriever<'_> {} + impl DistancesUnordered for Retriever<'_> {} + // This strategy explicitly does not define `post_process` so we can test the provided // implementation. struct Strategy; diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 62a4a45be..3025cd742 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -1032,12 +1032,15 @@ impl provider::HasId for Accessor<'_> { type Id = u32; } +impl provider::HasElementRef for Accessor<'_> { + type ElementRef<'a> = &'a [f32]; +} + impl provider::Accessor for Accessor<'_> { type Element<'a> = &'a [f32] where Self: 'a; - type ElementRef<'a> = &'a [f32]; type GetError = AccessError; async fn get_element(&mut self, id: u32) -> Result<&[f32], AccessError> { @@ -1077,6 +1080,8 @@ impl provider::BuildQueryComputer<&[f32]> for Accessor<'_> { } } +impl provider::DistancesUnordered<&[f32]> for Accessor<'_> {} + impl provider::BuildDistanceComputer for Accessor<'_> { type DistanceComputerError = Infallible; type DistanceComputer = ::Distance; diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index b4a652e6b..e409ef1b0 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -71,10 +71,14 @@ //! * [`BuildDistanceComputer`]: A sub-trait of [`Accessor`] that allows for random-access //! distance computations on the retrieved elements. //! -//! * [`BuildQueryComputer`]: A sub-trait of [`Accessor`] that allows for specialized query -//! based computations. This allows a query to be pre-processed in a way that allows +//! * [`BuildQueryComputer`]: A sub-trait of [`HasElementRef`] that allows for specialized +//! query based computations. This allows a query to be pre-processed in a way that allows //! faster computations. //! +//! * [`DistancesUnordered`]: A sub-trait of [`Accessor`] and [`BuildQueryComputer`] that +//! provides a fused iterate-and-compute primitive over a set of element ids using a +//! pre-built query computer. +//! //! # Neighbor Delegation //! //! Index search requires that accessor types implement both the data-centric [`Accessor`] @@ -381,6 +385,17 @@ where } } + +/////////////////// +// HasElementRef // +/////////////////// + +/// A catch-all super trait for traits that have an associated `ElementRef<'_>` +/// type. Traits like [`Accessor`] are subtraits. +pub trait HasElementRef { + type ElementRef<'a>; +} + ////////////// // Accessor // ////////////// @@ -419,18 +434,15 @@ where /// The need for `ElementRef` arises to allow HRTB bounds to distance computers without /// inducing a `'static` bound on `Self`. In traits like [`BuildQueryComputer`], attempting /// to use `Element` directly will result in such a requirement on the implementing Accessor. -pub trait Accessor: HasId + Send + Sync { - /// A generalized reference type used for distance computations. - /// - /// Note that the lifetime of `ElementRef` is unconstrained and thus using it in a - /// [HRTB](https://doc.rust-lang.org/nomicon/hrtb.html) will not induce a `'static` - /// requirement on `Self`. - type ElementRef<'a>; - +pub trait Accessor: HasId + HasElementRef + Send + Sync { /// The concrete type of the data element associated with this accessor. /// /// For distance computations, this should be cheaply convertible via [`Reborrow`] to /// `Self::ElementRef`. + /// + /// Note that the lifetime of `ElementRef` is unconstrained and thus using it in a + /// [HRTB](https://doc.rust-lang.org/nomicon/hrtb.html) will not induce a `'static` + /// requirement on `Self`. type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync where Self: 'a; @@ -523,11 +535,18 @@ pub trait BuildDistanceComputer: Accessor { ) -> Result; } -/// A specialized [`Accessor`] that provides query computations for a query type `T`. +/// A trait that provides query computations for a query type `T`. /// /// Query computers are allowed to preprocess the query to enable more efficient distance /// computations. -pub trait BuildQueryComputer: Accessor { +/// +/// This trait only requires [`HasElementRef`] (so the query computer's element type can be +/// named) so that it can be used with multiple access patterns - like [`Accessor`] and +/// [`crate::flat::FlastSearchStrategy`]. +/// +/// A fused iterate-and-compute primitive can be created as a sub-trait - +/// e.g. [`DistancesUnordered`], which requires both [`Accessor`] and `BuildQueryComputer`. +pub trait BuildQueryComputer: HasElementRef { /// The error type (if any) associated with distance computer construction. type QueryComputerError: std::error::Error + Into + Send + Sync + 'static; @@ -546,11 +565,16 @@ pub trait BuildQueryComputer: Accessor { &self, from: T, ) -> Result; +} - /// Compute the distances for the elements in the iterator `itr` using the - /// `computer` and apply the closure `f` to each distance and ID. The default - /// implementation uses on_elements_unordered to iterate over the elements - /// and compute the distances using `computer` parameter. +/// A sub-trait of [`Accessor`] and [`BuildQueryComputer`] that exposes the fused +/// iterate-and-compute primitive `distances_unordered`. +/// +/// The default implementation uses [`Accessor::on_elements_unordered`] to iterate over the +/// elements and computes their distances using the provided `computer`. +pub trait DistancesUnordered: Accessor + BuildQueryComputer { + /// Compute the distances for the elements in the iterator `vec_id_itr` using the + /// `computer` and apply the closure `f` to each distance and ID. fn distances_unordered( &mut self, vec_id_itr: Itr, @@ -1073,12 +1097,14 @@ mod tests { impl HasId for FloatAccessor<'_> { type Id = u32; } + impl HasElementRef for FloatAccessor<'_> { + type ElementRef<'a> = f32; + } impl Accessor for FloatAccessor<'_> { type Element<'a> = f32 where Self: 'a; - type ElementRef<'a> = f32; type GetError = Missing; @@ -1139,12 +1165,14 @@ mod tests { impl HasId for StringAccessor<'_> { type Id = u32; } + impl HasElementRef for StringAccessor<'_> { + type ElementRef<'a> = &'a str; + } impl Accessor for StringAccessor<'_> { type Element<'a> = &'a str where Self: 'a; - type ElementRef<'a> = &'a str; type GetError = Missing; @@ -1298,12 +1326,15 @@ mod tests { common_test_accessor!(Allocating<'_>); + impl HasElementRef for Allocating<'_> { + type ElementRef<'a> = &'a [u8]; + } + impl Accessor for Allocating<'_> { type Element<'a> = Box<[u8]> where Self: 'a; - type ElementRef<'a> = &'a [u8]; type GetError = Infallible; async fn get_element(&mut self, _: u32) -> Result, Infallible> { @@ -1325,6 +1356,10 @@ mod tests { common_test_accessor!(Forwarding<'_>); + impl HasElementRef for Forwarding<'_> { + type ElementRef<'a> = &'a [u8]; + } + impl<'provider> Accessor for Forwarding<'provider> { // NOTE: The lifetime of `Element` is `'provider` - not `'a`. This is what makes // it a forwarding accessor. @@ -1332,7 +1367,6 @@ mod tests { = &'provider [u8] where Self: 'a; - type ElementRef<'a> = &'a [u8]; type GetError = Infallible; async fn get_element(&mut self, _: u32) -> Result<&'provider [u8], Infallible> { @@ -1369,12 +1403,15 @@ mod tests { common_test_accessor!(Wrapping<'_>); + impl HasElementRef for Wrapping<'_> { + type ElementRef<'a> = &'a [u8]; + } + impl Accessor for Wrapping<'_> { type Element<'a> = Wrapped<'a> where Self: 'a; - type ElementRef<'a> = &'a [u8]; type GetError = Infallible; async fn get_element(&mut self, _: u32) -> Result, Infallible> { @@ -1400,12 +1437,15 @@ mod tests { common_test_accessor!(Sharing<'_>); + impl HasElementRef for Sharing<'_> { + type ElementRef<'a> = &'a [u8]; + } + impl Accessor for Sharing<'_> { type Element<'a> = &'a [u8] where Self: 'a; - type ElementRef<'a> = &'a [u8]; type GetError = Infallible; async fn get_element(&mut self, _: u32) -> Result<&[u8], Infallible> { From 0a63759f73a5aba721e9d0947c871b57d22338f1 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Tue, 5 May 2026 17:40:08 -0400 Subject: [PATCH 13/24] delete flatpostprocess, cleanup and docs --- .../graph/provider/async_/inmem/product.rs | 4 +- .../graph/provider/async_/inmem/scalar.rs | 7 +- .../graph/provider/async_/inmem/spherical.rs | 4 +- .../model/graph/provider/layers/betafilter.rs | 5 +- diskann/src/flat/index.rs | 23 +++--- diskann/src/flat/iterator.rs | 45 +++++++----- diskann/src/flat/mod.rs | 20 +++--- diskann/src/flat/post_process.rs | 72 ------------------- diskann/src/flat/strategy.rs | 61 +++++----------- diskann/src/provider.rs | 17 +++-- rfcs/00983-flat-search.md | 15 ++-- 11 files changed, 93 insertions(+), 180 deletions(-) delete mode 100644 diskann/src/flat/post_process.rs diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index fa202d09f..d13505b8e 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -16,8 +16,8 @@ use diskann::{ workingset, }, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, ExecutionContext, - HasElementRef, HasId, + Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, + ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index 78f60779c..ec278a552 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -16,8 +16,8 @@ use diskann::{ workingset, }, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, ExecutionContext, - HasElementRef, HasId, + Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, + ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -563,7 +563,8 @@ where { } -impl DistancesUnordered<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx> +impl DistancesUnordered<&[T]> + for QuantAccessor<'_, NBITS, V, D, Ctx> where T: VectorRepr, V: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 7b7d9c2b4..f8045dbbe 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -18,8 +18,8 @@ use diskann::{ workingset, }, provider::{ - Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, ExecutionContext, - HasElementRef, HasId, + Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, + ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index 9a38379cf..ff822dd9d 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -310,7 +310,10 @@ impl ExpandBeam for BetaAccessor where { } -impl DistancesUnordered for BetaAccessor where Inner: BuildQueryComputer + Accessor {} +impl DistancesUnordered for BetaAccessor where + Inner: BuildQueryComputer + Accessor +{ +} /// A [`PreprocessedDistanceFunction`] that applied `beta` filtering to the inner computer. pub struct BetaComputer { diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs index c5ee18824..332ba8edb 100644 --- a/diskann/src/flat/index.rs +++ b/diskann/src/flat/index.rs @@ -11,17 +11,17 @@ use diskann_utils::future::SendFuture; use crate::{ ANNResult, error::IntoANNResult, - flat::{DistancesUnordered, FlatPostProcess, FlatSearchStrategy}, - graph::{SearchOutputBuffer, index::SearchStats}, + flat::{DistancesUnordered, SearchStrategy}, + graph::{SearchOutputBuffer, glue::SearchPostProcess, index::SearchStats}, neighbor::{Neighbor, NeighborPriorityQueue}, - provider::DataProvider, + provider::{BuildQueryComputer, DataProvider}, }; /// A `'static` thin wrapper around a [`DataProvider`] used for flat search. /// /// The provider is owned by the index. The index is constructed once at process startup and /// shared across requests; per-query state lives in the [`crate::flat::OnElementsUnordered`] -/// implementation that the [`crate::flat::FlatSearchStrategy`] produces. +/// implementation that the [`SearchStrategy`] produces. #[derive(Debug)] pub struct FlatIndex { /// The backing provider. @@ -47,7 +47,7 @@ impl FlatIndex

{ /// /// # Arguments /// - `k`: number of nearest neighbors to return. - /// - `strategy`: produces the per-query iterator and the query computer. See [`FlatSearchStrategy`] + /// - `strategy`: produces the per-query iterator and the query computer. See [`SearchStrategy`]. /// - `processor`: post-processes the survivor candidates into the output type. /// - `context`: per-request context threaded through to the provider. /// - `query`: the query. @@ -58,22 +58,23 @@ impl FlatIndex

{ strategy: &S, processor: &PP, context: &P::Context, - query: &T, + query: T, output: &mut OB, ) -> impl SendFuture> where - S: FlatSearchStrategy, - T: ?Sized + Sync, + S: SearchStrategy, + T: Copy + Send + Sync, O: Send, OB: SearchOutputBuffer + Send + ?Sized, - PP: for<'a> FlatPostProcess, T, O> + Send + Sync, + PP: for<'a> SearchPostProcess, T, O> + Send + Sync, { async move { let mut visitor = strategy .create_visitor(&self.provider, context) .into_ann_result()?; - let computer = strategy.build_query_computer(query).into_ann_result()?; + let computer = + BuildQueryComputer::build_query_computer(&visitor, query).into_ann_result()?; let k = k.get(); let mut queue = NeighborPriorityQueue::new(k); @@ -88,7 +89,7 @@ impl FlatIndex

{ .into_ann_result()?; let result_count = processor - .post_process(&mut visitor, query, queue.iter().take(k), output) + .post_process(&mut visitor, query, &computer, queue.iter().take(k), output) .await .into_ann_result()? as u32; diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index 459b02216..40061f9d9 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -4,14 +4,19 @@ */ //! [`OnElementsUnordered`] — the sequential access primitive for accessing a flat index. +//! [`DistancesUnordered`] — sub-trait of [`OnElementsUnordered`] and [`BuildQueryComputer`] +//! that computes distances over the elements in the flat index. //! //! [`FlatIterator`] — a lending async iterator that can be bridged into -//! [`OnElementsUnordered`] via [`DefaultIteratedOperator`]. +//! [`OnElementsUnordered`] via [`Iterated`]. use diskann_utils::{Reborrow, future::SendFuture}; use diskann_vector::PreprocessedDistanceFunction; -use crate::{error::StandardError, provider::{HasElementRef, HasId}}; +use crate::{ + error::StandardError, + provider::{BuildQueryComputer, HasElementRef, HasId}, +}; /// Callback-driven sequential scan over the elements of a flat index. /// @@ -30,31 +35,31 @@ pub trait OnElementsUnordered: HasId + HasElementRef + Send + Sync { F: Send + for<'a> FnMut(Self::Id, ::ElementRef<'a>); } -/// Extension of [`OnElementsUnordered`] that drives the scan with a pre-built query -/// computer, invoking a callback with `(id, distance)` pairs instead of raw elements. +/// Extension of [`OnElementsUnordered`] that drives the scan with a query computer +/// produced by the visitor's [`BuildQueryComputer`] impl, invoking a callback with +/// `(id, distance)` pairs. /// -/// The concrete computer is insantiated and supplied externally -/// by the [`FlatSearchStrategy`](crate::flat::FlatSearchStrategy). +/// This fuses the scan with a pre-processed query computer and runs over a +/// streaming visitor. It pulls the computer type from the implementor's own +/// [`BuildQueryComputer`] impl. /// /// The default implementation delegates to [`OnElementsUnordered::on_elements_unordered`], /// calling `computer.evaluate_similarity` on each element. -pub trait DistancesUnordered: OnElementsUnordered { - /// The concrete type of the distance computer for a query, which should be applicable for - /// all elements in the underlying driver. - type QueryComputer: for<'a> PreprocessedDistanceFunction<::ElementRef<'a>, f32> - + Send - + Sync - + 'static; - +/// +/// ## Note +/// +/// This is the flat analog to [`crate::provider::DistancesUnordered`] which runs over +/// a random-access [`crate::provider::Accessor`]. +pub trait DistancesUnordered: OnElementsUnordered + BuildQueryComputer { /// Drive the entire scan, scoring each element with `computer` and invoking `f` with /// the resulting `(id, distance)` pair. fn distances_unordered( &mut self, - computer: &Self::QueryComputer, + computer: &>::QueryComputer, mut f: F, - ) -> impl SendFuture> + ) -> impl SendFuture::Error>> where - F: Send + FnMut(Self::Id, f32), + F: Send + FnMut(::Id, f32), { self.on_elements_unordered(move |id, element| { let dist = computer.evaluate_similarity(element); @@ -74,7 +79,9 @@ pub trait DistancesUnordered: OnElementsUnordered { /// an [`OnElementsUnordered`] implementation automatically. pub trait FlatIterator: HasId + HasElementRef + Send + Sync { /// The concrete element returned by [`Self::next`]. Reborrows to [`Self::ElementRef`]. - type Element<'a>: for<'b> Reborrow<'b, Target = ::ElementRef<'b>> + Send + Sync + type Element<'a>: for<'b> Reborrow<'b, Target = ::ElementRef<'b>> + + Send + + Sync where Self: 'a; @@ -123,7 +130,7 @@ impl HasId for Iterated { impl HasElementRef for Iterated { type ElementRef<'a> = I::ElementRef<'a>; -} +} impl OnElementsUnordered for Iterated where diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs index d3407310a..b939056a0 100644 --- a/diskann/src/flat/mod.rs +++ b/diskann/src/flat/mod.rs @@ -14,14 +14,14 @@ //! //! The module mirrors the layering used by graph search: //! -//! | Graph (random access) | Flat (sequential) | -//! | :------------------------------------ | :-------------------------------- | -//! | [`crate::provider::DataProvider`] | [`crate::provider::DataProvider`] | -//! | [`crate::graph::DiskANNIndex`] | [`FlatIndex`] | -//! | [`crate::provider::Accessor`] | [`FlatIterator`] | -//! | [`crate::graph::glue::SearchStrategy`] | [`FlatSearchStrategy`] | -//! | [`crate::graph::glue::SearchPostProcess`] | [`FlatPostProcess`] | -//! | [`crate::graph::Search`] | [`FlatIndex::knn_search`] | +//! | Graph (random access) | Flat (sequential) | Shared? | +//! | :------------------------------------ | :-------------------------------- |:--------- | +//! | [`crate::provider::DataProvider`] | [`crate::provider::DataProvider`] | Yes | +//! | [`crate::graph::DiskANNIndex`] | [`FlatIndex`] | No | +//! | [`crate::provider::Accessor`] | [`FlatIterator`] | No | +//! | [`crate::graph::glue::SearchStrategy`] | [`SearchStrategy`] | No | +//! | [`crate::graph::glue::SearchPostProcess`] | [`crate::graph::glue::SearchPostProcess`] | Yes | +//! | [`crate::graph::Search`] | [`FlatIndex::knn_search`] | No | //! //! # Hot loop //! @@ -34,10 +34,8 @@ pub mod index; pub mod iterator; -pub mod post_process; pub mod strategy; pub use index::FlatIndex; pub use iterator::{DistancesUnordered, FlatIterator, Iterated, OnElementsUnordered}; -pub use post_process::{CopyIds, FlatPostProcess}; -pub use strategy::FlatSearchStrategy; +pub use strategy::SearchStrategy; diff --git a/diskann/src/flat/post_process.rs b/diskann/src/flat/post_process.rs deleted file mode 100644 index 2cab763dd..000000000 --- a/diskann/src/flat/post_process.rs +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -//! [`FlatPostProcess`] — terminal stage of the flat search pipeline. - -use diskann_utils::future::SendFuture; - -use crate::{ - error::StandardError, flat::OnElementsUnordered, graph::SearchOutputBuffer, neighbor::Neighbor, - provider::HasId, -}; - -/// Post-process the survivor candidates produced by a flat search and -/// write them into an output buffer. -/// -/// This is the flat counterpart to [`crate::graph::glue::SearchPostProcess`]. Processors -/// receive `&mut S` so they can consult any iterator-owned lookup state (e.g., an -/// `Id -> rich-record` table built up during the scan) when assembling outputs. -/// -/// The `O` type parameter lets callers pick the output element type (raw `(Id, f32)` -/// pairs, fully hydrated hits etc.). -pub trait FlatPostProcess::Id> -where - S: OnElementsUnordered, - T: ?Sized, -{ - /// Errors yielded by [`Self::post_process`]. - type Error: StandardError; - - /// Consume `candidates` (in distance order) and write at most `k` results into - /// `output`. Returns the number of results written. - fn post_process( - &self, - iter: &mut S, - query: &T, - candidates: I, - output: &mut B, - ) -> impl SendFuture> - where - I: Iterator> + Send, - B: SearchOutputBuffer + Send + ?Sized; -} - -/// A trivial [`FlatPostProcess`] that copies each `(Id, distance)` pair straight into the -/// output buffer. -#[derive(Debug, Default, Clone, Copy)] -pub struct CopyIds; - -impl FlatPostProcess for CopyIds -where - S: OnElementsUnordered, - T: ?Sized, -{ - type Error = crate::error::Infallible; - - fn post_process( - &self, - _iter: &mut S, - _query: &T, - candidates: I, - output: &mut B, - ) -> impl SendFuture> - where - I: Iterator::Id>> + Send, - B: SearchOutputBuffer<::Id> + Send + ?Sized, - { - let count = output.extend(candidates.map(|n| (n.id, n.distance))); - std::future::ready(Ok(count)) - } -} diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs index ee3fce96c..f8ec19e78 100644 --- a/diskann/src/flat/strategy.rs +++ b/diskann/src/flat/strategy.rs @@ -3,61 +3,40 @@ * Licensed under the MIT license. */ -//! [`FlatSearchStrategy`] — glue between [`DataProvider`] and per-query [`FlatIterator`]s. +//! [`SearchStrategy`] — glue between [`DataProvider`] and per-query [`crate::flat::FlatIterator`]s. -use diskann_vector::PreprocessedDistanceFunction; +use crate::{error::StandardError, flat::DistancesUnordered, provider::DataProvider}; -use crate::{ - error::StandardError, - flat::{DistancesUnordered}, - provider::{DataProvider, HasElementRef}, -}; - -/// Per-call configuration that knows how to construct a [`DistancesUnordered`] for a provider -/// and how to pre-process queries of type `T` into a distance computer. -/// -/// `FlatSearchStrategy` is the flat counterpart to [`crate::graph::glue::SearchStrategy`]. -/// A strategy instance is stateless config — typically constructed at the call site, used -/// for one search, and dropped. +/// Per-call configuration that knows how to construct a per-query +/// [`DistancesUnordered`] visitor for a provider. /// -/// # Why two methods? +/// `SearchStrategy` is the flat counterpart to [`crate::graph::glue::SearchStrategy`] +/// (disambiguated by module path). A strategy instance is stateless config — typically +/// constructed at the call site, used for one search, and dropped. /// -/// - [`Self::create_visitor`] is query-independent and may be called multiple times per -/// request (e.g., once per parallel query in a batched search). -/// - [`Self::build_query_computer`] is visitor-independent — the same query can be -/// pre-processed once and used against multiple visitors. -/// -/// Both methods may borrow from the strategy itself. +/// The strategy itself is a pure factory; the visitor it produces carries the +/// query-preprocessing capability via [`crate::provider::BuildQueryComputer`] (a +/// super-trait of [`DistancesUnordered`]). /// /// # Type parameters /// -/// - `Provider`: the [`DataProvider`] that backs the index. -/// - `T`: the query type. Often `[E]` for vector queries; can be any `?Sized` type. -pub trait FlatSearchStrategy: Send + Sync +/// - `P`: the [`DataProvider`] that backs the index. +/// - `T`: the query type that the query computer is constructed using. +pub trait SearchStrategy: Send + Sync where P: DataProvider, - T: ?Sized, { /// The visitor type produced by [`Self::create_visitor`]. Borrows from `self` and the /// provider. - type Visitor<'a>: DistancesUnordered + /// + /// The visitor implements both the streaming [`DistancesUnordered`] primitive and + /// the query preprocessor [`crate::provider::BuildQueryComputer`]. + type Visitor<'a>: DistancesUnordered where Self: 'a, P: 'a; - /// The query computer produced by [`Self::build_query_computer`]. - /// - /// The HRTB on `ElementRef` ensures the same computer can score every element yielded - /// by every lifetime of `Visitor`. Two lifetimes are needed: `'a` for the visitor - /// instance and `'b` for the reborrowed element. - type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< - as HasElementRef>::ElementRef<'b>, - f32, - > + Send - + Sync - + 'static; - - /// The error type for both factory methods. + /// The error type for [`Self::create_visitor`]. type Error: StandardError; /// Construct a fresh visitor over `provider` for the given request `context`. @@ -70,8 +49,4 @@ where provider: &'a P, context: &'a P::Context, ) -> Result, Self::Error>; - - /// Pre-process a query into a [`Self::QueryComputer`] usable for distance computation - /// against any visitor produced by [`Self::create_visitor`]. - fn build_query_computer(&self, query: &T) -> Result; } diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index e409ef1b0..505e08337 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -385,15 +385,14 @@ where } } - /////////////////// // HasElementRef // -/////////////////// +/////////////////// -/// A catch-all super trait for traits that have an associated `ElementRef<'_>` +/// A catch-all super trait for traits that have an associated `ElementRef<'_>` /// type. Traits like [`Accessor`] are subtraits. pub trait HasElementRef { - type ElementRef<'a>; + type ElementRef<'a>; } ////////////// @@ -439,7 +438,7 @@ pub trait Accessor: HasId + HasElementRef + Send + Sync { /// /// For distance computations, this should be cheaply convertible via [`Reborrow`] to /// `Self::ElementRef`. - /// + /// /// Note that the lifetime of `ElementRef` is unconstrained and thus using it in a /// [HRTB](https://doc.rust-lang.org/nomicon/hrtb.html) will not induce a `'static` /// requirement on `Self`. @@ -541,10 +540,10 @@ pub trait BuildDistanceComputer: Accessor { /// computations. /// /// This trait only requires [`HasElementRef`] (so the query computer's element type can be -/// named) so that it can be used with multiple access patterns - like [`Accessor`] and -/// [`crate::flat::FlastSearchStrategy`]. -/// -/// A fused iterate-and-compute primitive can be created as a sub-trait - +/// named) so that it can be used with multiple access patterns - like [`Accessor`] and +/// [`crate::flat::FlastSearchStrategy`]. +/// +/// A fused iterate-and-compute primitive can be created as a sub-trait - /// e.g. [`DistancesUnordered`], which requires both [`Accessor`] and `BuildQueryComputer`. pub trait BuildQueryComputer: HasElementRef { /// The error type (if any) associated with distance computer construction. diff --git a/rfcs/00983-flat-search.md b/rfcs/00983-flat-search.md index e552b6164..fbf9889f1 100644 --- a/rfcs/00983-flat-search.md +++ b/rfcs/00983-flat-search.md @@ -31,7 +31,7 @@ stuffing the algorithm or the backend through the `Accessor` trait surface. ## Proposal -The flat-search infrastructure is built on a small sequence of traits. The only required traits for the algorithm is `OnElementsUnordered` and its subtrait `DistancesUnordered`. A strategy - `FlatSearchStrategy` - instantiates these implementations for specific providers. An opt-in iterator trait `FlatIterator` and default implementations of the core traits - `DefaultIteratedOperator` - exist for convenience for backends that naturally expose element-at-a-time iteration. +The flat-search infrastructure is built on a small sequence of traits. The only required traits for the algorithm is `OnElementsUnordered` and its subtrait `DistancesUnordered`. A strategy - `flat::SearchStrategy` - instantiates these implementations for specific providers. An opt-in iterator trait `FlatIterator` and default implementations of the core traits - `DefaultIteratedOperator` - exist for convenience for backends that naturally expose element-at-a-time iteration. ### `OnElementsUnordered` — the core scan @@ -70,7 +70,7 @@ A subtrait that fuses scanning with scoring. The default implementation loops `on_elements_unordered` and calls `computer.evaluate_similarity` on each element. The query computer is a generic parameter rather than an associated type, so the same -callback type can be driven by different computers. The `FlatSearchStrategy` is the +callback type can be driven by different computers. The `flat::SearchStrategy` is the source of truth for which computer is used in any given search. ### `FlatIterator` and `Iterated` — convenience for element-at-a-time backends @@ -97,16 +97,17 @@ pub trait FlatIterator: HasId + Send + Sync { element. -### The glue: `FlatSearchStrategy` +### The glue: `flat::SearchStrategy` While `OnElementsUnordered` is the primary handle the algorithm uses to walk the index, -it is scoped to each query. We introduce a constructor — `FlatSearchStrategy` — similar -to `SearchStrategy` for `Accessor`, to instantiate the per-query visitor. +it is scoped to each query. We introduce a constructor — `flat::SearchStrategy` — similar +to the random-access `graph::glue::SearchStrategy` (the two share a name and live in +distinct modules), to instantiate the per-query visitor. A strategy is per-call configuration that is stateless, cheap to construct and scoped to one search. It produces both a per-query visitor and a query computer. ```rust -pub trait FlatSearchStrategy: Send + Sync +pub trait SearchStrategy: Send + Sync where P: DataProvider, T: ?Sized, @@ -169,7 +170,7 @@ impl FlatIndex

{ output: &mut OB, ) -> impl SendFuture> where - S: FlatSearchStrategy, + S: flat::SearchStrategy, T: ?Sized + Sync, O: Send, OB: SearchOutputBuffer + Send + ?Sized, From 57d86e9636780834a5777f01acac89c86f37e430 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Tue, 5 May 2026 23:21:54 -0400 Subject: [PATCH 14/24] update rfc --- rfcs/00983-flat-search.md | 267 +++++++++++++++++++++++--------------- 1 file changed, 160 insertions(+), 107 deletions(-) diff --git a/rfcs/00983-flat-search.md b/rfcs/00983-flat-search.md index fbf9889f1..867f66f90 100644 --- a/rfcs/00983-flat-search.md +++ b/rfcs/00983-flat-search.md @@ -4,129 +4,129 @@ |------------------|--------------------------------| | **Authors** | Aditya Krishnan, Alex Razumov, Dongliang Wu | | **Created** | 2026-04-24 | -| **Updated** | 2026-04-28 | +| **Updated** | 2026-05-05 | -## Motivation +## 1. Motivation -### Background +### 1.1 Background DiskANN today exposes a single abstraction family centered on the [`crate::provider::Accessor`] trait. Accessors are random access by design since the graph greedy search algorithm needs to decide which ids to fetch and the accessor materializes the corresponding elements (vectors, quantized vectors and neighbor lists) on demand. This is the right contract for graph search, where neighborhood expansion is inherently random-access against the [`crate::provider::DataProvider`]. A growing class of consumers diverge from our current pattern of use by accesssing their index **sequentially**. Some consumers build their index in an "append-only" fashion and require that they walk the index in a sequential, fixed order, relying on iteration position to enforce versioning / deduplication invariants. -### Problem Statement +### 1.2 Problem Statement The problem-statement here is simple: provide first-class support for sequential, one-pass scans over a data backend without stuffing the algorithm or the backend through the `Accessor` trait surface. -### Goals +### 1.3 Goals 1. Define a streaming access primitive — `OnElementsUnordered` — that mirrors the role `Accessor` plays for graph search but exposes a callback-driven scan instead of random access. -2. Provide flat-search algorithm implementations (with `knn_search` as default and filtered and diverse variants to opt-into) built on the new - primitives, so consumers can use this against their own providers / backends. -3. Expose support for features and implementations native to the repo like quantized distance computers out-of-the-box. - -## Proposal - -The flat-search infrastructure is built on a small sequence of traits. The only required traits for the algorithm is `OnElementsUnordered` and its subtrait `DistancesUnordered`. A strategy - `flat::SearchStrategy` - instantiates these implementations for specific providers. An opt-in iterator trait `FlatIterator` and default implementations of the core traits - `DefaultIteratedOperator` - exist for convenience for backends that naturally expose element-at-a-time iteration. - -### `OnElementsUnordered` — the core scan +2. Provide flat-search algorithm implementations built on the new primitives, so consumers can use this against their own providers / backends. +3. Expose support for diferent distance computers and post-processing like re-ranking _out-of-the-box_ without having to reimplement these for the flat search path. + +## 2. Proposal + +The flat-search infrastructure is built on a small sequence of traits. The only traits a +backend *must* implement are `OnElementsUnordered` and its subtrait +`flat::DistancesUnordered`. A `flat::SearchStrategy` then instantiates them per +query. + +An opt-in `FlatIterator` trait plus the `Iterated` adapter exist for +convenience for backends that naturally expose element-at-a-time iteration. + +### 2.1 Refactor `Accessor` and `BuildQueryComputer ` +We start by a small refactor we introduce to the traits in `diskann::providers` and `diskann::graph::glue` that +will enable us to cleanly separate the query preprocessing, result post-processing from the search pattern so that +both graph and flat search can share common components as much as possible: + +1. **Extract `HasElementRef` out of `Accessor`.** The `ElementRef<'a>` GAT moves to + its own zero-method trait so that streaming visitors (which are not `Accessor`s) + can still expose an element type. `Accessor` is now `Accessor: HasId + + HasElementRef + Send + Sync`. `HasElementRef` is simply: + + ```rust + pub trait HasElementRef { type ElementRef<'a> } + ``` + +2. **Decouple `BuildQueryComputer` from `Accessor`.** Previously a + sub-trait of `Accessor`, `BuildQueryComputer` is lifted to depend only on + `HasElementRef`. Secondly, it now contains only a constructor `build_query_computer` + as an associated method and nothing else. This is the + change that lets `BuildQueryComputer` and `graph::glue::SearchPostProcess` be + used unchanged by both the flat index and the graph. + +3. **Split distance scoring into a new `DistancesUnordered` trait family.** + Previously, the unordered iterate-and-score loop was a default method tucked + inside `Accessor` (and shadowed by overrides on a few providers). It is now its + own subtrait of `BuildQueryComputer`, with two flavors that share a name and a + default-body shape but differ in their access super-trait: + + - **`provider::DistancesUnordered: Accessor + BuildQueryComputer`** — drives + the scan via the random-access `Accessor` machinery. Used by graph search. + - **`flat::DistancesUnordered: OnElementsUnordered + BuildQueryComputer`** — + drives the scan via the new sequential `OnElementsUnordered` primitive. This primitive + is used by flat search. More on it below. + +### 2.2 Core traits for flat search +At the very core is the `OnElementsUnordered` trait, which is simply an API to implement +a callback on the entire index. Implementations choose iteration order, prefetching, and +any bulk reads if they want; algorithms see only `(Id, ElementRef)` pairs. ```rust -pub trait OnElementsUnordered: HasId + Send + Sync { - type ElementRef<'a>; +pub trait OnElementsUnordered: HasId + HasElementRef + Send + Sync { type Error: StandardError; fn on_elements_unordered(&mut self, f: F) -> impl SendFuture> where - F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>); + F: Send + for<'a> FnMut(Self::Id, ::ElementRef<'a>); } ``` -A single required method: drive the entire scan via a callback. Async to match -[`crate::provider::Accessor`]. Implementations choose iteration order, prefetching, and -any SIMD-friendly bulk reads if they want; algorithms see only `(Id, ElementRef)` pairs. +`Id` and `ElementRef<'a>` come from the shared `HasId` / `HasElementRef` traits, so a +type that implements `Accessor` and `OnElementsUnordered` exposes the same id and +element types to both subsystems. -### `DistancesUnordered` — the distance subtrait +For computing distance with a query specifically, we define a sub-trait of the above - `flat::DistancesUnordered`. ```rust -pub trait DistancesUnordered: OnElementsUnordered { - fn distances_unordered( - &mut self, computer: &C, mut f: F, - ) -> impl SendFuture> +pub trait DistancesUnordered: OnElementsUnordered + BuildQueryComputer { + fn distances_unordered( + &mut self, + computer: &>::QueryComputer, + f: F, + ) -> impl SendFuture::Error>> where - C: for<'a> PreprocessedDistanceFunction, f32> + Send + Sync, - F: Send + FnMut(Self::Id, f32), + F: Send + FnMut(::Id, f32), { // default delegates to on_elements_unordered + evaluate_similarity } } ``` +The default implementation loops `on_elements_unordered` and calls `computer.evaluate_similarity` on each element; +backends that can fuse retrieval and scoring can override it. -A subtrait that fuses scanning with scoring. The default implementation loops -`on_elements_unordered` and calls `computer.evaluate_similarity` on each element. - -The query computer is a generic parameter rather than an associated type, so the same -callback type can be driven by different computers. The `flat::SearchStrategy` is the -source of truth for which computer is used in any given search. - -### `FlatIterator` and `Iterated` — convenience for element-at-a-time backends - -For backends that naturally expose element-at-a-time iteration, `FlatIterator` is a -lending async iterator: - -```rust -pub trait FlatIterator: HasId + Send + Sync { - type ElementRef<'a>; - // lifetime gymnastics to make lifetime of `Element<'_>` to play nice with HRTB - type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync - where Self: 'a; - type Error: StandardError; - - fn next( - &mut self, - ) -> impl SendFuture)>, Self::Error>>; -} -``` - -`Iterated` wraps any `FlatIterator` and implements `OnElementsUnordered` -(and `DistancesUnordered` by inheritance) by looping over `next()` and reborrowing each -element. - - -### The glue: `flat::SearchStrategy` - -While `OnElementsUnordered` is the primary handle the algorithm uses to walk the index, -it is scoped to each query. We introduce a constructor — `flat::SearchStrategy` — similar -to the random-access `graph::glue::SearchStrategy` (the two share a name and live in -distinct modules), to instantiate the per-query visitor. -A strategy is per-call configuration that is stateless, cheap to construct and scoped to one -search. It produces both a per-query visitor and a query computer. +`DistancesUnordered` is scoped to a single query. We introduce a strategy that is the per-call +constructor that hands the algorithm a freshly-bound visitor. It is stateless, +cheap to construct, and lives only for the duration of one search. ```rust pub trait SearchStrategy: Send + Sync where P: DataProvider, - T: ?Sized, { /// The per-query visitor type produced by [`Self::create_visitor`]. Borrows from - /// `self` and the provider. - type Visitor<'a>: DistancesUnordered + /// `self` and the provider. The visitor implements both the streaming + /// [`DistancesUnordered`] primitive and the query preprocessor + /// [`BuildQueryComputer`]. + type Visitor<'a>: DistancesUnordered where Self: 'a, + P: 'a; - /// The query computer produced by [`Self::build_query_computer`]. - type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction< - as OnElementsUnordered>::ElementRef<'b>, - f32, - > + Send - + Sync - + 'static; - - /// The error type type Error: StandardError; /// Construct a fresh visitor over `provider` for the given request `context`. @@ -135,20 +135,15 @@ where provider: &'a P, context: &'a P::Context, ) -> Result, Self::Error>; - - /// Pre-process a query into a [`Self::QueryComputer`] usable for distance computation - /// against any visitor produced by [`Self::create_visitor`]. - fn build_query_computer(&self, query: &T) -> Result; } ``` +This shape mirrors the random-access `graph::glue::SearchStrategy` and lets `FlatIndex::knn_search` accept the same +`graph::glue::SearchPostProcess` that graph search uses (see below). -The `ElementRef<'b>` that the `QueryComputer` acts on is tied to the -`OnElementsUnordered::ElementRef` of the visitor produced by `create_visitor`. +### 2.3 `FlatIndex` — the top-level handle -### `FlatIndex` - -`FlatIndex` is a thin `'static` wrapper around a `DataProvider`. The same `DataProvider` -trait used by graph search is reused here - flat and graph subsystems share a single +`FlatIndex` is a thin `'static` wrapper around a `DataProvider`. The same +`DataProvider` trait used by graph search is reused — flat and graph share one provider surface and the same `Context` / id-mapping / error machinery. ```rust @@ -166,22 +161,27 @@ impl FlatIndex

{ strategy: &S, processor: &PP, context: &P::Context, - query: &T, + query: T, output: &mut OB, ) -> impl SendFuture> where S: flat::SearchStrategy, - T: ?Sized + Sync, + T: Copy + Send + Sync, O: Send, OB: SearchOutputBuffer + Send + ?Sized, - PP: for<'a> FlatPostProcess, T, O> + Send + Sync, + PP: for<'a> graph::glue::SearchPostProcess, T, O> + Send + Sync, } ``` +**Note:** The `PP` bound uses the same `graph::glue::SearchPostProcess` trait as graph search; +there is no flat-specific post-process trait. Reuse is enabled by the trait splits +described above (the visitor implements `BuildQueryComputer + HasId`, which is all +`SearchPostProcess` requires). + The `knn_search` method is the canonical brute-force search algorithm: 1. Construct the per-query visitor via `strategy.create_visitor`. -2. Build the query computer via `strategy.build_query_computer`. +2. Build the query computer from the visitor via `BuildQueryComputer::build_query_computer`. 3. Drive the scan via `visitor.distances_unordered(&computer, ...)`, inserting each `(id, distance)` pair into a `NeighborPriorityQueue` of capacity `k`. 4. Hand the survivors (in distance order) to `processor.post_process`. @@ -190,6 +190,69 @@ The `knn_search` method is the canonical brute-force search algorithm: Other algorithms (filtered, range, diverse) can be added later as additional methods on `FlatIndex`. +#### Search call chain (AI Generated) + +The diagram below traces the trait dispatch sequence inside one `search` call for +each of graph and flat search. The centre lane shows the shared traits that both +columns dip into. + +```text + Graph Shared Flat + ───── ────── ──── + + DiskANNIndex::search FlatIndex::knn_search + │ │ + ▼ ▼ + graph::glue::SearchStrategy flat::SearchStrategy + ::search_accessor ::create_visitor + │ │ + ▼ ▼ + ExpandBeam visitor DistancesUnordered visitor + (Accessor + BuildQueryComputer) (OnElementsUnordered + BuildQueryComputer) + │ │ + │ BuildQueryComputer │ + ├─────────────────►::build_query_computer ◄──────────────────────┤ + │ (visitor → QueryComputer) │ + │ │ + ▼ ▼ + ExpandBeam::expand_beam DistancesUnordered + (greedy beam loop: ::distances_unordered + for each frontier id, (one pass over every + get_neighbors, element; computer scores + distances_unordered) each one) + │ │ + ▼ ▼ + NeighborPriorityQueue NeighborPriorityQueue + │ │ + │ graph::glue::SearchPostProcess │ + └─────────────►::post_process ◄──────────────────────────────────┘ + │ + ▼ + SearchOutputBuffer +``` + +### 2.4 `FlatIterator` and `Iterated` — convenience for element-at-a-time backends + +For backends that naturally expose element-at-a-time iteration, `FlatIterator` is a +lending async iterator: + +```rust +pub trait FlatIterator: HasId + HasElementRef + Send + Sync { + type Element<'a>: for<'b> Reborrow<'b, Target = ::ElementRef<'b>> + + Send + Sync + where Self: 'a; + type Error: StandardError; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>>; +} +``` + +`Iterated` wraps any `FlatIterator` and implements `OnElementsUnordered` (and +`DistancesUnordered` by inheritance, when the inner type implements +`BuildQueryComputer`) by looping over `next()` and reborrowing each element. + ## Trade-offs ### Reusing `DataProvider` @@ -202,27 +265,17 @@ This design leans into using the `DataProvider` trait which requires implementat `to_external_id`), and error machinery are identical across graph and flat search, reducing the learning surface for new contributors. -### Async vs sync scan API - -`on_elements_unordered` and `distances_unordered` return a future, making the scan -surface async. This is the right default for disk-backed and network-backed backends -where advancing the scan involves real I/O. It also matches the `Accessor` surface, -keeping the two subsystems shaped the same way. - -The cost is paid by in-memory consumers: the scan goes through the future machinery -even when results are immediately available. In a tight brute-force loop this overhead — -poll scaffolding, pinning etc — could be measurable. - -We chose async because the wider audience of consumers (disk, network, mixed) benefits -more than in-memory consumers lose. - ### Expand `Element` to support batched distance computation? -The current design yields one element per `next()` call, and the query computer scores +The current optional iterator `FlatIterator` yields one element per `next()` call, and the query computer scores elements one at a time via `PreprocessedDistanceFunction::evaluate_similarity`. This could leave some optimization and performance on the table; especially with the upcoming effort around batched distance kernels. Of course, a consumer can choose to implement their own optimized implementation of `distances_unordered` that uses batching. An alternative is to make `next()` yield a *batch* instead of a single vector representation like `Element<'_>`. Some work will need to be done to define the right interaction between the batch type, the element type in the batch, the interaction with `QueryComputer`'s types and way IDs and distances are collected in the queue. +### Intra-query parallelism + +The current design of `OnElementsUnordered` does not allow an implementation to exploit parallelism within a query; since the trait requires a `&mut self`. Especially for a flat index, some implementations might want to parallelize within the scan for a query. Arguably we will need a more complex extension of this architecture to support this. + ## Future Work - Support for other flat-search algorithms like - filtered, range and diverse flat algorithms as additional methods on `FlatIndex`. - Index build -- this is just one part of the picture; more work needs to be done around how this fits in with any traits / interface we need for index build. From 715150ab1209b1acf71190ffe139e895495cf574 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Wed, 6 May 2026 15:30:56 -0400 Subject: [PATCH 15/24] error types --- diskann/src/flat/index.rs | 26 +++++++++++++++----------- diskann/src/flat/iterator.rs | 8 +++++--- diskann/src/flat/mod.rs | 2 +- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs index 332ba8edb..06cb5f29f 100644 --- a/diskann/src/flat/index.rs +++ b/diskann/src/flat/index.rs @@ -10,13 +10,23 @@ use diskann_utils::future::SendFuture; use crate::{ ANNResult, - error::IntoANNResult, + error::{ErrorExt, IntoANNResult}, flat::{DistancesUnordered, SearchStrategy}, - graph::{SearchOutputBuffer, glue::SearchPostProcess, index::SearchStats}, + graph::{SearchOutputBuffer, glue::SearchPostProcess}, neighbor::{Neighbor, NeighborPriorityQueue}, provider::{BuildQueryComputer, DataProvider}, }; +/// Statistics collected during a flat search. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct SearchStats { + /// The total number of distance computations performed during the scan. + pub cmps: u32, + + /// The total number of results written to the output buffer. + pub result_count: u32, +} + /// A `'static` thin wrapper around a [`DataProvider`] used for flat search. /// /// The provider is owned by the index. The index is constructed once at process startup and @@ -73,8 +83,7 @@ impl FlatIndex

{ .create_visitor(&self.provider, context) .into_ann_result()?; - let computer = - BuildQueryComputer::build_query_computer(&visitor, query).into_ann_result()?; + let computer = visitor.build_query_computer(query).into_ann_result()?; let k = k.get(); let mut queue = NeighborPriorityQueue::new(k); @@ -86,19 +95,14 @@ impl FlatIndex

{ queue.insert(Neighbor::new(id, dist)); }) .await - .into_ann_result()?; + .escalate("flat scan must complete to produce correct k-NN results")?; let result_count = processor .post_process(&mut visitor, query, &computer, queue.iter().take(k), output) .await .into_ann_result()? as u32; - Ok(SearchStats { - cmps, - hops: 0, - result_count, - range_search_second_round: false, - }) + Ok(SearchStats { cmps, result_count }) } } } diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index 40061f9d9..e4021bc00 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -10,11 +10,13 @@ //! [`FlatIterator`] — a lending async iterator that can be bridged into //! [`OnElementsUnordered`] via [`Iterated`]. +use std::fmt::Debug; + use diskann_utils::{Reborrow, future::SendFuture}; use diskann_vector::PreprocessedDistanceFunction; use crate::{ - error::StandardError, + error::ToRanked, provider::{BuildQueryComputer, HasElementRef, HasId}, }; @@ -27,7 +29,7 @@ use crate::{ /// Algorithms see only `(Id, ElementRef)` pairs and treat the stream as opaque. pub trait OnElementsUnordered: HasId + HasElementRef + Send + Sync { /// The error type yielded by [`Self::on_elements_unordered`]. - type Error: StandardError; + type Error: ToRanked + Debug + Send + Sync + 'static; /// Drive the entire scan, invoking `f` for each yielded element. fn on_elements_unordered(&mut self, f: F) -> impl SendFuture> @@ -86,7 +88,7 @@ pub trait FlatIterator: HasId + HasElementRef + Send + Sync { Self: 'a; /// The error type yielded by [`Self::next`]. - type Error: StandardError; + type Error: ToRanked + Debug + Send + Sync + 'static; /// Advance the iterator and asynchronously yield the next `(id, element)` pair. /// diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs index b939056a0..8f619dc91 100644 --- a/diskann/src/flat/mod.rs +++ b/diskann/src/flat/mod.rs @@ -36,6 +36,6 @@ pub mod index; pub mod iterator; pub mod strategy; -pub use index::FlatIndex; +pub use index::{FlatIndex, SearchStats}; pub use iterator::{DistancesUnordered, FlatIterator, Iterated, OnElementsUnordered}; pub use strategy::SearchStrategy; From 0627cc067d6828b8f83bf94e8be1dae98e1f18bf Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Thu, 7 May 2026 11:59:22 -0400 Subject: [PATCH 16/24] small doc fixes --- diskann/src/flat/index.rs | 5 +++-- diskann/src/flat/iterator.rs | 7 ++++--- diskann/src/flat/mod.rs | 23 +++++++++-------------- diskann/src/flat/strategy.rs | 13 +++++-------- 4 files changed, 21 insertions(+), 27 deletions(-) diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs index 06cb5f29f..5633d8d47 100644 --- a/diskann/src/flat/index.rs +++ b/diskann/src/flat/index.rs @@ -3,7 +3,8 @@ * Licensed under the MIT license. */ -//! [`FlatIndex`] — the index wrapper for an on which we do flat search. +//! [`FlatIndex`] — the index wrapper for a [`DataProvider`](crate::provider::DataProvider) +//! over which we do flat search. use std::num::NonZeroUsize; use diskann_utils::future::SendFuture; @@ -51,7 +52,7 @@ impl FlatIndex

{ /// Brute-force k-nearest-neighbor flat search. /// - /// Streams every element produced by the strategy's iterator through the query + /// Streams every element produced by the strategy's visitor through the query /// computer, keeps the best `k` candidates in a [`NeighborPriorityQueue`], and hands /// the survivors to the post-processor. /// diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index e4021bc00..2f019ce25 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -5,7 +5,8 @@ //! [`OnElementsUnordered`] — the sequential access primitive for accessing a flat index. //! [`DistancesUnordered`] — sub-trait of [`OnElementsUnordered`] and [`BuildQueryComputer`] -//! that computes distances over the elements in the flat index. +//! that fuses a pre-built query computer with a sequential scan and yields +//! `(id, distance)` pairs to its callback. //! //! [`FlatIterator`] — a lending async iterator that can be bridged into //! [`OnElementsUnordered`] via [`Iterated`]. @@ -77,8 +78,8 @@ pub trait DistancesUnordered: OnElementsUnordered + BuildQueryComputer { /// A lending, asynchronous iterator over the elements of a flat index. /// /// Implementations provide element-at-a-time access via [`Self::next`]. Providers that -/// only implement `FlatIterator` can be wrapped in [`DefaultIteratedOperator`] to obtain -/// an [`OnElementsUnordered`] implementation automatically. +/// only implement `FlatIterator` can be wrapped in [`Iterated`] to obtain an +/// [`OnElementsUnordered`] implementation automatically. pub trait FlatIterator: HasId + HasElementRef + Send + Sync { /// The concrete element returned by [`Self::next`]. Reborrows to [`Self::ElementRef`]. type Element<'a>: for<'b> Reborrow<'b, Target = ::ElementRef<'b>> diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs index 8f619dc91..86920e4a1 100644 --- a/diskann/src/flat/mod.rs +++ b/diskann/src/flat/mod.rs @@ -14,20 +14,15 @@ //! //! The module mirrors the layering used by graph search: //! -//! | Graph (random access) | Flat (sequential) | Shared? | -//! | :------------------------------------ | :-------------------------------- |:--------- | -//! | [`crate::provider::DataProvider`] | [`crate::provider::DataProvider`] | Yes | -//! | [`crate::graph::DiskANNIndex`] | [`FlatIndex`] | No | -//! | [`crate::provider::Accessor`] | [`FlatIterator`] | No | -//! | [`crate::graph::glue::SearchStrategy`] | [`SearchStrategy`] | No | -//! | [`crate::graph::glue::SearchPostProcess`] | [`crate::graph::glue::SearchPostProcess`] | Yes | -//! | [`crate::graph::Search`] | [`FlatIndex::knn_search`] | No | -//! -//! # Hot loop -//! -//! Algorithms drive the scan via [`FlatIterator::next`] (lending iterator) or override -//! [`FlatIterator::on_elements_unordered`] when batching/prefetching wins. The default -//! implementation of `on_elements_unordered` simply loops over `next`. +//! | Graph (random access) | Flat (sequential) | Shared? | +//! | :------------------------------------ | :----------------------------------------- |:--------- | +//! | [`crate::provider::DataProvider`] | [`crate::provider::DataProvider`] | Yes | +//! | [`crate::graph::DiskANNIndex`] | [`FlatIndex`] | No | +//! | [`crate::provider::Accessor`] | [`OnElementsUnordered`] (and [`Iterated`] for an element-at-a-time bridge from [`FlatIterator`]) | No | +//! | [`crate::provider::DistancesUnordered`] | [`DistancesUnordered`] | No | +//! | [`crate::graph::glue::SearchStrategy`] | [`SearchStrategy`] | No | +//! | [`crate::graph::glue::SearchPostProcess`] | [`crate::graph::glue::SearchPostProcess`] | Yes | +//! | [`crate::graph::Search`] | [`FlatIndex::knn_search`] | No | //! //! See [`FlatIndex::knn_search`] for the canonical brute-force k-NN algorithm built on these //! primitives. diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs index f8ec19e78..467de49d3 100644 --- a/diskann/src/flat/strategy.rs +++ b/diskann/src/flat/strategy.rs @@ -3,7 +3,8 @@ * Licensed under the MIT license. */ -//! [`SearchStrategy`] — glue between [`DataProvider`] and per-query [`crate::flat::FlatIterator`]s. +//! [`SearchStrategy`] — glue between [`DataProvider`] and per-query +//! [`DistancesUnordered`] visitors. use crate::{error::StandardError, flat::DistancesUnordered, provider::DataProvider}; @@ -11,17 +12,13 @@ use crate::{error::StandardError, flat::DistancesUnordered, provider::DataProvid /// [`DistancesUnordered`] visitor for a provider. /// /// `SearchStrategy` is the flat counterpart to [`crate::graph::glue::SearchStrategy`] -/// (disambiguated by module path). A strategy instance is stateless config — typically -/// constructed at the call site, used for one search, and dropped. +/// (disambiguated by module path). A strategy instance carries the per-query setup +/// recipe; the per-query mutable state lives in the visitor it produces, so a single +/// strategy may be reused across many searches. /// /// The strategy itself is a pure factory; the visitor it produces carries the /// query-preprocessing capability via [`crate::provider::BuildQueryComputer`] (a /// super-trait of [`DistancesUnordered`]). -/// -/// # Type parameters -/// -/// - `P`: the [`DataProvider`] that backs the index. -/// - `T`: the query type that the query computer is constructed using. pub trait SearchStrategy: Send + Sync where P: DataProvider, From 1d4e90b3b82820bec9a004dadc52525c4c5a1c58 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Thu, 7 May 2026 15:54:51 -0400 Subject: [PATCH 17/24] iterator docs --- diskann/src/flat/iterator.rs | 39 +++++++++++++++++++++++------------- diskann/src/flat/mod.rs | 2 +- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index 2f019ce25..77be634fb 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -3,13 +3,28 @@ * Licensed under the MIT license. */ -//! [`OnElementsUnordered`] — the sequential access primitive for accessing a flat index. -//! [`DistancesUnordered`] — sub-trait of [`OnElementsUnordered`] and [`BuildQueryComputer`] -//! that fuses a pre-built query computer with a sequential scan and yields -//! `(id, distance)` pairs to its callback. +//! Sequential ("flat") access primitives. //! -//! [`FlatIterator`] — a lending async iterator that can be bridged into -//! [`OnElementsUnordered`] via [`Iterated`]. +//! This module defines the traits that flat-search algorithms use to walk every element +//! of a [`DataProvider`](crate::provider::DataProvider) once. +//! +//! * [`OnElementsUnordered`]: the lowest-level entry point and the only required trait +//! to implement. It is a single-method trait that applies a caller-supplied closure +//! to every `(id, element ref)` pair in the provider. The super-traits [`HasId`] and +//! [`HasElementRef`] define the concrete id and element reference types. +//! +//! * [`DistancesUnordered`]: a sub-trait of [`OnElementsUnordered`] that takes a query +//! computer and a closure (typically to filter results through a priority queue) and +//! applies the closure to the `(id, distance)` pair for every element, with each +//! distance computed using the supplied computer. +//! +//! * [`FlatIterator`]: a convenient entry point for backends whose natural shape is +//! element-at-a-time iteration. The trait exposes a single `next` method and an +//! associated `Element<'_>` type that must be [`Reborrow`]able to the `ElementRef<'_>` +//! exposed via the [`HasElementRef`] super-trait. +//! +//! * [`Iterated`]: bridges any [`FlatIterator`] implementation into an +//! [`OnElementsUnordered`] by looping over [`FlatIterator::next`]. use std::fmt::Debug; @@ -38,9 +53,10 @@ pub trait OnElementsUnordered: HasId + HasElementRef + Send + Sync { F: Send + for<'a> FnMut(Self::Id, ::ElementRef<'a>); } -/// Extension of [`OnElementsUnordered`] that drives the scan with a query computer -/// produced by the visitor's [`BuildQueryComputer`] impl, invoking a callback with -/// `(id, distance)` pairs. +/// Extension of [`OnElementsUnordered`] that drives the scan with a query computer. +/// +/// The computer is produced by the visitor's [`BuildQueryComputer`] impl, +/// and invokes a callback with `(id, distance)` pairs. /// /// This fuses the scan with a pre-processed query computer and runs over a /// streaming visitor. It pulls the computer type from the implementor's own @@ -48,11 +64,6 @@ pub trait OnElementsUnordered: HasId + HasElementRef + Send + Sync { /// /// The default implementation delegates to [`OnElementsUnordered::on_elements_unordered`], /// calling `computer.evaluate_similarity` on each element. -/// -/// ## Note -/// -/// This is the flat analog to [`crate::provider::DistancesUnordered`] which runs over -/// a random-access [`crate::provider::Accessor`]. pub trait DistancesUnordered: OnElementsUnordered + BuildQueryComputer { /// Drive the entire scan, scoring each element with `computer` and invoking `f` with /// the resulting `(id, distance)` pair. diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs index 86920e4a1..605c634d2 100644 --- a/diskann/src/flat/mod.rs +++ b/diskann/src/flat/mod.rs @@ -18,7 +18,7 @@ //! | :------------------------------------ | :----------------------------------------- |:--------- | //! | [`crate::provider::DataProvider`] | [`crate::provider::DataProvider`] | Yes | //! | [`crate::graph::DiskANNIndex`] | [`FlatIndex`] | No | -//! | [`crate::provider::Accessor`] | [`OnElementsUnordered`] (and [`Iterated`] for an element-at-a-time bridge from [`FlatIterator`]) | No | +//! | [`crate::provider::Accessor`] | [`OnElementsUnordered`] | No | //! | [`crate::provider::DistancesUnordered`] | [`DistancesUnordered`] | No | //! | [`crate::graph::glue::SearchStrategy`] | [`SearchStrategy`] | No | //! | [`crate::graph::glue::SearchPostProcess`] | [`crate::graph::glue::SearchPostProcess`] | Yes | From b4d9df023f58cbfd46c57f325f75f31daa3a90aa Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Thu, 7 May 2026 20:32:05 -0400 Subject: [PATCH 18/24] strong pass on testing --- diskann/src/flat/index.rs | 95 ++++ diskann/src/flat/iterator.rs | 479 +++++++++++++++++- diskann/src/flat/mod.rs | 3 + diskann/src/flat/strategy.rs | 29 ++ .../src/flat/test/cases/flat_knn_search.rs | 205 ++++++++ diskann/src/flat/test/cases/mod.rs | 6 + diskann/src/flat/test/harness.rs | 142 ++++++ diskann/src/flat/test/mod.rs | 14 + diskann/src/flat/test/provider.rs | 442 ++++++++++++++++ .../cases/flat_knn_search/search_1_100.json | 264 ++++++++++ .../cases/flat_knn_search/search_2_5.json | 270 ++++++++++ .../cases/flat_knn_search/search_3_4.json | 276 ++++++++++ 12 files changed, 2221 insertions(+), 4 deletions(-) create mode 100644 diskann/src/flat/test/cases/flat_knn_search.rs create mode 100644 diskann/src/flat/test/cases/mod.rs create mode 100644 diskann/src/flat/test/harness.rs create mode 100644 diskann/src/flat/test/mod.rs create mode 100644 diskann/src/flat/test/provider.rs create mode 100644 diskann/test/generated/flat/test/cases/flat_knn_search/search_1_100.json create mode 100644 diskann/test/generated/flat/test/cases/flat_knn_search/search_2_5.json create mode 100644 diskann/test/generated/flat/test/cases/flat_knn_search/search_3_4.json diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs index 5633d8d47..3a032536d 100644 --- a/diskann/src/flat/index.rs +++ b/diskann/src/flat/index.rs @@ -107,3 +107,98 @@ impl FlatIndex

{ } } } + +///////////// +// Tests /// +///////////// + +#[cfg(test)] +mod tests { + use crate::flat::{ + FlatIndex, + test::{ + harness::KnnOracleRun, + provider::{self as flat_provider, Strategy}, + }, + }; + use crate::graph::test::synthetic::Grid; + + fn fixture(grid: Grid, size: usize) -> (FlatIndex, usize) { + let provider = flat_provider::Provider::grid(grid, size); + let len = provider.len(); + (FlatIndex::new(provider), len) + } + + /// `knn_search` returns a `Send` future, and a shared `&FlatIndex` can serve + /// many concurrent searches on a multi-threaded runtime, each producing the + /// correct top-k independently. + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn knn_search() { + use std::sync::Arc; + + let (index, len) = fixture(Grid::Two, 4); + let index = Arc::new(index); + + // Mix of corner, axis-aligned, and off-grid queries; k spans 1..=len. + let cases: &[(&[f32], usize)] = &[ + (&[-1.0, -1.0], 1), + (&[1.0, 1.0], len), + (&[-1.0, 1.0], len / 2), + (&[1.0, -1.0], len - 1), + (&[0.0, 0.0], 3), + (&[3.0, 3.0], len), + (&[-2.0, 0.5], 2), + (&[0.5, -0.5], len), + ]; + + let mut set = tokio::task::JoinSet::new(); + for (query, k) in cases { + let index = Arc::clone(&index); + let query: Vec = query.to_vec(); + let k = *k; + set.spawn(async move { + let outcome = KnnOracleRun::run(&index, &Strategy::new(), &query, k) + .await + .expect("knn_search failed"); + (query, k, outcome) + }); + } + + while let Some(joined) = set.join_next().await { + let (query, k, outcome) = joined.expect("task panicked"); + assert_eq!( + outcome.top_k, outcome.ground_truth, + "query = {query:?}, k = {k}: top-k must match brute force", + ); + assert_eq!(outcome.stats.cmps as usize, len); + assert_eq!(outcome.stats.result_count as usize, k.min(len)); + } + } + + /// A transient error from the visitor's scan must escalate up through `knn_search`. + #[test] + fn transient_scan_error() { + let (index, _len) = fixture(Grid::Two, 3); + + // The flat scan must touch every id, so any transient id is guaranteed to be + // hit. + for transient_ids in [&[0u32][..], &[3][..], &[1, 2, 5][..]] { + let err = KnnOracleRun::run_sync( + &index, + &Strategy::with_transient(transient_ids.iter().copied()), + &[1.0, 0.0], + 4, + ) + .expect_err("transient error during full scan must escalate"); + + let msg = format!("{err}"); + assert!( + transient_ids + .iter() + .any(|id| msg.contains(&format!("id {id}"))), + "transients = {transient_ids:?}: expected error to name one of the \ + transient ids, got: {msg}", + ); + } + } +} diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index 77be634fb..9229a3a92 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -38,10 +38,6 @@ use crate::{ /// Callback-driven sequential scan over the elements of a flat index. /// -/// `OnElementsUnordered` is the streaming counterpart to [`crate::provider::Accessor`]. -/// Where an accessor exposes random retrieval by id, this trait exposes a *sequential* -/// walk that invokes a caller-supplied closure for every element. -/// /// Algorithms see only `(Id, ElementRef)` pairs and treat the stream as opaque. pub trait OnElementsUnordered: HasId + HasElementRef + Send + Sync { /// The error type yielded by [`Self::on_elements_unordered`]. @@ -164,3 +160,478 @@ where } } } + +#[cfg(test)] +mod tests { + use std::{ + fmt::Debug, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + }; + + use diskann_utils::Reborrow; + use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; + + use super::*; + use crate::{ + ANNError, always_escalate, + error::Infallible, + provider::{BuildQueryComputer, HasElementRef, HasId}, + utils::VectorRepr, + }; + + /////////////////////////// + // Shared sample dataset // + /////////////////////////// + + /// Canonical sample dataset shared by every contract test below. + fn sample_items() -> Vec<(u32, Vec)> { + vec![ + (10, vec![0.0, 0.0]), + (11, vec![1.0, 0.0]), + (12, vec![0.0, 2.0]), + ] + } + + /// Backing store of `[f32]` vectors, used by every element-shape fixture + /// below to cover [`FlatIterator::Element`] variants without re-implementing + /// the data layout each time. + struct Store { + items: Vec<(u32, Vec)>, + } + + impl Store { + fn sample() -> Self { + Self { + items: sample_items(), + } + } + } + + /////////////////////// + // Common impl macro // + /////////////////////// + + /// Implement [`HasId`], [`HasElementRef`], [`BuildQueryComputer`], and + /// [`DistancesUnordered`] for an iterator type. Every fixture in this module + /// shares these impls — only [`FlatIterator::Element`] varies. + macro_rules! common_iterator_impls { + ($T:ty) => { + impl HasId for $T { + type Id = u32; + } + + impl HasElementRef for $T { + type ElementRef<'a> = &'a [f32]; + } + + impl BuildQueryComputer<&[f32]> for $T { + type QueryComputerError = Infallible; + type QueryComputer = ::QueryDistance; + + fn build_query_computer( + &self, + from: &[f32], + ) -> Result { + Ok(f32::query_distance(from, Metric::L2)) + } + } + + // Forward `BuildQueryComputer` through the `Iterated` adapter so the + // `DistancesUnordered` supertrait bound is satisfied. + impl BuildQueryComputer<&[f32]> for Iterated<$T> { + type QueryComputerError = Infallible; + type QueryComputer = ::QueryDistance; + + fn build_query_computer( + &self, + from: &[f32], + ) -> Result { + Ok(f32::query_distance(from, Metric::L2)) + } + } + + impl DistancesUnordered<&[f32]> for Iterated<$T> {} + }; + } + + ///////////////////////////////// + // Allocating: Element = Vec // + ///////////////////////////////// + + /// `Element<'a> = Vec` — owns its data, reborrows to `&'a [f32]`. + /// Mirrors the `Allocating` accessor in [`crate::provider`]. + struct Allocating<'a> { + store: &'a Store, + cursor: usize, + } + + impl<'a> Allocating<'a> { + fn new(store: &'a Store) -> Self { + Self { store, cursor: 0 } + } + } + + common_iterator_impls!(Allocating<'_>); + + impl FlatIterator for Allocating<'_> { + type Element<'a> + = Vec + where + Self: 'a; + type Error = Infallible; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>> { + async move { + let i = self.cursor; + if i >= self.store.items.len() { + return Ok(None); + } + self.cursor += 1; + let (id, ref v) = self.store.items[i]; + Ok(Some((id, v.clone()))) + } + } + } + + //////////////////////////////////////////////// + // Forwarding: Element = &'store [f32] // + //////////////////////////////////////////////// + + /// `Element<'a> = &'store [f32]` — borrows directly out of the underlying + /// store. The element lifetime is tied to the *store* (not the iterator), + /// proving the trait supports forwarding accessors. Mirrors the + /// `Forwarding` accessor in [`crate::provider`]. + struct Forwarding<'store> { + store: &'store Store, + cursor: usize, + } + + impl<'store> Forwarding<'store> { + fn new(store: &'store Store) -> Self { + Self { store, cursor: 0 } + } + } + + common_iterator_impls!(Forwarding<'_>); + + impl<'store> FlatIterator for Forwarding<'store> { + type Element<'a> + = &'store [f32] + where + Self: 'a; + type Error = Infallible; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>> { + async move { + let i = self.cursor; + if i >= self.store.items.len() { + return Ok(None); + } + self.cursor += 1; + let (id, ref v) = self.store.items[i]; + Ok(Some((id, v.as_slice()))) + } + } + } + + ///////////////////////////////////////////////////////// + // Wrapping: Element = guard-shaped non-ref `Wrapped` // + ///////////////////////////////////////////////////////// + + /// A guard-shaped element that reborrows to `&'b [f32]` and counts its own + /// drops. Mirrors the `Wrapping` accessor's `Wrapped<'a>` in + /// [`crate::provider`], plus a [`Drop`] hook to verify the [`Iterated`] + /// adapter does not leak guards. + struct Wrapped<'g> { + data: &'g [f32], + drop_count: Arc, + } + + impl<'b> Reborrow<'b> for Wrapped<'_> { + type Target = &'b [f32]; + fn reborrow(&'b self) -> Self::Target { + self.data + } + } + + impl Drop for Wrapped<'_> { + fn drop(&mut self) { + self.drop_count.fetch_add(1, Ordering::SeqCst); + } + } + + /// `Element<'a> = Wrapped<'a>`. + struct Wrapping<'a> { + store: &'a Store, + cursor: usize, + drop_count: Arc, + } + + impl<'a> Wrapping<'a> { + fn new(store: &'a Store) -> Self { + Self { + store, + cursor: 0, + drop_count: Arc::new(AtomicUsize::new(0)), + } + } + } + + common_iterator_impls!(Wrapping<'_>); + + impl FlatIterator for Wrapping<'_> { + type Element<'a> + = Wrapped<'a> + where + Self: 'a; + type Error = Infallible; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>> { + async move { + let i = self.cursor; + if i >= self.store.items.len() { + return Ok(None); + } + self.cursor += 1; + let (id, ref v) = self.store.items[i]; + Ok(Some(( + id, + Wrapped { + data: v.as_slice(), + drop_count: self.drop_count.clone(), + }, + ))) + } + } + } + + ///////////////////////////////////////////////// + // Sharing: Element = &'a [f32] via local buf // + ///////////////////////////////////////////////// + + /// `Element<'a> = &'a [f32]` — copies into an internal buffer per `next()` + /// to avoid per-call allocation. Mirrors the `Sharing` accessor in + /// [`crate::provider`]. + struct Sharing<'a> { + store: &'a Store, + cursor: usize, + buf: Vec, + } + + impl<'a> Sharing<'a> { + fn new(store: &'a Store) -> Self { + Self { + store, + cursor: 0, + buf: Vec::new(), + } + } + } + + common_iterator_impls!(Sharing<'_>); + + impl FlatIterator for Sharing<'_> { + type Element<'a> + = &'a [f32] + where + Self: 'a; + type Error = Infallible; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>> { + async move { + let i = self.cursor; + if i >= self.store.items.len() { + return Ok(None); + } + self.cursor += 1; + let (id, ref v) = self.store.items[i]; + self.buf.clear(); + self.buf.extend_from_slice(v); + Ok(Some((id, self.buf.as_slice()))) + } + } + } + + //////////////////////// + // Failing iterator // + //////////////////////// + + /// A critical (non-recoverable) error type the [`Failing`] iterator yields. + #[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] + #[error("synthetic iterator failure at id {0}")] + struct Boom(u32); + + always_escalate!(Boom); + + impl From for ANNError { + #[track_caller] + fn from(boom: Boom) -> ANNError { + ANNError::opaque(boom) + } + } + + /// `Element<'a> = &'a [f32]`, but `next()` returns `Err(Boom(id))` exactly + /// once after `fail_after` successful yields. Used to verify error + /// propagation through [`Iterated::on_elements_unordered`]. + struct Failing<'a> { + store: &'a Store, + cursor: usize, + fail_after: usize, + } + + impl HasId for Failing<'_> { + type Id = u32; + } + + impl HasElementRef for Failing<'_> { + type ElementRef<'a> = &'a [f32]; + } + + impl FlatIterator for Failing<'_> { + type Element<'a> + = &'a [f32] + where + Self: 'a; + type Error = Boom; + + fn next( + &mut self, + ) -> impl SendFuture)>, Self::Error>> { + async move { + let i = self.cursor; + if i >= self.store.items.len() { + return Ok(None); + } + self.cursor += 1; + let (id, ref v) = self.store.items[i]; + if i == self.fail_after { + return Err(Boom(id)); + } + Ok(Some((id, v.as_slice()))) + } + } + } + + ///////////// + // Helpers // + ///////////// + + /// Drive `visitor.on_elements_unordered` to completion and assert the + /// yielded `(id, element)` pairs equal [`sample_items`] in iteration order. + async fn check_visitor(visitor: &mut V) + where + V: OnElementsUnordered + HasId, + V: for<'a> HasElementRef = &'a [f32]>, + V::Error: Debug, + { + let mut out = Vec::new(); + visitor + .on_elements_unordered(|id, e: &[f32]| out.push((id, e.to_vec()))) + .await + .unwrap(); + assert_eq!(out, sample_items()); + } + + /////////// + // Tests // + /////////// + + /// `Iterated::on_elements_unordered` is correct for every supported + /// [`FlatIterator::Element`] shape: owning, forwarding, guard-wrapped, and + /// shared-buffer. + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn default_implementations() { + let store = Store::sample(); + + // Allocating: Element = Vec (owns). + check_visitor(&mut Iterated::new(Allocating::new(&store))).await; + + // Forwarding: Element = &'store [f32] (borrows from store). + check_visitor(&mut Iterated::new(Forwarding::new(&store))).await; + + let recovered = Iterated::new(Forwarding::new(&store)).into_inner(); + + check_visitor(&mut Iterated::new(recovered)).await; + + // Wrapping: Element = Wrapped<'a> (guard-shaped non-ref). + check_visitor(&mut Iterated::new(Wrapping::new(&store))).await; + + // Sharing: Element = &'a [f32] (per-call internal buffer). + check_visitor(&mut Iterated::new(Sharing::new(&store))).await; + } + + /// The default body of [`DistancesUnordered::distances_unordered`] produces + /// `(id, computer.evaluate_similarity(elem))` pairs for every element shape. + #[tokio::test] + async fn distances_unordered() { + let store = Store::sample(); + let query = vec![0.5_f32, 0.9]; + let computer = f32::query_distance(&query, Metric::L2); + let expected = sample_items() + .into_iter() + .map(|(id, v)| (id, computer.evaluate_similarity(v.as_slice()))) + .collect::>(); + + async fn run(mut visitor: Iterated, query: &[f32], expected: &[(u32, f32)]) + where + I: FlatIterator + Send + Sync, + I: for<'a> HasElementRef = &'a [f32]>, + Iterated: HasId + + for<'q> BuildQueryComputer< + &'q [f32], + QueryComputerError = Infallible, + QueryComputer = ::QueryDistance, + > + for<'q> DistancesUnordered<&'q [f32]>, + { + let computer = visitor.build_query_computer(query).unwrap(); + let mut seen: Vec<(u32, f32)> = Vec::new(); + visitor + .distances_unordered(&computer, |id, d| seen.push((id, d))) + .await + .unwrap(); + assert_eq!(seen, expected); + } + + run(Iterated::new(Allocating::new(&store)), &query, &expected).await; + run(Iterated::new(Forwarding::new(&store)), &query, &expected).await; + run(Iterated::new(Wrapping::new(&store)), &query, &expected).await; + run(Iterated::new(Sharing::new(&store)), &query, &expected).await; + } + + /// An error returned mid-iteration by [`FlatIterator::next`] propagates up + /// through [`Iterated::on_elements_unordered`], and the closure stops being + /// invoked at the failure point. + #[tokio::test] + async fn failures_midstream() { + let store = Store::sample(); + let mut visitor = Iterated::new(Failing { + store: &store, + cursor: 0, + fail_after: 1, // Yield item 0 successfully, fail on item 1. + }); + + let mut seen: Vec = Vec::new(); + let err = visitor + .on_elements_unordered(|id, _e: &[f32]| seen.push(id)) + .await + .expect_err("Failing iterator must surface its error"); + + assert_eq!(err, Boom(11)); + assert_eq!( + seen, + vec![10], + "the closure must only see items yielded before the failure", + ); + } +} diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs index 605c634d2..fc06e21b9 100644 --- a/diskann/src/flat/mod.rs +++ b/diskann/src/flat/mod.rs @@ -34,3 +34,6 @@ pub mod strategy; pub use index::{FlatIndex, SearchStats}; pub use iterator::{DistancesUnordered, FlatIterator, Iterated, OnElementsUnordered}; pub use strategy::SearchStrategy; + +#[cfg(test)] +mod test; diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs index 467de49d3..a3ff980b1 100644 --- a/diskann/src/flat/strategy.rs +++ b/diskann/src/flat/strategy.rs @@ -47,3 +47,32 @@ where context: &'a P::Context, ) -> Result, Self::Error>; } + +#[cfg(test)] +mod tests { + use crate::{ + flat::test::provider::{self as flat_provider, Strategy}, + graph::test::synthetic::Grid, + }; + + use super::SearchStrategy; + + /// `create_visitor` produces independent visitors on successive calls. + /// + /// The strategy is a stateless factory; calling it twice should yield two + /// distinct visitors that may be used in parallel without interfering with + /// each other. + #[test] + fn exercise_create_visitor() { + let provider = flat_provider::Provider::grid(Grid::Two, 3); + let context = flat_provider::Context::new(); + let strategy = Strategy::new(); + + let v1 = strategy.create_visitor(&provider, &context).unwrap(); + let v2 = strategy.create_visitor(&provider, &context).unwrap(); + + // The two visitors must occupy distinct stack slots — i.e. holding `v1` + // does not preclude constructing `v2`. + let _ = (&v1, &v2); + } +} diff --git a/diskann/src/flat/test/cases/flat_knn_search.rs b/diskann/src/flat/test/cases/flat_knn_search.rs new file mode 100644 index 000000000..b5bd47df1 --- /dev/null +++ b/diskann/src/flat/test/cases/flat_knn_search.rs @@ -0,0 +1,205 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Baseline-cached regression sweep for [`crate::flat::FlatIndex::knn_search`]. +//! +//! Mirrors [`crate::graph::test::cases::grid_search`]: builds a fresh index per +//! parameter combination, runs `knn_search` through the +//! [`crate::flat::test::harness`], snapshots the result + statistics into +//! [`FlatKnnBaseline`], and compares the entire batch against the JSON committed under +//! `diskann/test/generated/flat/test/cases/flat_knn_search/`. + +use crate::{ + flat::{ + FlatIndex, + test::{ + harness, + provider::{self as flat_provider, Metrics, Strategy}, + }, + }, + graph::test::synthetic::Grid, + test::{ + TestPath, TestRoot, + cmp::{assert_eq_verbose, verbose_eq}, + get_or_save_test_results, + }, +}; + +fn root() -> TestRoot { + TestRoot::new("flat/test/cases/flat_knn_search") +} + +/// `k` values exercised for every `(grid, query)` combination. +const KS: [usize; 3] = [1, 4, 10]; + +/// One row of the baseline JSON: a single `(grid, size, query, k)` execution of +/// `FlatIndex::knn_search` plus the brute-force ground truth, search stats, and +/// per-row provider metrics. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct FlatKnnBaseline { + /// Free-form description of what this row exercises. + description: String, + + /// The query vector. + query: Vec, + + /// The dimensionality of the underlying grid. + grid_dims: usize, + + /// The side length of the underlying grid. + grid_size: usize, + + /// The requested `k`. + k: usize, + + /// Sorted distance multiset of the top-`k` returned by `knn_search`. + /// We store the distance multiset rather than `(id, distance)` pairs because + /// the priority queue may evict different *ids* on a boundary distance tie + /// (the queue's tie-breaking is heap-internal, not id-based) — but the + /// multiset of distances is invariant. + top_k_distances: Vec, + + /// Brute-force ground-truth top-`k` `(id, distance)` (sorted by `(distance asc, + /// id asc)`). The brute-force pass enumerates ids in ascending order, so on a + /// tie this prefers the smaller id and gives a canonical answer for the JSON. + ground_truth: Vec<(u32, f32)>, + + /// `cmps` reported by `knn_search`. Must equal `provider.len()`. + comparisons: usize, + + /// `result_count` reported by `knn_search`. Must equal `min(k, provider.len())`. + result_count: usize, + + /// Per-provider metrics observed for this row (see [`Metrics`]). + metrics: Metrics, +} + +verbose_eq!(FlatKnnBaseline { + description, + query, + grid_dims, + grid_size, + k, + top_k_distances, + ground_truth, + comparisons, + result_count, + metrics, +}); + +/// Run `knn_search` + brute-force oracle against a *shared* `index`, assert the +/// cross-row invariants, and produce the baseline row. The per-row provider metrics +/// captured into the baseline are the *delta* observed during this row, which keeps +/// the snapshot independent of how many rows preceded it. +fn run_row( + index: &FlatIndex, + grid_dim: usize, + grid_size: usize, + query: &[f32], + k: usize, + desc: &str, +) -> FlatKnnBaseline { + let len = index.provider().len(); + let metrics_before = index.provider().metrics(); + + let outcome = harness::KnnOracleRun::run_sync(index, &Strategy::new(), query, k).unwrap(); + let stats = outcome.stats; + + assert_eq!( + stats.cmps as usize, len, + "flat scan must touch every element exactly once", + ); + assert_eq!( + stats.result_count as usize, + k.min(len), + "result_count must equal min(k, provider.len())", + ); + + let gt_distances: Vec = outcome.ground_truth.iter().map(|(_, d)| *d).collect(); + assert_eq!( + outcome.top_k_distances, gt_distances, + "flat scan top-k distance multiset must agree with brute force", + ); + + let metrics_after = index.provider().metrics(); + let metrics = Metrics { + get_element: metrics_after.get_element - metrics_before.get_element, + }; + // `get_element` is incremented only by the [`Visitor`] used during `knn_search`; + // the brute-force oracle iterates `Provider::items()` directly and does not touch + // the visitor, so we expect exactly one scan's worth of increments per row. + assert_eq!( + metrics.get_element, len, + "expected exactly one scan (from knn_search) to increment get_element", + ); + + FlatKnnBaseline { + description: desc.to_string(), + query: query.to_vec(), + grid_dims: grid_dim, + grid_size, + k, + top_k_distances: outcome.top_k_distances, + ground_truth: outcome.ground_truth, + comparisons: stats.cmps as usize, + result_count: stats.result_count as usize, + metrics, + } +} + +/// Sweep [`KS`] × `queries` for the given `(grid, size)` and snapshot the results. +fn _flat_knn_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { + let dim: usize = grid.dim().into(); + + // Build the provider and index once, mirroring the production pattern where a + // single index serves many queries. + let provider = flat_provider::Provider::grid(grid, size); + let len = provider.len(); + assert_eq!( + len, + size.pow(dim as u32), + "flat::test::Provider::grid should produce size^dim rows", + ); + let index = FlatIndex::new(provider); + + let queries: [(Vec, &str); 2] = [ + ( + vec![-1.0; dim], + "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + ), + ( + vec![(size - 1) as f32; dim], + "All `size-1`: query coincides with the last grid corner.", + ), + ]; + + let index_ref = &index; + let results: Vec = queries + .iter() + .flat_map(|(q, desc)| { + KS.iter() + .map(move |&k| run_row(index_ref, dim, size, q, k, desc)) + }) + .collect(); + + let name = parent.push(format!("search_{dim}_{size}")); + let expected = get_or_save_test_results(&name, &results); + assert_eq_verbose!(expected, results); +} + +#[test] +fn flat_knn_search_1_100() { + _flat_knn_search(Grid::One, 100, root().path()); +} + +#[test] +fn flat_knn_search_2_5() { + _flat_knn_search(Grid::Two, 5, root().path()); +} + +#[test] +fn flat_knn_search_3_4() { + _flat_knn_search(Grid::Three, 4, root().path()); +} diff --git a/diskann/src/flat/test/cases/mod.rs b/diskann/src/flat/test/cases/mod.rs new file mode 100644 index 000000000..ffaf5b91c --- /dev/null +++ b/diskann/src/flat/test/cases/mod.rs @@ -0,0 +1,6 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +mod flat_knn_search; diff --git a/diskann/src/flat/test/harness.rs b/diskann/src/flat/test/harness.rs new file mode 100644 index 000000000..af177378b --- /dev/null +++ b/diskann/src/flat/test/harness.rs @@ -0,0 +1,142 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Reusable execution harness for [`crate::flat::FlatIndex`] tests. +//! +//! Centralizes every piece of repeated boilerplate that surrounds a `knn_search` +//! call so individual tests stay close to a one-line statement of intent: +//! +//! * runtime construction (`current_thread_runtime`), +//! * scratch buffer + [`BackInserter`] plumbing, +//! * [`NonZeroUsize`] wrapping and [`CopyIds`] selection, +//! * brute-force ground-truth oracle, +//! * canonical `(distance asc, id asc)` re-sort that matches the oracle's +//! tie-breaking (the priority queue's heap-pop order is *not* id-stable on +//! ties, but the oracle's is). +//! +//! Use [`KnnOracleRun::new`] to drive `knn_search` and pair the result with the +//! brute-force ground truth. + +use std::{cmp::Ordering, num::NonZeroUsize}; + +use diskann_vector::PreprocessedDistanceFunction; + +use crate::{ + ANNResult, + flat::{ + FlatIndex, SearchStats, + test::provider::{Provider, Strategy}, + }, + graph::glue::CopyIds, + neighbor::{BackInserter, Neighbor}, + test::tokio::current_thread_runtime, + utils::VectorRepr, +}; + +/// Result of running [`FlatIndex::knn_search`] under the harness alongside a +/// brute-force ground-truth oracle. +#[derive(Debug, Clone)] +pub(crate) struct KnnOracleRun { + /// Top-`k` `(id, distance)` pairs in canonical `(distance asc, id asc)` order. + /// Re-sorted from the heap output so equality checks are deterministic on ties. + pub top_k: Vec<(u32, f32)>, + /// `top_k.iter().map(|(_, d)| d).collect()`. + pub top_k_distances: Vec, + /// Statistics returned by `knn_search` (cmps, result_count). + pub stats: SearchStats, + /// Brute-force ground-truth top-`k` `(id, distance)` pairs in `(distance asc, + /// id asc)` order. + pub ground_truth: Vec<(u32, f32)>, +} + +impl KnnOracleRun { + /// Run [`FlatIndex::knn_search`] once, blocking on a fresh single-threaded + /// runtime, and pair the result with the brute-force ground truth. + pub fn run_sync( + index: &FlatIndex, + strategy: &Strategy, + query: &[f32], + k: usize, + ) -> ANNResult { + current_thread_runtime().block_on(Self::run(index, strategy, query, k)) + } + + /// Async variant of [`KnnOracleRun::new`]. Use this from tests that already + /// have a Tokio runtime (e.g. `#[tokio::test]`) or that need to drive + /// `knn_search` concurrently across tasks. + pub async fn run( + index: &FlatIndex, + strategy: &Strategy, + query: &[f32], + k: usize, + ) -> ANNResult { + let context = crate::flat::test::provider::Context::new(); + let mut buf = vec![Neighbor::::default(); k]; + + let stats = index + .knn_search( + NonZeroUsize::new(k).expect("flat::test::harness requires k > 0"), + strategy, + &CopyIds, + &context, + query, + &mut BackInserter::new(buf.as_mut_slice()), + ) + .await?; + + let top_k = top_k_sorted(&buf, stats.result_count as usize); + let top_k_distances = top_k.iter().map(|(_, d)| *d).collect(); + let ground_truth = brute_force_topk(index.provider(), query, k); + + Ok(Self { + top_k, + top_k_distances, + stats, + ground_truth, + }) + } +} + +/// Compute the brute-force top-`k` `(id, distance)` pairs over every element of +/// `provider`. Iterates [`Provider::items`] directly and scores with a fresh +/// [`f32::query_distance`] computer, so the oracle is independent of the +/// [`crate::flat::test::provider::Visitor`] under test. Ties are broken by ascending +/// id for determinism. +pub(crate) fn brute_force_topk(provider: &Provider, query: &[f32], k: usize) -> Vec<(u32, f32)> { + let computer = f32::query_distance(query, provider.metric()); + + let mut neighbors: Vec> = provider + .items() + .iter() + .enumerate() + .map(|(id, element)| Neighbor::new(id as u32, computer.evaluate_similarity(element))) + .collect(); + + sort_neighbors(&mut neighbors); + neighbors + .into_iter() + .take(k) + .map(|n| n.as_tuple()) + .collect() +} + +/// Take the first `result_count` neighbors and return them in `(distance asc, id asc)` +/// order. +fn top_k_sorted(buf: &[Neighbor], result_count: usize) -> Vec<(u32, f32)> { + let mut neighbors: Vec> = buf.iter().copied().take(result_count).collect(); + sort_neighbors(&mut neighbors); + neighbors.into_iter().map(|n| n.as_tuple()).collect() +} + +/// Sort a slice of [`Neighbor`] by `(distance asc, id asc)`. NaN distances are +/// treated as equal (test data should not produce NaN). +fn sort_neighbors(neighbors: &mut [Neighbor]) { + neighbors.sort_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(Ordering::Equal) + .then(a.id.cmp(&b.id)) + }); +} diff --git a/diskann/src/flat/test/mod.rs b/diskann/src/flat/test/mod.rs new file mode 100644 index 000000000..70856a964 --- /dev/null +++ b/diskann/src/flat/test/mod.rs @@ -0,0 +1,14 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Test fixtures and helpers for the flat module. +//! +//! Mirrors the layout of [`crate::graph::test`]: shared visitor / strategy / data helpers +//! live at this level, while end-to-end baseline-cached tests live under [`cases`]. + +pub(crate) mod harness; +pub(crate) mod provider; + +mod cases; diff --git a/diskann/src/flat/test/provider.rs b/diskann/src/flat/test/provider.rs new file mode 100644 index 000000000..bc9f4767e --- /dev/null +++ b/diskann/src/flat/test/provider.rs @@ -0,0 +1,442 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +#![allow(dead_code)] + +//! Self-contained test provider for the flat-search module. +//! +//! Mirrors the spirit of [`crate::graph::test::provider`] but stripped down to just +//! the surface needed by [`crate::flat::FlatIndex`]: +//! +//! * no adjacency lists / max-degree / start points, +//! * no `Delete` / `SetElement` / multi-insert plumbing, +//! * a single per-element counter (`get_element`) instead of a 5-counter metrics struct, +//! * no element-buffering accessor — vectors live in a `Vec>` and are handed +//! to the visitor by reference. +//! +//! The three public types parallel their graph counterparts: +//! +//! | Graph | Flat | Role | +//! | :----------------------------------------- | :-------------------- | :------------------------------------------ | +//! | [`crate::graph::test::provider::Provider`] | [`Provider`] | Owns the data and per-element counter. | +//! | [`crate::graph::test::provider::Accessor`] | [`Visitor`] | Per-search visitor; carries fault state. | +//! | [`crate::graph::test::provider::Strategy`] | [`Strategy`] | Stateless factory of [`Visitor`]s. | + +use std::{ + borrow::Cow, + collections::HashSet, + fmt::{self, Debug}, + future::Future, + sync::Arc, +}; + +use diskann_utils::future::SendFuture; +use diskann_vector::distance::Metric; +use thiserror::Error; + +use crate::{ + ANNError, always_escalate, + error::{Infallible, RankedError, ToRanked, TransientError}, + flat::{DistancesUnordered, OnElementsUnordered, SearchStrategy}, + graph::test::synthetic::Grid, + internal::counter::{Counter, LocalCounter}, + provider::{self, BuildQueryComputer, ExecutionContext, HasElementRef, HasId, NoopGuard}, + utils::VectorRepr, +}; + +////////////// +// Provider // +////////////// + +/// In-memory test provider for flat search. +#[derive(Debug)] +pub struct Provider { + items: Vec>, + dim: usize, + metric: Metric, + get_element: Counter, +} + +impl Provider { + /// Construct a provider that owns `items`. Every vector must have the same + /// (non-zero) length. + pub fn new(metric: Metric, items: impl IntoIterator>) -> Self { + let items: Vec> = items.into_iter().collect(); + assert!( + !items.is_empty(), + "flat::test::Provider needs at least one item" + ); + let dim = items[0].len(); + assert!( + dim > 0, + "flat::test::Provider items must have non-zero dimension" + ); + for (i, v) in items.iter().enumerate() { + assert_eq!( + v.len(), + dim, + "flat::test::Provider item {i} has dim {} but expected {dim}", + v.len(), + ); + } + Self { + items, + dim, + metric, + get_element: Counter::new(), + } + } + + /// Build a provider over the row vectors of [`Grid::data`]. IDs are `0..n` in + /// row-major order (last coordinate varies fastest). Uses [`Metric::L2`]. + /// + /// Unlike the graph-side `Provider::grid`, this does *not* add a separate + /// start-point row — flat search has no notion of one. + pub fn grid(grid: Grid, size: usize) -> Self { + let data = grid.data(size); + let items: Vec> = data.row_iter().map(|row| row.to_vec()).collect(); + Self::new(Metric::L2, items) + } + + /// Dimensionality of every vector in the provider. + pub fn dim(&self) -> usize { + self.dim + } + + /// Number of vectors in the provider. + pub fn len(&self) -> usize { + self.items.len() + } + + /// `true` if there are no vectors. + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + /// Distance metric the provider was constructed with. + pub fn metric(&self) -> Metric { + self.metric + } + + /// Snapshot of the per-provider counters. + pub fn metrics(&self) -> Metrics { + Metrics { + get_element: self.get_element.value(), + } + } + + /// Expose the items for brute force. + pub fn items(&self) -> &[Vec] { + self.items.as_slice() + } +} + +/// Counters tracked by [`Provider`]. +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(serde::Serialize, serde::Deserialize))] +pub struct Metrics { + /// The number of times any [`Visitor`] yielded an element. + pub get_element: usize, +} + +#[cfg(test)] +crate::test::cmp::verbose_eq!(Metrics { get_element }); + +///////////// +// Context // +///////////// + +/// Per-search execution context. No spawn/clone tracking — flat search runs on +/// the calling task and never spawns. +#[derive(Debug, Clone, Default)] +pub struct Context; + +impl Context { + pub fn new() -> Self { + Self + } +} + +impl ExecutionContext for Context { + fn wrap_spawn(&self, f: F) -> impl Future + Send + 'static + where + F: Future + Send + 'static, + { + f + } +} + +///////////////////// +// Errors / Guards // +///////////////////// + +/// Critical id-validation error: the requested id is out of range. +#[derive(Debug, Clone, Copy, Error, PartialEq, Eq)] +#[error("flat::test::Provider has no id {0}")] +pub struct InvalidId(pub u32); + +always_escalate!(InvalidId); + +impl From for ANNError { + #[track_caller] + fn from(err: InvalidId) -> ANNError { + ANNError::opaque(err) + } +} + +/// Transient access error injected by [`Visitor::flaky`]. +/// +/// Matches the shape of `graph::test::TransientAccessError`: panics in `Drop` if it +/// is dropped without being acknowledged or escalated. This guards against accidental +/// silent suppression of the error in the test code itself. +#[must_use] +#[derive(Debug)] +pub struct TransientGetError { + id: u32, + handled: bool, +} + +impl TransientGetError { + fn new(id: u32) -> Self { + Self { id, handled: false } + } +} + +impl fmt::Display for TransientGetError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "transient failure retrieving id {}", self.id) + } +} + +impl std::error::Error for TransientGetError {} + +impl Drop for TransientGetError { + fn drop(&mut self) { + assert!( + self.handled, + "dropped an unhandled TransientGetError for id {}", + self.id, + ); + } +} + +impl TransientError for TransientGetError { + fn acknowledge(mut self, _why: D) + where + D: fmt::Display, + { + self.handled = true; + } + + fn escalate(mut self, _why: D) -> InvalidId + where + D: fmt::Display, + { + self.handled = true; + InvalidId(self.id) + } +} + +/// Two-tier error for [`Visitor::on_elements_unordered`]: a critical [`InvalidId`] +/// or a recoverable [`TransientGetError`]. +#[derive(Debug)] +pub enum AccessError { + InvalidId(InvalidId), + Transient(TransientGetError), +} + +impl fmt::Display for AccessError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidId(e) => fmt::Display::fmt(e, f), + Self::Transient(e) => fmt::Display::fmt(e, f), + } + } +} + +impl std::error::Error for AccessError {} + +impl ToRanked for AccessError { + type Transient = TransientGetError; + type Error = InvalidId; + + fn to_ranked(self) -> RankedError { + match self { + Self::InvalidId(e) => RankedError::Error(e), + Self::Transient(e) => RankedError::Transient(e), + } + } + + fn from_transient(transient: TransientGetError) -> Self { + Self::Transient(transient) + } + + fn from_error(error: InvalidId) -> Self { + Self::InvalidId(error) + } +} + +////////////////// +// DataProvider // +////////////////// + +impl provider::DataProvider for Provider { + type Context = Context; + type InternalId = u32; + type ExternalId = u32; + type Error = InvalidId; + type Guard = NoopGuard; + + fn to_internal_id(&self, _ctx: &Context, gid: &u32) -> Result { + if (*gid as usize) < self.items.len() { + Ok(*gid) + } else { + Err(InvalidId(*gid)) + } + } + + fn to_external_id(&self, _ctx: &Context, id: u32) -> Result { + if (id as usize) < self.items.len() { + Ok(id) + } else { + Err(InvalidId(id)) + } + } +} + +///////////// +// Visitor // +///////////// + +/// Per-search visitor over a [`Provider`]. Analog of `graph::test::Accessor`: holds +/// the `'a` borrow of the provider, accumulates a local `get_element` counter that +/// flushes back on drop, and optionally injects transient errors for a configurable +/// set of ids. +pub struct Visitor<'a> { + provider: &'a Provider, + transient_ids: Option>>, + get_element: LocalCounter<'a>, +} + +impl<'a> Visitor<'a> { + /// Construct a visitor with no fault injection. + pub fn new(provider: &'a Provider) -> Self { + Self { + provider, + transient_ids: None, + get_element: provider.get_element.local(), + } + } + + /// Construct a visitor that returns a [`TransientGetError`] for any id in + /// `transient_ids`. Other ids behave normally. + pub fn flaky(provider: &'a Provider, transient_ids: Cow<'a, HashSet>) -> Self { + Self { + provider, + transient_ids: Some(transient_ids), + get_element: provider.get_element.local(), + } + } + + /// The borrowed [`Provider`]. + pub fn provider(&self) -> &'a Provider { + self.provider + } +} + +impl Debug for Visitor<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Visitor") + .field("provider", &self.provider) + .field("transient_ids", &self.transient_ids) + .finish_non_exhaustive() + } +} + +impl HasId for Visitor<'_> { + type Id = u32; +} + +impl HasElementRef for Visitor<'_> { + type ElementRef<'a> = &'a [f32]; +} + +impl BuildQueryComputer<&[f32]> for Visitor<'_> { + type QueryComputerError = Infallible; + type QueryComputer = ::QueryDistance; + + fn build_query_computer( + &self, + from: &[f32], + ) -> Result { + Ok(f32::query_distance(from, self.provider.metric)) + } +} + +impl OnElementsUnordered for Visitor<'_> { + type Error = AccessError; + + fn on_elements_unordered(&mut self, mut f: F) -> impl SendFuture> + where + F: Send + for<'a> FnMut(Self::Id, ::ElementRef<'a>), + { + async move { + for (i, vector) in self.provider.items.iter().enumerate() { + let id = i as u32; + if let Some(ids) = &self.transient_ids + && ids.contains(&id) + { + return Err(AccessError::Transient(TransientGetError::new(id))); + } + self.get_element.increment(); + f(id, vector.as_slice()); + } + Ok(()) + } + } +} + +impl DistancesUnordered<&[f32]> for Visitor<'_> {} + +////////////// +// Strategy // +////////////// + +/// Stateless factory of [`Visitor`]s. +#[derive(Clone, Debug, Default)] +pub struct Strategy { + transient_ids: Option>>, +} + +impl Strategy { + pub fn new() -> Self { + Self::default() + } + + /// Construct a strategy whose visitors return a transient error on `get_element` + /// for every id in `transient_ids`. + pub fn with_transient(transient_ids: impl IntoIterator) -> Self { + Self { + transient_ids: Some(Arc::new(transient_ids.into_iter().collect())), + } + } +} + +impl SearchStrategy for Strategy { + type Visitor<'a> = Visitor<'a>; + type Error = Infallible; + + fn create_visitor<'a>( + &'a self, + provider: &'a Provider, + _context: &'a Context, + ) -> Result, Self::Error> { + let visitor = match &self.transient_ids { + Some(ids) => Visitor::flaky(provider, Cow::Borrowed(ids)), + None => Visitor::new(provider), + }; + Ok(visitor) + } +} diff --git a/diskann/test/generated/flat/test/cases/flat_knn_search/search_1_100.json b/diskann/test/generated/flat/test/cases/flat_knn_search/search_1_100.json new file mode 100644 index 000000000..3dda3b7c8 --- /dev/null +++ b/diskann/test/generated/flat/test/cases/flat_knn_search/search_1_100.json @@ -0,0 +1,264 @@ +{ + "file": "diskann/src/flat/test/cases/flat_knn_search.rs", + "test": "flat/test/cases/flat_knn_search/search_1_100", + "payload": [ + { + "comparisons": 100, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 0, + 1.0 + ] + ], + "k": 1, + "metrics": { + "get_element": 100 + }, + "query": [ + -1.0 + ], + "result_count": 1, + "top_k_distances": [ + 1.0 + ] + }, + { + "comparisons": 100, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 0, + 1.0 + ], + [ + 1, + 4.0 + ], + [ + 2, + 9.0 + ], + [ + 3, + 16.0 + ] + ], + "k": 4, + "metrics": { + "get_element": 100 + }, + "query": [ + -1.0 + ], + "result_count": 4, + "top_k_distances": [ + 1.0, + 4.0, + 9.0, + 16.0 + ] + }, + { + "comparisons": 100, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 0, + 1.0 + ], + [ + 1, + 4.0 + ], + [ + 2, + 9.0 + ], + [ + 3, + 16.0 + ], + [ + 4, + 25.0 + ], + [ + 5, + 36.0 + ], + [ + 6, + 49.0 + ], + [ + 7, + 64.0 + ], + [ + 8, + 81.0 + ], + [ + 9, + 100.0 + ] + ], + "k": 10, + "metrics": { + "get_element": 100 + }, + "query": [ + -1.0 + ], + "result_count": 10, + "top_k_distances": [ + 1.0, + 4.0, + 9.0, + 16.0, + 25.0, + 36.0, + 49.0, + 64.0, + 81.0, + 100.0 + ] + }, + { + "comparisons": 100, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 99, + 0.0 + ] + ], + "k": 1, + "metrics": { + "get_element": 100 + }, + "query": [ + 99.0 + ], + "result_count": 1, + "top_k_distances": [ + 0.0 + ] + }, + { + "comparisons": 100, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 99, + 0.0 + ], + [ + 98, + 1.0 + ], + [ + 97, + 4.0 + ], + [ + 96, + 9.0 + ] + ], + "k": 4, + "metrics": { + "get_element": 100 + }, + "query": [ + 99.0 + ], + "result_count": 4, + "top_k_distances": [ + 0.0, + 1.0, + 4.0, + 9.0 + ] + }, + { + "comparisons": 100, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 99, + 0.0 + ], + [ + 98, + 1.0 + ], + [ + 97, + 4.0 + ], + [ + 96, + 9.0 + ], + [ + 95, + 16.0 + ], + [ + 94, + 25.0 + ], + [ + 93, + 36.0 + ], + [ + 92, + 49.0 + ], + [ + 91, + 64.0 + ], + [ + 90, + 81.0 + ] + ], + "k": 10, + "metrics": { + "get_element": 100 + }, + "query": [ + 99.0 + ], + "result_count": 10, + "top_k_distances": [ + 0.0, + 1.0, + 4.0, + 9.0, + 16.0, + 25.0, + 36.0, + 49.0, + 64.0, + 81.0 + ] + } + ] +} \ No newline at end of file diff --git a/diskann/test/generated/flat/test/cases/flat_knn_search/search_2_5.json b/diskann/test/generated/flat/test/cases/flat_knn_search/search_2_5.json new file mode 100644 index 000000000..3a978a520 --- /dev/null +++ b/diskann/test/generated/flat/test/cases/flat_knn_search/search_2_5.json @@ -0,0 +1,270 @@ +{ + "file": "diskann/src/flat/test/cases/flat_knn_search.rs", + "test": "flat/test/cases/flat_knn_search/search_2_5", + "payload": [ + { + "comparisons": 25, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 0, + 2.0 + ] + ], + "k": 1, + "metrics": { + "get_element": 25 + }, + "query": [ + -1.0, + -1.0 + ], + "result_count": 1, + "top_k_distances": [ + 2.0 + ] + }, + { + "comparisons": 25, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 0, + 2.0 + ], + [ + 1, + 5.0 + ], + [ + 5, + 5.0 + ], + [ + 6, + 8.0 + ] + ], + "k": 4, + "metrics": { + "get_element": 25 + }, + "query": [ + -1.0, + -1.0 + ], + "result_count": 4, + "top_k_distances": [ + 2.0, + 5.0, + 5.0, + 8.0 + ] + }, + { + "comparisons": 25, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 0, + 2.0 + ], + [ + 1, + 5.0 + ], + [ + 5, + 5.0 + ], + [ + 6, + 8.0 + ], + [ + 2, + 10.0 + ], + [ + 10, + 10.0 + ], + [ + 7, + 13.0 + ], + [ + 11, + 13.0 + ], + [ + 3, + 17.0 + ], + [ + 15, + 17.0 + ] + ], + "k": 10, + "metrics": { + "get_element": 25 + }, + "query": [ + -1.0, + -1.0 + ], + "result_count": 10, + "top_k_distances": [ + 2.0, + 5.0, + 5.0, + 8.0, + 10.0, + 10.0, + 13.0, + 13.0, + 17.0, + 17.0 + ] + }, + { + "comparisons": 25, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 24, + 0.0 + ] + ], + "k": 1, + "metrics": { + "get_element": 25 + }, + "query": [ + 4.0, + 4.0 + ], + "result_count": 1, + "top_k_distances": [ + 0.0 + ] + }, + { + "comparisons": 25, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 24, + 0.0 + ], + [ + 19, + 1.0 + ], + [ + 23, + 1.0 + ], + [ + 18, + 2.0 + ] + ], + "k": 4, + "metrics": { + "get_element": 25 + }, + "query": [ + 4.0, + 4.0 + ], + "result_count": 4, + "top_k_distances": [ + 0.0, + 1.0, + 1.0, + 2.0 + ] + }, + { + "comparisons": 25, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 24, + 0.0 + ], + [ + 19, + 1.0 + ], + [ + 23, + 1.0 + ], + [ + 18, + 2.0 + ], + [ + 14, + 4.0 + ], + [ + 22, + 4.0 + ], + [ + 13, + 5.0 + ], + [ + 17, + 5.0 + ], + [ + 12, + 8.0 + ], + [ + 9, + 9.0 + ] + ], + "k": 10, + "metrics": { + "get_element": 25 + }, + "query": [ + 4.0, + 4.0 + ], + "result_count": 10, + "top_k_distances": [ + 0.0, + 1.0, + 1.0, + 2.0, + 4.0, + 4.0, + 5.0, + 5.0, + 8.0, + 9.0 + ] + } + ] +} \ No newline at end of file diff --git a/diskann/test/generated/flat/test/cases/flat_knn_search/search_3_4.json b/diskann/test/generated/flat/test/cases/flat_knn_search/search_3_4.json new file mode 100644 index 000000000..74e067761 --- /dev/null +++ b/diskann/test/generated/flat/test/cases/flat_knn_search/search_3_4.json @@ -0,0 +1,276 @@ +{ + "file": "diskann/src/flat/test/cases/flat_knn_search.rs", + "test": "flat/test/cases/flat_knn_search/search_3_4", + "payload": [ + { + "comparisons": 64, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 0, + 3.0 + ] + ], + "k": 1, + "metrics": { + "get_element": 64 + }, + "query": [ + -1.0, + -1.0, + -1.0 + ], + "result_count": 1, + "top_k_distances": [ + 3.0 + ] + }, + { + "comparisons": 64, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 0, + 3.0 + ], + [ + 1, + 6.0 + ], + [ + 4, + 6.0 + ], + [ + 16, + 6.0 + ] + ], + "k": 4, + "metrics": { + "get_element": 64 + }, + "query": [ + -1.0, + -1.0, + -1.0 + ], + "result_count": 4, + "top_k_distances": [ + 3.0, + 6.0, + 6.0, + 6.0 + ] + }, + { + "comparisons": 64, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 0, + 3.0 + ], + [ + 1, + 6.0 + ], + [ + 4, + 6.0 + ], + [ + 16, + 6.0 + ], + [ + 5, + 9.0 + ], + [ + 17, + 9.0 + ], + [ + 20, + 9.0 + ], + [ + 2, + 11.0 + ], + [ + 8, + 11.0 + ], + [ + 32, + 11.0 + ] + ], + "k": 10, + "metrics": { + "get_element": 64 + }, + "query": [ + -1.0, + -1.0, + -1.0 + ], + "result_count": 10, + "top_k_distances": [ + 3.0, + 6.0, + 6.0, + 6.0, + 9.0, + 9.0, + 9.0, + 11.0, + 11.0, + 11.0 + ] + }, + { + "comparisons": 64, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 63, + 0.0 + ] + ], + "k": 1, + "metrics": { + "get_element": 64 + }, + "query": [ + 3.0, + 3.0, + 3.0 + ], + "result_count": 1, + "top_k_distances": [ + 0.0 + ] + }, + { + "comparisons": 64, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 63, + 0.0 + ], + [ + 47, + 1.0 + ], + [ + 59, + 1.0 + ], + [ + 62, + 1.0 + ] + ], + "k": 4, + "metrics": { + "get_element": 64 + }, + "query": [ + 3.0, + 3.0, + 3.0 + ], + "result_count": 4, + "top_k_distances": [ + 0.0, + 1.0, + 1.0, + 1.0 + ] + }, + { + "comparisons": 64, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 63, + 0.0 + ], + [ + 47, + 1.0 + ], + [ + 59, + 1.0 + ], + [ + 62, + 1.0 + ], + [ + 43, + 2.0 + ], + [ + 46, + 2.0 + ], + [ + 58, + 2.0 + ], + [ + 42, + 3.0 + ], + [ + 31, + 4.0 + ], + [ + 55, + 4.0 + ] + ], + "k": 10, + "metrics": { + "get_element": 64 + }, + "query": [ + 3.0, + 3.0, + 3.0 + ], + "result_count": 10, + "top_k_distances": [ + 0.0, + 1.0, + 1.0, + 1.0, + 2.0, + 2.0, + 2.0, + 3.0, + 4.0, + 4.0 + ] + } + ] +} \ No newline at end of file From f338a8395e19699419abbd887ef4837f03664c02 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Fri, 8 May 2026 17:09:24 -0400 Subject: [PATCH 19/24] minor cleanup docs --- diskann/src/flat/test/cases/flat_knn_search.rs | 3 +-- diskann/src/flat/test/mod.rs | 4 ---- diskann/src/flat/test/provider.rs | 17 ----------------- 3 files changed, 1 insertion(+), 23 deletions(-) diff --git a/diskann/src/flat/test/cases/flat_knn_search.rs b/diskann/src/flat/test/cases/flat_knn_search.rs index b5bd47df1..76a00c33c 100644 --- a/diskann/src/flat/test/cases/flat_knn_search.rs +++ b/diskann/src/flat/test/cases/flat_knn_search.rs @@ -5,8 +5,7 @@ //! Baseline-cached regression sweep for [`crate::flat::FlatIndex::knn_search`]. //! -//! Mirrors [`crate::graph::test::cases::grid_search`]: builds a fresh index per -//! parameter combination, runs `knn_search` through the +//! Bbuilds a fresh index per parameter combination, runs `knn_search` through the //! [`crate::flat::test::harness`], snapshots the result + statistics into //! [`FlatKnnBaseline`], and compares the entire batch against the JSON committed under //! `diskann/test/generated/flat/test/cases/flat_knn_search/`. diff --git a/diskann/src/flat/test/mod.rs b/diskann/src/flat/test/mod.rs index 70856a964..8cbd01622 100644 --- a/diskann/src/flat/test/mod.rs +++ b/diskann/src/flat/test/mod.rs @@ -4,10 +4,6 @@ */ //! Test fixtures and helpers for the flat module. -//! -//! Mirrors the layout of [`crate::graph::test`]: shared visitor / strategy / data helpers -//! live at this level, while end-to-end baseline-cached tests live under [`cases`]. - pub(crate) mod harness; pub(crate) mod provider; diff --git a/diskann/src/flat/test/provider.rs b/diskann/src/flat/test/provider.rs index bc9f4767e..9890410a3 100644 --- a/diskann/src/flat/test/provider.rs +++ b/diskann/src/flat/test/provider.rs @@ -6,23 +6,6 @@ #![allow(dead_code)] //! Self-contained test provider for the flat-search module. -//! -//! Mirrors the spirit of [`crate::graph::test::provider`] but stripped down to just -//! the surface needed by [`crate::flat::FlatIndex`]: -//! -//! * no adjacency lists / max-degree / start points, -//! * no `Delete` / `SetElement` / multi-insert plumbing, -//! * a single per-element counter (`get_element`) instead of a 5-counter metrics struct, -//! * no element-buffering accessor — vectors live in a `Vec>` and are handed -//! to the visitor by reference. -//! -//! The three public types parallel their graph counterparts: -//! -//! | Graph | Flat | Role | -//! | :----------------------------------------- | :-------------------- | :------------------------------------------ | -//! | [`crate::graph::test::provider::Provider`] | [`Provider`] | Owns the data and per-element counter. | -//! | [`crate::graph::test::provider::Accessor`] | [`Visitor`] | Per-search visitor; carries fault state. | -//! | [`crate::graph::test::provider::Strategy`] | [`Strategy`] | Stateless factory of [`Visitor`]s. | use std::{ borrow::Cow, From 292840422980ae7a1df154af0aa1abef860f4ed3 Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Mon, 11 May 2026 17:09:22 -0700 Subject: [PATCH 20/24] Add HasQueryComputer trait --- .../src/search/provider/disk_provider.rs | 15 +++-- diskann-garnet/src/provider.rs | 19 ++++--- .../encoded_document_accessor.rs | 20 ++++--- .../inline_beta_search/inline_beta_filter.rs | 2 +- .../graph/provider/async_/bf_tree/provider.rs | 29 +++++++--- .../graph/provider/async_/caching/provider.rs | 21 ++++--- .../provider/async_/inmem/full_precision.rs | 17 ++++-- .../graph/provider/async_/inmem/product.rs | 18 ++++-- .../graph/provider/async_/inmem/scalar.rs | 21 ++++--- .../graph/provider/async_/inmem/spherical.rs | 20 ++++--- .../model/graph/provider/async_/inmem/test.rs | 15 +++-- .../graph/provider/async_/postprocess.rs | 4 +- .../model/graph/provider/layers/betafilter.rs | 30 ++++++---- diskann/src/flat/iterator.rs | 50 +++++++++-------- diskann/src/flat/strategy.rs | 16 ++++-- diskann/src/flat/test/provider.rs | 12 +++- diskann/src/graph/glue.rs | 28 ++++++---- diskann/src/graph/index.rs | 4 +- diskann/src/graph/search/multihop_search.rs | 4 +- diskann/src/graph/search/range_search.rs | 4 +- diskann/src/graph/test/provider.rs | 9 ++- diskann/src/provider.rs | 56 +++++++++++++------ 22 files changed, 267 insertions(+), 147 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 3dbe517d5..7470e8ef1 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -29,7 +29,7 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, - DistancesUnordered, HasElementRef, HasId, NeighborAccessor, NoopGuard, + DistancesUnordered, HasElementRef, HasId, HasQueryComputer, NeighborAccessor, NoopGuard, }, utils::{IntoUsize, VectorRepr}, ANNError, ANNResult, @@ -406,13 +406,20 @@ impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer { } } +impl HasQueryComputer for DiskAccessor<'_, Data, VP> +where + Data: GraphDataType, + VP: VertexProvider, +{ + type QueryComputer = DiskQueryComputer; +} + impl BuildQueryComputer<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> where Data: GraphDataType, VP: VertexProvider, { type QueryComputerError = ANNError; - type QueryComputer = DiskQueryComputer; fn build_query_computer( &self, @@ -429,7 +436,7 @@ where } } -impl DistancesUnordered<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> +impl DistancesUnordered for DiskAccessor<'_, Data, VP> where Data: GraphDataType, VP: VertexProvider, @@ -448,7 +455,7 @@ where } } -impl ExpandBeam<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> +impl ExpandBeam for DiskAccessor<'_, Data, VP> where Data: GraphDataType, VP: VertexProvider, diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 8350c6cb8..64389f2c2 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -18,8 +18,8 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DelegateNeighbor, - Delete, DistancesUnordered, ElementStatus, HasElementRef, HasId, NeighborAccessor, - NeighborAccessorMut, NoopGuard, SetElement, + Delete, DistancesUnordered, ElementStatus, HasElementRef, HasId, HasQueryComputer, + NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, }, utils::VectorRepr, }; @@ -467,7 +467,7 @@ impl SearchExt for FullAccessor<'_, T> { } } -impl ExpandBeam<&[T]> for FullAccessor<'_, T> { +impl ExpandBeam for FullAccessor<'_, T> { fn expand_beam( &mut self, ids: Itr, @@ -573,8 +573,11 @@ impl BuildDistanceComputer for FullAccessor<'_, T> { } } -impl BuildQueryComputer<&[T]> for FullAccessor<'_, T> { +impl HasQueryComputer for FullAccessor<'_, T> { type QueryComputer = T::QueryDistance; +} + +impl BuildQueryComputer<&[T]> for FullAccessor<'_, T> { type QueryComputerError = GarnetProviderError; fn build_query_computer( @@ -585,7 +588,7 @@ impl BuildQueryComputer<&[T]> for FullAccessor<'_, T> { } } -impl DistancesUnordered<&[T]> for FullAccessor<'_, T> {} +impl DistancesUnordered for FullAccessor<'_, T> {} /// An escape hatch for the blanket implementation of [`workingset::Fill`]. /// @@ -760,16 +763,14 @@ impl NeighborAccessorMut for DelegateNeighborAccessor<'_, '_, T> #[derive(Debug, Default, Clone, Copy)] pub struct CopyExternalIds; -impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], GarnetId> - for CopyExternalIds -{ +impl<'a, T: VectorRepr> SearchPostProcess, &[T], GarnetId> for CopyExternalIds { type Error = GarnetProviderError; fn post_process( &self, accessor: &mut FullAccessor<'a, T>, _query: &[T], - _computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, + _computer: & as HasQueryComputer>::QueryComputer, candidates: I, output: &mut B, ) -> impl Future> + Send diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 15a2b5ce9..cd25f1b5f 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -10,7 +10,7 @@ use diskann::{ graph::glue::{ExpandBeam, SearchExt}, provider::{ Accessor, AsNeighbor, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, - HasElementRef, HasId, + HasElementRef, HasId, HasQueryComputer, }, ANNError, ANNErrorKind, }; @@ -171,12 +171,18 @@ where } } +impl HasQueryComputer for EncodedDocumentAccessor +where + IA: HasQueryComputer + Accessor, +{ + type QueryComputer = InlineBetaComputer; +} + impl<'q, IA, Q> BuildQueryComputer<&'q FilteredQuery> for EncodedDocumentAccessor where IA: BuildQueryComputer<&'q Q> + Accessor, { type QueryComputerError = ANNError; - type QueryComputer = InlineBetaComputer; fn build_query_computer( &self, @@ -196,19 +202,17 @@ where } } -impl ExpandBeam for EncodedDocumentAccessor +impl ExpandBeam for EncodedDocumentAccessor where IA: Accessor, - EncodedDocumentAccessor: BuildQueryComputer + AsNeighbor, - Q: Clone, + EncodedDocumentAccessor: DistancesUnordered + AsNeighbor, { } -impl DistancesUnordered for EncodedDocumentAccessor +impl DistancesUnordered for EncodedDocumentAccessor where IA: Accessor, - EncodedDocumentAccessor: BuildQueryComputer, - Q: Clone, + EncodedDocumentAccessor: HasQueryComputer, { } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 10aef0109..012c74e6f 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -153,7 +153,7 @@ where &self, accessor: &mut EncodedDocumentAccessor, query: &'q FilteredQuery, - computer: &InlineBetaComputer<>::QueryComputer>, + computer: &InlineBetaComputer, candidates: I, output: &mut B, ) -> Result diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index e67c50d00..4f61b6900 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -29,7 +29,7 @@ use diskann::{ provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, Delete, DistancesUnordered, ElementStatus, HasElementRef, HasId, - NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, + HasQueryComputer, NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, }, utils::{IntoUsize, VectorRepr}, }; @@ -1040,6 +1040,15 @@ where } } +impl HasQueryComputer for FullAccessor<'_, T, Q, D> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, +{ + type QueryComputer = T::QueryDistance; +} + impl BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D> where T: VectorRepr, @@ -1047,7 +1056,6 @@ where D: AsyncFriendly, { type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; fn build_query_computer( &self, @@ -1056,7 +1064,7 @@ where Ok(T::query_distance(from, self.provider.metric)) } } -impl ExpandBeam<&[T]> for FullAccessor<'_, T, Q, D> +impl ExpandBeam for FullAccessor<'_, T, Q, D> where T: VectorRepr, Q: AsyncFriendly, @@ -1064,7 +1072,7 @@ where { } -impl DistancesUnordered<&[T]> for FullAccessor<'_, T, Q, D> +impl DistancesUnordered for FullAccessor<'_, T, Q, D> where T: VectorRepr, Q: AsyncFriendly, @@ -1216,13 +1224,20 @@ where } } +impl HasQueryComputer for QuantAccessor<'_, T, D> +where + T: VectorRepr, + D: AsyncFriendly, +{ + type QueryComputer = pq::distance::QueryComputer>; +} + impl BuildQueryComputer<&[T]> for QuantAccessor<'_, T, D> where T: VectorRepr, D: AsyncFriendly, { type QueryComputerError = ANNError; - type QueryComputer = pq::distance::QueryComputer>; fn build_query_computer( &self, @@ -1232,14 +1247,14 @@ where } } -impl ExpandBeam<&[T]> for QuantAccessor<'_, T, D> +impl ExpandBeam for QuantAccessor<'_, T, D> where T: VectorRepr, D: AsyncFriendly, { } -impl DistancesUnordered<&[T]> for QuantAccessor<'_, T, D> +impl DistancesUnordered for QuantAccessor<'_, T, D> where T: VectorRepr, D: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs index 09899c781..e9962d6b8 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs @@ -75,7 +75,7 @@ use diskann::{ provider::{ Accessor, AsNeighbor, BuildDistanceComputer, BuildQueryComputer, CacheableAccessor, DataProvider, DelegateNeighbor, Delete, DistancesUnordered, ElementStatus, HasElementRef, - HasId, NeighborAccessor, NeighborAccessorMut, SetElement, + HasId, HasQueryComputer, NeighborAccessor, NeighborAccessorMut, SetElement, }, }; use diskann_utils::{ @@ -781,13 +781,20 @@ where } } +impl HasQueryComputer for CachingAccessor +where + A: HasQueryComputer + CacheableAccessor, + C: ElementCache, +{ + type QueryComputer = A::QueryComputer; +} + impl BuildQueryComputer for CachingAccessor where A: BuildQueryComputer + CacheableAccessor, C: ElementCache, { type QueryComputerError = A::QueryComputerError; - type QueryComputer = A::QueryComputer; fn build_query_computer( &self, @@ -825,16 +832,16 @@ where } } -impl ExpandBeam for CachingAccessor +impl ExpandBeam for CachingAccessor where - A: BuildQueryComputer + CacheableAccessor + AsNeighbor, + A: DistancesUnordered + CacheableAccessor + AsNeighbor, C: ElementCache + NeighborCache, { } -impl DistancesUnordered for CachingAccessor +impl DistancesUnordered for CachingAccessor where - A: BuildQueryComputer + CacheableAccessor + AsNeighbor, + A: HasQueryComputer + CacheableAccessor + AsNeighbor, C: ElementCache + NeighborCache, { } @@ -860,7 +867,7 @@ where next: &Next, accessor: &mut CachingAccessor, query: T, - computer: &>::QueryComputer, + computer: &::QueryComputer, candidates: I, output: &mut B, ) -> impl Future>> + Send diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 0b5eb1b72..5c6a3988e 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -20,7 +20,7 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - DistancesUnordered, ExecutionContext, HasElementRef, HasId, + DistancesUnordered, ExecutionContext, HasElementRef, HasId, HasQueryComputer, }, utils::{IntoUsize, VectorRepr}, }; @@ -296,6 +296,16 @@ where } } +impl HasQueryComputer for FullAccessor<'_, T, Q, D, Ctx> +where + T: VectorRepr, + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputer = T::QueryDistance; +} + impl BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, @@ -304,7 +314,6 @@ where Ctx: ExecutionContext, { type QueryComputerError = Panics; - type QueryComputer = T::QueryDistance; fn build_query_computer( &self, @@ -314,7 +323,7 @@ where } } -impl ExpandBeam<&[T]> for FullAccessor<'_, T, Q, D, Ctx> +impl ExpandBeam for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, Q: AsyncFriendly, @@ -323,7 +332,7 @@ where { } -impl DistancesUnordered<&[T]> for FullAccessor<'_, T, Q, D, Ctx> +impl DistancesUnordered for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, Q: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index 477dc517d..6022d4520 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -17,7 +17,7 @@ use diskann::{ }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, - ExecutionContext, HasElementRef, HasId, + ExecutionContext, HasElementRef, HasId, HasQueryComputer, }, utils::{IntoUsize, VectorRepr}, }; @@ -205,6 +205,15 @@ where } } +impl HasQueryComputer for QuantAccessor<'_, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputer = pq::distance::QueryComputer>; +} + impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> where T: VectorRepr, @@ -213,7 +222,6 @@ where Ctx: ExecutionContext, { type QueryComputerError = ANNError; - type QueryComputer = pq::distance::QueryComputer>; fn build_query_computer( &self, @@ -239,18 +247,16 @@ where } } -impl ExpandBeam<&[T]> for QuantAccessor<'_, V, D, Ctx> +impl ExpandBeam for QuantAccessor<'_, V, D, Ctx> where - T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, { } -impl DistancesUnordered<&[T]> for QuantAccessor<'_, V, D, Ctx> +impl DistancesUnordered for QuantAccessor<'_, V, D, Ctx> where - T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index c9c781fed..5e130e8f9 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -17,7 +17,7 @@ use diskann::{ }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, - ExecutionContext, HasElementRef, HasId, + ExecutionContext, HasElementRef, HasId, HasQueryComputer, }, utils::{IntoUsize, VectorRepr}, }; @@ -527,6 +527,17 @@ where } } +impl HasQueryComputer for QuantAccessor<'_, NBITS, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, + Unsigned: Representation, + QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, +{ + type QueryComputer = QueryComputer; +} + impl BuildQueryComputer<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx> where @@ -538,7 +549,6 @@ where QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { type QueryComputerError = ANNError; - type QueryComputer = QueryComputer; fn build_query_computer( &self, @@ -552,9 +562,8 @@ where } } -impl ExpandBeam<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx> +impl ExpandBeam for QuantAccessor<'_, NBITS, V, D, Ctx> where - T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, @@ -563,10 +572,8 @@ where { } -impl DistancesUnordered<&[T]> - for QuantAccessor<'_, NBITS, V, D, Ctx> +impl DistancesUnordered for QuantAccessor<'_, NBITS, V, D, Ctx> where - T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index 0edaf4a77..da6731b69 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -19,7 +19,7 @@ use diskann::{ }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, - ExecutionContext, HasElementRef, HasId, + ExecutionContext, HasElementRef, HasId, HasQueryComputer, }, utils::{IntoUsize, VectorRepr}, }; @@ -441,6 +441,16 @@ where } } +impl HasQueryComputer for QuantAccessor<'_, V, D, Ctx> +where + V: AsyncFriendly, + D: AsyncFriendly, + Ctx: ExecutionContext, +{ + type QueryComputer = + UnwrapErr; +} + impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> where T: VectorRepr, @@ -449,8 +459,6 @@ where Ctx: ExecutionContext, { type QueryComputerError = Bridge; - type QueryComputer = - UnwrapErr; fn build_query_computer( &self, @@ -464,18 +472,16 @@ where } } -impl ExpandBeam<&[T]> for QuantAccessor<'_, V, D, Ctx> +impl ExpandBeam for QuantAccessor<'_, V, D, Ctx> where - T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, { } -impl DistancesUnordered<&[T]> for QuantAccessor<'_, V, D, Ctx> +impl DistancesUnordered for QuantAccessor<'_, V, D, Ctx> where - T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs index 61a99498f..692c2913f 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs @@ -21,7 +21,7 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - DistancesUnordered, HasElementRef, HasId, + DistancesUnordered, HasElementRef, HasId, HasQueryComputer, }, utils::IntoUsize, }; @@ -218,10 +218,13 @@ impl<'a> BuildDistanceComputer for FlakyAccessor<'a> { } } +impl<'b> HasQueryComputer for FlakyAccessor<'b> { + type QueryComputer = as HasQueryComputer>::QueryComputer; +} + impl<'a, 'b> BuildQueryComputer<&'a [f32]> for FlakyAccessor<'b> { type QueryComputerError = as BuildQueryComputer<&'a [f32]>>::QueryComputerError; - type QueryComputer = as BuildQueryComputer<&'a [f32]>>::QueryComputer; fn build_query_computer( &self, @@ -231,9 +234,9 @@ impl<'a, 'b> BuildQueryComputer<&'a [f32]> for FlakyAccessor<'b> { } } -impl ExpandBeam<&[f32]> for FlakyAccessor<'_> {} +impl ExpandBeam for FlakyAccessor<'_> {} -impl DistancesUnordered<&[f32]> for FlakyAccessor<'_> {} +impl DistancesUnordered for FlakyAccessor<'_> {} impl<'a> DelegateNeighbor<'a> for FlakyAccessor<'_> { type Delegate = &'a SimpleNeighborProviderAsync; @@ -242,8 +245,8 @@ impl<'a> DelegateNeighbor<'a> for FlakyAccessor<'_> { } } -impl<'x> SearchStrategy for Flaky { - type QueryComputer = as BuildQueryComputer<&'x [f32]>>::QueryComputer; +impl SearchStrategy for Flaky { + type QueryComputer = as HasQueryComputer>::QueryComputer; type SearchAccessor<'a> = FlakyAccessor<'a>; type SearchAccessorError = ANNError; diff --git a/diskann-providers/src/model/graph/provider/async_/postprocess.rs b/diskann-providers/src/model/graph/provider/async_/postprocess.rs index dbbf08fa4..b652c8bdb 100644 --- a/diskann-providers/src/model/graph/provider/async_/postprocess.rs +++ b/diskann-providers/src/model/graph/provider/async_/postprocess.rs @@ -8,7 +8,7 @@ use diskann::{ graph::{SearchOutputBuffer, glue}, neighbor::Neighbor, - provider::{BuildQueryComputer, HasId}, + provider::{BuildQueryComputer, HasId, HasQueryComputer}, }; /// A bridge allowing `Accessors` to opt-in to [`RemoveDeletedIdsAndCopy`] by delegating to @@ -47,7 +47,7 @@ where &self, accessor: &mut A, _query: T, - _computer: &>::QueryComputer, + _computer: &::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index ff822dd9d..7719f5206 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -26,7 +26,7 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, AsNeighbor, BuildQueryComputer, DataProvider, DelegateNeighbor, - DistancesUnordered, HasElementRef, HasId, + DistancesUnordered, HasElementRef, HasId, HasQueryComputer, }, utils::VectorId, }; @@ -286,12 +286,18 @@ where } } -impl BuildQueryComputer for BetaAccessor +impl HasQueryComputer for BetaAccessor where - Inner: BuildQueryComputer + Accessor, + Inner: HasQueryComputer + Accessor, { /// Use a [`BetaComputer`] to apply filtering. type QueryComputer = BetaComputer; +} + +impl BuildQueryComputer for BetaAccessor +where + Inner: BuildQueryComputer + Accessor, +{ /// Use the same error as `Inner`. type QueryComputerError = Inner::QueryComputerError; @@ -305,15 +311,12 @@ where } } -impl ExpandBeam for BetaAccessor where - Inner: BuildQueryComputer + AsNeighbor + Accessor +impl ExpandBeam for BetaAccessor where + Inner: DistancesUnordered + AsNeighbor + Accessor { } -impl DistancesUnordered for BetaAccessor where - Inner: BuildQueryComputer + Accessor -{ -} +impl DistancesUnordered for BetaAccessor where Inner: HasQueryComputer + Accessor {} /// A [`PreprocessedDistanceFunction`] that applied `beta` filtering to the inner computer. pub struct BetaComputer { @@ -504,8 +507,11 @@ mod tests { } } - impl BuildQueryComputer for Doubler { + impl HasQueryComputer for Doubler { type QueryComputer = AddingComputer; + } + + impl BuildQueryComputer for Doubler { type QueryComputerError = ANNError; fn build_query_computer( @@ -516,9 +522,9 @@ mod tests { } } - impl ExpandBeam for Doubler {} + impl ExpandBeam for Doubler {} - impl DistancesUnordered for Doubler {} + impl DistancesUnordered for Doubler {} #[derive(Debug)] struct SimpleStrategy; diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index 9229a3a92..579be8be2 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -33,7 +33,7 @@ use diskann_vector::PreprocessedDistanceFunction; use crate::{ error::ToRanked, - provider::{BuildQueryComputer, HasElementRef, HasId}, + provider::{HasElementRef, HasId, HasQueryComputer}, }; /// Callback-driven sequential scan over the elements of a flat index. @@ -51,21 +51,19 @@ pub trait OnElementsUnordered: HasId + HasElementRef + Send + Sync { /// Extension of [`OnElementsUnordered`] that drives the scan with a query computer. /// -/// The computer is produced by the visitor's [`BuildQueryComputer`] impl, -/// and invokes a callback with `(id, distance)` pairs. -/// -/// This fuses the scan with a pre-processed query computer and runs over a -/// streaming visitor. It pulls the computer type from the implementor's own -/// [`BuildQueryComputer`] impl. +/// The computer type is named via the [`HasQueryComputer`] supertrait. +/// This trait does **not** require [`BuildQueryComputer`] — callers that also need to +/// *construct* a computer should bound on `DistancesUnordered + BuildQueryComputer` +/// at the call site. /// /// The default implementation delegates to [`OnElementsUnordered::on_elements_unordered`], /// calling `computer.evaluate_similarity` on each element. -pub trait DistancesUnordered: OnElementsUnordered + BuildQueryComputer { +pub trait DistancesUnordered: OnElementsUnordered + HasQueryComputer { /// Drive the entire scan, scoring each element with `computer` and invoking `f` with /// the resulting `(id, distance)` pair. fn distances_unordered( &mut self, - computer: &>::QueryComputer, + computer: &::QueryComputer, mut f: F, ) -> impl SendFuture::Error>> where @@ -178,7 +176,7 @@ mod tests { use crate::{ ANNError, always_escalate, error::Infallible, - provider::{BuildQueryComputer, HasElementRef, HasId}, + provider::{BuildQueryComputer, HasElementRef, HasId, HasQueryComputer}, utils::VectorRepr, }; @@ -214,9 +212,10 @@ mod tests { // Common impl macro // /////////////////////// - /// Implement [`HasId`], [`HasElementRef`], [`BuildQueryComputer`], and - /// [`DistancesUnordered`] for an iterator type. Every fixture in this module - /// shares these impls — only [`FlatIterator::Element`] varies. + /// Implement [`HasId`], [`HasElementRef`], [`HasQueryComputer`], + /// [`BuildQueryComputer`], and [`DistancesUnordered`] for an iterator type. + /// Every fixture in this module shares these impls — only + /// [`FlatIterator::Element`] varies. macro_rules! common_iterator_impls { ($T:ty) => { impl HasId for $T { @@ -227,9 +226,12 @@ mod tests { type ElementRef<'a> = &'a [f32]; } + impl HasQueryComputer for $T { + type QueryComputer = ::QueryDistance; + } + impl BuildQueryComputer<&[f32]> for $T { type QueryComputerError = Infallible; - type QueryComputer = ::QueryDistance; fn build_query_computer( &self, @@ -239,11 +241,15 @@ mod tests { } } - // Forward `BuildQueryComputer` through the `Iterated` adapter so the - // `DistancesUnordered` supertrait bound is satisfied. + // Forward `HasQueryComputer` and `BuildQueryComputer` through the + // `Iterated` adapter so the `DistancesUnordered` supertrait bound + // is satisfied. + impl HasQueryComputer for Iterated<$T> { + type QueryComputer = ::QueryDistance; + } + impl BuildQueryComputer<&[f32]> for Iterated<$T> { type QueryComputerError = Infallible; - type QueryComputer = ::QueryDistance; fn build_query_computer( &self, @@ -253,7 +259,7 @@ mod tests { } } - impl DistancesUnordered<&[f32]> for Iterated<$T> {} + impl DistancesUnordered for Iterated<$T> {} }; } @@ -588,11 +594,9 @@ mod tests { I: FlatIterator + Send + Sync, I: for<'a> HasElementRef = &'a [f32]>, Iterated: HasId - + for<'q> BuildQueryComputer< - &'q [f32], - QueryComputerError = Infallible, - QueryComputer = ::QueryDistance, - > + for<'q> DistancesUnordered<&'q [f32]>, + + HasQueryComputer::QueryDistance> + + for<'q> BuildQueryComputer<&'q [f32], QueryComputerError = Infallible> + + DistancesUnordered, { let computer = visitor.build_query_computer(query).unwrap(); let mut seen: Vec<(u32, f32)> = Vec::new(); diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs index a3ff980b1..98742873b 100644 --- a/diskann/src/flat/strategy.rs +++ b/diskann/src/flat/strategy.rs @@ -6,10 +6,14 @@ //! [`SearchStrategy`] — glue between [`DataProvider`] and per-query //! [`DistancesUnordered`] visitors. -use crate::{error::StandardError, flat::DistancesUnordered, provider::DataProvider}; +use crate::{ + error::StandardError, + flat::DistancesUnordered, + provider::{BuildQueryComputer, DataProvider}, +}; /// Per-call configuration that knows how to construct a per-query -/// [`DistancesUnordered`] visitor for a provider. +/// [`DistancesUnordered`] visitor for a provider. /// /// `SearchStrategy` is the flat counterpart to [`crate::graph::glue::SearchStrategy`] /// (disambiguated by module path). A strategy instance carries the per-query setup @@ -17,8 +21,8 @@ use crate::{error::StandardError, flat::DistancesUnordered, provider::DataProvid /// strategy may be reused across many searches. /// /// The strategy itself is a pure factory; the visitor it produces carries the -/// query-preprocessing capability via [`crate::provider::BuildQueryComputer`] (a -/// super-trait of [`DistancesUnordered`]). +/// query-preprocessing capability via [`crate::provider::BuildQueryComputer`] +/// (bound alongside [`DistancesUnordered`]). pub trait SearchStrategy: Send + Sync where P: DataProvider, @@ -26,9 +30,9 @@ where /// The visitor type produced by [`Self::create_visitor`]. Borrows from `self` and the /// provider. /// - /// The visitor implements both the streaming [`DistancesUnordered`] primitive and + /// The visitor implements both the streaming [`DistancesUnordered`] primitive and /// the query preprocessor [`crate::provider::BuildQueryComputer`]. - type Visitor<'a>: DistancesUnordered + type Visitor<'a>: DistancesUnordered + BuildQueryComputer where Self: 'a, P: 'a; diff --git a/diskann/src/flat/test/provider.rs b/diskann/src/flat/test/provider.rs index 9890410a3..0de4499c2 100644 --- a/diskann/src/flat/test/provider.rs +++ b/diskann/src/flat/test/provider.rs @@ -25,7 +25,10 @@ use crate::{ flat::{DistancesUnordered, OnElementsUnordered, SearchStrategy}, graph::test::synthetic::Grid, internal::counter::{Counter, LocalCounter}, - provider::{self, BuildQueryComputer, ExecutionContext, HasElementRef, HasId, NoopGuard}, + provider::{ + self, BuildQueryComputer, ExecutionContext, HasElementRef, HasId, HasQueryComputer, + NoopGuard, + }, utils::VectorRepr, }; @@ -346,9 +349,12 @@ impl HasElementRef for Visitor<'_> { type ElementRef<'a> = &'a [f32]; } +impl HasQueryComputer for Visitor<'_> { + type QueryComputer = ::QueryDistance; +} + impl BuildQueryComputer<&[f32]> for Visitor<'_> { type QueryComputerError = Infallible; - type QueryComputer = ::QueryDistance; fn build_query_computer( &self, @@ -381,7 +387,7 @@ impl OnElementsUnordered for Visitor<'_> { } } -impl DistancesUnordered<&[f32]> for Visitor<'_> {} +impl DistancesUnordered for Visitor<'_> {} ////////////// // Strategy // diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index 2f5bf5f0f..b7009b8ae 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -252,7 +252,7 @@ impl HybridPredicate for NotInMut<'_, T> where T: Clone + Eq + std::hash:: /// ## Error Handling /// /// Transient errors yielded by `distances_unordered` are acknowledged and not escalated. -pub trait ExpandBeam: DistancesUnordered + AsNeighbor + Sized { +pub trait ExpandBeam: DistancesUnordered + AsNeighbor + Sized { fn expand_beam( &mut self, ids: Itr, @@ -311,7 +311,8 @@ where /// The concrete type of the accessor that is used to access `Self` during the greedy /// graph search. The query will be provided to the accessor exactly once during search /// to construct the query computer. - type SearchAccessor<'a>: ExpandBeam + type SearchAccessor<'a>: ExpandBeam + + BuildQueryComputer + SearchExt; /// Construct and return the search accessor. @@ -396,7 +397,7 @@ where &self, accessor: &mut A, query: T, - computer: &>::QueryComputer, + computer: &A::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send @@ -455,7 +456,7 @@ where next: &Next, accessor: &mut A, query: T, - computer: &>::QueryComputer, + computer: &A::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future>> + Send @@ -550,7 +551,7 @@ where &self, accessor: &mut A, query: T, - computer: &>::QueryComputer, + computer: &A::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send @@ -786,7 +787,8 @@ where /// of associated types. /// /// Lifting the accessor all the way to the trait level makes the caching provider possible. - type DeleteSearchAccessor<'a>: ExpandBeam, Id = Provider::InternalId> + type DeleteSearchAccessor<'a>: ExpandBeam + + BuildQueryComputer> + SearchExt; /// The processor used during the delete-search phase. @@ -854,7 +856,10 @@ mod tests { use super::*; use crate::{ ANNResult, neighbor, - provider::{DelegateNeighbor, ExecutionContext, HasElementRef, HasId, NeighborAccessor}, + provider::{ + DelegateNeighbor, ExecutionContext, HasElementRef, HasId, HasQueryComputer, + NeighborAccessor, + }, }; // A really simple provider that just holds floats and uses the absolute value for its @@ -972,17 +977,20 @@ mod tests { } } + impl HasQueryComputer for Retriever<'_> { + type QueryComputer = QueryComputer; + } + impl BuildQueryComputer for Retriever<'_> { type QueryComputerError = ANNError; - type QueryComputer = QueryComputer; fn build_query_computer(&self, _from: f32) -> Result { Ok(QueryComputer) } } - impl ExpandBeam for Retriever<'_> {} + impl ExpandBeam for Retriever<'_> {} - impl DistancesUnordered for Retriever<'_> {} + impl DistancesUnordered for Retriever<'_> {} // This strategy explicitly does not define `post_process` so we can test the provided // implementation. diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index eca65b8e6..71f1949e7 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -2007,7 +2007,7 @@ where } // A is the accessor type, T is the query type used for BuildQueryComputer - pub(crate) fn search_internal( + pub(crate) fn search_internal( &self, beam_width: Option, start_ids: &[DP::InternalId], @@ -2017,7 +2017,7 @@ where search_record: &mut SR, ) -> impl SendFuture> where - A: ExpandBeam + SearchExt, + A: ExpandBeam + SearchExt, SR: SearchRecord + ?Sized, Q: NeighborQueue, { diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index aba0f44c5..fee3eab35 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -171,7 +171,7 @@ impl HybridPredicate for NotInMutWithLabelCheck<'_, K> where K: VectorId { /// /// Performs label-filtered search by expanding through non-matching nodes /// to find matching neighbors within two hops. -pub(crate) async fn multihop_search_internal( +pub(crate) async fn multihop_search_internal( max_degree_with_slack: usize, search_params: &Knn, accessor: &mut A, @@ -182,7 +182,7 @@ pub(crate) async fn multihop_search_internal( ) -> ANNResult where I: VectorId, - A: ExpandBeam + SearchExt, + A: ExpandBeam + SearchExt, SR: SearchRecord + ?Sized, { let beam_width = search_params.beam_width().get(); diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index b92fa1384..e0e1824db 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -323,7 +323,7 @@ where /// /// Expands the search frontier to find all points within the specified radius. /// Called after the initial graph search has identified starting candidates. -pub(crate) async fn range_search_internal( +pub(crate) async fn range_search_internal( max_degree_with_slack: usize, search_params: &Range, accessor: &mut A, @@ -332,7 +332,7 @@ pub(crate) async fn range_search_internal( ) -> ANNResult where I: crate::utils::VectorId, - A: ExpandBeam + SearchExt, + A: ExpandBeam + SearchExt, { let beam_width = search_params.beam_width().unwrap_or(1); diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 4969ebe20..9519b60ef 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -1068,9 +1068,12 @@ impl<'a> provider::DelegateNeighbor<'a> for Accessor<'_> { } } +impl provider::HasQueryComputer for Accessor<'_> { + type QueryComputer = ::QueryDistance; +} + impl provider::BuildQueryComputer<&[f32]> for Accessor<'_> { type QueryComputerError = Infallible; - type QueryComputer = ::QueryDistance; fn build_query_computer( &self, @@ -1080,7 +1083,7 @@ impl provider::BuildQueryComputer<&[f32]> for Accessor<'_> { } } -impl provider::DistancesUnordered<&[f32]> for Accessor<'_> {} +impl provider::DistancesUnordered for Accessor<'_> {} impl provider::BuildDistanceComputer for Accessor<'_> { type DistanceComputerError = Infallible; @@ -1106,7 +1109,7 @@ impl glue::SearchExt for Accessor<'_> { } } -impl glue::ExpandBeam<&[f32]> for Accessor<'_> {} +impl glue::ExpandBeam for Accessor<'_> {} impl glue::IdIterator> for Accessor<'_> { async fn id_iterator(&mut self) -> Result, ANNError> { diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index 82a5d9a3c..ff067d49c 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -71,11 +71,14 @@ //! * [`BuildDistanceComputer`]: A sub-trait of [`Accessor`] that allows for random-access //! distance computations on the retrieved elements. //! -//! * [`BuildQueryComputer`]: A sub-trait of [`HasElementRef`] that allows for specialized +//! * [`HasQueryComputer`]: Names the canonical query-computer type for an accessor, +//! analogous to [`HasElementRef`]. +//! +//! * [`BuildQueryComputer`]: A sub-trait of [`HasQueryComputer`] that allows for specialized //! query based computations. This allows a query to be pre-processed in a way that allows //! faster computations. //! -//! * [`DistancesUnordered`]: A sub-trait of [`Accessor`] and [`BuildQueryComputer`] that +//! * [`DistancesUnordered`]: A sub-trait of [`Accessor`] and [`HasQueryComputer`] that //! provides a fused iterate-and-compute primitive over a set of element ids using a //! pre-built query computer. //! @@ -533,28 +536,45 @@ pub trait BuildDistanceComputer: Accessor { ) -> Result; } +////////////////////// +// HasQueryComputer // +////////////////////// + +/// Declare the canonical query-computer type for this accessor. +/// +/// This is the query-computer analogue of [`HasElementRef`]: it names the associated +/// type without imposing any construction requirement. Traits that only *use* a +/// pre-built computer (e.g. [`DistancesUnordered`]) require this trait; traits that +/// also *build* one (e.g. [`BuildQueryComputer`]) extend it. +pub trait HasQueryComputer: HasElementRef { + /// The concrete type of the distance computer, which must be applicable for all + /// elements yielded by the [`Accessor`]. + type QueryComputer: for<'a> PreprocessedDistanceFunction, f32> + + Send + + Sync + + 'static; +} + +///////////////////////// +// BuildQueryComputer // +///////////////////////// + /// A trait that provides query computations for a query type `T`. /// /// Query computers are allowed to preprocess the query to enable more efficient distance /// computations. /// -/// This trait only requires [`HasElementRef`] (so the query computer's element type can be -/// named) so that it can be used with multiple access patterns - like [`Accessor`] and -/// [`crate::flat::FlastSearchStrategy`]. +/// This trait extends [`HasQueryComputer`] (which names the computer type) with a +/// factory method that constructs a computer from a query of type `T`. /// /// A fused iterate-and-compute primitive can be created as a sub-trait - -/// e.g. [`DistancesUnordered`], which requires both [`Accessor`] and `BuildQueryComputer`. -pub trait BuildQueryComputer: HasElementRef { +/// e.g. [`DistancesUnordered`], which requires [`Accessor`] and [`HasQueryComputer`]. +/// Callers that need both iteration *and* computer construction should bound on +/// `DistancesUnordered + BuildQueryComputer` at the call site. +pub trait BuildQueryComputer: HasQueryComputer { /// The error type (if any) associated with distance computer construction. type QueryComputerError: std::error::Error + Into + Send + Sync + 'static; - /// The concrete type of the distance computer, which must be applicable for all - /// elements yielded by the [`Accessor`]. - type QueryComputer: for<'a> PreprocessedDistanceFunction, f32> - + Send - + Sync - + 'static; - /// Build the query computer for this accessor. /// /// This method is encouraged to be as fast as possible, but will generally only be @@ -565,12 +585,16 @@ pub trait BuildQueryComputer: HasElementRef { ) -> Result; } -/// A sub-trait of [`Accessor`] and [`BuildQueryComputer`] that exposes the fused +/// A sub-trait of [`Accessor`] and [`HasQueryComputer`] that exposes the fused /// iterate-and-compute primitive `distances_unordered`. /// /// The default implementation uses [`Accessor::on_elements_unordered`] to iterate over the /// elements and computes their distances using the provided `computer`. -pub trait DistancesUnordered: Accessor + BuildQueryComputer { +/// +/// This trait does **not** require [`BuildQueryComputer`] — it only needs the computer +/// *type* via [`HasQueryComputer`]. Callers that also need to *construct* a computer +/// should bound on `DistancesUnordered + BuildQueryComputer` at the call site. +pub trait DistancesUnordered: Accessor + HasQueryComputer { /// Compute the distances for the elements in the iterator `vec_id_itr` using the /// `computer` and apply the closure `f` to each distance and ID. fn distances_unordered( From 3620c6d338cd65bb3b1e4c6b7e4be4a94ed6a2aa Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Mon, 11 May 2026 17:17:48 -0700 Subject: [PATCH 21/24] Revert "Add HasQueryComputer trait" This reverts commit 292840422980ae7a1df154af0aa1abef860f4ed3. --- .../src/search/provider/disk_provider.rs | 15 ++--- diskann-garnet/src/provider.rs | 19 +++---- .../encoded_document_accessor.rs | 20 +++---- .../inline_beta_search/inline_beta_filter.rs | 2 +- .../graph/provider/async_/bf_tree/provider.rs | 29 +++------- .../graph/provider/async_/caching/provider.rs | 21 +++---- .../provider/async_/inmem/full_precision.rs | 17 ++---- .../graph/provider/async_/inmem/product.rs | 18 ++---- .../graph/provider/async_/inmem/scalar.rs | 21 +++---- .../graph/provider/async_/inmem/spherical.rs | 20 +++---- .../model/graph/provider/async_/inmem/test.rs | 15 ++--- .../graph/provider/async_/postprocess.rs | 4 +- .../model/graph/provider/layers/betafilter.rs | 30 ++++------ diskann/src/flat/iterator.rs | 50 ++++++++--------- diskann/src/flat/strategy.rs | 16 ++---- diskann/src/flat/test/provider.rs | 12 +--- diskann/src/graph/glue.rs | 28 ++++------ diskann/src/graph/index.rs | 4 +- diskann/src/graph/search/multihop_search.rs | 4 +- diskann/src/graph/search/range_search.rs | 4 +- diskann/src/graph/test/provider.rs | 9 +-- diskann/src/provider.rs | 56 ++++++------------- 22 files changed, 147 insertions(+), 267 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 7470e8ef1..3dbe517d5 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -29,7 +29,7 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, - DistancesUnordered, HasElementRef, HasId, HasQueryComputer, NeighborAccessor, NoopGuard, + DistancesUnordered, HasElementRef, HasId, NeighborAccessor, NoopGuard, }, utils::{IntoUsize, VectorRepr}, ANNError, ANNResult, @@ -406,20 +406,13 @@ impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer { } } -impl HasQueryComputer for DiskAccessor<'_, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - type QueryComputer = DiskQueryComputer; -} - impl BuildQueryComputer<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> where Data: GraphDataType, VP: VertexProvider, { type QueryComputerError = ANNError; + type QueryComputer = DiskQueryComputer; fn build_query_computer( &self, @@ -436,7 +429,7 @@ where } } -impl DistancesUnordered for DiskAccessor<'_, Data, VP> +impl DistancesUnordered<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> where Data: GraphDataType, VP: VertexProvider, @@ -455,7 +448,7 @@ where } } -impl ExpandBeam for DiskAccessor<'_, Data, VP> +impl ExpandBeam<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP> where Data: GraphDataType, VP: VertexProvider, diff --git a/diskann-garnet/src/provider.rs b/diskann-garnet/src/provider.rs index 64389f2c2..8350c6cb8 100644 --- a/diskann-garnet/src/provider.rs +++ b/diskann-garnet/src/provider.rs @@ -18,8 +18,8 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DelegateNeighbor, - Delete, DistancesUnordered, ElementStatus, HasElementRef, HasId, HasQueryComputer, - NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, + Delete, DistancesUnordered, ElementStatus, HasElementRef, HasId, NeighborAccessor, + NeighborAccessorMut, NoopGuard, SetElement, }, utils::VectorRepr, }; @@ -467,7 +467,7 @@ impl SearchExt for FullAccessor<'_, T> { } } -impl ExpandBeam for FullAccessor<'_, T> { +impl ExpandBeam<&[T]> for FullAccessor<'_, T> { fn expand_beam( &mut self, ids: Itr, @@ -573,11 +573,8 @@ impl BuildDistanceComputer for FullAccessor<'_, T> { } } -impl HasQueryComputer for FullAccessor<'_, T> { - type QueryComputer = T::QueryDistance; -} - impl BuildQueryComputer<&[T]> for FullAccessor<'_, T> { + type QueryComputer = T::QueryDistance; type QueryComputerError = GarnetProviderError; fn build_query_computer( @@ -588,7 +585,7 @@ impl BuildQueryComputer<&[T]> for FullAccessor<'_, T> { } } -impl DistancesUnordered for FullAccessor<'_, T> {} +impl DistancesUnordered<&[T]> for FullAccessor<'_, T> {} /// An escape hatch for the blanket implementation of [`workingset::Fill`]. /// @@ -763,14 +760,16 @@ impl NeighborAccessorMut for DelegateNeighborAccessor<'_, '_, T> #[derive(Debug, Default, Clone, Copy)] pub struct CopyExternalIds; -impl<'a, T: VectorRepr> SearchPostProcess, &[T], GarnetId> for CopyExternalIds { +impl<'a, 'b, T: VectorRepr> SearchPostProcess, &'b [T], GarnetId> + for CopyExternalIds +{ type Error = GarnetProviderError; fn post_process( &self, accessor: &mut FullAccessor<'a, T>, _query: &[T], - _computer: & as HasQueryComputer>::QueryComputer, + _computer: & as BuildQueryComputer<&'b [T]>>::QueryComputer, candidates: I, output: &mut B, ) -> impl Future> + Send diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index cd25f1b5f..15a2b5ce9 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -10,7 +10,7 @@ use diskann::{ graph::glue::{ExpandBeam, SearchExt}, provider::{ Accessor, AsNeighbor, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, - HasElementRef, HasId, HasQueryComputer, + HasElementRef, HasId, }, ANNError, ANNErrorKind, }; @@ -171,18 +171,12 @@ where } } -impl HasQueryComputer for EncodedDocumentAccessor -where - IA: HasQueryComputer + Accessor, -{ - type QueryComputer = InlineBetaComputer; -} - impl<'q, IA, Q> BuildQueryComputer<&'q FilteredQuery> for EncodedDocumentAccessor where IA: BuildQueryComputer<&'q Q> + Accessor, { type QueryComputerError = ANNError; + type QueryComputer = InlineBetaComputer; fn build_query_computer( &self, @@ -202,17 +196,19 @@ where } } -impl ExpandBeam for EncodedDocumentAccessor +impl ExpandBeam for EncodedDocumentAccessor where IA: Accessor, - EncodedDocumentAccessor: DistancesUnordered + AsNeighbor, + EncodedDocumentAccessor: BuildQueryComputer + AsNeighbor, + Q: Clone, { } -impl DistancesUnordered for EncodedDocumentAccessor +impl DistancesUnordered for EncodedDocumentAccessor where IA: Accessor, - EncodedDocumentAccessor: HasQueryComputer, + EncodedDocumentAccessor: BuildQueryComputer, + Q: Clone, { } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 012c74e6f..10aef0109 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -153,7 +153,7 @@ where &self, accessor: &mut EncodedDocumentAccessor, query: &'q FilteredQuery, - computer: &InlineBetaComputer, + computer: &InlineBetaComputer<>::QueryComputer>, candidates: I, output: &mut B, ) -> Result diff --git a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs index 4f61b6900..e67c50d00 100644 --- a/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs @@ -29,7 +29,7 @@ use diskann::{ provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, Delete, DistancesUnordered, ElementStatus, HasElementRef, HasId, - HasQueryComputer, NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, + NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement, }, utils::{IntoUsize, VectorRepr}, }; @@ -1040,15 +1040,6 @@ where } } -impl HasQueryComputer for FullAccessor<'_, T, Q, D> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, -{ - type QueryComputer = T::QueryDistance; -} - impl BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D> where T: VectorRepr, @@ -1056,6 +1047,7 @@ where D: AsyncFriendly, { type QueryComputerError = Panics; + type QueryComputer = T::QueryDistance; fn build_query_computer( &self, @@ -1064,7 +1056,7 @@ where Ok(T::query_distance(from, self.provider.metric)) } } -impl ExpandBeam for FullAccessor<'_, T, Q, D> +impl ExpandBeam<&[T]> for FullAccessor<'_, T, Q, D> where T: VectorRepr, Q: AsyncFriendly, @@ -1072,7 +1064,7 @@ where { } -impl DistancesUnordered for FullAccessor<'_, T, Q, D> +impl DistancesUnordered<&[T]> for FullAccessor<'_, T, Q, D> where T: VectorRepr, Q: AsyncFriendly, @@ -1224,20 +1216,13 @@ where } } -impl HasQueryComputer for QuantAccessor<'_, T, D> -where - T: VectorRepr, - D: AsyncFriendly, -{ - type QueryComputer = pq::distance::QueryComputer>; -} - impl BuildQueryComputer<&[T]> for QuantAccessor<'_, T, D> where T: VectorRepr, D: AsyncFriendly, { type QueryComputerError = ANNError; + type QueryComputer = pq::distance::QueryComputer>; fn build_query_computer( &self, @@ -1247,14 +1232,14 @@ where } } -impl ExpandBeam for QuantAccessor<'_, T, D> +impl ExpandBeam<&[T]> for QuantAccessor<'_, T, D> where T: VectorRepr, D: AsyncFriendly, { } -impl DistancesUnordered for QuantAccessor<'_, T, D> +impl DistancesUnordered<&[T]> for QuantAccessor<'_, T, D> where T: VectorRepr, D: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs index e9962d6b8..09899c781 100644 --- a/diskann-providers/src/model/graph/provider/async_/caching/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/caching/provider.rs @@ -75,7 +75,7 @@ use diskann::{ provider::{ Accessor, AsNeighbor, BuildDistanceComputer, BuildQueryComputer, CacheableAccessor, DataProvider, DelegateNeighbor, Delete, DistancesUnordered, ElementStatus, HasElementRef, - HasId, HasQueryComputer, NeighborAccessor, NeighborAccessorMut, SetElement, + HasId, NeighborAccessor, NeighborAccessorMut, SetElement, }, }; use diskann_utils::{ @@ -781,20 +781,13 @@ where } } -impl HasQueryComputer for CachingAccessor -where - A: HasQueryComputer + CacheableAccessor, - C: ElementCache, -{ - type QueryComputer = A::QueryComputer; -} - impl BuildQueryComputer for CachingAccessor where A: BuildQueryComputer + CacheableAccessor, C: ElementCache, { type QueryComputerError = A::QueryComputerError; + type QueryComputer = A::QueryComputer; fn build_query_computer( &self, @@ -832,16 +825,16 @@ where } } -impl ExpandBeam for CachingAccessor +impl ExpandBeam for CachingAccessor where - A: DistancesUnordered + CacheableAccessor + AsNeighbor, + A: BuildQueryComputer + CacheableAccessor + AsNeighbor, C: ElementCache + NeighborCache, { } -impl DistancesUnordered for CachingAccessor +impl DistancesUnordered for CachingAccessor where - A: HasQueryComputer + CacheableAccessor + AsNeighbor, + A: BuildQueryComputer + CacheableAccessor + AsNeighbor, C: ElementCache + NeighborCache, { } @@ -867,7 +860,7 @@ where next: &Next, accessor: &mut CachingAccessor, query: T, - computer: &::QueryComputer, + computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl Future>> + Send diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 5c6a3988e..0b5eb1b72 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -20,7 +20,7 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - DistancesUnordered, ExecutionContext, HasElementRef, HasId, HasQueryComputer, + DistancesUnordered, ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -296,16 +296,6 @@ where } } -impl HasQueryComputer for FullAccessor<'_, T, Q, D, Ctx> -where - T: VectorRepr, - Q: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputer = T::QueryDistance; -} - impl BuildQueryComputer<&[T]> for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, @@ -314,6 +304,7 @@ where Ctx: ExecutionContext, { type QueryComputerError = Panics; + type QueryComputer = T::QueryDistance; fn build_query_computer( &self, @@ -323,7 +314,7 @@ where } } -impl ExpandBeam for FullAccessor<'_, T, Q, D, Ctx> +impl ExpandBeam<&[T]> for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, Q: AsyncFriendly, @@ -332,7 +323,7 @@ where { } -impl DistancesUnordered for FullAccessor<'_, T, Q, D, Ctx> +impl DistancesUnordered<&[T]> for FullAccessor<'_, T, Q, D, Ctx> where T: VectorRepr, Q: AsyncFriendly, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs index 6022d4520..477dc517d 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/product.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/product.rs @@ -17,7 +17,7 @@ use diskann::{ }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, - ExecutionContext, HasElementRef, HasId, HasQueryComputer, + ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -205,15 +205,6 @@ where } } -impl HasQueryComputer for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputer = pq::distance::QueryComputer>; -} - impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> where T: VectorRepr, @@ -222,6 +213,7 @@ where Ctx: ExecutionContext, { type QueryComputerError = ANNError; + type QueryComputer = pq::distance::QueryComputer>; fn build_query_computer( &self, @@ -247,16 +239,18 @@ where } } -impl ExpandBeam for QuantAccessor<'_, V, D, Ctx> +impl ExpandBeam<&[T]> for QuantAccessor<'_, V, D, Ctx> where + T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, { } -impl DistancesUnordered for QuantAccessor<'_, V, D, Ctx> +impl DistancesUnordered<&[T]> for QuantAccessor<'_, V, D, Ctx> where + T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs index 5e130e8f9..c9c781fed 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/scalar.rs @@ -17,7 +17,7 @@ use diskann::{ }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, - ExecutionContext, HasElementRef, HasId, HasQueryComputer, + ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -527,17 +527,6 @@ where } } -impl HasQueryComputer for QuantAccessor<'_, NBITS, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, - Unsigned: Representation, - QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, -{ - type QueryComputer = QueryComputer; -} - impl BuildQueryComputer<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx> where @@ -549,6 +538,7 @@ where QueryComputer: for<'a> PreprocessedDistanceFunction, f32>, { type QueryComputerError = ANNError; + type QueryComputer = QueryComputer; fn build_query_computer( &self, @@ -562,8 +552,9 @@ where } } -impl ExpandBeam for QuantAccessor<'_, NBITS, V, D, Ctx> +impl ExpandBeam<&[T]> for QuantAccessor<'_, NBITS, V, D, Ctx> where + T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, @@ -572,8 +563,10 @@ where { } -impl DistancesUnordered for QuantAccessor<'_, NBITS, V, D, Ctx> +impl DistancesUnordered<&[T]> + for QuantAccessor<'_, NBITS, V, D, Ctx> where + T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs index da6731b69..0edaf4a77 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/spherical.rs @@ -19,7 +19,7 @@ use diskann::{ }, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DelegateNeighbor, DistancesUnordered, - ExecutionContext, HasElementRef, HasId, HasQueryComputer, + ExecutionContext, HasElementRef, HasId, }, utils::{IntoUsize, VectorRepr}, }; @@ -441,16 +441,6 @@ where } } -impl HasQueryComputer for QuantAccessor<'_, V, D, Ctx> -where - V: AsyncFriendly, - D: AsyncFriendly, - Ctx: ExecutionContext, -{ - type QueryComputer = - UnwrapErr; -} - impl BuildQueryComputer<&[T]> for QuantAccessor<'_, V, D, Ctx> where T: VectorRepr, @@ -459,6 +449,8 @@ where Ctx: ExecutionContext, { type QueryComputerError = Bridge; + type QueryComputer = + UnwrapErr; fn build_query_computer( &self, @@ -472,16 +464,18 @@ where } } -impl ExpandBeam for QuantAccessor<'_, V, D, Ctx> +impl ExpandBeam<&[T]> for QuantAccessor<'_, V, D, Ctx> where + T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, { } -impl DistancesUnordered for QuantAccessor<'_, V, D, Ctx> +impl DistancesUnordered<&[T]> for QuantAccessor<'_, V, D, Ctx> where + T: VectorRepr, V: AsyncFriendly, D: AsyncFriendly, Ctx: ExecutionContext, diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs index 692c2913f..61a99498f 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/test.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/test.rs @@ -21,7 +21,7 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, BuildDistanceComputer, BuildQueryComputer, DefaultContext, DelegateNeighbor, - DistancesUnordered, HasElementRef, HasId, HasQueryComputer, + DistancesUnordered, HasElementRef, HasId, }, utils::IntoUsize, }; @@ -218,13 +218,10 @@ impl<'a> BuildDistanceComputer for FlakyAccessor<'a> { } } -impl<'b> HasQueryComputer for FlakyAccessor<'b> { - type QueryComputer = as HasQueryComputer>::QueryComputer; -} - impl<'a, 'b> BuildQueryComputer<&'a [f32]> for FlakyAccessor<'b> { type QueryComputerError = as BuildQueryComputer<&'a [f32]>>::QueryComputerError; + type QueryComputer = as BuildQueryComputer<&'a [f32]>>::QueryComputer; fn build_query_computer( &self, @@ -234,9 +231,9 @@ impl<'a, 'b> BuildQueryComputer<&'a [f32]> for FlakyAccessor<'b> { } } -impl ExpandBeam for FlakyAccessor<'_> {} +impl ExpandBeam<&[f32]> for FlakyAccessor<'_> {} -impl DistancesUnordered for FlakyAccessor<'_> {} +impl DistancesUnordered<&[f32]> for FlakyAccessor<'_> {} impl<'a> DelegateNeighbor<'a> for FlakyAccessor<'_> { type Delegate = &'a SimpleNeighborProviderAsync; @@ -245,8 +242,8 @@ impl<'a> DelegateNeighbor<'a> for FlakyAccessor<'_> { } } -impl SearchStrategy for Flaky { - type QueryComputer = as HasQueryComputer>::QueryComputer; +impl<'x> SearchStrategy for Flaky { + type QueryComputer = as BuildQueryComputer<&'x [f32]>>::QueryComputer; type SearchAccessor<'a> = FlakyAccessor<'a>; type SearchAccessorError = ANNError; diff --git a/diskann-providers/src/model/graph/provider/async_/postprocess.rs b/diskann-providers/src/model/graph/provider/async_/postprocess.rs index b652c8bdb..dbbf08fa4 100644 --- a/diskann-providers/src/model/graph/provider/async_/postprocess.rs +++ b/diskann-providers/src/model/graph/provider/async_/postprocess.rs @@ -8,7 +8,7 @@ use diskann::{ graph::{SearchOutputBuffer, glue}, neighbor::Neighbor, - provider::{BuildQueryComputer, HasId, HasQueryComputer}, + provider::{BuildQueryComputer, HasId}, }; /// A bridge allowing `Accessors` to opt-in to [`RemoveDeletedIdsAndCopy`] by delegating to @@ -47,7 +47,7 @@ where &self, accessor: &mut A, _query: T, - _computer: &::QueryComputer, + _computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send diff --git a/diskann-providers/src/model/graph/provider/layers/betafilter.rs b/diskann-providers/src/model/graph/provider/layers/betafilter.rs index 7719f5206..ff822dd9d 100644 --- a/diskann-providers/src/model/graph/provider/layers/betafilter.rs +++ b/diskann-providers/src/model/graph/provider/layers/betafilter.rs @@ -26,7 +26,7 @@ use diskann::{ neighbor::Neighbor, provider::{ Accessor, AsNeighbor, BuildQueryComputer, DataProvider, DelegateNeighbor, - DistancesUnordered, HasElementRef, HasId, HasQueryComputer, + DistancesUnordered, HasElementRef, HasId, }, utils::VectorId, }; @@ -286,18 +286,12 @@ where } } -impl HasQueryComputer for BetaAccessor -where - Inner: HasQueryComputer + Accessor, -{ - /// Use a [`BetaComputer`] to apply filtering. - type QueryComputer = BetaComputer; -} - impl BuildQueryComputer for BetaAccessor where Inner: BuildQueryComputer + Accessor, { + /// Use a [`BetaComputer`] to apply filtering. + type QueryComputer = BetaComputer; /// Use the same error as `Inner`. type QueryComputerError = Inner::QueryComputerError; @@ -311,12 +305,15 @@ where } } -impl ExpandBeam for BetaAccessor where - Inner: DistancesUnordered + AsNeighbor + Accessor +impl ExpandBeam for BetaAccessor where + Inner: BuildQueryComputer + AsNeighbor + Accessor { } -impl DistancesUnordered for BetaAccessor where Inner: HasQueryComputer + Accessor {} +impl DistancesUnordered for BetaAccessor where + Inner: BuildQueryComputer + Accessor +{ +} /// A [`PreprocessedDistanceFunction`] that applied `beta` filtering to the inner computer. pub struct BetaComputer { @@ -507,11 +504,8 @@ mod tests { } } - impl HasQueryComputer for Doubler { - type QueryComputer = AddingComputer; - } - impl BuildQueryComputer for Doubler { + type QueryComputer = AddingComputer; type QueryComputerError = ANNError; fn build_query_computer( @@ -522,9 +516,9 @@ mod tests { } } - impl ExpandBeam for Doubler {} + impl ExpandBeam for Doubler {} - impl DistancesUnordered for Doubler {} + impl DistancesUnordered for Doubler {} #[derive(Debug)] struct SimpleStrategy; diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index 579be8be2..9229a3a92 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -33,7 +33,7 @@ use diskann_vector::PreprocessedDistanceFunction; use crate::{ error::ToRanked, - provider::{HasElementRef, HasId, HasQueryComputer}, + provider::{BuildQueryComputer, HasElementRef, HasId}, }; /// Callback-driven sequential scan over the elements of a flat index. @@ -51,19 +51,21 @@ pub trait OnElementsUnordered: HasId + HasElementRef + Send + Sync { /// Extension of [`OnElementsUnordered`] that drives the scan with a query computer. /// -/// The computer type is named via the [`HasQueryComputer`] supertrait. -/// This trait does **not** require [`BuildQueryComputer`] — callers that also need to -/// *construct* a computer should bound on `DistancesUnordered + BuildQueryComputer` -/// at the call site. +/// The computer is produced by the visitor's [`BuildQueryComputer`] impl, +/// and invokes a callback with `(id, distance)` pairs. +/// +/// This fuses the scan with a pre-processed query computer and runs over a +/// streaming visitor. It pulls the computer type from the implementor's own +/// [`BuildQueryComputer`] impl. /// /// The default implementation delegates to [`OnElementsUnordered::on_elements_unordered`], /// calling `computer.evaluate_similarity` on each element. -pub trait DistancesUnordered: OnElementsUnordered + HasQueryComputer { +pub trait DistancesUnordered: OnElementsUnordered + BuildQueryComputer { /// Drive the entire scan, scoring each element with `computer` and invoking `f` with /// the resulting `(id, distance)` pair. fn distances_unordered( &mut self, - computer: &::QueryComputer, + computer: &>::QueryComputer, mut f: F, ) -> impl SendFuture::Error>> where @@ -176,7 +178,7 @@ mod tests { use crate::{ ANNError, always_escalate, error::Infallible, - provider::{BuildQueryComputer, HasElementRef, HasId, HasQueryComputer}, + provider::{BuildQueryComputer, HasElementRef, HasId}, utils::VectorRepr, }; @@ -212,10 +214,9 @@ mod tests { // Common impl macro // /////////////////////// - /// Implement [`HasId`], [`HasElementRef`], [`HasQueryComputer`], - /// [`BuildQueryComputer`], and [`DistancesUnordered`] for an iterator type. - /// Every fixture in this module shares these impls — only - /// [`FlatIterator::Element`] varies. + /// Implement [`HasId`], [`HasElementRef`], [`BuildQueryComputer`], and + /// [`DistancesUnordered`] for an iterator type. Every fixture in this module + /// shares these impls — only [`FlatIterator::Element`] varies. macro_rules! common_iterator_impls { ($T:ty) => { impl HasId for $T { @@ -226,12 +227,9 @@ mod tests { type ElementRef<'a> = &'a [f32]; } - impl HasQueryComputer for $T { - type QueryComputer = ::QueryDistance; - } - impl BuildQueryComputer<&[f32]> for $T { type QueryComputerError = Infallible; + type QueryComputer = ::QueryDistance; fn build_query_computer( &self, @@ -241,15 +239,11 @@ mod tests { } } - // Forward `HasQueryComputer` and `BuildQueryComputer` through the - // `Iterated` adapter so the `DistancesUnordered` supertrait bound - // is satisfied. - impl HasQueryComputer for Iterated<$T> { - type QueryComputer = ::QueryDistance; - } - + // Forward `BuildQueryComputer` through the `Iterated` adapter so the + // `DistancesUnordered` supertrait bound is satisfied. impl BuildQueryComputer<&[f32]> for Iterated<$T> { type QueryComputerError = Infallible; + type QueryComputer = ::QueryDistance; fn build_query_computer( &self, @@ -259,7 +253,7 @@ mod tests { } } - impl DistancesUnordered for Iterated<$T> {} + impl DistancesUnordered<&[f32]> for Iterated<$T> {} }; } @@ -594,9 +588,11 @@ mod tests { I: FlatIterator + Send + Sync, I: for<'a> HasElementRef = &'a [f32]>, Iterated: HasId - + HasQueryComputer::QueryDistance> - + for<'q> BuildQueryComputer<&'q [f32], QueryComputerError = Infallible> - + DistancesUnordered, + + for<'q> BuildQueryComputer< + &'q [f32], + QueryComputerError = Infallible, + QueryComputer = ::QueryDistance, + > + for<'q> DistancesUnordered<&'q [f32]>, { let computer = visitor.build_query_computer(query).unwrap(); let mut seen: Vec<(u32, f32)> = Vec::new(); diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs index 98742873b..a3ff980b1 100644 --- a/diskann/src/flat/strategy.rs +++ b/diskann/src/flat/strategy.rs @@ -6,14 +6,10 @@ //! [`SearchStrategy`] — glue between [`DataProvider`] and per-query //! [`DistancesUnordered`] visitors. -use crate::{ - error::StandardError, - flat::DistancesUnordered, - provider::{BuildQueryComputer, DataProvider}, -}; +use crate::{error::StandardError, flat::DistancesUnordered, provider::DataProvider}; /// Per-call configuration that knows how to construct a per-query -/// [`DistancesUnordered`] visitor for a provider. +/// [`DistancesUnordered`] visitor for a provider. /// /// `SearchStrategy` is the flat counterpart to [`crate::graph::glue::SearchStrategy`] /// (disambiguated by module path). A strategy instance carries the per-query setup @@ -21,8 +17,8 @@ use crate::{ /// strategy may be reused across many searches. /// /// The strategy itself is a pure factory; the visitor it produces carries the -/// query-preprocessing capability via [`crate::provider::BuildQueryComputer`] -/// (bound alongside [`DistancesUnordered`]). +/// query-preprocessing capability via [`crate::provider::BuildQueryComputer`] (a +/// super-trait of [`DistancesUnordered`]). pub trait SearchStrategy: Send + Sync where P: DataProvider, @@ -30,9 +26,9 @@ where /// The visitor type produced by [`Self::create_visitor`]. Borrows from `self` and the /// provider. /// - /// The visitor implements both the streaming [`DistancesUnordered`] primitive and + /// The visitor implements both the streaming [`DistancesUnordered`] primitive and /// the query preprocessor [`crate::provider::BuildQueryComputer`]. - type Visitor<'a>: DistancesUnordered + BuildQueryComputer + type Visitor<'a>: DistancesUnordered where Self: 'a, P: 'a; diff --git a/diskann/src/flat/test/provider.rs b/diskann/src/flat/test/provider.rs index 0de4499c2..9890410a3 100644 --- a/diskann/src/flat/test/provider.rs +++ b/diskann/src/flat/test/provider.rs @@ -25,10 +25,7 @@ use crate::{ flat::{DistancesUnordered, OnElementsUnordered, SearchStrategy}, graph::test::synthetic::Grid, internal::counter::{Counter, LocalCounter}, - provider::{ - self, BuildQueryComputer, ExecutionContext, HasElementRef, HasId, HasQueryComputer, - NoopGuard, - }, + provider::{self, BuildQueryComputer, ExecutionContext, HasElementRef, HasId, NoopGuard}, utils::VectorRepr, }; @@ -349,12 +346,9 @@ impl HasElementRef for Visitor<'_> { type ElementRef<'a> = &'a [f32]; } -impl HasQueryComputer for Visitor<'_> { - type QueryComputer = ::QueryDistance; -} - impl BuildQueryComputer<&[f32]> for Visitor<'_> { type QueryComputerError = Infallible; + type QueryComputer = ::QueryDistance; fn build_query_computer( &self, @@ -387,7 +381,7 @@ impl OnElementsUnordered for Visitor<'_> { } } -impl DistancesUnordered for Visitor<'_> {} +impl DistancesUnordered<&[f32]> for Visitor<'_> {} ////////////// // Strategy // diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index b7009b8ae..2f5bf5f0f 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -252,7 +252,7 @@ impl HybridPredicate for NotInMut<'_, T> where T: Clone + Eq + std::hash:: /// ## Error Handling /// /// Transient errors yielded by `distances_unordered` are acknowledged and not escalated. -pub trait ExpandBeam: DistancesUnordered + AsNeighbor + Sized { +pub trait ExpandBeam: DistancesUnordered + AsNeighbor + Sized { fn expand_beam( &mut self, ids: Itr, @@ -311,8 +311,7 @@ where /// The concrete type of the accessor that is used to access `Self` during the greedy /// graph search. The query will be provided to the accessor exactly once during search /// to construct the query computer. - type SearchAccessor<'a>: ExpandBeam - + BuildQueryComputer + type SearchAccessor<'a>: ExpandBeam + SearchExt; /// Construct and return the search accessor. @@ -397,7 +396,7 @@ where &self, accessor: &mut A, query: T, - computer: &A::QueryComputer, + computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send @@ -456,7 +455,7 @@ where next: &Next, accessor: &mut A, query: T, - computer: &A::QueryComputer, + computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future>> + Send @@ -551,7 +550,7 @@ where &self, accessor: &mut A, query: T, - computer: &A::QueryComputer, + computer: &>::QueryComputer, candidates: I, output: &mut B, ) -> impl std::future::Future> + Send @@ -787,8 +786,7 @@ where /// of associated types. /// /// Lifting the accessor all the way to the trait level makes the caching provider possible. - type DeleteSearchAccessor<'a>: ExpandBeam - + BuildQueryComputer> + type DeleteSearchAccessor<'a>: ExpandBeam, Id = Provider::InternalId> + SearchExt; /// The processor used during the delete-search phase. @@ -856,10 +854,7 @@ mod tests { use super::*; use crate::{ ANNResult, neighbor, - provider::{ - DelegateNeighbor, ExecutionContext, HasElementRef, HasId, HasQueryComputer, - NeighborAccessor, - }, + provider::{DelegateNeighbor, ExecutionContext, HasElementRef, HasId, NeighborAccessor}, }; // A really simple provider that just holds floats and uses the absolute value for its @@ -977,20 +972,17 @@ mod tests { } } - impl HasQueryComputer for Retriever<'_> { - type QueryComputer = QueryComputer; - } - impl BuildQueryComputer for Retriever<'_> { type QueryComputerError = ANNError; + type QueryComputer = QueryComputer; fn build_query_computer(&self, _from: f32) -> Result { Ok(QueryComputer) } } - impl ExpandBeam for Retriever<'_> {} + impl ExpandBeam for Retriever<'_> {} - impl DistancesUnordered for Retriever<'_> {} + impl DistancesUnordered for Retriever<'_> {} // This strategy explicitly does not define `post_process` so we can test the provided // implementation. diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index 71f1949e7..eca65b8e6 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -2007,7 +2007,7 @@ where } // A is the accessor type, T is the query type used for BuildQueryComputer - pub(crate) fn search_internal( + pub(crate) fn search_internal( &self, beam_width: Option, start_ids: &[DP::InternalId], @@ -2017,7 +2017,7 @@ where search_record: &mut SR, ) -> impl SendFuture> where - A: ExpandBeam + SearchExt, + A: ExpandBeam + SearchExt, SR: SearchRecord + ?Sized, Q: NeighborQueue, { diff --git a/diskann/src/graph/search/multihop_search.rs b/diskann/src/graph/search/multihop_search.rs index fee3eab35..aba0f44c5 100644 --- a/diskann/src/graph/search/multihop_search.rs +++ b/diskann/src/graph/search/multihop_search.rs @@ -171,7 +171,7 @@ impl HybridPredicate for NotInMutWithLabelCheck<'_, K> where K: VectorId { /// /// Performs label-filtered search by expanding through non-matching nodes /// to find matching neighbors within two hops. -pub(crate) async fn multihop_search_internal( +pub(crate) async fn multihop_search_internal( max_degree_with_slack: usize, search_params: &Knn, accessor: &mut A, @@ -182,7 +182,7 @@ pub(crate) async fn multihop_search_internal( ) -> ANNResult where I: VectorId, - A: ExpandBeam + SearchExt, + A: ExpandBeam + SearchExt, SR: SearchRecord + ?Sized, { let beam_width = search_params.beam_width().get(); diff --git a/diskann/src/graph/search/range_search.rs b/diskann/src/graph/search/range_search.rs index e0e1824db..b92fa1384 100644 --- a/diskann/src/graph/search/range_search.rs +++ b/diskann/src/graph/search/range_search.rs @@ -323,7 +323,7 @@ where /// /// Expands the search frontier to find all points within the specified radius. /// Called after the initial graph search has identified starting candidates. -pub(crate) async fn range_search_internal( +pub(crate) async fn range_search_internal( max_degree_with_slack: usize, search_params: &Range, accessor: &mut A, @@ -332,7 +332,7 @@ pub(crate) async fn range_search_internal( ) -> ANNResult where I: crate::utils::VectorId, - A: ExpandBeam + SearchExt, + A: ExpandBeam + SearchExt, { let beam_width = search_params.beam_width().unwrap_or(1); diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 9519b60ef..4969ebe20 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -1068,12 +1068,9 @@ impl<'a> provider::DelegateNeighbor<'a> for Accessor<'_> { } } -impl provider::HasQueryComputer for Accessor<'_> { - type QueryComputer = ::QueryDistance; -} - impl provider::BuildQueryComputer<&[f32]> for Accessor<'_> { type QueryComputerError = Infallible; + type QueryComputer = ::QueryDistance; fn build_query_computer( &self, @@ -1083,7 +1080,7 @@ impl provider::BuildQueryComputer<&[f32]> for Accessor<'_> { } } -impl provider::DistancesUnordered for Accessor<'_> {} +impl provider::DistancesUnordered<&[f32]> for Accessor<'_> {} impl provider::BuildDistanceComputer for Accessor<'_> { type DistanceComputerError = Infallible; @@ -1109,7 +1106,7 @@ impl glue::SearchExt for Accessor<'_> { } } -impl glue::ExpandBeam for Accessor<'_> {} +impl glue::ExpandBeam<&[f32]> for Accessor<'_> {} impl glue::IdIterator> for Accessor<'_> { async fn id_iterator(&mut self) -> Result, ANNError> { diff --git a/diskann/src/provider.rs b/diskann/src/provider.rs index ff067d49c..82a5d9a3c 100644 --- a/diskann/src/provider.rs +++ b/diskann/src/provider.rs @@ -71,14 +71,11 @@ //! * [`BuildDistanceComputer`]: A sub-trait of [`Accessor`] that allows for random-access //! distance computations on the retrieved elements. //! -//! * [`HasQueryComputer`]: Names the canonical query-computer type for an accessor, -//! analogous to [`HasElementRef`]. -//! -//! * [`BuildQueryComputer`]: A sub-trait of [`HasQueryComputer`] that allows for specialized +//! * [`BuildQueryComputer`]: A sub-trait of [`HasElementRef`] that allows for specialized //! query based computations. This allows a query to be pre-processed in a way that allows //! faster computations. //! -//! * [`DistancesUnordered`]: A sub-trait of [`Accessor`] and [`HasQueryComputer`] that +//! * [`DistancesUnordered`]: A sub-trait of [`Accessor`] and [`BuildQueryComputer`] that //! provides a fused iterate-and-compute primitive over a set of element ids using a //! pre-built query computer. //! @@ -536,45 +533,28 @@ pub trait BuildDistanceComputer: Accessor { ) -> Result; } -////////////////////// -// HasQueryComputer // -////////////////////// - -/// Declare the canonical query-computer type for this accessor. -/// -/// This is the query-computer analogue of [`HasElementRef`]: it names the associated -/// type without imposing any construction requirement. Traits that only *use* a -/// pre-built computer (e.g. [`DistancesUnordered`]) require this trait; traits that -/// also *build* one (e.g. [`BuildQueryComputer`]) extend it. -pub trait HasQueryComputer: HasElementRef { - /// The concrete type of the distance computer, which must be applicable for all - /// elements yielded by the [`Accessor`]. - type QueryComputer: for<'a> PreprocessedDistanceFunction, f32> - + Send - + Sync - + 'static; -} - -///////////////////////// -// BuildQueryComputer // -///////////////////////// - /// A trait that provides query computations for a query type `T`. /// /// Query computers are allowed to preprocess the query to enable more efficient distance /// computations. /// -/// This trait extends [`HasQueryComputer`] (which names the computer type) with a -/// factory method that constructs a computer from a query of type `T`. +/// This trait only requires [`HasElementRef`] (so the query computer's element type can be +/// named) so that it can be used with multiple access patterns - like [`Accessor`] and +/// [`crate::flat::FlastSearchStrategy`]. /// /// A fused iterate-and-compute primitive can be created as a sub-trait - -/// e.g. [`DistancesUnordered`], which requires [`Accessor`] and [`HasQueryComputer`]. -/// Callers that need both iteration *and* computer construction should bound on -/// `DistancesUnordered + BuildQueryComputer` at the call site. -pub trait BuildQueryComputer: HasQueryComputer { +/// e.g. [`DistancesUnordered`], which requires both [`Accessor`] and `BuildQueryComputer`. +pub trait BuildQueryComputer: HasElementRef { /// The error type (if any) associated with distance computer construction. type QueryComputerError: std::error::Error + Into + Send + Sync + 'static; + /// The concrete type of the distance computer, which must be applicable for all + /// elements yielded by the [`Accessor`]. + type QueryComputer: for<'a> PreprocessedDistanceFunction, f32> + + Send + + Sync + + 'static; + /// Build the query computer for this accessor. /// /// This method is encouraged to be as fast as possible, but will generally only be @@ -585,16 +565,12 @@ pub trait BuildQueryComputer: HasQueryComputer { ) -> Result; } -/// A sub-trait of [`Accessor`] and [`HasQueryComputer`] that exposes the fused +/// A sub-trait of [`Accessor`] and [`BuildQueryComputer`] that exposes the fused /// iterate-and-compute primitive `distances_unordered`. /// /// The default implementation uses [`Accessor::on_elements_unordered`] to iterate over the /// elements and computes their distances using the provided `computer`. -/// -/// This trait does **not** require [`BuildQueryComputer`] — it only needs the computer -/// *type* via [`HasQueryComputer`]. Callers that also need to *construct* a computer -/// should bound on `DistancesUnordered + BuildQueryComputer` at the call site. -pub trait DistancesUnordered: Accessor + HasQueryComputer { +pub trait DistancesUnordered: Accessor + BuildQueryComputer { /// Compute the distances for the elements in the iterator `vec_id_itr` using the /// `computer` and apply the closure `f` to each distance and ID. fn distances_unordered( From 409d177c88fa7a4d131f5dc66347e864febbc787 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Wed, 13 May 2026 10:57:08 -0400 Subject: [PATCH 22/24] minor comments --- diskann/src/flat/test/harness.rs | 17 +++++++++++------ diskann/src/flat/test/provider.rs | 30 ++++-------------------------- 2 files changed, 15 insertions(+), 32 deletions(-) diff --git a/diskann/src/flat/test/harness.rs b/diskann/src/flat/test/harness.rs index af177378b..af8500511 100644 --- a/diskann/src/flat/test/harness.rs +++ b/diskann/src/flat/test/harness.rs @@ -21,7 +21,7 @@ use std::{cmp::Ordering, num::NonZeroUsize}; -use diskann_vector::PreprocessedDistanceFunction; +use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; use crate::{ ANNResult, @@ -88,7 +88,7 @@ impl KnnOracleRun { let top_k = top_k_sorted(&buf, stats.result_count as usize); let top_k_distances = top_k.iter().map(|(_, d)| *d).collect(); - let ground_truth = brute_force_topk(index.provider(), query, k); + let ground_truth = brute_force_topk(index.provider(), Metric::L2, query, k); Ok(Self { top_k, @@ -100,12 +100,17 @@ impl KnnOracleRun { } /// Compute the brute-force top-`k` `(id, distance)` pairs over every element of -/// `provider`. Iterates [`Provider::items`] directly and scores with a fresh -/// [`f32::query_distance`] computer, so the oracle is independent of the +/// `provider` under `metric`. Iterates [`Provider::items`] directly and scores with +/// a fresh [`f32::query_distance`] computer, so the oracle is independent of the /// [`crate::flat::test::provider::Visitor`] under test. Ties are broken by ascending /// id for determinism. -pub(crate) fn brute_force_topk(provider: &Provider, query: &[f32], k: usize) -> Vec<(u32, f32)> { - let computer = f32::query_distance(query, provider.metric()); +pub(crate) fn brute_force_topk( + provider: &Provider, + metric: Metric, + query: &[f32], + k: usize, +) -> Vec<(u32, f32)> { + let computer = f32::query_distance(query, metric); let mut neighbors: Vec> = provider .items() diff --git a/diskann/src/flat/test/provider.rs b/diskann/src/flat/test/provider.rs index 9890410a3..445438869 100644 --- a/diskann/src/flat/test/provider.rs +++ b/diskann/src/flat/test/provider.rs @@ -38,14 +38,13 @@ use crate::{ pub struct Provider { items: Vec>, dim: usize, - metric: Metric, get_element: Counter, } impl Provider { /// Construct a provider that owns `items`. Every vector must have the same /// (non-zero) length. - pub fn new(metric: Metric, items: impl IntoIterator>) -> Self { + pub fn new(items: impl IntoIterator>) -> Self { let items: Vec> = items.into_iter().collect(); assert!( !items.is_empty(), @@ -67,25 +66,19 @@ impl Provider { Self { items, dim, - metric, get_element: Counter::new(), } } /// Build a provider over the row vectors of [`Grid::data`]. IDs are `0..n` in - /// row-major order (last coordinate varies fastest). Uses [`Metric::L2`]. + /// row-major order (last coordinate varies fastest). /// /// Unlike the graph-side `Provider::grid`, this does *not* add a separate /// start-point row — flat search has no notion of one. pub fn grid(grid: Grid, size: usize) -> Self { let data = grid.data(size); let items: Vec> = data.row_iter().map(|row| row.to_vec()).collect(); - Self::new(Metric::L2, items) - } - - /// Dimensionality of every vector in the provider. - pub fn dim(&self) -> usize { - self.dim + Self::new(items) } /// Number of vectors in the provider. @@ -93,16 +86,6 @@ impl Provider { self.items.len() } - /// `true` if there are no vectors. - pub fn is_empty(&self) -> bool { - self.items.is_empty() - } - - /// Distance metric the provider was constructed with. - pub fn metric(&self) -> Metric { - self.metric - } - /// Snapshot of the per-provider counters. pub fn metrics(&self) -> Metrics { Metrics { @@ -322,11 +305,6 @@ impl<'a> Visitor<'a> { get_element: provider.get_element.local(), } } - - /// The borrowed [`Provider`]. - pub fn provider(&self) -> &'a Provider { - self.provider - } } impl Debug for Visitor<'_> { @@ -354,7 +332,7 @@ impl BuildQueryComputer<&[f32]> for Visitor<'_> { &self, from: &[f32], ) -> Result { - Ok(f32::query_distance(from, self.provider.metric)) + Ok(f32::query_distance(from, Metric::L2)) } } From 3a125f76c28e4a344dda6186ec7ece10a5729275 Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Thu, 14 May 2026 10:09:41 -0400 Subject: [PATCH 23/24] remove OnElementsUnordered --- diskann/src/flat/index.rs | 2 +- diskann/src/flat/iterator.rs | 230 ++++++++++++++---------------- diskann/src/flat/mod.rs | 3 +- diskann/src/flat/test/provider.rs | 21 +-- 4 files changed, 118 insertions(+), 138 deletions(-) diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs index 3a032536d..8ebec490d 100644 --- a/diskann/src/flat/index.rs +++ b/diskann/src/flat/index.rs @@ -31,7 +31,7 @@ pub struct SearchStats { /// A `'static` thin wrapper around a [`DataProvider`] used for flat search. /// /// The provider is owned by the index. The index is constructed once at process startup and -/// shared across requests; per-query state lives in the [`crate::flat::OnElementsUnordered`] +/// shared across requests; per-query state lives in the [`crate::flat::DistancesUnordered`] /// implementation that the [`SearchStrategy`] produces. #[derive(Debug)] pub struct FlatIndex { diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index 9229a3a92..eb2fd8cb4 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -8,23 +8,21 @@ //! This module defines the traits that flat-search algorithms use to walk every element //! of a [`DataProvider`](crate::provider::DataProvider) once. //! -//! * [`OnElementsUnordered`]: the lowest-level entry point and the only required trait -//! to implement. It is a single-method trait that applies a caller-supplied closure -//! to every `(id, element ref)` pair in the provider. The super-traits [`HasId`] and -//! [`HasElementRef`] define the concrete id and element reference types. -//! -//! * [`DistancesUnordered`]: a sub-trait of [`OnElementsUnordered`] that takes a query -//! computer and a closure (typically to filter results through a priority queue) and -//! applies the closure to the `(id, distance)` pair for every element, with each -//! distance computed using the supplied computer. +//! * [`DistancesUnordered`]: the single trait flat search consumes. It takes a +//! pre-built query computer and a callback, applies the callback to every +//! `(id, distance)` pair in the provider, and is the only trait an in-memory +//! visitor (such as [`crate::flat::test::provider::Visitor`]) needs to implement. +//! The super-traits [`HasId`] and [`BuildQueryComputer`] define the id and +//! query-computer types. //! //! * [`FlatIterator`]: a convenient entry point for backends whose natural shape is //! element-at-a-time iteration. The trait exposes a single `next` method and an //! associated `Element<'_>` type that must be [`Reborrow`]able to the `ElementRef<'_>` //! exposed via the [`HasElementRef`] super-trait. //! -//! * [`Iterated`]: bridges any [`FlatIterator`] implementation into an -//! [`OnElementsUnordered`] by looping over [`FlatIterator::next`]. +//! * [`Iterated`]: bridges any [`FlatIterator`] implementation into a +//! [`DistancesUnordered`] by looping over [`FlatIterator::next`] and scoring each +//! element with the supplied computer. use std::fmt::Debug; @@ -36,46 +34,25 @@ use crate::{ provider::{BuildQueryComputer, HasElementRef, HasId}, }; -/// Callback-driven sequential scan over the elements of a flat index. +/// Fused iterate-and-score primitive over the elements of a flat index. /// -/// Algorithms see only `(Id, ElementRef)` pairs and treat the stream as opaque. -pub trait OnElementsUnordered: HasId + HasElementRef + Send + Sync { - /// The error type yielded by [`Self::on_elements_unordered`]. +/// Implementations drive an entire scan over the underlying data, scoring each +/// element with the supplied [`BuildQueryComputer::QueryComputer`] and invoking +/// `f` with the resulting `(id, distance)` pair. The super-trait +/// [`BuildQueryComputer`] supplies the computer type. +pub trait DistancesUnordered: HasId + BuildQueryComputer + Send + Sync { + /// The error type yielded by [`Self::distances_unordered`]. type Error: ToRanked + Debug + Send + Sync + 'static; - /// Drive the entire scan, invoking `f` for each yielded element. - fn on_elements_unordered(&mut self, f: F) -> impl SendFuture> - where - F: Send + for<'a> FnMut(Self::Id, ::ElementRef<'a>); -} - -/// Extension of [`OnElementsUnordered`] that drives the scan with a query computer. -/// -/// The computer is produced by the visitor's [`BuildQueryComputer`] impl, -/// and invokes a callback with `(id, distance)` pairs. -/// -/// This fuses the scan with a pre-processed query computer and runs over a -/// streaming visitor. It pulls the computer type from the implementor's own -/// [`BuildQueryComputer`] impl. -/// -/// The default implementation delegates to [`OnElementsUnordered::on_elements_unordered`], -/// calling `computer.evaluate_similarity` on each element. -pub trait DistancesUnordered: OnElementsUnordered + BuildQueryComputer { - /// Drive the entire scan, scoring each element with `computer` and invoking `f` with - /// the resulting `(id, distance)` pair. + /// Drive the entire scan, scoring each element with `computer` and invoking `f` + /// with the resulting `(id, distance)` pair. fn distances_unordered( &mut self, computer: &>::QueryComputer, - mut f: F, - ) -> impl SendFuture::Error>> + f: F, + ) -> impl SendFuture> where - F: Send + FnMut(::Id, f32), - { - self.on_elements_unordered(move |id, element| { - let dist = computer.evaluate_similarity(element); - f(id, dist); - }) - } + F: Send + FnMut(::Id, f32); } ////////////// @@ -112,18 +89,19 @@ pub trait FlatIterator: HasId + HasElementRef + Send + Sync { // Default // ///////////// -/// Bridges a [`FlatIterator`] into an [`OnElementsUnordered`] by looping over -/// [`FlatIterator::next`] and reborrowing each element into the closure. +/// Bridges a [`FlatIterator`] into a [`DistancesUnordered`] by looping over +/// [`FlatIterator::next`], reborrowing each element, and scoring it with the +/// supplied computer. /// -/// This is the default adapter for providers that implement element-at-a-time iteration. -/// Providers that can do better (prefetching, SIMD batching, bulk I/O) should implement -/// [`OnElementsUnordered`] directly. +/// This is the default adapter for providers that implement element-at-a-time +/// iteration. Providers that can do better (prefetching, SIMD batching, bulk +/// I/O) should implement [`DistancesUnordered`] directly. pub struct Iterated { inner: I, } impl Iterated { - /// Wrap an iterator to produce an [`OnElementsUnordered`] implementation. + /// Wrap an iterator to produce a [`DistancesUnordered`] implementation. pub fn new(inner: I) -> Self { Self { inner } } @@ -142,19 +120,44 @@ impl HasElementRef for Iterated { type ElementRef<'a> = I::ElementRef<'a>; } -impl OnElementsUnordered for Iterated +/// Forwards the inner iterator's [`BuildQueryComputer`] impl through the wrapper +/// so that callers (and the [`DistancesUnordered`] blanket below) can obtain the +/// query computer from the [`Iterated`] adapter directly. +impl BuildQueryComputer for Iterated where - I: FlatIterator + HasId + Send + Sync, + I: BuildQueryComputer, +{ + type QueryComputerError = I::QueryComputerError; + type QueryComputer = I::QueryComputer; + + fn build_query_computer( + &self, + from: T, + ) -> Result { + self.inner.build_query_computer(from) + } +} + +/// The blanket implementation of [`DistancesUnordered`] for any +/// [`FlatIterator`] that also exposes a [`BuildQueryComputer`]. +impl DistancesUnordered for Iterated +where + I: FlatIterator + BuildQueryComputer + Send + Sync, { type Error = I::Error; - fn on_elements_unordered(&mut self, mut f: F) -> impl SendFuture> + fn distances_unordered( + &mut self, + computer: &Self::QueryComputer, + mut f: F, + ) -> impl SendFuture> where - F: Send + for<'a> FnMut(Self::Id, Self::ElementRef<'a>), + F: Send + FnMut(::Id, f32), { async move { while let Some((id, element)) = self.inner.next().await? { - f(id, element.reborrow()); + let dist = computer.evaluate_similarity(element.reborrow()); + f(id, dist); } Ok(()) } @@ -214,9 +217,11 @@ mod tests { // Common impl macro // /////////////////////// - /// Implement [`HasId`], [`HasElementRef`], [`BuildQueryComputer`], and - /// [`DistancesUnordered`] for an iterator type. Every fixture in this module - /// shares these impls — only [`FlatIterator::Element`] varies. + /// Implement [`HasId`], [`HasElementRef`], and [`BuildQueryComputer`] for an + /// iterator type. Every fixture in this module shares these impls — only + /// [`FlatIterator::Element`] varies. The [`DistancesUnordered`] impl on + /// `Iterated<$T>` comes from the blanket impl in the parent module, so does + /// not need to be repeated here. macro_rules! common_iterator_impls { ($T:ty) => { impl HasId for $T { @@ -238,22 +243,6 @@ mod tests { Ok(f32::query_distance(from, Metric::L2)) } } - - // Forward `BuildQueryComputer` through the `Iterated` adapter so the - // `DistancesUnordered` supertrait bound is satisfied. - impl BuildQueryComputer<&[f32]> for Iterated<$T> { - type QueryComputerError = Infallible; - type QueryComputer = ::QueryDistance; - - fn build_query_computer( - &self, - from: &[f32], - ) -> Result { - Ok(f32::query_distance(from, Metric::L2)) - } - } - - impl DistancesUnordered<&[f32]> for Iterated<$T> {} }; } @@ -498,6 +487,18 @@ mod tests { type ElementRef<'a> = &'a [f32]; } + impl BuildQueryComputer<&[f32]> for Failing<'_> { + type QueryComputerError = Infallible; + type QueryComputer = ::QueryDistance; + + fn build_query_computer( + &self, + from: &[f32], + ) -> Result { + Ok(f32::query_distance(from, Metric::L2)) + } + } + impl FlatIterator for Failing<'_> { type Element<'a> = &'a [f32] @@ -527,72 +528,35 @@ mod tests { // Helpers // ///////////// - /// Drive `visitor.on_elements_unordered` to completion and assert the - /// yielded `(id, element)` pairs equal [`sample_items`] in iteration order. - async fn check_visitor(visitor: &mut V) - where - V: OnElementsUnordered + HasId, - V: for<'a> HasElementRef = &'a [f32]>, - V::Error: Debug, - { - let mut out = Vec::new(); - visitor - .on_elements_unordered(|id, e: &[f32]| out.push((id, e.to_vec()))) - .await - .unwrap(); - assert_eq!(out, sample_items()); + /// Build the canonical `(id, distance)` ground-truth list for a query under + /// L2, against [`sample_items`]. + fn expected_distances(query: &[f32]) -> Vec<(u32, f32)> { + let computer = f32::query_distance(query, Metric::L2); + sample_items() + .into_iter() + .map(|(id, v)| (id, computer.evaluate_similarity(v.as_slice()))) + .collect() } /////////// // Tests // /////////// - /// `Iterated::on_elements_unordered` is correct for every supported + /// The blanket [`DistancesUnordered`] impl on [`Iterated`] produces the + /// correct `(id, distance)` pairs for every supported /// [`FlatIterator::Element`] shape: owning, forwarding, guard-wrapped, and /// shared-buffer. #[tokio::test(flavor = "multi_thread", worker_threads = 4)] - async fn default_implementations() { - let store = Store::sample(); - - // Allocating: Element = Vec (owns). - check_visitor(&mut Iterated::new(Allocating::new(&store))).await; - - // Forwarding: Element = &'store [f32] (borrows from store). - check_visitor(&mut Iterated::new(Forwarding::new(&store))).await; - - let recovered = Iterated::new(Forwarding::new(&store)).into_inner(); - - check_visitor(&mut Iterated::new(recovered)).await; - - // Wrapping: Element = Wrapped<'a> (guard-shaped non-ref). - check_visitor(&mut Iterated::new(Wrapping::new(&store))).await; - - // Sharing: Element = &'a [f32] (per-call internal buffer). - check_visitor(&mut Iterated::new(Sharing::new(&store))).await; - } - - /// The default body of [`DistancesUnordered::distances_unordered`] produces - /// `(id, computer.evaluate_similarity(elem))` pairs for every element shape. - #[tokio::test] async fn distances_unordered() { let store = Store::sample(); let query = vec![0.5_f32, 0.9]; - let computer = f32::query_distance(&query, Metric::L2); - let expected = sample_items() - .into_iter() - .map(|(id, v)| (id, computer.evaluate_similarity(v.as_slice()))) - .collect::>(); + let expected = expected_distances(&query); async fn run(mut visitor: Iterated, query: &[f32], expected: &[(u32, f32)]) where I: FlatIterator + Send + Sync, I: for<'a> HasElementRef = &'a [f32]>, - Iterated: HasId - + for<'q> BuildQueryComputer< - &'q [f32], - QueryComputerError = Infallible, - QueryComputer = ::QueryDistance, - > + for<'q> DistancesUnordered<&'q [f32]>, + Iterated: HasId + for<'q> DistancesUnordered<&'q [f32]>, { let computer = visitor.build_query_computer(query).unwrap(); let mut seen: Vec<(u32, f32)> = Vec::new(); @@ -603,15 +567,26 @@ mod tests { assert_eq!(seen, expected); } + // Allocating: Element = Vec (owns). run(Iterated::new(Allocating::new(&store)), &query, &expected).await; + + // Forwarding: Element = &'store [f32] (borrows from store). run(Iterated::new(Forwarding::new(&store)), &query, &expected).await; + + // Round-trip through `Iterated::into_inner` to exercise the unwrap path. + let recovered = Iterated::new(Forwarding::new(&store)).into_inner(); + run(Iterated::new(recovered), &query, &expected).await; + + // Wrapping: Element = Wrapped<'a> (guard-shaped non-ref). run(Iterated::new(Wrapping::new(&store)), &query, &expected).await; + + // Sharing: Element = &'a [f32] (per-call internal buffer). run(Iterated::new(Sharing::new(&store)), &query, &expected).await; } /// An error returned mid-iteration by [`FlatIterator::next`] propagates up - /// through [`Iterated::on_elements_unordered`], and the closure stops being - /// invoked at the failure point. + /// through the [`Iterated`] adapter's [`DistancesUnordered`] impl, and the + /// closure stops being invoked at the failure point. #[tokio::test] async fn failures_midstream() { let store = Store::sample(); @@ -621,9 +596,12 @@ mod tests { fail_after: 1, // Yield item 0 successfully, fail on item 1. }); + let query = vec![0.0_f32, 0.0]; + let computer = visitor.build_query_computer(query.as_slice()).unwrap(); + let mut seen: Vec = Vec::new(); let err = visitor - .on_elements_unordered(|id, _e: &[f32]| seen.push(id)) + .distances_unordered(&computer, |id, _d| seen.push(id)) .await .expect_err("Failing iterator must surface its error"); diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs index fc06e21b9..8f3d66306 100644 --- a/diskann/src/flat/mod.rs +++ b/diskann/src/flat/mod.rs @@ -18,7 +18,6 @@ //! | :------------------------------------ | :----------------------------------------- |:--------- | //! | [`crate::provider::DataProvider`] | [`crate::provider::DataProvider`] | Yes | //! | [`crate::graph::DiskANNIndex`] | [`FlatIndex`] | No | -//! | [`crate::provider::Accessor`] | [`OnElementsUnordered`] | No | //! | [`crate::provider::DistancesUnordered`] | [`DistancesUnordered`] | No | //! | [`crate::graph::glue::SearchStrategy`] | [`SearchStrategy`] | No | //! | [`crate::graph::glue::SearchPostProcess`] | [`crate::graph::glue::SearchPostProcess`] | Yes | @@ -32,7 +31,7 @@ pub mod iterator; pub mod strategy; pub use index::{FlatIndex, SearchStats}; -pub use iterator::{DistancesUnordered, FlatIterator, Iterated, OnElementsUnordered}; +pub use iterator::{DistancesUnordered, FlatIterator, Iterated}; pub use strategy::SearchStrategy; #[cfg(test)] diff --git a/diskann/src/flat/test/provider.rs b/diskann/src/flat/test/provider.rs index 445438869..c6464b020 100644 --- a/diskann/src/flat/test/provider.rs +++ b/diskann/src/flat/test/provider.rs @@ -16,13 +16,13 @@ use std::{ }; use diskann_utils::future::SendFuture; -use diskann_vector::distance::Metric; +use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; use thiserror::Error; use crate::{ ANNError, always_escalate, error::{Infallible, RankedError, ToRanked, TransientError}, - flat::{DistancesUnordered, OnElementsUnordered, SearchStrategy}, + flat::{DistancesUnordered, SearchStrategy}, graph::test::synthetic::Grid, internal::counter::{Counter, LocalCounter}, provider::{self, BuildQueryComputer, ExecutionContext, HasElementRef, HasId, NoopGuard}, @@ -205,7 +205,7 @@ impl TransientError for TransientGetError { } } -/// Two-tier error for [`Visitor::on_elements_unordered`]: a critical [`InvalidId`] +/// Two-tier error for [`Visitor::distances_unordered`]: a critical [`InvalidId`] /// or a recoverable [`TransientGetError`]. #[derive(Debug)] pub enum AccessError { @@ -336,12 +336,16 @@ impl BuildQueryComputer<&[f32]> for Visitor<'_> { } } -impl OnElementsUnordered for Visitor<'_> { +impl DistancesUnordered<&[f32]> for Visitor<'_> { type Error = AccessError; - fn on_elements_unordered(&mut self, mut f: F) -> impl SendFuture> + fn distances_unordered( + &mut self, + computer: &Self::QueryComputer, + mut f: F, + ) -> impl SendFuture> where - F: Send + for<'a> FnMut(Self::Id, ::ElementRef<'a>), + F: Send + FnMut(Self::Id, f32), { async move { for (i, vector) in self.provider.items.iter().enumerate() { @@ -352,15 +356,14 @@ impl OnElementsUnordered for Visitor<'_> { return Err(AccessError::Transient(TransientGetError::new(id))); } self.get_element.increment(); - f(id, vector.as_slice()); + let dist = computer.evaluate_similarity(vector.as_slice()); + f(id, dist); } Ok(()) } } } -impl DistancesUnordered<&[f32]> for Visitor<'_> {} - ////////////// // Strategy // ////////////// From e322565bfef1d1c8bf57f5b09a5ce765198f4b7b Mon Sep 17 00:00:00 2001 From: "Aditya Krishnan (from Dev Box)" Date: Thu, 14 May 2026 16:42:19 -0400 Subject: [PATCH 24/24] remove OnElementsUnordered --- diskann/src/flat/iterator.rs | 6 ++-- rfcs/00983-flat-search.md | 69 ++++++++++++++---------------------- 2 files changed, 29 insertions(+), 46 deletions(-) diff --git a/diskann/src/flat/iterator.rs b/diskann/src/flat/iterator.rs index eb2fd8cb4..7bad2d131 100644 --- a/diskann/src/flat/iterator.rs +++ b/diskann/src/flat/iterator.rs @@ -62,8 +62,8 @@ pub trait DistancesUnordered: HasId + BuildQueryComputer + Send + Sync { /// A lending, asynchronous iterator over the elements of a flat index. /// /// Implementations provide element-at-a-time access via [`Self::next`]. Providers that -/// only implement `FlatIterator` can be wrapped in [`Iterated`] to obtain an -/// [`OnElementsUnordered`] implementation automatically. +/// only implement `FlatIterator` can be wrapped in [`Iterated`] to obtain a +/// default [`DistancesUnordered`] implementation. pub trait FlatIterator: HasId + HasElementRef + Send + Sync { /// The concrete element returned by [`Self::next`]. Reborrows to [`Self::ElementRef`]. type Element<'a>: for<'b> Reborrow<'b, Target = ::ElementRef<'b>> @@ -472,7 +472,7 @@ mod tests { /// `Element<'a> = &'a [f32]`, but `next()` returns `Err(Boom(id))` exactly /// once after `fail_after` successful yields. Used to verify error - /// propagation through [`Iterated::on_elements_unordered`]. + /// propagation through [`Iterated`]'s [`DistancesUnordered`] impl. struct Failing<'a> { store: &'a Store, cursor: usize, diff --git a/rfcs/00983-flat-search.md b/rfcs/00983-flat-search.md index 867f66f90..a4a4312e2 100644 --- a/rfcs/00983-flat-search.md +++ b/rfcs/00983-flat-search.md @@ -4,7 +4,7 @@ |------------------|--------------------------------| | **Authors** | Aditya Krishnan, Alex Razumov, Dongliang Wu | | **Created** | 2026-04-24 | -| **Updated** | 2026-05-05 | +| **Updated** | 2026-05-14 | ## 1. Motivation @@ -22,18 +22,17 @@ stuffing the algorithm or the backend through the `Accessor` trait surface. ### 1.3 Goals -1. Define a streaming access primitive — `OnElementsUnordered` — that mirrors the role - `Accessor` plays for graph search but exposes a callback-driven scan instead of - random access. +1. Define a fused iterate-and-score primitive — `flat::DistancesUnordered` — that + mirrors the role `Accessor` plays for graph search but exposes a sequential + scan-and-score operation instead of random access. 2. Provide flat-search algorithm implementations built on the new primitives, so consumers can use this against their own providers / backends. 3. Expose support for diferent distance computers and post-processing like re-ranking _out-of-the-box_ without having to reimplement these for the flat search path. ## 2. Proposal -The flat-search infrastructure is built on a small sequence of traits. The only traits a -backend *must* implement are `OnElementsUnordered` and its subtrait -`flat::DistancesUnordered`. A `flat::SearchStrategy` then instantiates them per -query. +The flat-search infrastructure is built on a small sequence of traits. The only trait a +backend *must* implement is `flat::DistancesUnordered`. A `flat::SearchStrategy` +then instantiates per-query visitors that implement it. An opt-in `FlatIterator` trait plus the `Iterated` adapter exist for convenience for backends that naturally expose element-at-a-time iteration. @@ -67,50 +66,34 @@ both graph and flat search can share common components as much as possible: - **`provider::DistancesUnordered: Accessor + BuildQueryComputer`** — drives the scan via the random-access `Accessor` machinery. Used by graph search. - - **`flat::DistancesUnordered: OnElementsUnordered + BuildQueryComputer`** — - drives the scan via the new sequential `OnElementsUnordered` primitive. This primitive - is used by flat search. More on it below. + - **`flat::DistancesUnordered: HasId + BuildQueryComputer`** — a + self-contained fused iterate-and-score trait used by flat search. Backends + implement the entire scan-and-score loop in a single method. More on it below. ### 2.2 Core traits for flat search -At the very core is the `OnElementsUnordered` trait, which is simply an API to implement -a callback on the entire index. Implementations choose iteration order, prefetching, and -any bulk reads if they want; algorithms see only `(Id, ElementRef)` pairs. -```rust -pub trait OnElementsUnordered: HasId + HasElementRef + Send + Sync { - type Error: StandardError; - - fn on_elements_unordered(&mut self, f: F) -> impl SendFuture> - where - F: Send + for<'a> FnMut(Self::Id, ::ElementRef<'a>); -} -``` - -`Id` and `ElementRef<'a>` come from the shared `HasId` / `HasElementRef` traits, so a -type that implements `Accessor` and `OnElementsUnordered` exposes the same id and -element types to both subsystems. - -For computing distance with a query specifically, we define a sub-trait of the above - `flat::DistancesUnordered`. +The single required trait for flat search is `flat::DistancesUnordered`. It fuses +iteration and scoring into one method: implementations drive an entire scan over their +underlying data, scoring each element with the supplied query computer and invoking a +callback with `(id, distance)` pairs. Implementations choose iteration order, +prefetching, and any bulk reads; algorithms see only `(Id, f32)` pairs. ```rust -pub trait DistancesUnordered: OnElementsUnordered + BuildQueryComputer { +pub trait DistancesUnordered: HasId + BuildQueryComputer + Send + Sync { + type Error: ToRanked + Debug + Send + Sync + 'static; + fn distances_unordered( &mut self, computer: &>::QueryComputer, f: F, - ) -> impl SendFuture::Error>> + ) -> impl SendFuture> where - F: Send + FnMut(::Id, f32), - { - // default delegates to on_elements_unordered + evaluate_similarity - } + F: Send + FnMut(::Id, f32); } ``` -The default implementation loops `on_elements_unordered` and calls `computer.evaluate_similarity` on each element; -backends that can fuse retrieval and scoring can override it. `DistancesUnordered` is scoped to a single query. We introduce a strategy that is the per-call -constructor that hands the algorithm a freshly-bound visitor. It is stateless, +constructor that hands the algorithm a freshly-bound visitor. It is meant to be stateless, cheap to construct, and lives only for the duration of one search. ```rust @@ -208,7 +191,7 @@ columns dip into. │ │ ▼ ▼ ExpandBeam visitor DistancesUnordered visitor - (Accessor + BuildQueryComputer) (OnElementsUnordered + BuildQueryComputer) + (Accessor + BuildQueryComputer) (HasId + BuildQueryComputer) │ │ │ BuildQueryComputer │ ├─────────────────►::build_query_computer ◄──────────────────────┤ @@ -249,9 +232,9 @@ pub trait FlatIterator: HasId + HasElementRef + Send + Sync { } ``` -`Iterated` wraps any `FlatIterator` and implements `OnElementsUnordered` (and -`DistancesUnordered` by inheritance, when the inner type implements -`BuildQueryComputer`) by looping over `next()` and reborrowing each element. +`Iterated` wraps any `FlatIterator` and implements `DistancesUnordered` (when +the inner type also implements `BuildQueryComputer`) by looping over `next()`, +reborrowing each element, and scoring it with the supplied query computer. ## Trade-offs @@ -274,7 +257,7 @@ An alternative is to make `next()` yield a *batch* instead of a single vector re ### Intra-query parallelism -The current design of `OnElementsUnordered` does not allow an implementation to exploit parallelism within a query; since the trait requires a `&mut self`. Especially for a flat index, some implementations might want to parallelize within the scan for a query. Arguably we will need a more complex extension of this architecture to support this. +The current design of `DistancesUnordered` does not allow an implementation to exploit parallelism within a query; since the trait requires a `&mut self`. Especially for a flat index, some implementations might want to parallelize within the scan for a query. Arguably we will need a more complex extension of this architecture to support this. ## Future Work - Support for other flat-search algorithms like - filtered, range and diverse flat algorithms as additional methods on `FlatIndex`.