@@ -6,9 +6,11 @@ import (
66 "fmt"
77 "io"
88 "net/http"
9+ "os"
910 "slices"
1011 "strings"
1112
13+ "github.com/docker/model-runner/pkg/distribution/huggingface"
1214 "github.com/docker/model-runner/pkg/distribution/internal/progress"
1315 "github.com/docker/model-runner/pkg/distribution/internal/store"
1416 "github.com/docker/model-runner/pkg/distribution/registry"
@@ -162,9 +164,10 @@ func (c *Client) normalizeModelName(model string) string {
162164 return model
163165 }
164166
165- // Normalize HuggingFace model names (lowercase path)
167+ // Normalize HuggingFace model names
166168 if strings .HasPrefix (model , "hf.co/" ) {
167169 // Replace hf.co with huggingface.co to avoid losing the Authorization header on redirect.
170+ // Lowercase for OCI compatibility (repository names must be lowercase)
168171 model = "huggingface.co" + strings .ToLower (strings .TrimPrefix (model , "hf.co" ))
169172 }
170173
@@ -261,21 +264,31 @@ func (c *Client) resolveID(id string) string {
261264
262265// PullModel pulls a model from a registry and returns the local file path
263266func (c * Client ) PullModel (ctx context.Context , reference string , progressWriter io.Writer , bearerToken ... string ) error {
267+ // Store original reference before normalization (needed for case-sensitive HuggingFace API)
268+ originalReference := reference
264269 // Normalize the model reference
265270 reference = c .normalizeModelName (reference )
266271 c .log .Infoln ("Starting model pull:" , utils .SanitizeForLog (reference ))
267272
268273 // Use the client's registry, or create a temporary one if bearer token is provided
269274 registryClient := c .registry
275+ var token string
270276 if len (bearerToken ) > 0 && bearerToken [0 ] != "" {
277+ token = bearerToken [0 ]
271278 // Create a temporary registry client with bearer token authentication
272- auth := & authn.Bearer {Token : bearerToken [ 0 ] }
279+ auth := & authn.Bearer {Token : token }
273280 registryClient = registry .FromClient (c .registry , registry .WithAuth (auth ))
274281 }
275282
276283 // First, fetch the remote model to get the manifest
277284 remoteModel , err := registryClient .Model (ctx , reference )
278285 if err != nil {
286+ // Check if this is a HuggingFace reference and the error indicates no OCI manifest
287+ if isHuggingFaceReference (reference ) && isNotOCIError (err ) {
288+ c .log .Infoln ("No OCI manifest found, attempting native HuggingFace pull" )
289+ // Pass original reference to preserve case-sensitivity for HuggingFace API
290+ return c .pullNativeHuggingFace (ctx , originalReference , progressWriter , token )
291+ }
279292 return fmt .Errorf ("reading model from registry: %w" , err )
280293 }
281294
@@ -637,3 +650,116 @@ func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string,
637650
638651 return nil
639652}
653+
654+ // isHuggingFaceReference checks if a reference is a HuggingFace model reference
655+ func isHuggingFaceReference (reference string ) bool {
656+ return strings .HasPrefix (reference , "huggingface.co/" )
657+ }
658+
659+ // isNotOCIError checks if the error indicates the model is not OCI-formatted
660+ // This happens when the HuggingFace repository doesn't have an OCI manifest
661+ func isNotOCIError (err error ) bool {
662+ if err == nil {
663+ return false
664+ }
665+
666+ // Check for registry errors indicating no manifest
667+ var regErr * registry.Error
668+ if errors .As (err , & regErr ) {
669+ if regErr .Code == "MANIFEST_UNKNOWN" || regErr .Code == "NAME_UNKNOWN" {
670+ return true
671+ }
672+ }
673+
674+ // Note: We intentionally don't treat ErrInvalidReference as "not OCI" - that's a format error
675+ // that should be reported to the user, not interpreted as a native HF model.
676+ // The model name is lowercased during normalization to ensure OCI compatibility.
677+
678+ // Also check error message for common patterns
679+ errStr := err .Error ()
680+ return strings .Contains (errStr , "MANIFEST_UNKNOWN" ) ||
681+ strings .Contains (errStr , "NAME_UNKNOWN" ) ||
682+ strings .Contains (errStr , "manifest unknown" ) ||
683+ // HuggingFace returns this error for non-GGUF repositories
684+ strings .Contains (errStr , "Repository is not GGUF" ) ||
685+ strings .Contains (errStr , "not compatible with llama.cpp" )
686+ }
687+
688+ // parseHFReference extracts repo and revision from a HF reference
689+ // e.g., "huggingface.co/org/model:revision" -> ("org/model", "revision")
690+ // e.g., "hf.co/org/model:latest" -> ("org/model", "main")
691+ // Note: This preserves the original case of the repo name for HuggingFace API compatibility
692+ func parseHFReference (reference string ) (repo , revision string ) {
693+ // Remove registry prefix (handle both hf.co and huggingface.co)
694+ ref := strings .TrimPrefix (reference , "huggingface.co/" )
695+ ref = strings .TrimPrefix (ref , "hf.co/" )
696+
697+ // Split by colon to get tag
698+ parts := strings .SplitN (ref , ":" , 2 )
699+ repo = parts [0 ]
700+
701+ revision = "main"
702+ if len (parts ) == 2 && parts [1 ] != "" && parts [1 ] != "latest" {
703+ revision = parts [1 ]
704+ }
705+
706+ return repo , revision
707+ }
708+
709+ // pullNativeHuggingFace pulls a native HuggingFace repository (non-OCI format)
710+ // This is used when the model is stored as raw files (safetensors) on HuggingFace Hub
711+ func (c * Client ) pullNativeHuggingFace (ctx context.Context , reference string , progressWriter io.Writer , token string ) error {
712+ repo , revision := parseHFReference (reference )
713+ c .log .Infof ("Pulling native HuggingFace model: repo=%s, revision=%s" , utils .SanitizeForLog (repo ), utils .SanitizeForLog (revision ))
714+
715+ // Create HuggingFace client
716+ hfOpts := []huggingface.ClientOption {
717+ huggingface .WithUserAgent (registry .DefaultUserAgent ),
718+ }
719+ if token != "" {
720+ hfOpts = append (hfOpts , huggingface .WithToken (token ))
721+ }
722+ hfClient := huggingface .NewClient (hfOpts ... )
723+
724+ // Create temp directory for downloads
725+ tempDir , err := os .MkdirTemp ("" , "hf-model-*" )
726+ if err != nil {
727+ return fmt .Errorf ("create temp dir: %w" , err )
728+ }
729+ defer os .RemoveAll (tempDir )
730+
731+ // Build model from HuggingFace repository
732+ model , err := huggingface .BuildModel (ctx , hfClient , repo , revision , tempDir , progressWriter )
733+ if err != nil {
734+ // Convert HuggingFace errors to registry errors for consistent handling
735+ var authErr * huggingface.AuthError
736+ var notFoundErr * huggingface.NotFoundError
737+ if errors .As (err , & authErr ) {
738+ return registry .ErrUnauthorized
739+ }
740+ if errors .As (err , & notFoundErr ) {
741+ return registry .ErrModelNotFound
742+ }
743+ if writeErr := progress .WriteError (progressWriter , fmt .Sprintf ("Error: %s" , err .Error ())); writeErr != nil {
744+ c .log .Warnf ("Failed to write error message: %v" , writeErr )
745+ }
746+ return fmt .Errorf ("build model from HuggingFace: %w" , err )
747+ }
748+
749+ // Write model to store
750+ // Lowercase the reference for storage since OCI tags don't allow uppercase
751+ storageTag := strings .ToLower (reference )
752+ c .log .Infof ("Writing model to store with tag: %s" , utils .SanitizeForLog (storageTag ))
753+ if err := c .store .Write (model , []string {storageTag }, progressWriter ); err != nil {
754+ if writeErr := progress .WriteError (progressWriter , fmt .Sprintf ("Error: %s" , err .Error ())); writeErr != nil {
755+ c .log .Warnf ("Failed to write error message: %v" , writeErr )
756+ }
757+ return fmt .Errorf ("writing model to store: %w" , err )
758+ }
759+
760+ if err := progress .WriteSuccess (progressWriter , "Model pulled successfully" ); err != nil {
761+ c .log .Warnf ("Failed to write success message: %v" , err )
762+ }
763+
764+ return nil
765+ }
0 commit comments