@@ -46,6 +46,7 @@ func (m *mockModelBundle) RootDir() string {
4646func TestGetArgs (t * testing.T ) {
4747 tests := []struct {
4848 name string
49+ mode inference.BackendMode
4950 config * inference.BackendConfiguration
5051 bundle * mockModelBundle
5152 expected []string
@@ -356,12 +357,52 @@ func TestGetArgs(t *testing.T) {
356357 `{"model_type":"llama"}` ,
357358 },
358359 },
360+ {
361+ name : "embedding mode adds --runner pooling" ,
362+ mode : inference .BackendModeEmbedding ,
363+ bundle : & mockModelBundle {
364+ safetensorsPath : "/path/to/model" ,
365+ },
366+ config : nil ,
367+ expected : []string {
368+ "serve" ,
369+ "/path/to" ,
370+ "--uds" ,
371+ "/tmp/socket" ,
372+ "--runner" ,
373+ "pooling" ,
374+ },
375+ },
376+ {
377+ name : "embedding mode with other config" ,
378+ mode : inference .BackendModeEmbedding ,
379+ bundle : & mockModelBundle {
380+ safetensorsPath : "/path/to/model" ,
381+ },
382+ config : & inference.BackendConfiguration {
383+ ContextSize : int32ptr (4096 ),
384+ },
385+ expected : []string {
386+ "serve" ,
387+ "/path/to" ,
388+ "--uds" ,
389+ "/tmp/socket" ,
390+ "--runner" ,
391+ "pooling" ,
392+ "--max-model-len" ,
393+ "4096" ,
394+ },
395+ },
359396 }
360397
361398 for _ , tt := range tests {
362399 t .Run (tt .name , func (t * testing.T ) {
363400 config := NewDefaultVLLMConfig ()
364- args , err := config .GetArgs (tt .bundle , "/tmp/socket" , inference .BackendModeCompletion , tt .config )
401+ mode := tt .mode
402+ if mode == 0 {
403+ mode = inference .BackendModeCompletion
404+ }
405+ args , err := config .GetArgs (tt .bundle , "/tmp/socket" , mode , tt .config )
365406
366407 if tt .expectError {
367408 if err == nil {
0 commit comments