diff --git a/Cargo.lock b/Cargo.lock index c3971f503..f1c43e077 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -626,6 +626,7 @@ dependencies = [ "anyhow", "bytemuck", "dashmap", + "diskann-record", "diskann-utils", "diskann-vector", "diskann-wide", @@ -638,6 +639,7 @@ dependencies = [ "relative-path 2.0.1", "serde", "serde_json", + "tempfile", "thiserror 2.0.17", "tokio", "tracing", @@ -845,6 +847,7 @@ dependencies = [ "diskann-linalg", "diskann-platform", "diskann-quantization", + "diskann-record", "diskann-utils", "diskann-vector", "diskann-wide", @@ -892,6 +895,16 @@ dependencies = [ "trybuild", ] +[[package]] +name = "diskann-record" +version = "0.50.0" +dependencies = [ + "anyhow", + "serde", + "serde_json", + "tempfile", +] + [[package]] name = "diskann-tools" version = "0.50.0" @@ -947,11 +960,13 @@ dependencies = [ "approx", "cfg-if", "criterion", + "diskann-record", "diskann-wide", "half", "iai-callgrind", "rand 0.9.4", "rand_distr", + "tempfile", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index b1bbc7bd1..af0809985 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ members = [ "diskann-benchmark-simd", "diskann-benchmark", "diskann-tools", - "vectorset", + "vectorset", "diskann-record", ] default-members = [ @@ -54,6 +54,7 @@ diskann-linalg = { path = "diskann-linalg", version = "0.50.0" } diskann-utils = { path = "diskann-utils", default-features = false, version = "0.50.0" } diskann-quantization = { path = "diskann-quantization", default-features = false, version = "0.50.0" } diskann-platform = { path = "diskann-platform", version = "0.50.0" } +diskann-record = { path = "diskann-record", version = "0.50.0" } # Algorithm diskann = { path = "diskann", version = "0.50.0" } # Providers @@ -85,7 +86,6 @@ iai-callgrind = "0.14.0" itertools = "0.13.0" num-traits = "0.2.15" num_cpus = "1.16.0" -once_cell = "1.19.0" opentelemetry = "0.30.0" opentelemetry_sdk = "0.30.0" paste = "1.0.15" diff --git a/diskann-providers/Cargo.toml b/diskann-providers/Cargo.toml index 1e3fd2bcc..5494208c6 100644 --- a/diskann-providers/Cargo.toml +++ b/diskann-providers/Cargo.toml @@ -33,6 +33,7 @@ diskann-linalg = { workspace = true } diskann = { workspace = true } diskann-utils = { workspace = true } diskann-quantization = { workspace = true, features = ["rayon"] } +diskann-record = { workspace = true } tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } tempfile = { workspace = true, optional = true } bf-tree = { workspace = true, optional = true } diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index df554b5e8..212da725f 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -426,14 +426,15 @@ mod tests { }; use diskann_utils::test_data_root; use diskann_vector::distance::Metric; + use rand::Rng; - use super::DiskANNIndex; + use super::{DiskANNIndex, create_current_thread_runtime}; use crate::{ index::diskann_async, model::{ configuration::IndexConfiguration, graph::provider::async_::{ - common::{FullPrecision, TableBasedDeletes}, + common::{FullPrecision, NoDeletes, NoStore, TableBasedDeletes}, inmem::{self, CreateFullPrecision, DefaultProvider}, }, }, @@ -537,4 +538,97 @@ mod tests { assert_eq!(ids[0], 0); assert_eq!(distances[0], 0.0); } + + ///////////////////////////////// + // diskann-record round-trips // + ///////////////////////////////// + + #[test] + fn test_diskann_record_save_load_round_trip() { + // -- Build a `FullPrecisionProvider` index ---- + let dim = 8; + let max_points = 32; + let num_points = 24; + + // Deterministic synthetic data so the test is hermetic (no on-disk fixture). + let mut rng = create_rnd_from_seed_in_tests(0x9c6a1c3b29f74e51); + let train_data: Vec> = (0..num_points) + .map(|_| (0..dim).map(|_| rng.random_range(-1.0..1.0)).collect()) + .collect(); + + let (build_config, parameters) = diskann_async::simplified_builder( + 20, + 16, + Metric::L2, + dim, + max_points, + |_| {}, + ) + .unwrap(); + + let fp_precursor = + CreateFullPrecision::new(parameters.dim, parameters.prefetch_cache_line_level); + let data_provider = + DefaultProvider::new_empty(parameters, fp_precursor, NoStore, NoDeletes).unwrap(); + + let index = + DiskANNIndex::new_with_current_thread_runtime(build_config.clone(), data_provider); + let ctx = DefaultContext; + for (i, v) in train_data.iter().enumerate() { + index + .insert(FullPrecision, &ctx, &(i as u32), v.as_slice()) + .unwrap(); + } + + // -- Search on the original index -------------------------------------- + let top_k = 5; + let search_l = 20; + let kind = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let query = train_data[0].as_slice(); + + let mut ids_orig = vec![0u32; top_k]; + let mut dists_orig = vec![0.0f32; top_k]; + let mut output_orig = + search_output_buffer::IdDistance::new(&mut ids_orig, &mut dists_orig); + let stats_orig = index + .search(kind, &FullPrecision, &ctx, query, &mut output_orig) + .unwrap(); + assert_eq!(stats_orig.result_count, top_k as u32); + // The query is itself in the dataset, so the nearest neighbor must be at distance 0. + assert_eq!(ids_orig[0], 0); + assert_eq!(dists_orig[0], 0.0); + + // -- Save via diskann-record (synchronous) ----------------------------- + let dir = tempfile::tempdir().expect("tempdir"); + let manifest = dir.path().join("manifest.json"); + diskann_record::save::save_to_disk(&*index.inner, dir.path(), &manifest) + .expect("save_to_disk"); + + // -- Load via diskann-record into a fresh sync wrapper ----------------- + type TestProvider = inmem::FullPrecisionProvider; + let loaded_inner: graph::DiskANNIndex = + diskann_record::load::load_from_disk(&manifest, dir.path()) + .expect("load_from_disk"); + let (rt, handle) = create_current_thread_runtime(); + let loaded: DiskANNIndex = DiskANNIndex { + inner: Arc::new(loaded_inner), + _runtime: Some(rt), + handle, + }; + + // -- Search on the loaded index ---------------------------------------- + let kind = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let mut ids_loaded = vec![0u32; top_k]; + let mut dists_loaded = vec![0.0f32; top_k]; + let mut output_loaded = + search_output_buffer::IdDistance::new(&mut ids_loaded, &mut dists_loaded); + let stats_loaded = loaded + .search(kind, &FullPrecision, &ctx, query, &mut output_loaded) + .unwrap(); + + // -- Results must match the pre-save search ---------------------------- + assert_eq!(stats_orig.result_count, stats_loaded.result_count); + assert_eq!(ids_orig, ids_loaded); + assert_eq!(dists_orig, dists_loaded); + } } diff --git a/diskann-providers/src/model/configuration/index_configuration.rs b/diskann-providers/src/model/configuration/index_configuration.rs index d279e7397..a04a93c25 100644 --- a/diskann-providers/src/model/configuration/index_configuration.rs +++ b/diskann-providers/src/model/configuration/index_configuration.rs @@ -108,6 +108,78 @@ impl IndexConfiguration { } } +////////////////////////////////// +// diskann-record Save/Load // +////////////////////////////////// +// +// The wire format preserves the same fields that `IndexConfiguration::new` takes +// (`config`, `num_threads`, `dist_metric`, `dim`, `max_points`, `num_frozen_pts`) plus +// `random_seed`, because the seed is part of reproducibility. The prefetch tunables +// (`prefetch_lookahead`, `prefetch_cache_line_level`) are intentionally not persisted; +// they are deployment knobs, not part of the index itself, so loaders apply their own +// defaults (`None`). + +impl diskann_record::save::Save for IndexConfiguration { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + Ok(diskann_record::save_fields!( + self, + context, + [ + config, + num_threads, + dist_metric, + dim, + max_points, + num_frozen_pts, + random_seed, + ] + )) + } +} + +impl diskann_record::load::Load<'_> for IndexConfiguration { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn load( + object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + diskann_record::load_fields!( + object, + [ + config: Config, + num_threads: usize, + dist_metric: Metric, + dim: usize, + max_points: usize, + num_frozen_pts: NonZeroUsize, + random_seed: Option, + ] + ); + Ok(Self { + config, + num_threads, + dist_metric, + dim, + max_points, + num_frozen_pts, + prefetch_lookahead: None, + prefetch_cache_line_level: None, + random_seed, + }) + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + #[cfg(test)] mod tests { use diskann::utils::ONE; @@ -177,4 +249,53 @@ mod tests { index_configuration.config.pruned_degree().get() ); } + + ///////////////////////////////// + // diskann-record round-trips // + ///////////////////////////////// + + fn round_trip_helper(value: &T) -> T + where + T: diskann_record::save::Saveable + for<'a> diskann_record::load::Loadable<'a>, + { + let dir = tempfile::tempdir().expect("tempdir"); + let manifest = dir.path().join("manifest.json"); + diskann_record::save::save_to_disk(value, dir.path(), &manifest) + .expect("save_to_disk"); + diskann_record::load::load_from_disk::(&manifest, dir.path()) + .expect("load_from_disk") + } + + #[test] + fn index_configuration_round_trips_minimal() { + let original = IndexConfiguration::new(Metric::L2, 128, 1000, ONE, 1, config()); + assert_eq!(original, round_trip_helper(&original)); + } + + #[test] + fn index_configuration_round_trips_preserves_random_seed() { + let original = IndexConfiguration::new(Metric::Cosine, 64, 500, ONE, 4, config()) + .with_pseudo_rng_from_seed(0xDEAD_BEEF_CAFE_F00D); + let restored = round_trip_helper(&original); + assert_eq!(original, restored); + assert_eq!(restored.random_seed, Some(0xDEAD_BEEF_CAFE_F00D)); + } + + #[test] + fn index_configuration_round_trips_drops_prefetch_fields() { + // Build a config with prefetch tunables set; they should NOT be persisted, so + // the loaded copy will differ from the original on those fields only. + let original = IndexConfiguration::new(Metric::L2, 128, 1000, ONE, 1, config()) + .with_prefetch_lookahead(NonZeroUsize::new(8)) + .with_prefetch_cache_line_level(Some(PrefetchCacheLineLevel::CacheLine8)); + let restored = round_trip_helper(&original); + assert_eq!(restored.prefetch_lookahead, None); + assert_eq!(restored.prefetch_cache_line_level, None); + + // Everything else still matches. + let mut expected = original.clone(); + expected.prefetch_lookahead = None; + expected.prefetch_cache_line_level = None; + assert_eq!(expected, restored); + } } diff --git a/diskann-providers/src/model/graph/provider/async_/common.rs b/diskann-providers/src/model/graph/provider/async_/common.rs index d24900404..f0ed97c98 100644 --- a/diskann-providers/src/model/graph/provider/async_/common.rs +++ b/diskann-providers/src/model/graph/provider/async_/common.rs @@ -477,6 +477,149 @@ impl Default for TestCallCount { } } +////////////////////////////////// +// diskann-record Save/Load // +////////////////////////////////// + +impl diskann_record::save::Save for StartPoints { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + Ok(diskann_record::save_fields!(self, context, [start, end])) + } +} + +impl diskann_record::load::Load<'_> for StartPoints { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn load( + object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + diskann_record::load_fields!(object, [start: u32, end: u32]); + Ok(Self { start, end }) + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + +impl diskann_record::save::Save for NoStore { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + _context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + Ok(diskann_record::save::Record::empty()) + } +} + +impl diskann_record::load::Load<'_> for NoStore { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn load( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Ok(Self) + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + +impl diskann_record::save::Save for NoDeletes { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + _context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + Ok(diskann_record::save::Record::empty()) + } +} + +impl diskann_record::load::Load<'_> for NoDeletes { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn load( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Ok(Self) + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + +/// Stable wire names for [`PrefetchCacheLineLevel`] variants. +const PREFETCH_CACHE_LINE_4: &str = "CacheLine4"; +const PREFETCH_CACHE_LINE_8: &str = "CacheLine8"; +const PREFETCH_CACHE_LINE_16: &str = "CacheLine16"; +const PREFETCH_CACHE_LINE_ALL: &str = "All"; + +impl diskann_record::save::Save for PrefetchCacheLineLevel { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + _context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + Ok(diskann_record::save::Record::empty()) + } + + fn variant(&self) -> Option> { + Some( + match self { + Self::CacheLine4 => PREFETCH_CACHE_LINE_4, + Self::CacheLine8 => PREFETCH_CACHE_LINE_8, + Self::CacheLine16 => PREFETCH_CACHE_LINE_16, + Self::All => PREFETCH_CACHE_LINE_ALL, + } + .into(), + ) + } +} + +impl diskann_record::load::Load<'_> for PrefetchCacheLineLevel { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + const IS_ENUM: bool = true; + + fn load( + object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + let variant = object + .variant() + .ok_or(diskann_record::load::error::Kind::MissingVariant)?; + match variant { + PREFETCH_CACHE_LINE_4 => Ok(Self::CacheLine4), + PREFETCH_CACHE_LINE_8 => Ok(Self::CacheLine8), + PREFETCH_CACHE_LINE_16 => Ok(Self::CacheLine16), + PREFETCH_CACHE_LINE_ALL => Ok(Self::All), + other => Err(diskann_record::load::Error::message(format!( + "unknown PrefetchCacheLineLevel variant: {other:?}" + ))), + } + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + #[cfg(test)] mod tests { use std::num::NonZeroUsize; @@ -508,4 +651,47 @@ mod tests { ); } } + + ///////////////////////////////// + // diskann-record round-trips // + ///////////////////////////////// + + fn round_trip_helper(value: &T) -> T + where + T: diskann_record::save::Saveable + for<'a> diskann_record::load::Loadable<'a>, + { + let dir = tempfile::tempdir().expect("tempdir"); + let manifest = dir.path().join("manifest.json"); + diskann_record::save::save_to_disk(value, dir.path(), &manifest) + .expect("save_to_disk"); + diskann_record::load::load_from_disk::(&manifest, dir.path()) + .expect("load_from_disk") + } + + #[test] + fn start_points_round_trips_through_record() { + let original = StartPoints::new(10, NonZeroUsize::new(5).unwrap()).unwrap(); + let restored = round_trip_helper(&original); + assert_eq!(restored.start(), original.start()); + assert_eq!(restored.end(), original.end()); + } + + #[test] + fn start_points_round_trips_at_boundary() { + // Boundary: valid_points = 0, single frozen point. + let original = StartPoints::new(0, NonZeroUsize::new(1).unwrap()).unwrap(); + let restored = round_trip_helper(&original); + assert_eq!(restored.start(), 0); + assert_eq!(restored.end(), 1); + } + + #[test] + fn no_store_round_trips_through_record() { + let _restored = round_trip_helper(&NoStore); + } + + #[test] + fn no_deletes_round_trips_through_record() { + let _restored = round_trip_helper(&NoDeletes); + } } diff --git a/diskann-providers/src/model/graph/provider/async_/fast_memory_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/fast_memory_vector_provider.rs index 688438ea5..48038fe03 100644 --- a/diskann-providers/src/model/graph/provider/async_/fast_memory_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/fast_memory_vector_provider.rs @@ -46,6 +46,10 @@ pub struct FastMemoryVectorProviderAsync { // because the Mutex is never held across an await. write_locks: Vec>, + /// The metric used by `distance`. Retained so save/load can round-trip it; + /// `distance` itself is opaque and cannot be inverted back to a `Metric`. + metric: Metric, + // The distance object used to compare two vector representations. distance: ::Distance, @@ -76,6 +80,7 @@ impl FastMemoryVectorProviderAsync { max_vectors, vectors, write_locks, + metric, distance: Data::VectorDataType::distance(metric, Some(dim)), num_get_calls: TestCallCount::default(), prefetch_cache_line_level: prefetch_cache_line_level.unwrap_or_default(), @@ -327,6 +332,85 @@ impl storage::bin::GetData for FastMemoryVectorProviderAsyn } } +////////////////////////////////// +// diskann-record Save/Load // +////////////////////////////////// + +/// On-disk filename used for the vector blob inside the manifest directory. +/// +/// The current contract is "one vector store per manifest". Once we support +/// multiple sibling vector stores in the same manifest (e.g. full-precision + +/// quantized), this will need parent-scoped naming. +const VECTORS_ARTIFACT: &str = "vectors.bin"; + +impl diskann_record::save::Save for FastMemoryVectorProviderAsync +where + Data: GraphDataType, +{ + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + // `metric` is required to reconstruct the opaque `distance` on load. + // The prefetch tunables are part of the provider's configured state and + // are round-tripped so loaders observe the same behaviour as savers. + let mut record = diskann_record::save_fields!( + self, + context, + [metric, prefetch_cache_line_level, prefetch_lookahead] + ); + let mut writer = context.write(VECTORS_ARTIFACT)?; + { + let shim = crate::storage::SingleUseWriteProvider::new(VECTORS_ARTIFACT, &mut writer); + self.save_to_bin(&shim, VECTORS_ARTIFACT) + .map_err(diskann_record::save::Error::new)?; + } + let handle = writer.finish()?; + record.insert("vectors", handle)?; + Ok(record) + } +} + +impl diskann_record::load::Load<'_> for FastMemoryVectorProviderAsync +where + Data: GraphDataType, +{ + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn load( + object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + diskann_record::load_fields!( + object, + [ + metric: Metric, + prefetch_cache_line_level: PrefetchCacheLineLevel, + prefetch_lookahead: usize, + vectors: diskann_record::save::Handle, + ] + ); + let mut reader = object.read(&vectors)?; + let shim = crate::storage::SingleUseReadProvider::new(VECTORS_ARTIFACT, &mut reader) + .map_err(diskann_record::load::Error::new)?; + Self::load_from_bin( + &shim, + VECTORS_ARTIFACT, + metric, + Some(prefetch_cache_line_level), + Some(prefetch_lookahead), + ) + .map_err(diskann_record::load::Error::new) + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + #[cfg(test)] mod tests { use std::{num::NonZeroUsize, sync::Arc}; @@ -504,4 +588,84 @@ mod tests { check_providers_equal(&provider, &reloaded); } + + ///////////////////////////////// + // diskann-record round-trips // + ///////////////////////////////// + + type TestVectorProvider = FastMemoryVectorProviderAsync; + + fn round_trip_helper(provider: &TestVectorProvider) -> TestVectorProvider { + let dir = tempfile::tempdir().expect("tempdir"); + let manifest = dir.path().join("manifest.json"); + diskann_record::save::save_to_disk(provider, dir.path(), &manifest) + .expect("save_to_disk"); + diskann_record::load::load_from_disk::(&manifest, dir.path()) + .expect("load_from_disk") + } + + #[test] + fn fast_memory_vector_provider_round_trips_default_initialised() { + let provider = TestVectorProvider::new(4, 3, Metric::L2, None, None); + let restored = round_trip_helper(&provider); + check_providers_equal(&provider, &restored); + } + + #[test] + fn fast_memory_vector_provider_round_trips_populated() { + let provider = TestVectorProvider::new(5, 4, Metric::L2, None, None); + // SAFETY: Single-threaded test, no aliasing of mutable slices. + unsafe { + for i in 0..provider.total() { + let row: Vec = (0..provider.dim()) + .map(|j| ((i * 17 + j * 3) as f32) * 0.125) + .collect(); + provider.set_vector_sync(i, &row).unwrap(); + } + } + let restored = round_trip_helper(&provider); + check_providers_equal(&provider, &restored); + } + + #[test] + fn fast_memory_vector_provider_round_trips_with_singleton_dim() { + let provider = TestVectorProvider::new(7, 1, Metric::L2, None, None); + // SAFETY: Single-threaded test, no aliasing of mutable slices. + unsafe { + for i in 0..provider.total() { + provider.set_vector_sync(i, &[i as f32 + 0.5]).unwrap(); + } + } + let restored = round_trip_helper(&provider); + check_providers_equal(&provider, &restored); + } + + #[test] + fn fast_memory_vector_provider_round_trips_preserves_metric() { + // Metric must be carried through the manifest because `distance` cannot + // be inverted back to a `Metric` value. + let provider = TestVectorProvider::new(3, 2, Metric::Cosine, None, None); + let restored = round_trip_helper(&provider); + assert_eq!(provider.metric, restored.metric); + check_providers_equal(&provider, &restored); + } + + #[test] + fn fast_memory_vector_provider_round_trips_preserves_prefetch_tunables() { + // Non-default prefetch tunables round-trip exactly through the manifest. + let provider = TestVectorProvider::new( + 3, + 2, + Metric::L2, + Some(PrefetchCacheLineLevel::CacheLine4), + Some(42), + ); + let restored = round_trip_helper(&provider); + assert_eq!( + provider.prefetch_cache_line_level, + restored.prefetch_cache_line_level + ); + assert_eq!(provider.prefetch_lookahead, restored.prefetch_lookahead); + check_providers_equal(&provider, &restored); + } } diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/provider.rs b/diskann-providers/src/model/graph/provider/async_/inmem/provider.rs index c05a0e077..c5ec48db1 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/provider.rs @@ -730,6 +730,85 @@ where } } +////////////////////////////////// +// diskann-record Save/Load // +////////////////////////////////// +// +// NOTE: Sub-providers are saved through the same `Context`, so artifact filenames must +// not collide. Today `FastMemoryVectorProviderAsync` writes `vectors.bin` and +// `SimpleNeighborProviderAsync` writes `graph.bin`. This is sufficient for the +// `FullPrecisionProvider` shape used by current tests. When a +// non-`NoStore` aux vector store is wired in, both `U` and `V` would attempt to write +// `vectors.bin` and `Context::write` would reject the second call. At that point we +// will need parent-scoped artifact names (e.g. UUIDs --> TODO). + +impl diskann_record::save::Save for DefaultProvider +where + U: diskann_record::save::Save, + V: diskann_record::save::Save, + D: diskann_record::save::Save, +{ + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + Ok(diskann_record::save_fields!( + self, + context, + [ + base_vectors, + aux_vectors, + neighbor_provider, + deleted, + metric, + start_points, + ] + )) + } +} + +impl<'a, U, V, D, Ctx> diskann_record::load::Load<'a> for DefaultProvider +where + U: diskann_record::load::Load<'a>, + V: diskann_record::load::Load<'a>, + D: diskann_record::load::Load<'a>, +{ + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn load( + object: diskann_record::load::Object<'a>, + ) -> diskann_record::load::Result { + diskann_record::load_fields!( + object, + [ + base_vectors: U, + aux_vectors: V, + neighbor_provider: SimpleNeighborProviderAsync, + deleted: D, + metric: Metric, + start_points: StartPoints, + ] + ); + Ok(Self { + base_vectors, + aux_vectors, + neighbor_provider, + deleted, + metric, + start_points, + context: std::marker::PhantomData, + }) + } + + fn load_legacy( + _object: diskann_record::load::Object<'a>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + /////////// // Tests // /////////// @@ -839,4 +918,97 @@ mod tests { .is_err() ); } + + ///////////////////////////////// + // diskann-record round-trips // + ///////////////////////////////// + + use crate::model::graph::provider::async_::inmem::FullPrecisionStore; + + type TestProvider = + DefaultProvider, NoStore, NoDeletes, DefaultContext>; + + fn build_full_precision_provider( + max_points: usize, + frozen_points: usize, + dim: usize, + max_degree: u32, + ) -> TestProvider { + DefaultProvider::<_, _, _, DefaultContext>::new_empty( + DefaultProviderParameters { + max_points, + frozen_points: NonZeroUsize::new(frozen_points).unwrap(), + dim, + metric: Metric::L2, + prefetch_lookahead: None, + max_degree, + prefetch_cache_line_level: None, + }, + CreateFullPrecision::::new(dim, None), + NoStore, + NoDeletes, + ) + .unwrap() + } + + fn populate_provider(provider: &TestProvider, dim: usize) { + let total = provider.total_points(); + for i in 0..total as u32 { + let v: Vec = (0..dim).map(|j| (i as f32) * 10.0 + j as f32).collect(); + provider.base_vectors.set_element(&i, &v).unwrap(); + let neighbors: Vec = + (0..3).map(|j| (i + j + 1) % total as u32).collect(); + provider + .neighbor_provider + .set_neighbors_sync(i as usize, &neighbors) + .unwrap(); + } + } + + fn assert_providers_match(left: &TestProvider, right: &TestProvider) { + assert_eq!(left.metric, right.metric); + assert_eq!(left.start_points.range(), right.start_points.range()); + assert_eq!(left.total_points(), right.total_points()); + + for i in 0..left.total_points() { + // SAFETY: Single-threaded test; no concurrent mutation of either provider. + unsafe { + let a = left.base_vectors.get_vector_sync(i); + let b = right.base_vectors.get_vector_sync(i); + assert_eq!(a, b, "base_vectors differ at {i}"); + } + + let mut a = AdjacencyList::new(); + let mut b = AdjacencyList::new(); + left.neighbor_provider.get_neighbors_sync(i, &mut a).unwrap(); + right.neighbor_provider.get_neighbors_sync(i, &mut b).unwrap(); + assert_eq!(a, b, "adjacency list at {i} differs"); + } + } + + fn round_trip_helper(provider: &TestProvider) -> TestProvider { + let dir = tempfile::tempdir().expect("tempdir"); + let manifest = dir.path().join("manifest.json"); + diskann_record::save::save_to_disk(provider, dir.path(), &manifest) + .expect("save_to_disk"); + diskann_record::load::load_from_disk::(&manifest, dir.path()) + .expect("load_from_disk") + } + + #[test] + fn default_provider_round_trips_populated() { + let dim = 4; + let provider = build_full_precision_provider(5, 1, dim, 8); + populate_provider(&provider, dim); + let restored = round_trip_helper(&provider); + assert_providers_match(&provider, &restored); + } + + #[test] + fn default_provider_round_trips_default_initialised() { + let dim = 3; + let provider = build_full_precision_provider(4, 1, dim, 6); + let restored = round_trip_helper(&provider); + assert_providers_match(&provider, &restored); + } } diff --git a/diskann-providers/src/model/graph/provider/async_/memory_vector_provider.rs b/diskann-providers/src/model/graph/provider/async_/memory_vector_provider.rs index 993d9fb28..c3cf4aea9 100644 --- a/diskann-providers/src/model/graph/provider/async_/memory_vector_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/memory_vector_provider.rs @@ -185,6 +185,63 @@ impl storage::bin::GetData for MemoryVectorProviderAsync diskann_record::save::Save for MemoryVectorProviderAsync +where + Data: GraphDataType, +{ + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + let mut writer = context.write(VECTORS_ARTIFACT)?; + { + let shim = crate::storage::SingleUseWriteProvider::new(VECTORS_ARTIFACT, &mut writer); + self.save_to_bin(&shim, VECTORS_ARTIFACT) + .map_err(diskann_record::save::Error::new)?; + } + let handle = writer.finish()?; + let mut record = diskann_record::save::Record::empty(); + record.insert("vectors", handle)?; + Ok(record) + } +} + +impl diskann_record::load::Load<'_> for MemoryVectorProviderAsync +where + Data: GraphDataType, +{ + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn load( + object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + diskann_record::load_fields!(object, [vectors: diskann_record::save::Handle]); + let mut reader = object.read(&vectors)?; + let shim = crate::storage::SingleUseReadProvider::new(VECTORS_ARTIFACT, &mut reader) + .map_err(diskann_record::load::Error::new)?; + Self::load_from_bin(&shim, VECTORS_ARTIFACT).map_err(diskann_record::load::Error::new) + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + /////////// // Tests // /////////// @@ -345,4 +402,59 @@ mod tests { check_providers_equal(&provider, &reloaded); } + + ///////////////////////////////// + // diskann-record round-trips // + ///////////////////////////////// + + type TestVectorProvider = MemoryVectorProviderAsync; + + fn assert_providers_match(left: &TestVectorProvider, right: &TestVectorProvider) { + assert_eq!(left.total(), right.total()); + assert_eq!(left.dim(), right.dim()); + for i in 0..left.total() { + let l = left.get_vector_sync(i).unwrap(); + let r = right.get_vector_sync(i).unwrap(); + assert_eq!(&*l, &*r, "vectors at index {i} differ"); + } + } + + fn round_trip(provider: &TestVectorProvider) -> TestVectorProvider { + let dir = tempfile::tempdir().expect("tempdir"); + let manifest = dir.path().join("manifest.json"); + diskann_record::save::save_to_disk(provider, dir.path(), &manifest) + .expect("save_to_disk"); + diskann_record::load::load_from_disk::(&manifest, dir.path()) + .expect("load_from_disk") + } + + #[test] + fn memory_vector_provider_round_trips_default_initialised() { + let provider = TestVectorProvider::new(4, 3); + let restored = round_trip(&provider); + assert_providers_match(&provider, &restored); + } + + #[test] + fn memory_vector_provider_round_trips_populated() { + let provider = TestVectorProvider::new(5, 4); + for i in 0..provider.total() { + let row: Vec = (0..provider.dim()) + .map(|j| ((i * 17 + j * 3) as f32) * 0.125) + .collect(); + provider.set_vector_sync(i, &row).unwrap(); + } + let restored = round_trip(&provider); + assert_providers_match(&provider, &restored); + } + + #[test] + fn memory_vector_provider_round_trips_with_singleton_dim() { + let provider = TestVectorProvider::new(7, 1); + for i in 0..provider.total() { + provider.set_vector_sync(i, &[i as f32 + 0.5]).unwrap(); + } + let restored = round_trip(&provider); + assert_providers_match(&provider, &restored); + } } diff --git a/diskann-providers/src/model/graph/provider/async_/simple_neighbor_provider.rs b/diskann-providers/src/model/graph/provider/async_/simple_neighbor_provider.rs index 2584caa4f..a84636015 100644 --- a/diskann-providers/src/model/graph/provider/async_/simple_neighbor_provider.rs +++ b/diskann-providers/src/model/graph/provider/async_/simple_neighbor_provider.rs @@ -361,6 +361,58 @@ impl storage::bin::GetAdjacencyList for DiskAdaptor<'_> { } } +////////////////////////////////// +// diskann-record Save/Load // +////////////////////////////////// + +/// On-disk filename used for the adjacency-list blob inside the manifest directory. +const GRAPH_ARTIFACT: &str = "graph.bin"; + +impl diskann_record::save::Save for SimpleNeighborProviderAsync { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + let mut writer = context.write(GRAPH_ARTIFACT)?; + { + let shim = crate::storage::SingleUseWriteProvider::new(GRAPH_ARTIFACT, &mut writer); + // The canonical graph file format records a `start_point` in its header that + // `load_direct` reads but discards. The graph's start point is tracked + // separately by `StartPoints` and persisted at the `DefaultProvider` level, so + // we write `0` here as a placeholder. Round-tripped `SimpleNeighborProviderAsync` + // values do not carry a start point on their own. + self.save_direct(&shim, 0, GRAPH_ARTIFACT) + .map_err(diskann_record::save::Error::new)?; + } + let handle = writer.finish()?; + let mut record = diskann_record::save::Record::empty(); + record.insert("graph", handle)?; + Ok(record) + } +} + +impl diskann_record::load::Load<'_> for SimpleNeighborProviderAsync { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn load( + object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + diskann_record::load_fields!(object, [graph: diskann_record::save::Handle]); + let mut reader = object.read(&graph)?; + let shim = crate::storage::SingleUseReadProvider::new(GRAPH_ARTIFACT, &mut reader) + .map_err(diskann_record::load::Error::new)?; + Self::load_direct(&shim, GRAPH_ARTIFACT).map_err(diskann_record::load::Error::new) + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + /////////// // Tests // /////////// @@ -445,4 +497,96 @@ mod tests { ); } } + + ///////////////////////////////// + // diskann-record round-trips // + ///////////////////////////////// + + type TestNeighborProvider = SimpleNeighborProviderAsync; + + fn round_trip_helper(provider: &TestNeighborProvider) -> TestNeighborProvider { + let dir = tempfile::tempdir().expect("tempdir"); + let manifest = dir.path().join("manifest.json"); + diskann_record::save::save_to_disk(provider, dir.path(), &manifest) + .expect("save_to_disk"); + diskann_record::load::load_from_disk::(&manifest, dir.path()) + .expect("load_from_disk") + } + + fn assert_adjacency_lists_match( + left: &TestNeighborProvider, + right: &TestNeighborProvider, + total: usize, + ) { + for i in 0..total { + let mut l = AdjacencyList::new(); + let mut r = AdjacencyList::new(); + left.get_neighbors_sync(i, &mut l).unwrap(); + right.get_neighbors_sync(i, &mut r).unwrap(); + assert_eq!(l, r, "adjacency list for node {} differs after round-trip", i); + } + } + + #[test] + fn simple_neighbor_provider_round_trips_default_initialised() { + let max_points = 4; + let additional_points = 1; + let provider = TestNeighborProvider::new(max_points, additional_points, 5, 1.0); + let restored = round_trip_helper(&provider); + assert_adjacency_lists_match(&provider, &restored, max_points + additional_points); + } + + #[test] + fn simple_neighbor_provider_round_trips_populated() { + let max_points = 8; + let additional_points = 2; + let max_degree = 5; + let provider = + TestNeighborProvider::new(max_points, additional_points, max_degree, 1.0); + for i in 0..max_points + additional_points { + let neighbors: Vec = (1..4).map(|j| i as u32 + j).collect(); + provider.set_neighbors_sync(i, &neighbors).unwrap(); + } + let restored = round_trip_helper(&provider); + assert_adjacency_lists_match(&provider, &restored, max_points + additional_points); + } + + #[test] + fn simple_neighbor_provider_round_trips_preserves_row_width() { + // The on-disk format records `max_degree = dim - 1`, so a non-unit slack factor + // round-trips through `save_direct` / `load_direct` even though the loader uses + // `slack = 1.0`. + let max_points = 3; + let additional_points = 1; + let max_degree = 4; + let slack = 1.5; + let provider = + TestNeighborProvider::new(max_points, additional_points, max_degree, slack); + // dim = (max_degree * slack) as usize + 1 = 7 + let expected_dim = (max_degree as f32 * slack) as usize + 1; + // Fill each row to its inflated capacity to make sure the wider row width matters. + for i in 0..max_points + additional_points { + let neighbors: Vec = (0..(expected_dim - 1) as u32) + .map(|j| (i as u32 * 10) + j) + .collect(); + provider.set_neighbors_sync(i, &neighbors).unwrap(); + } + let restored = round_trip_helper(&provider); + assert_adjacency_lists_match(&provider, &restored, max_points + additional_points); + } + + #[test] + fn simple_neighbor_provider_round_trips_with_jagged_adjacency_lists() { + let max_points = 6; + let additional_points = 1; + let provider = TestNeighborProvider::new(max_points, additional_points, 5, 1.0); + // Mix of empty, short, and full adjacency lists. + provider.set_neighbors_sync(0, &[]).unwrap(); + provider.set_neighbors_sync(1, &[10]).unwrap(); + provider.set_neighbors_sync(2, &[20, 21]).unwrap(); + provider.set_neighbors_sync(3, &[30, 31, 32, 33]).unwrap(); + // Leave id=4, 5, 6 with the default-empty adjacency lists. + let restored = round_trip_helper(&provider); + assert_adjacency_lists_match(&provider, &restored, max_points + additional_points); + } } diff --git a/diskann-providers/src/storage/bin.rs b/diskann-providers/src/storage/bin.rs index bf1f89d14..fe4e213d6 100644 --- a/diskann-providers/src/storage/bin.rs +++ b/diskann-providers/src/storage/bin.rs @@ -13,7 +13,7 @@ use diskann::{ }; use diskann_utils::io::Metadata; -use crate::{model::graph::traits::AdHoc, utils::load_metadata_from_file}; +use crate::model::graph::traits::AdHoc; /// An simplified adaptor interface for allowing providers to use and [`load_graph`]. /// @@ -135,22 +135,25 @@ where S: SetData, T: VectorRepr, { - let metadata = load_metadata_from_file(provider, path).map_err(|err| { - ANNError::log_index_error(format_args!( - "failed to load data file \"{}\" due to the following error: {}", - path, err - )) - })?; + let itr = crate::utils::VectorDataIterator::<_, AdHoc>::new(path, None, provider) + .map_err(|err| { + ANNError::log_index_error(format_args!( + "failed to load data file \"{}\" due to the following error: {}", + path, err + )) + })?; + + let num_points = itr.get_num_points(); + let dimension = itr.get_dimension(); tracing::info!( "Loading {} vectors with dimension {} from storage system {} into dataset...", - metadata.npoints(), - metadata.ndims(), + num_points, + dimension, path ); - let mut data = create(metadata.npoints(), metadata.ndims())?; - let itr = crate::utils::VectorDataIterator::<_, AdHoc>::new(path, None, provider)?; + let mut data = create(num_points, dimension)?; for (i, (vector, _)) in itr.enumerate() { data.set_data(i.into_usize(), &vector)?; } diff --git a/diskann-providers/src/storage/mod.rs b/diskann-providers/src/storage/mod.rs index 1233b11f6..75dd80084 100644 --- a/diskann-providers/src/storage/mod.rs +++ b/diskann-providers/src/storage/mod.rs @@ -8,6 +8,11 @@ pub use storage_provider::{ DynWriteProvider, StorageReadProvider, StorageWriteProvider, WriteProviderWrapper, WriteSeek, }; +mod record_shim; +pub use record_shim::{ + ReadSeek, SingleUseReadProvider, SingleUseReader, SingleUseWriteProvider, SingleUseWriter, +}; + #[cfg(any(test, feature = "virtual_storage"))] mod virtual_storage_provider; #[cfg(any(test, feature = "virtual_storage"))] diff --git a/diskann-providers/src/storage/record_shim.rs b/diskann-providers/src/storage/record_shim.rs new file mode 100644 index 000000000..6c3474801 --- /dev/null +++ b/diskann-providers/src/storage/record_shim.rs @@ -0,0 +1,394 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Adapters that expose a single in-hand [`Read`]/[`Write`] target as a +//! [`StorageReadProvider`]/[`StorageWriteProvider`]. +//! +//! These exist to bridge the new `diskann-record` save/load APIs (which produce +//! borrowed `Writer`/`Reader` handles backed by a manifest) and the existing +//! byte-level helpers (which expect a [`StorageReadProvider`] / +//! [`StorageWriteProvider`] and a path). +//! +//! Each adapter is a "single use" thing: it lets the wrapped writer/reader be +//! handed out exactly once, for a single known item name. A second call, or a +//! call with a mismatched name, returns an [`io::Error`]. This is the right +//! contract for our use case — every leaf component's `Save` impl produces one +//! artifact through one helper invocation. +//! +//! When the existing helpers are migrated to a direct `Read`/`Write` API, +//! these adapters (and the underlying `StorageRead/WriteProvider` traits) can +//! be deprecated. + +use std::{ + io::{self, Read, Seek, SeekFrom, Write}, + sync::Mutex, +}; + +use crate::storage::{StorageReadProvider, StorageWriteProvider, WriteSeek}; + +/// Trait alias for types that implement both [`Read`] and [`Seek`], mirroring +/// the existing [`WriteSeek`] alias used by the write side. +pub trait ReadSeek: Read + Seek {} +impl ReadSeek for T where T: Read + Seek {} + +////////////////////////// +// Write side // +////////////////////////// + +/// A [`StorageWriteProvider`] backed by a single borrowed writer. +/// +/// The wrapped writer is handed out exactly once via [`create_for_write`] +/// (or [`open_writer`]) for the configured `name`. Any other call — repeated, +/// or with a mismatched name — returns an [`io::Error`]. +/// +/// [`create_for_write`]: StorageWriteProvider::create_for_write +/// [`open_writer`]: StorageWriteProvider::open_writer +pub struct SingleUseWriteProvider<'w> { + name: String, + inner: Mutex>, +} + +impl<'w> SingleUseWriteProvider<'w> { + /// Wrap `writer` so that calls to `create_for_write(name)` (or + /// `open_writer(name)`) on this provider yield it exactly once. + pub fn new(name: impl Into, writer: &'w mut W) -> Self + where + W: WriteSeek + Send, + { + Self { + name: name.into(), + inner: Mutex::new(Some(writer)), + } + } + + fn take_for(&self, requested: &str) -> io::Result<&'w mut (dyn WriteSeek + Send)> { + if requested != self.name { + return Err(io::Error::new( + io::ErrorKind::NotFound, + format!( + "SingleUseWriteProvider only serves {:?}; got request for {:?}", + self.name, requested, + ), + )); + } + // Lint: PoisonError here would mean the underlying writer is in a + // broken state from a previous panicked write; surfacing it as an + // io::Error rather than panicking lets callers handle it. + let mut guard = self + .inner + .lock() + .map_err(|_| io::Error::other("SingleUseWriteProvider lock poisoned"))?; + guard.take().ok_or_else(|| { + io::Error::other(format!( + "SingleUseWriteProvider for {:?} has already been used", + self.name, + )) + }) + } +} + +impl std::fmt::Debug for SingleUseWriteProvider<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SingleUseWriteProvider") + .field("name", &self.name) + .finish_non_exhaustive() + } +} + +/// Borrowed writer produced by [`SingleUseWriteProvider`]. +pub struct SingleUseWriter<'w>(&'w mut (dyn WriteSeek + Send)); + +impl std::fmt::Debug for SingleUseWriter<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("SingleUseWriter(..)") + } +} + +impl Write for SingleUseWriter<'_> { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } + fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { + self.0.write_vectored(bufs) + } + fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { + self.0.write_all(buf) + } + fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> io::Result<()> { + self.0.write_fmt(fmt) + } +} + +impl Seek for SingleUseWriter<'_> { + fn seek(&mut self, pos: SeekFrom) -> io::Result { + self.0.seek(pos) + } + fn rewind(&mut self) -> io::Result<()> { + self.0.rewind() + } + fn stream_position(&mut self) -> io::Result { + self.0.stream_position() + } + fn seek_relative(&mut self, offset: i64) -> io::Result<()> { + self.0.seek_relative(offset) + } +} + +impl<'w> StorageWriteProvider for SingleUseWriteProvider<'w> { + type Writer = SingleUseWriter<'w>; + + fn open_writer(&self, item_identifier: &str) -> io::Result { + self.take_for(item_identifier).map(SingleUseWriter) + } + + fn create_for_write(&self, item_identifier: &str) -> io::Result { + self.take_for(item_identifier).map(SingleUseWriter) + } + + fn delete(&self, _item_identifier: &str) -> io::Result<()> { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "SingleUseWriteProvider does not support delete", + )) + } +} + +////////////////////////// +// Read side // +////////////////////////// + +/// A [`StorageReadProvider`] backed by a single borrowed reader. +/// +/// Symmetric to [`SingleUseWriteProvider`]: hands out the wrapped reader +/// exactly once via [`open_reader`] for the configured `name`. +/// +/// [`open_reader`]: StorageReadProvider::open_reader +pub struct SingleUseReadProvider<'r> { + name: String, + length: u64, + inner: Mutex>, +} + +impl<'r> SingleUseReadProvider<'r> { + /// Wrap `reader` so that a call to `open_reader(name)` on this provider + /// yields it exactly once. + /// + /// The constructor probes the reader for its byte length (by seeking to + /// the end and restoring the original cursor) so that subsequent + /// [`StorageReadProvider::get_length`] calls can be answered without + /// re-touching the stream. Fails if either seek fails. + pub fn new(name: impl Into, reader: &'r mut R) -> io::Result + where + R: ReadSeek + Send, + { + let start = reader.stream_position()?; + let end = reader.seek(SeekFrom::End(0))?; + reader.seek(SeekFrom::Start(start))?; + Ok(Self { + name: name.into(), + length: end, + inner: Mutex::new(Some(reader)), + }) + } + + fn take_for(&self, requested: &str) -> io::Result<&'r mut (dyn ReadSeek + Send)> { + if requested != self.name { + return Err(io::Error::new( + io::ErrorKind::NotFound, + format!( + "SingleUseReadProvider only serves {:?}; got request for {:?}", + self.name, requested, + ), + )); + } + let mut guard = self + .inner + .lock() + .map_err(|_| io::Error::other("SingleUseReadProvider lock poisoned"))?; + guard.take().ok_or_else(|| { + io::Error::other(format!( + "SingleUseReadProvider for {:?} has already been used", + self.name, + )) + }) + } +} + +impl std::fmt::Debug for SingleUseReadProvider<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SingleUseReadProvider") + .field("name", &self.name) + .field("length", &self.length) + .finish_non_exhaustive() + } +} + +/// Borrowed reader produced by [`SingleUseReadProvider`]. +pub struct SingleUseReader<'r>(&'r mut (dyn ReadSeek + Send)); + +impl std::fmt::Debug for SingleUseReader<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("SingleUseReader(..)") + } +} + +impl Read for SingleUseReader<'_> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } + fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result { + self.0.read_vectored(bufs) + } + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + self.0.read_to_end(buf) + } + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + self.0.read_to_string(buf) + } + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.0.read_exact(buf) + } +} + +impl Seek for SingleUseReader<'_> { + fn seek(&mut self, pos: SeekFrom) -> io::Result { + self.0.seek(pos) + } + fn rewind(&mut self) -> io::Result<()> { + self.0.rewind() + } + fn stream_position(&mut self) -> io::Result { + self.0.stream_position() + } + fn seek_relative(&mut self, offset: i64) -> io::Result<()> { + self.0.seek_relative(offset) + } +} + +impl<'r> StorageReadProvider for SingleUseReadProvider<'r> { + type Reader = SingleUseReader<'r>; + + fn open_reader(&self, item_identifier: &str) -> io::Result { + self.take_for(item_identifier).map(SingleUseReader) + } + + fn get_length(&self, item_identifier: &str) -> io::Result { + if item_identifier != self.name { + return Err(io::Error::new( + io::ErrorKind::NotFound, + format!( + "SingleUseReadProvider only knows {:?}; got request for {:?}", + self.name, item_identifier, + ), + )); + } + Ok(self.length) + } + + fn exists(&self, item_identifier: &str) -> bool { + item_identifier == self.name + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use std::io::{Read, Seek, SeekFrom, Write}; + + use super::*; + + #[test] + fn write_provider_routes_one_create_to_the_writer() { + let mut buf = std::io::Cursor::new(Vec::::new()); + let provider = SingleUseWriteProvider::new("only.bin", &mut buf); + { + let mut w = provider + .create_for_write("only.bin") + .expect("first create_for_write"); + w.write_all(b"hello").unwrap(); + w.write_all(b" world").unwrap(); + w.flush().unwrap(); + } + drop(provider); + assert_eq!(buf.into_inner(), b"hello world"); + } + + #[test] + fn write_provider_rejects_second_take() { + let mut buf = std::io::Cursor::new(Vec::::new()); + let provider = SingleUseWriteProvider::new("only.bin", &mut buf); + let _first = provider.create_for_write("only.bin").unwrap(); + let err = provider + .create_for_write("only.bin") + .expect_err("second take must fail"); + assert!( + err.to_string().contains("already been used"), + "unexpected error: {err}" + ); + } + + #[test] + fn write_provider_rejects_unknown_name() { + let mut buf = std::io::Cursor::new(Vec::::new()); + let provider = SingleUseWriteProvider::new("only.bin", &mut buf); + let err = provider + .create_for_write("other.bin") + .expect_err("mismatched name must fail"); + assert_eq!(err.kind(), io::ErrorKind::NotFound); + } + + #[test] + fn write_provider_does_not_support_delete() { + let mut buf = std::io::Cursor::new(Vec::::new()); + let provider = SingleUseWriteProvider::new("only.bin", &mut buf); + let err = provider.delete("only.bin").expect_err("delete must fail"); + assert_eq!(err.kind(), io::ErrorKind::Unsupported); + } + + #[test] + fn read_provider_routes_one_open_to_the_reader() { + let mut src = std::io::Cursor::new(b"abcdefg".to_vec()); + let provider = SingleUseReadProvider::new("only.bin", &mut src).expect("new"); + + assert!(provider.exists("only.bin")); + assert!(!provider.exists("other.bin")); + assert_eq!(provider.get_length("only.bin").unwrap(), 7); + + let mut r = provider.open_reader("only.bin").unwrap(); + let mut out = Vec::new(); + r.read_to_end(&mut out).unwrap(); + assert_eq!(out, b"abcdefg"); + } + + #[test] + fn read_provider_rejects_second_take() { + let mut src = std::io::Cursor::new(b"abc".to_vec()); + let provider = SingleUseReadProvider::new("only.bin", &mut src).expect("new"); + let _first = provider.open_reader("only.bin").unwrap(); + let err = provider + .open_reader("only.bin") + .expect_err("second take must fail"); + assert!(err.to_string().contains("already been used")); + } + + #[test] + fn read_provider_preserves_cursor_after_construction() { + let mut src = std::io::Cursor::new(b"123456789".to_vec()); + src.seek(SeekFrom::Start(2)).unwrap(); + let provider = SingleUseReadProvider::new("only.bin", &mut src).expect("new"); + assert_eq!(provider.get_length("only.bin").unwrap(), 9); + let mut r = provider.open_reader("only.bin").unwrap(); + let mut out = Vec::new(); + r.read_to_end(&mut out).unwrap(); + // Cursor was restored to position 2, so we should read from there. + assert_eq!(out, b"3456789"); + } +} diff --git a/diskann-record/Cargo.toml b/diskann-record/Cargo.toml new file mode 100644 index 000000000..f9c9efe72 --- /dev/null +++ b/diskann-record/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "diskann-record" +version.workspace = true +description.workspace = true +authors.workspace = true +documentation.workspace = true +license.workspace = true +edition = "2024" + +[dependencies] +anyhow.workspace = true +serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true + +[dev-dependencies] +tempfile.workspace = true + +[lints] +workspace = true diff --git a/diskann-record/HANDOFF.md b/diskann-record/HANDOFF.md new file mode 100644 index 000000000..cf5f12f1b --- /dev/null +++ b/diskann-record/HANDOFF.md @@ -0,0 +1,212 @@ +# diskann-record: Handoff Document + +## Status + +**Working proof-of-concept.** The crate implements a full save→JSON→load round-trip for +nested structs with binary file artifacts. The end-to-end test in `src/lib.rs` passes. +Many parts are intentionally rough — this is a design exploration, not production code. + +See `rfc.md` at the repo root for the motivating RFC. + +## Architecture Overview + +### Two-Trait Pattern (Save and Load) + +Each side has two traits: + +- **`Save`** / **`Load`** — versioned component traits. Users implement these for their + structs. Each carries a `const VERSION: Version`. +- **`Saveable`** / **`Loadable`** — universal dispatch traits used in generic bounds. + Primitives implement these directly. `Save`/`Load` types get them via blanket impls that + handle version tagging/dispatch automatically. + +This split exists because primitives (numbers, strings) shouldn't carry versions, but +versioned structs and primitives must be storable through the same `Value` tree. + +### Value Model + +`Value<'a>` is a bespoke enum decoupled from `serde_json`: + +```rust +pub enum Value<'a> { + Bool(bool), + Number(Number), // Copy enum: U64, I64, F64 + String(Cow<'a, str>), + Bytes(Cow<'a, [u8]>), + Array(Vec>), + Object(Versioned<'a>), + Handle(Handle), +} +``` + +The lifetime `'a` allows the save path to borrow from the data being saved (zero-copy +strings, etc.). The load path deserializes to `Value<'static>` with owned `Cow`s since +the JSON manifest is kilobytes and zero-copy deserialization isn't worth the ergonomic cost. + +Custom `Serialize`/`Deserialize` impls ensure `Number` serializes as a plain JSON number, +`Handle` as `{"$handle": "name"}`, and `Versioned` as a flat object with `$version` +alongside user fields. + +### Context Objects + +**Save side:** `Context<'a>` wraps `&'a ContextInner` (shared reference, `Clone`). This +enables parallel saves via `rayon::join`. `ContextInner` owns the output directory path and +a `Mutex>` for file name deduplication. `Context::write(name)` returns a +`Writer` (wrapping `BufWriter`), and `Writer::finish() -> Handle` provides a proof +token that the file was fully written. + +**Load side:** `Context<'a>` wraps `&'a ContextInner` + `&'a Value<'a>`. `Object<'a>` +provides typed field access via `field::(key)`. `Array<'a>` provides iteration. +Everything is by-reference into the deserialized `Value` tree ("parse once, probe many"). + +### Error Handling + +**Save errors:** `save::Error` wraps `anyhow::Error`. Simple — something went wrong. + +**Load errors:** `load::Error` has a light/heavy split: +- `Light(Kind)` — cheap, `Copy` enum (`TypeMismatch`, `MissingField`, etc.) for speculative + probing (try one version, fall back to another without allocating error context). +- `Heavy(anyhow::Error)` — rich diagnostics with context chaining for actual failures. + +### Macros + +- `save_fields!(self, context, [x, y, inner])` — saves named fields, wraps errors with + field name context. Inside an enum match arm (where the variant's payload has been + destructured), drop the first argument: `save_fields!(context, [weights])` reads from + the local bindings. +- `load_fields!(object, [x, y, inner, vector: Handle])` — loads named fields with optional + type annotations when inference fails. Works identically for structs and enum variants. + +### Reserved Keys + +Keys starting with `$` are reserved for infrastructure (`$version`, `$variant`, +`$handle`). Since Rust identifiers can't start with `$`, the `save_fields!` macro is +inherently safe — no runtime check needed in the macro path. + +## What Works + +- Full round-trip: `save_to_disk(&t, dir, metadata)` → JSON + binary files → + `load_from_disk(metadata, dir) -> T` +- Nested structs with recursive save/load +- Binary file artifacts (write bytes, get Handle, store in manifest, read back on load) +- Version tagging in manifest (`$version` alongside fields) +- Custom serde for `Value`, `Number`, `Handle` (plain JSON output) +- Primitive `bool` support via `Value::Bool`, including nested `Vec` round-trips +- `Option` support via an explicit `Value::Null` variant +- Tempfile isolation in tests (`tempfile::tempdir()`) to ensure cleanup of artifacts and manifest +- Light/heavy error split on load path +- `Writer::finish()` consumes the inner `BufWriter` and propagates buffered write/flush errors via `save::Result` +- Compile-time `const` assertion in `src/lib.rs` that rejects targets where `usize::BITS != 64`. +- Propagate light errors for duplicate filenames/creation failure, manifest finish, attempting to write to reserved keys, missing files on `load`, out-of-range values for numerics (light errors) +- Enum support via internally-tagged objects (`$variant` alongside `$version`). Save side + exposes `Save::variant() -> Option>` (default `None` = struct), and + `save_fields!` has a two-argument form (`save_fields!(context, [...])`) for use inside + enum match arms after destructuring. Load side adds `Load::IS_ENUM` (default `false`) + plus `Object::variant()`. The blanket `Loadable` impl strictly enforces tag presence: + loading a tagged record as a struct yields `UnexpectedVariant`, loading an untagged + record as an enum yields `MissingVariant`. + +## Remaining Work + +### Value::Bytes Wiring + +`Value::Bytes` exists in the enum and serializes, but there's no end-to-end pattern for +using it (no `Saveable`/`Loadable` impls for `&[u8]` / `Vec` as inline bytes). Should +follow the `Handle` pattern — schema-aware (the Rust type tells the loader it's bytes, not +a JSON array of integers). + +### Enum Support + +Unit and struct variants are supported via the internally-tagged representation +described above. Tuple variants are *not* directly supported: rename the payload +field(s) to a struct variant, or bind to a local before constructing the record. +Multi-field tuple variants are intentionally out of scope — name your fields. + +Open follow-ups: derive macro support for the `variant()` / `IS_ENUM` boilerplate, +and a clearer error for the `UnknownVariant` case (currently a light error with the +string `"unknown variant"` and no embedded name). + +### SemVer Version Dispatch + +`load_legacy()` currently receives everything that isn't an exact version match. There's no +actual SemVer compatibility logic (e.g., "minor bumps are backward compatible within the +same major version"). The blanket `Loadable` impl should route based on SemVer rules, not +just equality. + +### File Name Disambiguation + +`save::ContextInner` currently uses raw file names as-is. The RFC envisions UUID-based +naming to prevent collisions (e.g., `{uuid}-{user_name}.bin`). The `uuid` crate is not yet +in the dependencies. + +### ContextInner Generalization + +Currently the save and load `ContextInner` types are concrete (directory-backed). Future +work includes: +- Trait object behind `Context` for backend swappability +- Packed single-file backend (sequential artifact writes, offset table) +- VFS support (`vfs` crate already in workspace) +- `write_sized(name, size)` API for pre-allocated regions in packed format + +The current design has clean seams for this — `Writer::finish() -> Handle` and +`ContextInner::finish(Value)` are the extension points. + +### Specify runtimes state +`DiskANNIndex::save` does not save `scratch_pool` because it is part of runtime state. At load time, we're skipping this field. We need to add an optional `Aux` to the `load` interface to allow for loading code to specify this runtime state params? + +### Manifest Improvements + +- Record file sizes in the manifest for integrity checking / pre-allocation on load +- Consider preserving field ordering (`HashMap` loses insertion order — + `IndexMap` or `Vec<(K, V)>` would preserve it for human-readable JSON) + +### Derive Macros + +The RFC envisions `#[derive(Save)]` and `#[derive(Load)]` to eliminate boilerplate. The +current macro-rules (`save_fields!`, `load_fields!`) are a stopgap. A proc macro crate +(`diskann-record-derive`) would be the next step. + +## File Map + +``` +diskann-record/ +├── Cargo.toml # deps: serde, serde_json, anyhow +├── HANDOFF.md # this document +└── src/ + ├── lib.rs # module structure, is_reserved(), round-trip test + ├── number.rs # Number enum with custom Serialize/Deserialize + ├── version.rs # Version { major, minor, patch } + ├── save/ + │ ├── mod.rs # Save/Saveable traits, blanket impl, save_fields! macro, + │ │ # primitive impls, save_to_disk entry point + │ ├── value.rs # Value, Record, Versioned, Handle + serde impls + │ ├── context.rs # ContextInner (dir-backed), Context, Writer + │ └── error.rs # save::Error (anyhow newtype) + └── load/ + ├── mod.rs # Load/Loadable traits, blanket impl, load_fields! macro, + │ # primitive impls, load_from_disk entry point + ├── context.rs # ContextInner, Context, Object, Array, Iter, Reader + └── error.rs # load::Error (light/heavy split), Kind enum +``` + +## Key Design Rationale + +Decisions that may not be obvious from the code alone: + +1. **Why two traits per side?** Primitives shouldn't carry versions, but versioned structs + and primitives must coexist in the same `Value` tree. The blanket impl bridges them. + +2. **Why `Cow` in `Value`?** Save path borrows from structs (`Cow::Borrowed`). Load path + owns from JSON (`Cow::Owned`). Same type, different usage patterns. + +3. **Why light/heavy errors on load?** Loading tries the current version first, falls back + to legacy. The first attempt's failure should be near-free (just a `Kind` enum) since + it's expected to fail for older data. Only the final failure needs rich diagnostics. + +4. **Why eager tree building?** The manifest is metadata (kilobytes). Lazy/deferred + serialization adds lifetime complexity for no practical gain. Artifacts (gigabytes) + stream directly to files — they're never in the tree. + +5. **Why separate save/load error types?** The light/heavy split only makes sense for + loading (speculative probing). Save errors are always "something went wrong." Unifying + them would force the save side to carry unused machinery. diff --git a/diskann-record/sample_output/graph.bin b/diskann-record/sample_output/graph.bin new file mode 100644 index 000000000..01499e60b Binary files /dev/null and b/diskann-record/sample_output/graph.bin differ diff --git a/diskann-record/sample_output/manifest.json b/diskann-record/sample_output/manifest.json new file mode 100644 index 000000000..aa50b3308 --- /dev/null +++ b/diskann-record/sample_output/manifest.json @@ -0,0 +1,120 @@ +{ + "files": [ + "graph.bin", + "vectors.bin" + ], + "value": { + "config": { + "pruned_degree": 16, + "prune_kind": { + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + }, + "$variant": "TriangleInequality" + }, + "intra_batch_candidates": { + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + }, + "$variant": "All" + }, + "max_occlusion_size": 750, + "alpha": 1.2000000476837158, + "max_minibatch_par": 1, + "l_build": 20, + "max_degree": 20, + "max_backedges": 16, + "saturate_after_prune": false, + "experimental_insert_retry": null, + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + } + }, + "data_provider": { + "neighbor_provider": { + "graph": { + "$handle": "graph.bin" + }, + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + } + }, + "metric": { + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + }, + "$variant": "L2" + }, + "aux_vectors": { + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + } + }, + "start_points": { + "end": 33, + "start": 32, + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + } + }, + "deleted": { + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + } + }, + "base_vectors": { + "metric": { + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + }, + "$variant": "L2" + }, + "prefetch_lookahead": 8, + "prefetch_cache_line_level": { + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + }, + "$variant": "CacheLine16" + }, + "vectors": { + "$handle": "vectors.bin" + }, + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + } + }, + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + } + }, + "$version": { + "major": 0, + "minor": 0, + "patch": 0 + } + } +} \ No newline at end of file diff --git a/diskann-record/sample_output/vectors.bin b/diskann-record/sample_output/vectors.bin new file mode 100644 index 000000000..c06da6eb7 Binary files /dev/null and b/diskann-record/sample_output/vectors.bin differ diff --git a/diskann-record/src/lib.rs b/diskann-record/src/lib.rs new file mode 100644 index 000000000..ae8f956f1 --- /dev/null +++ b/diskann-record/src/lib.rs @@ -0,0 +1,388 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +mod number; +pub use number::Number; + +mod version; +pub use version::Version; + +pub mod load; +pub mod save; + +// Canonical wire width for `usize` and `isize` in manifests is 64 bits. Saving a value +// on a 64-bit platform and loading it on a 32-bit platform (or vice versa) could silently +// truncate values that exceed `u32::MAX` / `i32::MAX`. We therefore require a 64-bit +// platform at compile time. Loaders still range-check at runtime, but this check ensures +// the saver never emits values that the canonical width cannot represent. +const _: () = assert!( + usize::BITS == 64, + "diskann-record requires a 64-bit target: usize/isize MUST be 64 bits wide !!", +); + +/// Return `true` if `s` is a reserved string for purposes of saving and loading. +#[doc(hidden)] +pub const fn is_reserved(s: &str) -> bool { + if let Some(first) = s.as_bytes().first() + && *first == b"$"[0] + { + true + } else { + false + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::{Read, Write}; + + #[derive(Debug, PartialEq)] + struct Test { + x: String, + y: f32, + enabled: bool, + inner: Inner, + // We write this as a binary file. + vector: Vec, + nickname: Option, + absent: Option, + } + + #[derive(Debug, PartialEq)] + struct Inner { + z: usize, + w: Vec, + flags: Vec, + maybe_count: Option, + maybe_missing: Option, + sparse: Vec>, + } + + impl save::Save for Inner { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!( + self, + context, + [z, w, flags, maybe_count, maybe_missing, sparse] + )) + } + } + + impl save::Save for Test { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + // We save `x`, `y`, and `inner` directly into the manifest. + // The raw vector data we instead store in an auxiliary file. + + let mut io = context.write("auxiliary.bin")?; + io.write_all(&self.vector).map_err(save::Error::new)?; + + let mut record = save_fields!(self, context, [x, y, enabled, inner, nickname, absent]); + record.insert("vector", io.finish()?)?; + Ok(record) + } + } + + impl load::Load<'_> for Test { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!( + object, + [ + x, + y, + enabled, + inner, + nickname: Option, + absent: Option, + vector: save::Handle, + ] + ); + + let mut io = object.read(&vector)?; + let mut vector = Vec::new(); + io.read_to_end(&mut vector).unwrap(); + + Ok(Self { + x, + y, + enabled, + inner, + vector, + nickname, + absent, + }) + } + + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + impl load::Load<'_> for Inner { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!( + object, + [ + z, + w, + flags, + maybe_count: Option, + maybe_missing: Option, + sparse: Vec>, + ] + ); + Ok(Self { + z, + w, + flags, + maybe_count, + maybe_missing, + sparse, + }) + } + + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[test] + fn round_trip_uses_isolated_temp_dir() -> anyhow::Result<()> { + let inner = Inner { + z: 10, + w: vec![-1, -2, -3], + flags: vec![true, false, true], + maybe_count: Some(42), + maybe_missing: None, + sparse: vec![Some(1), None, Some(-3), None], + }; + + let t = Test { + x: "hello".into(), + y: 5.0, + enabled: true, + inner, + vector: vec![0, 1, 2, 3, 4, 5], + nickname: Some("friend".into()), + absent: None, + }; + + // Keep the TempDir guard alive for the full round trip; Drop removes the + // manifest and auxiliary artifact after the assertion completes. + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&t, dir, &metadata)?; + let we_are_back: Test = load::load_from_disk(&metadata, dir)?; + + assert_eq!(t, we_are_back); + Ok(()) + } + + ///////////////////////// + // Enum support: round // + ///////////////////////// + + #[derive(Debug, PartialEq)] + enum Metric { + L2, + Cosine, + Weighted { weights: Vec }, + } + + impl save::Save for Metric { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(match self { + Self::L2 | Self::Cosine => save::Record::empty(), + Self::Weighted { weights } => save_fields!(context, [weights]), + }) + } + fn variant(&self) -> Option> { + Some(match self { + Self::L2 => "L2".into(), + Self::Cosine => "Cosine".into(), + Self::Weighted { .. } => "Weighted".into(), + }) + } + } + + impl load::Load<'_> for Metric { + const VERSION: Version = Version::new(0, 0, 0); + const IS_ENUM: bool = true; + fn load(object: load::Object<'_>) -> load::Result { + let variant = object.variant().ok_or(load::error::Kind::MissingVariant)?; + match variant { + "L2" => Ok(Self::L2), + "Cosine" => Ok(Self::Cosine), + "Weighted" => { + load_fields!(object, [weights: Vec]); + Ok(Self::Weighted { weights }) + } + _ => Err(load::error::Kind::UnknownVariant.into()), + } + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + Err(load::error::Kind::UnknownVersion.into()) + } + } + + #[derive(Debug, PartialEq)] + struct MetricBag { + primary: Metric, + alternatives: Vec, + } + + impl save::Save for MetricBag { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!(self, context, [primary, alternatives])) + } + } + + impl load::Load<'_> for MetricBag { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!(object, [primary: Metric, alternatives: Vec]); + Ok(Self { + primary, + alternatives, + }) + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[test] + fn enum_round_trip_through_disk() -> anyhow::Result<()> { + let bag = MetricBag { + primary: Metric::Weighted { + weights: vec![0.25, 0.5, 0.25], + }, + alternatives: vec![ + Metric::L2, + Metric::Cosine, + Metric::Weighted { weights: vec![1.0] }, + ], + }; + + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&bag, dir, &metadata)?; + let restored: MetricBag = load::load_from_disk(&metadata, dir)?; + + assert_eq!(bag, restored); + Ok(()) + } + + #[derive(Debug, PartialEq)] + struct StructShape { + x: i32, + } + + impl save::Save for StructShape { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(save_fields!(self, context, [x])) + } + } + + impl load::Load<'_> for StructShape { + const VERSION: Version = Version::new(0, 0, 0); + fn load(object: load::Object<'_>) -> load::Result { + load_fields!(object, [x: i32]); + Ok(Self { x }) + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[derive(Debug, PartialEq)] + enum EnumShape { + Only { x: i32 }, + } + + impl save::Save for EnumShape { + const VERSION: Version = Version::new(0, 0, 0); + fn save(&self, context: save::Context<'_>) -> save::Result> { + Ok(match self { + Self::Only { x } => save_fields!(context, [x]), + }) + } + fn variant(&self) -> Option> { + Some(match self { + Self::Only { .. } => "Only".into(), + }) + } + } + + impl load::Load<'_> for EnumShape { + const VERSION: Version = Version::new(0, 0, 0); + const IS_ENUM: bool = true; + fn load(object: load::Object<'_>) -> load::Result { + let variant = object.variant().ok_or(load::error::Kind::MissingVariant)?; + match variant { + "Only" => { + load_fields!(object, [x: i32]); + Ok(Self::Only { x }) + } + _ => Err(load::error::Kind::UnknownVariant.into()), + } + } + fn load_legacy(_object: load::Object<'_>) -> load::Result { + panic!("nope!"); + } + } + + #[test] + fn loading_enum_as_struct_is_rejected() -> anyhow::Result<()> { + let value = EnumShape::Only { x: 7 }; + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&value, dir, &metadata)?; + let err = load::load_from_disk::(&metadata, dir) + .expect_err("loading a tagged record as a struct should fail"); + let msg = format!("{err}"); + assert!( + msg.contains("unexpected variant"), + "expected UnexpectedVariant error, got: {msg}" + ); + Ok(()) + } + + #[test] + fn loading_struct_as_enum_is_rejected() -> anyhow::Result<()> { + let value = StructShape { x: 7 }; + let temp_dir = tempfile::tempdir()?; + let dir = temp_dir.path(); + let metadata = dir.join("metadata.json"); + + save::save_to_disk(&value, dir, &metadata)?; + let err = load::load_from_disk::(&metadata, dir) + .expect_err("loading an untagged record as an enum should fail"); + let msg = format!("{err}"); + assert!( + msg.contains("missing variant"), + "expected MissingVariant error, got: {msg}" + ); + Ok(()) + } +} diff --git a/diskann-record/src/load/context.rs b/diskann-record/src/load/context.rs new file mode 100644 index 000000000..8f7460000 --- /dev/null +++ b/diskann-record/src/load/context.rs @@ -0,0 +1,288 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{ + collections::HashSet, + fs::File, + io::BufReader, + path::{Path, PathBuf}, +}; + +use crate::{ + Number, Version, + load::{Error, Loadable, Result, error}, + save, +}; + +#[derive(Debug, serde::Deserialize)] +pub(super) struct ContextInner { + dir: PathBuf, + files: HashSet, + value: save::Value<'static>, +} + +#[derive(Debug, serde::Deserialize)] +struct FileRepr { + files: HashSet, + value: save::Value<'static>, +} + +impl ContextInner { + pub(super) fn new(metadata: &Path, dir: &Path) -> Result { + let file = std::fs::File::open(metadata).map_err(|e| { + Error::new(e).context(format!("while trying to open {}", metadata.display())) + })?; + + let reader = std::io::BufReader::new(file); + let repr: FileRepr = serde_json::from_reader(reader) + .map_err(|e| Error::new(e).context("could not deserialize manifest"))?; + + let this = Self { + dir: dir.into(), + files: repr.files, + value: repr.value, + }; + Ok(this) + } + + pub(super) fn read(&self, key: &str) -> Result> { + let key_as_path: &Path = key.as_ref(); + if !self.files.contains(key_as_path) { + return Err(Error::from(error::Kind::MissingFile).context(format!( + "handle references file {:?} which is not registered in the manifest", + key, + ))); + } + + let full = self.dir.join(key); + let file = std::fs::File::open(&full).map_err(|err| { + Error::new(err).context(format!("while opening artifact file {}", full.display())) + })?; + let reader = Reader { + io: BufReader::new(file), + _lifetime: std::marker::PhantomData, + }; + + Ok(reader) + } + + pub(super) fn context(&self) -> Context<'_> { + Context::new(self, &self.value) + } +} + +pub struct Reader<'a> { + io: BufReader, + _lifetime: std::marker::PhantomData<&'a ()>, +} + +impl std::io::Read for Reader<'_> { + // Required method + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.io.read(buf) + } + + // Provided methods + fn read_vectored(&mut self, bufs: &mut [std::io::IoSliceMut<'_>]) -> std::io::Result { + self.io.read_vectored(bufs) + } + fn read_to_end(&mut self, buf: &mut Vec) -> std::io::Result { + self.io.read_to_end(buf) + } + fn read_to_string(&mut self, buf: &mut String) -> std::io::Result { + self.io.read_to_string(buf) + } + fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> { + self.io.read_exact(buf) + } +} + +impl std::io::Seek for Reader<'_> { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { + self.io.seek(pos) + } + + fn rewind(&mut self) -> std::io::Result<()> { + self.io.rewind() + } + fn stream_position(&mut self) -> std::io::Result { + self.io.stream_position() + } + fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> { + self.io.seek_relative(offset) + } +} + +/////////////////////// +// User facing types // +/////////////////////// + +#[derive(Debug, Clone)] +pub struct Context<'a> { + inner: &'a ContextInner, + value: &'a save::Value<'a>, +} + +impl<'a> Context<'a> { + fn new(inner: &'a ContextInner, value: &'a save::Value<'a>) -> Self { + Self { inner, value } + } + + fn context(&self) -> &'a ContextInner { + self.inner + } + + pub fn load(&self) -> Result + where + T: Loadable<'a>, + { + T::load(self.clone()) + } + + pub fn as_object(&self) -> Option> { + match self.value { + save::Value::Object(versioned) => { + let object = Object { + inner: self.inner, + record: versioned.record(), + version: versioned.version(), + variant: versioned.variant(), + }; + Some(object) + } + _ => None, + } + } + + pub fn as_str(&self) -> Option<&'a str> { + match self.value { + save::Value::String(s) => Some(s), + _ => None, + } + } + + pub fn as_array(&self) -> Option> { + match self.value { + save::Value::Array(array) => Some(Array::new(self.context(), array)), + _ => None, + } + } + + pub fn as_number(&self) -> Option { + match self.value { + save::Value::Number(number) => Some(*number), + _ => None, + } + } + + pub fn as_bool(&self) -> Option { + match self.value { + save::Value::Bool(value) => Some(*value), + _ => None, + } + } + + pub fn is_null(&self) -> bool { + matches!(self.value, save::Value::Null) + } + + pub(crate) fn as_handle(&self) -> Option<&save::Handle> { + match self.value { + save::Value::Handle(handle) => Some(handle), + _ => None, + } + } +} + +#[derive(Debug)] +pub struct Object<'a> { + inner: &'a ContextInner, + record: &'a save::Record<'a>, + version: Version, + variant: Option<&'a str>, +} + +impl<'a> Object<'a> { + pub fn version(&self) -> Version { + self.version + } + + /// Return the variant tag for enum-shaped records, or `None` for structs. + pub fn variant(&self) -> Option<&'a str> { + self.variant + } + + pub fn field(&self, key: &str) -> Result + where + T: Loadable<'a>, + { + match self.record.get(key) { + Some(value) => T::load(Context::new(self.context(), value)), + None => Err((error::Kind::MissingField).into()), + } + } + + pub fn read(&self, handle: &save::Handle) -> Result> { + self.inner.read(handle.as_str()) + } + + fn context(&self) -> &'a ContextInner { + self.inner + } +} + +#[derive(Debug)] +pub struct Array<'a> { + inner: &'a ContextInner, + array: &'a [save::Value<'a>], +} + +impl<'a> Array<'a> { + fn new(inner: &'a ContextInner, array: &'a [save::Value<'a>]) -> Self { + Self { inner, array } + } + + pub fn len(&self) -> usize { + self.array.len() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn iter(&self) -> Iter<'a> { + Iter::new(self.context(), self.array.iter()) + } + + fn context(&self) -> &'a ContextInner { + self.inner + } +} + +pub struct Iter<'a> { + inner: &'a ContextInner, + iter: std::slice::Iter<'a, save::Value<'a>>, +} + +impl<'a> Iter<'a> { + fn new(inner: &'a ContextInner, iter: std::slice::Iter<'a, save::Value<'a>>) -> Self { + Self { inner, iter } + } +} + +impl<'a> Iterator for Iter<'a> { + type Item = Context<'a>; + fn next(&mut self) -> Option { + self.iter + .next() + .map(|value| Context::new(self.inner, value)) + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +impl ExactSizeIterator for Iter<'_> {} diff --git a/diskann-record/src/load/error.rs b/diskann-record/src/load/error.rs new file mode 100644 index 000000000..15b8cd60e --- /dev/null +++ b/diskann-record/src/load/error.rs @@ -0,0 +1,117 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::fmt::{Debug, Display}; + +pub type Result = ::std::result::Result; + +#[derive(Debug)] +pub struct Error { + inner: ErrorInner, +} + +impl Error { + pub fn new(err: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Error { + inner: ErrorInner::Heavy(anyhow::Error::new(err)), + } + } + + pub fn message(message: D) -> Self + where + D: Display + Debug + Send + Sync + 'static, + { + Error { + inner: ErrorInner::Heavy(anyhow::Error::msg(message)), + } + } + + pub fn context(self, message: D) -> Self + where + D: Display + Send + Sync + 'static, + { + // TODO: Should we do something clever with "light" errors to avoid context + // proliferation? + match self.inner { + ErrorInner::Light(kind) => Self { + inner: ErrorInner::Light(kind), + }, + ErrorInner::Heavy(kind) => Self { + inner: ErrorInner::Heavy(kind.context(message)), + }, + } + } +} + +#[derive(Debug)] +enum ErrorInner { + Light(Kind), + Heavy(anyhow::Error), +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.inner { + ErrorInner::Light(kind) => write!(f, "Load Error: {}", kind), + ErrorInner::Heavy(error) => write!(f, "Load Error: {:?}", error), + } + } +} + +impl std::error::Error for Error {} + +#[derive(Debug, Clone, Copy)] +#[non_exhaustive] +pub enum Kind { + VersionMismatch, + MissingField, + TypeMismatch, + UnknownVersion, + /// Wire format has a `$variant` tag but the target type is not an enum. + UnexpectedVariant, + /// Target type is an enum but the wire format has no `$variant` tag. + MissingVariant, + /// The wire format's `$variant` tag does not match any known variant of the + /// target enum. + UnknownVariant, + /// A numeric value in the manifest does not fit in the requested Rust type + /// (either out of range or would lose precision). + NumberOutOfRange, + /// A `$handle` references a file name that is not registered in the + /// manifest's `files` set. + MissingFile, +} + +impl Kind { + pub const fn as_str(self) -> &'static str { + match self { + Self::VersionMismatch => "version mismatch", + Self::MissingField => "missing field", + Self::TypeMismatch => "type mismatch", + Self::UnknownVersion => "unknown version", + Self::UnexpectedVariant => "unexpected variant tag on non-enum record", + Self::MissingVariant => "missing variant tag on enum record", + Self::UnknownVariant => "unknown variant", + Self::NumberOutOfRange => "number out of range for target type", + Self::MissingFile => "handle references a file not present in the manifest", + } + } +} + +impl std::fmt::Display for Kind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +impl From for Error { + fn from(kind: Kind) -> Self { + let inner = ErrorInner::Light(kind); + Self { inner } + } +} diff --git a/diskann-record/src/load/mod.rs b/diskann-record/src/load/mod.rs new file mode 100644 index 000000000..8aef25145 --- /dev/null +++ b/diskann-record/src/load/mod.rs @@ -0,0 +1,171 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +pub mod error; +pub use error::{Error, Result}; + +mod context; +pub use context::{Context, Object}; + +use std::path::Path; + +use crate::{Version, save}; + +pub fn load_from_disk(metadata: &Path, dir: &Path) -> Result +where + T: for<'a> Loadable<'a>, +{ + let inner = context::ContextInner::new(metadata, dir)?; + inner.context().load() +} + +pub trait Load<'a>: Sized { + const VERSION: Version; + /// Set to `true` for enum types. The framework checks at load time that the + /// wire format's `$variant` tag presence matches this constant and rejects + /// mismatches (an enum loader cannot consume an untagged record and a + /// struct loader cannot consume a tagged record). + const IS_ENUM: bool = false; + fn load(object: Object<'a>) -> Result; + fn load_legacy(object: Object<'a>) -> Result; +} + +pub trait Loadable<'a>: Sized { + fn load(context: Context<'a>) -> Result; +} + +impl<'a, T> Loadable<'a> for T +where + T: Load<'a>, +{ + fn load(context: Context<'a>) -> Result { + let object = context.as_object().ok_or(error::Kind::TypeMismatch)?; + match (T::IS_ENUM, object.variant().is_some()) { + (false, true) => return Err(error::Kind::UnexpectedVariant.into()), + (true, false) => return Err(error::Kind::MissingVariant.into()), + _ => {} + } + let version = object.version(); + if version == T::VERSION { + T::load(object) + } else { + T::load_legacy(object) + } + } +} + +//////////// +// Macros // +//////////// + +#[macro_export] +macro_rules! load_fields { + (@field $object:ident, $field:ident: $T:ty) => { + let $field: $T = $object.field(stringify!($field))?; + }; + (@field $object:ident, $field:ident) => { + let $field = $object.field(stringify!($field))?; + }; + ($object:ident, [$($field:ident $(: $ty:ty)?),+ $(,)?]) => { + $( + $crate::load_fields!(@field $object, $field $(: $ty)?); + )+ + }; +} + +/////////////// +// Bootstrap // +/////////////// + +impl<'a> Loadable<'a> for &'a str { + fn load(context: Context<'a>) -> Result { + context + .as_str() + .ok_or_else(|| error::Kind::TypeMismatch.into()) + } +} + +impl Loadable<'_> for String { + fn load(context: Context<'_>) -> Result { + context.load::<&str>().map(|s| s.into()) + } +} + +impl Loadable<'_> for save::Handle { + fn load(context: Context<'_>) -> Result { + context + .as_handle() + .cloned() + .ok_or_else(|| error::Kind::TypeMismatch.into()) + } +} + +impl Loadable<'_> for bool { + fn load(context: Context<'_>) -> Result { + context + .as_bool() + .ok_or_else(|| error::Kind::TypeMismatch.into()) + } +} + +impl<'a, T> Loadable<'a> for Option +where + T: Loadable<'a>, +{ + fn load(context: Context<'a>) -> Result { + if context.is_null() { + Ok(None) + } else { + T::load(context).map(Some) + } + } +} + +impl<'a, T> Loadable<'a> for Vec +where + T: Loadable<'a>, +{ + fn load(context: Context<'a>) -> Result { + match context.as_array() { + Some(array) => array.iter().map(T::load).collect(), + None => Err((error::Kind::TypeMismatch).into()), + } + } +} + +macro_rules! load_number { + ($T:ty) => { + impl Loadable<'_> for $T { + fn load(context: Context<'_>) -> Result { + match context.as_number() { + Some(n) => n.try_into().map_err(|_| error::Kind::NumberOutOfRange.into()), + None => Err((error::Kind::TypeMismatch).into()), + } + } + } + }; + ($($Ts:ty),+ $(,)?) => { + $(load_number!($Ts);)+ + } +} + +load_number!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, f32, f64); + +// NonZero* primitives are loaded by deserializing the inner numeric type and then +// validating it is non-zero. A zero value produces a `NumberOutOfRange` light error. +macro_rules! load_nonzero { + ($T:ty, $Inner:ty) => { + impl Loadable<'_> for $T { + fn load(context: Context<'_>) -> Result { + let inner: $Inner = context.load()?; + <$T>::new(inner).ok_or_else(|| error::Kind::NumberOutOfRange.into()) + } + } + }; +} + +load_nonzero!(std::num::NonZeroU32, u32); +load_nonzero!(std::num::NonZeroU64, u64); +load_nonzero!(std::num::NonZeroUsize, usize); diff --git a/diskann-record/src/number.rs b/diskann-record/src/number.rs new file mode 100644 index 000000000..163cbf1e5 --- /dev/null +++ b/diskann-record/src/number.rs @@ -0,0 +1,163 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +#[derive(Debug, Clone, Copy)] +pub enum Number { + U64(u64), + I64(i64), + F64(f64), +} + +impl Serialize for Number { + fn serialize(&self, serializer: S) -> Result { + match *self { + Self::U64(v) => serializer.serialize_u64(v), + Self::I64(v) => serializer.serialize_i64(v), + Self::F64(v) => serializer.serialize_f64(v), + } + } +} + +impl<'de> Deserialize<'de> for Number { + fn deserialize>(deserializer: D) -> Result { + struct NumberVisitor; + + impl<'de> Visitor<'de> for NumberVisitor { + type Value = Number; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("a number") + } + + fn visit_u64(self, v: u64) -> Result { + Ok(Number::U64(v)) + } + + fn visit_i64(self, v: i64) -> Result { + Ok(Number::I64(v)) + } + + fn visit_f64(self, v: f64) -> Result { + Ok(Number::F64(v)) + } + } + + deserializer.deserialize_any(NumberVisitor) + } +} + +macro_rules! try_cast { + ($v:ident :$T:ty => $U:ty) => {{ + let c = $v as $U; + if c as $T == $v { Some(c) } else { None } + }}; +} + +macro_rules! int { + ($f:ident, $T:ty) => { + pub fn $f(self) -> Option<$T> { + match self { + Self::U64(v) => v.try_into().ok(), + Self::I64(v) => v.try_into().ok(), + Self::F64(v) => try_cast!(v:f64 => $T), + } + } + } +} + +macro_rules! float { + ($f:ident, $T:ty) => { + pub fn $f(self) -> Option<$T> { + match self { + Self::U64(v) => try_cast!(v:u64 => $T), + Self::I64(v) => try_cast!(v:i64 => $T), + Self::F64(v) => try_cast!(v:f64 => $T), + } + } + } +} + +impl Number { + int!(as_u8, u8); + int!(as_u16, u16); + int!(as_u32, u32); + int!(as_u64, u64); + int!(as_usize, usize); + + int!(as_i8, i8); + int!(as_i16, i16); + int!(as_i32, i32); + int!(as_i64, i64); + int!(as_isize, isize); + + float!(as_f32, f32); + float!(as_f64, f64); +} + +macro_rules! from { + ($T:ty => $variant:ident) => { + impl From<$T> for Number { + fn from(v: $T) -> Self { + Self::$variant(v.into()) + } + } + }; + ($($T:ty => $variant:ident),+ $(,)?) => { + $(from!($T => $variant);)+ + } +} + +from!( + u64 => U64, + u32 => U64, + u16 => U64, + u8 => U64, + i64 => I64, + i32 => I64, + i16 => I64, + i8 => I64, + f32 => F64, + f64 => F64, +); + +impl From for Number { + fn from(v: usize) -> Self { + Self::U64(v.try_into().unwrap()) + } +} + +macro_rules! try_from { + ($T:ty => $f:ident) => { + impl TryFrom for $T { + type Error = (); + fn try_from(number: Number) -> Result<$T, Self::Error> { + number.$f().ok_or(()) + } + } + }; + ($($T:ty => $f:ident),+ $(,)?) => { + $(try_from!($T => $f);)+ + } +} + +try_from!( + u64 => as_u64, + u32 => as_u32, + u16 => as_u16, + u8 => as_u8, + usize => as_usize, + + i64 => as_i64, + i32 => as_i32, + i16 => as_i16, + i8 => as_i8, + isize => as_isize, + + f32 => as_f32, + f64 => as_f64, +); diff --git a/diskann-record/src/save/context.rs b/diskann-record/src/save/context.rs new file mode 100644 index 000000000..d748daa82 --- /dev/null +++ b/diskann-record/src/save/context.rs @@ -0,0 +1,155 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{collections::HashSet, fs::File, io::BufWriter, path::PathBuf, sync::Mutex}; + +use crate::save::{Error, Handle, Result, Value}; + +#[derive(Debug)] +pub(super) struct ContextInner { + dir: PathBuf, + metadata: PathBuf, + files: Mutex>, +} + +#[derive(serde::Serialize)] +struct Final<'a> { + files: Vec<&'a str>, + value: &'a Value<'a>, +} + +impl ContextInner { + // TODO: Error if the directory looks bad? + pub(super) fn new(dir: PathBuf, metadata: PathBuf) -> Self { + Self { + dir, + metadata, + files: Mutex::new(HashSet::new()), + } + } + + pub(super) fn context(&self) -> Context<'_> { + Context { inner: self } + } + + pub(super) fn write(&self, key: &str) -> Result> { + // TODO: Proper disambiguation - making UUIDs etc. + let mut files = self + .files + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + if !files.insert(key.into()) { + return Err(Error::message(format!( + "file name {:?} has already been registered with this save context", + key, + ))); + } + let full = self.dir.join(key); + let file = std::fs::File::create_new(&full).map_err(|err| { + Error::new(err).context(format!("while creating new file {}", full.display())) + })?; + Ok(Writer { + io: BufWriter::new(file), + name: key.into(), + _lifetime: std::marker::PhantomData, + }) + } + + pub fn finish(self, value: Value<'_>) -> Result<()> { + let temp = format!("{}.temp", self.metadata.display()); + if std::path::Path::new(&temp).exists() { + return Err(Error::message(format!( + "Temporary file {} already exists. Aborting!", + temp + ))); + } + + let files = self + .files + .into_inner() + .unwrap_or_else(|poison| poison.into_inner()); + let f = Final { + files: files.iter().map(|k| &**k).collect(), + value: &value, + }; + + let buffer = std::fs::File::create(&temp).map_err(|err| { + Error::new(err).context(format!("while creating temp manifest file {}", temp)) + })?; + serde_json::to_writer_pretty(buffer, &f) + .map_err(|err| Error::new(err).context("while serializing manifest to JSON"))?; + std::fs::rename(&temp, &self.metadata).map_err(|err| { + Error::new(err).context(format!( + "while renaming temp manifest {} to final path {}", + temp, + self.metadata.display() + )) + })?; + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct Context<'a> { + inner: &'a ContextInner, +} + +impl<'a> Context<'a> { + pub fn write(&self, key: &str) -> Result> { + self.inner.write(key) + } +} + +#[derive(Debug)] +pub struct Writer<'a> { + io: BufWriter, + name: String, + _lifetime: std::marker::PhantomData<&'a ()>, +} + +impl Writer<'_> { + pub fn finish(self) -> Result { + // NOTE: self.io.into_inner() will flush the buffer and close the file. + self.io + .into_inner() + .map_err(|err| Error::new(err.into_error()))?; + Ok(Handle::new(self.name)) + } +} + +impl std::io::Write for Writer<'_> { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.io.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.io.flush() + } + fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result { + self.io.write_vectored(bufs) + } + fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + self.io.write_all(buf) + } + fn write_fmt(&mut self, args: std::fmt::Arguments<'_>) -> std::io::Result<()> { + self.io.write_fmt(args) + } +} + +impl std::io::Seek for Writer<'_> { + fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { + self.io.seek(pos) + } + + fn rewind(&mut self) -> std::io::Result<()> { + self.io.rewind() + } + fn stream_position(&mut self) -> std::io::Result { + self.io.stream_position() + } + fn seek_relative(&mut self, offset: i64) -> std::io::Result<()> { + self.io.seek_relative(offset) + } +} diff --git a/diskann-record/src/save/error.rs b/diskann-record/src/save/error.rs new file mode 100644 index 000000000..298f6065d --- /dev/null +++ b/diskann-record/src/save/error.rs @@ -0,0 +1,50 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::fmt::{Debug, Display}; + +pub type Result = ::std::result::Result; + +#[derive(Debug)] +pub struct Error { + inner: anyhow::Error, +} + +impl Error { + pub fn new(err: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Error { + inner: anyhow::Error::new(err), + } + } + + pub fn message(message: D) -> Self + where + D: Display + Debug + Send + Sync + 'static, + { + Error { + inner: anyhow::Error::msg(message), + } + } + + pub fn context(self, message: D) -> Self + where + D: Display + Send + Sync + 'static, + { + Error { + inner: self.inner.context(message), + } + } +} + +impl Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Save Error: {:?}", self.inner) + } +} + +impl std::error::Error for Error {} diff --git a/diskann-record/src/save/mod.rs b/diskann-record/src/save/mod.rs new file mode 100644 index 000000000..57a516071 --- /dev/null +++ b/diskann-record/src/save/mod.rs @@ -0,0 +1,211 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +mod value; +pub use value::{Handle, Record, Value, Versioned}; + +mod context; +pub use context::Context; + +mod error; +pub use error::{Error, Result}; + +use std::borrow::Cow; + +use crate::Version; + +pub fn save_to_disk( + x: &T, + dir: impl AsRef, + metadata: impl AsRef, +) -> Result<()> +where + T: Saveable, +{ + let inner = context::ContextInner::new(dir.as_ref().into(), metadata.as_ref().into()); + let value = x.save(inner.context())?; + inner.finish(value) +} + +/// Save objects! +pub trait Save { + const VERSION: Version; + fn save(&self, context: Context<'_>) -> Result>; + + /// Return the variant tag for enum types. Default: `None` (struct). + /// + /// Enum implementations must return `Some(variant_name)` for every variant. + /// The framework writes this into the manifest as `$variant` and enforces on + /// load that the tag's presence matches the corresponding [`Load::IS_ENUM`]. + fn variant(&self) -> Option> { + None + } +} + +/// Save anything! +pub trait Saveable { + fn save(&self, context: Context<'_>) -> Result>; +} + +impl Saveable for T +where + T: Save, +{ + fn save(&self, context: Context<'_>) -> Result> { + let record = self.save(context)?; + let variant = ::variant(self); + let versioned = Versioned::new(record, T::VERSION, variant); + Ok(Value::Object(versioned)) + } +} + +////////////////// +// Random Stuff // +////////////////// + +impl Saveable for [T] +where + T: Saveable, +{ + fn save(&self, context: Context<'_>) -> Result> { + let values: Result> = self.iter().map(|t| t.save(context.clone())).collect(); + values.map(Value::Array) + } +} + +impl Saveable for Vec +where + T: Saveable, +{ + fn save(&self, context: Context<'_>) -> Result> { + self.as_slice().save(context) + } +} + +impl Saveable for str { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::String(self.into())) + } +} + +impl Saveable for String { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::String(self.as_str().into())) + } +} + +impl Saveable for Handle { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::Handle(self.clone())) + } +} + +impl Saveable for bool { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::Bool(*self)) + } +} + +impl Saveable for Option +where + T: Saveable, +{ + fn save(&self, context: Context<'_>) -> Result> { + match self { + None => Ok(Value::Null), + Some(t) => t.save(context), + } + } +} + +macro_rules! save_number { + ($T:ty) => { + impl Saveable for $T { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::Number((*self).into())) + } + } + }; + ($($Ts:ty),+ $(,)?) => { + $(save_number!($Ts);)+ + } +} + +save_number!(usize, u64, u32, u16, u8, i64, i32, i16, i8, f32, f64); + +// NonZero* primitives serialize as their inner numeric type. Loaders reject zero. +macro_rules! save_nonzero { + ($T:ty) => { + impl Saveable for $T { + fn save(&self, _: Context<'_>) -> Result> { + Ok(Value::Number(self.get().into())) + } + } + }; + ($($Ts:ty),+ $(,)?) => { + $(save_nonzero!($Ts);)+ + } +} + +save_nonzero!(std::num::NonZeroU32, std::num::NonZeroU64, std::num::NonZeroUsize); + +#[derive(Debug, Clone, Copy)] +#[doc(hidden)] +pub struct Serializing(pub &'static str); + +impl std::fmt::Display for Serializing { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "while serializing field \"{}\"", self.0) + } +} + +/// Build a [`Record`] from a list of fields. +/// +/// Two forms are supported: +/// +/// * `save_fields!(self, context, [a, b, c])` reads each field as `self.a`, +/// `self.b`, etc. Use this from `Save::save` for plain structs. +/// * `save_fields!(context, [a, b, c])` reads each field from a local binding of +/// the same name. Use this inside enum match arms where the variant's payload +/// has already been destructured into local bindings. Those bindings are +/// assumed to be references (which is automatic when matching against `&self`); +/// for an owned local, take a reference explicitly first. +#[macro_export] +macro_rules! save_fields { + ($me:ident, $context:ident, [$($field:ident),+ $(,)?]) => {{ + $crate::save::Record::from_iter( + [ + $( + ( + ::std::borrow::Cow::Borrowed(stringify!($field)), + <_ as $crate::save::Saveable>::save( + &$me.$field, + $context.clone() + ).map_err(|err| { + err.context($crate::save::Serializing(stringify!($field))) + })? + ), + )+ + ] + ) + }}; + ($context:ident, [$($field:ident),+ $(,)?]) => {{ + $crate::save::Record::from_iter( + [ + $( + ( + ::std::borrow::Cow::Borrowed(stringify!($field)), + <_ as $crate::save::Saveable>::save( + $field, + $context.clone() + ).map_err(|err| { + err.context($crate::save::Serializing(stringify!($field))) + })? + ), + )+ + ] + ) + }}; +} diff --git a/diskann-record/src/save/value.rs b/diskann-record/src/save/value.rs new file mode 100644 index 000000000..0bf7418c6 --- /dev/null +++ b/diskann-record/src/save/value.rs @@ -0,0 +1,291 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::{borrow::Cow, collections::HashMap}; + +use serde::{ + Deserialize, Deserializer, Serialize, Serializer, + de::{self, MapAccess, SeqAccess, Visitor}, + ser::SerializeStruct, +}; + +use crate::{Number, Version, save::Error}; + +#[derive(Debug)] +pub enum Value<'a> { + Null, + Bool(bool), + Number(Number), + String(Cow<'a, str>), + Bytes(Cow<'a, [u8]>), + Array(Vec>), + Object(Versioned<'a>), + Handle(Handle), +} + +impl Serialize for Value<'_> { + fn serialize(&self, ser: S) -> Result { + match self { + Self::Null => ser.serialize_none(), + Self::Bool(b) => ser.serialize_bool(*b), + Self::Number(n) => n.serialize(ser), + Self::String(s) => ser.serialize_str(s), + Self::Bytes(b) => ser.serialize_bytes(b), + Self::Array(a) => a.serialize(ser), + Self::Object(v) => v.serialize(ser), + Self::Handle(h) => h.serialize(ser), + } + } +} + +impl<'de> Deserialize<'de> for Value<'static> { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct Inner; + + impl<'de> Visitor<'de> for Inner { + type Value = Value<'static>; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("a valid Value") + } + + fn visit_unit(self) -> Result, E> { + Ok(Value::Null) + } + + fn visit_none(self) -> Result, E> { + Ok(Value::Null) + } + + fn visit_some(self, deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + Value::deserialize(deserializer) + } + + fn visit_bool(self, v: bool) -> Result, E> { + Ok(Value::Bool(v)) + } + + fn visit_u64(self, v: u64) -> Result, E> { + Ok(Value::Number(Number::U64(v))) + } + + fn visit_i64(self, v: i64) -> Result, E> { + Ok(Value::Number(Number::I64(v))) + } + + fn visit_f64(self, v: f64) -> Result, E> { + Ok(Value::Number(Number::F64(v))) + } + + fn visit_str(self, v: &str) -> Result, E> { + Ok(Value::String(Cow::Owned(v.to_owned()))) + } + + fn visit_string(self, v: String) -> Result, E> { + Ok(Value::String(Cow::Owned(v))) + } + + fn visit_bytes(self, v: &[u8]) -> Result, E> { + Ok(Value::Bytes(Cow::Owned(v.to_owned()))) + } + + fn visit_byte_buf(self, v: Vec) -> Result, E> { + Ok(Value::Bytes(Cow::Owned(v))) + } + + fn visit_seq(self, mut seq: A) -> Result, A::Error> + where + A: SeqAccess<'de>, + { + let mut values = Vec::with_capacity(seq.size_hint().unwrap_or(0)); + while let Some(v) = seq.next_element()? { + values.push(v); + } + Ok(Value::Array(values)) + } + + fn visit_map(self, mut map: A) -> Result, A::Error> + where + A: MapAccess<'de>, + { + // TODO: Handle invaiants that only one of our reserved words are present. + let mut version: Option = None; + let mut variant: Option> = None; + let mut handle_name: Option = None; + let mut fields: HashMap, Value<'static>> = HashMap::new(); + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "$version" => { + version = Some(map.next_value()?); + } + "$variant" => { + variant = Some(map.next_value()?); + } + "$handle" => { + handle_name = Some(map.next_value()?); + } + _ => { + let value = map.next_value()?; + fields.insert(Cow::Owned(key), value); + } + } + } + + if let Some(name) = handle_name { + return Ok(Value::Handle(Handle(name))); + } + + if let Some(version) = version { + let record = Record { record: fields }; + return Ok(Value::Object(Versioned { + record, + version, + variant, + })); + } + + Err(de::Error::custom( + "map must contain either \"$version\" or \"$handle\"", + )) + } + } + + deserializer.deserialize_any(Inner) + } +} + +impl From for Value<'_> { + fn from(handle: Handle) -> Self { + Self::Handle(handle) + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Record<'a> { + record: HashMap, Value<'a>>, +} + +impl<'a> Record<'a> { + /// Construct an empty record. Useful for unit enum variants. + pub fn empty() -> Self { + Self { + record: HashMap::new(), + } + } + + pub fn contains_key(&self, key: &str) -> bool { + self.record.contains_key(key) + } + + pub fn get(&self, key: &str) -> Option<&Value<'a>> { + self.record.get(key) + } + + pub fn insert(&mut self, key: K, value: V) -> crate::save::Result>> + where + K: Into>, + V: Into>, + { + let key = key.into(); + if crate::is_reserved(&key) { + return Err(Error::message(format!( + "record key {:?} is reserved (keys starting with `$` are reserved for the \ + save/load framework)", + key, + ))); + } + + Ok(self.record.insert(key, value.into())) + } +} + +impl<'a> FromIterator<(Cow<'a, str>, Value<'a>)> for Record<'a> { + fn from_iter, Value<'a>)>>(itr: I) -> Self { + Self { + record: itr.into_iter().collect(), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Versioned<'a> { + #[serde(flatten)] + record: Record<'a>, + #[serde(rename = "$version")] + version: Version, + #[serde( + rename = "$variant", + default, + skip_serializing_if = "Option::is_none", + borrow + )] + variant: Option>, +} + +impl<'a> Versioned<'a> { + pub(crate) fn new(record: Record<'a>, version: Version, variant: Option>) -> Self { + Self { + record, + version, + variant, + } + } + + pub(crate) fn version(&self) -> Version { + self.version + } + + pub(crate) fn variant(&self) -> Option<&str> { + self.variant.as_deref() + } + + pub(crate) fn record(&self) -> &Record<'a> { + &self.record + } +} + +#[derive(Debug, Clone)] +pub struct Handle(String); + +impl Handle { + pub(crate) fn new(string: String) -> Self { + Self(string) + } + + pub(crate) fn as_str(&self) -> &str { + &self.0 + } +} + +impl Serialize for Handle { + fn serialize(&self, ser: S) -> Result { + let mut handle = ser.serialize_struct("Handle", 1)?; + handle.serialize_field("$handle", &self.0)?; + handle.end() + } +} + +impl<'de> Deserialize<'de> for Handle { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct Helper { + #[serde(rename = "$handle")] + handle: String, + } + let helper = Helper::deserialize(deserializer)?; + Ok(Handle(helper.handle)) + } +} diff --git a/diskann-record/src/version.rs b/diskann-record/src/version.rs new file mode 100644 index 000000000..1483bd2b1 --- /dev/null +++ b/diskann-record/src/version.rs @@ -0,0 +1,23 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub struct Version { + pub major: u32, + pub minor: u32, + pub patch: u32, +} + +impl Version { + pub const fn new(major: u32, minor: u32, patch: u32) -> Self { + Self { + major, + minor, + patch, + } + } +} diff --git a/diskann-vector/Cargo.toml b/diskann-vector/Cargo.toml index 79b5239ad..5494ee7ee 100644 --- a/diskann-vector/Cargo.toml +++ b/diskann-vector/Cargo.toml @@ -15,6 +15,7 @@ license.workspace = true cfg-if.workspace = true half.workspace = true diskann-wide.workspace = true +diskann-record.workspace = true [dev-dependencies] approx.workspace = true @@ -23,6 +24,7 @@ rand.workspace = true criterion.workspace = true rand_distr.workspace = true half = { workspace = true, features = ["rand_distr", "num-traits"] } +tempfile.workspace = true [[bench]] name = "bench_main" diff --git a/diskann-vector/src/distance/metric.rs b/diskann-vector/src/distance/metric.rs index 92c77a87e..b34ef7131 100644 --- a/diskann-vector/src/distance/metric.rs +++ b/diskann-vector/src/distance/metric.rs @@ -42,10 +42,10 @@ impl TryFrom for Metric { fn try_from(value: i32) -> Result { match value { - x if x == Metric::Cosine.into() => Ok(Metric::Cosine), - x if x == Metric::InnerProduct.into() => Ok(Metric::InnerProduct), - x if x == Metric::L2.into() => Ok(Metric::L2), - x if x == Metric::CosineNormalized.into() => Ok(Metric::CosineNormalized), + x if x == i32::from(Metric::Cosine) => Ok(Metric::Cosine), + x if x == i32::from(Metric::InnerProduct) => Ok(Metric::InnerProduct), + x if x == i32::from(Metric::L2) => Ok(Metric::L2), + x if x == i32::from(Metric::CosineNormalized) => Ok(Metric::CosineNormalized), _ => Err(TryFromMetricError(value)), } } @@ -98,6 +98,68 @@ impl FromStr for Metric { } } +//////////////////////// +// diskann-record I/O // +//////////////////////// + +/// Stable wire names for [`Metric`] variants. Renaming a Rust variant must not change +/// these strings without bumping the saved version, or old manifests will fail to load. +const METRIC_VARIANT_COSINE: &str = "Cosine"; +const METRIC_VARIANT_INNER_PRODUCT: &str = "InnerProduct"; +const METRIC_VARIANT_L2: &str = "L2"; +const METRIC_VARIANT_COSINE_NORMALIZED: &str = "CosineNormalized"; + +impl diskann_record::save::Save for Metric { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + _context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + Ok(diskann_record::save::Record::empty()) + } + + fn variant(&self) -> Option> { + Some( + match self { + Self::Cosine => METRIC_VARIANT_COSINE, + Self::InnerProduct => METRIC_VARIANT_INNER_PRODUCT, + Self::L2 => METRIC_VARIANT_L2, + Self::CosineNormalized => METRIC_VARIANT_COSINE_NORMALIZED, + } + .into(), + ) + } +} + +impl diskann_record::load::Load<'_> for Metric { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + const IS_ENUM: bool = true; + + fn load( + object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + let variant = object + .variant() + .ok_or(diskann_record::load::error::Kind::MissingVariant)?; + match variant { + METRIC_VARIANT_COSINE => Ok(Self::Cosine), + METRIC_VARIANT_INNER_PRODUCT => Ok(Self::InnerProduct), + METRIC_VARIANT_L2 => Ok(Self::L2), + METRIC_VARIANT_COSINE_NORMALIZED => Ok(Self::CosineNormalized), + other => Err(diskann_record::load::Error::message(format!( + "unknown Metric variant: {other:?}" + ))), + } + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + #[cfg(test)] mod tests { use std::str::FromStr; @@ -139,4 +201,45 @@ mod tests { assert_eq!(Metric::try_from(-1), Err(TryFromMetricError(-1))); assert_eq!(Metric::try_from(4), Err(TryFromMetricError(4))); } + + #[test] + fn metric_round_trips_through_record() { + for metric in [ + Metric::Cosine, + Metric::InnerProduct, + Metric::L2, + Metric::CosineNormalized, + ] { + let dir = tempfile::tempdir().expect("tempdir"); + let manifest = dir.path().join("metric.json"); + diskann_record::save::save_to_disk(&metric, dir.path(), &manifest) + .expect("save_to_disk"); + let restored: Metric = + diskann_record::load::load_from_disk(&manifest, dir.path()) + .expect("load_from_disk"); + assert_eq!(metric, restored); + } + } + + #[test] + fn loading_metric_from_unknown_variant_is_rejected() { + // Save a valid metric, hand-edit the JSON to use an unrecognised variant, + // and confirm we get a load error mentioning the bogus name. + let dir = tempfile::tempdir().expect("tempdir"); + let manifest = dir.path().join("metric.json"); + diskann_record::save::save_to_disk(&Metric::L2, dir.path(), &manifest) + .expect("save_to_disk"); + + let raw = std::fs::read_to_string(&manifest).expect("read manifest"); + let tampered = raw.replace("\"L2\"", "\"Bogus\""); + std::fs::write(&manifest, &tampered).expect("write manifest"); + + let err = diskann_record::load::load_from_disk::(&manifest, dir.path()) + .expect_err("load should fail for unknown variant"); + let msg = format!("{err}"); + assert!( + msg.contains("Bogus"), + "error should embed the unknown variant name, got: {msg}", + ); + } } diff --git a/diskann/Cargo.toml b/diskann/Cargo.toml index e19a62b87..cdefbfa95 100644 --- a/diskann/Cargo.toml +++ b/diskann/Cargo.toml @@ -27,6 +27,7 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } tracing = { workspace = true, optional = true } diskann-vector = { workspace = true } diskann-wide = { workspace = true } +diskann-record = { workspace = true } # Optional Dependencies dashmap = { workspace = true, optional = true } @@ -38,6 +39,7 @@ rand.workspace = true relative-path = "2.0.1" serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +tempfile.workspace = true tokio = { workspace = true, features = ["macros", "sync"] } dashmap = { workspace = true } diff --git a/diskann/src/graph/config/experimental.rs b/diskann/src/graph/config/experimental.rs index 2f5b6c7a9..328d2e626 100644 --- a/diskann/src/graph/config/experimental.rs +++ b/diskann/src/graph/config/experimental.rs @@ -71,3 +71,54 @@ impl InsertRetry { && num_candidates < self.retry_if_candidates_shorter_than().get() } } + +////////////////////////////////// +// diskann-record Save/Load // +////////////////////////////////// + +impl diskann_record::save::Save for InsertRetry { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + Ok(diskann_record::save_fields!( + self, + context, + [ + max_retries, + retry_if_candidates_shorter_than, + saturate_on_last_attempt, + ] + )) + } +} + +impl diskann_record::load::Load<'_> for InsertRetry { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn load( + object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + diskann_record::load_fields!( + object, + [ + max_retries: NonZeroU32, + retry_if_candidates_shorter_than: NonZeroU32, + saturate_on_last_attempt: bool, + ] + ); + Ok(Self { + max_retries, + retry_if_candidates_shorter_than, + saturate_on_last_attempt, + }) + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} diff --git a/diskann/src/graph/config/mod.rs b/diskann/src/graph/config/mod.rs index 502c56530..6620af0a3 100644 --- a/diskann/src/graph/config/mod.rs +++ b/diskann/src/graph/config/mod.rs @@ -695,6 +695,192 @@ impl Builder { } } +////////////////////////////////// +// diskann-record Save/Load // +////////////////////////////////// + +/// Stable wire names for [`PruneKind`] variants. +const PRUNE_KIND_TRIANGLE_INEQUALITY: &str = "TriangleInequality"; +const PRUNE_KIND_OCCLUDING: &str = "Occluding"; + +impl diskann_record::save::Save for PruneKind { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + _context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + Ok(diskann_record::save::Record::empty()) + } + + fn variant(&self) -> Option> { + Some( + match self { + Self::TriangleInequality => PRUNE_KIND_TRIANGLE_INEQUALITY, + Self::Occluding => PRUNE_KIND_OCCLUDING, + } + .into(), + ) + } +} + +impl diskann_record::load::Load<'_> for PruneKind { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + const IS_ENUM: bool = true; + + fn load( + object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + let variant = object + .variant() + .ok_or(diskann_record::load::error::Kind::MissingVariant)?; + match variant { + PRUNE_KIND_TRIANGLE_INEQUALITY => Ok(Self::TriangleInequality), + PRUNE_KIND_OCCLUDING => Ok(Self::Occluding), + other => Err(diskann_record::load::Error::message(format!( + "unknown PruneKind variant: {other:?}" + ))), + } + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + +/// Stable wire names for [`IntraBatchCandidates`] variants. +const INTRA_BATCH_NONE: &str = "None"; +const INTRA_BATCH_MAX: &str = "Max"; +const INTRA_BATCH_ALL: &str = "All"; + +impl diskann_record::save::Save for IntraBatchCandidates { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + Ok(match self { + Self::None | Self::All => diskann_record::save::Record::empty(), + Self::Max(max) => diskann_record::save_fields!(context, [max]), + }) + } + + fn variant(&self) -> Option> { + Some( + match self { + Self::None => INTRA_BATCH_NONE, + Self::Max(_) => INTRA_BATCH_MAX, + Self::All => INTRA_BATCH_ALL, + } + .into(), + ) + } +} + +impl diskann_record::load::Load<'_> for IntraBatchCandidates { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + const IS_ENUM: bool = true; + + fn load( + object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + let variant = object + .variant() + .ok_or(diskann_record::load::error::Kind::MissingVariant)?; + match variant { + INTRA_BATCH_NONE => Ok(Self::None), + INTRA_BATCH_ALL => Ok(Self::All), + INTRA_BATCH_MAX => { + diskann_record::load_fields!(object, [max: NonZeroU32]); + Ok(Self::Max(max)) + } + other => Err(diskann_record::load::Error::message(format!( + "unknown IntraBatchCandidates variant: {other:?}" + ))), + } + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + +impl diskann_record::save::Save for Config { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + Ok(diskann_record::save_fields!( + self, + context, + [ + pruned_degree, + max_degree, + l_build, + alpha, + prune_kind, + max_occlusion_size, + max_backedges, + max_minibatch_par, + intra_batch_candidates, + saturate_after_prune, + experimental_insert_retry, + ] + )) + } +} + +impl diskann_record::load::Load<'_> for Config { + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn load( + object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + diskann_record::load_fields!( + object, + [ + pruned_degree: NonZeroU32, + max_degree: NonZeroU32, + l_build: NonZeroU32, + alpha: f32, + prune_kind: PruneKind, + max_occlusion_size: NonZeroU32, + max_backedges: NonZeroU32, + max_minibatch_par: NonZeroU32, + intra_batch_candidates: IntraBatchCandidates, + saturate_after_prune: bool, + experimental_insert_retry: Option, + ] + ); + Ok(Self { + pruned_degree, + max_degree, + l_build, + alpha, + prune_kind, + max_occlusion_size, + max_backedges, + max_minibatch_par, + intra_batch_candidates, + saturate_after_prune, + experimental_insert_retry, + }) + } + + fn load_legacy( + _object: diskann_record::load::Object<'_>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +} + /////////// // Tests // /////////// @@ -1273,4 +1459,64 @@ mod tests { 1.0 + OCCLUDING_MASK, ); } + + ///////////////////////////////// + // diskann-record round-trips // + ///////////////////////////////// + + fn round_trip_helper(value: &T) -> T + where + T: diskann_record::save::Saveable + for<'a> diskann_record::load::Loadable<'a>, + { + let dir = tempfile::tempdir().expect("tempdir"); + let manifest = dir.path().join("manifest.json"); + diskann_record::save::save_to_disk(value, dir.path(), &manifest) + .expect("save_to_disk"); + diskann_record::load::load_from_disk::(&manifest, dir.path()) + .expect("load_from_disk") + } + + #[test] + fn prune_kind_round_trips() { + for kind in [PruneKind::TriangleInequality, PruneKind::Occluding] { + let restored = round_trip_helper(&kind); + assert_eq!(kind, restored); + } + } + + #[test] + fn intra_batch_candidates_round_trip_all_variants() { + let cases = [ + IntraBatchCandidates::None, + IntraBatchCandidates::All, + IntraBatchCandidates::Max(NonZeroU32::new(7).unwrap()), + ]; + for c in cases { + assert_eq!(c, round_trip_helper(&c)); + } + } + + #[test] + fn config_round_trips_minimal() { + // Build a minimal config via the public Builder. + let cfg = Builder::new(8, MaxDegree::Same, 16, PruneKind::TriangleInequality) + .build() + .expect("Builder::build"); + assert_eq!(cfg, round_trip_helper(&cfg)); + } + + #[test] + fn config_round_trips_with_insert_retry() { + // Build a config and tack on an experimental InsertRetry to exercise + // Option through the wire. + let mut cfg = Builder::new(8, SLACK, 16, PruneKind::Occluding) + .build() + .expect("Builder::build"); + cfg.experimental_insert_retry = Some(experimental::InsertRetry::new( + NonZeroU32::new(3).unwrap(), + NonZeroU32::new(2).unwrap(), + true, + )); + assert_eq!(cfg, round_trip_helper(&cfg)); + } } diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index cd81ba5f1..375c73bc3 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -3164,3 +3164,50 @@ struct BatchIdMismatch { batch_len: usize, ids_len: usize, } + +////////////////////////////////// +// diskann-record Save/Load // +////////////////////////////////// + +impl diskann_record::save::Save for DiskANNIndex +where + DP: DataProvider + diskann_record::save::Save, +{ + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn save( + &self, + context: diskann_record::save::Context<'_>, + ) -> diskann_record::save::Result> { + Ok(diskann_record::save_fields!( + self, + context, + [config, data_provider] + )) + } +} + +impl<'a, DP> diskann_record::load::Load<'a> for DiskANNIndex +where + DP: DataProvider + diskann_record::load::Load<'a>, +{ + const VERSION: diskann_record::Version = diskann_record::Version::new(0, 0, 0); + + fn load( + object: diskann_record::load::Object<'a>, + ) -> diskann_record::load::Result { + diskann_record::load_fields!(object, [config: Config, data_provider: DP]); + // The scratch pool is transient runtime state; thread sizing is a deployment + // decision and not part of the persisted index. Loaders that want a specific + // thread count should construct `DataProvider` and `Config` directly and call + // `DiskANNIndex::new` + // TODO :: Add a way to specify runtime state like ScratchPool params in `load` + Ok(Self::new(config, data_provider, None)) + } + + fn load_legacy( + _object: diskann_record::load::Object<'a>, + ) -> diskann_record::load::Result { + Err(diskann_record::load::error::Kind::UnknownVersion.into()) + } +}