Skip to content

Commit b253b96

Browse files
authored
Merge pull request #167 from TheodoreEhrenborg/main
fix: Prevent OverflowError in Offline client
2 parents 3c2e139 + 5459a22 commit b253b96

1 file changed

Lines changed: 18 additions & 6 deletions

File tree

delphi/clients/offline.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,28 @@ async def process_func(
8686
Process a single request.
8787
"""
8888

89-
# This is actually stupid
89+
# Extract params from kwargs - must pass to constructor, not mutate after,
90+
# because SamplingParams.__post_init__ likely does some extra setup,
91+
# and mutation after construction skips this.
92+
logprobs = None
93+
prompt_logprobs = None
94+
max_tokens = self.sampling_params.max_tokens
95+
temperature = 1.0
9096
for kwarg in kwargs:
9197
if "logprobs" in kwarg:
92-
self.sampling_params.logprobs = kwarg["top_logprobs"]
98+
logprobs = kwarg["top_logprobs"]
9399
if "prompt_logprobs" in kwarg:
94-
self.sampling_params.prompt_logprobs = kwarg["prompt_logprobs"]
100+
prompt_logprobs = kwarg["prompt_logprobs"]
95101
if "max_tokens" in kwarg:
96-
self.sampling_params.max_tokens = kwarg["max_tokens"]
102+
max_tokens = kwarg["max_tokens"]
97103
if "temperature" in kwarg:
98-
self.sampling_params.temperature = kwarg["temperature"]
104+
temperature = kwarg["temperature"]
105+
sampling_params = SamplingParams(
106+
max_tokens=max_tokens,
107+
logprobs=logprobs,
108+
prompt_logprobs=prompt_logprobs,
109+
temperature=temperature,
110+
)
99111
loop = asyncio.get_running_loop()
100112
prompts = []
101113
statistics = []
@@ -124,7 +136,7 @@ async def process_func(
124136
partial(
125137
self.client.generate, # type: ignore
126138
prompts,
127-
sampling_params=self.sampling_params,
139+
sampling_params=sampling_params, # Use fresh sampling_params
128140
use_tqdm=False,
129141
),
130142
)

0 commit comments

Comments
 (0)