|
1 | 1 | #pragma once |
2 | 2 |
|
| 3 | +#include <chimbuko_config.h> |
3 | 4 | #ifdef _USE_MPINET |
4 | 5 | #include "chimbuko/net/mpi_net.hpp" |
5 | 6 | #else |
@@ -193,7 +194,8 @@ namespace chimbuko{ |
193 | 194 |
|
194 | 195 | #endif |
195 | 196 |
|
196 | | - |
| 197 | + |
| 198 | + |
197 | 199 | /** |
198 | 200 | * @brief Implementation of ADNetClient for intraprocess communications |
199 | 201 | */ |
@@ -224,208 +226,119 @@ namespace chimbuko{ |
224 | 226 | std::string send_and_receive(const Message &msg) const override; |
225 | 227 | }; |
226 | 228 |
|
227 | | - //Actions performed by the worker thread |
228 | | - struct ClientAction{ |
229 | | - virtual void perform(ADNetClient &client) = 0; |
230 | | - virtual bool do_delete() const = 0; //whether to delete the work object after completion |
231 | | - virtual bool shutdown_worker() const{ return false; } //whether to shutdown the worker after completing the action |
232 | | - virtual ~ClientAction(){} |
233 | | - }; |
234 | | - |
235 | | - struct ClientActionConnect: public ClientAction{ |
236 | | - int rank; |
237 | | - int srank; |
238 | | - std::string sname; |
239 | | - |
240 | | - ClientActionConnect(int rank, int srank, const std::string &sname): rank(rank), srank(srank), sname(sname){} |
241 | | - |
242 | | - void perform(ADNetClient &client){ |
243 | | - std::cout << "Connecting to client" << std::endl; |
244 | | - client.connect_ps(rank, srank, sname); |
245 | | - } |
246 | | - bool do_delete() const{ return true; } |
247 | | - }; |
248 | | - |
249 | | - struct ClientActionDisconnect: public ClientAction{ |
250 | | - void perform(ADNetClient &client){ |
251 | | - std::cout << "Disconnecting from client" << std::endl; |
252 | | - client.disconnect_ps(); |
253 | | - } |
254 | | - bool do_delete() const{ return true; } |
255 | | - bool shutdown_worker() const{ return true; } |
256 | | - }; |
257 | | - |
258 | | - //Make the worker wait for some time, for testing |
259 | | - struct ClientActionWait: public ClientAction{ |
260 | | - size_t wait_ms; |
261 | | - ClientActionWait(size_t wait_ms): wait_ms(wait_ms){} |
262 | | - |
263 | | - void perform(ADNetClient &client){ |
264 | | - std::this_thread::sleep_for(std::chrono::milliseconds(wait_ms)); |
265 | | - } |
266 | | - bool do_delete() const{ return true; } |
267 | | - }; |
268 | | - |
269 | | - struct ClientActionBlockingSendReceive: public ClientAction{ |
270 | | - std::mutex m; |
271 | | - std::condition_variable cv; |
272 | | - Message *recv; |
273 | | - Message const *send; //it's blocking so we know that the object will live long enough |
274 | | - bool complete; |
275 | | - |
276 | | - ClientActionBlockingSendReceive(Message *recv, Message const *send): send(send), recv(recv), complete(false){} |
277 | | - |
278 | | - void perform(ADNetClient &client){ |
279 | | - client.send_and_receive(*recv, *send); |
280 | | - |
281 | | - { |
282 | | - std::unique_lock<std::mutex> lk(m); |
283 | | - complete = true; |
284 | | - cv.notify_one(); |
285 | | - } |
286 | | - } |
287 | | - bool do_delete() const{ return false; } |
288 | | - |
289 | | - void wait_for(){ |
290 | | - std::unique_lock<std::mutex> l(m); |
291 | | - cv.wait(l, [&]{ return complete; }); |
292 | | - } |
293 | | - }; |
294 | | - |
295 | | - //Return message is just dumped |
296 | | - struct ClientActionAsyncSend: public ClientAction{ |
297 | | - Message send; //copy of send message because we don't know how long it will be before it sends |
298 | | - |
299 | | - ClientActionAsyncSend(const Message &send): send(send){} |
300 | | - |
301 | | - void perform(ADNetClient &client){ |
302 | | - Message recv; |
303 | | - client.send_and_receive(recv, send); |
304 | | - } |
305 | | - bool do_delete() const{ return true; } |
306 | | - }; |
307 | | - |
308 | | - struct ClientActionSetRecvTimeout: public ClientAction{ |
309 | | - int timeout; |
310 | | - ClientActionSetRecvTimeout(const int timeout): timeout(timeout){} |
311 | | - |
312 | | - void perform(ADNetClient &client){ |
313 | | - client.setRecvTimeout(timeout); |
314 | | - } |
315 | 229 |
|
316 | | - bool do_delete() const{return true;} |
317 | | - }; |
318 | | - |
319 | | - //ADNetClient inside a worker thread with blocking send/receive and non-blocking send |
| 230 | + /** |
| 231 | + * @brief ADNetClient inside a worker thread with blocking send/receive and non-blocking send |
| 232 | + */ |
320 | 233 | class ADThreadNetClient{ |
| 234 | + public: |
| 235 | + /** |
| 236 | + * @brief Virtual class representing actions performed by the worker thread |
| 237 | + */ |
| 238 | + struct ClientAction{ |
| 239 | + /** |
| 240 | + * @brief Perform the action utilizing the underlying net implementation |
| 241 | + */ |
| 242 | + virtual void perform(ADNetClient &client) = 0; |
| 243 | + |
| 244 | + /** |
| 245 | + * @brief Whether to delete the work object (instance of ClientAction) after completion |
| 246 | + */ |
| 247 | + virtual bool do_delete() const = 0; |
| 248 | + |
| 249 | + /** |
| 250 | + * @brief Whether to shutdown the worker thread after completing the action |
| 251 | + */ |
| 252 | + virtual bool shutdown_worker() const{ return false; } |
| 253 | + |
| 254 | + virtual ~ClientAction(){} |
| 255 | + }; |
| 256 | + |
| 257 | + private: |
321 | 258 | std::thread worker; |
322 | 259 | mutable std::mutex m; |
323 | | - std::queue<ClientAction*> queue; |
324 | | - |
325 | | - size_t getNwork() const{ |
326 | | - std::lock_guard<std::mutex> l(m); |
327 | | - return queue.size(); |
328 | | - } |
329 | | - ClientAction* getWorkItem(){ |
330 | | - std::lock_guard<std::mutex> l(m); |
331 | | - ClientAction *work_item = queue.front(); |
332 | | - queue.pop(); |
333 | | - return work_item; |
334 | | - } |
| 260 | + std::queue<ClientAction*> queue; /**< The queue of net operations*/ |
| 261 | + |
| 262 | + int m_rank; |
| 263 | + int m_srank; |
| 264 | + bool m_use_ps; |
| 265 | + PerfStats * m_perf; |
| 266 | + |
| 267 | + /** |
| 268 | + * @brief Get the number of outstanding net operations |
| 269 | + */ |
| 270 | + size_t getNwork() const; |
| 271 | + /** |
| 272 | + * @brief Get the next net operation |
| 273 | + */ |
| 274 | + ClientAction* getWorkItem(); |
335 | 275 |
|
336 | 276 | /** |
337 | 277 | * @brief Create the worker thread |
338 | 278 | * @param local Use a local (in process) communicator if true, otherwise use the default network communicator |
339 | 279 | */ |
340 | | - void run(bool local = false){ |
341 | | - worker = std::thread([&](){ |
342 | | - ADNetClient *client = nullptr; |
343 | | - if(local){ |
344 | | - client = new ADLocalNetClient; |
345 | | - }else{ |
346 | | -#ifdef _USE_MPINET |
347 | | - client = new ADMPINetClient; |
348 | | -#else |
349 | | - client = new ADZMQNetClient; |
350 | | -#endif |
351 | | - } |
352 | | - bool shutdown = false; |
353 | | - |
354 | | - while(!shutdown){ |
355 | | - size_t nwork = getNwork(); |
356 | | - while(nwork > 0){ |
357 | | - ClientAction* work_item = getWorkItem(); |
358 | | - work_item->perform(*client); |
359 | | - shutdown = shutdown || work_item->shutdown_worker(); |
360 | | - |
361 | | - if(work_item->do_delete()) delete work_item; |
362 | | - nwork = getNwork(); |
363 | | - } |
364 | | - if(shutdown){ |
365 | | - if(nwork > 0) fatal_error("Worker was shut down before emptying its queue!"); |
366 | | - }else{ |
367 | | - std::this_thread::sleep_for(std::chrono::milliseconds(80)); |
368 | | - } |
369 | | - } |
370 | | - delete client; |
371 | | - }); |
372 | | - } |
| 280 | + void run(bool local = false); |
373 | 281 |
|
374 | 282 | public: |
375 | 283 | /** |
376 | 284 | * @brief Constructor |
377 | 285 | * @param local Use a local (in process) communicator if true, otherwise use the default network communicator |
378 | 286 | */ |
379 | | - ADThreadNetClient(bool local = false){ |
380 | | - run(local); |
381 | | - } |
382 | | - |
383 | | - //Use only if you know what you are doing! |
384 | | - void enqueue_action(ClientAction *action){ |
385 | | - std::lock_guard<std::mutex> l(m); |
386 | | - queue.push(action); |
387 | | - } |
388 | | - |
389 | | - void connect_ps(int rank, int srank = 0, std::string sname="MPINET"){ |
390 | | - m_rank = rank; |
391 | | - m_srank = srank; |
392 | | - enqueue_action(new ClientActionConnect(rank, srank,sname)); |
393 | | - m_use_ps = true; |
394 | | - } |
395 | | - void disconnect_ps(){ |
396 | | - enqueue_action(new ClientActionDisconnect()); |
397 | | - } |
398 | | - void send_and_receive(Message &recv, const Message &send){ |
399 | | - ClientActionBlockingSendReceive action(&recv, &send); |
400 | | - enqueue_action(&action); |
401 | | - action.wait_for(); |
402 | | - } |
403 | | - void async_send(const Message &send){ |
404 | | - enqueue_action(new ClientActionAsyncSend(send)); |
405 | | - } |
| 287 | + ADThreadNetClient(bool local = false); |
| 288 | + |
| 289 | + /** |
| 290 | + * @brief Add an action to the queue |
| 291 | + * |
| 292 | + *Use only if you know what you are doing! |
| 293 | + */ |
| 294 | + void enqueue_action(ClientAction *action); |
| 295 | + |
| 296 | + /** |
| 297 | + * @brief Connect to the parameter server |
| 298 | + */ |
| 299 | + void connect_ps(int rank, int srank = 0, std::string sname="MPINET"); |
| 300 | + |
| 301 | + /** |
| 302 | + * @brief Disconnect from the parameter server |
| 303 | + */ |
| 304 | + void disconnect_ps(); |
| 305 | + |
| 306 | + /** |
| 307 | + * @brief Perform a blocking send and receive operation |
| 308 | + */ |
| 309 | + void send_and_receive(Message &recv, const Message &send); |
406 | 310 |
|
| 311 | + /** |
| 312 | + * @brief Perform a non-blocking send operation |
| 313 | + */ |
| 314 | + void async_send(const Message &send); |
| 315 | + |
| 316 | + /** |
| 317 | + * @brief Is the parameter server in use / connected? |
| 318 | + */ |
407 | 319 | bool use_ps() const { return m_use_ps; } |
408 | 320 |
|
| 321 | + /** |
| 322 | + * @brief Link a performance monitoring instance |
| 323 | + */ |
409 | 324 | void linkPerf(PerfStats* perf){ m_perf = perf; } |
410 | 325 |
|
| 326 | + /** |
| 327 | + * @brief Get the MPI rank of the server (MPINet) |
| 328 | + */ |
411 | 329 | int get_server_rank() const{ return m_srank; } |
412 | 330 |
|
| 331 | + /** |
| 332 | + * @brief Get the MPI rank of the client |
| 333 | + */ |
413 | 334 | int get_client_rank() const{ return m_rank; } |
414 | 335 |
|
415 | | - void setRecvTimeout(const int timeout_ms) { |
416 | | - enqueue_action(new ClientActionSetRecvTimeout(timeout_ms)); |
417 | | - } |
| 336 | + /** |
| 337 | + * @brief Set a timeout (in ms) on receiving a response message |
| 338 | + */ |
| 339 | + void setRecvTimeout(const int timeout_ms); |
418 | 340 |
|
419 | | - ~ADThreadNetClient(){ |
420 | | - disconnect_ps(); |
421 | | - worker.join(); |
422 | | - } |
423 | | - |
424 | | - private: |
425 | | - int m_rank; |
426 | | - int m_srank; |
427 | | - bool m_use_ps; |
428 | | - PerfStats * m_perf; |
| 341 | + ~ADThreadNetClient(); |
429 | 342 | }; |
430 | 343 |
|
431 | 344 |
|
|
0 commit comments