diff --git a/diskann-benchmark-runner/src/any.rs b/diskann-benchmark-runner/src/any.rs deleted file mode 100644 index d25a58a9e..000000000 --- a/diskann-benchmark-runner/src/any.rs +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -/// An refinement of [`std::any::Any`] with an associated name (tag) and serialization. -/// -/// This type represents deserialized inputs returned from [`crate::Input::try_deserialize`] -/// and is passed to beckend benchmarks for matching and execution. -#[derive(Debug)] -pub struct Any { - any: Box, - tag: &'static str, -} - -impl Any { - /// Construct a new [`Any`] around `any` and associate it with the name `tag`. - /// - /// The tag is included as merely a debugging and readability aid and usually should - /// belong to a [`crate::Input::tag`] that generated `any`. - pub fn new(any: T, tag: &'static str) -> Self - where - T: serde::Serialize + std::fmt::Debug + 'static, - { - Self { - any: Box::new(any), - tag, - } - } - - /// A lower level API for constructing an [`Any`] that decouples the serialized - /// representation from the inmemory representation. - /// - /// When serialized, the **exact** representation of `repr` will be used. - /// - /// This is useful in some contexts where as part of input resolution, a fully resolved - /// input struct contains elements that are not serializable. - /// - /// Like [`Any::new`], the tag is included for debugging and readability. - pub fn raw(any: T, repr: serde_json::Value, tag: &'static str) -> Self - where - T: std::fmt::Debug + 'static, - { - Self { - any: Box::new(Raw::new(any, repr)), - tag, - } - } - - /// Return the benchmark tag associated with this benchmarks. - pub fn tag(&self) -> &'static str { - self.tag - } - - /// Return the Rust [`std::any::TypeId`] for the contained object. - pub fn type_id(&self) -> std::any::TypeId { - self.any.as_any().type_id() - } - - /// Return `true` if the runtime value is `T`. Otherwise, return false. - /// - /// ```rust - /// use diskann_benchmark_runner::any::Any; - /// - /// let value = Any::new(42usize, "usize"); - /// assert!(value.is::()); - /// assert!(!value.is::()); - /// ``` - #[must_use = "this function has no side effects"] - pub fn is(&self) -> bool - where - T: std::any::Any, - { - self.any.as_any().is::() - } - - /// Return a reference to the contained object if it's runtime type is `T`. - /// - /// Otherwise return `None`. - /// - /// ```rust - /// use diskann_benchmark_runner::any::Any; - /// - /// let value = Any::new(42usize, "usize"); - /// assert_eq!(*value.downcast_ref::().unwrap(), 42); - /// assert!(value.downcast_ref::().is_none()); - /// ``` - pub fn downcast_ref(&self) -> Option<&T> - where - T: std::any::Any, - { - self.any.as_any().downcast_ref::() - } - - /// Serialize the contained object to a [`serde_json::Value`]. - pub fn serialize(&self) -> Result { - self.any.dump() - } -} - -trait SerializableAny: std::fmt::Debug { - fn as_any(&self) -> &dyn std::any::Any; - fn dump(&self) -> Result; -} - -impl SerializableAny for T -where - T: std::any::Any + serde::Serialize + std::fmt::Debug, -{ - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn dump(&self) -> Result { - serde_json::to_value(self) - } -} - -// A backend type that allows users to decouple the serialized representation from the -// actual type. -#[derive(Debug)] -struct Raw { - value: T, - repr: serde_json::Value, -} - -impl Raw { - fn new(value: T, repr: serde_json::Value) -> Self { - Self { value, repr } - } -} - -impl SerializableAny for Raw -where - T: std::any::Any + std::fmt::Debug, -{ - fn as_any(&self) -> &dyn std::any::Any { - &self.value - } - - fn dump(&self) -> Result { - Ok(self.repr.clone()) - } -} - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_new() { - let x = Any::new(42usize, "my-tag"); - assert_eq!(x.tag(), "my-tag"); - assert_eq!(x.type_id(), std::any::TypeId::of::()); - assert!(x.is::()); - assert!(!x.is::()); - assert_eq!(*x.downcast_ref::().unwrap(), 42); - assert!(x.downcast_ref::().is_none()); - - assert!(!x.is::>()); - assert!(!x.is::>()); - assert!(x.downcast_ref::>().is_none()); - assert!(x.downcast_ref::>().is_none()); - - assert_eq!( - x.serialize().unwrap(), - serde_json::Value::Number(serde_json::value::Number::from(42usize)) - ); - } - - #[test] - fn test_raw() { - let repr = serde_json::json!(1.5); - let x = Any::raw(42usize, repr, "my-tag"); - assert_eq!(x.tag(), "my-tag"); - assert_eq!(x.type_id(), std::any::TypeId::of::()); - assert!(x.is::()); - assert!(!x.is::()); - assert_eq!(*x.downcast_ref::().unwrap(), 42); - assert!(x.downcast_ref::().is_none()); - - assert!(!x.is::>()); - assert!(!x.is::>()); - assert!(x.downcast_ref::>().is_none()); - assert!(x.downcast_ref::>().is_none()); - - assert_eq!( - x.serialize().unwrap(), - serde_json::Value::Number(serde_json::value::Number::from_f64(1.5).unwrap()) - ); - } -} diff --git a/diskann-benchmark-runner/src/app.rs b/diskann-benchmark-runner/src/app.rs index b46e4fb78..42f146ca5 100644 --- a/diskann-benchmark-runner/src/app.rs +++ b/diskann-benchmark-runner/src/app.rs @@ -293,12 +293,12 @@ impl App { let serialized = jobs .iter() .map(|job| { - serde_json::to_value(jobs::Unprocessed::new( + Ok(serde_json::to_value(jobs::Unprocessed::new( job.tag().into(), job.serialize()?, - )) + ))?) }) - .collect::, serde_json::Error>>()?; + .collect::>>()?; for (i, job) in jobs.iter().enumerate() { let prefix: &str = if i != 0 { "\n\n" } else { "" }; writeln!( diff --git a/diskann-benchmark-runner/src/benchmark.rs b/diskann-benchmark-runner/src/benchmark.rs index 39f48fd54..dbdfe8063 100644 --- a/diskann-benchmark-runner/src/benchmark.rs +++ b/diskann-benchmark-runner/src/benchmark.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; -use crate::{Any, Checkpoint, Input, Output}; +use crate::{Checkpoint, Input, Output}; /// A registered benchmark. /// @@ -134,6 +134,8 @@ pub enum PassFail { pub(crate) mod internal { use super::*; + use crate::input::internal::Any; + use anyhow::Context; use thiserror::Error; @@ -179,7 +181,7 @@ pub(crate) mod internal { pub(crate) type CheckedPassFail = PassFail; pub(crate) trait Regression { - fn tolerance(&self) -> &dyn crate::input::DynInput; + fn tolerance(&self) -> &dyn crate::input::internal::DynInput; fn input_tag(&self) -> &'static str; fn check( &self, @@ -228,8 +230,8 @@ pub(crate) mod internal { where T: super::Regression, { - fn tolerance(&self) -> &dyn crate::input::DynInput { - &crate::input::Wrapper::::INSTANCE + fn tolerance(&self) -> &dyn crate::input::internal::DynInput { + &crate::input::internal::Wrapper::::INSTANCE } fn input_tag(&self) -> &'static str { diff --git a/diskann-benchmark-runner/src/checker.rs b/diskann-benchmark-runner/src/checker.rs index 4b3dda556..03ca6dec0 100644 --- a/diskann-benchmark-runner/src/checker.rs +++ b/diskann-benchmark-runner/src/checker.rs @@ -8,8 +8,6 @@ use std::{ path::{Path, PathBuf}, }; -use crate::Any; - /// Shared context for resolving input and output files paths post deserialization. #[derive(Debug)] pub struct Checker { @@ -29,12 +27,6 @@ pub struct Checker { /// /// This ensures that each job uses a distinct output directory to avoid conflicts. current_outputs: HashSet, - - /// This crate-private variable is used to store the current input deserialization - /// tag and is referenced when creating new `Any` objects. - /// - /// Ensure that the correct tag is present before invoking [`Input::try_deserialize`]. - tag: Option<&'static str>, } impl Checker { @@ -44,23 +36,9 @@ impl Checker { search_directories, output_directory, current_outputs: HashSet::new(), - tag: None, } } - /// Invoke [`CheckDeserialization`] on `value` and if successful, package it in [`Any`]. - pub fn any(&mut self, mut value: T) -> anyhow::Result - where - T: serde::Serialize + CheckDeserialization + std::fmt::Debug + 'static, - { - value.check_deserialization(self)?; - #[expect( - clippy::expect_used, - reason = "crate infrastructure ensures an untagged Checker is not leaked" - )] - Ok(Any::new(value, self.tag.expect("tag must be set"))) - } - /// Return the ordered list of search directories registered with the [`Checker`]. pub fn search_directories(&self) -> &[PathBuf] { &self.search_directories @@ -167,17 +145,6 @@ impl Checker { self.search_directories(), ))) } - - pub(crate) fn set_tag(&mut self, tag: &'static str) { - let _ = self.tag.insert(tag); - } -} - -/// Perform post-process resolution of input and output files paths. -pub trait CheckDeserialization { - /// Perform any necessary resolution of file paths, returning an error if a problem is - /// discovered. - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error>; } /////////// diff --git a/diskann-benchmark-runner/src/files.rs b/diskann-benchmark-runner/src/files.rs index 355b47010..1672f6f62 100644 --- a/diskann-benchmark-runner/src/files.rs +++ b/diskann-benchmark-runner/src/files.rs @@ -7,7 +7,7 @@ use std::path::{Path, PathBuf}; use serde::{Deserialize, Serialize}; -use super::checker::{CheckDeserialization, Checker}; +use super::Checker; /// A file that is used as an input to for a benchmark. /// @@ -37,17 +37,8 @@ impl InputFile { path: PathBuf::from(path), } } -} - -impl std::ops::Deref for InputFile { - type Target = Path; - fn deref(&self) -> &Self::Target { - &self.path - } -} -impl CheckDeserialization for InputFile { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + pub fn resolve(&mut self, checker: &mut Checker) -> anyhow::Result<()> { let checked_path = checker.check_path(self); match checked_path { Ok(p) => { @@ -59,6 +50,13 @@ impl CheckDeserialization for InputFile { } } +impl std::ops::Deref for InputFile { + type Target = Path; + fn deref(&self) -> &Self::Target { + &self.path + } +} + /////////// // Tests // /////////// @@ -86,7 +84,7 @@ mod tests { } #[test] - fn test_check_deserialization() { + fn test_resolve() { // We create a directory that looks like this: // // dir/ @@ -113,13 +111,13 @@ mod tests { let absolute = path.join("file_a.txt"); let mut file = InputFile::new(absolute.clone()); let mut checker = Checker::new(Vec::new(), None); - file.check_deserialization(&mut checker).unwrap(); + file.resolve(&mut checker).unwrap(); assert_eq!(file.path, absolute); let absolute = path.join("dir0/file_b.txt"); let mut file = InputFile::new(absolute.clone()); let mut checker = Checker::new(Vec::new(), None); - file.check_deserialization(&mut checker).unwrap(); + file.resolve(&mut checker).unwrap(); assert_eq!(file.path, absolute); } @@ -128,7 +126,7 @@ mod tests { let absolute = path.join("dir0/file_c.txt"); let mut file = InputFile::new(absolute.clone()); let mut checker = Checker::new(Vec::new(), None); - let err = file.check_deserialization(&mut checker).unwrap_err(); + let err = file.resolve(&mut checker).unwrap_err(); let message = err.to_string(); assert!(message.contains("input file with absolute path")); assert!(message.contains("either does not exist or is not a file")); @@ -143,23 +141,23 @@ mod tests { // Directories are searched in order. let mut file = InputFile::new("file_c.txt"); - file.check_deserialization(&mut checker).unwrap(); + file.resolve(&mut checker).unwrap(); assert_eq!(file.path, path.join("dir1/dir0/file_c.txt")); let mut file = InputFile::new("file_b.txt"); - file.check_deserialization(&mut checker).unwrap(); + file.resolve(&mut checker).unwrap(); assert_eq!(file.path, path.join("dir0/file_b.txt")); // Directory search can fail. let mut file = InputFile::new("file_a.txt"); - let err = file.check_deserialization(&mut checker).unwrap_err(); + let err = file.resolve(&mut checker).unwrap_err(); let message = err.to_string(); assert!(message.contains("could not find input file")); assert!(message.contains("in the search directories")); // If we give an absolute path, no directory search is performed. let mut file = InputFile::new(path.join("file_c.txt")); - let err = file.check_deserialization(&mut checker).unwrap_err(); + let err = file.resolve(&mut checker).unwrap_err(); let message = err.to_string(); assert!(message.starts_with("input file with absolute path")); } diff --git a/diskann-benchmark-runner/src/input.rs b/diskann-benchmark-runner/src/input.rs index 019298542..29a78b1ea 100644 --- a/diskann-benchmark-runner/src/input.rs +++ b/diskann-benchmark-runner/src/input.rs @@ -3,9 +3,21 @@ * Licensed under the MIT license. */ -use crate::{Any, Checker}; +use crate::Checker; + +/// Inputs to [`Benchmarks`](crate::Benchmark). +/// +/// These begin as [`raw`](Self::Raw) data transfer objects before final construction via +/// [`from_raw`](Self::from_raw). +pub trait Input: Sized + std::fmt::Debug + 'static { + /// The raw form of this input that is deserialized from input files and serialized as + /// [`examples`](Self::example). The raw nature of this type reflects that no input + /// validation has been performed beyond the checks performed by its + /// [`Deserialize`](serde::Deserialize) implementation. + /// + /// Final object validation is performed via [`from_raw`](Self::from_raw). + type Raw: serde::de::DeserializeOwned + serde::Serialize; -pub trait Input { /// Return the discriminant associated with this type. /// /// This is used to map inputs types to their respective parsers. @@ -13,36 +25,22 @@ pub trait Input { /// Well formed implementations should always return the same result. fn tag() -> &'static str; - /// Attempt to deserialize an opaque object from the raw `serialized` representation. - /// - /// Deserialized values can be constructed and returned via [`Checker::any`], - /// [`Any::new`] or [`Any::raw`]. - /// - /// If using the [`Any`] constructors directly, implementations should associate - /// [`Self::tag`] with the returned `Any`. If [`Checker::any`] is used - this will - /// happen automatically. - /// - /// Implementations are **strongly** encouraged to implement - /// [`CheckDeserialization`](crate::CheckDeserialization) and use this API to ensure - /// shared resources (like input files or output files) are correctly resolved and - /// properly shared among all jobs in a benchmark run. - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result; + /// Construct `Self` from the raw deserialized representation, performing any necessary + /// validation checks (e.g., resolving file paths via the [`Checker`]). + fn from_raw(raw: Self::Raw, checker: &mut Checker) -> anyhow::Result; + + /// Serialize `self` to a [`serde_json::Value`]. + fn serialize(&self) -> anyhow::Result; - /// Print an example JSON representation of objects this input is expected to parse. + /// Return an example of a raw input for this [`Input`]. /// - /// Well-formed implementations should ensure that passing the returned - /// [`serde_json::Value`] back to [`Self::try_deserialize`] correctly deserializes, - /// though it need not necessarily pass - /// [`CheckDeserialization`](crate::CheckDeserialization). - fn example() -> anyhow::Result; + /// This is used to supply sample JSON layouts in the benchmark CLI. + fn example() -> Self::Raw; } /// A registered input. See [`crate::Registry::input`]. #[derive(Clone, Copy)] -pub struct Registered<'a>(pub(crate) &'a dyn DynInput); +pub struct Registered<'a>(pub(crate) &'a dyn internal::DynInput); impl Registered<'_> { /// Return the input tag of the registered input. @@ -54,12 +52,12 @@ impl Registered<'_> { /// Try to deserialize raw JSON into the dynamic type of the input. /// - /// See: [`Input::try_deserialize`]. - pub fn try_deserialize( + /// See: [`Input::from_raw`]. + pub(crate) fn try_deserialize( &self, serialized: &serde_json::Value, checker: &mut Checker, - ) -> anyhow::Result { + ) -> anyhow::Result { self.0.try_deserialize(serialized, checker) } @@ -79,64 +77,123 @@ impl std::fmt::Debug for Registered<'_> { } } -////////////// -// Internal // -////////////// +pub(crate) mod internal { + use super::*; -#[derive(Debug)] -pub(crate) struct Wrapper(std::marker::PhantomData); + /// Runtime representation of a deserialized [`Input`]. + #[derive(Debug)] + pub(crate) struct Any { + any: Box, + } -impl Wrapper { - pub(crate) const INSTANCE: Self = Self::new(); + impl Any { + pub(crate) fn new(input: T) -> Self + where + T: Input, + { + Self { + any: Box::new(input), + } + } + + #[must_use = "this function has no side effects"] + pub(crate) fn tag(&self) -> &'static str { + self.any.tag() + } + + #[must_use = "this function has no side effects"] + pub(crate) fn downcast_ref(&self) -> Option<&T> + where + T: std::any::Any, + { + self.any.as_any().downcast_ref::() + } + + #[must_use = "this function has no side effects"] + pub(crate) fn serialize(&self) -> anyhow::Result { + self.any.serialize() + } + } - pub(crate) const fn new() -> Self { - Self(std::marker::PhantomData) + trait RuntimeAny: std::fmt::Debug { + fn tag(&self) -> &'static str; + fn as_any(&self) -> &dyn std::any::Any; + fn serialize(&self) -> anyhow::Result; } -} -impl Clone for Wrapper { - fn clone(&self) -> Self { - *self + impl RuntimeAny for T + where + T: Input, + { + fn tag(&self) -> &'static str { + ::tag() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn serialize(&self) -> anyhow::Result { + ::serialize(self) + } } -} -impl Copy for Wrapper {} + #[derive(Debug)] + pub(crate) struct Wrapper(std::marker::PhantomData); -pub(crate) trait DynInput { - fn tag(&self) -> &'static str; - fn try_deserialize( - &self, - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result; - fn example(&self) -> anyhow::Result; - - // reflection - fn as_any(&self) -> &dyn std::any::Any; - fn type_name(&self) -> &'static str; -} + impl Wrapper { + pub(crate) const INSTANCE: Self = Self::new(); -impl DynInput for Wrapper -where - T: Input + 'static, -{ - fn tag(&self) -> &'static str { - T::tag() - } - fn try_deserialize( - &self, - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - T::try_deserialize(serialized, checker) + pub(crate) const fn new() -> Self { + Self(std::marker::PhantomData) + } } - fn example(&self) -> anyhow::Result { - T::example() + + impl Clone for Wrapper { + fn clone(&self) -> Self { + *self + } } - fn as_any(&self) -> &dyn std::any::Any { - self + + impl Copy for Wrapper {} + + pub(crate) trait DynInput { + fn tag(&self) -> &'static str; + fn try_deserialize( + &self, + serialized: &serde_json::Value, + checker: &mut Checker, + ) -> anyhow::Result; + fn example(&self) -> anyhow::Result; + + // reflection + fn as_any(&self) -> &dyn std::any::Any; + fn type_name(&self) -> &'static str; } - fn type_name(&self) -> &'static str { - std::any::type_name::() + + impl DynInput for Wrapper + where + T: Input, + { + fn tag(&self) -> &'static str { + T::tag() + } + fn try_deserialize( + &self, + serialized: &serde_json::Value, + checker: &mut Checker, + ) -> anyhow::Result { + let raw = >::deserialize(serialized)?; + Ok(Any::new(T::from_raw(raw, checker)?)) + } + fn example(&self) -> anyhow::Result { + Ok(serde_json::to_value(T::example())?) + } + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn type_name(&self) -> &'static str { + std::any::type_name::() + } } } diff --git a/diskann-benchmark-runner/src/internal/regression.rs b/diskann-benchmark-runner/src/internal/regression.rs index c1d1838ee..2df2d922a 100644 --- a/diskann-benchmark-runner/src/internal/regression.rs +++ b/diskann-benchmark-runner/src/internal/regression.rs @@ -100,8 +100,9 @@ use serde_json::Value; use crate::{ benchmark::{internal::CheckedPassFail, PassFail}, + input::internal::Any, internal::load_from_disk, - jobs, registry, result, Any, Checker, + jobs, registry, result, Checker, }; //////////// @@ -348,7 +349,6 @@ impl Raw { .with_context(context); } - checker.set_tag(entry.tolerance.tag()); let tolerance = entry .tolerance .try_deserialize(&unprocessed.tolerance.content, &mut checker) diff --git a/diskann-benchmark-runner/src/jobs.rs b/diskann-benchmark-runner/src/jobs.rs index e7ca99a2a..0cc75e751 100644 --- a/diskann-benchmark-runner/src/jobs.rs +++ b/diskann-benchmark-runner/src/jobs.rs @@ -8,22 +8,22 @@ use std::path::{Path, PathBuf}; use anyhow::Context; use serde::{Deserialize, Serialize}; -use crate::{checker::Checker, input, Any, Registry}; +use crate::{checker::Checker, input, Registry}; #[derive(Debug)] pub(crate) struct Jobs { /// The benchmark jobs to execute. - jobs: Vec, + jobs: Vec, } impl Jobs { /// Return the jobs associated with this benchmark run. - pub(crate) fn jobs(&self) -> &[Any] { + pub(crate) fn jobs(&self) -> &[input::internal::Any] { &self.jobs } /// Consume `self`, returning the contained list of jobs. - pub(crate) fn into_inner(self) -> Vec { + pub(crate) fn into_inner(self) -> Vec { self.jobs } @@ -51,7 +51,7 @@ impl Jobs { ); let num_jobs = partial.jobs.len(); - let jobs: anyhow::Result> = partial + let jobs: anyhow::Result> = partial .jobs .iter() .enumerate() @@ -71,7 +71,6 @@ impl Jobs { }) .with_context(context)?; - checker.set_tag(input.tag()); input .try_deserialize(&unprocessed.content, &mut checker) .with_context(context) diff --git a/diskann-benchmark-runner/src/lib.rs b/diskann-benchmark-runner/src/lib.rs index e0c3b2791..724a827f6 100644 --- a/diskann-benchmark-runner/src/lib.rs +++ b/diskann-benchmark-runner/src/lib.rs @@ -11,7 +11,6 @@ mod internal; mod jobs; mod result; -pub mod any; pub mod app; pub mod files; pub mod input; @@ -19,10 +18,9 @@ pub mod output; pub mod registry; pub mod utils; -pub use any::Any; pub use app::App; pub use benchmark::Benchmark; -pub use checker::{CheckDeserialization, Checker}; +pub use checker::Checker; pub use input::Input; pub use output::Output; pub use registry::{Registry, RegistryError}; diff --git a/diskann-benchmark-runner/src/registry.rs b/diskann-benchmark-runner/src/registry.rs index f19bb93d3..4cc4aaa1b 100644 --- a/diskann-benchmark-runner/src/registry.rs +++ b/diskann-benchmark-runner/src/registry.rs @@ -9,7 +9,7 @@ use thiserror::Error; use crate::{ benchmark::{self, Benchmark, FailureScore, MatchScore, Regression}, - input, Any, Checkpoint, Input, Output, + input, Checkpoint, Input, Output, }; /// A registered benchmark entry: a name paired with a type-erased benchmark. @@ -41,7 +41,7 @@ impl RegisteredBenchmark { /// A collection of registered inputs and benchmarks. pub struct Registry { // Inputs keyed by their tag type. - inputs: HashMap<&'static str, Box>, + inputs: HashMap<&'static str, Box>, benchmarks: Vec, } @@ -105,7 +105,7 @@ impl Registry { } /// Return `true` if `job` matches with any registered benchmark. Otherwise, return `false`. - pub fn has_match(&self, job: &Any) -> bool { + pub(crate) fn has_match(&self, job: &input::internal::Any) -> bool { self.find_best_match(job).is_some() } @@ -114,9 +114,9 @@ impl Registry { /// Returns the results of the benchmark if successful. /// /// Errors if a suitable method could not be found or if the invoked benchmark failed. - pub fn call( + pub(crate) fn call( &self, - job: &Any, + job: &input::internal::Any, checkpoint: Checkpoint<'_>, output: &mut dyn Output, ) -> anyhow::Result { @@ -132,7 +132,11 @@ impl Registry { /// reasons. /// /// Returns `Ok(())` if a match was found. - pub fn debug(&self, job: &Any, max_methods: usize) -> Result<(), Vec> { + pub(crate) fn debug( + &self, + job: &input::internal::Any, + max_methods: usize, + ) -> Result<(), Vec> { if self.has_match(job) { return Ok(()); } @@ -166,7 +170,7 @@ impl Registry { } /// Find the best matching benchmark for `job` by score. - fn find_best_match(&self, job: &Any) -> Option<&RegisteredBenchmark> { + fn find_best_match(&self, job: &input::internal::Any) -> Option<&RegisteredBenchmark> { self.benchmarks .iter() .filter_map(|entry| { @@ -180,7 +184,7 @@ impl Registry { .map(|(entry, _)| entry) } - fn _input(&self, tag: &str) -> Option<&dyn input::DynInput> { + fn _input(&self, tag: &str) -> Option<&dyn input::internal::DynInput> { self.inputs.get(tag).map(|v| &**v) } @@ -189,16 +193,16 @@ impl Registry { T: Input + 'static, { let tag = T::tag(); - let wrapper = crate::input::Wrapper::::new(); + let wrapper = crate::input::internal::Wrapper::::new(); match self.inputs.entry(tag) { Entry::Vacant(v) => { v.insert(Box::new(wrapper)); Ok(()) } Entry::Occupied(o) => { - use input::DynInput; + use input::internal::DynInput; - if o.get().as_any().is::>() { + if o.get().as_any().is::>() { Ok(()) } else { Err(RegistryError { @@ -334,14 +338,17 @@ impl RegressionBenchmark<'_> { self.regression.input_tag() } - pub(crate) fn try_match(&self, input: &Any) -> Result { + pub(crate) fn try_match( + &self, + input: &input::internal::Any, + ) -> Result { self.benchmark.benchmark().try_match(input) } pub(crate) fn check( &self, - tolerance: &Any, - input: &Any, + tolerance: &input::internal::Any, + input: &input::internal::Any, before: &serde_json::Value, after: &serde_json::Value, ) -> anyhow::Result { @@ -360,7 +367,10 @@ pub(crate) struct RegisteredTolerance<'a> { } /// Helper to capture a `Benchmark::description` call into a `String` via `Display`. -struct Capture<'a>(&'a dyn benchmark::internal::Benchmark, Option<&'a Any>); +struct Capture<'a>( + &'a dyn benchmark::internal::Benchmark, + Option<&'a input::internal::Any>, +); impl std::fmt::Display for Capture<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -386,19 +396,21 @@ mod tests { macro_rules! input { ($T:ident, $tag:literal) => { + #[derive(Debug)] struct $T; impl Input for $T { + type Raw = (); fn tag() -> &'static str { $tag } - fn try_deserialize( - _serialized: &serde_json::Value, - _checker: &mut Checker, - ) -> anyhow::Result { + fn from_raw(_raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result<$T> { + unimplemented!("this struct is for test only"); + } + fn serialize(&self) -> anyhow::Result { unimplemented!("this struct is for test only"); } - fn example() -> anyhow::Result { + fn example() -> Self::Raw { unimplemented!("this struct is for test only"); } } @@ -422,7 +434,7 @@ mod tests { { let a = registry._input(A::tag()).unwrap(); - assert!(a.as_any().is::>()); + assert!(a.as_any().is::>()); let name = a.type_name(); assert!(name.contains("A"), "{}", name); @@ -430,7 +442,7 @@ mod tests { { let b = registry._input(B::tag()).unwrap(); - assert!(b.as_any().is::>()); + assert!(b.as_any().is::>()); let name = b.type_name(); assert!(name.contains("B"), "{}", name); diff --git a/diskann-benchmark-runner/src/result.rs b/diskann-benchmark-runner/src/result.rs index cd8e34bb8..f80b44581 100644 --- a/diskann-benchmark-runner/src/result.rs +++ b/diskann-benchmark-runner/src/result.rs @@ -267,9 +267,9 @@ mod tests { let savepath = path.join("output.json"); let inputs = [ - TypeInput::new(DataType::Float32, 1, false), - TypeInput::new(DataType::Float16, 2, false), - TypeInput::new(DataType::Float64, 3, false), + TypeInput::new(DataType::Float32, 1), + TypeInput::new(DataType::Float16, 2), + TypeInput::new(DataType::Float64, 3), ]; let serialized: Vec<_> = inputs diff --git a/diskann-benchmark-runner/src/test/dim.rs b/diskann-benchmark-runner/src/test/dim.rs index edc033a38..d70456883 100644 --- a/diskann-benchmark-runner/src/test/dim.rs +++ b/diskann-benchmark-runner/src/test/dim.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use crate::{ benchmark::{FailureScore, MatchScore, PassFail, Regression}, - Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, + Benchmark, Checker, Checkpoint, Input, Output, }; /////////// @@ -32,25 +32,22 @@ impl DimInput { } impl Input for DimInput { + type Raw = Self; + fn tag() -> &'static str { "test-input-dim" } - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(DimInput::deserialize(serialized)?) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Ok(raw) } - fn example() -> anyhow::Result { - Ok(serde_json::to_value(DimInput::new(Some(128)))?) + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) } -} -impl CheckDeserialization for DimInput { - fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { - Ok(()) + fn example() -> Self::Raw { + DimInput::new(Some(128)) } } @@ -65,23 +62,25 @@ pub(super) struct Tolerance { } impl Input for Tolerance { + type Raw = Self; + fn tag() -> &'static str { "test-input-dim-tolerance" } - fn try_deserialize( - serialized: &serde_json::Value, - _checker: &mut Checker, - ) -> anyhow::Result { - Ok(Any::new(Self::deserialize(serialized)?, Self::tag())) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Ok(raw) } - fn example() -> anyhow::Result { - let this = Self { + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self::Raw { + Self { succeed: true, error_in_check: false, - }; - Ok(serde_json::to_value(this)?) + } } } diff --git a/diskann-benchmark-runner/src/test/typed.rs b/diskann-benchmark-runner/src/test/typed.rs index b34d4301d..f737f875c 100644 --- a/diskann-benchmark-runner/src/test/typed.rs +++ b/diskann-benchmark-runner/src/test/typed.rs @@ -10,31 +10,34 @@ use serde::{Deserialize, Serialize}; use crate::{ benchmark::{FailureScore, MatchScore, PassFail, Regression}, utils::datatype::{AsDataType, DataType}, - Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, + Benchmark, Checker, Checkpoint, Input, Output, }; /////////// // Input // /////////// -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub(crate) struct TypeInput { pub(super) data_type: DataType, pub(super) dim: usize, - // Should we return an error when `check_deserialization` is called? - pub(super) error_when_checked: bool, - // A flag to verify that [`CheckDeserialization`] has run. - #[serde(skip)] - pub(crate) checked: bool, + error_when_checked: bool, +} + +#[derive(Serialize, Deserialize)] +pub(crate) struct TypeInputRaw { + data_type: DataType, + dim: usize, + // Should we return an error when deserializing? + error_when_checked: bool, } impl TypeInput { - pub(crate) fn new(data_type: DataType, dim: usize, error_when_checked: bool) -> Self { + pub(crate) fn new(data_type: DataType, dim: usize) -> Self { Self { data_type, dim, - error_when_checked, - checked: false, + error_when_checked: false, } } @@ -44,33 +47,29 @@ impl TypeInput { } impl Input for TypeInput { + type Raw = TypeInputRaw; + fn tag() -> &'static str { "test-input-types" } - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(TypeInput::deserialize(serialized)?) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + if raw.error_when_checked { + Err(anyhow::anyhow!("test input erroring when checked")) + } else { + Ok(Self::new(raw.data_type, raw.dim)) + } } - fn example() -> anyhow::Result { - Ok(serde_json::to_value(TypeInput::new( - DataType::Float32, - 128, - false, - ))?) + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) } -} -impl CheckDeserialization for TypeInput { - fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { - if self.error_when_checked { - Err(anyhow::anyhow!("test input erroring when checked")) - } else { - self.checked = true; - Ok(()) + fn example() -> Self::Raw { + TypeInputRaw { + data_type: DataType::Float32, + dim: 128, + error_when_checked: false, } } } @@ -81,42 +80,32 @@ impl CheckDeserialization for TypeInput { #[derive(Debug, Serialize, Deserialize)] pub(super) struct Tolerance { - // Should we return an error when `check_deserialization` is called? + // Should we return an error when `from_raw` is called? pub(super) error_when_checked: bool, - - // A flag to verify that [`CheckDeserialization`] has run. - #[serde(skip)] - pub(crate) checked: bool, } impl Input for Tolerance { + type Raw = Self; + fn tag() -> &'static str { "test-input-types-tolerance" } - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(Self::deserialize(serialized)?) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + if raw.error_when_checked { + Err(anyhow::anyhow!("test input erroring when checked")) + } else { + Ok(raw) + } } - fn example() -> anyhow::Result { - let this = Self { - error_when_checked: false, - checked: false, - }; - Ok(serde_json::to_value(this)?) + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) } -} -impl CheckDeserialization for Tolerance { - fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { - if self.error_when_checked { - Err(anyhow::anyhow!("test input erroring when checked")) - } else { - self.checked = true; - Ok(()) + fn example() -> Self::Raw { + Self { + error_when_checked: false, } } } diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index 9fe99e9e9..cf12ae373 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -26,7 +26,7 @@ use diskann_benchmark_runner::{ num::{relative_change, NonNegativeFinite}, percentiles, MicroSeconds, }, - Any, Benchmark, CheckDeserialization, Checker, Input, Registry, + Benchmark, Checker, Input, Registry, }; //////////////// @@ -118,12 +118,6 @@ pub struct SimdOp { runs: Vec, } -impl CheckDeserialization for SimdOp { - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { - Ok(()) - } -} - macro_rules! write_field { ($f:ident, $field:tt, $($expr:tt)*) => { writeln!($f, "{:>18}: {}", $field, $($expr)*) @@ -149,18 +143,21 @@ impl std::fmt::Display for SimdOp { } impl Input for SimdOp { + type Raw = Self; + fn tag() -> &'static str { "simd-op" } - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(Self::deserialize(serialized)?) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Ok(raw) } - fn example() -> anyhow::Result { + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self::Raw { const DIM: [NonZeroUsize; 2] = [ NonZeroUsize::new(128).unwrap(), NonZeroUsize::new(150).unwrap(), @@ -191,12 +188,12 @@ impl Input for SimdOp { }, ]; - Ok(serde_json::to_value(&Self { + Self { query_type: DataType::Float32, data_type: DataType::Float32, arch: Arch::X86_64_V3, runs, - })?) + } } } @@ -213,33 +210,30 @@ struct SimdTolerance { min_time_regression: NonNegativeFinite, } -impl CheckDeserialization for SimdTolerance { - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { - Ok(()) - } -} - impl Input for SimdTolerance { + type Raw = Self; + fn tag() -> &'static str { "simd-tolerance" } - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(Self::deserialize(serialized)?) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Ok(raw) } - fn example() -> anyhow::Result { + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self { const EXAMPLE: NonNegativeFinite = match NonNegativeFinite::new(0.10) { Ok(v) => v, Err(_) => panic!("use a non-negative finite please"), }; - Ok(serde_json::to_value(SimdTolerance { + SimdTolerance { min_time_regression: EXAMPLE, - })?) + } } } diff --git a/diskann-benchmark/README.md b/diskann-benchmark/README.md index 923bb9e27..b48fdf18a 100644 --- a/diskann-benchmark/README.md +++ b/diskann-benchmark/README.md @@ -206,20 +206,20 @@ this is usually easily done with a small code change. With the example of adding Range search to the `f16` index, the registration site: ```rust -benchmarks.register( +registry.register( "async-full-precision-f16", FullPrecision::::new() .search(plugins::Topk), -); +)?; ``` Can be updated to: ```rust -benchmarks.register( +registry.register( "async-full-precision-f16", FullPrecision::::new() .search(plugins::Topk) .search(plugins::Range), -); +)?; ``` This will both compile the range search implementation and make it available for benchmark matching. @@ -337,53 +337,49 @@ pub(crate) struct ComputeGroundTruth { pub(crate) num_nearest_neighbors: usize, } ``` -We need to implement a few traits related to this input type: +We need to implement `diskann_benchmark_runner::Input` for the type. This trait associates +a tag name used for deserialization and benchmark matching, a `Raw` type for JSON +serialization/deserialization, a `from_raw` constructor that performs post-deserialization +validation (e.g., resolving file paths via the `Checker`), and an `example` that supplies +sample JSON layouts for the CLI. -* `diskann_benchmark_runner::Input`: A type-name for this input that is used to identify it for - deserialization and benchmark matching. To make this easier, `benchmark` defines - `benchmark::inputs::Input` that can be used to express type level implementation (shown - below) - -* `CheckDeserialization`: This trait performs post-deserialization invariant checking. - In the context of the `ComputeGroundTruth` type, we use this to check that the input - files are valid. +In the context of the `ComputeGroundTruth` type, we use `from_raw` to check that the input +files are valid. ```rust -impl diskann_benchmark_runner::Input for crate::inputs::Input { +impl diskann_benchmark_runner::Input for ComputeGroundTruth { + // The raw form is just `Self` since the struct is directly deserializable. + type Raw = Self; + // This gets associated with the JSON representation returned by `example` and at run - // time, inputs tagged with this value will be given to `try_deserialize`. + // time, inputs tagged with this value will be given to `from_raw`. fn tag() -> &'static str { "compute_groundtruth" } - // Attempt to deserialize `Self` from raw JSON. - // - // Implementors can assume that `serialized` looks similar in structure to what is - // returned by `example`. - fn try_deserialize( - serialized: &serde_json::Value, + // Construct from the raw deserialized form, performing file path resolution. + fn from_raw( + mut raw: Self::Raw, checker: &mut diskann_benchmark_runner::Checker, - ) -> anyhow::Result { - checker.any(ComputeGroundTruth::deserialize(serialized)?) + ) -> anyhow::Result { + raw.data.resolve(checker)?; + raw.queries.resolve(checker)?; + Ok(raw) } - // Return a serialized representation of `self` to help users create an input file. - fn example() -> anyhow::Result { - serde_json::to_value(Self { + // Serialize `self` to JSON. + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + // Return an example input to help users create an input file. + fn example() -> Self { + Self { data_type: DataType::Float32, data: InputFile::new("path/to/data"), queries: InputFile::new("path/to/queries"), num_nearest_neighbors: 100, - }) - } -} - -impl CheckDeserialization for ComputeGroundTruth { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Forward the deserialization check to the input files. - self.data.check_deserialization(checker)?; - self.queries.check_deserialization(checker)?; - Ok(()) + } } } ``` @@ -409,8 +405,8 @@ To implement benchmarks, we register them with the `diskann_benchmark_runner::Re The simplest thing we can do is something like this: ```rust use diskann_benchmark_runner::{ - dispatcher::{MatchScore, FailureScore}, - Any, Benchmark, Checkpoint, Output, Registry, + benchmark::{MatchScore, FailureScore}, + Benchmark, Checkpoint, Output, }; // Benchmarks can be stateful. @@ -429,6 +425,15 @@ impl Benchmark for RunGroundTruth { Ok(MatchScore::new(0)) } + // Describe the benchmark for CLI display and debugging. + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + _input: Option<&Self::Input>, + ) -> std::fmt::Result { + write!(f, "compute groundtruth") + } + // Run the benchmark (for this example, nothing happens). fn run( &self, diff --git a/diskann-benchmark/src/backend/disk_index/benchmarks.rs b/diskann-benchmark/src/backend/disk_index/benchmarks.rs index f41e39346..c81022b97 100644 --- a/diskann-benchmark/src/backend/disk_index/benchmarks.rs +++ b/diskann-benchmark/src/backend/disk_index/benchmarks.rs @@ -15,7 +15,7 @@ use diskann_benchmark_runner::{ fmt::Table, num::{relative_change, NonNegativeFinite}, }, - Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Registry, + Benchmark, Checker, Checkpoint, Input, Registry, }; use diskann_providers::storage::FileStorageProvider; use half::f16; @@ -165,25 +165,22 @@ impl DiskIndexTolerance { } } -impl CheckDeserialization for DiskIndexTolerance { - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { - Ok(()) - } -} - impl Input for DiskIndexTolerance { + type Raw = Self; + fn tag() -> &'static str { Self::tag() } - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(Self::deserialize(serialized)?) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Ok(raw) } - fn example() -> anyhow::Result { + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self { const DEFAULT: NonNegativeFinite = match NonNegativeFinite::new(0.10) { Ok(v) => v, Err(_) => panic!("use a non-negative finite value"), @@ -193,7 +190,7 @@ impl Input for DiskIndexTolerance { Err(_) => panic!("use a non-negative finite value"), }; - Ok(serde_json::to_value(DiskIndexTolerance { + DiskIndexTolerance { build_time_regression: DEFAULT, qps_regression: DEFAULT, recall_regression: RECALL, @@ -201,7 +198,7 @@ impl Input for DiskIndexTolerance { mean_comps_regression: DEFAULT, mean_latency_regression: DEFAULT, p95_latency_regression: DEFAULT, - })?) + } } } diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index 00f6067d4..473d7982b 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -6,9 +6,7 @@ use std::{fmt, num::NonZeroUsize, path::Path}; use anyhow::Context; -use diskann_benchmark_runner::{ - files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, -}; +use diskann_benchmark_runner::{files::InputFile, utils::datatype::DataType, Checker}; #[cfg(feature = "disk-index")] use diskann_disk::QuantizationType; use diskann_providers::storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file}; @@ -88,29 +86,19 @@ impl DiskIndexOperation { pub(crate) const fn tag() -> &'static str { "disk-index" } -} -/////////////////////////// -// Check Deserialization // -/////////////////////////// - -impl CheckDeserialization for DiskIndexOperation { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // validate the source + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { match &mut self.source { - DiskIndexSource::Load(load) => load.check_deserialization(checker)?, - DiskIndexSource::Build(build) => build.check_deserialization(checker)?, + DiskIndexSource::Load(load) => load.validate(checker)?, + DiskIndexSource::Build(build) => build.validate(checker)?, } - - // validate the search phase - self.search_phase.check_deserialization(checker)?; - + self.search_phase.validate(checker)?; Ok(()) } } -impl CheckDeserialization for DiskIndexLoad { - fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { +impl DiskIndexLoad { + pub(crate) fn validate(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { let files = [ (get_pq_pivot_file(&self.load_path), "pq pivot file"), ( @@ -131,12 +119,9 @@ impl CheckDeserialization for DiskIndexLoad { } } -impl CheckDeserialization for DiskIndexBuild { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // file input - self.data - .check_deserialization(checker) - .context("invalid data file")?; +impl DiskIndexBuild { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.data.resolve(checker).context("invalid data file")?; // basic constraints if self.dim == 0 { @@ -183,18 +168,16 @@ impl CheckDeserialization for DiskIndexBuild { } } -impl CheckDeserialization for DiskSearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // inputs +impl DiskSearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { self.queries - .check_deserialization(checker) + .resolve(checker) .context("invalid queries file")?; self.groundtruth - .check_deserialization(checker) + .resolve(checker) .context("invalid groundtruth file")?; if let Some(vf) = self.vector_filters_file.as_mut() { - vf.check_deserialization(checker) - .context("invalid vector_filters_file")?; + vf.resolve(checker).context("invalid vector_filters_file")?; } // basic numeric sanity checks diff --git a/diskann-benchmark/src/inputs/exhaustive.rs b/diskann-benchmark/src/inputs/exhaustive.rs index d73bc1491..20583de85 100644 --- a/diskann-benchmark/src/inputs/exhaustive.rs +++ b/diskann-benchmark/src/inputs/exhaustive.rs @@ -6,9 +6,7 @@ use std::num::NonZeroUsize; use anyhow::{anyhow, Context}; -use diskann_benchmark_runner::{ - files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, -}; +use diskann_benchmark_runner::{files::InputFile, utils::datatype::DataType, Checker}; use serde::{Deserialize, Serialize}; use crate::{ @@ -41,8 +39,8 @@ pub(crate) struct SearchValues { pub(crate) recall_n: Vec, } -impl CheckDeserialization for SearchValues { - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { +impl SearchValues { + pub(crate) fn validate(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { // Ensure that both `recall_k` and `recall_n` are non-empty. if self.recall_k.is_empty() { return Err(anyhow!("field `recall_k` cannot be empty")); @@ -96,12 +94,11 @@ pub(crate) struct SearchPhase { pub(crate) recalls: SearchValues, } -impl CheckDeserialization for SearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.queries.check_deserialization(checker)?; - self.groundtruth.check_deserialization(checker)?; - self.recalls.check_deserialization(checker)?; +impl SearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.resolve(checker)?; + self.groundtruth.resolve(checker)?; + self.recalls.validate(checker)?; Ok(()) } } @@ -219,14 +216,12 @@ impl Product { pub(crate) const fn tag() -> &'static str { "exhaustive-product-quantization" } -} -impl CheckDeserialization for Product { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.data.check_deserialization(checker)?; - self.search.check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.data.resolve(checker)?; + self.search.validate(checker)?; - // Chcck that provided data type is compatible with `f32`. + // Check that provided data type is compatible with `f32`. f32::check_converting_load(self.data_type)?; let num_centers = self.num_pq_centers.get(); @@ -368,8 +363,8 @@ impl std::fmt::Display for PreScale { } } -impl CheckDeserialization for PreScale { - fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { +impl PreScale { + pub(crate) fn validate(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { if let Self::Some(v) = self { if *v <= 0.0 { anyhow::bail!("pre-scaling {} must be positive", v); @@ -401,12 +396,10 @@ impl Spherical { pub(crate) const fn tag() -> &'static str { "exhaustive-spherical-quantization" } -} -impl CheckDeserialization for Spherical { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.data.check_deserialization(checker)?; - self.search.check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.data.resolve(checker)?; + self.search.validate(checker)?; // Chcck that provided data type is compatible with `f32`. f32::check_converting_load(self.data_type)?; @@ -422,7 +415,7 @@ impl CheckDeserialization for Spherical { })?; } - self.pre_scale.check_deserialization(checker)?; + self.pre_scale.validate(checker)?; Ok(()) } } @@ -504,12 +497,10 @@ impl MinMax { pub(crate) const fn tag() -> &'static str { "exhaustive-minmax-quantization" } -} -impl CheckDeserialization for MinMax { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.data.check_deserialization(checker)?; - self.search.check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.data.resolve(checker)?; + self.search.validate(checker)?; // Chcck that provided data type is compatible with `f32`. f32::check_converting_load(self.data_type)?; diff --git a/diskann-benchmark/src/inputs/filters.rs b/diskann-benchmark/src/inputs/filters.rs index 09fdaf919..942c6da12 100644 --- a/diskann-benchmark/src/inputs/filters.rs +++ b/diskann-benchmark/src/inputs/filters.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::{files::InputFile, CheckDeserialization, Checker}; +use diskann_benchmark_runner::{files::InputFile, Checker}; use serde::{Deserialize, Serialize}; use crate::inputs::{as_input, Example}; @@ -55,17 +55,10 @@ impl MetadataIndexBuild { pub(crate) const fn tag() -> &'static str { "metadata-index-build" } -} -impl CheckDeserialization for MetadataIndexBuild { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Validate filter parameters (which include the paths to queries and label files) - self.filter_params - .data_labels - .check_deserialization(checker)?; - self.filter_params - .query_predicates - .check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.filter_params.data_labels.resolve(checker)?; + self.filter_params.query_predicates.resolve(checker)?; Ok(()) } } diff --git a/diskann-benchmark/src/inputs/graph_index.rs b/diskann-benchmark/src/inputs/graph_index.rs index 95cb89484..9df194382 100644 --- a/diskann-benchmark/src/inputs/graph_index.rs +++ b/diskann-benchmark/src/inputs/graph_index.rs @@ -11,9 +11,7 @@ use diskann::{ utils::IntoUsize, }; use diskann_benchmark_core::streaming::executors::bigann; -use diskann_benchmark_runner::{ - files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, -}; +use diskann_benchmark_runner::{files::InputFile, utils::datatype::DataType, Checker}; use diskann_providers::{ model::{ configuration::IndexConfiguration, @@ -50,8 +48,8 @@ pub(crate) struct GraphSearch { pub(crate) recall_k: usize, } -impl CheckDeserialization for GraphSearch { - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { +impl GraphSearch { + pub(crate) fn validate(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { for (i, l) in self.search_l.iter().enumerate() { if *l < self.search_n { return Err(anyhow!( @@ -97,9 +95,8 @@ impl GraphRangeSearch { } } -impl CheckDeserialization for GraphRangeSearch { - // all necessary checks are carried out when Range is initialized - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { +impl GraphRangeSearch { + pub(crate) fn validate(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { self.construct_params() .context("invalid range search params")?; @@ -117,14 +114,12 @@ pub(crate) struct TopkSearchPhase { pub(crate) runs: Vec, } -impl CheckDeserialization for TopkSearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.queries.check_deserialization(checker)?; - - self.groundtruth.check_deserialization(checker)?; +impl TopkSearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.resolve(checker)?; + self.groundtruth.resolve(checker)?; for (i, run) in self.runs.iter_mut().enumerate() { - run.check_deserialization(checker) + run.validate(checker) .with_context(|| format!("search run {}", i))?; } @@ -169,14 +164,12 @@ pub(crate) struct RangeSearchPhase { pub(crate) runs: Vec, } -impl CheckDeserialization for RangeSearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.queries.check_deserialization(checker)?; - - self.groundtruth.check_deserialization(checker)?; +impl RangeSearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.resolve(checker)?; + self.groundtruth.resolve(checker)?; for (i, run) in self.runs.iter_mut().enumerate() { - run.check_deserialization(checker) + run.validate(checker) .with_context(|| format!("search run {}", i))?; } @@ -197,13 +190,11 @@ pub(crate) struct BetaSearchPhase { pub(crate) runs: Vec, } -impl CheckDeserialization for BetaSearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.queries.check_deserialization(checker)?; - - self.query_predicates.check_deserialization(checker)?; - self.data_labels.check_deserialization(checker)?; +impl BetaSearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.resolve(checker)?; + self.query_predicates.resolve(checker)?; + self.data_labels.resolve(checker)?; if self.beta <= 0.0 || self.beta > 1.0 { return Err(anyhow::anyhow!( @@ -212,9 +203,9 @@ impl CheckDeserialization for BetaSearchPhase { )); } - self.groundtruth.check_deserialization(checker)?; + self.groundtruth.resolve(checker)?; for (i, run) in self.runs.iter_mut().enumerate() { - run.check_deserialization(checker) + run.validate(checker) .with_context(|| format!("search run {}", i))?; } @@ -234,17 +225,14 @@ pub(crate) struct MultiHopSearchPhase { pub(crate) runs: Vec, } -impl CheckDeserialization for MultiHopSearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.queries.check_deserialization(checker)?; - - self.query_predicates.check_deserialization(checker)?; - self.data_labels.check_deserialization(checker)?; - - self.groundtruth.check_deserialization(checker)?; +impl MultiHopSearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.resolve(checker)?; + self.query_predicates.resolve(checker)?; + self.data_labels.resolve(checker)?; + self.groundtruth.resolve(checker)?; for (i, run) in self.runs.iter_mut().enumerate() { - run.check_deserialization(checker) + run.validate(checker) .with_context(|| format!("search run {}", i))?; } @@ -394,13 +382,13 @@ impl SearchPhase { } } -impl CheckDeserialization for SearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { +impl SearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { match self { - SearchPhase::Topk(phase) => phase.check_deserialization(checker), - SearchPhase::Range(phase) => phase.check_deserialization(checker), - SearchPhase::TopkBetaFilter(phase) => phase.check_deserialization(checker), - SearchPhase::TopkMultihopFilter(phase) => phase.check_deserialization(checker), + SearchPhase::Topk(phase) => phase.validate(checker), + SearchPhase::Range(phase) => phase.validate(checker), + SearchPhase::TopkBetaFilter(phase) => phase.validate(checker), + SearchPhase::TopkMultihopFilter(phase) => phase.validate(checker), } } } @@ -482,10 +470,8 @@ impl IndexLoad { write_field!(f, "Load Path", self.load_path)?; Ok(()) } -} -impl CheckDeserialization for IndexLoad { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { // Check if the file exists (allowing for relative paths with respect to the current // directory. // @@ -652,12 +638,9 @@ impl IndexBuild { } Ok(()) } -} -impl CheckDeserialization for IndexBuild { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.data.check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.data.resolve(checker)?; // We allow overwriting of already existing save paths, since users like to do this // The save path must either (1) be an absolute path, in which case we check that its parent directory exists @@ -720,18 +703,14 @@ impl IndexSource { IndexSource::Build(build) => &build.data_type, } } -} -impl CheckDeserialization for IndexSource { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { match self { - IndexSource::Load(load) => load.check_deserialization(checker), - IndexSource::Build(build) => build.check_deserialization(checker), + IndexSource::Load(load) => load.validate(checker), + IndexSource::Build(build) => build.validate(checker), } } -} -impl IndexSource { fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { IndexSource::Load(load) => load.summarize_fields(f), @@ -750,13 +729,10 @@ impl IndexOperation { pub(crate) const fn tag() -> &'static str { "graph-index-build" } -} -impl CheckDeserialization for IndexOperation { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.source.check_deserialization(checker)?; - self.search_phase.check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.source.validate(checker)?; + self.search_phase.validate(checker)?; Ok(()) } @@ -832,11 +808,9 @@ impl IndexPQOperation { IndexSource::Build(b) => Ok(b.inmem_parameters(num_points, dim)), } } -} -impl CheckDeserialization for IndexPQOperation { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.index_operation.check_deserialization(checker) + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.index_operation.validate(checker) } } @@ -911,10 +885,8 @@ impl IndexSQOperation { IndexSource::Build(b) => Ok(b.inmem_parameters(num_points, dim)), } } -} -impl CheckDeserialization for IndexSQOperation { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { if self.standard_deviations <= 0.0 { return Err(anyhow::anyhow!( "scalar quantization standard deviations ({}) must be strictly positive", @@ -922,7 +894,7 @@ impl CheckDeserialization for IndexSQOperation { )); } - self.index_operation.check_deserialization(checker) + self.index_operation.validate(checker) } } @@ -994,12 +966,10 @@ impl SphericalQuantBuild { ) -> DefaultProviderParameters { self.build.inmem_parameters(num_points, dim) } -} -impl CheckDeserialization for SphericalQuantBuild { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.build.check_deserialization(checker)?; - self.search_phase.check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.build.validate(checker)?; + self.search_phase.validate(checker)?; if self.build.save_path.is_some() { return Err(anyhow::anyhow!( @@ -1021,7 +991,7 @@ impl CheckDeserialization for SphericalQuantBuild { } if let Some(pre_scale) = &mut self.pre_scale { - pre_scale.check_deserialization(checker)?; + pre_scale.validate(checker)?; } Ok(()) @@ -1125,9 +1095,9 @@ pub(crate) struct DynamicRunbookParams { // 1. The runbook file can be parsed // 2. The dataset_name exists in the runbook // 3. All required ground truth files exist in gt_directory -impl CheckDeserialization for DynamicRunbookParams { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.runbook_path.check_deserialization(checker)?; +impl DynamicRunbookParams { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.runbook_path.resolve(checker)?; // Validate consolidate_threshold is greater than 0 if self.consolidate_threshold <= 0.0 { @@ -1256,6 +1226,13 @@ impl DynamicIndexRun { "graph-index-dynamic-run" } + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.build.validate(checker)?; + self.runbook_params.validate(checker)?; + self.search_phase.validate(checker)?; + Ok(()) + } + pub(crate) fn try_as_config(&self, insert_l: usize) -> anyhow::Result { let mut builder = self.build.try_as_config()?; builder.l_build(insert_l); @@ -1271,15 +1248,6 @@ impl DynamicIndexRun { } } -impl CheckDeserialization for DynamicIndexRun { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.build.check_deserialization(checker)?; - self.runbook_params.check_deserialization(checker)?; - self.search_phase.check_deserialization(checker)?; - Ok(()) - } -} - impl Example for DynamicIndexRun { fn example() -> Self { let build = IndexBuild::example(); diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index 7875beb1d..492f0b9c1 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -14,29 +14,38 @@ pub(crate) trait Example { fn example() -> Self; } -// NOTE: The input registration and dispatching isn't prefect. It uses a pattern (like -// the use of `'static` on the benchmark types) as a byproduct of older ways of doing -// benchmark selection. -// -// In the future, these can be migrated to reduce this legacy cruft. +/// Implement [`diskann_benchmark_runner::Input`] for `$T` using `Raw = $T`. +/// +/// Requires `$T` to: +/// - implement [`Example`]; +/// - provide an inherent `fn tag() -> &'static str` method; +/// - provide a +/// `fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()>` method; and +/// - implement the serde traits required by +/// [`diskann_benchmark_runner::Input`] and `serde_json::to_value(self)`. macro_rules! as_input { ($T:ty) => { impl diskann_benchmark_runner::Input for $T { + type Raw = $T; + fn tag() -> &'static str { <$T>::tag() } - fn try_deserialize( - serialized: &serde_json::Value, + fn from_raw( + mut raw: Self::Raw, checker: &mut diskann_benchmark_runner::Checker, - ) -> anyhow::Result { - checker.any(<$T as serde::Deserialize>::deserialize(serialized)?) + ) -> anyhow::Result { + raw.validate(checker)?; + Ok(raw) + } + + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) } - fn example() -> anyhow::Result { - Ok(serde_json::to_value( - <$T as $crate::inputs::Example>::example(), - )?) + fn example() -> Self { + <$T as $crate::inputs::Example>::example() } } };