|
85 | 85 | import threading |
86 | 86 | import time |
87 | 87 | import traceback |
88 | | -from typing import Any, AsyncIterator, Optional, Tuple, Union, cast |
| 88 | +from typing import Any, AsyncIterator, Optional, Tuple, cast |
89 | 89 |
|
90 | 90 | import grpc |
91 | 91 | import jax |
@@ -434,13 +434,6 @@ def place_request_on_prefill_queue(self, request: ActiveRequest): |
434 | 434 | self._prefill_backlog.put(request, block=False) |
435 | 435 | self._prefill_backlog_size_metric.set(self._prefill_backlog.qsize()) |
436 | 436 |
|
437 | | - def _load_cache_history(self, path: str) -> Union[None, Any]: |
438 | | - """Loads previous kv cache for a longer conversation.""" |
439 | | - if path: |
440 | | - raise NotImplementedError |
441 | | - else: |
442 | | - return None |
443 | | - |
444 | 437 | def _process_prefill_content( |
445 | 438 | self, |
446 | 439 | request: ActiveRequest, |
@@ -744,6 +737,60 @@ def _get_prefill_content( |
744 | 737 | True, |
745 | 738 | ) |
746 | 739 |
|
| 740 | + def process_client_side_tokenization_response(self, response: Any): |
| 741 | + samples = [] |
| 742 | + for sample in response: |
| 743 | + samples.append( |
| 744 | + jetstream_pb2.DecodeResponse.StreamContent.Sample( |
| 745 | + token_ids=sample.token_ids, |
| 746 | + ) |
| 747 | + ) |
| 748 | + return jetstream_pb2.DecodeResponse( |
| 749 | + stream_content=jetstream_pb2.DecodeResponse.StreamContent( |
| 750 | + samples=samples |
| 751 | + ) |
| 752 | + ) |
| 753 | + |
| 754 | + def should_buffer_response(self, response: Any) -> bool: |
| 755 | + for item in response: |
| 756 | + if item.text and token_utils.is_byte_token(item.text[-1]): |
| 757 | + # If any sample ends in bytes, this means we might still need to |
| 758 | + # decode more bytes to compose the string. |
| 759 | + return True |
| 760 | + |
| 761 | + def process_server_side_tokenization_response( |
| 762 | + self, response: Any, buffered_response_list |
| 763 | + ): |
| 764 | + # Flush the buffered responses to each sample of current response. |
| 765 | + current_response_with_flushed_buffer = list( |
| 766 | + zip(*buffered_response_list, response) |
| 767 | + ) |
| 768 | + # Empty buffer: [[s0_cur], [s1_cur], ...] |
| 769 | + # Has buffer: |
| 770 | + # [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...] |
| 771 | + current_response_with_flushed_buffer = cast( |
| 772 | + list[list[ReturnSample]], current_response_with_flushed_buffer |
| 773 | + ) |
| 774 | + # Form correct sample(s) and return as StreamContent for this iteration. |
| 775 | + samples = [] |
| 776 | + for sample in current_response_with_flushed_buffer: |
| 777 | + text = [] |
| 778 | + token_ids = [] |
| 779 | + for resp in sample: |
| 780 | + text.extend(resp.text) |
| 781 | + token_ids.extend(resp.token_ids) |
| 782 | + samples.append( |
| 783 | + jetstream_pb2.DecodeResponse.StreamContent.Sample( |
| 784 | + text=token_utils.text_tokens_to_str(text), |
| 785 | + token_ids=token_ids, |
| 786 | + ) |
| 787 | + ) |
| 788 | + return jetstream_pb2.DecodeResponse( |
| 789 | + stream_content=jetstream_pb2.DecodeResponse.StreamContent( |
| 790 | + samples=samples |
| 791 | + ) |
| 792 | + ) |
| 793 | + |
747 | 794 | async def Decode( # pylint: disable=invalid-overridden-method |
748 | 795 | self, |
749 | 796 | request: jetstream_pb2.DecodeRequest, |
@@ -795,70 +842,24 @@ async def Decode( # pylint: disable=invalid-overridden-method |
795 | 842 | # The DecodeResponse stream should consume all generated tokens in |
796 | 843 | # return_channel when complete signal is received (AsyncMultifuture |
797 | 844 | # promises this). |
798 | | - if is_client_side_tokenization: |
799 | | - # If is_client_side_tokenization, the client should request with token |
800 | | - # ids, and the JetStream server will return token ids as response. |
801 | | - # The client should take care of tokenization and detokenization. |
802 | | - async for response in active_request.return_channel: |
803 | | - response = cast(list[ReturnSample], response) |
804 | | - samples = [] |
805 | | - for sample in response: |
806 | | - samples.append( |
807 | | - jetstream_pb2.DecodeResponse.StreamContent.Sample( |
808 | | - token_ids=sample.token_ids, |
809 | | - ) |
810 | | - ) |
811 | | - yield jetstream_pb2.DecodeResponse( |
812 | | - stream_content=jetstream_pb2.DecodeResponse.StreamContent( |
813 | | - samples=samples |
814 | | - ) |
815 | | - ) |
816 | | - else: |
817 | | - # Buffer response mechanism is used to handle streaming |
818 | | - # detokenization with special character (For some edge cases with |
819 | | - # SentencePiece tokenizer, it requires to decode a complete sequence |
820 | | - # instead of a single token). |
821 | | - buffered_response_list = [] |
822 | | - async for response in active_request.return_channel: |
823 | | - response = cast(list[ReturnSample], response) |
824 | | - buffered = False |
825 | | - for item in response: |
826 | | - if item.text and token_utils.is_byte_token(item.text[-1]): |
827 | | - # If any sample ends in bytes, this means we might still need to |
828 | | - # decode more bytes to compose the string. |
829 | | - buffered_response_list.append(response) |
830 | | - buffered = True |
831 | | - break |
832 | | - if buffered: |
| 845 | + buffered_response_list = [] |
| 846 | + async for response in active_request.return_channel: |
| 847 | + response = cast(list[ReturnSample], response) |
| 848 | + if is_client_side_tokenization: |
| 849 | + # If is_client_side_tokenization, the client should request with token |
| 850 | + # ids, and the JetStream server will return token ids as response. |
| 851 | + # The client should take care of tokenization and detokenization. |
| 852 | + yield self.process_client_side_tokenization_response(response) |
| 853 | + else: |
| 854 | + # Buffer response mechanism is used to handle streaming |
| 855 | + # detokenization with special character (For some edge cases with |
| 856 | + # SentencePiece tokenizer, it requires to decode a complete sequence |
| 857 | + # instead of a single token). |
| 858 | + if self.should_buffer_response(response): |
| 859 | + buffered_response_list.append(response) |
833 | 860 | continue |
834 | | - # Flush the buffered responses to each sample of current response. |
835 | | - current_response_with_flushed_buffer = list( |
836 | | - zip(*buffered_response_list, response) |
837 | | - ) |
838 | | - # Empty buffer: [[s0_cur], [s1_cur], ...] |
839 | | - # Has buffer: |
840 | | - # [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...] |
841 | | - current_response_with_flushed_buffer = cast( |
842 | | - list[list[ReturnSample]], current_response_with_flushed_buffer |
| 861 | + yield self.process_server_side_tokenization_response( |
| 862 | + response, buffered_response_list |
843 | 863 | ) |
844 | 864 | # Reset buffer after flushed. |
845 | 865 | buffered_response_list = [] |
846 | | - # Form correct sample(s) and return as StreamContent for this iteration. |
847 | | - samples = [] |
848 | | - for sample in current_response_with_flushed_buffer: |
849 | | - text = [] |
850 | | - token_ids = [] |
851 | | - for resp in sample: |
852 | | - text.extend(resp.text) |
853 | | - token_ids.extend(resp.token_ids) |
854 | | - samples.append( |
855 | | - jetstream_pb2.DecodeResponse.StreamContent.Sample( |
856 | | - text=token_utils.text_tokens_to_str(text), |
857 | | - token_ids=token_ids, |
858 | | - ) |
859 | | - ) |
860 | | - yield jetstream_pb2.DecodeResponse( |
861 | | - stream_content=jetstream_pb2.DecodeResponse.StreamContent( |
862 | | - samples=samples |
863 | | - ) |
864 | | - ) |
|
0 commit comments