|
7 | 7 | #endif |
8 | 8 | #include "chimbuko/message.hpp" |
9 | 9 | #include "chimbuko/util/PerfStats.hpp" |
10 | | - |
| 10 | +#include "chimbuko/util/string.hpp" |
| 11 | +#include "chimbuko/util/error.hpp" |
| 12 | +#include "chimbuko/util/time.hpp" |
| 13 | + |
11 | 14 | namespace chimbuko{ |
12 | 15 |
|
13 | 16 |
|
@@ -113,6 +116,163 @@ namespace chimbuko{ |
113 | 116 |
|
114 | 117 |
|
115 | 118 |
|
| 119 | + //Actions performed by the worker thread |
| 120 | + struct ClientAction{ |
| 121 | + virtual void perform(ADNetClient &client) = 0; |
| 122 | + virtual bool do_delete() const = 0; //whether to delete the work object after completion |
| 123 | + virtual bool shutdown_worker() const{ return false; } //whether to shutdown the worker after completing the action |
| 124 | + virtual ~ClientAction(){} |
| 125 | + }; |
| 126 | + |
| 127 | + struct ClientActionConnect: public ClientAction{ |
| 128 | + int rank; |
| 129 | + int srank; |
| 130 | + std::string sname; |
| 131 | + |
| 132 | + ClientActionConnect(int rank, int srank, const std::string &sname): rank(rank), srank(srank), sname(sname){} |
| 133 | + |
| 134 | + void perform(ADNetClient &client){ |
| 135 | + std::cout << "Connecting to client" << std::endl; |
| 136 | + client.connect_ps(rank, srank, sname); |
| 137 | + } |
| 138 | + bool do_delete() const{ return true; } |
| 139 | + }; |
| 140 | + |
| 141 | + struct ClientActionDisconnect: public ClientAction{ |
| 142 | + void perform(ADNetClient &client){ |
| 143 | + std::cout << "Disconnecting from client" << std::endl; |
| 144 | + client.disconnect_ps(); |
| 145 | + } |
| 146 | + bool do_delete() const{ return true; } |
| 147 | + bool shutdown_worker() const{ return true; } |
| 148 | + }; |
| 149 | + |
| 150 | + //Make the worker wait for some time, for testing |
| 151 | + struct ClientActionWait: public ClientAction{ |
| 152 | + size_t wait_ms; |
| 153 | + ClientActionWait(size_t wait_ms): wait_ms(wait_ms){} |
| 154 | + |
| 155 | + void perform(ADNetClient &client){ |
| 156 | + std::cout << "Worker is waiting for "<< wait_ms << "ms" << std::endl; |
| 157 | + std::this_thread::sleep_for(std::chrono::milliseconds(wait_ms)); |
| 158 | + } |
| 159 | + bool do_delete() const{ return true; } |
| 160 | + }; |
| 161 | + |
| 162 | + struct ClientActionBlockingSendReceive: public ClientAction{ |
| 163 | + std::mutex m; |
| 164 | + std::condition_variable cv; |
| 165 | + Message *recv; |
| 166 | + Message const *send; //it's blocking so we know that the object will live long enough |
| 167 | + bool complete; |
| 168 | + |
| 169 | + ClientActionBlockingSendReceive(Message *recv, Message const *send): send(send), recv(recv), complete(false){} |
| 170 | + |
| 171 | + void perform(ADNetClient &client){ |
| 172 | + std::cout << "Performing blocking send and receive" << std::endl; |
| 173 | + client.send_and_receive(*recv, *send); |
| 174 | + complete = true; |
| 175 | + cv.notify_one(); |
| 176 | + } |
| 177 | + bool do_delete() const{ return false; } |
| 178 | + |
| 179 | + void wait_for(){ |
| 180 | + std::unique_lock<std::mutex> l(m); |
| 181 | + cv.wait(l, [&]{ return complete; }); |
| 182 | + } |
| 183 | + }; |
| 184 | + |
| 185 | + //Return message is just dumped |
| 186 | + struct ClientActionAsyncSend: public ClientAction{ |
| 187 | + Message send; //copy of send message because we don't know how long it will be before it sends |
| 188 | + |
| 189 | + ClientActionAsyncSend(const Message &send): send(send){} |
| 190 | + |
| 191 | + void perform(ADNetClient &client){ |
| 192 | + std::cout << "Performing non-blocking send and receive" << std::endl; |
| 193 | + Message recv; |
| 194 | + client.send_and_receive(recv, send); |
| 195 | + std::cout << "Non-blocking send returned " << recv.buf() << std::endl; |
| 196 | + } |
| 197 | + bool do_delete() const{ return true; } |
| 198 | + }; |
| 199 | + |
| 200 | + //ADNetClient inside a worker thread with blocking send/receive and non-blocking send |
| 201 | + class ADThreadNetClient{ |
| 202 | + std::thread worker; |
| 203 | + mutable std::mutex m; |
| 204 | + std::queue<ClientAction*> queue; |
| 205 | + |
| 206 | + size_t getNwork() const{ |
| 207 | + std::lock_guard<std::mutex> l(m); |
| 208 | + return queue.size(); |
| 209 | + } |
| 210 | + ClientAction* getWorkItem(){ |
| 211 | + std::lock_guard<std::mutex> l(m); |
| 212 | + ClientAction *work_item = queue.front(); |
| 213 | + queue.pop(); |
| 214 | + return work_item; |
| 215 | + } |
| 216 | + |
| 217 | + void run(){ |
| 218 | + std::cout << "Starting worker thread" << std::endl; |
| 219 | + worker = std::thread([&](){ |
| 220 | + ADNetClient client; |
| 221 | + bool shutdown = false; |
| 222 | + |
| 223 | + while(!shutdown){ |
| 224 | + size_t nwork = getNwork(); |
| 225 | + while(nwork > 0){ |
| 226 | + ClientAction* work_item = getWorkItem(); |
| 227 | + work_item->perform(client); |
| 228 | + shutdown = work_item->shutdown_worker(); |
| 229 | + |
| 230 | + if(work_item->do_delete()) delete work_item; |
| 231 | + nwork = getNwork(); |
| 232 | + } |
| 233 | + if(shutdown){ |
| 234 | + if(nwork > 0) fatal_error("Worker was shut down before emptying its queue!"); |
| 235 | + std::cout << "Worker received shutdown request" << std::endl; |
| 236 | + }else{ |
| 237 | + std::this_thread::sleep_for(std::chrono::milliseconds(10)); |
| 238 | + } |
| 239 | + } |
| 240 | + }); |
| 241 | + } |
| 242 | + |
| 243 | + public: |
| 244 | + ADThreadNetClient(){ |
| 245 | + run(); |
| 246 | + } |
| 247 | + |
| 248 | + //Use only if you know what you are doing! |
| 249 | + void enqueue_action(ClientAction *action){ |
| 250 | + std::lock_guard<std::mutex> l(m); |
| 251 | + queue.push(action); |
| 252 | + } |
| 253 | + |
| 254 | + void connect_ps(int rank, int srank = 0, std::string sname="MPINET"){ |
| 255 | + enqueue_action(new ClientActionConnect(rank, srank,sname)); |
| 256 | + } |
| 257 | + void disconnect_ps(){ |
| 258 | + enqueue_action(new ClientActionDisconnect()); |
| 259 | + } |
| 260 | + void send_and_receive(Message &recv, const Message &send){ |
| 261 | + ClientActionBlockingSendReceive action(&recv, &send); |
| 262 | + enqueue_action(&action); |
| 263 | + action.wait_for(); |
| 264 | + } |
| 265 | + void async_send(const Message &send){ |
| 266 | + enqueue_action(new ClientActionAsyncSend(send)); |
| 267 | + } |
| 268 | + |
| 269 | + ~ADThreadNetClient(){ |
| 270 | + std::cout << "Joining worker thread" << std::endl; |
| 271 | + worker.join(); |
| 272 | + } |
| 273 | + |
| 274 | + }; |
116 | 275 |
|
| 276 | + |
117 | 277 |
|
118 | 278 | }; |
0 commit comments