Skip to content

Commit 51e92f4

Browse files
committed
Add context cancellation support for Ctrl+C during model response
Add Ctrl+C handling to basic interactive mode for consistency
1 parent fb5e172 commit 51e92f4

2 files changed

Lines changed: 81 additions & 8 deletions

File tree

cmd/cli/commands/run.go

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@ package commands
22

33
import (
44
"bufio"
5+
"context"
56
"errors"
67
"fmt"
78
"io"
89
"os"
10+
"os/signal"
911
"strings"
12+
"syscall"
1013

1114
"github.com/charmbracelet/glamour"
1215
"github.com/docker/model-runner/cmd/cli/commands/completion"
@@ -201,8 +204,32 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.
201204
if sb.Len() > 0 && !multiline {
202205
userInput := sb.String()
203206

204-
if err := chatWithMarkdown(cmd, desktopClient, backend, model, userInput, apiKey); err != nil {
205-
cmd.PrintErr(handleClientError(err, "Failed to generate a response"))
207+
// Create a cancellable context for the chat request
208+
// This allows us to cancel the request if the user presses Ctrl+C during response generation
209+
chatCtx, cancelChat := context.WithCancel(context.Background())
210+
211+
// Set up signal handler to cancel the context on Ctrl+C
212+
sigChan := make(chan os.Signal, 1)
213+
signal.Notify(sigChan, syscall.SIGINT)
214+
go func() {
215+
<-sigChan
216+
cancelChat()
217+
}()
218+
219+
err := chatWithMarkdownContext(chatCtx, cmd, desktopClient, backend, model, userInput, apiKey)
220+
221+
// Clean up signal handler
222+
signal.Stop(sigChan)
223+
close(sigChan)
224+
cancelChat()
225+
226+
if err != nil {
227+
// Check if the error is due to context cancellation (Ctrl+C during response)
228+
if errors.Is(err, context.Canceled) {
229+
cmd.Println()
230+
} else {
231+
cmd.PrintErr(handleClientError(err, "Failed to generate a response"))
232+
}
206233
sb.Reset()
207234
continue
208235
}
@@ -233,8 +260,32 @@ func generateInteractiveBasic(cmd *cobra.Command, desktopClient *desktop.Client,
233260
continue
234261
}
235262

236-
if err := chatWithMarkdown(cmd, desktopClient, backend, model, userInput, apiKey); err != nil {
237-
cmd.PrintErr(handleClientError(err, "Failed to generate a response"))
263+
// Create a cancellable context for the chat request
264+
// This allows us to cancel the request if the user presses Ctrl+C during response generation
265+
chatCtx, cancelChat := context.WithCancel(context.Background())
266+
267+
// Set up signal handler to cancel the context on Ctrl+C
268+
sigChan := make(chan os.Signal, 1)
269+
signal.Notify(sigChan, syscall.SIGINT)
270+
go func() {
271+
<-sigChan
272+
cancelChat()
273+
}()
274+
275+
err = chatWithMarkdownContext(chatCtx, cmd, desktopClient, backend, model, userInput, apiKey)
276+
277+
// Clean up signal handler
278+
signal.Stop(sigChan)
279+
close(sigChan)
280+
cancelChat()
281+
282+
if err != nil {
283+
// Check if the error is due to context cancellation (Ctrl+C during response)
284+
if errors.Is(err, context.Canceled) {
285+
fmt.Println("\nUse Ctrl + d or /bye to exit.")
286+
} else {
287+
cmd.PrintErr(handleClientError(err, "Failed to generate a response"))
288+
}
238289
continue
239290
}
240291

@@ -425,21 +476,26 @@ func renderMarkdown(content string) (string, error) {
425476

426477
// chatWithMarkdown performs chat and streams the response with selective markdown rendering.
427478
func chatWithMarkdown(cmd *cobra.Command, client *desktop.Client, backend, model, prompt, apiKey string) error {
479+
return chatWithMarkdownContext(cmd.Context(), cmd, client, backend, model, prompt, apiKey)
480+
}
481+
482+
// chatWithMarkdownContext performs chat with context support and streams the response with selective markdown rendering.
483+
func chatWithMarkdownContext(ctx context.Context, cmd *cobra.Command, client *desktop.Client, backend, model, prompt, apiKey string) error {
428484
colorMode, _ := cmd.Flags().GetString("color")
429485
useMarkdown := shouldUseMarkdown(colorMode)
430486
debug, _ := cmd.Flags().GetBool("debug")
431487

432488
if !useMarkdown {
433489
// Simple case: just stream as plain text
434-
return client.Chat(backend, model, prompt, apiKey, func(content string) {
490+
return client.ChatWithContext(ctx, backend, model, prompt, apiKey, func(content string) {
435491
cmd.Print(content)
436492
}, false)
437493
}
438494

439495
// For markdown: use streaming buffer to render code blocks as they complete
440496
markdownBuffer := NewStreamingMarkdownBuffer()
441497

442-
err := client.Chat(backend, model, prompt, apiKey, func(content string) {
498+
err := client.ChatWithContext(ctx, backend, model, prompt, apiKey, func(content string) {
443499
// Use the streaming markdown buffer to intelligently render content
444500
rendered, err := markdownBuffer.AddContent(content, true)
445501
if err != nil {

cmd/cli/desktop/desktop.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,11 @@ func (c *Client) fullModelID(id string) (string, error) {
366366

367367
// Chat performs a chat request and streams the response content with selective markdown rendering.
368368
func (c *Client) Chat(backend, model, prompt, apiKey string, outputFunc func(string), shouldUseMarkdown bool) error {
369+
return c.ChatWithContext(context.Background(), backend, model, prompt, apiKey, outputFunc, shouldUseMarkdown)
370+
}
371+
372+
// ChatWithContext performs a chat request with context support for cancellation and streams the response content with selective markdown rendering.
373+
func (c *Client) ChatWithContext(ctx context.Context, backend, model, prompt, apiKey string, outputFunc func(string), shouldUseMarkdown bool) error {
369374
model = normalizeHuggingFaceModelName(model)
370375
if !strings.Contains(strings.Trim(model, "/"), "/") {
371376
// Do an extra API call to check if the model parameter isn't a model ID.
@@ -397,7 +402,8 @@ func (c *Client) Chat(backend, model, prompt, apiKey string, outputFunc func(str
397402
completionsPath = inference.InferencePrefix + "/v1/chat/completions"
398403
}
399404

400-
resp, err := c.doRequestWithAuth(
405+
resp, err := c.doRequestWithAuthContext(
406+
ctx,
401407
http.MethodPost,
402408
completionsPath,
403409
bytes.NewReader(jsonData),
@@ -432,6 +438,13 @@ func (c *Client) Chat(backend, model, prompt, apiKey string, outputFunc func(str
432438

433439
scanner := bufio.NewScanner(resp.Body)
434440
for scanner.Scan() {
441+
// Check if context was cancelled
442+
select {
443+
case <-ctx.Done():
444+
return ctx.Err()
445+
default:
446+
}
447+
435448
line := scanner.Text()
436449
if line == "" {
437450
continue
@@ -755,7 +768,11 @@ func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response,
755768

756769
// doRequestWithAuth is a helper function that performs HTTP requests with optional authentication
757770
func (c *Client) doRequestWithAuth(method, path string, body io.Reader, backend, apiKey string) (*http.Response, error) {
758-
req, err := http.NewRequest(method, c.modelRunner.URL(path), body)
771+
return c.doRequestWithAuthContext(context.Background(), method, path, body, backend, apiKey)
772+
}
773+
774+
func (c *Client) doRequestWithAuthContext(ctx context.Context, method, path string, body io.Reader, backend, apiKey string) (*http.Response, error) {
775+
req, err := http.NewRequestWithContext(ctx, method, c.modelRunner.URL(path), body)
759776
if err != nil {
760777
return nil, fmt.Errorf("error creating request: %w", err)
761778
}

0 commit comments

Comments
 (0)