Skip to content

Commit 4a7c8f6

Browse files
committed
fix(vllm): add pooling runner for embedding mode and update tests
1 parent 4ebad0e commit 4a7c8f6

2 files changed

Lines changed: 44 additions & 3 deletions

File tree

pkg/inference/backends/vllm/vllm_config.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
4545
case inference.BackendModeCompletion:
4646
// Default mode for vLLM
4747
case inference.BackendModeEmbedding:
48-
// vLLM doesn't have a specific embedding flag like llama.cpp
49-
// Embedding models are detected automatically
48+
// Use pooling runner for embedding models
49+
args = append(args, "--runner", "pooling")
5050
case inference.BackendModeReranking:
5151
// vLLM does not have a specific flag for reranking
5252
case inference.BackendModeImageGeneration:

pkg/inference/backends/vllm/vllm_config_test.go

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ func (m *mockModelBundle) RootDir() string {
4646
func TestGetArgs(t *testing.T) {
4747
tests := []struct {
4848
name string
49+
mode inference.BackendMode
4950
config *inference.BackendConfiguration
5051
bundle *mockModelBundle
5152
expected []string
@@ -356,12 +357,52 @@ func TestGetArgs(t *testing.T) {
356357
`{"model_type":"llama"}`,
357358
},
358359
},
360+
{
361+
name: "embedding mode adds --runner pooling",
362+
mode: inference.BackendModeEmbedding,
363+
bundle: &mockModelBundle{
364+
safetensorsPath: "/path/to/model",
365+
},
366+
config: nil,
367+
expected: []string{
368+
"serve",
369+
"/path/to",
370+
"--uds",
371+
"/tmp/socket",
372+
"--runner",
373+
"pooling",
374+
},
375+
},
376+
{
377+
name: "embedding mode with other config",
378+
mode: inference.BackendModeEmbedding,
379+
bundle: &mockModelBundle{
380+
safetensorsPath: "/path/to/model",
381+
},
382+
config: &inference.BackendConfiguration{
383+
ContextSize: int32ptr(4096),
384+
},
385+
expected: []string{
386+
"serve",
387+
"/path/to",
388+
"--uds",
389+
"/tmp/socket",
390+
"--runner",
391+
"pooling",
392+
"--max-model-len",
393+
"4096",
394+
},
395+
},
359396
}
360397

361398
for _, tt := range tests {
362399
t.Run(tt.name, func(t *testing.T) {
363400
config := NewDefaultVLLMConfig()
364-
args, err := config.GetArgs(tt.bundle, "/tmp/socket", inference.BackendModeCompletion, tt.config)
401+
mode := tt.mode
402+
if mode == 0 {
403+
mode = inference.BackendModeCompletion
404+
}
405+
args, err := config.GetArgs(tt.bundle, "/tmp/socket", mode, tt.config)
365406

366407
if tt.expectError {
367408
if err == nil {

0 commit comments

Comments
 (0)