Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 133 additions & 65 deletions diskann-benchmark-core/src/recall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@ pub struct RecallMetrics {
pub recall_n: usize,
/// The number of queries.
pub num_queries: usize,
/// The average recall across all queries.
/// The average recall across queries with non-empty groundtruth.
/// Queries with zero groundtruth results are excluded from the average.
pub average: f64,
/// The minimum observed recall (max possible value: `recall_n`).
pub minimum: usize,
/// The maximum observed recall (max possible value: `recall_k`).
pub maximum: usize,
}

#[derive(Debug, Error)]
Expand Down Expand Up @@ -186,76 +183,79 @@ where
}
}

// The actual recall computation for fixed-size groundtruth
let mut recall_values: Vec<usize> = Vec::new();
// The actual recall computation for groundtruth
let mut recall_values: Vec<f64> = Vec::new();
let mut this_groundtruth = HashSet::new();
let mut this_results = HashSet::new();

let mut num_nonzero = 0;
Comment thread
magdalendobson marked this conversation as resolved.

for i in 0..results.nrows() {
let result = results.row(i);
if !allow_insufficient_results && result.len() < recall_n {
return Err(ComputeRecallError::NotEnoughResults(result.len(), recall_n));
}

let gt_row = groundtruth.row(i);
if gt_row.len() < recall_k {
return Err(ComputeRecallError::NotEnoughGroundTruth(
gt_row.len(),
recall_k,
));
}

// Populate the groundtruth using the top-k
this_groundtruth.clear();
this_groundtruth.extend(gt_row.iter().take(recall_k).cloned());

// If we have distances, then continue to append distances as long as the distance
// value is constant
if let Some(distances) = groundtruth_distances
&& recall_k > 0
{
let distances_row = distances.row(i);
if distances_row.len() > recall_k - 1 && gt_row.len() > recall_k - 1 {
let last_distance = distances_row[recall_k - 1];
for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(recall_k) {
if *d == last_distance {
this_groundtruth.insert(g.clone());
} else {
break;
// groundtruth does not have to be fixed-size, so we compute recall_k for this row based on its gt length
let this_recall_k = gt_row.len().min(recall_k);

let recall = if this_recall_k > 0 {
num_nonzero += 1;
Comment thread
magdalendobson marked this conversation as resolved.

// Populate the groundtruth using the top-k
this_groundtruth.clear();
this_groundtruth.extend(gt_row.iter().take(this_recall_k).cloned());

// If we have distances, then continue to append distances as long as the distance
// value is constant
if let Some(distances) = groundtruth_distances
&& this_recall_k > 0
{
let distances_row = distances.row(i);
if distances_row.len() > this_recall_k - 1 && gt_row.len() > this_recall_k - 1 {
let last_distance = distances_row[this_recall_k - 1];
for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(this_recall_k) {
if *d == last_distance {
this_groundtruth.insert(g.clone());
} else {
break;
}
}
}
}
}

this_results.clear();
this_results.extend(result.iter().take(recall_n).cloned());
this_results.clear();
this_results.extend(result.iter().take(recall_n).cloned());

// Count the overlap
let r = this_groundtruth
.iter()
.filter(|i| this_results.contains(i))
.count()
.min(recall_k);
// Count the overlap
let r = this_groundtruth
.iter()
.filter(|i| this_results.contains(i))
.count()
.min(this_recall_k);

recall_values.push(r);
}
(r as f64) / (this_recall_k as f64)
} else {
0.0
};
Comment thread
magdalendobson marked this conversation as resolved.

// Perform post-processing
let total: usize = recall_values.iter().sum();
let minimum = recall_values.iter().min().unwrap_or(&0);
let maximum = recall_values.iter().max().unwrap_or(&0);
recall_values.push(recall);
}

// We explicitly check that each groundtruth row has at least `recall_k` elements.
let div = recall_k * nrows;
let average = (total as f64) / (div as f64);
// Compute the average recall
let total: f64 = recall_values.iter().sum();
let average = if num_nonzero == 0 {
0.0
} else {
total / (num_nonzero as f64)
};

Ok(RecallMetrics {
recall_k,
recall_n,
num_queries: nrows,
average,
minimum: *minimum,
maximum: *maximum,
})
}

Expand Down Expand Up @@ -467,8 +467,6 @@ mod tests {
assert_eq!(recall.num_queries, our_results.nrows());
assert_eq!(recall.recall_k, expected.recall_k);
assert_eq!(recall.recall_n, expected.recall_n);
assert_eq!(recall.minimum, *expected.components.iter().min().unwrap());
assert_eq!(recall.maximum, *expected.components.iter().max().unwrap());
}

//-----------//
Expand Down Expand Up @@ -514,8 +512,6 @@ mod tests {
assert_eq!(recall.num_queries, our_results.nrows());
assert_eq!(recall.recall_k, expected.recall_k);
assert_eq!(recall.recall_n, expected.recall_n);
assert_eq!(recall.minimum, *expected.components.iter().min().unwrap());
assert_eq!(recall.maximum, *expected.components.iter().max().unwrap());
}
}

