Skip to content

Commit 95d7594

Browse files
committed
Merge branch 'ckelly_develop' into sm_release
2 parents 466e171 + 2852e04 commit 95d7594

38 files changed

Lines changed: 629 additions & 358 deletions

app/Makefile.am

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
AM_CPPFLAGS = -I$(top_srcdir)/include -I$(top_srcdir)/3rdparty @PS_FLAGS@
2-
LDADD = ../src/libchimbuko.la -lstdc++fs
2+
LDADD = ../src/libchimbuko.la -lstdc++fs
33

44
bin_PROGRAMS = driver pclient pclient_stats hpserver pserver pshutdown sstSinker sst_view bpfile_replay
55

@@ -32,14 +32,17 @@ bpfile_replay_LDADD = $(LDADD)
3232

3333

3434
if ENABLE_PROVDB
35-
bin_PROGRAMS += provdb_admin provdb_query
35+
bin_PROGRAMS += provdb_admin provdb_query provdb_shutdown
3636

3737
provdb_admin_SOURCES = provdb_admin.cpp
3838
provdb_admin_LDADD = $(LDADD)
3939

4040
provdb_query_SOURCES = provdb_query.cpp
4141
provdb_query_LDADD = $(LDADD)
4242

43+
provdb_shutdown_SOURCES = provdb_shutdown.cpp
44+
provdb_shutdown_LDADD = $(LDADD)
45+
4346
else
4447
echo "Provenance DB not being built"
4548
endif

app/driver.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
// #include "chimbuko/AD.hpp"
21
#include "chimbuko/chimbuko.hpp"
32
#include "chimbuko/verbose.hpp"
43
#include "chimbuko/util/string.hpp"
54
#include "chimbuko/util/commandLineParser.hpp"
5+
#include "chimbuko/util/error.hpp"
66
#include <chrono>
77
#include <cstdlib>
88

@@ -285,14 +285,17 @@ int main(int argc, char ** argv){
285285
}
286286
catch (const std::invalid_argument &e){
287287
std::cout << '[' << getDateTime() << ", rank " << params.rank << "] Driver : caught invalid argument: " << e.what() << std::endl;
288+
if(params.err_outputpath.size()) recoverable_error(std::string("Driver : caught invalid argument: ") + e.what()); //ensure errors also written to error logs
288289
error = true;
289290
}
290291
catch (const std::ios_base::failure &e){
291292
std::cout << '[' << getDateTime() << ", rank " << params.rank << "] Driver : I/O base exception caught: " << e.what() << std::endl;
293+
if(params.err_outputpath.size()) recoverable_error(std::string("Driver : I/O base exception caught: ") + e.what());
292294
error = true;
293295
}
294296
catch (const std::exception &e){
295297
std::cout << '[' << getDateTime() << ", rank " << params.rank << "] Driver : Exception caught: " << e.what() << std::endl;
298+
if(params.err_outputpath.size()) recoverable_error(std::string("Driver : Exception caught: ") + e.what());
296299
error = true;
297300
}
298301

app/provdb_admin.cpp

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ void pserver_goodbye(const tl::request& req) {
6666
progressStream << "ProvDB Admin: Pserver has said goodbye" << std::endl;
6767
}
6868

69+
bool cmd_shutdown = false; //true if a client has requested that the server shut down
70+
71+
void client_stop_rpc(const tl::request& req) {
72+
std::lock_guard<tl::mutex> lock(*mtx);
73+
cmd_shutdown = true;
74+
progressStream << "ProvDB Admin: Received shutdown request from client" << std::endl;
75+
}
76+
77+
78+
79+
6980

