Skip to content

Commit d510957

Browse files
committed
Add ability to pull vllm-compatible hf models
This commit introduces native HuggingFace model support by adding a new HuggingFace client implementation that can download safetensors files directly from HuggingFace Hub repositories. The changes include: A new HuggingFace client with authentication, file listing, and download capabilities. The client handles LFS files, error responses, and rate limiting appropriately. A downloader component that manages parallel file downloads with progress reporting and temporary file storage. It includes progress tracking and concurrent download limiting. Model building functionality that downloads files from HuggingFace repositories and constructs OCI model artifacts using the existing builder framework. Repository utilities for file classification, filtering, and size calculations to identify safetensors and config files needed for model construction. Integration with the existing pull mechanism to detect HuggingFace references and attempt native pulling when no OCI manifest is found. This preserves existing OCI functionality while adding fallback support for raw HuggingFace repositories. Signed-off-by: Eric Curtin <eric.curtin@docker.com>
1 parent 754e0ed commit d510957

9 files changed

Lines changed: 1250 additions & 12 deletions

File tree

pkg/distribution/distribution/client.go

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ import (
66
"fmt"
77
"io"
88
"net/http"
9+
"os"
910
"slices"
1011
"strings"
1112

13+
"github.com/docker/model-runner/pkg/distribution/huggingface"
1214
"github.com/docker/model-runner/pkg/distribution/internal/progress"
1315
"github.com/docker/model-runner/pkg/distribution/internal/store"
1416
"github.com/docker/model-runner/pkg/distribution/registry"
@@ -162,9 +164,10 @@ func (c *Client) normalizeModelName(model string) string {
162164
return model
163165
}
164166

165-
// Normalize HuggingFace model names (lowercase path)
167+
// Normalize HuggingFace model names
166168
if strings.HasPrefix(model, "hf.co/") {
167169
// Replace hf.co with huggingface.co to avoid losing the Authorization header on redirect.
170+
// Lowercase for OCI compatibility (repository names must be lowercase)
168171
model = "huggingface.co" + strings.ToLower(strings.TrimPrefix(model, "hf.co"))
169172
}
170173

@@ -261,21 +264,31 @@ func (c *Client) resolveID(id string) string {
261264

262265
// PullModel pulls a model from a registry and returns the local file path
263266
func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer, bearerToken ...string) error {
267+
// Store original reference before normalization (needed for case-sensitive HuggingFace API)
268+
originalReference := reference
264269
// Normalize the model reference
265270
reference = c.normalizeModelName(reference)
266271
c.log.Infoln("Starting model pull:", utils.SanitizeForLog(reference))
267272

268273
// Use the client's registry, or create a temporary one if bearer token is provided
269274
registryClient := c.registry
275+
var token string
270276
if len(bearerToken) > 0 && bearerToken[0] != "" {
277+
token = bearerToken[0]
271278
// Create a temporary registry client with bearer token authentication
272-
auth := &authn.Bearer{Token: bearerToken[0]}
279+
auth := &authn.Bearer{Token: token}
273280
registryClient = registry.FromClient(c.registry, registry.WithAuth(auth))
274281
}
275282

276283
// First, fetch the remote model to get the manifest
277284
remoteModel, err := registryClient.Model(ctx, reference)
278285
if err != nil {
286+
// Check if this is a HuggingFace reference and the error indicates no OCI manifest
287+
if isHuggingFaceReference(reference) && isNotOCIError(err) {
288+
c.log.Infoln("No OCI manifest found, attempting native HuggingFace pull")
289+
// Pass original reference to preserve case-sensitivity for HuggingFace API
290+
return c.pullNativeHuggingFace(ctx, originalReference, progressWriter, token)
291+
}
279292
return fmt.Errorf("reading model from registry: %w", err)
280293
}
281294

@@ -637,3 +650,116 @@ func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string,
637650

638651
return nil
639652
}
653+
654+
// isHuggingFaceReference checks if a reference is a HuggingFace model reference
655+
func isHuggingFaceReference(reference string) bool {
656+
return strings.HasPrefix(reference, "huggingface.co/")
657+
}
658+
659+
// isNotOCIError checks if the error indicates the model is not OCI-formatted
660+
// This happens when the HuggingFace repository doesn't have an OCI manifest
661+
func isNotOCIError(err error) bool {
662+
if err == nil {
663+
return false
664+
}
665+
666+
// Check for registry errors indicating no manifest
667+
var regErr *registry.Error
668+
if errors.As(err, &regErr) {
669+
if regErr.Code == "MANIFEST_UNKNOWN" || regErr.Code == "NAME_UNKNOWN" {
670+
return true
671+
}
672+
}
673+
674+
// Note: We intentionally don't treat ErrInvalidReference as "not OCI" - that's a format error
675+
// that should be reported to the user, not interpreted as a native HF model.
676+
// The model name is lowercased during normalization to ensure OCI compatibility.
677+
678+
// Also check error message for common patterns
679+
errStr := err.Error()
680+
return strings.Contains(errStr, "MANIFEST_UNKNOWN") ||
681+
strings.Contains(errStr, "NAME_UNKNOWN") ||
682+
strings.Contains(errStr, "manifest unknown") ||
683+
// HuggingFace returns this error for non-GGUF repositories
684+
strings.Contains(errStr, "Repository is not GGUF") ||
685+
strings.Contains(errStr, "not compatible with llama.cpp")
686+
}
687+
688+
// parseHFReference extracts repo and revision from a HF reference
689+
// e.g., "huggingface.co/org/model:revision" -> ("org/model", "revision")
690+
// e.g., "hf.co/org/model:latest" -> ("org/model", "main")
691+
// Note: This preserves the original case of the repo name for HuggingFace API compatibility
692+
func parseHFReference(reference string) (repo, revision string) {
693+
// Remove registry prefix (handle both hf.co and huggingface.co)
694+
ref := strings.TrimPrefix(reference, "huggingface.co/")
695+
ref = strings.TrimPrefix(ref, "hf.co/")
696+
697+
// Split by colon to get tag
698+
parts := strings.SplitN(ref, ":", 2)
699+
repo = parts[0]
700+
701+
revision = "main"
702+
if len(parts) == 2 && parts[1] != "" && parts[1] != "latest" {
703+
revision = parts[1]
704+
}
705+
706+
return repo, revision
707+
}
708+
709+
// pullNativeHuggingFace pulls a native HuggingFace repository (non-OCI format)
710+
// This is used when the model is stored as raw files (safetensors) on HuggingFace Hub
711+
func (c *Client) pullNativeHuggingFace(ctx context.Context, reference string, progressWriter io.Writer, token string) error {
712+
repo, revision := parseHFReference(reference)
713+
c.log.Infof("Pulling native HuggingFace model: repo=%s, revision=%s", utils.SanitizeForLog(repo), utils.SanitizeForLog(revision))
714+
715+
// Create HuggingFace client
716+
hfOpts := []huggingface.ClientOption{
717+
huggingface.WithUserAgent(registry.DefaultUserAgent),
718+
}
719+
if token != "" {
720+
hfOpts = append(hfOpts, huggingface.WithToken(token))
721+
}
722+
hfClient := huggingface.NewClient(hfOpts...)
723+
724+
// Create temp directory for downloads
725+
tempDir, err := os.MkdirTemp("", "hf-model-*")
726+
if err != nil {
727+
return fmt.Errorf("create temp dir: %w", err)
728+
}
729+
defer os.RemoveAll(tempDir)
730+
731+
// Build model from HuggingFace repository
732+
model, err := huggingface.BuildModel(ctx, hfClient, repo, revision, tempDir, progressWriter)
733+
if err != nil {
734+
// Convert HuggingFace errors to registry errors for consistent handling
735+
var authErr *huggingface.AuthError
736+
var notFoundErr *huggingface.NotFoundError
737+
if errors.As(err, &authErr) {
738+
return registry.ErrUnauthorized
739+
}
740+
if errors.As(err, &notFoundErr) {
741+
return registry.ErrModelNotFound
742+
}
743+
if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil {
744+
c.log.Warnf("Failed to write error message: %v", writeErr)
745+
}
746+
return fmt.Errorf("build model from HuggingFace: %w", err)
747+
}
748+
749+
// Write model to store
750+
// Lowercase the reference for storage since OCI tags don't allow uppercase
751+
storageTag := strings.ToLower(reference)
752+
c.log.Infof("Writing model to store with tag: %s", utils.SanitizeForLog(storageTag))
753+
if err := c.store.Write(model, []string{storageTag}, progressWriter); err != nil {
754+
if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil {
755+
c.log.Warnf("Failed to write error message: %v", writeErr)
756+
}
757+
return fmt.Errorf("writing model to store: %w", err)
758+
}
759+
760+
if err := progress.WriteSuccess(progressWriter, "Model pulled successfully"); err != nil {
761+
c.log.Warnf("Failed to write success message: %v", err)
762+
}
763+
764+
return nil
765+
}

pkg/distribution/distribution/normalize_test.go

Lines changed: 115 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ package distribution
22

33
import (
44
"context"
5+
"errors"
56
"io"
67
"path/filepath"
78
"strings"
89
"testing"
910

1011
"github.com/docker/model-runner/pkg/distribution/builder"
12+
"github.com/docker/model-runner/pkg/distribution/registry"
1113
"github.com/docker/model-runner/pkg/distribution/tarball"
1214
"github.com/sirupsen/logrus"
1315
)
@@ -66,7 +68,7 @@ func TestNormalizeModelName(t *testing.T) {
6668
expected: "registry.example.com/myorg/model:v1",
6769
},
6870

69-
// HuggingFace cases
71+
// HuggingFace cases (lowercased for OCI reference compatibility)
7072
{
7173
name: "huggingface short form lowercase",
7274
input: "hf.co/model",
@@ -75,12 +77,12 @@ func TestNormalizeModelName(t *testing.T) {
7577
{
7678
name: "huggingface short form uppercase",
7779
input: "hf.co/Model",
78-
expected: "huggingface.co/model:latest",
80+
expected: "huggingface.co/model:latest", // lowercased for OCI compatibility
7981
},
8082
{
8183
name: "huggingface short form with org",
8284
input: "hf.co/MyOrg/MyModel",
83-
expected: "huggingface.co/myorg/mymodel:latest",
85+
expected: "huggingface.co/myorg/mymodel:latest", // lowercased for OCI compatibility
8486
},
8587
{
8688
name: "huggingface with tag",
@@ -355,6 +357,116 @@ func createTestClient(t *testing.T) (*Client, func()) {
355357
return client, cleanup
356358
}
357359

360+
func TestIsHuggingFaceReference(t *testing.T) {
361+
tests := []struct {
362+
name string
363+
input string
364+
expected bool
365+
}{
366+
{"huggingface.co prefix", "huggingface.co/org/model:latest", true},
367+
{"huggingface.co without tag", "huggingface.co/org/model", true},
368+
{"not huggingface", "registry.example.com/model:latest", false},
369+
{"docker hub", "ai/gemma3:latest", false},
370+
{"hf.co prefix (not normalized)", "hf.co/org/model", false}, // This is the un-normalized form
371+
{"empty", "", false},
372+
}
373+
374+
for _, tt := range tests {
375+
t.Run(tt.name, func(t *testing.T) {
376+
result := isHuggingFaceReference(tt.input)
377+
if result != tt.expected {
378+
t.Errorf("isHuggingFaceReference(%q) = %v, want %v", tt.input, result, tt.expected)
379+
}
380+
})
381+
}
382+
}
383+
384+
func TestParseHFReference(t *testing.T) {
385+
tests := []struct {
386+
name string
387+
input string
388+
expectedRepo string
389+
expectedRev string
390+
}{
391+
{
392+
name: "basic with latest tag",
393+
input: "huggingface.co/org/model:latest",
394+
expectedRepo: "org/model",
395+
expectedRev: "main", // latest maps to main
396+
},
397+
{
398+
name: "with explicit revision",
399+
input: "huggingface.co/org/model:v1.0",
400+
expectedRepo: "org/model",
401+
expectedRev: "v1.0",
402+
},
403+
{
404+
name: "without tag",
405+
input: "huggingface.co/org/model",
406+
expectedRepo: "org/model",
407+
expectedRev: "main",
408+
},
409+
{
410+
name: "with commit hash as tag",
411+
input: "huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct:abc123",
412+
expectedRepo: "HuggingFaceTB/SmolLM2-135M-Instruct",
413+
expectedRev: "abc123",
414+
},
415+
{
416+
name: "single name (no org)",
417+
input: "huggingface.co/model:latest",
418+
expectedRepo: "model",
419+
expectedRev: "main",
420+
},
421+
}
422+
423+
for _, tt := range tests {
424+
t.Run(tt.name, func(t *testing.T) {
425+
repo, rev := parseHFReference(tt.input)
426+
if repo != tt.expectedRepo {
427+
t.Errorf("parseHFReference(%q) repo = %q, want %q", tt.input, repo, tt.expectedRepo)
428+
}
429+
if rev != tt.expectedRev {
430+
t.Errorf("parseHFReference(%q) rev = %q, want %q", tt.input, rev, tt.expectedRev)
431+
}
432+
})
433+
}
434+
}
435+
436+
func TestIsNotOCIError(t *testing.T) {
437+
tests := []struct {
438+
name string
439+
err error
440+
expected bool
441+
}{
442+
{"nil error", nil, false},
443+
{"generic error", errors.New("some error"), false},
444+
{"manifest unknown in message", errors.New("MANIFEST_UNKNOWN: manifest not found"), true},
445+
{"name unknown in message", errors.New("NAME_UNKNOWN: repository not found"), true},
446+
{"manifest unknown lowercase", errors.New("manifest unknown"), true},
447+
{"unrelated error", errors.New("network timeout"), false},
448+
{"HuggingFace not GGUF error", errors.New("Repository is not GGUF or is not compatible with llama.cpp"), true},
449+
{"HuggingFace llama.cpp incompatible", errors.New("not compatible with llama.cpp"), true},
450+
// registry.Error typed error cases
451+
{"registry error MANIFEST_UNKNOWN", &registry.Error{Code: "MANIFEST_UNKNOWN"}, true},
452+
{"registry error NAME_UNKNOWN", &registry.Error{Code: "NAME_UNKNOWN"}, true},
453+
{"registry error other code", &registry.Error{Code: "UNAUTHORIZED"}, false},
454+
// ErrInvalidReference is NOT treated as "not OCI" - it's a format error
455+
// that should be reported to the user. Model names are lowercased during
456+
// normalization to ensure OCI compatibility.
457+
{"invalid reference error", registry.ErrInvalidReference, false},
458+
}
459+
460+
for _, tt := range tests {
461+
t.Run(tt.name, func(t *testing.T) {
462+
result := isNotOCIError(tt.err)
463+
if result != tt.expected {
464+
t.Errorf("isNotOCIError(%v) = %v, want %v", tt.err, result, tt.expected)
465+
}
466+
})
467+
}
468+
}
469+
358470
// Helper function to load a test model and return its ID
359471
func loadTestModel(t *testing.T, client *Client, ggufPath string) string {
360472
t.Helper()

0 commit comments

Comments
 (0)