Expand Down Expand Up @@ -575,18 +571,90 @@ mod tests {
));
}

// Not enough groundtruth - dynamic
// Not enough groundtruth - dynamic: unlike the fixed-size matrix case, dynamic
// (variable-length) groundtruth rows with fewer than recall_k entries are valid
// and represent queries with limited results (e.g. filtered queries). Recall is
// computed using the available entries (this_recall_k = gt_row.len().min(recall_k)).
{
let groundtruth: Vec<_> = (0..10).map(|_| vec![0; 5]).collect();
let groundtruth: Vec<_> = (0..10).map(|_| vec![0u32; 5]).collect();
let results = Matrix::<u32>::new(0, 10, 10);
let err = knn(&groundtruth, None, &results, 10, 10, false).unwrap_err();
assert!(matches!(err, ComputeRecallError::NotEnoughGroundTruth(..)));
let err_allow_insufficient_results =
knn(&groundtruth, None, &results, 10, 10, true).unwrap_err();
assert!(matches!(
err_allow_insufficient_results,
ComputeRecallError::NotEnoughGroundTruth(..)
));
// Should succeed: each row uses this_recall_k = min(5, 10) = 5
let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap();
assert_eq!(recall.num_queries, 10);
}

// Dynamic groundtruth with fewer entries: verify correct recall values.
// groundtruth has 5 entries per row: [1, 2, 3, 4, 5].
// results has 10 entries per row: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].
// With recall_k=10, this_recall_k = min(5, 10) = 5. All 5 groundtruth
// entries appear in the results, so recall = 5/5 = 1.0.
{
let gt_row: Vec<u32> = (1..=5).collect();
let groundtruth: Vec<_> = (0..10).map(|_| gt_row.clone()).collect();
let mut results = Matrix::<u32>::new(0, 10, 10);
for i in 0..10 {
for (j, v) in (1u32..=10).enumerate() {
results[(i, j)] = v;
}
}
let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap();
assert!((recall.average - 1.0).abs() < 1e-10);
}

// Dynamic groundtruth with partial match: 3 of 5 groundtruth entries appear in results.
// recall = 3/5 = 0.6 per query.
{
// groundtruth: [1, 2, 3, 4, 5]; results contain [1, 2, 3, 6, 7, 8, 9, 10, 11, 12]
let gt_row: Vec<u32> = (1..=5).collect();
let groundtruth: Vec<_> = (0..10).map(|_| gt_row.clone()).collect();
let mut results = Matrix::<u32>::new(0, 10, 10);
let res_row: Vec<u32> = vec![1, 2, 3, 6, 7, 8, 9, 10, 11, 12];
for i in 0..10 {
for (j, &v) in res_row.iter().enumerate() {
results[(i, j)] = v;
}
}
let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap();
assert!((recall.average - 0.6).abs() < 1e-10);
}

// Mixed zero and non-zero groundtruth rows: verify denominator uses only non-zero rows.
// 5 queries with groundtruth [1, 2, 3, 4, 5] (all match → recall = 1.0 each)
// 5 queries with empty groundtruth [] (excluded from average)
// Expected average = (5 * 1.0) / 5 = 1.0
{
let mut groundtruth: Vec<Vec<u32>> = Vec::new();
// First 5 rows: non-empty groundtruth
for _ in 0..5 {
groundtruth.push((1..=5).collect());
}
// Last 5 rows: empty groundtruth
for _ in 0..5 {
groundtruth.push(vec![]);
}

let mut results = Matrix::<u32>::new(0, 10, 10);
for i in 0..10 {
for (j, v) in (1u32..=10).enumerate() {
results[(i, j)] = v;
}
}

let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap();
assert_eq!(recall.num_queries, 10);
assert!((recall.average - 1.0).abs() < 1e-10);
}

