Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 54 additions & 18 deletions diskann-disk/src/search/provider/disk_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use std::{
collections::HashMap,
future::Future,
num::NonZeroUsize,
ops::Range,
sync::{
atomic::{AtomicU64, AtomicUsize},
Arc,
Expand All @@ -21,13 +20,12 @@ use diskann::{
graph::{
self,
glue::{
self, DefaultPostProcessor, ExpandBeam, IdIterator, SearchExt, SearchPostProcess,
SearchStrategy,
self, DefaultPostProcessor, ExpandBeam, SearchExt, SearchPostProcess, SearchStrategy,
},
search::Knn,
search_output_buffer, AdjacencyList, DiskANNIndex,
},
neighbor::Neighbor,
neighbor::{Neighbor, NeighborPriorityQueue},
provider::{
Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId,
NeighborAccessor, NoopGuard,
Expand Down Expand Up @@ -715,16 +713,6 @@ where
}
}

impl<Data, VP> IdIterator<Range<u32>> for DiskAccessor<'_, Data, VP>
where
Data: GraphDataType<VectorIdType = u32>,
VP: VertexProvider<Data>,
{
async fn id_iterator(&mut self) -> Result<Range<u32>, ANNError> {
Ok(0..self.provider.num_points as u32)
}
}

impl<'a, 'b, Data, VP> DelegateNeighbor<'a> for DiskAccessor<'b, Data, VP>
where
Data: GraphDataType<VectorIdType = u32>,
Expand Down Expand Up @@ -916,6 +904,55 @@ where
}
}

/// Perform a brute-force linear scan of all points in the index, returning the
/// nearest neighbors that pass `vector_filter`.
///
/// The top `neighbors_before_reranking` candidates from the quantized scan will be
/// provided to full-precision reranking.
async fn flat_search<OB>(
&self,
strategy: &DiskSearchStrategy<'_, Data, ProviderFactory>,
query: &[Data::VectorDataType],
vector_filter: &(dyn Fn(&u32) -> bool + Send + Sync),
neighbors_before_reranking: usize,
output: &mut OB,
) -> ANNResult<graph::index::SearchStats>
where
OB: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send,
{
let provider = self.index.provider();
let mut accessor = strategy
.search_accessor(provider, &DefaultContext)
.into_ann_result()?;
let computer = accessor.build_query_computer(query).into_ann_result()?;

let mut best = NeighborPriorityQueue::new(neighbors_before_reranking);
let mut cmps = 0u32;

let num_points = provider.num_points as u32;
for id in 0..num_points {
if vector_filter(&id) {
let element = accessor.get_element(id).await.into_ann_result()?;
let dist = computer.evaluate_similarity(element);
best.insert(Neighbor::new(id, dist));
cmps += 1;
}
}

let result_count = strategy
.default_post_processor()
.post_process(&mut accessor, query, &computer, best.iter(), output)
.await
.into_ann_result()?;

Ok(graph::index::SearchStats {
cmps,
hops: 0,
result_count: result_count as u32,
range_search_second_round: false,
})
}

