Skip to content

Commit e4952fb

Browse files
authored
Update benchmark script to easily test llama-3 (#83)
* Update benchmark script to easily test llama-3 * fix lint * Update benchmarks/README.md
1 parent 01c5a03 commit e4952fb

3 files changed

Lines changed: 29 additions & 15 deletions

File tree

benchmarks/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ python benchmark_serving.py \
2929
3030
```
3131

32+
### Run Benchmark for Llama 3
33+
34+
```
35+
python benchmark_serving.py \
36+
--tokenizer <llama3 tokenizer path> \
37+
--num-prompts 10 \
38+
--dataset sharegpt \
39+
--dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json \
40+
--max-output-length 1024 \
41+
--model llama-3
42+
43+
```
44+
3245
### Save request outputs in Benchmark
3346

3447
Please use `--save-request-outputs` flag to save predictions to a file.

benchmarks/benchmark_serving.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from jetstream.core.proto import jetstream_pb2
7474
from jetstream.core.proto import jetstream_pb2_grpc
7575
from jetstream.engine.token_utils import load_vocab
76+
from jetstream.third_party.llama3 import llama3_tokenizer
7677
import numpy as np
7778
from tqdm.asyncio import tqdm # pytype: disable=pyi-error
7879
import pandas
@@ -130,10 +131,13 @@ def to_dict(self):
130131
}
131132

132133

133-
def get_tokenizer(tokenizer_name: str) -> Any:
134+
def get_tokenizer(model_id: str, tokenizer_name: str) -> Any:
134135
"""Return a tokenizer or a tokenizer placholder."""
135136
if tokenizer_name == "test":
136137
return "test"
138+
elif model_id == "llama-3":
139+
# Llama 3 uses a tiktoken tokenizer.
140+
return llama3_tokenizer.Tokenizer(tokenizer_name)
137141
else:
138142
# Use JetStream tokenizer util. It's using the sentencepiece wrapper in
139143
# seqio library.
@@ -195,18 +199,14 @@ def tokenize_dataset(
195199
prompts = []
196200
outputs = []
197201
indices = []
198-
202+
prompt_token_ids = []
203+
outputs_token_ids = []
199204
for prompt, output, idx in dataset:
200205
prompts.append(prompt)
201206
outputs.append(output)
202207
indices.append(idx)
203-
204-
prompt_token_ids = tokenizer.encode(
205-
prompts
206-
) # adjust this code based on tokenizer method
207-
outputs_token_ids = tokenizer.encode(
208-
outputs
209-
) # adjust this code based on tokenizer method
208+
prompt_token_ids.append(tokenizer.encode(prompt))
209+
outputs_token_ids.append(tokenizer.encode(output))
210210

211211
tokenized_dataset = []
212212
for i in range(n):
@@ -549,7 +549,7 @@ def main(args: argparse.Namespace):
549549

550550
api_url = f"{args.server}:{args.port}"
551551

552-
tokenizer = get_tokenizer(tokenizer_id)
552+
tokenizer = get_tokenizer(model_id, tokenizer_id)
553553
if tokenizer == "test" or args.dataset == "test":
554554
input_requests = mock_requests(
555555
args.total_mock_requests
@@ -680,9 +680,10 @@ def main(args: argparse.Namespace):
680680
type=str,
681681
default="no_model",
682682
help=(
683-
"Name of the model. (it's just used to label the benchmark, the model"
684-
" config is defined in config_lib, and passed as the server config"
685-
" flag when we run the JetStream server)"
683+
"Name of the model like llama-2, llama-3, gemma. (it's just used to"
684+
" label the benchmark, pick the tokenizer, the model config is"
685+
" defined in config_lib, and passed as the server config flag when"
686+
" we run the JetStream server)"
686687
),
687688
)
688689
parser.add_argument(

jetstream/third_party/llama3/llama3_tokenizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def encode(
107107
self,
108108
s: str,
109109
*,
110-
bos: bool,
111-
eos: bool,
110+
bos: bool = False,
111+
eos: bool = False,
112112
allowed_special: Union[Literal["all"], AbstractSet[str]] | None = None,
113113
disallowed_special: Union[Literal["all"], Collection[str]] = (),
114114
) -> List[int]:

0 commit comments

Comments
 (0)