Skip to content

Commit faf1595

Browse files
committed
Add /set command and context length configuration
Add a /set command to the interactive mode to allow users to configure parameters like num_ctx during runtime. Also implement environment variable support for default context length (DMR_CONTEXT_LENGTH) and ensure the scheduler respects these configurations when setting up backend runners. Signed-off-by: Eric Curtin <eric.curtin@docker.com>
1 parent 3245480 commit faf1595

4 files changed

Lines changed: 112 additions & 35 deletions

File tree

cmd/cli/commands/run.go

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@ import (
88
"io"
99
"os"
1010
"os/signal"
11+
"strconv"
1112
"strings"
1213
"syscall"
1314

1415
"github.com/charmbracelet/glamour"
1516
"github.com/docker/model-runner/cmd/cli/commands/completion"
1617
"github.com/docker/model-runner/cmd/cli/desktop"
1718
"github.com/docker/model-runner/cmd/cli/readline"
19+
"github.com/docker/model-runner/pkg/inference"
20+
"github.com/docker/model-runner/pkg/inference/scheduling"
1821
"github.com/fatih/color"
1922
"github.com/muesli/termenv"
2023
"github.com/spf13/cobra"
@@ -90,11 +93,12 @@ func readMultilineInput(cmd *cobra.Command, scanner *bufio.Scanner) (string, err
9093
func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.Client, model string) error {
9194
usage := func() {
9295
fmt.Fprintln(os.Stderr, "Available Commands:")
93-
fmt.Fprintln(os.Stderr, " /set system Set or update the system message")
9496
fmt.Fprintln(os.Stderr, " /bye Exit")
97+
fmt.Fprintln(os.Stderr, " /set Set a session variable")
9598
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
9699
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
97100
fmt.Fprintln(os.Stderr, " /? files Help for file inclusion with @ symbol")
101+
fmt.Fprintln(os.Stderr, " /? set Help for /set command")
98102
fmt.Fprintln(os.Stderr, "")
99103
fmt.Fprintln(os.Stderr, `Use """ to begin a multi-line message.`)
100104
fmt.Fprintln(os.Stderr, "")
@@ -134,6 +138,13 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.
134138
fmt.Fprintln(os.Stderr, "")
135139
}
136140

141+
usageSet := func() {
142+
fmt.Fprintln(os.Stderr, "Available /set commands:")
143+
fmt.Fprintln(os.Stderr, " /set system <message> Set system message for the conversation")
144+
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <n> Set context window size (in tokens)")
145+
fmt.Fprintln(os.Stderr, "")
146+
}
147+
137148
scanner, err := readline.New(readline.Prompt{
138149
Prompt: "> ",
139150
AltPrompt: ". ",
@@ -204,36 +215,79 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.
204215
case scanner.Pasting:
205216
fmt.Fprintln(&sb, line)
206217
continue
207-
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
218+
case strings.HasPrefix(line, "/"):
208219
args := strings.Fields(line)
209-
if len(args) > 1 {
220+
switch args[0] {
221+
case "/help", "/?":
222+
if len(args) > 1 {
223+
switch args[1] {
224+
case "shortcut", "shortcuts":
225+
usageShortcuts()
226+
case "file", "files":
227+
usageFiles()
228+
case "set":
229+
usageSet()
230+
default:
231+
usage()
232+
}
233+
} else {
234+
usage()
235+
}
236+
case "/exit", "/bye":
237+
return nil
238+
case "/set":
239+
if len(args) < 2 {
240+
usageSet()
241+
continue
242+
}
210243
switch args[1] {
211-
case "shortcut", "shortcuts":
212-
usageShortcuts()
213-
case "file", "files":
214-
usageFiles()
244+
case "system":
245+
// Extract the system prompt text after "/set system"
246+
if len(args) > 2 {
247+
systemPrompt = strings.Join(args[2:], " ")
248+
} else {
249+
systemPrompt = ""
250+
}
251+
if systemPrompt == "" {
252+
fmt.Fprintln(os.Stderr, "Cleared system message.")
253+
} else {
254+
fmt.Fprintln(os.Stderr, "Set system message.")
255+
}
256+
case "parameter":
257+
if len(args) < 4 {
258+
fmt.Fprintln(os.Stderr, "Usage: /set parameter <name> <value>")
259+
fmt.Fprintln(os.Stderr, "Available parameters: num_ctx")
260+
continue
261+
}
262+
paramName, paramValue := args[2], args[3]
263+
switch paramName {
264+
case "num_ctx":
265+
if val, err := strconv.ParseInt(paramValue, 10, 32); err == nil && val > 0 {
266+
ctx := int32(val)
267+
if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
268+
Model: model,
269+
BackendConfiguration: inference.BackendConfiguration{
270+
ContextSize: &ctx,
271+
},
272+
}); err != nil {
273+
fmt.Fprintf(os.Stderr, "Failed to set num_ctx: %v\n", err)
274+
} else {
275+
fmt.Fprintf(os.Stderr, "Set num_ctx to %d\n", val)
276+
}
277+
} else {
278+
fmt.Fprintf(os.Stderr, "Invalid value for num_ctx: %s (must be a positive integer)\n", paramValue)
279+
}
280+
default:
281+
fmt.Fprintf(os.Stderr, "Unknown parameter: %s\n", paramName)
282+
fmt.Fprintln(os.Stderr, "Available parameters: num_ctx")
283+
}
215284
default:
216-
usage()
285+
usageSet()
217286
}
218-
} else {
219-
usage()
220-
}
221-
continue
222-
case strings.HasPrefix(line, "/set system ") || line == "/set system":
223-
// Extract the system prompt text after "/set system "
224-
systemPrompt = strings.TrimPrefix(line, "/set system ")
225-
systemPrompt = strings.TrimSpace(systemPrompt)
226-
if systemPrompt == "" {
227-
fmt.Fprintln(os.Stderr, "Cleared system message.")
228-
} else {
229-
fmt.Fprintln(os.Stderr, "Set system message.")
287+
default:
288+
fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
230289
}
231290
continue
232-
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
233-
return nil
234-
case strings.HasPrefix(line, "/"):
235-
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
236-
continue
237291
default:
238292
sb.WriteString(line)
239293
}

main.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99
"os/signal"
1010
"path/filepath"
11+
"strconv"
1112
"strings"
1213
"syscall"
1314
"time"
@@ -86,6 +87,18 @@ func main() {
8687
mlxServerPath := os.Getenv("MLX_SERVER_PATH")
8788
diffusersServerPath := os.Getenv("DIFFUSERS_SERVER_PATH")
8889

90+
// Parse default context length from environment
91+
var defaultContextLength *int32
92+
if ctxStr := os.Getenv("DMR_CONTEXT_LENGTH"); ctxStr != "" {
93+
if parsed, err := strconv.ParseInt(ctxStr, 10, 32); err == nil && parsed > 0 {
94+
ctx := int32(parsed)
95+
defaultContextLength = &ctx
96+
log.Infof("DMR_CONTEXT_LENGTH: %d", ctx)
97+
} else {
98+
log.Warnf("Invalid DMR_CONTEXT_LENGTH: %s (must be a positive integer)", ctxStr)
99+
}
100+
}
101+
89102
// Create a proxy-aware HTTP transport
90103
// Use a safe type assertion with fallback, and explicitly set Proxy to http.ProxyFromEnvironment
91104
var baseTransport *http.Transport
@@ -197,6 +210,7 @@ func main() {
197210
"",
198211
false,
199212
),
213+
defaultContextLength,
200214
)
201215

202216
// Create the HTTP handler for the scheduler

pkg/inference/scheduling/scheduler.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ type Scheduler struct {
4040
tracker *metrics.Tracker
4141
// openAIRecorder is used to record OpenAI API inference requests and responses.
4242
openAIRecorder *metrics.OpenAIRecorder
43+
// defaultContextLength is the default context length from environment variable.
44+
defaultContextLength *int32
4345
}
4446

4547
// NewScheduler creates a new inference scheduler.
@@ -50,19 +52,21 @@ func NewScheduler(
5052
modelManager *models.Manager,
5153
httpClient *http.Client,
5254
tracker *metrics.Tracker,
55+
defaultContextLength *int32,
5356
) *Scheduler {
5457
openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager)
5558

5659
// Create the scheduler.
5760
s := &Scheduler{
58-
log: log,
59-
backends: backends,
60-
defaultBackend: defaultBackend,
61-
modelManager: modelManager,
62-
installer: newInstaller(log, backends, httpClient),
63-
loader: newLoader(log, backends, modelManager, openAIRecorder),
64-
tracker: tracker,
65-
openAIRecorder: openAIRecorder,
61+
log: log,
62+
backends: backends,
63+
defaultBackend: defaultBackend,
64+
modelManager: modelManager,
65+
installer: newInstaller(log, backends, httpClient),
66+
loader: newLoader(log, backends, modelManager, openAIRecorder),
67+
tracker: tracker,
68+
openAIRecorder: openAIRecorder,
69+
defaultContextLength: defaultContextLength,
6670
}
6771

6872
// Scheduler successfully initialized.
@@ -253,7 +257,12 @@ func (s *Scheduler) ConfigureRunner(ctx context.Context, backend inference.Backe
253257

254258
// Build runner configuration with shared settings
255259
var runnerConfig inference.BackendConfiguration
256-
runnerConfig.ContextSize = req.ContextSize
260+
// Use request context size if provided, otherwise fall back to default from env var
261+
if req.ContextSize != nil {
262+
runnerConfig.ContextSize = req.ContextSize
263+
} else if s.defaultContextLength != nil {
264+
runnerConfig.ContextSize = s.defaultContextLength
265+
}
257266
runnerConfig.Speculative = req.Speculative
258267
runnerConfig.RuntimeFlags = runtimeFlags
259268

pkg/inference/scheduling/scheduler_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func TestCors(t *testing.T) {
3333
discard := logrus.New()
3434
discard.SetOutput(io.Discard)
3535
log := logrus.NewEntry(discard)
36-
s := NewScheduler(log, nil, nil, nil, nil, nil)
36+
s := NewScheduler(log, nil, nil, nil, nil, nil, nil)
3737
httpHandler := NewHTTPHandler(s, nil, []string{"*"})
3838
req := httptest.NewRequest(http.MethodOptions, "http://model-runner.docker.internal"+tt.path, http.NoBody)
3939
req.Header.Set("Origin", "docker.com")

0 commit comments

Comments
 (0)