Skip to content

Commit 81beb11

Browse files
authored
Update benchmark to run openorca dataset (#21)
* add openorca dataset * update readme * fix input_requests for test case
1 parent 970b529 commit 81beb11

3 files changed

Lines changed: 146 additions & 48 deletions

File tree

.gitignore

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,10 @@ __pycache__
22
.env*
33
build/
44
dist/
5-
google_jetstream.egg-info/
5+
google_jetstream.egg-info/
6+
7+
# local folders
8+
data/
9+
logs/
10+
tmp/
11+
venv/

benchmarks/README.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,33 @@
11
# JetStream Benchmark And Eval
22

3-
## Install Dependencies
3+
## Install Dependencies
44

55
```
66
cd ~/JetStream/benchmarks
77
pip install -r requirements.in
88
```
99

10-
## Benchmark
10+
## Benchmark
1111

1212
### Prepare DataSet
1313

1414
```
1515
cd ~/data
1616
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
1717
18-
```
18+
```
1919

2020
### Run Benchmark with maxtext tokenizer
2121

2222
```
2323
python benchmark_serving.py \
2424
--tokenizer /home/{username}/maxtext/assets/tokenizer \
2525
--num-prompts 10 \
26-
--dataset ~/data/ShareGPT_V3_unfiltered_cleaned_split.json
26+
--dataset sharegpt \
27+
--dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json \
28+
--max-output-length 1024
2729
28-
```
30+
```
2931

3032
### Save request outputs in Benchmark
3133

@@ -35,7 +37,9 @@ Please use --save-request-outputs flag to enable this feature.
3537
python benchmark_serving.py \
3638
--tokenizer /home/{username}/maxtext/assets/tokenizer \
3739
--num-prompts 10 \
38-
--dataset ~/data/ShareGPT_V3_unfiltered_cleaned_split.json \
40+
--dataset sharegpt \
41+
--dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json \
42+
--max-output-length 1024 \
3943
--save-request-outputs
4044
4145
```

benchmarks/benchmark_serving.py

Lines changed: 129 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
"""Benchmark JetStream online serving.
1616
1717
On the server side, run one of the following commands:
18-
* For real server, you need to pass correct server config (include the model config that
19-
being passed into your engine impl) to the command below. Refer to config_lib.py and
18+
* For real server, you need to pass correct server config (include the model config that
19+
being passed into your engine impl) to the command below. Refer to config_lib.py and
2020
implementations/mock/config.py for config impl detail.
2121
2222
(run with real server)
@@ -28,21 +28,29 @@
2828
2929
On the client side, run:
3030
* For real server and shareGPT dataset, you need to pass the tokenizer, server config, and
31-
dataset flags to the command below, and make some changes to the tokenizer logic in the
31+
dataset flags to the command below, and make some changes to the tokenizer logic in the
3232
benchmark script (get_tokenizer and sample_requests func) to use your tokenizer correctly.
3333
* Add `--save-result` flag to save the benchmark result to a json file in current folder.
3434
* Add `--threads` flag to set the maximum number of threads used for request dispatching.
3535
3636
(run with real model and engines)
3737
python -m benchmarks.benchmark_serving \
38-
--tokenizer <your_tokenizer> --dataset <target_dataset_path> \
38+
--tokenizer <your_tokenizer> \
39+
--dataset <target_dataset_name> \
40+
--dataset-path <target_dataset_path> \
3941
--request-rate <request_rate>
4042
4143
(run with mock)
4244
python -m benchmarks.benchmark_serving \
4345
--request-rate 1
4446
45-
e2e example: python3 benchmark_serving.py --tokenizer /home/rwitten/maxtext/assets/tokenizer --num-prompts 100 --dataset ~/ShareGPT_V3_unfiltered_cleaned_split.json
47+
e2e example:
48+
python3 benchmark_serving.py \
49+
--tokenizer /home/{username}/maxtext/assets/tokenizer \
50+
--num-prompts 100 \
51+
--dataset sharegpt \
52+
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json
53+
4654
"""
4755

4856

@@ -51,13 +59,13 @@
5159

5260
import argparse
5361
import asyncio
54-
from concurrent.futures import ThreadPoolExecutor
62+
5563
from dataclasses import dataclass
5664
from datetime import datetime
5765
import json
5866
import random
5967
import time
60-
from typing import Any, AsyncGenerator, List, Tuple
68+
from typing import Any, AsyncGenerator, List, Optional
6169
import grpc
6270
from jetstream.core.proto import jetstream_pb2
6371
from jetstream.core.proto import jetstream_pb2_grpc
@@ -99,15 +107,15 @@ class RequestFuncOutput:
99107
prompt_len: int = 0
100108

101109
# Flatten the structure and return only the necessary results
102-
def to_dict(self):
110+
def to_dict(self):
103111
return {
104112
"prompt": self.input_request.prompt,
105113
"original_output": self.input_request.output,
106114
"generated_text": self.generated_text,
107115
"success": self.success,
108116
"latency": self.latency,
109117
"prompt_len": self.prompt_len
110-
}
118+
}
111119

112120

113121
def get_tokenizer(tokenizer_name: str) -> Any:
@@ -121,13 +129,11 @@ def get_tokenizer(tokenizer_name: str) -> Any:
121129
model=sp_model, add_bos=True, add_eos=False, reverse=False)
122130
return sp_tokenizer
123131

124-
def sample_requests(
132+
def load_sharegpt_dataset(
125133
dataset_path: str,
126-
num_requests: int,
127134
tokenizer: Any,
128-
max_output_length: int,
129135
conversation_starter: str,
130-
oversample_multiplier: float=1.2,
136+
max_output_length: Optional[int] = None,
131137
) -> List[InputRequest]:
132138
# Load the dataset.
133139
with open(dataset_path) as f:
@@ -144,18 +150,6 @@ def sample_requests(
144150
for data in dataset
145151
]
146152

147-
# Create necessary number of requests even if bigger than dataset size
148-
sampled_indices = random.sample(range(len(dataset)),
149-
min(int(num_requests * oversample_multiplier), len(dataset)))
150-
if num_requests > len(sampled_indices):
151-
print(f"Number of requests {num_requests} is larger than size of dataset {len(dataset)}.\n",
152-
f"Repeating data to meet number of requests.\n")
153-
sampled_indices = sampled_indices * int(np.ceil(num_requests / len(sampled_indices)))
154-
155-
print(f"{len(sampled_indices)=}")
156-
# some of these will be filtered out, so sample more than we need
157-
dataset = [dataset[i] for i in sampled_indices]
158-
159153
# Tokenize the prompts and completions.
160154
prompts = [prompt for prompt, _ in dataset]
161155
prompt_token_ids = tokenizer.tokenize(
@@ -167,27 +161,104 @@ def sample_requests(
167161
) # adjust this code based on tokenizer method
168162
tokenized_dataset = []
169163
for i in range(len(dataset)):
170-
output_len = len(completion_token_ids[i])
171-
tokenized_dataset.append((prompts[i], prompt_token_ids[i], completions[i], output_len))
164+
prompt_len = len(prompt_token_ids[i])
165+
completion_len = len(completion_token_ids[i])
166+
tokenized_dataset.append(
167+
(prompts[i], prompt_token_ids[i], completions[i], prompt_len, completion_len)
168+
)
172169

173170
# Filter out too long sequences.
174171
filtered_dataset: List[InputRequest] = []
175172

176-
for prompt, prompt_token_ids, output, output_len in tokenized_dataset:
177-
prompt_len = len(prompt_token_ids)
178-
if prompt_len < 4 or output_len < 4:
173+
for prompt, prompt_token_ids, completion, prompt_len, completion_len in tokenized_dataset:
174+
if prompt_len < 4 or completion_len < 4:
179175
# Prune too short sequences.
180176
# This is because TGI causes errors when the input or output length
181177
# is too short.
182178
continue
179+
if prompt_len > 1024 or prompt_len + completion_len > 2048:
180+
# Prune too long sequences.
181+
continue
182+
request = InputRequest(prompt, prompt_len, completion, max_output_length or completion_len)
183+
filtered_dataset.append(request)
184+
185+
if max_output_length is None:
186+
print("In InputRequest, pass in actual output_length for each sample")
187+
else:
188+
print(f"In InputRequest, pass in max_output_length: {max_output_length} for each sample")
189+
190+
print(f"The dataset contains {len(tokenized_dataset)} samples.")
191+
print(f"The filtered dataset contains {len(filtered_dataset)} samples.")
192+
193+
return filtered_dataset
194+
195+
196+
def load_openorca_dataset(
197+
dataset_path: str,
198+
tokenizer: Any,
199+
max_output_length: Optional[int] = None,
200+
) -> List[InputRequest]:
201+
202+
# Load the dataset.
203+
with open(dataset_path) as f:
204+
dataset = json.load(f)
205+
206+
# Tokenize the prompts and completions.
207+
prompts = dataset["prompts"]
208+
outputs = dataset["results"]
209+
n = len(prompts)
210+
prompt_token_ids = tokenizer.tokenize(prompts)
211+
output_token_ids = tokenizer.tokenize(outputs)
212+
213+
tokenized_dataset = []
214+
for i in range(n):
215+
prompt_len = len(prompt_token_ids[i])
216+
output_len = len(output_token_ids[i])
217+
tokenized_dataset.append((prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len))
218+
219+
# Filter out too long sequences.
220+
filtered_dataset: List[InputRequest] = []
221+
222+
for prompt, prompt_token_ids, output, prompt_len, output_len in tokenized_dataset:
183223
if prompt_len > 1024 or prompt_len + output_len > 2048:
184224
# Prune too long sequences.
185225
continue
186-
request = InputRequest(prompt, prompt_len, output, max_output_length)
226+
request = InputRequest(prompt, prompt_len, output, max_output_length or output_len)
187227
filtered_dataset.append(request)
188228

229+
if max_output_length is None:
230+
print("In InputRequest, pass in actual output_length for each sample")
231+
else:
232+
print(f"In InputRequest, pass in max_output_length: {max_output_length} for each sample")
233+
234+
print(f"The dataset contains {len(tokenized_dataset)} samples.")
235+
print(f"The filtered dataset contains {len(filtered_dataset)} samples.")
236+
237+
return filtered_dataset
238+
239+
240+
def sample_requests(
241+
dataset: List[InputRequest],
242+
num_requests: int,
243+
oversample_multiplier: float=1.2,
244+
) -> List[InputRequest]:
245+
246+
# Create necessary number of requests even if bigger than dataset size
247+
sampled_indices = random.sample(
248+
range(len(dataset)), min(int(num_requests * oversample_multiplier), len(dataset)))
249+
250+
if num_requests > len(sampled_indices):
251+
print(f"Number of requests {num_requests} is larger than size of dataset {len(dataset)}.\n",
252+
f"Repeating data to meet number of requests.\n")
253+
sampled_indices = sampled_indices * int(np.ceil(num_requests / len(sampled_indices)))
254+
255+
print(f"{len(sampled_indices)=}")
256+
# some of these will be filtered out, so sample more than we need
257+
dataset = [dataset[i] for i in sampled_indices]
258+
189259
# Sample the requests.
190-
sampled_requests = random.sample(filtered_dataset, num_requests)
260+
sampled_requests = random.sample(dataset, num_requests)
261+
191262
return sampled_requests
192263

193264

@@ -396,7 +467,7 @@ def sample_warmup_requests(requests):
396467
256,
397468
512,
398469
1024,]
399-
470+
400471
for start, end in zip(interesting_buckets[:-1], interesting_buckets[1:]):
401472
for request in requests:
402473
if start < request.prompt_len <= end:
@@ -418,12 +489,26 @@ def main(args: argparse.Namespace):
418489
if tokenizer == "test" or args.dataset == "test":
419490
input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, "AB", 3)]
420491
else:
492+
if args.dataset == "openorca":
493+
dataset = load_openorca_dataset(
494+
args.dataset_path,
495+
tokenizer,
496+
args.max_output_length
497+
)
498+
elif args.dataset == "sharegpt":
499+
dataset = load_sharegpt_dataset(
500+
args.dataset_path,
501+
tokenizer,
502+
args.conversation_starter,
503+
args.max_output_length
504+
)
505+
506+
# A given args.max_output_length value is the max generation step,
507+
# when the args.max_output_length is default to None, the sample's golden output length
508+
# will be used to decide the generation step
421509
input_requests = sample_requests(
422-
args.dataset,
423-
args.num_prompts,
424-
tokenizer,
425-
args.max_output_length,
426-
args.conversation_starter,
510+
dataset,
511+
args.num_prompts,
427512
)
428513

429514
if args.warmup_first:
@@ -486,7 +571,7 @@ def main(args: argparse.Namespace):
486571
if args.save_request_outputs:
487572
file_path = args.request_outputs_file_path
488573
with open(file_path, "w") as output_file:
489-
json.dump([output.to_dict() for output in request_outputs], output_file, indent=4)
574+
json.dump([output.to_dict() for output in request_outputs], output_file, indent=4)
490575

491576

492577
if __name__ == "__main__":
@@ -501,7 +586,10 @@ def main(args: argparse.Namespace):
501586
)
502587
parser.add_argument("--port", type=str, default=9000)
503588
parser.add_argument(
504-
"--dataset", type=str, default="test", help="Path to the dataset."
589+
"--dataset", type=str, default="test", choices=["test", "sharegpt", "openorca"], help="The dataset name."
590+
)
591+
parser.add_argument(
592+
"--dataset-path", type=str, help="Path to the dataset."
505593
)
506594
parser.add_argument(
507595
"--model",
@@ -558,7 +646,7 @@ def main(args: argparse.Namespace):
558646
parser.add_argument(
559647
"--max-output-length",
560648
type=int,
561-
default=1024,
649+
default=None,
562650
help="The maximum output length for reference request.",
563651
)
564652

0 commit comments

Comments
 (0)