Skip to content

Commit a77d682

Browse files
authored
Merge pull request #639 from doringeman/fix-package
fix: add daemon-side model repackaging for Linux support
2 parents c4e3a34 + c0186c1 commit a77d682

5 files changed

Lines changed: 292 additions & 9 deletions

File tree

cmd/cli/commands/package.go

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ type builderInitResult struct {
214214
}
215215

216216
// initializeBuilder creates a package builder from GGUF, Safetensors, DDUF, or existing model
217-
func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitResult, error) {
217+
func initializeBuilder(ctx context.Context, cmd *cobra.Command, client *desktop.Client, opts packageOptions) (*builderInitResult, error) {
218218
result := &builderInitResult{}
219219

220220
if opts.fromModel != "" {
@@ -238,10 +238,14 @@ func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitRes
238238
// Package from existing model
239239
cmd.PrintErrf("Reading model from store: %q\n", opts.fromModel)
240240

241-
// Get the model from the local store
242241
mdl, err := distClient.GetModel(opts.fromModel)
243242
if err != nil {
244-
return nil, fmt.Errorf("get model from store: %w", err)
243+
cmd.PrintErrf("Model not found in local store, fetching from daemon...\n")
244+
245+
mdl, result.distClient, result.cleanupFunc, err = fetchModelFromDaemon(ctx, cmd, client, opts.fromModel)
246+
if err != nil {
247+
return nil, fmt.Errorf("get model from store: %w", err)
248+
}
245249
}
246250

247251
// Type assert to ModelArtifact - the Model from store implements both interfaces
@@ -306,7 +310,74 @@ func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitRes
306310
return result, nil
307311
}
308312

313+
func fetchModelFromDaemon(ctx context.Context, cmd *cobra.Command, client *desktop.Client, modelRef string) (types.Model, *distribution.Client, func(), error) {
314+
exportReader, err := client.ExportModel(ctx, modelRef)
315+
if err != nil {
316+
return nil, nil, nil, fmt.Errorf("export model from daemon: %w", err)
317+
}
318+
defer exportReader.Close()
319+
320+
tempDir, err := os.MkdirTemp("", "docker-model-package-*")
321+
if err != nil {
322+
return nil, nil, nil, fmt.Errorf("create temp directory: %w", err)
323+
}
324+
cleanup := func() {
325+
os.RemoveAll(tempDir)
326+
}
327+
328+
tempClient, err := distribution.NewClient(distribution.WithStoreRootPath(tempDir))
329+
if err != nil {
330+
cleanup()
331+
return nil, nil, nil, fmt.Errorf("create temp distribution client: %w", err)
332+
}
333+
334+
cmd.PrintErrf("Loading model from daemon...\n")
335+
modelID, err := tempClient.LoadModel(exportReader, nil)
336+
if err != nil {
337+
cleanup()
338+
return nil, nil, nil, fmt.Errorf("load model into temp store: %w", err)
339+
}
340+
341+
mdl, err := tempClient.GetModel(modelID)
342+
if err != nil {
343+
cleanup()
344+
return nil, nil, nil, fmt.Errorf("get model from temp store: %w", err)
345+
}
346+
347+
return mdl, tempClient, cleanup, nil
348+
}
349+
309350
func packageModel(ctx context.Context, cmd *cobra.Command, client *desktop.Client, opts packageOptions) error {
351+
// Use daemon-side repackaging for simple config-only changes (no new layers)
352+
canUseDaemonRepackage := opts.fromModel != "" &&
353+
!opts.push &&
354+
len(opts.licensePaths) == 0 &&
355+
opts.chatTemplatePath == "" &&
356+
opts.mmprojPath == "" &&
357+
len(opts.dirTarPaths) == 0 &&
358+
cmd.Flags().Changed("context-size")
359+
360+
if canUseDaemonRepackage {
361+
cmd.PrintErrf("Reading model from daemon: %q\n", opts.fromModel)
362+
cmd.PrintErrf("Setting context size %d\n", opts.contextSize)
363+
cmd.PrintErrln("Creating lightweight model variant...")
364+
365+
// Ensure standalone runner is available
366+
if _, err := ensureStandaloneRunnerAvailable(ctx, asPrinter(cmd), false); err != nil {
367+
return fmt.Errorf("unable to initialize standalone model runner: %w", err)
368+
}
369+
370+
repackageOpts := desktop.RepackageOptions{
371+
ContextSize: &opts.contextSize,
372+
}
373+
if err := client.RepackageModel(ctx, opts.fromModel, opts.tag, repackageOpts); err != nil {
374+
return fmt.Errorf("failed to create lightweight model: %w", err)
375+
}
376+
377+
cmd.PrintErrln("Model variant created successfully")
378+
return nil
379+
}
380+
310381
var (
311382
target builder.Target
312383
err error
@@ -327,7 +398,7 @@ func packageModel(ctx context.Context, cmd *cobra.Command, client *desktop.Clien
327398
}
328399

329400
// Initialize the package builder based on model format
330-
initResult, err := initializeBuilder(cmd, opts)
401+
initResult, err := initializeBuilder(ctx, cmd, client, opts)
331402
if err != nil {
332403
return err
333404
}

cmd/cli/desktop/desktop.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,3 +938,66 @@ func (c *Client) LoadModel(ctx context.Context, r io.Reader) error {
938938
}
939939
return nil
940940
}
941+
942+
func (c *Client) ExportModel(ctx context.Context, model string) (io.ReadCloser, error) {
943+
exportPath := fmt.Sprintf("%s/%s/export", inference.ModelsPrefix, model)
944+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.modelRunner.URL(exportPath), http.NoBody)
945+
if err != nil {
946+
return nil, fmt.Errorf("failed to create request: %w", err)
947+
}
948+
req.Header.Set("User-Agent", "docker-model-cli/"+Version)
949+
950+
resp, err := c.modelRunner.Client().Do(req)
951+
if err != nil {
952+
return nil, c.handleQueryError(err, exportPath)
953+
}
954+
955+
if resp.StatusCode == http.StatusNotFound {
956+
resp.Body.Close()
957+
return nil, errors.Wrap(ErrNotFound, model)
958+
}
959+
if resp.StatusCode != http.StatusOK {
960+
body, _ := io.ReadAll(resp.Body)
961+
resp.Body.Close()
962+
return nil, fmt.Errorf("export failed with status %s: %s", resp.Status, string(body))
963+
}
964+
965+
return resp.Body, nil
966+
}
967+
968+
type RepackageOptions struct {
969+
ContextSize *uint64 `json:"context_size,omitempty"`
970+
}
971+
972+
func (c *Client) RepackageModel(ctx context.Context, source, target string, opts RepackageOptions) error {
973+
repackagePath := fmt.Sprintf("%s/%s/repackage", inference.ModelsPrefix, source)
974+
975+
reqBody := struct {
976+
Target string `json:"target"`
977+
ContextSize *uint64 `json:"context_size,omitempty"`
978+
}{
979+
Target: target,
980+
ContextSize: opts.ContextSize,
981+
}
982+
983+
jsonData, err := json.Marshal(reqBody)
984+
if err != nil {
985+
return fmt.Errorf("error marshaling request: %w", err)
986+
}
987+
988+
resp, err := c.doRequestWithAuthContext(ctx, http.MethodPost, repackagePath, bytes.NewReader(jsonData))
989+
if err != nil {
990+
return c.handleQueryError(err, repackagePath)
991+
}
992+
defer resp.Body.Close()
993+
994+
if resp.StatusCode == http.StatusNotFound {
995+
return errors.Wrap(ErrNotFound, source)
996+
}
997+
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
998+
body, _ := io.ReadAll(resp.Body)
999+
return fmt.Errorf("repackage failed with status %s: %s", resp.Status, string(body))
1000+
}
1001+
1002+
return nil
1003+
}

pkg/distribution/distribution/client.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strings"
1111

1212
"github.com/docker/model-runner/pkg/distribution/huggingface"
13+
"github.com/docker/model-runner/pkg/distribution/internal/mutate"
1314
"github.com/docker/model-runner/pkg/distribution/internal/progress"
1415
"github.com/docker/model-runner/pkg/distribution/internal/store"
1516
"github.com/docker/model-runner/pkg/distribution/oci"
@@ -615,6 +616,59 @@ func (c *Client) ResetStore() error {
615616
return nil
616617
}
617618

619+
func (c *Client) ExportModel(reference string, w io.Writer) error {
620+
c.log.Infoln("Exporting model:", utils.SanitizeForLog(reference))
621+
normalizedRef := c.normalizeModelName(reference)
622+
mdl, err := c.store.Read(normalizedRef)
623+
if err != nil {
624+
c.log.Errorln("Failed to get model for export:", err, "reference:", utils.SanitizeForLog(reference))
625+
return fmt.Errorf("get model '%q': %w", utils.SanitizeForLog(reference), err)
626+
}
627+
628+
target, err := tarball.NewTarget(w)
629+
if err != nil {
630+
return fmt.Errorf("create tarball target: %w", err)
631+
}
632+
633+
if err := target.Write(context.Background(), mdl, nil); err != nil {
634+
c.log.Errorln("Failed to export model:", err, "reference:", utils.SanitizeForLog(reference))
635+
return fmt.Errorf("export model: %w", err)
636+
}
637+
638+
c.log.Infoln("Successfully exported model:", utils.SanitizeForLog(reference))
639+
return nil
640+
}
641+
642+
type RepackageOptions struct {
643+
ContextSize *uint64
644+
}
645+
646+
func (c *Client) RepackageModel(sourceRef string, targetRef string, opts RepackageOptions) error {
647+
c.log.Infoln("Repackaging model:", utils.SanitizeForLog(sourceRef), "->", utils.SanitizeForLog(targetRef))
648+
649+
normalizedSource := c.normalizeModelName(sourceRef)
650+
normalizedTarget := c.normalizeModelName(targetRef)
651+
652+
mdl, err := c.store.Read(normalizedSource)
653+
if err != nil {
654+
c.log.Errorln("Failed to get model for repackaging:", err, "reference:", utils.SanitizeForLog(sourceRef))
655+
return fmt.Errorf("get model '%q': %w", utils.SanitizeForLog(sourceRef), err)
656+
}
657+
658+
var modifiedModel types.ModelArtifact = mdl
659+
if opts.ContextSize != nil {
660+
modifiedModel = mutate.ContextSize(modifiedModel, int32(*opts.ContextSize))
661+
}
662+
663+
if err := c.store.WriteLightweight(modifiedModel, []string{normalizedTarget}); err != nil {
664+
c.log.Errorln("Failed to write repackaged model:", err, "target:", utils.SanitizeForLog(targetRef))
665+
return fmt.Errorf("write repackaged model: %w", err)
666+
}
667+
668+
c.log.Infoln("Successfully repackaged model:", utils.SanitizeForLog(sourceRef), "->", utils.SanitizeForLog(targetRef))
669+
return nil
670+
}
671+
618672
// GetBundle returns a types.Bundle containing the model, creating one as necessary
619673
func (c *Client) GetBundle(ref string) (types.ModelBundle, error) {
620674
normalizedRef := c.normalizeModelName(ref)

pkg/inference/models/http_handler.go

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc {
8484
"POST " + inference.ModelsPrefix + "/create": h.handleCreateModel,
8585
"POST " + inference.ModelsPrefix + "/load": h.handleLoadModel,
8686
"GET " + inference.ModelsPrefix: h.handleGetModels,
87-
"GET " + inference.ModelsPrefix + "/{name...}": h.handleGetModel,
87+
"GET " + inference.ModelsPrefix + "/{nameAndAction...}": h.handleModelGetAction,
8888
"DELETE " + inference.ModelsPrefix + "/{name...}": h.handleDeleteModel,
8989
"POST " + inference.ModelsPrefix + "/{nameAndAction...}": h.handleModelAction,
9090
"DELETE " + inference.ModelsPrefix + "/purge": h.handlePurge,
@@ -142,6 +142,35 @@ func (h *HTTPHandler) handleLoadModel(w http.ResponseWriter, r *http.Request) {
142142
}
143143
}
144144

145+
func (h *HTTPHandler) handleModelGetAction(w http.ResponseWriter, r *http.Request) {
146+
nameAndAction := r.PathValue("nameAndAction")
147+
model, action := path.Split(nameAndAction)
148+
model = strings.TrimRight(model, "/")
149+
150+
if action == "export" {
151+
h.handleExportModel(w, r, model)
152+
return
153+
}
154+
155+
h.handleGetModelByRef(w, r, nameAndAction)
156+
}
157+
158+
func (h *HTTPHandler) handleExportModel(w http.ResponseWriter, r *http.Request, modelRef string) {
159+
w.Header().Set("Content-Type", "application/x-tar")
160+
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", modelRef+".tar"))
161+
162+
err := h.manager.Export(modelRef, w)
163+
if err != nil {
164+
if errors.Is(err, distribution.ErrModelNotFound) {
165+
http.Error(w, err.Error(), http.StatusNotFound)
166+
return
167+
}
168+
h.log.Warnln("Error while exporting model:", err)
169+
http.Error(w, err.Error(), http.StatusInternalServerError)
170+
return
171+
}
172+
}
173+
145174
// handleGetModels handles GET <inference-prefix>/models requests.
146175
func (h *HTTPHandler) handleGetModels(w http.ResponseWriter, r *http.Request) {
147176
apiModels, err := h.manager.List()
@@ -160,7 +189,10 @@ func (h *HTTPHandler) handleGetModels(w http.ResponseWriter, r *http.Request) {
160189
// handleGetModel handles GET <inference-prefix>/models/{name} requests.
161190
func (h *HTTPHandler) handleGetModel(w http.ResponseWriter, r *http.Request) {
162191
modelRef := r.PathValue("name")
192+
h.handleGetModelByRef(w, r, modelRef)
193+
}
163194

195+
func (h *HTTPHandler) handleGetModelByRef(w http.ResponseWriter, r *http.Request, modelRef string) {
164196
// Parse remote query parameter
165197
remote := false
166198
if r.URL.Query().Has("remote") {
@@ -355,10 +387,8 @@ func (h *HTTPHandler) handleOpenAIGetModel(w http.ResponseWriter, r *http.Reques
355387
}
356388
}
357389

358-
// handleTagModel handles POST <inference-prefix>/models/{nameAndAction} requests.
359-
// Action is one of:
360-
// - tag: tag the model with a repository and tag (e.g. POST <inference-prefix>/models/my-org/my-repo:latest/tag})
361-
// - push: pushes a tagged model to the registry
390+
// handleModelAction handles POST <inference-prefix>/models/{nameAndAction} requests.
391+
// Actions: tag, push, repackage
362392
func (h *HTTPHandler) handleModelAction(w http.ResponseWriter, r *http.Request) {
363393
model, action := path.Split(r.PathValue("nameAndAction"))
364394
model = strings.TrimRight(model, "/")
@@ -368,6 +398,8 @@ func (h *HTTPHandler) handleModelAction(w http.ResponseWriter, r *http.Request)
368398
h.handleTagModel(w, r, model)
369399
case "push":
370400
h.handlePushModel(w, r, model)
401+
case "repackage":
402+
h.handleRepackageModel(w, r, model)
371403
default:
372404
http.Error(w, fmt.Sprintf("unknown action %q", action), http.StatusNotFound)
373405
}
@@ -438,6 +470,49 @@ func (h *HTTPHandler) handlePushModel(w http.ResponseWriter, r *http.Request, mo
438470
}
439471
}
440472

473+
type RepackageRequest struct {
474+
Target string `json:"target"`
475+
ContextSize *uint64 `json:"context_size,omitempty"`
476+
}
477+
478+
func (h *HTTPHandler) handleRepackageModel(w http.ResponseWriter, r *http.Request, model string) {
479+
var req RepackageRequest
480+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
481+
http.Error(w, "invalid request body: "+err.Error(), http.StatusBadRequest)
482+
return
483+
}
484+
485+
if req.Target == "" {
486+
http.Error(w, "target is required", http.StatusBadRequest)
487+
return
488+
}
489+
490+
opts := RepackageOptions{
491+
ContextSize: req.ContextSize,
492+
}
493+
494+
if err := h.manager.Repackage(model, req.Target, opts); err != nil {
495+
if errors.Is(err, distribution.ErrModelNotFound) {
496+
http.Error(w, err.Error(), http.StatusNotFound)
497+
return
498+
}
499+
h.log.Warnf("Failed to repackage model %q: %v", utils.SanitizeForLog(model, -1), err)
500+
http.Error(w, err.Error(), http.StatusInternalServerError)
501+
return
502+
}
503+
504+
w.Header().Set("Content-Type", "application/json")
505+
w.WriteHeader(http.StatusCreated)
506+
response := map[string]string{
507+
"message": fmt.Sprintf("Model repackaged successfully as %q", req.Target),
508+
"source": model,
509+
"target": req.Target,
510+
}
511+
if err := json.NewEncoder(w).Encode(response); err != nil {
512+
h.log.Warnln("Error while encoding repackage response:", err)
513+
}
514+
}
515+
441516
// handlePurge handles DELETE <inference-prefix>/models/purge requests.
442517
func (h *HTTPHandler) handlePurge(w http.ResponseWriter, _ *http.Request) {
443518
err := h.manager.Purge()

pkg/inference/models/manager.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,23 @@ func (m *Manager) Purge() error {
416416
}
417417
return nil
418418
}
419+
420+
func (m *Manager) Export(ref string, w io.Writer) error {
421+
if m.distributionClient == nil {
422+
return fmt.Errorf("model distribution service unavailable")
423+
}
424+
return m.distributionClient.ExportModel(ref, w)
425+
}
426+
427+
type RepackageOptions struct {
428+
ContextSize *uint64 `json:"context_size,omitempty"`
429+
}
430+
431+
func (m *Manager) Repackage(sourceRef string, targetRef string, opts RepackageOptions) error {
432+
if m.distributionClient == nil {
433+
return fmt.Errorf("model distribution service unavailable")
434+
}
435+
return m.distributionClient.RepackageModel(sourceRef, targetRef, distribution.RepackageOptions{
436+
ContextSize: opts.ContextSize,
437+
})
438+
}

0 commit comments

Comments
 (0)