Skip to content

Commit 87aa565

Browse files
authored
Unit test coverage cleanup (#81)
* Unit test coverage cleanup * Increase test coverage to 96% * Set CI failure w coverage lower than 96% * Update README unit test status badge to main branch
1 parent e4952fb commit 87aa565

10 files changed

Lines changed: 182 additions & 115 deletions

File tree

.github/workflows/unit_tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,4 @@ jobs:
8080
coverage run -m unittest -v
8181
- name: Create test coverage report
8282
run: |
83-
coverage report -m
83+
coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/third_party/*" --fail-under=96

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
[![Unit Tests](https://github.com/google/JetStream/actions/workflows/unit_tests.yaml/badge.svg)](https://github.com/google/JetStream/actions/workflows/unit_tests.yaml)
1+
[![Unit Tests](https://github.com/google/JetStream/actions/workflows/unit_tests.yaml/badge.svg?branch=main)](https://github.com/google/JetStream/actions/workflows/unit_tests.yaml?query=branch:main)
22
[![PyPI version](https://badge.fury.io/py/google-jetstream.svg)](https://badge.fury.io/py/google-jetstream)
33
[![PyPi downloads](https://img.shields.io/pypi/dm/google-jetstream?style=flat-square&logo=pypi&logoColor=white)](https://pypi.org/project/google-jetstream/)
44
[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
@@ -57,15 +57,16 @@ python -m jetstream.tools.load_tester
5757
### Test core modules
5858
```
5959
# Test JetStream core orchestrator
60-
python -m jetstream.tests.core.test_orchestrator
60+
python -m unittest -v jetstream.tests.core.test_orchestrator
6161
6262
# Test JetStream core server library
63-
python -m jetstream.tests.core.test_server
63+
python -m unittest -v jetstream.tests.core.test_server
6464
6565
# Test mock JetStream engine implementation
66-
python -m jetstream.tests.engine.test_mock_engine
66+
python -m unittest -v jetstream.tests.engine.test_mock_engine
6767
6868
# Test mock JetStream token utils
69-
python -m jetstream.tests.engine.test_utils
69+
python -m unittest -v jetstream.tests.engine.test_token_utils
70+
python -m unittest -v jetstream.tests.engine.test_utils
7071
7172
```

jetstream/core/config_lib.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,6 @@ class ServerConfig:
3838
generate_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
3939
interleaved_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
4040

41-
def get_slices_to_launch(self: "ServerConfig") -> str:
42-
"""Used when launching this config via xm config."""
43-
return ",".join(
44-
self.prefill_slices + self.generate_slices + self.interleaved_slices
45-
)
46-
4741

4842
@dataclasses.dataclass
4943
class InstantiatedEngines:

jetstream/core/orchestrator.py

Lines changed: 72 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
import threading
8686
import time
8787
import traceback
88-
from typing import Any, AsyncIterator, Optional, Tuple, Union, cast
88+
from typing import Any, AsyncIterator, Optional, Tuple, cast
8989

9090
import grpc
9191
import jax
@@ -434,13 +434,6 @@ def place_request_on_prefill_queue(self, request: ActiveRequest):
434434
self._prefill_backlog.put(request, block=False)
435435
self._prefill_backlog_size_metric.set(self._prefill_backlog.qsize())
436436

437-
def _load_cache_history(self, path: str) -> Union[None, Any]:
438-
"""Loads previous kv cache for a longer conversation."""
439-
if path:
440-
raise NotImplementedError
441-
else:
442-
return None
443-
444437
def _process_prefill_content(
445438
self,
446439
request: ActiveRequest,
@@ -744,6 +737,60 @@ def _get_prefill_content(
744737
True,
745738
)
746739

740+
def process_client_side_tokenization_response(self, response: Any):
741+
samples = []
742+
for sample in response:
743+
samples.append(
744+
jetstream_pb2.DecodeResponse.StreamContent.Sample(
745+
token_ids=sample.token_ids,
746+
)
747+
)
748+
return jetstream_pb2.DecodeResponse(
749+
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
750+
samples=samples
751+
)
752+
)
753+
754+
def should_buffer_response(self, response: Any) -> bool:
755+
for item in response:
756+
if item.text and token_utils.is_byte_token(item.text[-1]):
757+
# If any sample ends in bytes, this means we might still need to
758+
# decode more bytes to compose the string.
759+
return True
760+
761+
def process_server_side_tokenization_response(
762+
self, response: Any, buffered_response_list
763+
):
764+
# Flush the buffered responses to each sample of current response.
765+
current_response_with_flushed_buffer = list(
766+
zip(*buffered_response_list, response)
767+
)
768+
# Empty buffer: [[s0_cur], [s1_cur], ...]
769+
# Has buffer:
770+
# [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...]
771+
current_response_with_flushed_buffer = cast(
772+
list[list[ReturnSample]], current_response_with_flushed_buffer
773+
)
774+
# Form correct sample(s) and return as StreamContent for this iteration.
775+
samples = []
776+
for sample in current_response_with_flushed_buffer:
777+
text = []
778+
token_ids = []
779+
for resp in sample:
780+
text.extend(resp.text)
781+
token_ids.extend(resp.token_ids)
782+
samples.append(
783+
jetstream_pb2.DecodeResponse.StreamContent.Sample(
784+
text=token_utils.text_tokens_to_str(text),
785+
token_ids=token_ids,
786+
)
787+
)
788+
return jetstream_pb2.DecodeResponse(
789+
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
790+
samples=samples
791+
)
792+
)
793+
747794
async def Decode( # pylint: disable=invalid-overridden-method
748795
self,
749796
request: jetstream_pb2.DecodeRequest,
@@ -795,70 +842,24 @@ async def Decode( # pylint: disable=invalid-overridden-method
795842
# The DecodeResponse stream should consume all generated tokens in
796843
# return_channel when complete signal is received (AsyncMultifuture
797844
# promises this).
798-
if is_client_side_tokenization:
799-
# If is_client_side_tokenization, the client should request with token
800-
# ids, and the JetStream server will return token ids as response.
801-
# The client should take care of tokenization and detokenization.
802-
async for response in active_request.return_channel:
803-
response = cast(list[ReturnSample], response)
804-
samples = []
805-
for sample in response:
806-
samples.append(
807-
jetstream_pb2.DecodeResponse.StreamContent.Sample(
808-
token_ids=sample.token_ids,
809-
)
810-
)
811-
yield jetstream_pb2.DecodeResponse(
812-
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
813-
samples=samples
814-
)
815-
)
816-
else:
817-
# Buffer response mechanism is used to handle streaming
818-
# detokenization with special character (For some edge cases with
819-
# SentencePiece tokenizer, it requires to decode a complete sequence
820-
# instead of a single token).
821-
buffered_response_list = []
822-
async for response in active_request.return_channel:
823-
response = cast(list[ReturnSample], response)
824-
buffered = False
825-
for item in response:
826-
if item.text and token_utils.is_byte_token(item.text[-1]):
827-
# If any sample ends in bytes, this means we might still need to
828-
# decode more bytes to compose the string.
829-
buffered_response_list.append(response)
830-
buffered = True
831-
break
832-
if buffered:
845+
buffered_response_list = []
846+
async for response in active_request.return_channel:
847+
response = cast(list[ReturnSample], response)
848+
if is_client_side_tokenization:
849+
# If is_client_side_tokenization, the client should request with token
850+
# ids, and the JetStream server will return token ids as response.
851+
# The client should take care of tokenization and detokenization.
852+
yield self.process_client_side_tokenization_response(response)
853+
else:
854+
# Buffer response mechanism is used to handle streaming
855+
# detokenization with special character (For some edge cases with
856+
# SentencePiece tokenizer, it requires to decode a complete sequence
857+
# instead of a single token).
858+
if self.should_buffer_response(response):
859+
buffered_response_list.append(response)
833860
continue
834-
# Flush the buffered responses to each sample of current response.
835-
current_response_with_flushed_buffer = list(
836-
zip(*buffered_response_list, response)
837-
)
838-
# Empty buffer: [[s0_cur], [s1_cur], ...]
839-
# Has buffer:
840-
# [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...]
841-
current_response_with_flushed_buffer = cast(
842-
list[list[ReturnSample]], current_response_with_flushed_buffer
861+
yield self.process_server_side_tokenization_response(
862+
response, buffered_response_list
843863
)
844864
# Reset buffer after flushed.
845865
buffered_response_list = []
846-
# Form correct sample(s) and return as StreamContent for this iteration.
847-
samples = []
848-
for sample in current_response_with_flushed_buffer:
849-
text = []
850-
token_ids = []
851-
for resp in sample:
852-
text.extend(resp.text)
853-
token_ids.extend(resp.token_ids)
854-
samples.append(
855-
jetstream_pb2.DecodeResponse.StreamContent.Sample(
856-
text=token_utils.text_tokens_to_str(text),
857-
token_ids=token_ids,
858-
)
859-
)
860-
yield jetstream_pb2.DecodeResponse(
861-
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
862-
samples=samples
863-
)
864-
)

jetstream/tests/core/test_config_lib.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,18 @@
1414

1515
"""Unit test for config_lib.py."""
1616

17-
from absl.testing import absltest, parameterized
17+
import unittest
18+
from parameterized import parameterized
1819
from jetstream.core import config_lib
1920

2021

21-
class TestConfigLib(parameterized.TestCase):
22+
class TestConfigLib(unittest.TestCase):
2223

23-
@parameterized.parameters(
24-
("tpu=8", 8),
25-
("v5e-8", 8),
26-
("v5e=4", 4),
27-
("v4-8", 4),
28-
)
24+
@parameterized.expand([("tpu=8", 8), ("v5e-8", 8), ("v5e=4", 4), ("v4-8", 4)])
2925
def test_slice_to_num_chips(self, accelerator_slice, expected_num_devices):
3026
got = config_lib.slice_to_num_chips(accelerator_slice)
3127
self.assertEqual(got, expected_num_devices)
3228

33-
34-
if __name__ == "__main__":
35-
absltest.main()
29+
def test_get_engines_invalid(self):
30+
with self.assertRaises(ValueError):
31+
config_lib.get_engines(config_lib.InterleavedCPUTestServer, [])

jetstream/tests/core/test_orchestrator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import unittest
4545
from jetstream.core import orchestrator
4646
from jetstream.core.proto import jetstream_pb2
47+
from jetstream.core.utils.return_sample import ReturnSample
4748
from jetstream.engine import mock_engine
4849

4950

@@ -131,6 +132,13 @@ async def test_orchestrator_interleaved_mode_client_tokenization(self):
131132
driver.stop()
132133
print("Orchestrator driver stopped.")
133134

134-
135-
if __name__ == "__main__":
136-
unittest.main()
135+
def test_should_buffer_response(self):
136+
driver = self._setup_driver_interleaved_mode()
137+
client = orchestrator.LLMOrchestrator(driver=driver)
138+
self.assertTrue(
139+
client.should_buffer_response(
140+
[ReturnSample(text=["<0xAB>"], token_ids=[13])]
141+
)
142+
)
143+
driver.stop()
144+
print("Orchestrator driver stopped.")

jetstream/tests/core/test_server.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,5 @@ async def test_server(
9696
counter += 1
9797
server.stop()
9898

99-
100-
if __name__ == "__main__":
101-
unittest.main()
99+
def test_get_devices(self):
100+
assert len(server_lib.get_devices()) == 1

jetstream/tests/engine/test_mock_engine.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@
3030
I.e. ['Ċ', 'Ə', 'ɖ'] when converted back with chr()
3131
"""
3232

33+
import unittest
3334
import jax.numpy as jnp
3435
import numpy as np
3536

3637
from jetstream.engine import mock_engine
3738
from jetstream.engine import token_utils
38-
from absl.testing import absltest
3939

4040

41-
class EngineTest(absltest.TestCase):
41+
class EngineTest(unittest.TestCase):
4242

4343
def _setup(self):
4444
"""Initialises a test engine."""
@@ -128,7 +128,3 @@ def test_generate(self, slot=1):
128128
token_data = sampled_tokens.get_result_at_slot(slot)
129129
tok = token_data.tokens
130130
assert tokenizer.IdToPiece(int(tok.item())) == "ɖ"
131-
132-
133-
if __name__ == "__main__":
134-
absltest.main()

jetstream/tests/engine/test_token_utils.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,67 @@ def test_tokenize_and_pad_np(self):
136136
)
137137
self.assertEqual(true_length, expected_true_length)
138138

139+
def test_tokenize_and_pad(self):
140+
jax.config.update("jax_platform_name", "cpu")
141+
self.setup_sentencepiece()
142+
s = "I believe the meaning of life is"
143+
vocab = self.jt_tokenizer.vocab
144+
max_prefill_length = 1024
145+
padded_tokens, true_length = token_utils.tokenize_and_pad(
146+
s,
147+
vocab,
148+
max_prefill_length=max_prefill_length,
149+
)
150+
expected_padded_tokens = jnp.array(
151+
[1, 306, 4658, 278, 6593, 310, 2834, 338, 0, 0, 0, 0, 0, 0, 0, 0]
152+
)
153+
expected_true_length = 8
154+
self.assertTrue(
155+
jnp.allclose(padded_tokens, expected_padded_tokens, atol=1e-7)
156+
)
157+
self.assertEqual(true_length, expected_true_length)
158+
159+
def test_pad_token_padding_less_than_zero(self):
160+
jax.config.update("jax_platform_name", "cpu")
161+
self.setup_sentencepiece()
162+
s = "I believe the meaning of life is having different experiences and "
163+
s += "enjoy everyday of my life."
164+
vocab = self.jt_tokenizer.vocab
165+
max_prefill_length = 16
166+
tokens = vocab.encode_tf(s)
167+
padded_tokens, true_length = token_utils.pad_tokens(
168+
tokens,
169+
bos_id=vocab.bos_id,
170+
pad_id=vocab.pad_id,
171+
max_prefill_length=max_prefill_length,
172+
)
173+
# Take the last N tokens if we have too many.
174+
expected_padded_tokens = jnp.array(
175+
[
176+
278,
177+
6593,
178+
310,
179+
2834,
180+
338,
181+
2534,
182+
1422,
183+
27482,
184+
322,
185+
13389,
186+
1432,
187+
3250,
188+
310,
189+
590,
190+
2834,
191+
29889,
192+
]
193+
)
194+
expected_true_length = 19
195+
self.assertTrue(
196+
jnp.allclose(padded_tokens, expected_padded_tokens, atol=1e-7)
197+
)
198+
self.assertEqual(true_length, expected_true_length)
199+
139200
def test_sentencepiece_tokenizer_encode(self):
140201
self.setup_sentencepiece()
141202
s = "I believe the meaning of life is"
@@ -559,7 +620,3 @@ def test_text_tokens_to_str(self):
559620
)
560621
== "你好�\n�hello"
561622
)
562-
563-
564-
if __name__ == "__main__":
565-
unittest.main()

0 commit comments

Comments
 (0)