|
73 | 73 | from jetstream.core.proto import jetstream_pb2 |
74 | 74 | from jetstream.core.proto import jetstream_pb2_grpc |
75 | 75 | from jetstream.engine.token_utils import load_vocab |
| 76 | +from jetstream.third_party.llama3 import llama3_tokenizer |
76 | 77 | import numpy as np |
77 | 78 | from tqdm.asyncio import tqdm # pytype: disable=pyi-error |
78 | 79 | import pandas |
@@ -130,10 +131,13 @@ def to_dict(self): |
130 | 131 | } |
131 | 132 |
|
132 | 133 |
|
133 | | -def get_tokenizer(tokenizer_name: str) -> Any: |
| 134 | +def get_tokenizer(model_id: str, tokenizer_name: str) -> Any: |
134 | 135 | """Return a tokenizer or a tokenizer placholder.""" |
135 | 136 | if tokenizer_name == "test": |
136 | 137 | return "test" |
| 138 | + elif model_id == "llama-3": |
| 139 | + # Llama 3 uses a tiktoken tokenizer. |
| 140 | + return llama3_tokenizer.Tokenizer(tokenizer_name) |
137 | 141 | else: |
138 | 142 | # Use JetStream tokenizer util. It's using the sentencepiece wrapper in |
139 | 143 | # seqio library. |
@@ -195,18 +199,14 @@ def tokenize_dataset( |
195 | 199 | prompts = [] |
196 | 200 | outputs = [] |
197 | 201 | indices = [] |
198 | | - |
| 202 | + prompt_token_ids = [] |
| 203 | + outputs_token_ids = [] |
199 | 204 | for prompt, output, idx in dataset: |
200 | 205 | prompts.append(prompt) |
201 | 206 | outputs.append(output) |
202 | 207 | indices.append(idx) |
203 | | - |
204 | | - prompt_token_ids = tokenizer.encode( |
205 | | - prompts |
206 | | - ) # adjust this code based on tokenizer method |
207 | | - outputs_token_ids = tokenizer.encode( |
208 | | - outputs |
209 | | - ) # adjust this code based on tokenizer method |
| 208 | + prompt_token_ids.append(tokenizer.encode(prompt)) |
| 209 | + outputs_token_ids.append(tokenizer.encode(output)) |
210 | 210 |
|
211 | 211 | tokenized_dataset = [] |
212 | 212 | for i in range(n): |
@@ -549,7 +549,7 @@ def main(args: argparse.Namespace): |
549 | 549 |
|
550 | 550 | api_url = f"{args.server}:{args.port}" |
551 | 551 |
|
552 | | - tokenizer = get_tokenizer(tokenizer_id) |
| 552 | + tokenizer = get_tokenizer(model_id, tokenizer_id) |
553 | 553 | if tokenizer == "test" or args.dataset == "test": |
554 | 554 | input_requests = mock_requests( |
555 | 555 | args.total_mock_requests |
@@ -680,9 +680,10 @@ def main(args: argparse.Namespace): |
680 | 680 | type=str, |
681 | 681 | default="no_model", |
682 | 682 | help=( |
683 | | - "Name of the model. (it's just used to label the benchmark, the model" |
684 | | - " config is defined in config_lib, and passed as the server config" |
685 | | - " flag when we run the JetStream server)" |
| 683 | + "Name of the model like llama-2, llama-3, gemma. (it's just used to" |
| 684 | + " label the benchmark, pick the tokenizer, the model config is" |
| 685 | + " defined in config_lib, and passed as the server config flag when" |
| 686 | + " we run the JetStream server)" |
686 | 687 | ), |
687 | 688 | ) |
688 | 689 | parser.add_argument( |
|
0 commit comments