|
2 | 2 | #include "gtest/gtest.h" |
3 | 3 | #include <cereal/archives/portable_binary.hpp> |
4 | 4 | #include <sstream> |
| 5 | +#include "../unit_test_common.hpp" |
5 | 6 |
|
6 | 7 | using namespace chimbuko; |
7 | 8 |
|
| 9 | +//Textbook definitions |
| 10 | +double mean(const std::vector<double> &a){ |
| 11 | + double r=0; |
| 12 | + for(double v:a) r+=v; |
| 13 | + return r/a.size(); |
| 14 | +} |
| 15 | +double variance(const std::vector<double> &a, bool incl_bessel = true){ |
| 16 | + double n = a.size(); |
| 17 | + double r=0,r2=0; |
| 18 | + for(double v:a){ r+=v; r2+=v*v; } |
| 19 | + r = (r2/n - r/n*r/n); |
| 20 | + if(incl_bessel) r *= n/(n-1); //include Bessel's correction by default |
| 21 | + return r; |
| 22 | +} |
| 23 | +double skewness(const std::vector<double> &a){ |
| 24 | + double mu = mean(a); |
| 25 | + double sigma = sqrt(variance(a,false)); |
| 26 | + double r = 0; |
| 27 | + for(double v:a) r += pow( (v - mu)/sigma, 3 ); |
| 28 | + return r/a.size(); |
| 29 | +} |
| 30 | +double kurtosis(const std::vector<double> &a){ //technically excess kurtosis |
| 31 | + double mu = mean(a); |
| 32 | + double sigma = sqrt(variance(a,false)); |
| 33 | + double r = 0; |
| 34 | + for(double v:a) r += pow( (v - mu)/sigma, 4 ); |
| 35 | + return r/a.size() - 3; |
| 36 | +} |
| 37 | + |
| 38 | + |
| 39 | + |
| 40 | +//Independent implementation using https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance |
| 41 | +struct statsTest{ |
| 42 | + double n; |
| 43 | + double mu; |
| 44 | + double M2; |
| 45 | + double M3; |
| 46 | + double M4; |
| 47 | + |
| 48 | + statsTest(): n(0), mu(0), M2(0), M3(0), M4(0){} |
| 49 | + statsTest(const std::vector<double> &v){ |
| 50 | + n = v.size(); |
| 51 | + mu = 0; |
| 52 | + M2 = 0; |
| 53 | + M3 = 0; |
| 54 | + M4 = 0; |
| 55 | + for(double e: v) |
| 56 | + mu += e; |
| 57 | + mu /= n; |
| 58 | + |
| 59 | + for(double e: v){ |
| 60 | + M2 += pow(e - mu,2); |
| 61 | + M3 += pow(e - mu,3); |
| 62 | + M4 += pow(e - mu,4); |
| 63 | + } |
| 64 | + |
| 65 | + } |
| 66 | + |
| 67 | + double variance() const{ //includes Bessel's correction |
| 68 | + return M2/(n-1.); |
| 69 | + } |
| 70 | + double mean() const{ |
| 71 | + return mu; |
| 72 | + } |
| 73 | + double skewness() const{ |
| 74 | + return M3/n/pow(M2/n,3./2.); |
| 75 | + } |
| 76 | + double kurtosis() const{ |
| 77 | + return M4/n/pow(M2/n,2) - 3; |
| 78 | + } |
| 79 | + |
| 80 | +}; |
| 81 | + |
| 82 | +statsTest operator+(const statsTest &a, const statsTest &b){ |
| 83 | + statsTest out; |
| 84 | + out.n = a.n + b.n; |
| 85 | + double delta = b.mu - a.mu; |
| 86 | + out.mu = a.mu + delta * b.n/out.n; |
| 87 | + out.M2 = a.M2 + b.M2 + pow(delta,2) * a.n * b.n / out.n; |
| 88 | + out.M3 = a.M3 + b.M3 + pow(delta,3) * a.n * b.n * (a.n - b.n) / pow(out.n,2) + 3*delta*(a.n*b.M2 - b.n*a.M2)/out.n; |
| 89 | + out.M4 = a.M4 + b.M4 + pow(delta,4) * a.n * b.n * (a.n*a.n - a.n*b.n + b.n*b.n)/pow(out.n,3) + 6*pow(delta,2) * (a.n*a.n*b.M2 + b.n*b.n*a.M2 ) / pow(out.n,2) |
| 90 | + + 4*delta*( a.n*b.M3 - b.n*a.M3 ) / out.n; |
| 91 | + |
| 92 | + return out; |
| 93 | +} |
| 94 | + |
| 95 | +bool compare(const statsTest &a, const statsTest &b, const double tol = 1e-12){ |
| 96 | + bool ret = true; |
| 97 | +#define COM(A) if(2.*fabs( a. A - b. A )/(a. A + b. A) > tol){ std::cout << #A << " " << a. A << " " << b. A << std::endl; ret = false; } |
| 98 | + COM(n); |
| 99 | + COM(mu); |
| 100 | + COM(M2); |
| 101 | + COM(M3); |
| 102 | + COM(M4); |
| 103 | + return ret; |
| 104 | +#undef COM |
| 105 | +} |
| 106 | + |
| 107 | + |
| 108 | +bool compare(const statsTest &a, const RunStats &b, const double tol = 1e-12){ |
| 109 | + const RunStats::State &sb = b.get_state(); |
| 110 | + |
| 111 | + // double eta; /**< mean */ |
| 112 | + // double rho; /**< = M2 = \sum_i (x_i - \bar x)^2 */ |
| 113 | + // double tau; /**< = M3 = \sum_i (x_i - \bar x)^3 */ |
| 114 | + // double phi; /**< = M4 = \sum_i (x_i - \bar x)^4 */ |
| 115 | + |
| 116 | + bool ret = true; |
| 117 | +#define COM(A,B) if( \ |
| 118 | + (fabs(a. A)<=tol && fabs(sb. B)>tol) || \ |
| 119 | + (fabs(sb. B)<=tol && fabs(a. A)>tol) || \ |
| 120 | + (fabs(a. A)>tol && fabs(sb. B) > tol && 2.*fabs( a. A - sb. B )/(a. A + sb. B) > tol) \ |
| 121 | + ){ std::cout << #A << " " << a. A << " " << #B << " " << sb. B << std::endl; ret = false; } |
| 122 | + COM(n,count); |
| 123 | + COM(mu,eta); |
| 124 | + COM(M2,rho); |
| 125 | + COM(M3,tau); |
| 126 | + COM(M4,phi); |
| 127 | + return ret; |
| 128 | +#undef COM |
| 129 | +} |
| 130 | + |
| 131 | + |
| 132 | +TEST(TestRunStats, TestIndependentImplementation){ |
| 133 | + //Test that summing two RunStats is the same as if the data were collected by a single RunStats instance |
| 134 | + std::vector<std::vector<double> > all_vals = { |
| 135 | + {160,150,140,122,103,77,33,22,19,7,1}, |
| 136 | + {77,33,22,19,7,1}, |
| 137 | + {77,33,22,19}, |
| 138 | + {-0.2, -0.5, 0.7, -0.4}, |
| 139 | + {3.14,6.28,9.99,10.22}, |
| 140 | + {1000,2000,3000,4000}, |
| 141 | + {22,-22,22,-22} |
| 142 | + }; |
| 143 | + for(auto const &vals: all_vals){ |
| 144 | + std::vector<double> data_a, data_b; |
| 145 | + |
| 146 | + int na = vals.size() / 2; |
| 147 | + int nb = vals.size() - na; |
| 148 | + for(int i=0;i<na;i++) |
| 149 | + data_a.push_back(vals[i]); |
| 150 | + for(int i=na;i<na+nb;i++) |
| 151 | + data_b.push_back(vals[i]); |
| 152 | + |
| 153 | + for(int i=0;i<vals.size();i++){ |
| 154 | + std::cout << vals[i] << " "; |
| 155 | + } |
| 156 | + std::cout << std::endl; |
| 157 | + |
| 158 | + statsTest a(data_a), b(data_b), c(vals); |
| 159 | + |
| 160 | + ASSERT_NEAR(c.mean(), mean(vals), 1e-3); |
| 161 | + ASSERT_NEAR(c.variance(), variance(vals), 1e-3); |
| 162 | + ASSERT_NEAR(c.skewness(), skewness(vals), 1e-3); |
| 163 | + ASSERT_NEAR(c.kurtosis(), kurtosis(vals), 1e-3); |
| 164 | + std::cout << "Full dist mean " << c.mean() << " var " << c.variance() << " skewness " << c.skewness() << " kurtosis " << c.kurtosis() << " match expected" << std::endl; |
| 165 | + |
| 166 | + statsTest sum = a + b; |
| 167 | + |
| 168 | + bool result = compare(c, sum, 1e-10); |
| 169 | + |
| 170 | + std::cout << "Result a+b: " << (result?"pass":"fail") << std::endl; |
| 171 | + |
| 172 | + EXPECT_EQ(result,true); |
| 173 | + ASSERT_NEAR(c.mean(), sum.mean(),1e-5); |
| 174 | + ASSERT_NEAR(c.variance(), sum.variance(),1e-5); |
| 175 | + ASSERT_NEAR(c.skewness(), sum.skewness(),1e-5); |
| 176 | + ASSERT_NEAR(c.kurtosis(), sum.kurtosis(),1e-5); |
| 177 | + } |
| 178 | + |
| 179 | +} |
| 180 | + |
| 181 | + |
| 182 | +TEST(TestRunStats, TestSumCombine){ |
| 183 | + //Test that summing two RunStats is the same as if the data were collected by a single RunStats instance |
| 184 | + std::vector<std::vector<double> > all_vals = { |
| 185 | + {160,150,140,122,103,77,33,22,19,7,1}, |
| 186 | + {77,33,22,19,7,1}, |
| 187 | + {77,33,22,19}, |
| 188 | + {-0.2, -0.5, 0.7, -0.4}, |
| 189 | + {3.14,6.28,9.99,10.22}, |
| 190 | + {1000,2000,3000,4000}, |
| 191 | + {22,-22,22,-22} |
| 192 | + }; |
| 193 | + for(auto const &vals: all_vals){ |
| 194 | + RunStats a(true),b(true),c(true); |
| 195 | + |
| 196 | + int na = vals.size() / 2; |
| 197 | + int nb = vals.size() - na; |
| 198 | + for(int i=0;i<na;i++) |
| 199 | + a.push(vals[i]); |
| 200 | + for(int i=na;i<na+nb;i++) |
| 201 | + b.push(vals[i]); |
| 202 | + |
| 203 | + for(int i=0;i<vals.size();i++){ |
| 204 | + std::cout << vals[i] << " "; |
| 205 | + c.push(vals[i]); |
| 206 | + } |
| 207 | + std::cout << std::endl; |
| 208 | + |
| 209 | + //Check against independent implementation |
| 210 | + std::vector<double> data_a, data_b; |
| 211 | + for(int i=0;i<na;i++) |
| 212 | + data_a.push_back(vals[i]); |
| 213 | + for(int i=na;i<na+nb;i++) |
| 214 | + data_b.push_back(vals[i]); |
| 215 | + |
| 216 | + statsTest ia(data_a), ib(data_b), ic(vals); |
| 217 | + |
| 218 | + std::cout << "Comparing distribution 'a' to independent implementation" << std::endl; |
| 219 | + bool result = compare(ia,a, 1e-10); |
| 220 | + std::cout << "Result: " << (result?"pass":"fail") << std::endl; |
| 221 | + ASSERT_EQ(result,true); |
| 222 | + |
| 223 | + std::cout << "Comparing distribution 'b' to independent implementation" << std::endl; |
| 224 | + result = compare(ib,b, 1e-10); |
| 225 | + std::cout << "Result: " << (result?"pass":"fail") << std::endl; |
| 226 | + ASSERT_EQ(result,true); |
| 227 | + |
| 228 | + std::cout << "Comparing distribution 'c' to independent implementation" << std::endl; |
| 229 | + result = compare(ic,c, 1e-10); |
| 230 | + std::cout << "Result: " << (result?"pass":"fail") << std::endl; |
| 231 | + ASSERT_EQ(result,true); |
| 232 | + |
| 233 | + std::cout << "Comparing distribution 'c' moments to textbook definitions" << std::endl; |
| 234 | + ASSERT_NEAR(c.mean(), mean(vals), 1e-3); |
| 235 | + ASSERT_NEAR(c.variance(), variance(vals), 1e-3); |
| 236 | + ASSERT_NEAR(c.skewness(), skewness(vals), 1e-3); |
| 237 | + ASSERT_NEAR(c.kurtosis(), kurtosis(vals), 1e-3); |
| 238 | + std::cout << "Full dist mean " << c.mean() << " var " << c.variance() << " skewness " << c.skewness() << " kurtosis " << c.kurtosis() << " match expected" << std::endl; |
| 239 | + |
| 240 | + statsTest isum = ia+ib; |
| 241 | + ASSERT_EQ(compare(ic,isum,1e-10),true); |
| 242 | + |
| 243 | + RunStats sum = a + b; |
| 244 | + std::cout << "Comparing combined distribution 'a+b' to independent implementation" << std::endl; |
| 245 | + result = compare(ic,sum, 1e-10); |
| 246 | + std::cout << "Result: " << (result?"pass":"fail") << std::endl; |
| 247 | + ASSERT_EQ(result,true); |
| 248 | + |
| 249 | + |
| 250 | + std::cout << "Comparing combined distribution 'a+b' to 'c'" << std::endl; |
| 251 | + result = compare(c, sum, 1e-10); |
| 252 | + std::cout << "Result: " << (result?"pass":"fail") << std::endl; |
| 253 | + EXPECT_EQ(result,true); |
| 254 | + |
| 255 | + sum = b + a; |
| 256 | + |
| 257 | + std::cout << "Comparing combined distribution 'b+a' to 'c'" << std::endl; |
| 258 | + result = compare(c, sum, 1e-10); |
| 259 | + std::cout << "Result: " << (result?"pass":"fail") << std::endl; |
| 260 | + EXPECT_EQ(result,true); |
| 261 | + } |
| 262 | +} |
| 263 | + |
| 264 | + |
8 | 265 | TEST(TestRunStats, TestStateToFromJSON){ |
9 | 266 | RunStats stats; |
10 | 267 | for(int i=0;i<100;i++) stats.push(i); |
@@ -53,7 +310,6 @@ TEST(TestRunStats, serialize){ |
53 | 310 | EXPECT_EQ(stats, stats_rd); |
54 | 311 | } |
55 | 312 |
|
56 | | - |
57 | 313 |
|
58 | 314 |
|
59 | 315 |
|
0 commit comments