7081
struct ProvdbArgs{
7182
std::string ip;
@@ -75,8 +86,9 @@ struct ProvdbArgs{
7586
int nthreads;
7687
std::string db_type;
7788
unsigned long db_commit_freq;
89+
std::string db_write_dir;
7890

79-
ProvdbArgs(): engine("ofi+tcp"), autoshutdown(true), nshards(1), db_type("unqlite"), nthreads(1), db_commit_freq(10000){}
91+
ProvdbArgs(): engine("ofi+tcp"), autoshutdown(true), nshards(1), db_type("unqlite"), nthreads(1), db_commit_freq(10000), db_write_dir("."){}
8092
};
8193

8294

@@ -88,7 +100,7 @@ int main(int argc, char** argv) {
88100
progressStream << "ProvDB Admin: Enabling verbose debug output" << std::endl;
89101
enableVerboseLogging() = true;
90102
spdlog::set_level(spdlog::level::trace); //enable logging of Sonata
91-
}
103+
}
92104

93105
//argv[1] should specify the ip address and port (the only way to fix the port that I'm aware of)
94106
//Should be of form <ip address>:<port> eg. 127.0.0.1:1234
@@ -101,6 +113,7 @@ int main(int argc, char** argv) {
101113
addOptionalCommandLineArg(parser, nthreads, "Specify the number of RPC handler threads (default 1)");
102114
addOptionalCommandLineArg(parser, db_type, "Specify the Sonata database type (default \"unqlite\")");
103115
addOptionalCommandLineArg(parser, db_commit_freq, "Specify the frequency at which the database flushes to disk in ms (default 10000)");
116+
addOptionalCommandLineArg(parser, db_write_dir, "Specify the directory in which the database shards will be written (default \".\")");
104117

105118
if(argc-1 < parser.nMandatoryArgs() || (argc == 2 && std::string(argv[1]) == "-help")){
106119
parser.help(std::cout);
@@ -120,7 +133,7 @@ int main(int argc, char** argv) {
120133
progressStream << "ProvDB Admin: initializing thallium with address: " << eng_opt << std::endl;
121134

122135
//Initialize provider engine
123-
tl::engine engine(eng_opt, THALLIUM_SERVER_MODE, true, args.nthreads);
136+
tl::engine engine(eng_opt, THALLIUM_SERVER_MODE, true, args.nthreads);
124137

125138
#ifdef _PERF_METRIC
126139
//Get Margo to output profiling information
@@ -141,29 +154,30 @@ int main(int argc, char** argv) {
141154
engine.define("client_goodbye",client_goodbye).disable_response();
142155
engine.define("pserver_hello",pserver_hello).disable_response();
143156
engine.define("pserver_goodbye",pserver_goodbye).disable_response();
157+
engine.define("stop_server",client_stop_rpc).disable_response();
144158

145159
std::string addr = (std::string)engine.self(); //ip and port of admin
146160

147161
{ //Scope in which provider is active
148162

149163
//Initialize provider
150164
sonata::Provider provider(engine, 0);
151-
165+
152166
progressStream << "ProvDB Admin: Provider is running on " << addr << std::endl;
153167

154168
{ //Scope in which admin object is active
155169
sonata::Admin admin(engine);
156170
progressStream << "ProvDB Admin: creating global data database" << std::endl;
157171
std::string glob_db_name = "provdb.global";
158-
std::string glob_db_config = stringize("{ \"path\" : \"./%s.unqlite\" }", glob_db_name.c_str());
172+
std::string glob_db_config = stringize("{ \"path\" : \"%s/%s.unqlite\" }", args.db_write_dir.c_str(), glob_db_name.c_str());
159173
admin.createDatabase(addr, 0, glob_db_name, args.db_type, glob_db_config);
160-
174+
161175
progressStream << "ProvDB Admin: creating " << args.nshards << " database shards" << std::endl;
162176

163177
std::vector<std::string> db_shard_names(args.nshards);
164178
for(int s=0;s<args.nshards;s++){
165179
std::string db_name = stringize("provdb.%d",s);
166-
std::string config = stringize("{ \"path\" : \"./%s.unqlite\" }", db_name.c_str());
180+
std::string config = stringize("{ \"path\" : \"%s/%s.unqlite\" }", args.db_write_dir.c_str(), db_name.c_str());
167181
progressStream << "ProvDB Admin: Shard " << s << ": " << db_name << " " << config << " " << args.db_type << std::endl;
168182
admin.createDatabase(addr, 0, db_name, args.db_type, config);
169183
db_shard_names[s] = db_name;
@@ -172,7 +186,7 @@ int main(int argc, char** argv) {
172186
//Create the collections
173187
{ //scope in which client is active
174188
sonata::Client client(engine);
175-
189+
176190
//Initialize the provdb shards
177191
std::vector<sonata::Database> db(args.nshards);
178192
for(int s=0;s<args.nshards;s++){
@@ -201,7 +215,7 @@ int main(int argc, char** argv) {
201215

202216

203217
//Spin quietly until SIGTERM sent
204-
signal(SIGTERM, termSignalHandler);
218+
signal(SIGTERM, termSignalHandler);
205219
progressStream << "ProvDB Admin: main thread waiting for completion" << std::endl;
206220
while(!stop_wait_loop) { //stop wait loop will be set by SIGTERM handler
207221
tl::thread::sleep(engine, 1000); //Thallium engine sleeps but listens for rpc requests
@@ -214,11 +228,13 @@ int main(int argc, char** argv) {
214228
glob_db.commit();
215229
commit_timer_start = Clock::now();
216230
}
217-
231+
218232
//If at least one client has previously connected but none are now connected, shutdown the server
219233
//If all clients disconnected we must also wait for the pserver to disconnect (if it is connected)
234+
235+
//If args.autoshutdown is disabled we can force shutdown via a "stop_server" RPC
220236
if(
221-
args.autoshutdown &&
237+
(args.autoshutdown || cmd_shutdown) &&
222238
( a_client_has_connected && connected.size() == 0 ) &&
223239
( !pserver_has_connected || (pserver_has_connected && !pserver_connected) )
224240
){
@@ -233,18 +249,18 @@ int main(int argc, char** argv) {
233249
progressStream << "ProvDB Admin: destroying pserver database as it didn't connect (connection is optional)" << std::endl;
234250
admin.destroyDatabase(addr, 0, glob_db_name);
235251
}
236-
252+
237253
progressStream << "ProvDB Admin: ending admin scope" << std::endl;
238254
}//admin scope
239255

240256
progressStream << "ProvDB Admin: ending provider scope" << std::endl;
241257
}//provider scope
242258

243259
progressStream << "ProvDB Admin: shutting down server engine" << std::endl;
244-
delete mtx; //delete mutex prior to engine finalize
260+
delete mtx; //delete mutex prior to engine finalize
245261
engine.finalize();
246-
progressStream << "ProvDB Admin: finished, exiting engine scope" << std::endl;
262+
progressStream << "ProvDB Admin: finished, exiting engine scope" << std::endl;
247263
}
248-
progressStream << "ProvDB Admin: finished, exiting main scope" << std::endl;
264+
progressStream << "ProvDB Admin: finished, exiting main scope" << std::endl;
249265
return 0;
250266
}

app/pserver.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ int main (int argc, char ** argv){
131131
nlohmann::json in_p;
132132
in >> in_p;
133133
global_func_index_map.deserialize(in_p["func_index_map"]);
134-
param->assign(in_p["alg_params"].dump()); //param.assign(in_p["alg_params"].dump());
134+
param->assign(in_p["alg_params"].dump());
135135
}
136136

137137
#ifdef _USE_MPINET
@@ -181,8 +181,8 @@ int main (int argc, char ** argv){
181181
if(args.stat_outputdir.size()) std::cout << "(dir @ " << args.stat_outputdir << ")";
182182
}
183183

184-
net.add_payload(new NetPayloadUpdateParams(param, args.freeze_params)); //new NetPayloadUpdateParams(&param, args.freeze_params));
185-
net.add_payload(new NetPayloadGetParams(param)); //new NetPayloadGetParams(&param));
184+
net.add_payload(new NetPayloadUpdateParams(param, args.freeze_params));
185+
net.add_payload(new NetPayloadGetParams(param));
186186
net.add_payload(new NetPayloadUpdateAnomalyStats(&global_func_stats));
187187
net.add_payload(new NetPayloadUpdateCounterStats(&global_counter_stats));
188188
net.add_payload(new NetPayloadGlobalFunctionIndexMapBatched(&global_func_index_map));
@@ -227,7 +227,7 @@ int main (int argc, char ** argv){
227227
o.open(args.logdir + "/parameters.txt");
228228
if (o.is_open())
229229
{
230-
param->show(o); //param.show(o);
230+
param->show(o);
231231
o.close();
232232
}
233233
}
@@ -256,13 +256,10 @@ int main (int argc, char ** argv){
256256
if(!out.good()) throw std::runtime_error("Could not write anomaly algorithm parameters to the file provided");
257257
nlohmann::json out_p;
258258
out_p["func_index_map"] = global_func_index_map.serialize();
259-
out_p["alg_params"] = nlohmann::json::parse(param->serialize()); //param.serialize());
259+
out_p["alg_params"] = nlohmann::json::parse(param->serialize());
260260
out << out_p;
261261
}
262262

263263
progressStream << "Pserver: finished" << std::endl;
264-
265-
delete param;
266-
267264
return 0;
268265
}

app/pshutdown.cpp

Lines changed: 20 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,23 @@
1-
#include "chimbuko/net/zmq_net.hpp"
2-
#include "chimbuko/message.hpp"
1+
#include "chimbuko/verbose.hpp"
2+
#include "chimbuko/util/error.hpp"
3+
#include "chimbuko/ad/ADNetClient.hpp"
34

45
using namespace chimbuko;
56

6-
int main (int argc, char ** argv)
7-
{
8-
void * context;
9-
void * socket;
10-
std::string addr;
11-
Message msg;
12-
std::string strmsg;
13-
14-
if (argc > 1)
15-
addr = argv[1];
16-
17-
context = zmq_ctx_new();
18-
socket = zmq_socket(context, ZMQ_REQ);
19-
zmq_connect(socket, addr.c_str());
20-
21-
// test connection
22-
msg.clear();
23-
msg.set_info(-1, -1, MessageType::REQ_ECHO, MessageKind::DEFAULT);
24-
msg.set_msg("Hello!");
25-
26-
ZMQNet::send(socket, msg.data());
27-
28-
msg.clear();
29-
ZMQNet::recv(socket, strmsg);
30-
msg.set_msg(strmsg, true);
31-
32-
// if (msg.data_buffer().compare("Hello!>I am ZMQNET!") != 0)
33-
// {
34-
// std::cerr << "Connect error to parameter server (ZMQNET)!\n";
35-
// exit(1);
36-
// }
37-
38-
// shutdown
39-
zmq_send(socket, nullptr, 0, 0);
40-
41-
// finalize
42-
zmq_close(socket);
43-
zmq_ctx_term(context);
44-
45-
return EXIT_SUCCESS;
46-
}
7+
int main(int argc, char** argv){
8+
#ifdef _USE_ZMQNET
9+
if(argc != 2){
10+
std::cout << "Usage: pshutdown <server address e.g. tcp://localhost:5559>" << std::endl;
11+
return 1;
12+
}
13+
ADNetClient ad;
14+
ad.connect_ps(0,0,argv[1]);
15+
ad.stopServer();
16+
ad.disconnect_ps();
17+
18+
return 0;
19+
#else
20+
std::cout << "Not implemented for MPINET" << std::endl;
21+
return 1;
22+
#endif
23+
}

0 commit comments

Comments
 (0)