/// Perform a search on the disk index.
/// return the list of nearest neighbors and associated data.
pub fn search(
Expand Down Expand Up @@ -993,12 +1030,11 @@ where
let k = k_value;
let l = search_list_size as usize;
let stats = if is_flat_search {
self.runtime.block_on(self.index.flat_search(
self.runtime.block_on(self.flat_search(
&strategy,
&DefaultContext,
strategy.query,
query,
vector_filter,
&Knn::new(k, l, beam_width)?,
l,
&mut result_output_buffer,
))?
} else {
Expand Down
19 changes: 0 additions & 19 deletions diskann/src/graph/glue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ use crate::{
Accessor, AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer,
DataProvider, HasId, NeighborAccessor,
},
utils::VectorId,
};

/// A trait to override search constraints such as early termination based on constraints
Expand Down Expand Up @@ -819,24 +818,6 @@ where
) -> impl Future<Output = Result<Self::DeleteElementGuard, Self::DeleteElementError>> + Send;
}

/// Provides asynchronous access to an iterator over vector IDs.
///
/// This trait defines a method to asynchronously retrieve an iterator over vector IDs.
///
/// # Type Parameters
///
/// - `I`: The iterator type returned by the accessor. It must implement `Iterator` with items of type implementing `VectorId`.
///
/// # Errors
///
/// Returns an [`ANNError`] if the iterator cannot be retrieved successfully.
pub trait IdIterator<I>
where
I: Iterator<Item: VectorId>,
{
fn id_iterator(&mut self) -> impl std::future::Future<Output = Result<I, ANNError>>;
}

///////////
// Tests //
///////////
Expand Down
98 changes: 2 additions & 96 deletions diskann/src/graph/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ use tokio::task::JoinSet;
use super::{
AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, Search,
glue::{
self, Batch, ExpandBeam, IdIterator, InplaceDeleteStrategy, InsertStrategy,
MultiInsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy,
self, Batch, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy,
PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy,
},
internal::{BackedgeBuffer, SortedNeighbors, prune},
search::{
Knn,
record::{NoopSearchRecord, SearchRecord, VisitedSearchRecord},
scratch::{self, PriorityQueueConfiguration, SearchScratch, SearchScratchParams},
},
Expand Down Expand Up @@ -2183,99 +2182,6 @@ where
search_params.search(self, strategy, processor, context, query, output)
}

/// Performs a brute-force flat search over the points matching a provided filter function.
///
/// This method executes a linear scan through all points in the index, applying the provided
/// `vector_filter` to select candidate points. It computes the similarity between the query
/// vector and each candidate, returning the top results according to the provided search parameters.
///
/// # Arguments
///
/// * `strategy` - The search strategy to use for accessing and processing elements.
/// * `context` - The context to pass through to providers.
/// * `query` - The query vector for which nearest neighbors are sought.
/// * `vector_filter` - A predicate function used to filter candidate vectors based on their external IDs.
/// * `search_params` - Parameters controlling the search behavior, such as search depth (`l_value`).
/// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller.
///
/// # Returns
///
/// Returns search statistics including the number of distance computations performed.
///
/// # Errors
///
/// Returns an error if there is a failure accessing elements or if the provided parameters are invalid.
///
/// # Notes
///
/// This method is computationally expensive for large datasets, as it does not leverage the graph structure
/// and instead performs a linear scan of all filtered points.
pub async fn flat_search<'a, S, T, O, OB, I>(
&'a self,
strategy: &'a S,
context: &'a DP::Context,
query: T,
vector_filter: &(dyn Fn(&DP::ExternalId) -> bool + Send + Sync),
search_params: &Knn,
output: &mut OB,
) -> ANNResult<SearchStats>
where
T: Copy + Send,
S: glue::DefaultSearchStrategy<DP, T, O, SearchAccessor<'a>: IdIterator<I>>,
I: Iterator<Item = <DP as DataProvider>::InternalId>,
O: Send,
OB: search_output_buffer::SearchOutputBuffer<O> + Send,
{
let mut accessor = strategy
.search_accessor(&self.data_provider, context)
.into_ann_result()?;
let computer = accessor.build_query_computer(query).into_ann_result()?;

let mut scratch = {
let num_start_points = accessor.starting_points().await?.len();
self.search_scratch(search_params.l_value().get(), num_start_points)
};

let id_iterator = accessor.id_iterator().await?;
for id in id_iterator {
let external_id = self
.data_provider
.to_external_id(context, id)
.escalate("external id should be found")?;

if vector_filter(&external_id) {
scratch.visited.insert(id);
let element = accessor
.get_element(id)
.await
.escalate("matched point retrieval must succeed")?;
let dist = computer.evaluate_similarity(element.reborrow());
scratch.best.insert(Neighbor::new(id, dist));
scratch.cmps += 1;
}
}

let result_count = strategy
.default_post_processor()
.post_process(
&mut accessor,
query,
&computer,
scratch.best.iter().take(search_params.l_value().get()),
output,
)
.send()
.await
.into_ann_result()?;

Ok(SearchStats {
cmps: scratch.cmps,
hops: scratch.hops,
result_count: result_count as u32,
range_search_second_round: false,
})
}

