Skip to content

Commit 5820c95

Browse files
authored
refactor to sample before tokenize (#26)
1 parent 2245876 commit 5820c95

1 file changed

Lines changed: 62 additions & 72 deletions

File tree

benchmarks/benchmark_serving.py

Lines changed: 62 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,8 @@ def get_tokenizer(tokenizer_name: str) -> Any:
131131

132132
def load_sharegpt_dataset(
133133
dataset_path: str,
134-
tokenizer: Any,
135134
conversation_starter: str,
136-
max_output_length: Optional[int] = None,
137-
) -> List[InputRequest]:
135+
) -> List[tuple[str]]:
138136
# Load the dataset.
139137
with open(dataset_path) as f:
140138
dataset = json.load(f)
@@ -150,116 +148,112 @@ def load_sharegpt_dataset(
150148
for data in dataset
151149
]
152150

151+
return dataset
152+
153+
154+
def load_openorca_dataset(
155+
dataset_path: str
156+
) -> List[tuple[str]]:
157+
# Load the dataset.
158+
with open(dataset_path) as f:
159+
dataset = json.load(f)
160+
153161
# Tokenize the prompts and completions.
162+
prompts = dataset["prompts"]
163+
outputs = dataset["results"]
164+
165+
return [(prompt, output) for prompt, output in zip(prompts, outputs)]
166+
167+
168+
def tokenize_dataset(
169+
dataset: List[tuple[str]],
170+
tokenizer: Any,
171+
) -> List[tuple[Any]]:
172+
173+
n = len(dataset)
174+
154175
prompts = [prompt for prompt, _ in dataset]
176+
outputs = [output for _, output in dataset]
177+
155178
prompt_token_ids = tokenizer.tokenize(
156179
prompts
157180
) # adjust this code based on tokenizer method
158-
completions = [completion for _, completion in dataset]
159-
completion_token_ids = tokenizer.tokenize(
160-
completions
181+
outputs_token_ids = tokenizer.tokenize(
182+
outputs
161183
) # adjust this code based on tokenizer method
184+
162185
tokenized_dataset = []
163-
for i in range(len(dataset)):
186+
for i in range(n):
164187
prompt_len = len(prompt_token_ids[i])
165-
completion_len = len(completion_token_ids[i])
188+
output_len = len(outputs_token_ids[i])
166189
tokenized_dataset.append(
167-
(prompts[i], prompt_token_ids[i], completions[i], prompt_len, completion_len)
190+
(prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len)
168191
)
192+
return tokenized_dataset
169193

170-
# Filter out too long sequences.
171-
filtered_dataset: List[InputRequest] = []
172-
173-
for prompt, prompt_token_ids, completion, prompt_len, completion_len in tokenized_dataset:
174-
if prompt_len < 4 or completion_len < 4:
175-
# Prune too short sequences.
176-
# This is because TGI causes errors when the input or output length
177-
# is too short.
178-
continue
179-
if prompt_len > 1024 or prompt_len + completion_len > 2048:
180-
# Prune too long sequences.
181-
continue
182-
request = InputRequest(prompt, prompt_len, completion, max_output_length or completion_len)
183-
filtered_dataset.append(request)
184194

195+
def filter_dataset(
196+
tokenized_dataset: List[tuple[Any]],
197+
max_output_length: Optional[int] = None
198+
) -> List[InputRequest]:
185199
if max_output_length is None:
186200
print("In InputRequest, pass in actual output_length for each sample")
187201
else:
188202
print(f"In InputRequest, pass in max_output_length: {max_output_length} for each sample")
189203

190-
print(f"The dataset contains {len(tokenized_dataset)} samples.")
191-
print(f"The filtered dataset contains {len(filtered_dataset)} samples.")
192-
193-
return filtered_dataset
194-
195-
196-
def load_openorca_dataset(
197-
dataset_path: str,
198-
tokenizer: Any,
199-
max_output_length: Optional[int] = None,
200-
) -> List[InputRequest]:
201-
202-
# Load the dataset.
203-
with open(dataset_path) as f:
204-
dataset = json.load(f)
205-
206-
# Tokenize the prompts and completions.
207-
prompts = dataset["prompts"]
208-
outputs = dataset["results"]
209-
n = len(prompts)
210-
prompt_token_ids = tokenizer.tokenize(prompts)
211-
output_token_ids = tokenizer.tokenize(outputs)
212-
213-
tokenized_dataset = []
214-
for i in range(n):
215-
prompt_len = len(prompt_token_ids[i])
216-
output_len = len(output_token_ids[i])
217-
tokenized_dataset.append((prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len))
218-
219204
# Filter out too long sequences.
220205
filtered_dataset: List[InputRequest] = []
221-
222206
for prompt, prompt_token_ids, output, prompt_len, output_len in tokenized_dataset:
207+
if prompt_len < 4 or output_len < 4:
208+
# Prune too short sequences.
209+
# This is because TGI causes errors when the input or output length
210+
# is too short.
211+
continue
223212
if prompt_len > 1024 or prompt_len + output_len > 2048:
224213
# Prune too long sequences.
225214
continue
226215
request = InputRequest(prompt, prompt_len, output, max_output_length or output_len)
227216
filtered_dataset.append(request)
228217

229-
if max_output_length is None:
230-
print("In InputRequest, pass in actual output_length for each sample")
231-
else:
232-
print(f"In InputRequest, pass in max_output_length: {max_output_length} for each sample")
233-
234218
print(f"The dataset contains {len(tokenized_dataset)} samples.")
235219
print(f"The filtered dataset contains {len(filtered_dataset)} samples.")
236220

237221
return filtered_dataset
238222

239223

240224
def sample_requests(
241-
dataset: List[InputRequest],
225+
dataset: List[tuple[str]],
226+
tokenizer: Any,
242227
num_requests: int,
228+
max_output_length: Optional[int] = None,
243229
oversample_multiplier: float=1.2,
244230
) -> List[InputRequest]:
245231

232+
# Original dataset size
233+
n = len(dataset)
234+
246235
# Create necessary number of requests even if bigger than dataset size
247236
sampled_indices = random.sample(
248-
range(len(dataset)), min(int(num_requests * oversample_multiplier), len(dataset)))
237+
range(n), min(int(num_requests * oversample_multiplier), n))
249238

