Skip to content

Commit 0b53592

Browse files
committed
Fixed bug caused by the bool 'local' parameter of ADThreadNetClient::run not being captured by copy leaving a dangling reference
ADThreadNetClient cleanup: Moved function implementations to cpp file ClientAction is now a member class of ADThreadNetClient Added documentation for many functions
1 parent db8276a commit 0b53592

2 files changed

Lines changed: 265 additions & 178 deletions

File tree

include/chimbuko/ad/ADNetClient.hpp

Lines changed: 91 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <chimbuko_config.h>
34
#ifdef _USE_MPINET
45
#include "chimbuko/net/mpi_net.hpp"
56
#else
@@ -193,7 +194,8 @@ namespace chimbuko{
193194

194195
#endif
195196

196-
197+
198+
197199
/**
198200
* @brief Implementation of ADNetClient for intraprocess communications
199201
*/
@@ -224,208 +226,119 @@ namespace chimbuko{
224226
std::string send_and_receive(const Message &msg) const override;
225227
};
226228

227-
//Actions performed by the worker thread
228-
struct ClientAction{
229-
virtual void perform(ADNetClient &client) = 0;
230-
virtual bool do_delete() const = 0; //whether to delete the work object after completion
231-
virtual bool shutdown_worker() const{ return false; } //whether to shutdown the worker after completing the action
232-
virtual ~ClientAction(){}
233-
};
234-
235-
struct ClientActionConnect: public ClientAction{
236-
int rank;
237-
int srank;
238-
std::string sname;
239-
240-
ClientActionConnect(int rank, int srank, const std::string &sname): rank(rank), srank(srank), sname(sname){}
241-
242-
void perform(ADNetClient &client){
243-
std::cout << "Connecting to client" << std::endl;
244-
client.connect_ps(rank, srank, sname);
245-
}
246-
bool do_delete() const{ return true; }
247-
};
248-
249-
struct ClientActionDisconnect: public ClientAction{
250-
void perform(ADNetClient &client){
251-
std::cout << "Disconnecting from client" << std::endl;
252-
client.disconnect_ps();
253-
}
254-
bool do_delete() const{ return true; }
255-
bool shutdown_worker() const{ return true; }
256-
};
257-
258-
//Make the worker wait for some time, for testing
259-
struct ClientActionWait: public ClientAction{
260-
size_t wait_ms;
261-
ClientActionWait(size_t wait_ms): wait_ms(wait_ms){}
262-
263-
void perform(ADNetClient &client){
264-
std::this_thread::sleep_for(std::chrono::milliseconds(wait_ms));
265-
}
266-
bool do_delete() const{ return true; }
267-
};
268-
269-
struct ClientActionBlockingSendReceive: public ClientAction{
270-
std::mutex m;
271-
std::condition_variable cv;
272-
Message *recv;
273-
Message const *send; //it's blocking so we know that the object will live long enough
274-
bool complete;
275-
276-
ClientActionBlockingSendReceive(Message *recv, Message const *send): send(send), recv(recv), complete(false){}
277-
278-
void perform(ADNetClient &client){
279-
client.send_and_receive(*recv, *send);
280-
281-
{
282-
std::unique_lock<std::mutex> lk(m);
283-
complete = true;
284-
cv.notify_one();
285-
}
286-
}
287-
bool do_delete() const{ return false; }
288-
289-
void wait_for(){
290-
std::unique_lock<std::mutex> l(m);
291-
cv.wait(l, [&]{ return complete; });
292-
}
293-
};
294-
295-
//Return message is just dumped
296-
struct ClientActionAsyncSend: public ClientAction{
297-
Message send; //copy of send message because we don't know how long it will be before it sends
298-
299-
ClientActionAsyncSend(const Message &send): send(send){}
300-
301-
void perform(ADNetClient &client){
302-
Message recv;
303-
client.send_and_receive(recv, send);
304-
}
305-
bool do_delete() const{ return true; }
306-
};
307-
308-
struct ClientActionSetRecvTimeout: public ClientAction{
309-
int timeout;
310-
ClientActionSetRecvTimeout(const int timeout): timeout(timeout){}
311-
312-
void perform(ADNetClient &client){
313-
client.setRecvTimeout(timeout);
314-
}
315229

316-
bool do_delete() const{return true;}
317-
};
318-
319-
//ADNetClient inside a worker thread with blocking send/receive and non-blocking send
230+
/**
231+
* @brief ADNetClient inside a worker thread with blocking send/receive and non-blocking send
232+
*/
320233
class ADThreadNetClient{
234+
public:
235+
/**
236+
* @brief Virtual class representing actions performed by the worker thread
237+
*/
238+
struct ClientAction{
239+
/**
240+
* @brief Perform the action utilizing the underlying net implementation
241+
*/
242+
virtual void perform(ADNetClient &client) = 0;
243+
244+
/**
245+
* @brief Whether to delete the work object (instance of ClientAction) after completion
246+
*/
247+
virtual bool do_delete() const = 0;
248+
249+
/**
250+
* @brief Whether to shutdown the worker thread after completing the action
251+
*/
252+
virtual bool shutdown_worker() const{ return false; }
253+
254+
virtual ~ClientAction(){}
255+
};
256+
257+
private:
321258
std::thread worker;
322259
mutable std::mutex m;
323-
std::queue<ClientAction*> queue;
324-
325-
size_t getNwork() const{
326-
std::lock_guard<std::mutex> l(m);
327-
return queue.size();
328-
}
329-
ClientAction* getWorkItem(){
330-
std::lock_guard<std::mutex> l(m);
331-
ClientAction *work_item = queue.front();
332-
queue.pop();
333-
return work_item;
334-
}
260+
std::queue<ClientAction*> queue; /**< The queue of net operations*/
261+
262+
int m_rank;
263+
int m_srank;
264+
bool m_use_ps;
265+
PerfStats * m_perf;
266+
267+
/**
268+
* @brief Get the number of outstanding net operations
269+
*/
270+
size_t getNwork() const;
271+
/**
272+
* @brief Get the next net operation
273+
*/
274+
ClientAction* getWorkItem();
335275

336276
/**
337277
* @brief Create the worker thread
338278
* @param local Use a local (in process) communicator if true, otherwise use the default network communicator
339279
*/
340-
void run(bool local = false){
341-
worker = std::thread([&](){
342-
ADNetClient *client = nullptr;
343-
if(local){
344-
client = new ADLocalNetClient;
345-
}else{
346-
#ifdef _USE_MPINET
347-
client = new ADMPINetClient;
348-
#else
349-
client = new ADZMQNetClient;
350-
#endif
351-
}
352-
bool shutdown = false;
353-
354-
while(!shutdown){
355-
size_t nwork = getNwork();
356-
while(nwork > 0){
357-
ClientAction* work_item = getWorkItem();
358-
work_item->perform(*client);
359-
shutdown = shutdown || work_item->shutdown_worker();
360-
361-
if(work_item->do_delete()) delete work_item;
362-
nwork = getNwork();
363-
}
364-
if(shutdown){
365-
if(nwork > 0) fatal_error("Worker was shut down before emptying its queue!");
366-
}else{
367-
std::this_thread::sleep_for(std::chrono::milliseconds(80));
368-
}
369-
}
370-
delete client;
371-
});
372-
}
280+
void run(bool local = false);
373281

374282
public:
375283
/**
376284
* @brief Constructor
377285
* @param local Use a local (in process) communicator if true, otherwise use the default network communicator
378286
*/
379-
ADThreadNetClient(bool local = false){
380-
run(local);
381-
}
382-
383-
//Use only if you know what you are doing!
384-
void enqueue_action(ClientAction *action){
385-
std::lock_guard<std::mutex> l(m);
386-
queue.push(action);
387-
}
388-
389-
void connect_ps(int rank, int srank = 0, std::string sname="MPINET"){
390-
m_rank = rank;
391-
m_srank = srank;
392-
enqueue_action(new ClientActionConnect(rank, srank,sname));
393-
m_use_ps = true;
394-
}
395-
void disconnect_ps(){
396-
enqueue_action(new ClientActionDisconnect());
397-
}
398-
void send_and_receive(Message &recv, const Message &send){
399-
ClientActionBlockingSendReceive action(&recv, &send);
400-
enqueue_action(&action);
401-
action.wait_for();
402-
}
403-
void async_send(const Message &send){
404-
enqueue_action(new ClientActionAsyncSend(send));
405-
}
287+
ADThreadNetClient(bool local = false);
288+
289+
/**
290+
* @brief Add an action to the queue
291+
*
292+
*Use only if you know what you are doing!
293+
*/
294+
void enqueue_action(ClientAction *action);
295+
296+
/**
297+
* @brief Connect to the parameter server
298+
*/
299+
void connect_ps(int rank, int srank = 0, std::string sname="MPINET");
300+
301+
/**
302+
* @brief Disconnect from the parameter server
303+
*/
304+
void disconnect_ps();
305+
306+
/**
307+
* @brief Perform a blocking send and receive operation
308+
*/
309+
void send_and_receive(Message &recv, const Message &send);
406310

311+
/**
312+
* @brief Perform a non-blocking send operation
313+
*/
314+
void async_send(const Message &send);
315+
316+
/**
317+
* @brief Is the parameter server in use / connected?
318+
*/
407319
bool use_ps() const { return m_use_ps; }
408320

321+
/**
322+
* @brief Link a performance monitoring instance
323+
*/
409324
void linkPerf(PerfStats* perf){ m_perf = perf; }
410325

326+
/**
327+
* @brief Get the MPI rank of the server (MPINet)
328+
*/
411329
int get_server_rank() const{ return m_srank; }
412330

331+
/**
332+
* @brief Get the MPI rank of the client
333+
*/
413334
int get_client_rank() const{ return m_rank; }
414335

415-
void setRecvTimeout(const int timeout_ms) {
416-
enqueue_action(new ClientActionSetRecvTimeout(timeout_ms));
417-
}
336+
/**
337+
* @brief Set a timeout (in ms) on receiving a response message
338+
*/
339+
void setRecvTimeout(const int timeout_ms);
418340

419-
~ADThreadNetClient(){
420-
disconnect_ps();
421-
worker.join();
422-
}
423-
424-
private:
425-
int m_rank;
426-
int m_srank;
427-
bool m_use_ps;
428-
PerfStats * m_perf;
341+
~ADThreadNetClient();
429342
};
430343

431344

0 commit comments

Comments
 (0)