|
76 | 76 |
|
77 | 77 | import dataclasses |
78 | 78 | import functools |
| 79 | +import itertools |
79 | 80 | import logging |
80 | 81 | import os |
81 | 82 | import queue |
|
100 | 101 |
|
101 | 102 | handler = logging.StreamHandler(sys.stdout) |
102 | 103 | handler.setLevel(logging.DEBUG) |
103 | | -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| 104 | +formatter = logging.Formatter( |
| 105 | + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| 106 | +) |
104 | 107 | handler.setFormatter(formatter) |
105 | 108 | root.addHandler(handler) |
106 | 109 |
|
| 110 | + |
107 | 111 | def delete_pytree(p): |
108 | 112 | def delete_leaf(leaf): |
109 | 113 | if isinstance(leaf, jax.Array): |
110 | 114 | leaf.delete() |
111 | 115 | del leaf |
| 116 | + |
112 | 117 | jax.tree_map(delete_leaf, p) |
113 | 118 |
|
114 | 119 |
|
@@ -185,19 +190,19 @@ class Driver: |
185 | 190 | _prefill_params: Optional[dict[int, Any]] = {} |
186 | 191 | _generate_params: Optional[dict[int, Any]] = {} |
187 | 192 | # Stage 1 |
188 | | - _prefill_backlog: queue.Queue[ActiveRequest] |
| 193 | + _prefill_backlog: queue.Queue[ActiveRequest | None] |
189 | 194 | # Stage 2 |
190 | 195 | # We keep this as a dict to avoid a possibly expensive object comparison |
191 | 196 | # when logging the index of the generate engine we send a prefill result |
192 | 197 | # to, it allows us to natively have the index from the min operation, rather |
193 | 198 | # than have to call .index() |
194 | | - _generate_backlogs: dict[int, queue.Queue[ActiveRequest]] = {} |
| 199 | + _generate_backlogs: dict[int, queue.Queue[ActiveRequest | None]] = {} |
195 | 200 | # Stage 3 |
196 | 201 | # This can be a list because we can pass it as an arg to generate and |
197 | 202 | # detokenize threads. It is a list of tokens to be detokenized. |
198 | 203 | _detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] |
199 | 204 | _generate_slots: list[queue.Queue[int]] = [] |
200 | | - _active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = [] |
| 205 | + _active_requests: list[queue.Queue[tuple[int, ActiveRequest | None]]] = [] |
201 | 206 |
|
202 | 207 | def __init__( |
203 | 208 | self, |
@@ -296,11 +301,58 @@ def __init__( |
296 | 301 | ) |
297 | 302 | for idx, engine in enumerate(self._generate_engines) |
298 | 303 | ] |
| 304 | + self._all_threads = list( |
| 305 | + itertools.chain( |
| 306 | + self._prefill_threads, |
| 307 | + self._generate_threads, |
| 308 | + self.detokenize_threads, |
| 309 | + ) |
| 310 | + ) |
299 | 311 | self.live = True |
300 | 312 | # Kick off all threads |
301 | | - _ = [f.start() for f in self._prefill_threads] |
302 | | - _ = [f.start() for f in self._generate_threads] |
303 | | - _ = [f.start() for f in self.detokenize_threads] |
| 313 | + for t in self._all_threads: |
| 314 | + t.start() |
| 315 | + |
| 316 | + def stop(self): |
| 317 | + """Stops the driver and all background threads.""" |
| 318 | + # Signal to all threads that they should stop. |
| 319 | + self.live = False |
| 320 | + |
| 321 | + all_backlogs = list( |
| 322 | + itertools.chain( |
| 323 | + [self._prefill_backlog], |
| 324 | + self._generate_backlogs.values(), |
| 325 | + self._detokenize_backlogs, |
| 326 | + ) |
| 327 | + ) |
| 328 | + |
| 329 | + while any(t.is_alive() for t in self._all_threads): |
| 330 | + # Empty all backlogs and mark any remaining requests as cancelled. |
| 331 | + for q in all_backlogs: |
| 332 | + while True: |
| 333 | + try: |
| 334 | + r = q.get_nowait() |
| 335 | + if r is None: |
| 336 | + continue |
| 337 | + elif isinstance(r, ActiveRequest): |
| 338 | + r.return_channel = None |
| 339 | + else: # detokenize backlog |
| 340 | + _, r = r |
| 341 | + if isinstance(r, ActiveRequest): |
| 342 | + r.return_channel = None |
| 343 | + except queue.Empty: |
| 344 | + break |
| 345 | + |
| 346 | + # Put sentinels to unblock threads. |
| 347 | + for q in all_backlogs: |
| 348 | + try: |
| 349 | + q.put_nowait(None) |
| 350 | + except queue.Full: |
| 351 | + pass |
| 352 | + |
| 353 | + # Wait for all threads to stop. |
| 354 | + for t in self._all_threads: |
| 355 | + t.join() |
304 | 356 |
|
305 | 357 | def get_total_concurrent_requests(self) -> int: |
306 | 358 | """Returns the total number of concurrent requests the driver can service.""" |
@@ -338,10 +390,12 @@ def _prefill_thread( |
338 | 390 | while self.live: |
339 | 391 | # We don't want to keep lots of kv caches live in memory on the prefill |
340 | 392 | # slice that aren't about to be sent over to a generation slice. |
341 | | - if (self._generate_backlogs[idx].qsize() < generate_backpressure): |
| 393 | + if self._generate_backlogs[idx].qsize() < generate_backpressure: |
342 | 394 | # Check if there is anything on the prefill backlog, pop if so. |
343 | 395 | try: |
344 | 396 | request = self._prefill_backlog.get(block=True) |
| 397 | + if request is None: |
| 398 | + break |
345 | 399 | # TODO: Implement hot/cold cache for history. |
346 | 400 | history = self._load_cache_history(request.history_path) # pylint: disable = assignment-from-none |
347 | 401 | # Tokenize, and introduce a leading dimension |
@@ -372,7 +426,8 @@ def _prefill_thread( |
372 | 426 | # Once prefill is complete, place it on the generation queue. |
373 | 427 | self._generate_backlogs[idx].put(request) |
374 | 428 | logging.info( |
375 | | - f'Placed request on the generate queue, {self._generate_backlogs[idx].qsize()=}' |
| 429 | + 'Placed request on the generate queue,' |
| 430 | + f' {self._generate_backlogs[idx].qsize()=}' |
376 | 431 | ) |
377 | 432 | except queue.Empty: |
378 | 433 | # Otherwise, don't do anything! |
@@ -410,13 +465,16 @@ def _generate_thread( |
410 | 465 | while self.live: |
411 | 466 | if (time.time() - time_of_last_print) > 1: |
412 | 467 | logging.info( |
413 | | - f'Generate thread making a decision with: prefill_backlog={self._prefill_backlog.qsize()} generate_free_slots={my_slots.qsize()}' |
| 468 | + 'Generate thread making a decision with:' |
| 469 | + f' prefill_backlog={self._prefill_backlog.qsize()} generate_free_slots={my_slots.qsize()}' |
414 | 470 | ) |
415 | 471 | time_of_last_print = time.time() |
416 | 472 | # Check if there are any free my_slots. |
417 | 473 | if not my_slots.empty() and not self._generate_backlogs[idx].empty(): |
418 | 474 | # Only get requests from the backlog corresponding to this engine. |
419 | 475 | new_request = self._generate_backlogs[idx].get() |
| 476 | + if new_request is None: |
| 477 | + break |
420 | 478 | slot = my_slots.get() |
421 | 479 | logging.info( |
422 | 480 | 'Generate slice %d slot %d step %d', |
@@ -475,6 +533,8 @@ def _detokenize_thread( |
475 | 533 | while self.live: |
476 | 534 | try: |
477 | 535 | data = my_detokenize_backlog.get(block=True) |
| 536 | + if data is None: |
| 537 | + break |
478 | 538 | start_detokenise_time = time.time() |
479 | 539 | if isinstance(data[1], engine_api.ResultTokens): |
480 | 540 | # We want to detokenise them. |
|
0 commit comments