250239
if num_requests > len(sampled_indices):
251-
print(f"Number of requests {num_requests} is larger than size of dataset {len(dataset)}.\n",
240+
print(f"Number of requests {num_requests} is larger than size of dataset {n}.\n",
252241
f"Repeating data to meet number of requests.\n")
253242
sampled_indices = sampled_indices * int(np.ceil(num_requests / len(sampled_indices)))
254243

255244
print(f"{len(sampled_indices)=}")
256245
# some of these will be filtered out, so sample more than we need
257246
dataset = [dataset[i] for i in sampled_indices]
258247

248+
tokenized_dataset = tokenize_dataset(dataset, tokenizer)
249+
250+
input_requests = filter_dataset(tokenized_dataset, max_output_length)
251+
259252
# Sample the requests.
260-
sampled_requests = random.sample(dataset, num_requests)
253+
if len(input_requests) > num_requests:
254+
input_requests = random.sample(input_requests, num_requests)
261255

262-
return sampled_requests
256+
return input_requests
263257

264258

265259
async def get_request(
@@ -490,25 +484,21 @@ def main(args: argparse.Namespace):
490484
input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, "AB", 3)]
491485
else:
492486
if args.dataset == "openorca":
493-
dataset = load_openorca_dataset(
494-
args.dataset_path,
495-
tokenizer,
496-
args.max_output_length
497-
)
487+
dataset = load_openorca_dataset(args.dataset_path)
498488
elif args.dataset == "sharegpt":
499489
dataset = load_sharegpt_dataset(
500490
args.dataset_path,
501-
tokenizer,
502491
args.conversation_starter,
503-
args.max_output_length
504492
)
505493

506494
# A given args.max_output_length value is the max generation step,
507495
# when the args.max_output_length is default to None, the sample's golden output length
508496
# will be used to decide the generation step
509497
input_requests = sample_requests(
510-
dataset,
511-
args.num_prompts,
498+
dataset=dataset,
499+
tokenizer=tokenizer,
500+
num_requests=args.num_prompts,
501+
max_output_length=args.max_output_length
512502
)
513503

514504
if args.warmup_first:

0 commit comments

Comments
 (0)