// All queries have zero groundtruth: should return average = 0.0 (not NaN/inf).
{
let groundtruth: Vec<Vec<u32>> = (0..10).map(|_| vec![]).collect();
let results = Matrix::<u32>::new(0, 10, 10);

let recall = knn(&groundtruth, None, &results, 10, 10, false).unwrap();
assert_eq!(recall.num_queries, 10);
assert_eq!(recall.average, 0.0);
assert!(!recall.average.is_nan());
assert!(!recall.average.is_infinite());
}

// Distance Row Mismatch
Expand Down
2 changes: 1 addition & 1 deletion diskann-benchmark/src/backend/exhaustive/minmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ mod imp {
f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?;

let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?;
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?;
let mut search_results = Vec::<SearchResults>::new();
let threadpool = rayon::ThreadPoolBuilder::new()
.num_threads(input.search.num_threads.get())
Expand Down
2 changes: 1 addition & 1 deletion diskann-benchmark/src/backend/exhaustive/product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ mod imp {
f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?;

let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?;
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?;

let search_progress =
make_progress_bar("running search", queries.nrows(), output.draw_target())?;
Expand Down
2 changes: 1 addition & 1 deletion diskann-benchmark/src/backend/exhaustive/spherical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ mod imp {
f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?;

let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?;
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth), None)?;

let search_progress = make_progress_bar(
"running search",
Expand Down
24 changes: 15 additions & 9 deletions diskann-benchmark/src/backend/index/benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,11 @@ where
let queries: Arc<Matrix<DP::Element>> =
Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?);

let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?;
// compute the maximum value of k used in any search
let max_k = topk.max_k();

let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?;

let knn = benchmark_core::search::graph::KNN::new(
index.clone(),
Expand Down Expand Up @@ -649,10 +653,8 @@ fn full_precision_streaming<T>(
where
T: bytemuck::Pod + VectorRepr + WithApproximateNorm + SampleableForStart,
{
let topk = match &input.search_phase {
SearchPhase::Topk(topk) => topk,
_ => anyhow::bail!("Only TopK is currently supported by the streaming index"),
};
let topk = input.search_phase.as_topk()?;

let consolidate_threshold: f32 = input.runbook_params.consolidate_threshold;

let data = datafiles::load_dataset::<T>(datafiles::BinFile(&input.build.data))?;
Expand Down Expand Up @@ -687,10 +689,14 @@ where

let managed = Managed::new(max_points, consolidate_threshold, managed_stream);

let layered = bigann::WithData::new(managed, data, queries, |path| {
Ok(Box::new(datafiles::load_groundtruth(datafiles::BinFile(
path,
))?))
// compute the maximum value of k used in any search
let max_k = topk.max_k();

let layered = bigann::WithData::new(managed, data, queries, move |path| {
Ok(Box::new(datafiles::load_groundtruth(
datafiles::BinFile(path),
Some(max_k),
)?))
});

Ok(layered)
Expand Down
8 changes: 6 additions & 2 deletions diskann-benchmark/src/backend/index/spherical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,14 @@ mod imp {
) -> anyhow::Result<AggregatedSearchResults> {
let topk = phase.as_topk()?;

// compute the maximum value of k used in any search
let max_k = topk.max_k();

let queries: Arc<Matrix<f32>> =
Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?);

let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?;
let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?;

let steps = search::knn::SearchSteps::new(topk.reps, &topk.num_threads, &topk.runs);

Expand Down Expand Up @@ -516,7 +520,7 @@ mod imp {
))?);

let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&multihop.groundtruth))?;
datafiles::load_range_groundtruth(datafiles::BinFile(&multihop.groundtruth))?;

let steps =
search::knn::SearchSteps::new(multihop.reps, &multihop.num_threads, &multihop.runs);
Expand Down
6 changes: 6 additions & 0 deletions diskann-benchmark/src/inputs/graph_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ pub(crate) struct TopkSearchPhase {
pub(crate) runs: Vec<GraphSearch>,
}

impl TopkSearchPhase {
pub(crate) fn max_k(&self) -> usize {
self.runs.iter().map(|run| run.recall_k).max().unwrap_or(0)
}
}

impl CheckDeserialization for TopkSearchPhase {
fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> {
// Check the validity of the input files.
Expand Down
Loading
Loading