Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cmd/cli/commands/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/distribution/tarball"
"github.com/docker/model-runner/pkg/distribution/types"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/google/go-containerregistry/pkg/name"
"github.com/spf13/cobra"

Expand Down Expand Up @@ -313,7 +314,9 @@ func newModelRunnerTarget(client *desktop.Client, tag string) (*modelRunnerTarge
}
if tag != "" {
var err error
target.tag, err = name.NewTag(tag)
// Normalize the tag to add default namespace (ai/) and tag (:latest) if missing
normalizedTag := models.NormalizeModelName(tag)
target.tag, err = name.NewTag(normalizedTag)
if err != nil {
return nil, fmt.Errorf("invalid tag: %w", err)
}
Expand Down
2 changes: 2 additions & 0 deletions cmd/cli/commands/tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ func newTagCmd() *cobra.Command {
func tagModel(cmd *cobra.Command, desktopClient *desktop.Client, source, target string) error {
// Normalize source model name to add default org and tag if missing
source = models.NormalizeModelName(source)
// Normalize target model name to add default org and tag if missing
target = models.NormalizeModelName(target)
// Ensure tag is valid
tag, err := name.NewTag(target)
if err != nil {
Expand Down
48 changes: 48 additions & 0 deletions cmd/cli/commands/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,51 @@ func TestStripDefaultsFromModelName(t *testing.T) {
})
}
}

// TestNormalizeModelNameConsistency verifies that locally packaged models
// (without namespace) get normalized the same way as other operations.
// This test documents the fix for the bug where `docker model package my-model`
// would create a model that couldn't be run with `docker model run my-model`.
func TestNormalizeModelNameConsistency(t *testing.T) {
tests := []struct {
name string
userProvidedName string
expectedNormalizedName string
description string
}{
{
name: "locally packaged model without namespace",
userProvidedName: "my-model",
expectedNormalizedName: "ai/my-model:latest",
description: "When a user packages a local model as 'my-model', it should be normalized to 'ai/my-model:latest'",
},
{
name: "locally packaged model without namespace but with tag",
userProvidedName: "my-model:v1.0",
expectedNormalizedName: "ai/my-model:v1.0",
description: "When a user packages a local model as 'my-model:v1.0', it should be normalized to 'ai/my-model:v1.0'",
},
{
name: "model with explicit namespace",
userProvidedName: "myorg/my-model",
expectedNormalizedName: "myorg/my-model:latest",
description: "When a user packages a model with explicit org 'myorg/my-model', it should keep the org",
},
{
name: "model with ai namespace explicitly set",
userProvidedName: "ai/my-model",
expectedNormalizedName: "ai/my-model:latest",
description: "When a user explicitly sets 'ai/' namespace, it should remain the same",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := models.NormalizeModelName(tt.userProvidedName)
if result != tt.expectedNormalizedName {
t.Errorf("%s: NormalizeModelName(%q) = %q, want %q",
tt.description, tt.userProvidedName, result, tt.expectedNormalizedName)
}
})
}
}
Comment thread
ericcurtin marked this conversation as resolved.
11 changes: 10 additions & 1 deletion cmd/cli/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,17 @@ func (c *Client) handleQueryError(err error, path string) error {
return fmt.Errorf("error querying %s: %w", path, err)
}

// normalizeHuggingFaceModelName converts Hugging Face model names to lowercase
func normalizeHuggingFaceModelName(model string) string {
if strings.HasPrefix(model, "hf.co/") {
return strings.ToLower(model)
}

return model
}

func (c *Client) Tag(source, targetRepo, targetTag string) error {
source = dmrm.NormalizeModelName(source)
source = normalizeHuggingFaceModelName(source)
// Check if the source is a model ID, and expand it if necessary
if !strings.Contains(strings.Trim(source, "/"), "/") {
// Do an extra API call to check if the model parameter might be a model ID
Expand Down
2 changes: 1 addition & 1 deletion cmd/cli/desktop/desktop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func TestTagHuggingFaceModel(t *testing.T) {

// Test case for tagging a Hugging Face model with mixed case
sourceModel := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest"
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"
targetRepo := "myrepo"
targetTag := "latest"

Expand Down
Loading