Skip to content

Commit 9e7fc32

Browse files
authored
Add an optional parameter for sampling in prefill / sample. (#133)
* Add an optional parameter for sampling in prefill / sample. This is needed because we want to enable per-request sampling parameters. This allows jetstream to be used as backend for HuggingFace TGI. * lint
1 parent 647ab24 commit 9e7fc32

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

jetstream/engine/engine_api.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""
2020

2121
import abc
22-
from typing import Any, Optional, Tuple, Union
22+
from typing import Any, Optional, Tuple, Union, Callable
2323

2424
from flax import struct
2525
import jax
@@ -142,18 +142,24 @@ def prefill(
142142
existing_prefix: Optional[Prefix] = None,
143143
padded_tokens: jax.Array,
144144
true_length: int,
145+
sampler: Optional[Callable[[Any], Any]] = None,
145146
) -> Tuple[Prefix, ResultTokens]:
146147
"""Computes a kv-cache for a set of tokens conditional on existing cache.
147148
148149
existing_prefix (if provided) represents a prefix that has already been
149150
processed by the underlying model. tokens is logically appended
150151
to the text represented by `existing_prefix`. This method returns a new
151152
kv_cache (typically) for the resulting text.
153+
154+
If sampler is passed, then the engine should use it do sample next token.
152155
"""
153156

154157
@abc.abstractmethod
155158
def generate(
156-
self, params: Params, decode_state: DecodeState
159+
self,
160+
params: Params,
161+
decode_state: DecodeState,
162+
sampler: Optional[Callable[[Any], Any]] = None,
157163
) -> Tuple[DecodeState, ResultTokens]:
158164
"""Generates tokens for each sequence being decoded in parallel.
159165
@@ -165,6 +171,8 @@ def generate(
165171
consists of each microbatch progressing through every stage), in
166172
non-pipelined code this is a full forward pass. In both cases, this accounts
167173
for a full embed-layerstack-unembed-sample operation.
174+
175+
If sampler is passed, then the engine should use it do sample next token.
168176
"""
169177

170178
@abc.abstractmethod

0 commit comments

Comments
 (0)