Skip to content

Commit b5ce01a

Browse files
committed
Added RunStats unit tests comparing statistics to textbook definitions, comparing the underlying data to an independent implementation and checking the combination operation
Fixed a subtle sign error in the combination operation of two RunStats objects that resulted in incorrect values for the skewness and kurtosis of the result
1 parent f08a6a7 commit b5ce01a

3 files changed

Lines changed: 277 additions & 8 deletions

File tree

include/chimbuko/util/RunStats.hpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@ namespace chimbuko {
2424
/**
2525
* @brief Internal state of RunStats object
2626
*
27+
* Note the variables in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance are M2,M3,M4. The mappings are provided in the comments below.
2728
*/
2829
struct State {
2930
double count; /**< count of instances */
3031
double eta; /**< mean */
31-
double rho;
32-
double tau;
33-
double phi;
32+
double rho; /**< = M2 = \sum_i (x_i - \bar x)^2 */
33+
double tau; /**< = M3 = \sum_i (x_i - \bar x)^3 */
34+
double phi; /**< = M4 = \sum_i (x_i - \bar x)^4 */
3435
double min; /**< minimum */
3536
double max; /**< maximum */
3637
double acc; /**< sum */
@@ -231,6 +232,13 @@ namespace chimbuko {
231232
*/
232233
double accumulate() const;
233234
double mean() const;
235+
236+
/**
237+
* @brief Return the variance of the data
238+
*
239+
* If ddof=1 (default) the variance will include Bessel's correction, and represents an estimate of the population variance.
240+
* If ddof=0 the variance will be the variance of the sample
241+
*/
234242
double variance(double ddof=1.0) const;
235243
double stddev(double ddof=1.0) const;
236244
double skewness() const;
@@ -241,6 +249,11 @@ namespace chimbuko {
241249
*/
242250
void set_do_accumulate(bool do_accumulate) { m_do_accumulate = do_accumulate; }
243251

252+
/**
253+
* @brief Determine whether the sum of all values is to be maintained
254+
*/
255+
bool get_do_accumulate() const{ return m_do_accumulate; }
256+
244257
/**
245258
* @brief Get the current statistics as a JSON object
246259
*/
@@ -255,7 +268,7 @@ namespace chimbuko {
255268
/**
256269
* @brief Combine two RunStats instances such that the resulting statistics are the union of the two
257270
*/
258-
friend RunStats operator+(const RunStats a, const RunStats b);
271+
friend RunStats operator+(const RunStats &a, const RunStats &b);
259272

260273
/**
261274
* @brief Combine two RunStats instances such that the resulting statistics are the union of the two
@@ -277,7 +290,7 @@ namespace chimbuko {
277290
bool m_do_accumulate; /**< True if the sum of the input values are maintained */
278291
};
279292

280-
RunStats operator+(const RunStats a, const RunStats b);
293+
RunStats operator+(const RunStats &a, const RunStats &b);
281294
bool operator==(const RunStats& a, const RunStats& b);
282295
bool operator!=(const RunStats& a, const RunStats& b);
283296

src/util/RunStats.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,13 @@ double RunStats::kurtosis() const {
155155
return m_state.count * m_state.phi / (m_state.rho * m_state.rho) - 3.0;
156156
}
157157

158-
RunStats chimbuko::operator+(const RunStats a, const RunStats b)
158+
RunStats chimbuko::operator+(const RunStats &a, const RunStats &b)
159159
{
160160
double sum_count = a.m_state.count + b.m_state.count;
161161
if (sum_count == 0.0)
162162
return RunStats();
163163

164-
double delta = a.m_state.eta - b.m_state.eta;
164+
double delta = b.m_state.eta - a.m_state.eta;
165165
double delta2 = delta * delta;
166166
double delta3 = delta * delta2;
167167
double delta4 = delta2 * delta2;

test/unit_tests/util/RunStats.cpp

Lines changed: 257 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,266 @@
22
#include "gtest/gtest.h"
33
#include <cereal/archives/portable_binary.hpp>
44
#include <sstream>
5+
#include "../unit_test_common.hpp"
56

67
using namespace chimbuko;
78

9+
//Textbook definitions
10+
double mean(const std::vector<double> &a){
11+
double r=0;
12+
for(double v:a) r+=v;
13+
return r/a.size();
14+
}
15+
double variance(const std::vector<double> &a, bool incl_bessel = true){
16+
double n = a.size();
17+
double r=0,r2=0;
18+
for(double v:a){ r+=v; r2+=v*v; }
19+
r = (r2/n - r/n*r/n);
20+
if(incl_bessel) r *= n/(n-1); //include Bessel's correction by default
21+
return r;
22+
}
23+
double skewness(const std::vector<double> &a){
24+
double mu = mean(a);
25+
double sigma = sqrt(variance(a,false));
26+
double r = 0;
27+
for(double v:a) r += pow( (v - mu)/sigma, 3 );
28+
return r/a.size();
29+
}
30+
double kurtosis(const std::vector<double> &a){ //technically excess kurtosis
31+
double mu = mean(a);
32+
double sigma = sqrt(variance(a,false));
33+
double r = 0;
34+
for(double v:a) r += pow( (v - mu)/sigma, 4 );
35+
return r/a.size() - 3;
36+
}
37+
38+
39+
40+
//Independent implementation using https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
41+
struct statsTest{
42+
double n;
43+
double mu;
44+
double M2;
45+
double M3;
46+
double M4;
47+
48+
statsTest(): n(0), mu(0), M2(0), M3(0), M4(0){}
49+
statsTest(const std::vector<double> &v){
50+
n = v.size();
51+
mu = 0;
52+
M2 = 0;
53+
M3 = 0;
54+
M4 = 0;
55+
for(double e: v)
56+
mu += e;
57+
mu /= n;
58+
59+
for(double e: v){
60+
M2 += pow(e - mu,2);
61+
M3 += pow(e - mu,3);
62+
M4 += pow(e - mu,4);
63+
}
64+
65+
}
66+
67+
double variance() const{ //includes Bessel's correction
68+
return M2/(n-1.);
69+
}
70+
double mean() const{
71+
return mu;
72+
}
73+
double skewness() const{
74+
return M3/n/pow(M2/n,3./2.);
75+
}
76+
double kurtosis() const{
77+
return M4/n/pow(M2/n,2) - 3;
78+
}
79+
80+
};
81+
82+
statsTest operator+(const statsTest &a, const statsTest &b){
83+
statsTest out;
84+
out.n = a.n + b.n;
85+
double delta = b.mu - a.mu;
86+
out.mu = a.mu + delta * b.n/out.n;
87+
out.M2 = a.M2 + b.M2 + pow(delta,2) * a.n * b.n / out.n;
88+
out.M3 = a.M3 + b.M3 + pow(delta,3) * a.n * b.n * (a.n - b.n) / pow(out.n,2) + 3*delta*(a.n*b.M2 - b.n*a.M2)/out.n;
89+
out.M4 = a.M4 + b.M4 + pow(delta,4) * a.n * b.n * (a.n*a.n - a.n*b.n + b.n*b.n)/pow(out.n,3) + 6*pow(delta,2) * (a.n*a.n*b.M2 + b.n*b.n*a.M2 ) / pow(out.n,2)
90+
+ 4*delta*( a.n*b.M3 - b.n*a.M3 ) / out.n;
91+
92+
return out;
93+
}
94+
95+
bool compare(const statsTest &a, const statsTest &b, const double tol = 1e-12){
96+
bool ret = true;
97+
#define COM(A) if(2.*fabs( a. A - b. A )/(a. A + b. A) > tol){ std::cout << #A << " " << a. A << " " << b. A << std::endl; ret = false; }
98+
COM(n);
99+
COM(mu);
100+
COM(M2);
101+
COM(M3);
102+
COM(M4);
103+
return ret;
104+
#undef COM
105+
}
106+
107+
108+
bool compare(const statsTest &a, const RunStats &b, const double tol = 1e-12){
109+
const RunStats::State &sb = b.get_state();
110+
111+
// double eta; /**< mean */
112+
// double rho; /**< = M2 = \sum_i (x_i - \bar x)^2 */
113+
// double tau; /**< = M3 = \sum_i (x_i - \bar x)^3 */
114+
// double phi; /**< = M4 = \sum_i (x_i - \bar x)^4 */
115+
116+
bool ret = true;
117+
#define COM(A,B) if( \
118+
(fabs(a. A)<=tol && fabs(sb. B)>tol) || \
119+
(fabs(sb. B)<=tol && fabs(a. A)>tol) || \
120+
(fabs(a. A)>tol && fabs(sb. B) > tol && 2.*fabs( a. A - sb. B )/(a. A + sb. B) > tol) \
121+
){ std::cout << #A << " " << a. A << " " << #B << " " << sb. B << std::endl; ret = false; }
122+
COM(n,count);
123+
COM(mu,eta);
124+
COM(M2,rho);
125+
COM(M3,tau);
126+
COM(M4,phi);
127+
return ret;
128+
#undef COM
129+
}
130+
131+
132+
TEST(TestRunStats, TestIndependentImplementation){
133+
//Test that summing two RunStats is the same as if the data were collected by a single RunStats instance
134+
std::vector<std::vector<double> > all_vals = {
135+
{160,150,140,122,103,77,33,22,19,7,1},
136+
{77,33,22,19,7,1},
137+
{77,33,22,19},
138+
{-0.2, -0.5, 0.7, -0.4},
139+
{3.14,6.28,9.99,10.22},
140+
{1000,2000,3000,4000},
141+
{22,-22,22,-22}
142+
};
143+
for(auto const &vals: all_vals){
144+
std::vector<double> data_a, data_b;
145+
146+
int na = vals.size() / 2;
147+
int nb = vals.size() - na;
148+
for(int i=0;i<na;i++)
149+
data_a.push_back(vals[i]);
150+
for(int i=na;i<na+nb;i++)
151+
data_b.push_back(vals[i]);
152+
153+
for(int i=0;i<vals.size();i++){
154+
std::cout << vals[i] << " ";
155+
}
156+
std::cout << std::endl;
157+
158+
statsTest a(data_a), b(data_b), c(vals);
159+
160+
ASSERT_NEAR(c.mean(), mean(vals), 1e-3);
161+
ASSERT_NEAR(c.variance(), variance(vals), 1e-3);
162+
ASSERT_NEAR(c.skewness(), skewness(vals), 1e-3);
163+
ASSERT_NEAR(c.kurtosis(), kurtosis(vals), 1e-3);
164+
std::cout << "Full dist mean " << c.mean() << " var " << c.variance() << " skewness " << c.skewness() << " kurtosis " << c.kurtosis() << " match expected" << std::endl;
165+
166+
statsTest sum = a + b;
167+
168+
bool result = compare(c, sum, 1e-10);
169+
170+
std::cout << "Result a+b: " << (result?"pass":"fail") << std::endl;
171+
172+
EXPECT_EQ(result,true);
173+
ASSERT_NEAR(c.mean(), sum.mean(),1e-5);
174+
ASSERT_NEAR(c.variance(), sum.variance(),1e-5);
175+
ASSERT_NEAR(c.skewness(), sum.skewness(),1e-5);
176+
ASSERT_NEAR(c.kurtosis(), sum.kurtosis(),1e-5);
177+
}
178+
179+
}
180+
181+
182+
TEST(TestRunStats, TestSumCombine){
183+
//Test that summing two RunStats is the same as if the data were collected by a single RunStats instance
184+
std::vector<std::vector<double> > all_vals = {
185+
{160,150,140,122,103,77,33,22,19,7,1},
186+
{77,33,22,19,7,1},
187+
{77,33,22,19},
188+
{-0.2, -0.5, 0.7, -0.4},
189+
{3.14,6.28,9.99,10.22},
190+
{1000,2000,3000,4000},
191+
{22,-22,22,-22}
192+
};
193+
for(auto const &vals: all_vals){
194+
RunStats a(true),b(true),c(true);
195+
196+
int na = vals.size() / 2;
197+
int nb = vals.size() - na;
198+
for(int i=0;i<na;i++)
199+
a.push(vals[i]);
200+
for(int i=na;i<na+nb;i++)
201+
b.push(vals[i]);
202+
203+
for(int i=0;i<vals.size();i++){
204+
std::cout << vals[i] << " ";
205+
c.push(vals[i]);
206+
}
207+
std::cout << std::endl;
208+
209+
//Check against independent implementation
210+
std::vector<double> data_a, data_b;
211+
for(int i=0;i<na;i++)
212+
data_a.push_back(vals[i]);
213+
for(int i=na;i<na+nb;i++)
214+
data_b.push_back(vals[i]);
215+
216+
statsTest ia(data_a), ib(data_b), ic(vals);
217+
218+
std::cout << "Comparing distribution 'a' to independent implementation" << std::endl;
219+
bool result = compare(ia,a, 1e-10);
220+
std::cout << "Result: " << (result?"pass":"fail") << std::endl;
221+
ASSERT_EQ(result,true);
222+
223+
std::cout << "Comparing distribution 'b' to independent implementation" << std::endl;
224+
result = compare(ib,b, 1e-10);
225+
std::cout << "Result: " << (result?"pass":"fail") << std::endl;
226+
ASSERT_EQ(result,true);
227+
228+
std::cout << "Comparing distribution 'c' to independent implementation" << std::endl;
229+
result = compare(ic,c, 1e-10);
230+
std::cout << "Result: " << (result?"pass":"fail") << std::endl;
231+
ASSERT_EQ(result,true);
232+
233+
std::cout << "Comparing distribution 'c' moments to textbook definitions" << std::endl;
234+
ASSERT_NEAR(c.mean(), mean(vals), 1e-3);
235+
ASSERT_NEAR(c.variance(), variance(vals), 1e-3);
236+
ASSERT_NEAR(c.skewness(), skewness(vals), 1e-3);
237+
ASSERT_NEAR(c.kurtosis(), kurtosis(vals), 1e-3);
238+
std::cout << "Full dist mean " << c.mean() << " var " << c.variance() << " skewness " << c.skewness() << " kurtosis " << c.kurtosis() << " match expected" << std::endl;
239+
240+
statsTest isum = ia+ib;
241+
ASSERT_EQ(compare(ic,isum,1e-10),true);
242+
243+
RunStats sum = a + b;
244+
std::cout << "Comparing combined distribution 'a+b' to independent implementation" << std::endl;
245+
result = compare(ic,sum, 1e-10);
246+
std::cout << "Result: " << (result?"pass":"fail") << std::endl;
247+
ASSERT_EQ(result,true);
248+
249+
250+
std::cout << "Comparing combined distribution 'a+b' to 'c'" << std::endl;
251+
result = compare(c, sum, 1e-10);
252+
std::cout << "Result: " << (result?"pass":"fail") << std::endl;
253+
EXPECT_EQ(result,true);
254+
255+
sum = b + a;
256+
257+
std::cout << "Comparing combined distribution 'b+a' to 'c'" << std::endl;
258+
result = compare(c, sum, 1e-10);
259+
std::cout << "Result: " << (result?"pass":"fail") << std::endl;
260+
EXPECT_EQ(result,true);
261+
}
262+
}
263+
264+
8265
TEST(TestRunStats, TestStateToFromJSON){
9266
RunStats stats;
10267
for(int i=0;i<100;i++) stats.push(i);
@@ -53,7 +310,6 @@ TEST(TestRunStats, serialize){
53310
EXPECT_EQ(stats, stats_rd);
54311
}
55312

56-
57313

58314

59315

0 commit comments

Comments
 (0)