Skip to content

Commit 4261d8d

Browse files
committed
Address reworks in geo_filters work
- Pass in RNG to helper functions - Add iterations to test harness
1 parent bd87a25 commit 4261d8d

5 files changed

Lines changed: 141 additions & 124 deletions

File tree

crates/geo_filters/src/config.rs

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -353,30 +353,31 @@ pub(crate) fn take_ref<I: Iterator>(iter: &mut I, n: usize) -> impl Iterator<Ite
353353

354354
#[cfg(test)]
355355
pub(crate) mod tests {
356-
use rand::RngCore;
356+
use rand::{rngs::StdRng, RngCore};
357357

358-
use crate::{test_rng::prng_test_harness, Count, Method};
358+
use crate::{Count, Method};
359359

360360
/// Runs estimation trials and returns the average precision and variance.
361-
pub(crate) fn test_estimate<M: Method, C: Count<M>>(f: impl Fn() -> C) -> (f32, f32) {
362-
prng_test_harness(|rnd| {
363-
let cnt = 10000usize;
364-
let mut avg_precision = 0.0;
365-
let mut avg_var = 0.0;
366-
let trials = 500;
367-
for _ in 0..trials {
368-
let mut m = f();
369-
// Insert cnt many random items.
370-
for _ in 0..cnt {
371-
m.push_hash(rnd.next_u64());
372-
}
373-
// Compute the relative error between estimate and actually inserted items.
374-
let high_precision = m.size() / cnt as f32 - 1.0;
375-
// Take the average over trials many attempts.
376-
avg_precision += high_precision / trials as f32;
377-
avg_var += high_precision.powf(2.0) / trials as f32;
361+
pub(crate) fn test_estimate<M: Method, C: Count<M>>(
362+
rnd: &mut StdRng,
363+
f: impl Fn() -> C,
364+
) -> (f32, f32) {
365+
let cnt = 10000usize;
366+
let mut avg_precision = 0.0;
367+
let mut avg_var = 0.0;
368+
let trials = 500;
369+
for _ in 0..trials {
370+
let mut m = f();
371+
// Insert cnt many random items.
372+
for _ in 0..cnt {
373+
m.push_hash(rnd.next_u64());
378374
}
379-
(avg_precision, avg_var)
380-
})
375+
// Compute the relative error between estimate and actually inserted items.
376+
let high_precision = m.size() / cnt as f32 - 1.0;
377+
// Take the average over trials many attempts.
378+
avg_precision += high_precision / trials as f32;
379+
avg_var += high_precision.powf(2.0) / trials as f32;
380+
}
381+
(avg_precision, avg_var)
381382
}
382383
}

crates/geo_filters/src/config/lookup.rs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ impl HashToBucketLookup {
4545

4646
#[cfg(test)]
4747
mod tests {
48-
use rand::RngCore;
48+
use rand::{rngs::StdRng, RngCore};
4949

5050
use crate::{
5151
config::{hash_to_bucket, phi_f64},
@@ -56,30 +56,32 @@ mod tests {
5656

5757
#[test]
5858
fn test_lookup_7() {
59-
let var = lookup_random_hashes_variance::<7>(1 << 16);
60-
assert!(var < 1e-4, "variance {var} too large");
59+
prng_test_harness(1, |rnd| {
60+
let var = lookup_random_hashes_variance::<7>(rnd, 1 << 16);
61+
assert!(var < 1e-4, "variance {var} too large");
62+
});
6163
}
6264

6365
#[test]
6466
fn test_lookup_13() {
65-
let var = lookup_random_hashes_variance::<13>(1 << 16);
66-
assert!(var < 1e-4, "variance {var} too large");
67+
prng_test_harness(1, |rnd| {
68+
let var = lookup_random_hashes_variance::<13>(rnd, 1 << 16);
69+
assert!(var < 1e-4, "variance {var} too large");
70+
});
6771
}
6872

69-
fn lookup_random_hashes_variance<const B: usize>(n: u64) -> f64 {
73+
fn lookup_random_hashes_variance<const B: usize>(rnd: &mut StdRng, n: u64) -> f64 {
7074
let phi = phi_f64(B);
7175
let buckets = HashToBucketLookup::new(B);
7276

73-
prng_test_harness(|rnd| {
74-
let mut var = 0.0;
75-
for _ in 0..n {
76-
let hash = rnd.next_u64();
77-
let estimate = buckets.lookup(hash) as f64;
78-
let real = hash_to_bucket(phi, hash) as f64;
79-
let err = estimate - real; // assume the mean = 0.0
80-
var += err.powf(2.0) / n as f64;
81-
}
82-
var
83-
})
77+
let mut var = 0.0;
78+
for _ in 0..n {
79+
let hash = rnd.next_u64();
80+
let estimate = buckets.lookup(hash) as f64;
81+
let real = hash_to_bucket(phi, hash) as f64;
82+
let err = estimate - real; // assume the mean = 0.0
83+
var += err.powf(2.0) / n as f64;
84+
}
85+
var
8486
}
8587
}

crates/geo_filters/src/diff_count.rs

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -432,35 +432,39 @@ mod tests {
432432

433433
#[test]
434434
fn test_estimate_fast() {
435-
let (avg_precision, avg_var) = test_estimate(GeoDiffCount7::default);
436-
println!(
437-
"avg precision: {} with standard deviation: {}",
438-
avg_precision,
439-
avg_var.sqrt(),
440-
);
441-
// Make sure that the estimate converges to the correct value.
442-
assert!(avg_precision.abs() < 0.04);
443-
// We should theoretically achieve a standard deviation of about 0.12
444-
assert!(avg_var.sqrt() < 0.14);
435+
prng_test_harness(1, |rnd| {
436+
let (avg_precision, avg_var) = test_estimate(rnd, GeoDiffCount7::default);
437+
println!(
438+
"avg precision: {} with standard deviation: {}",
439+
avg_precision,
440+
avg_var.sqrt(),
441+
);
442+
// Make sure that the estimate converges to the correct value.
443+
assert!(avg_precision.abs() < 0.04);
444+
// We should theoretically achieve a standard deviation of about 0.12
445+
assert!(avg_var.sqrt() < 0.14);
446+
})
445447
}
446448

447449
#[test]
448450
fn test_estimate_fast_low_precision() {
449-
let (avg_precision, avg_var) = test_estimate(GeoDiffCount7_50::default);
450-
println!(
451-
"avg precision: {} with standard deviation: {}",
452-
avg_precision,
453-
avg_var.sqrt(),
454-
);
455-
// Make sure that the estimate converges to the correct value.
456-
assert!(avg_precision.abs() < 0.15);
457-
// We should theoretically achieve a standard deviation of about 0.25
458-
assert!(avg_var.sqrt() < 0.4);
451+
prng_test_harness(1, |rnd| {
452+
let (avg_precision, avg_var) = test_estimate(rnd, GeoDiffCount7_50::default);
453+
println!(
454+
"avg precision: {} with standard deviation: {}",
455+
avg_precision,
456+
avg_var.sqrt(),
457+
);
458+
// Make sure that the estimate converges to the correct value.
459+
assert!(avg_precision.abs() < 0.15);
460+
// We should theoretically achieve a standard deviation of about 0.25
461+
assert!(avg_var.sqrt() < 0.4);
462+
});
459463
}
460464

461465
#[test]
462466
fn test_estimate_diff_size_fast() {
463-
prng_test_harness(|rnd| {
467+
prng_test_harness(1, |rnd| {
464468
let mut a_p = GeoDiffCount7_50::default();
465469
let mut a_hp = GeoDiffCount7::default();
466470
let mut b_p = GeoDiffCount7_50::default();
@@ -515,45 +519,41 @@ mod tests {
515519

516520
#[test]
517521
fn test_xor_plus_mask() {
518-
for _ in 0..1000 {
519-
prng_test_harness(|rnd| {
520-
let mask_size = 12;
521-
let mask = 0b100001100000;
522-
let mut a = GeoDiffCount7::default();
523-
for _ in 0..10000 {
524-
a.xor_bit(a.config.hash_to_bucket(rnd.next_u64()));
525-
}
526-
let mut expected = GeoDiffCount7::default();
527-
let mut b = a.clone();
528-
for _ in 0..1000 {
529-
let hash = rnd.next_u64();
530-
b.xor_bit(b.config.hash_to_bucket(hash));
531-
expected.xor_bit(expected.config.hash_to_bucket(hash));
532-
assert_eq!(expected, xor(&a, &b));
533-
let masked_a = masked(&a, mask, mask_size);
534-
let masked_b = masked(&b, mask, mask_size);
535-
let masked_expected = masked(&expected, mask, mask_size);
536-
assert_eq!(masked_expected, xor(&masked_a, &masked_b));
537-
}
538-
});
539-
}
522+
prng_test_harness(1000, |rnd| {
523+
let mask_size = 12;
524+
let mask = 0b100001100000;
525+
let mut a = GeoDiffCount7::default();
526+
for _ in 0..10000 {
527+
a.xor_bit(a.config.hash_to_bucket(rnd.next_u64()));
528+
}
529+
let mut expected = GeoDiffCount7::default();
530+
let mut b = a.clone();
531+
for _ in 0..1000 {
532+
let hash = rnd.next_u64();
533+
b.xor_bit(b.config.hash_to_bucket(hash));
534+
expected.xor_bit(expected.config.hash_to_bucket(hash));
535+
assert_eq!(expected, xor(&a, &b));
536+
let masked_a = masked(&a, mask, mask_size);
537+
let masked_b = masked(&b, mask, mask_size);
538+
let masked_expected = masked(&expected, mask, mask_size);
539+
assert_eq!(masked_expected, xor(&masked_a, &masked_b));
540+
}
541+
});
540542
}
541543

542544
#[test]
543545
fn test_bit_chunks() {
544-
for _ in 0..100 {
545-
prng_test_harness(|rnd| {
546-
let mut expected = GeoDiffCount7::default();
547-
for _ in 0..1000 {
548-
expected.push_hash(rnd.next_u64());
549-
}
550-
let actual = GeoDiffCount::from_bit_chunks(
551-
expected.config.clone(),
552-
expected.bit_chunks().peekable(),
553-
);
554-
assert_eq!(expected, actual);
555-
});
556-
}
546+
prng_test_harness(100, |rnd| {
547+
let mut expected = GeoDiffCount7::default();
548+
for _ in 0..1000 {
549+
expected.push_hash(rnd.next_u64());
550+
}
551+
let actual = GeoDiffCount::from_bit_chunks(
552+
expected.config.clone(),
553+
expected.bit_chunks().peekable(),
554+
);
555+
assert_eq!(expected, actual);
556+
});
557557
}
558558

559559
#[test]

crates/geo_filters/src/distinct_count.rs

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -308,21 +308,23 @@ mod tests {
308308

309309
#[test]
310310
fn test_estimate_fast() {
311-
let (avg_precision, avg_var) = test_estimate(GeoDistinctCount7::default);
312-
println!(
313-
"avg precision: {} with standard deviation: {}",
314-
avg_precision,
315-
avg_var.sqrt(),
316-
);
317-
// Make sure that the estimate converges to the correct value.
318-
assert!(avg_precision.abs() < 0.04);
319-
// We should theoretically achieve a standard deviation of about 0.065
320-
assert!(avg_var.sqrt() < 0.08);
311+
prng_test_harness(1, |rnd| {
312+
let (avg_precision, avg_var) = test_estimate(rnd, GeoDistinctCount7::default);
313+
println!(
314+
"avg precision: {} with standard deviation: {}",
315+
avg_precision,
316+
avg_var.sqrt(),
317+
);
318+
// Make sure that the estimate converges to the correct value.
319+
assert!(avg_precision.abs() < 0.04);
320+
// We should theoretically achieve a standard deviation of about 0.065
321+
assert!(avg_var.sqrt() < 0.08);
322+
})
321323
}
322324

323325
#[test]
324326
fn test_estimate_union_size_fast() {
325-
prng_test_harness(|rnd| {
327+
prng_test_harness(1, |rnd| {
326328
let mut a = GeoDistinctCount7::default();
327329
let mut b = GeoDistinctCount7::default();
328330
for _ in 0..10000 {
@@ -394,19 +396,15 @@ mod tests {
394396

395397
#[test]
396398
fn test_bit_chunks() {
397-
for _ in 0..100 {
398-
prng_test_harness(|rnd| {
399-
let mut expected = GeoDistinctCount7::default();
400-
for _ in 0..1000 {
401-
expected.push_hash(rnd.next_u64());
402-
}
403-
let actual = GeoDistinctCount::from_bit_chunks(
404-
expected.config.clone(),
405-
expected.bit_chunks(),
406-
);
407-
assert_eq!(expected, actual);
408-
})
409-
}
399+
prng_test_harness(100, |rnd| {
400+
let mut expected = GeoDistinctCount7::default();
401+
for _ in 0..1000 {
402+
expected.push_hash(rnd.next_u64());
403+
}
404+
let actual =
405+
GeoDistinctCount::from_bit_chunks(expected.config.clone(), expected.bit_chunks());
406+
assert_eq!(expected, actual);
407+
})
410408
}
411409

412410
#[test]

crates/geo_filters/src/test_rng.rs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,31 @@ use rand::{rngs::StdRng, SeedableRng as _};
66
/// degree of randomization. If the test panics the harness will print the
77
/// seed used for that run. You can then pass in this seed using the `TEST_SEED`
88
/// environment variable when running your tests.
9-
pub fn prng_test_harness<F, T>(test_fn: F) -> T
9+
///
10+
/// You can provide a number of `iterations` this harness will run with randomly
11+
/// generated seeds. If a manual seed is provided via the environment then the test
12+
/// is only ran once with this seed.
13+
pub fn prng_test_harness<F>(iterations: usize, mut test_fn: F)
1014
where
11-
F: Fn(&mut StdRng) -> T,
15+
F: FnMut(&mut StdRng) -> (),
1216
{
13-
let seed = std::env::var("TEST_SEED")
17+
let maybe_manual_seed = std::env::var("TEST_SEED")
1418
.map(|s| s.parse::<u64>().expect("Parse TEST_SEED to u64"))
15-
.unwrap_or_else(|_| rand::random());
16-
let mut rng = StdRng::seed_from_u64(seed);
17-
let maybe_panic = catch_unwind(AssertUnwindSafe(|| test_fn(&mut rng)));
19+
.ok();
20+
let mut seed = 0;
21+
let maybe_panic = catch_unwind(AssertUnwindSafe(|| {
22+
if let Some(manual_seed) = maybe_manual_seed {
23+
seed = manual_seed;
24+
let mut rng = StdRng::seed_from_u64(seed);
25+
test_fn(&mut rng);
26+
} else {
27+
for _ in 0..iterations {
28+
seed = rand::random();
29+
let mut rng = StdRng::seed_from_u64(seed);
30+
test_fn(&mut rng);
31+
}
32+
}
33+
}));
1834
match maybe_panic {
1935
Ok(t) => t,
2036
Err(panic_info) => {

0 commit comments

Comments
 (0)