Skip to content

Commit b3dc52e

Browse files
committed
Refactor geo_filters tests to use PRNG test harness
1 parent 1013140 commit b3dc52e

6 files changed

Lines changed: 111 additions & 104 deletions

File tree

crates/geo_filters/src/config.rs

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -353,29 +353,30 @@ pub(crate) fn take_ref<I: Iterator>(iter: &mut I, n: usize) -> impl Iterator<Ite
353353

354354
#[cfg(test)]
355355
pub(crate) mod tests {
356-
use rand::{RngCore, SeedableRng};
356+
use rand::RngCore;
357357

358-
use crate::{Count, Method};
358+
use crate::{test_rng::prng_test_harness, Count, Method};
359359

360360
/// Runs estimation trials and returns the average precision and variance.
361361
pub(crate) fn test_estimate<M: Method, C: Count<M>>(f: impl Fn() -> C) -> (f32, f32) {
362-
let mut rnd = rand::rngs::StdRng::from_os_rng();
363-
let cnt = 10000usize;
364-
let mut avg_precision = 0.0;
365-
let mut avg_var = 0.0;
366-
let trials = 500;
367-
for _ in 0..trials {
368-
let mut m = f();
369-
// Insert cnt many random items.
370-
for _ in 0..cnt {
371-
m.push_hash(rnd.next_u64());
362+
prng_test_harness(|rnd| {
363+
let cnt = 10000usize;
364+
let mut avg_precision = 0.0;
365+
let mut avg_var = 0.0;
366+
let trials = 500;
367+
for _ in 0..trials {
368+
let mut m = f();
369+
// Insert cnt many random items.
370+
for _ in 0..cnt {
371+
m.push_hash(rnd.next_u64());
372+
}
373+
// Compute the relative error between estimate and actually inserted items.
374+
let high_precision = m.size() / cnt as f32 - 1.0;
375+
// Take the average over trials many attempts.
376+
avg_precision += high_precision / trials as f32;
377+
avg_var += high_precision.powf(2.0) / trials as f32;
372378
}
373-
// Compute the relative error between estimate and actually inserted items.
374-
let high_precision = m.size() / cnt as f32 - 1.0;
375-
// Take the average over trials many attempts.
376-
avg_precision += high_precision / trials as f32;
377-
avg_var += high_precision.powf(2.0) / trials as f32;
378-
}
379-
(avg_precision, avg_var)
379+
(avg_precision, avg_var)
380+
})
380381
}
381382
}

crates/geo_filters/src/config/lookup.rs

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,12 @@ impl HashToBucketLookup {
4545

4646
#[cfg(test)]
4747
mod tests {
48-
use rand::{RngCore, SeedableRng};
48+
use rand::RngCore;
4949

50-
use crate::config::{hash_to_bucket, phi_f64};
50+
use crate::{
51+
config::{hash_to_bucket, phi_f64},
52+
test_rng::prng_test_harness,
53+
};
5154

5255
use super::HashToBucketLookup;
5356

@@ -66,15 +69,17 @@ mod tests {
6669
fn lookup_random_hashes_variance<const B: usize>(n: u64) -> f64 {
6770
let phi = phi_f64(B);
6871
let buckets = HashToBucketLookup::new(B);
69-
let mut var = 0.0;
70-
let mut rnd = rand::rngs::StdRng::from_os_rng();
71-
for _ in 0..n {
72-
let hash = rnd.next_u64();
73-
let estimate = buckets.lookup(hash) as f64;
74-
let real = hash_to_bucket(phi, hash) as f64;
75-
let err = estimate - real; // assume the mean = 0.0
76-
var += err.powf(2.0) / n as f64;
77-
}
78-
var
72+
73+
prng_test_harness(|rnd| {
74+
let mut var = 0.0;
75+
for _ in 0..n {
76+
let hash = rnd.next_u64();
77+
let estimate = buckets.lookup(hash) as f64;
78+
let real = hash_to_bucket(phi, hash) as f64;
79+
let err = estimate - real; // assume the mean = 0.0
80+
var += err.powf(2.0) / n as f64;
81+
}
82+
var
83+
})
7984
}
8085
}

crates/geo_filters/src/diff_count.rs

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ impl<C: GeoConfig<Diff>> Count<Diff> for GeoDiffCount<'_, C> {
358358
#[cfg(test)]
359359
mod tests {
360360
use itertools::Itertools;
361-
use rand::{RngCore, SeedableRng};
361+
use rand::RngCore;
362362

363363
use crate::{
364364
build_hasher::UnstableDefaultBuildHasher,
@@ -462,29 +462,30 @@ mod tests {
462462

463463
#[test]
464464
fn test_estimate_diff_size_fast() {
465-
let mut rnd = rand::rngs::StdRng::from_os_rng();
466-
let mut a_p = GeoDiffCount7_50::default();
467-
let mut a_hp = GeoDiffCount7::default();
468-
let mut b_p = GeoDiffCount7_50::default();
469-
let mut b_hp = GeoDiffCount7::default();
470-
for _ in 0..10000 {
471-
let hash = rnd.next_u64();
472-
a_p.push_hash(hash);
473-
a_hp.push_hash(hash);
474-
}
475-
for _ in 0..1000 {
476-
let hash = rnd.next_u64();
477-
b_p.push_hash(hash);
478-
b_hp.push_hash(hash);
479-
}
480-
let c_p = xor(&a_p, &b_p);
481-
let c_hp = xor(&a_hp, &b_hp);
465+
prng_test_harness(|rnd| {
466+
let mut a_p = GeoDiffCount7_50::default();
467+
let mut a_hp = GeoDiffCount7::default();
468+
let mut b_p = GeoDiffCount7_50::default();
469+
let mut b_hp = GeoDiffCount7::default();
470+
for _ in 0..10000 {
471+
let hash = rnd.next_u64();
472+
a_p.push_hash(hash);
473+
a_hp.push_hash(hash);
474+
}
475+
for _ in 0..1000 {
476+
let hash = rnd.next_u64();
477+
b_p.push_hash(hash);
478+
b_hp.push_hash(hash);
479+
}
480+
let c_p = xor(&a_p, &b_p);
481+
let c_hp = xor(&a_hp, &b_hp);
482482

483-
assert_eq!(c_p.size(), a_p.size_with_sketch(&b_p));
484-
assert_eq!(c_p.size(), b_p.size_with_sketch(&a_p));
483+
assert_eq!(c_p.size(), a_p.size_with_sketch(&b_p));
484+
assert_eq!(c_p.size(), b_p.size_with_sketch(&a_p));
485485

486-
assert_eq!(c_hp.size(), a_hp.size_with_sketch(&b_hp));
487-
assert_eq!(c_hp.size(), b_hp.size_with_sketch(&a_hp));
486+
assert_eq!(c_hp.size(), a_hp.size_with_sketch(&b_hp));
487+
assert_eq!(c_hp.size(), b_hp.size_with_sketch(&a_hp));
488+
});
488489
}
489490

490491
#[test]
@@ -517,24 +518,19 @@ mod tests {
517518
#[test]
518519
fn test_xor_plus_mask() {
519520
for _ in 0..1000 {
520-
prng_test_harness(|mut rng| {
521+
prng_test_harness(|rnd| {
521522
let mask_size = 12;
522523
let mask = 0b100001100000;
523524
let mut a = GeoDiffCount7::default();
524525
for _ in 0..10000 {
525-
a.xor_bit(a.config.hash_to_bucket(rng.next_u64()));
526+
a.xor_bit(a.config.hash_to_bucket(rnd.next_u64()));
526527
}
527528
let mut expected = GeoDiffCount7::default();
528529
let mut b = a.clone();
529530
for _ in 0..1000 {
530-
let hash = rng.next_u64();
531+
let hash = rnd.next_u64();
531532
b.xor_bit(b.config.hash_to_bucket(hash));
532533
expected.xor_bit(expected.config.hash_to_bucket(hash));
533-
534-
println!("a -> {:?}", a);
535-
println!("b -> {:?}", b);
536-
println!("exp -> {:?}", expected);
537-
538534
assert_eq!(expected, xor(&a, &b));
539535
let masked_a = masked(&a, mask, mask_size);
540536
let masked_b = masked(&b, mask, mask_size);
@@ -547,17 +543,18 @@ mod tests {
547543

548544
#[test]
549545
fn test_bit_chunks() {
550-
let mut rnd = rand::rngs::StdRng::from_os_rng();
551546
for _ in 0..100 {
552-
let mut expected = GeoDiffCount7::default();
553-
for _ in 0..1000 {
554-
expected.push_hash(rnd.next_u64());
555-
}
556-
let actual = GeoDiffCount::from_bit_chunks(
557-
expected.config.clone(),
558-
expected.bit_chunks().peekable(),
559-
);
560-
assert_eq!(expected, actual);
547+
prng_test_harness(|rnd| {
548+
let mut expected = GeoDiffCount7::default();
549+
for _ in 0..1000 {
550+
expected.push_hash(rnd.next_u64());
551+
}
552+
let actual = GeoDiffCount::from_bit_chunks(
553+
expected.config.clone(),
554+
expected.bit_chunks().peekable(),
555+
);
556+
assert_eq!(expected, actual);
557+
});
561558
}
562559
}
563560

crates/geo_filters/src/diff_count/bitvec.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ impl BitVec<'_> {
101101
assert!(index < self.num_bits);
102102
let (block_idx, bit_idx) = index.into_index_and_bit();
103103
self.blocks.to_mut()[block_idx] ^= bit_idx.into_block();
104-
105-
if self.blocks.to_mut()[block_idx] & bit_idx.into_block() == 0 {}
106104
}
107105

108106
/// Returns an iterator over all blocks in reverse order.

crates/geo_filters/src/distinct_count.rs

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,12 @@ fn or<C: GeoConfig<Distinct>>(
232232
#[cfg(test)]
233233
mod tests {
234234
use itertools::Itertools;
235-
use rand::{RngCore, SeedableRng};
235+
use rand::RngCore;
236236

237237
use crate::build_hasher::UnstableDefaultBuildHasher;
238238
use crate::config::{iter_ones, tests::test_estimate, FixedConfig, VariableConfig};
239239
use crate::evaluation::simulation::simulate;
240+
use crate::test_rng::prng_test_harness;
240241

241242
use super::*;
242243

@@ -321,19 +322,20 @@ mod tests {
321322

322323
#[test]
323324
fn test_estimate_union_size_fast() {
324-
let mut rnd = rand::rngs::StdRng::from_os_rng();
325-
let mut a = GeoDistinctCount7::default();
326-
let mut b = GeoDistinctCount7::default();
327-
for _ in 0..10000 {
328-
a.push_hash(rnd.next_u64());
329-
}
330-
for _ in 0..1000 {
331-
b.push_hash(rnd.next_u64());
332-
}
333-
let c = or(&a, &b);
325+
prng_test_harness(|rnd| {
326+
let mut a = GeoDistinctCount7::default();
327+
let mut b = GeoDistinctCount7::default();
328+
for _ in 0..10000 {
329+
a.push_hash(rnd.next_u64());
330+
}
331+
for _ in 0..1000 {
332+
b.push_hash(rnd.next_u64());
333+
}
334+
let c = or(&a, &b);
334335

335-
assert_eq!(c.size(), a.size_with_sketch(&b));
336-
assert_eq!(c.size(), b.size_with_sketch(&a));
336+
assert_eq!(c.size(), a.size_with_sketch(&b));
337+
assert_eq!(c.size(), b.size_with_sketch(&a));
338+
})
337339
}
338340

339341
fn golden_section_min<F: Fn(f32) -> f32>(min: f32, max: f32, f: F) -> f32 {
@@ -392,15 +394,18 @@ mod tests {
392394

393395
#[test]
394396
fn test_bit_chunks() {
395-
let mut rnd = rand::rngs::StdRng::from_os_rng();
396397
for _ in 0..100 {
397-
let mut expected = GeoDistinctCount7::default();
398-
for _ in 0..1000 {
399-
expected.push_hash(rnd.next_u64());
400-
}
401-
let actual =
402-
GeoDistinctCount::from_bit_chunks(expected.config.clone(), expected.bit_chunks());
403-
assert_eq!(expected, actual);
398+
prng_test_harness(|rnd| {
399+
let mut expected = GeoDistinctCount7::default();
400+
for _ in 0..1000 {
401+
expected.push_hash(rnd.next_u64());
402+
}
403+
let actual = GeoDistinctCount::from_bit_chunks(
404+
expected.config.clone(),
405+
expected.bit_chunks(),
406+
);
407+
assert_eq!(expected, actual);
408+
})
404409
}
405410
}
406411

crates/geo_filters/src/test_rng.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
1-
use std::panic::{catch_unwind, resume_unwind, UnwindSafe};
1+
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
22

33
use rand::{rngs::StdRng, SeedableRng as _};
44

55
/// Provides a seeded random number generator to tests which require some
66
/// degree of randomization. If the test panics the harness will print the
77
/// seed used for that run. You can then pass in this seed using the `TEST_SEED`
88
/// environment variable when running your tests.
9-
pub fn prng_test_harness<F>(test_fn: F)
9+
pub fn prng_test_harness<F, T>(test_fn: F) -> T
1010
where
11-
F: Fn(StdRng) -> () + UnwindSafe,
11+
F: Fn(&mut StdRng) -> T,
1212
{
1313
let seed = std::env::var("TEST_SEED")
1414
.map(|s| s.parse::<u64>().expect("Parse TEST_SEED to u64"))
1515
.unwrap_or_else(|_| rand::random());
16-
let rng = StdRng::seed_from_u64(seed);
17-
let maybe_panic = catch_unwind(move || {
18-
test_fn(rng);
19-
});
20-
if let Err(panic_info) = maybe_panic {
21-
eprintln!("Test failed! Reproduce with: TEST_SEED={}", seed);
22-
resume_unwind(panic_info);
16+
let mut rng = StdRng::seed_from_u64(seed);
17+
let maybe_panic = catch_unwind(AssertUnwindSafe(|| test_fn(&mut rng)));
18+
match maybe_panic {
19+
Ok(t) => t,
20+
Err(panic_info) => {
21+
eprintln!("Test failed! Reproduce with: TEST_SEED={}", seed);
22+
resume_unwind(panic_info);
23+
}
2324
}
2425
}

0 commit comments

Comments
 (0)