Skip to content

Commit 14c5b08

Browse files
committed
Merge branch 'main' into sc-20250723-serialize
2 parents 0b63f37 + 08316ba commit 14c5b08

7 files changed

Lines changed: 204 additions & 157 deletions

File tree

crates/geo_filters/src/config.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,13 +353,15 @@ pub(crate) fn take_ref<I: Iterator>(iter: &mut I, n: usize) -> impl Iterator<Ite
353353

354354
#[cfg(test)]
355355
pub(crate) mod tests {
356-
use rand::{RngCore, SeedableRng};
356+
use rand::{rngs::StdRng, RngCore};
357357

358358
use crate::{Count, Method};
359359

360360
/// Runs estimation trials and returns the average precision and variance.
361-
pub(crate) fn test_estimate<M: Method, C: Count<M>>(f: impl Fn() -> C) -> (f32, f32) {
362-
let mut rnd = rand::rngs::StdRng::from_os_rng();
361+
pub(crate) fn test_estimate<M: Method, C: Count<M>>(
362+
rnd: &mut StdRng,
363+
f: impl Fn() -> C,
364+
) -> (f32, f32) {
363365
let cnt = 10000usize;
364366
let mut avg_precision = 0.0;
365367
let mut avg_var = 0.0;

crates/geo_filters/src/config/lookup.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,29 +45,36 @@ impl HashToBucketLookup {
4545

4646
#[cfg(test)]
4747
mod tests {
48-
use rand::{RngCore, SeedableRng};
48+
use rand::{rngs::StdRng, RngCore};
4949

50-
use crate::config::{hash_to_bucket, phi_f64};
50+
use crate::{
51+
config::{hash_to_bucket, phi_f64},
52+
test_rng::prng_test_harness,
53+
};
5154

5255
use super::HashToBucketLookup;
5356

5457
#[test]
5558
fn test_lookup_7() {
56-
let var = lookup_random_hashes_variance::<7>(1 << 16);
57-
assert!(var < 1e-4, "variance {var} too large");
59+
prng_test_harness(1, |rnd| {
60+
let var = lookup_random_hashes_variance::<7>(rnd, 1 << 16);
61+
assert!(var < 1e-4, "variance {var} too large");
62+
});
5863
}
5964

6065
#[test]
6166
fn test_lookup_13() {
62-
let var = lookup_random_hashes_variance::<13>(1 << 16);
63-
assert!(var < 1e-4, "variance {var} too large");
67+
prng_test_harness(1, |rnd| {
68+
let var = lookup_random_hashes_variance::<13>(rnd, 1 << 16);
69+
assert!(var < 1e-4, "variance {var} too large");
70+
});
6471
}
6572

66-
fn lookup_random_hashes_variance<const B: usize>(n: u64) -> f64 {
73+
fn lookup_random_hashes_variance<const B: usize>(rnd: &mut StdRng, n: u64) -> f64 {
6774
let phi = phi_f64(B);
6875
let buckets = HashToBucketLookup::new(B);
76+
6977
let mut var = 0.0;
70-
let mut rnd = rand::rngs::StdRng::from_os_rng();
7178
for _ in 0..n {
7279
let hash = rnd.next_u64();
7380
let estimate = buckets.lookup(hash) as f64;

crates/geo_filters/src/diff_count.rs

Lines changed: 109 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,17 @@ impl<'a, C: GeoConfig<Diff>> GeoDiffCount<'a, C> {
9696
/// having to construct another iterator with the remaining `BitChunk`s.
9797
fn from_bit_chunks<I: Iterator<Item = BitChunk>>(config: C, chunks: I) -> Self {
9898
let mut ones = iter_ones::<C::BucketType, _>(chunks.peekable());
99-
10099
let mut msb = Vec::default();
101100
take_ref(&mut ones, config.max_msb_len() - 1).for_each(|bucket| {
102101
msb.push(bucket);
103102
});
104103
let smallest_msb = ones
105104
.next()
106-
.map(|bucket| {
107-
msb.push(bucket);
108-
bucket
105+
.inspect(|bucket| {
106+
msb.push(*bucket);
109107
})
110108
.unwrap_or_default();
111-
112109
let lsb = BitVec::from_bit_chunks(ones.into_bitchunks(), smallest_msb.into_usize());
113-
114110
let result = Self {
115111
config,
116112
msb: Cow::from(msb),
@@ -217,17 +213,16 @@ impl<'a, C: GeoConfig<Diff>> GeoDiffCount<'a, C> {
217213
match msb.binary_search_by(|k| bucket.cmp(k)) {
218214
Ok(idx) => {
219215
msb.remove(idx);
220-
let (first, second) = {
216+
let first = {
221217
let mut lsb = iter_ones(self.lsb.bit_chunks().peekable());
222-
(lsb.next(), lsb.next())
218+
lsb.next()
223219
};
224-
let new_smallest = if let Some(smallest) = first {
220+
if let Some(smallest) = first {
225221
msb.push(C::BucketType::from_usize(smallest));
226-
second.map(|_| smallest).unwrap_or(0)
222+
self.lsb.resize(smallest);
227223
} else {
228-
0
224+
self.lsb.resize(0);
229225
};
230-
self.lsb.resize(new_smallest);
231226
}
232227
Err(idx) => {
233228
msb.insert(idx, bucket);
@@ -245,6 +240,12 @@ impl<'a, C: GeoConfig<Diff>> GeoDiffCount<'a, C> {
245240
// ensure LSB bit vector has the space for `smallest`
246241
self.lsb.resize(new_smallest);
247242
self.lsb.toggle(smallest);
243+
} else if msb.len() == self.config.max_msb_len() {
244+
let smallest = msb
245+
.last()
246+
.expect("should have at least one element")
247+
.into_usize();
248+
self.lsb.resize(smallest);
248249
}
249250
}
250251
}
@@ -418,11 +419,12 @@ mod tests {
418419
use std::io::Write;
419420

420421
use itertools::Itertools;
421-
use rand::{seq::IteratorRandom, RngCore, SeedableRng};
422+
use rand::{rngs::StdRng, seq::IteratorRandom, RngCore};
422423

423424
use crate::{
424425
build_hasher::UnstableDefaultBuildHasher,
425426
config::{iter_ones, tests::test_estimate, FixedConfig},
427+
test_rng::prng_test_harness,
426428
};
427429

428430
use super::*;
@@ -493,57 +495,62 @@ mod tests {
493495

494496
#[test]
495497
fn test_estimate_fast() {
496-
let (avg_precision, avg_var) = test_estimate(GeoDiffCount7::default);
497-
println!(
498-
"avg precision: {} with standard deviation: {}",
499-
avg_precision,
500-
avg_var.sqrt(),
501-
);
502-
// Make sure that the estimate converges to the correct value.
503-
assert!(avg_precision.abs() < 0.04);
504-
// We should theoretically achieve a standard deviation of about 0.12
505-
assert!(avg_var.sqrt() < 0.14);
498+
prng_test_harness(1, |rnd| {
499+
let (avg_precision, avg_var) = test_estimate(rnd, GeoDiffCount7::default);
500+
println!(
501+
"avg precision: {} with standard deviation: {}",
502+
avg_precision,
503+
avg_var.sqrt(),
504+
);
505+
// Make sure that the estimate converges to the correct value.
506+
assert!(avg_precision.abs() < 0.04);
507+
// We should theoretically achieve a standard deviation of about 0.12
508+
assert!(avg_var.sqrt() < 0.14);
509+
})
506510
}
507511

508512
#[test]
509513
fn test_estimate_fast_low_precision() {
510-
let (avg_precision, avg_var) = test_estimate(GeoDiffCount7_50::default);
511-
println!(
512-
"avg precision: {} with standard deviation: {}",
513-
avg_precision,
514-
avg_var.sqrt(),
515-
);
516-
// Make sure that the estimate converges to the correct value.
517-
assert!(avg_precision.abs() < 0.15);
518-
// We should theoretically achieve a standard deviation of about 0.25
519-
assert!(avg_var.sqrt() < 0.4);
514+
prng_test_harness(1, |rnd| {
515+
let (avg_precision, avg_var) = test_estimate(rnd, GeoDiffCount7_50::default);
516+
println!(
517+
"avg precision: {} with standard deviation: {}",
518+
avg_precision,
519+
avg_var.sqrt(),
520+
);
521+
// Make sure that the estimate converges to the correct value.
522+
assert!(avg_precision.abs() < 0.15);
523+
// We should theoretically achieve a standard deviation of about 0.25
524+
assert!(avg_var.sqrt() < 0.4);
525+
});
520526
}
521527

522528
#[test]
523529
fn test_estimate_diff_size_fast() {
524-
let mut rnd = rand::rngs::StdRng::from_os_rng();
525-
let mut a_p = GeoDiffCount7_50::default();
526-
let mut a_hp = GeoDiffCount7::default();
527-
let mut b_p = GeoDiffCount7_50::default();
528-
let mut b_hp = GeoDiffCount7::default();
529-
for _ in 0..10000 {
530-
let hash = rnd.next_u64();
531-
a_p.push_hash(hash);
532-
a_hp.push_hash(hash);
533-
}
534-
for _ in 0..1000 {
535-
let hash = rnd.next_u64();
536-
b_p.push_hash(hash);
537-
b_hp.push_hash(hash);
538-
}
539-
let c_p = xor(&a_p, &b_p);
540-
let c_hp = xor(&a_hp, &b_hp);
530+
prng_test_harness(1, |rnd| {
531+
let mut a_p = GeoDiffCount7_50::default();
532+
let mut a_hp = GeoDiffCount7::default();
533+
let mut b_p = GeoDiffCount7_50::default();
534+
let mut b_hp = GeoDiffCount7::default();
535+
for _ in 0..10000 {
536+
let hash = rnd.next_u64();
537+
a_p.push_hash(hash);
538+
a_hp.push_hash(hash);
539+
}
540+
for _ in 0..1000 {
541+
let hash = rnd.next_u64();
542+
b_p.push_hash(hash);
543+
b_hp.push_hash(hash);
544+
}
545+
let c_p = xor(&a_p, &b_p);
546+
let c_hp = xor(&a_hp, &b_hp);
541547

542-
assert_eq!(c_p.size(), a_p.size_with_sketch(&b_p));
543-
assert_eq!(c_p.size(), b_p.size_with_sketch(&a_p));
548+
assert_eq!(c_p.size(), a_p.size_with_sketch(&b_p));
549+
assert_eq!(c_p.size(), b_p.size_with_sketch(&a_p));
544550

545-
assert_eq!(c_hp.size(), a_hp.size_with_sketch(&b_hp));
546-
assert_eq!(c_hp.size(), b_hp.size_with_sketch(&a_hp));
551+
assert_eq!(c_hp.size(), a_hp.size_with_sketch(&b_hp));
552+
assert_eq!(c_hp.size(), b_hp.size_with_sketch(&a_hp));
553+
});
547554
}
548555

549556
#[test]
@@ -575,45 +582,39 @@ mod tests {
575582

576583
#[test]
577584
fn test_xor_plus_mask() {
578-
let mut rnd = rand::rngs::StdRng::from_os_rng();
579-
let mask_size = 12;
580-
let mask = 0b100001100000;
581-
let mut a = GeoDiffCount7::default();
582-
for _ in 0..10000 {
583-
a.xor_bit(a.config.hash_to_bucket(rnd.next_u64()));
584-
}
585-
let mut expected = GeoDiffCount7::default();
586-
let mut b = a.clone();
587-
for _ in 0..1000 {
588-
let hash = rnd.next_u64();
589-
b.xor_bit(b.config.hash_to_bucket(hash));
590-
expected.xor_bit(expected.config.hash_to_bucket(hash));
591-
assert_eq!(expected, xor(&a, &b));
592-
593-
let masked_a = masked(&a, mask, mask_size);
594-
let masked_b = masked(&b, mask, mask_size);
595-
let masked_expected = masked(&expected, mask, mask_size);
596-
// FIXME: test failed once with:
597-
// left: ~12.37563 (msb: [390, 334, 263, 242, 222, 215, 164, 148, 100, 97, 66, 36], |lsb|: 36)
598-
// right: ~12.37563 (msb: [390, 334, 263, 242, 222, 215, 164, 148, 100, 97, 66, 36], |lsb|: 0)
599-
assert_eq!(masked_expected, xor(&masked_a, &masked_b));
600-
}
585+
prng_test_harness(10, |rnd| {
586+
let mask_size = 12;
587+
let mask = 0b100001100000;
588+
let mut a = GeoDiffCount7::default();
589+
for _ in 0..10000 {
590+
a.xor_bit(a.config.hash_to_bucket(rnd.next_u64()));
591+
}
592+
let mut expected = GeoDiffCount7::default();
593+
let mut b = a.clone();
594+
for _ in 0..1000 {
595+
let hash = rnd.next_u64();
596+
b.xor_bit(b.config.hash_to_bucket(hash));
597+
expected.xor_bit(expected.config.hash_to_bucket(hash));
598+
assert_eq!(expected, xor(&a, &b));
599+
let masked_a = masked(&a, mask, mask_size);
600+
let masked_b = masked(&b, mask, mask_size);
601+
let masked_expected = masked(&expected, mask, mask_size);
602+
assert_eq!(masked_expected, xor(&masked_a, &masked_b));
603+
}
604+
});
601605
}
602606

603607
#[test]
604608
fn test_bit_chunks() {
605-
let mut rnd = rand::rngs::StdRng::from_os_rng();
606-
for _ in 0..100 {
609+
prng_test_harness(100, |rnd| {
607610
let mut expected = GeoDiffCount7::default();
608611
for _ in 0..1000 {
609612
expected.push_hash(rnd.next_u64());
610613
}
611-
let actual = GeoDiffCount::from_bit_chunks(
612-
expected.config.clone(),
613-
expected.bit_chunks().peekable(),
614-
);
614+
let actual =
615+
GeoDiffCount::from_bit_chunks(expected.config.clone(), expected.bit_chunks());
615616
assert_eq!(expected, actual);
616-
}
617+
});
617618
}
618619

619620
#[test]
@@ -656,44 +657,44 @@ mod tests {
656657
// This helper exists in order to easily test serializing types with different
657658
// bucket types in the MSB sparse bit field representation. See tests below.
658659
#[cfg(target_endian = "little")]
659-
fn serialization_round_trip<C: GeoConfig<Diff> + Default>() {
660-
let mut rnd = rand::rngs::StdRng::from_os_rng();
660+
fn serialization_round_trip<C: GeoConfig<Diff> + Default>(rnd: &mut StdRng) {
661661
// Run 100 simulations of random values being put into
662662
// a diff counter. "Serializing" to a vector to emulate
663663
// writing to a disk, and then deserializing and asserting
664664
// the filters are equal.
665-
for _ in 0..100 {
666-
let mut before = GeoDiffCount::<'_, C>::default();
667-
// Select a random number of items to insert.
668-
let items = (1..1000).choose(&mut rnd).unwrap();
669-
for _ in 0..items {
670-
before.push_hash(rnd.next_u64());
671-
}
672-
let mut writer = vec![];
673-
// Insert some padding to emulate alignment issues with the slices.
674-
// A previous version of this test never panicked even though we were
675-
// violating the alignment preconditions for the `from_raw_parts` function.
676-
let padding = [0_u8; 8];
677-
let pad_amount = (0..8).choose(&mut rnd).unwrap();
678-
writer.write_all(&padding[..pad_amount]).unwrap();
679-
before.write(&mut writer).unwrap();
680-
let after =
681-
GeoDiffCount::<'_, C>::from_bytes(before.config.clone(), &writer[pad_amount..]);
682-
assert_eq!(before, after);
665+
let mut before = GeoDiffCount::<'_, C>::default();
666+
// Select a random number of items to insert.
667+
let items = (1..1000).choose(rnd).unwrap();
668+
for _ in 0..items {
669+
before.push_hash(rnd.next_u64());
683670
}
671+
let mut writer = vec![];
672+
// Insert some padding to emulate alignment issues with the slices.
673+
// A previous version of this test never panicked even though we were
674+
// violating the alignment preconditions for the `from_raw_parts` function.
675+
let padding = [0_u8; 8];
676+
let pad_amount = (0..8).choose(rnd).unwrap();
677+
writer.write_all(&padding[..pad_amount]).unwrap();
678+
before.write(&mut writer).unwrap();
679+
let after = GeoDiffCount::<'_, C>::from_bytes(before.config.clone(), &writer[pad_amount..]);
680+
assert_eq!(before, after);
684681
}
685682

686683
#[test]
687684
#[cfg(target_endian = "little")]
688685
fn test_serialization_round_trip_7() {
689-
// Uses a u16 for MSB buckets.
690-
serialization_round_trip::<GeoDiffConfig7>();
686+
prng_test_harness(100, |rnd| {
687+
// Uses a u16 for MSB buckets.
688+
serialization_round_trip::<GeoDiffConfig7>(rnd);
689+
});
691690
}
692691

693692
#[test]
694693
#[cfg(target_endian = "little")]
695694
fn test_serialization_round_trip_13() {
696-
// Uses a u32 for MSB buckets.
697-
serialization_round_trip::<GeoDiffConfig13>();
695+
prng_test_harness(100, |rnd| {
696+
// Uses a u32 for MSB buckets.
697+
serialization_round_trip::<GeoDiffConfig13>(rnd);
698+
});
698699
}
699700
}

0 commit comments

Comments
 (0)