//////////////////
// Paged Search //
//////////////////
Expand Down
81 changes: 0 additions & 81 deletions diskann/src/graph/test/cases/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,84 +281,3 @@ async fn test_drop_deleted_neighbors_noop() {
.unwrap();
assert_eq!(result, graph::ConsolidateKind::Complete);
}

#[tokio::test(flavor = "current_thread")]
async fn test_flat_search_basic() {
use crate::graph::search::Knn;
use crate::graph::search_output_buffer::IdDistance;

let adjacency_list = generate_2d_square_adjacency_list();
let index = setup_2d_square(adjacency_list, 4);
let strategy = test_provider::Strategy::new();
let ctx = test_provider::Context::new();

// Query near origin — node 0 at (0,0) is closest.
// l_value must cover all 5 points (4 data + 1 start) so the working set
// doesn't drop any before the post-processor runs.
let query = [0.1_f32, 0.1];
let params = Knn::new(4, 5, None).unwrap();

let mut ids = [0u32; 4];
let mut distances = [0.0f32; 4];
let mut output = IdDistance::new(&mut ids, &mut distances);

let stats = index
.flat_search(
&strategy,
&ctx,
query.as_slice(),
&|_| true,
&params,
&mut output,
)
.await
.unwrap();

// FilterStartPoints removes the start node, leaving 4 data nodes.
assert_eq!(stats.result_count, 4);
let results: std::collections::HashSet<u32> =
ids[..stats.result_count as usize].iter().copied().collect();
for id in 0..4u32 {
assert!(results.contains(&id), "data node {id} should be in results");
}
}

#[tokio::test(flavor = "current_thread")]
async fn test_flat_search_with_filter() {
use crate::graph::search::Knn;
use crate::graph::search_output_buffer::IdDistance;

let adjacency_list = generate_2d_square_adjacency_list();
let index = setup_2d_square(adjacency_list, 4);
let strategy = test_provider::Strategy::new();
let ctx = test_provider::Context::new();

// Query near origin, but filter out node 0.
let query = [0.1_f32, 0.1];
let params = Knn::new(2, 4, None).unwrap();

let mut ids = [0u32; 2];
let mut distances = [0.0f32; 2];
let mut output = IdDistance::new(&mut ids, &mut distances);

let stats = index
.flat_search(
&strategy,
&ctx,
query.as_slice(),
&|ext_id: &u32| *ext_id != 0,
&params,
&mut output,
)
.await
.unwrap();

assert_eq!(stats.result_count, 2);
assert!(
!ids[..stats.result_count as usize].contains(&0),
"node 0 should be filtered out"
);
// Nodes 1, 2, 3 remain — closest two to (0.1, 0.1) are 1 (1,0) and 2 (0,1).
assert!(ids.contains(&1), "node 1 should be present");
assert!(ids.contains(&2), "node 2 should be present");
}
7 changes: 0 additions & 7 deletions diskann/src/graph/test/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1111,13 +1111,6 @@ impl glue::SearchExt for Accessor<'_> {

impl glue::ExpandBeam<&[f32]> for Accessor<'_> {}

impl glue::IdIterator<std::vec::IntoIter<u32>> for Accessor<'_> {
async fn id_iterator(&mut self) -> Result<std::vec::IntoIter<u32>, ANNError> {
let ids: Vec<u32> = self.provider.terms.iter().map(|r| *r.key()).collect();
Ok(ids.into_iter())
}
}

#[derive(Debug, Clone)]
pub struct Strategy {
// Set this flag to enable reuse within the [`workingset::Map`]. For multi-threaded
Expand Down
Loading