Skip to content

Commit 23b1934

Browse files
committed
Moved the algorithm name/type into AlgoParams such that it contains enough information to instantiate the algorithm completely
Added JSON serialization to AlgoParams + unit test
1 parent be78652 commit 23b1934

13 files changed

Lines changed: 130 additions & 25 deletions

File tree

include/chimbuko/core/ad/ADOutlier.hpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#pragma once
22
#include <chimbuko_config.h>
3-
#include<array>
4-
#include<unordered_set>
3+
#include <array>
4+
#include <unordered_set>
5+
#include <nlohmann/json.hpp>
6+
57
#include "chimbuko/core/util/RunStats.hpp"
68
#include "chimbuko/core/param.hpp"
79
#include "chimbuko/core/param/sstd_param.hpp"
@@ -21,6 +23,8 @@ namespace chimbuko {
2123
* @brief Unified structure for passing the parameters of the AD algorithms to the factory method
2224
*/
2325
struct AlgoParams{
26+
std::string algorithm; /**< The string name of the algorithm: "sstd", "hbos", "copod" */
27+
2428
//SSTD
2529
double sstd_sigma; /**< The number of sigma that defines an outlier*/
2630

@@ -33,6 +37,26 @@ namespace chimbuko {
3337
int hbos_max_bins; /**< The maximum number of bins in a histogram */
3438

3539
AlgoParams();
40+
41+
/**
42+
* @brief Read the parameters from a json object. Note, only "algorithm" and the entries associated with the specific algorithm need to be set
43+
*/
44+
void setJson(const nlohmann::json &in);
45+
46+
/**
47+
* @brief Read the parameters from a json file. Note, only "algorithm" and the entries associated with the specific algorithm need to be set
48+
*/
49+
void loadJsonFile(const std::string &filename);
50+
51+
/**
52+
* @brief Return the parameters as a json object
53+
*/
54+
nlohmann::json getJson() const;
55+
56+
/**
57+
* @brief Equivalence operator
58+
*/
59+
bool operator==(const AlgoParams &r) const;
3660
};
3761

3862

@@ -50,7 +74,7 @@ namespace chimbuko {
5074
/**
5175
* @brief Factory method to select AD algorithm at runtime
5276
*/
53-
static ADOutlier *set_algorithm(int rank, const std::string & algorithm, const AlgoParams &params);
77+
static ADOutlier *set_algorithm(int rank, const AlgoParams &params);
5478

5579
/**
5680
* @brief check if the parameter server is in use

include/chimbuko/core/provdb/ProvDBpruneCore.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace chimbuko {
1010

1111
class ProvDBpruneCore{
1212
public:
13-
ProvDBpruneCore(const std::string &algorithm, const ADOutlier::AlgoParams &algo_params, const std::string &model_ser);
13+
ProvDBpruneCore(const ADOutlier::AlgoParams &algo_params, const std::string &model_ser);
1414

1515
void prune(sonata::Database &db);
1616

include/chimbuko/modules/factory.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ namespace chimbuko{
1919

2020
/**
2121
*@brief A factory function for ProvDBpruneCore instances
22-
*@param algorithm The AD algorithm
2322
*@param algo_params Parameters for the algoritm
2423
*@param model_ser The serialized model
2524
*/
2625
std::unique_ptr<ProvDBpruneCore> factoryInstantiateProvDBprune(const std::string &module,
27-
const std::string &algorithm, const ADOutlier::AlgoParams &algo_params, const std::string &model_ser);
26+
const ADOutlier::AlgoParams &algo_params, const std::string &model_ser);
2827

2928
/**
3029
*@brief A factory function for PSmoduleDataManager instances

include/chimbuko/modules/performance_analysis/provdb/ProvDBprune.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace chimbuko {
1414

1515
class ProvDBprune: public ProvDBpruneCore{
1616
public:
17-
ProvDBprune(const std::string &algorithm, const ADOutlier::AlgoParams &algo_params, const std::string &model_ser): ProvDBpruneCore(algorithm,algo_params,model_ser){}
17+
ProvDBprune(const ADOutlier::AlgoParams &algo_params, const std::string &model_ser): ProvDBpruneCore(algo_params,model_ser){}
1818

1919
/**
2020
* @brief Prune the database shard. Both the anomalies and normalexecs will be updated

src/core/ad/ADOutlier.cpp

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,57 @@
1010
#include <boost/math/distributions/normal.hpp>
1111
#include <boost/math/distributions/empirical_cumulative_distribution_function.hpp>
1212
#include <limits>
13-
13+
#include <fstream>
1414
using namespace chimbuko;
1515

1616

17-
ADOutlier::AlgoParams::AlgoParams(): sstd_sigma(6.0), hbos_thres(0.99), glob_thres(true), hbos_max_bins(200){} //, func_threshold_file("")
17+
ADOutlier::AlgoParams::AlgoParams(): algorithm("hbos"), sstd_sigma(6.0), hbos_thres(0.99), glob_thres(true), hbos_max_bins(200){} //, func_threshold_file("")
18+
19+
bool ADOutlier::AlgoParams::operator==(const AlgoParams &r) const{ return algorithm == r.algorithm && sstd_sigma == r.sstd_sigma && hbos_thres == r.hbos_thres && glob_thres == r.glob_thres && hbos_max_bins == r.hbos_max_bins; }
20+
21+
void ADOutlier::AlgoParams::setJson(const nlohmann::json &in){
22+
#define JSON_CHECK(to) if(!in.contains(#to)) fatal_error("Expected key " #to);
23+
#define JSON_GET(to) if(in.contains(#to)) to = in[#to].template get<decltype(to)>()
24+
//Check for required
25+
JSON_CHECK(algorithm);
26+
if(algorithm == "sstd"){
27+
JSON_CHECK(sstd_sigma);
28+
}else if(algorithm == "hbos"){
29+
JSON_CHECK(glob_thres);
30+
JSON_CHECK(hbos_max_bins);
31+
}
32+
if(algorithm == "hbos" || algorithm == "copod"){
33+
JSON_CHECK(hbos_thres);
34+
}
35+
//Get all available
36+
JSON_GET(algorithm);
37+
JSON_GET(sstd_sigma);
38+
JSON_GET(glob_thres);
39+
JSON_GET(hbos_max_bins);
40+
JSON_GET(hbos_thres);
41+
#undef JSON_CHECK
42+
#undef JSON_GET
43+
44+
}
45+
46+
void ADOutlier::AlgoParams::loadJsonFile(const std::string &filename){
47+
std::ifstream f(filename);
48+
nlohmann::json j; f >> j;
49+
setJson(j);
50+
}
51+
52+
nlohmann::json ADOutlier::AlgoParams::getJson() const{
53+
nlohmann::json out;
54+
#define JSON_SET(key) out[#key] = key
55+
JSON_SET(algorithm);
56+
JSON_SET(sstd_sigma);
57+
JSON_SET(hbos_thres);
58+
JSON_SET(glob_thres);
59+
JSON_SET(hbos_max_bins);
60+
return out;
61+
#undef JSON_SET
62+
}
63+
1864

1965

2066
/* ---------------------------------------------------------------------------
@@ -50,22 +96,22 @@ ADOutlier::~ADOutlier() {
5096
// }
5197

5298

53-
ADOutlier *ADOutlier::set_algorithm(int rank, const std::string & algorithm, const AlgoParams &params) {
54-
if (algorithm == "sstd" || algorithm == "SSTD") {
99+
ADOutlier *ADOutlier::set_algorithm(int rank, const AlgoParams &params) {
100+
if (params.algorithm == "sstd" || params.algorithm == "SSTD") {
55101
return new ADOutlierSSTD(rank,params.sstd_sigma);
56102
}
57-
else if (algorithm == "hbos" || algorithm == "HBOS") {
103+
else if (params.algorithm == "hbos" || params.algorithm == "HBOS") {
58104
ADOutlierHBOS* alg = new ADOutlierHBOS(rank,params.hbos_thres, params.glob_thres, params.hbos_max_bins);
59105
//loadPerFunctionThresholds(alg,params.func_threshold_file);
60106
return alg;
61107
}
62-
else if (algorithm == "copod" || algorithm == "COPOD") {
108+
else if (params.algorithm == "copod" || params.algorithm == "COPOD") {
63109
ADOutlierCOPOD* alg = new ADOutlierCOPOD(rank,params.hbos_thres);
64110
//loadPerFunctionThresholds(alg,params.func_threshold_file);
65111
return alg;
66112
}
67113
else{
68-
fatal_error("Invalid algorithm: " + algorithm);
114+
fatal_error("Invalid algorithm: " + params.algorithm);
69115
}
70116
}
71117

src/core/chimbuko.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,14 @@ void ChimbukoBase::init_net_client(){
144144

145145
void ChimbukoBase::init_outlier(){
146146
ADOutlier::AlgoParams params;
147+
params.algorithm = m_base_params.ad_algorithm;
147148
params.hbos_thres = m_base_params.hbos_threshold;
148149
params.glob_thres = m_base_params.hbos_use_global_threshold;
149150
params.sstd_sigma = m_base_params.outlier_sigma;
150151
params.hbos_max_bins = m_base_params.hbos_max_bins;
151152
//params.func_threshold_file = m_base_params.func_threshold_file;
152153

153-
m_outlier = ADOutlier::set_algorithm(m_base_params.rank, m_base_params.ad_algorithm, params);
154+
m_outlier = ADOutlier::set_algorithm(m_base_params.rank, params);
154155
if(m_net_client) m_outlier->linkNetworkClient(m_net_client);
155156
m_outlier->linkPerf(&m_perf);
156157
m_outlier->setGlobalModelSyncFrequency(m_base_params.global_model_sync_freq);

src/core/provdb/ProvDBpruneCore.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
using namespace chimbuko;
66

7-
ProvDBpruneCore::ProvDBpruneCore(const std::string &algorithm, const ADOutlier::AlgoParams &algo_params, const std::string &model_ser): m_outlier(ADOutlier::set_algorithm(0,algorithm,algo_params)){
7+
ProvDBpruneCore::ProvDBpruneCore(const ADOutlier::AlgoParams &algo_params, const std::string &model_ser): m_outlier(ADOutlier::set_algorithm(0,algo_params)){
88
m_outlier->setGlobalParameters(model_ser); //input model
99
m_outlier->setGlobalModelSyncFrequency(0); //fix model
1010
}

src/modules/factory.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ std::unique_ptr<chimbuko::PSmoduleDataManagerCore> chimbuko::modules::factoryIns
3030
}
3131

3232
std::unique_ptr<chimbuko::ProvDBpruneCore> chimbuko::modules::factoryInstantiateProvDBprune(const std::string &module,
33-
const std::string &algorithm, const ADOutlier::AlgoParams &algo_params, const std::string &model_ser){
33+
const ADOutlier::AlgoParams &algo_params, const std::string &model_ser){
3434
if(module == "performance_analysis"){
35-
return std::unique_ptr<ProvDBpruneCore>(new performance_analysis::ProvDBprune(algorithm,algo_params,model_ser) );
35+
return std::unique_ptr<ProvDBpruneCore>(new performance_analysis::ProvDBprune(algo_params,model_ser) );
3636
}else{
3737
fatal_error("Unknown module");
3838
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include<chimbuko/core/ad/ADOutlier.hpp>
2+
#include <fstream>
3+
#include "gtest/gtest.h"
4+
5+
using namespace chimbuko;
6+
7+
TEST(ADOutlier, serializeAlgoParams){
8+
ADOutlier::AlgoParams pin;
9+
pin.algorithm = "copod";
10+
pin.sstd_sigma = 3.14;
11+
pin.hbos_thres = 0.33;
12+
pin.glob_thres = false;
13+
pin.hbos_max_bins = 1234;
14+
15+
nlohmann::json j = pin.getJson();
16+
17+
ADOutlier::AlgoParams pout;
18+
pout.setJson(j);
19+
20+
EXPECT_EQ(pin,pout);
21+
22+
std::string fn = "/tmp/algo_params.json";
23+
{
24+
std::ofstream f(fn);
25+
f << j.dump(4);
26+
}
27+
ADOutlier::AlgoParams pfile;
28+
pfile.loadJsonFile(fn);
29+
30+
EXPECT_EQ(pin,pfile);
31+
}

test/unit_tests/core/ad/Makefile.am

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ AM_CPPFLAGS = -I$(top_srcdir)/include -I$(top_srcdir)/3rdparty @PS_FLAGS@
22
LDADD = $(top_builddir)/src/libchimbuko.la -lgtest -lstdc++fs
33

44
testdir = $(prefix)/test/unit_tests/core/ad
5-
test_PROGRAMS = ADio utils
5+
test_PROGRAMS = ADio ADOutlier utils
66

77
ADio_SOURCES = ADio.cpp ../../unit_test_main_mpi.cpp
88
ADio_LDADD = $(LDADD)
99

10+
ADOutlier_SOURCES = ADOutlier.cpp ../../unit_test_main_mpi.cpp
11+
ADOutlier_LDADD = $(LDADD)
12+
1013
utils_SOURCES = utils.cpp ../../unit_test_main_mpi.cpp
1114
utils_LDADD = $(LDADD)

0 commit comments

Comments
 (0)