diff --git a/cmd/cli/commands/package.go b/cmd/cli/commands/package.go index e793150b0..f7f0b4b03 100644 --- a/cmd/cli/commands/package.go +++ b/cmd/cli/commands/package.go @@ -15,6 +15,7 @@ import ( "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/tarball" "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/inference/models" "github.com/google/go-containerregistry/pkg/name" "github.com/spf13/cobra" @@ -313,7 +314,9 @@ func newModelRunnerTarget(client *desktop.Client, tag string) (*modelRunnerTarge } if tag != "" { var err error - target.tag, err = name.NewTag(tag) + // Normalize the tag to add default namespace (ai/) and tag (:latest) if missing + normalizedTag := models.NormalizeModelName(tag) + target.tag, err = name.NewTag(normalizedTag) if err != nil { return nil, fmt.Errorf("invalid tag: %w", err) } diff --git a/cmd/cli/commands/tag.go b/cmd/cli/commands/tag.go index f7197b363..265c51ad8 100644 --- a/cmd/cli/commands/tag.go +++ b/cmd/cli/commands/tag.go @@ -39,6 +39,8 @@ func newTagCmd() *cobra.Command { func tagModel(cmd *cobra.Command, desktopClient *desktop.Client, source, target string) error { // Normalize source model name to add default org and tag if missing source = models.NormalizeModelName(source) + // Normalize target model name to add default org and tag if missing + target = models.NormalizeModelName(target) // Ensure tag is valid tag, err := name.NewTag(target) if err != nil { diff --git a/cmd/cli/commands/utils_test.go b/cmd/cli/commands/utils_test.go index ad4a8e3c3..2f64232ac 100644 --- a/cmd/cli/commands/utils_test.go +++ b/cmd/cli/commands/utils_test.go @@ -146,3 +146,51 @@ func TestStripDefaultsFromModelName(t *testing.T) { }) } } + +// TestNormalizeModelNameConsistency verifies that locally packaged models +// (without namespace) get normalized the same way as other operations. +// This test documents the fix for the bug where `docker model package my-model` +// would create a model that couldn't be run with `docker model run my-model`. +func TestNormalizeModelNameConsistency(t *testing.T) { + tests := []struct { + name string + userProvidedName string + expectedNormalizedName string + description string + }{ + { + name: "locally packaged model without namespace", + userProvidedName: "my-model", + expectedNormalizedName: "ai/my-model:latest", + description: "When a user packages a local model as 'my-model', it should be normalized to 'ai/my-model:latest'", + }, + { + name: "locally packaged model without namespace but with tag", + userProvidedName: "my-model:v1.0", + expectedNormalizedName: "ai/my-model:v1.0", + description: "When a user packages a local model as 'my-model:v1.0', it should be normalized to 'ai/my-model:v1.0'", + }, + { + name: "model with explicit namespace", + userProvidedName: "myorg/my-model", + expectedNormalizedName: "myorg/my-model:latest", + description: "When a user packages a model with explicit org 'myorg/my-model', it should keep the org", + }, + { + name: "model with ai namespace explicitly set", + userProvidedName: "ai/my-model", + expectedNormalizedName: "ai/my-model:latest", + description: "When a user explicitly sets 'ai/' namespace, it should remain the same", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := models.NormalizeModelName(tt.userProvidedName) + if result != tt.expectedNormalizedName { + t.Errorf("%s: NormalizeModelName(%q) = %q, want %q", + tt.description, tt.userProvidedName, result, tt.expectedNormalizedName) + } + }) + } +} diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 01b7b080f..b3f41af27 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -799,8 +799,17 @@ func (c *Client) handleQueryError(err error, path string) error { return fmt.Errorf("error querying %s: %w", path, err) } +// normalizeHuggingFaceModelName converts Hugging Face model names to lowercase +func normalizeHuggingFaceModelName(model string) string { + if strings.HasPrefix(model, "hf.co/") { + return strings.ToLower(model) + } + + return model +} + func (c *Client) Tag(source, targetRepo, targetTag string) error { - source = dmrm.NormalizeModelName(source) + source = normalizeHuggingFaceModelName(source) // Check if the source is a model ID, and expand it if necessary if !strings.Contains(strings.Trim(source, "/"), "/") { // Do an extra API call to check if the model parameter might be a model ID diff --git a/cmd/cli/desktop/desktop_test.go b/cmd/cli/desktop/desktop_test.go index 32d7907ce..fb0c82262 100644 --- a/cmd/cli/desktop/desktop_test.go +++ b/cmd/cli/desktop/desktop_test.go @@ -179,7 +179,7 @@ func TestTagHuggingFaceModel(t *testing.T) { // Test case for tagging a Hugging Face model with mixed case sourceModel := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" targetRepo := "myrepo" targetTag := "latest"