Skip to content

Commit 0637309

Browse files
authored
Adds filtering for sharegpt based on conversation starter. (#17)
1 parent 1c153b1 commit 0637309

1 file changed

Lines changed: 23 additions & 3 deletions

File tree

benchmarks/benchmark_serving.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,17 @@ def sample_requests(
126126
num_requests: int,
127127
tokenizer: Any,
128128
max_output_length: int,
129+
conversation_starter: str,
129130
) -> List[InputRequest]:
130131
# Load the dataset.
131132
with open(dataset_path) as f:
132133
dataset = json.load(f)
133134
# Filter out the conversations with less than 2 turns.
134135
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
136+
137+
# Filter based on conversation starter
138+
if conversation_starter != "both":
139+
dataset = [data for data in dataset if data["conversations"][0]["from"] == conversation_starter]
135140
# Only keep the first two turns of each conversation.
136141
dataset = [
137142
(data["conversations"][0]["value"], data["conversations"][1]["value"])
@@ -169,8 +174,8 @@ def sample_requests(
169174
if prompt_len > 1024 or prompt_len + output_len > 2048:
170175
# Prune too long sequences.
171176
continue
172-
reqeust = InputRequest(prompt, prompt_len, output, max_output_length)
173-
filtered_dataset.append(reqeust)
177+
request = InputRequest(prompt, prompt_len, output, max_output_length)
178+
filtered_dataset.append(request)
174179

175180
# Sample the requests.
176181
sampled_requests = random.sample(filtered_dataset, num_requests)
@@ -409,7 +414,13 @@ def main(args: argparse.Namespace):
409414
if tokenizer == "test" or args.dataset == "test":
410415
input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, "AB", 3)]
411416
else:
412-
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.max_output_length)
417+
input_requests = sample_requests(
418+
args.dataset,
419+
args.num_prompts,
420+
tokenizer,
421+
args.max_output_length,
422+
args.conversation_starter,
423+
)
413424

414425
if args.warmup_first:
415426
print('Warm up start:' )
@@ -597,6 +608,15 @@ def main(args: argparse.Namespace):
597608
"Whether to send warmup req first"
598609
),
599610
)
611+
parser.add_argument(
612+
"--conversation-starter",
613+
type=str,
614+
default="human",
615+
choices=["human", "gpt", "both"],
616+
help=(
617+
"What entity should be the one starting the conversations."
618+
),
619+
)
600620

601621
args = parser.parse_args()
602622
main(args)

0 commit comments

Comments
 (0)