Skip to content

Commit d2eac5f

Browse files
committed
improve DMR support
- fixes session title generation - adds 'context_size' provider_opt for DMR usage instead of giving 'max_tokens' double responsibility to avoid confusion - improved thinking budget support and fix for NoThinking() - improves how flags are sent to the DMR model/runtime configuration endpoint - clarify docs on sampling/runtime params Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent 42ca212 commit d2eac5f

4 files changed

Lines changed: 1432 additions & 209 deletions

File tree

docs/providers/dmr/index.md

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,29 +64,111 @@ models:
6464
model: ai/qwen3
6565
max_tokens: 8192
6666
provider_opts:
67-
runtime_flags: ["--ngl=33", "--top-p=0.9"]
67+
runtime_flags: ["--threads", "8"]
6868
```
6969

7070
Runtime flags also accept a single string:
7171

7272
```yaml
7373
provider_opts:
74-
runtime_flags: "--ngl=33 --top-p=0.9"
74+
runtime_flags: "--threads 8"
7575
```
7676

77-
## Parameter Mapping
77+
Use only flags your Model Runner backend allows (see `docker model configure --help` and backend docs). **Do not** put sampling parameters (`temperature`, `top_p`, penalties) in `runtime_flags` — set them on the model (`temperature`, `top_p`, etc.); they are sent **per request** via the OpenAI-compatible chat API.
7878

79-
docker-agent model config fields map to llama.cpp flags automatically:
79+
## Context size
8080

81-
| Config | llama.cpp Flag |
82-
| ------------------- | --------------------- |
83-
| `temperature` | `--temp` |
84-
| `top_p` | `--top-p` |
85-
| `frequency_penalty` | `--frequency-penalty` |
86-
| `presence_penalty` | `--presence-penalty` |
87-
| `max_tokens` | `--context-size` |
81+
`max_tokens` controls the **maximum output tokens** per chat completion request. To set the engine's **total context window**, use `provider_opts.context_size`:
8882

89-
`runtime_flags` always take priority over derived flags on conflict.
83+
```yaml
84+
models:
85+
local:
86+
provider: dmr
87+
model: ai/qwen3
88+
max_tokens: 4096 # max output tokens (per-request)
89+
provider_opts:
90+
context_size: 32768 # total context window (sent via _configure)
91+
```
92+
93+
If `context_size` is omitted, Model Runner uses its default. `max_tokens` is **not** used as the context window.
94+
95+
## Thinking / reasoning budget
96+
97+
When using the **llama.cpp** backend, `thinking_budget` is sent as structured `llamacpp.reasoning-budget` on `_configure` (maps to `--reasoning-budget`). String efforts use the same token mapping as other providers; `adaptive` maps to unlimited (`-1`).
98+
99+
When using the **vLLM** backend, `thinking_budget` is sent as `thinking_token_budget` in each chat completion request. Effort levels map to token counts using the same scale as other providers; `adaptive` maps to unlimited (`-1`).
100+
101+
```yaml
102+
models:
103+
local:
104+
provider: dmr
105+
model: ai/qwen3
106+
thinking_budget: medium # llama.cpp: reasoning-budget=8192; vLLM: thinking_token_budget=8192
107+
```
108+
109+
On **MLX** and **SGLang** backends, `thinking_budget` is silently ignored — those engines do not currently expose a per-request reasoning token budget knob.
110+
111+
## vLLM-specific configuration
112+
113+
When running a model on the **vLLM** backend, additional engine-level settings can be passed via `provider_opts` and are forwarded to model-runner's `_configure` endpoint:
114+
115+
- `gpu_memory_utilization` — fraction of GPU memory (0.0–1.0) vLLM may use. Values outside this range are rejected.
116+
- `hf_overrides` — map of Hugging Face config overrides applied when vLLM loads the model.
117+
118+
```yaml
119+
models:
120+
vllm-local:
121+
provider: dmr
122+
model: ai/some-model-safetensors
123+
provider_opts:
124+
gpu_memory_utilization: 0.9
125+
hf_overrides:
126+
max_model_len: 8192
127+
dtype: bfloat16
128+
```
129+
130+
`hf_overrides` keys (including nested ones) must match `^[a-zA-Z_][a-zA-Z0-9_]*$` — the same rule model-runner enforces server-side to block injection via flags. Invalid keys are rejected at client creation time so you fail fast instead of after a round-trip.
131+
132+
These options are ignored on non-vLLM backends.
133+
134+
## Keeping models resident in memory (`keep_alive`)
135+
136+
By default model-runner unloads idle models after a few minutes. Override the idle timeout via `provider_opts.keep_alive`:
137+
138+
```yaml
139+
models:
140+
sticky:
141+
provider: dmr
142+
model: ai/qwen3
143+
provider_opts:
144+
keep_alive: "30m" # duration string
145+
# keep_alive: "0" # unload immediately after each request
146+
# keep_alive: "-1" # keep loaded forever
147+
```
148+
149+
Accepted values: any Go duration string (`"30s"`, `"5m"`, `"1h"`, `"2h30m"`), `"0"` (immediate unload), or `"-1"` (never unload). Invalid values are rejected before the configure request is sent.
150+
151+
## Operating mode (`mode`)
152+
153+
Model-runner normally infers the backend mode from the request path. You can pin it explicitly via `provider_opts.mode`:
154+
155+
```yaml
156+
provider_opts:
157+
mode: embedding # one of: completion, embedding, reranking, image-generation
158+
```
159+
160+
Most agents don't need this — leave it unset unless you know you need it.
161+
162+
## Raw runtime flags (`raw_runtime_flags`)
163+
164+
`runtime_flags` (a list) is the preferred way to pass flags. If you have a pre-built command-line string you'd rather ship verbatim, use `raw_runtime_flags` instead:
165+
166+
```yaml
167+
provider_opts:
168+
raw_runtime_flags: "--threads 8 --batch-size 512"
169+
```
170+
171+
Model-runner parses the string with shell-style word splitting. `runtime_flags` and `raw_runtime_flags` are mutually exclusive — setting both is an error.
90172

91173
## Speculative Decoding
92174

pkg/model/provider/dmr/client.go

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"log/slog"
10+
"maps"
1011
"net/http"
1112
"os"
1213
"time"
@@ -54,6 +55,7 @@ type Client struct {
5455
client openai.Client
5556
baseURL string
5657
httpClient *http.Client
58+
engine string
5759
}
5860

5961
// NewClient creates a new DMR client from the provided configuration
@@ -103,18 +105,28 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt
103105

104106
clientOptions = append(clientOptions, option.WithBaseURL(baseURL), option.WithAPIKey("")) // DMR doesn't need auth
105107

106-
// Build runtime flags from ModelConfig and engine
107-
contextSize, providerRuntimeFlags, specOpts := parseDMRProviderOpts(cfg)
108-
configFlags := buildRuntimeFlagsFromModelConfig(engine, cfg)
109-
finalFlags, warnings := mergeRuntimeFlagsPreferUser(configFlags, providerRuntimeFlags)
110-
for _, w := range warnings {
111-
slog.Warn(w)
108+
parsed, err := parseDMRProviderOpts(engine, cfg)
109+
if err != nil {
110+
slog.Error("DMR provider_opts invalid", "error", err, "model", cfg.Model)
111+
return nil, err
112112
}
113-
slog.Debug("DMR provider_opts parsed", "model", cfg.Model, "context_size", contextSize, "runtime_flags", finalFlags, "speculative_opts", specOpts, "engine", engine)
113+
backendCfg := buildConfigureBackendConfig(parsed.contextSize, parsed.runtimeFlags, parsed.specOpts, parsed.llamaCpp, parsed.vllm, parsed.keepAlive)
114+
slog.Debug("DMR provider_opts parsed",
115+
"model", cfg.Model,
116+
"engine", engine,
117+
"context_size", derefInt64(parsed.contextSize),
118+
"runtime_flags", parsed.runtimeFlags,
119+
"raw_runtime_flags", parsed.rawRuntimeFlags,
120+
"mode", derefString(parsed.mode),
121+
"keep_alive", derefString(parsed.keepAlive),
122+
"speculative_opts", parsed.specOpts,
123+
"llamacpp", parsed.llamaCpp,
124+
"vllm", parsed.vllm,
125+
)
114126
// Skip model configuration when generating titles to avoid reconfiguring the model
115127
// with different settings (e.g., smaller max_tokens) that would affect the main agent.
116128
if !globalOptions.GeneratingTitle() {
117-
if err := configureModel(ctx, httpClient, baseURL, cfg.Model, contextSize, finalFlags, specOpts); err != nil {
129+
if err := configureModel(ctx, httpClient, baseURL, cfg.Model, backendCfg, parsed.mode, parsed.rawRuntimeFlags); err != nil {
118130
slog.Debug("model configure via API skipped or failed", "error", err)
119131
}
120132
}
@@ -129,6 +141,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt
129141
client: openai.NewClient(clientOptions...),
130142
baseURL: baseURL,
131143
httpClient: httpClient,
144+
engine: engine,
132145
}, nil
133146
}
134147

@@ -214,6 +227,37 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat
214227
}
215228
}
216229

230+
// Collect per-request extra JSON fields. SetExtraFields replaces the map
231+
// wholesale, so merge all contributors before a single Set call.
232+
extraFields := map[string]any{}
233+
234+
// NoThinking: disable reasoning at the chat-template level. llama.cpp and
235+
// vLLM both honor chat_template_kwargs.enable_thinking=false for Qwen3 /
236+
// Hermes / DeepSeek-R1 style templates; other engines ignore unknown keys.
237+
// Also enforce a max_tokens floor so that if the engine/template ignores
238+
// the hint, internal reasoning can't starve visible output (e.g. session
239+
// title generation requests max_tokens=20).
240+
if c.ModelOptions.NoThinking() {
241+
extraFields["chat_template_kwargs"] = map[string]any{"enable_thinking": false}
242+
if c.ModelConfig.MaxTokens != nil && *c.ModelConfig.MaxTokens < noThinkingMinOutputTokens {
243+
params.MaxTokens = openai.Int(noThinkingMinOutputTokens)
244+
slog.Debug("DMR NoThinking: bumped max_tokens floor",
245+
"from", *c.ModelConfig.MaxTokens, "to", noThinkingMinOutputTokens)
246+
}
247+
}
248+
249+
// vLLM-specific per-request fields (e.g. thinking_token_budget).
250+
if c.engine == engineVLLM {
251+
if fields := buildVLLMRequestFields(&c.ModelConfig); fields != nil {
252+
maps.Copy(extraFields, fields)
253+
}
254+
}
255+
256+
if len(extraFields) > 0 {
257+
params.SetExtraFields(extraFields)
258+
slog.Debug("DMR extra request fields applied", "fields", extraFields)
259+
}
260+
217261
// Log the request in JSON format for debugging
218262
if requestJSON, err := json.Marshal(params); err == nil {
219263
slog.Debug("DMR chat completion request", "request", string(requestJSON))
@@ -222,7 +266,7 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat
222266
}
223267

224268
if structuredOutput := c.ModelOptions.StructuredOutput(); structuredOutput != nil {
225-
slog.Debug("Adding structured output to DMR request", "structured_output", structuredOutput)
269+
slog.Debug("Adding structured output to DMR request", "name", structuredOutput.Name, "strict", structuredOutput.Strict)
226270

227271
params.ResponseFormat.OfJSONSchema = &openai.ResponseFormatJSONSchemaParam{
228272
JSONSchema: openai.ResponseFormatJSONSchemaJSONSchemaParam{

0 commit comments

Comments
 (0)