From 63558c72a6bdcc1852c2b46ab00f16d49d8e697c Mon Sep 17 00:00:00 2001 From: Avinash Singh Date: Tue, 11 Nov 2025 15:03:08 +0530 Subject: [PATCH 1/5] Add support for HF model_url Signed-off-by: Avinash Singh --- cmd/modelfile/generate.go | 68 ++++++++- pkg/config/modelfile/modelfile.go | 2 + pkg/hfhub/download.go | 220 ++++++++++++++++++++++++++++++ pkg/hfhub/download_test.go | 131 ++++++++++++++++++ 4 files changed, 416 insertions(+), 5 deletions(-) create mode 100644 pkg/hfhub/download.go create mode 100644 pkg/hfhub/download_test.go diff --git a/cmd/modelfile/generate.go b/cmd/modelfile/generate.go index b88cc7ce..152e7b7e 100644 --- a/cmd/modelfile/generate.go +++ b/cmd/modelfile/generate.go @@ -20,11 +20,13 @@ import ( "context" "fmt" "os" + "path/filepath" "github.com/spf13/cobra" "github.com/spf13/viper" configmodelfile "github.com/modelpack/modctl/pkg/config/modelfile" + "github.com/modelpack/modctl/pkg/hfhub" "github.com/modelpack/modctl/pkg/modelfile" ) @@ -32,14 +34,43 @@ var generateConfig = configmodelfile.NewGenerateConfig() // generateCmd represents the modelfile tools command for generating modelfile. var generateCmd = &cobra.Command{ - Use: "generate [flags] ", - Short: "A command line tool for generating modelfile in the workspace, the workspace must be a directory including model files and model configuration files", - Args: cobra.ExactArgs(1), + Use: "generate [flags] []", + Short: "Generate a modelfile from a local workspace or Hugging Face model", + Long: `Generate a modelfile from either a local directory containing model files or by downloading a model from Hugging Face. + +The workspace must be a directory including model files and model configuration files. +Alternatively, use --model_url to download a model from Hugging Face Hub.`, + Example: ` # Generate from local directory + modctl modelfile generate ./my-model-dir + + # Generate from Hugging Face model URL + modctl modelfile generate --model_url https://huggingface.co/meta-llama/Llama-2-7b-hf + + # Generate from Hugging Face using short form + modctl modelfile generate --model_url meta-llama/Llama-2-7b-hf + + # Generate with custom output path + modctl modelfile generate ./my-model-dir --output ./output/modelfile.yaml + + # Generate with metadata overrides + modctl modelfile generate ./my-model-dir --name my-custom-model --family llama3`, + Args: cobra.MaximumNArgs(1), DisableAutoGenTag: true, SilenceUsage: true, FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true}, RunE: func(cmd *cobra.Command, args []string) error { - if err := generateConfig.Convert(args[0]); err != nil { + // If model_url is provided, path is optional + workspace := "." + if len(args) > 0 { + workspace = args[0] + } + + // Validate that either path or model_url is provided + if generateConfig.ModelURL == "" && len(args) == 0 { + return fmt.Errorf("either argument or --model_url flag must be provided") + } + + if err := generateConfig.Convert(workspace); err != nil { return err } @@ -64,6 +95,7 @@ func init() { flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory") flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace") flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile") + flags.StringVar(&generateConfig.ModelURL, "model_url", "", "download model from Hugging Face (format: owner/repo or full URL)") // Mark the ignore-unrecognized-file-types flag as deprecated and hidden flags.MarkDeprecated("ignore-unrecognized-file-types", "this flag will be removed in the next release") @@ -75,7 +107,33 @@ func init() { } // runGenerate runs the generate modelfile. -func runGenerate(_ context.Context) error { +func runGenerate(ctx context.Context) error { + // If model URL is provided, download the model first + if generateConfig.ModelURL != "" { + fmt.Printf("Model URL provided: %s\n", generateConfig.ModelURL) + + // Check if user is authenticated with Hugging Face + if err := hfhub.CheckHuggingFaceAuth(); err != nil { + return fmt.Errorf("authentication check failed: %w", err) + } + + // Create a temporary directory for downloading the model + tmpDir := filepath.Join(os.TempDir(), "modctl-hf-downloads") + if err := os.MkdirAll(tmpDir, 0755); err != nil { + return fmt.Errorf("failed to create temporary directory: %w", err) + } + + // Download the model + downloadPath, err := hfhub.DownloadModel(ctx, generateConfig.ModelURL, tmpDir) + if err != nil { + return fmt.Errorf("failed to download model: %w", err) + } + + // Update workspace to the downloaded model path + generateConfig.Workspace = downloadPath + fmt.Printf("Using downloaded model at: %s\n", downloadPath) + } + fmt.Printf("Generating modelfile for %s\n", generateConfig.Workspace) modelfile, err := modelfile.NewModelfileByWorkspace(generateConfig.Workspace, generateConfig) if err != nil { diff --git a/pkg/config/modelfile/modelfile.go b/pkg/config/modelfile/modelfile.go index f2025052..5723193a 100644 --- a/pkg/config/modelfile/modelfile.go +++ b/pkg/config/modelfile/modelfile.go @@ -39,6 +39,7 @@ type GenerateConfig struct { ParamSize string Precision string Quantization string + ModelURL string } func NewGenerateConfig() *GenerateConfig { @@ -55,6 +56,7 @@ func NewGenerateConfig() *GenerateConfig { ParamSize: "", Precision: "", Quantization: "", + ModelURL: "", } } diff --git a/pkg/hfhub/download.go b/pkg/hfhub/download.go new file mode 100644 index 00000000..87da7d65 --- /dev/null +++ b/pkg/hfhub/download.go @@ -0,0 +1,220 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package hfhub + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "strings" +) + +const ( + HuggingFaceBaseURL = "https://huggingface.co" +) + +// ParseModelURL parses a Hugging Face model URL and extracts the owner and repository name +func ParseModelURL(modelURL string) (owner, repo string, err error) { + // Handle both full URLs and short-form repo names + modelURL = strings.TrimSpace(modelURL) + + // Remove trailing slashes + modelURL = strings.TrimSuffix(modelURL, "/") + + // If it's a full URL, parse it + if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") { + u, err := url.Parse(modelURL) + if err != nil { + return "", "", fmt.Errorf("invalid URL: %w", err) + } + + // Expected format: https://huggingface.co/owner/repo + parts := strings.Split(strings.Trim(u.Path, "/"), "/") + if len(parts) < 2 { + return "", "", fmt.Errorf("invalid Hugging Face URL format, expected https://huggingface.co/owner/repo") + } + + owner = parts[0] + repo = parts[1] + } else { + // Handle short-form like "owner/repo" + parts := strings.Split(modelURL, "/") + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid model identifier, expected format: owner/repo") + } + + owner = parts[0] + repo = parts[1] + } + + if owner == "" || repo == "" { + return "", "", fmt.Errorf("owner and repository name cannot be empty") + } + + return owner, repo, nil +} + +// DownloadModel downloads a model from Hugging Face using the huggingface-cli +// It assumes the user is already logged in via `huggingface-cli login` +func DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) { + owner, repo, err := ParseModelURL(modelURL) + if err != nil { + return "", err + } + + repoID := fmt.Sprintf("%s/%s", owner, repo) + + // Check if huggingface-cli is available + if _, err := exec.LookPath("huggingface-cli"); err != nil { + return "", fmt.Errorf("huggingface-cli not found in PATH. Please install it using: pip install huggingface_hub[cli]") + } + + // Create destination directory if it doesn't exist + if err := os.MkdirAll(destDir, 0755); err != nil { + return "", fmt.Errorf("failed to create destination directory: %w", err) + } + + // Construct the download path + downloadPath := filepath.Join(destDir, repo) + + // Use huggingface-cli to download the model + // The --local-dir-use-symlinks=False flag ensures files are copied, not symlinked + cmd := exec.CommandContext(ctx, "huggingface-cli", "download", repoID, "--local-dir", downloadPath, "--local-dir-use-symlinks", "False") + + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + fmt.Printf("Downloading model %s to %s...\n", repoID, downloadPath) + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("failed to download model using huggingface-cli: %w", err) + } + + fmt.Printf("Successfully downloaded model to %s\n", downloadPath) + + return downloadPath, nil +} + +// CheckHuggingFaceAuth checks if the user is authenticated with Hugging Face +func CheckHuggingFaceAuth() error { + // Try to find the HF token + token := os.Getenv("HF_TOKEN") + if token != "" { + return nil + } + + // Check if the token file exists + homeDir, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("failed to get user home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, ".huggingface", "token") + if _, err := os.Stat(tokenPath); err == nil { + return nil + } + + // Try using whoami command + if _, err := exec.LookPath("huggingface-cli"); err == nil { + cmd := exec.Command("huggingface-cli", "whoami") + if err := cmd.Run(); err == nil { + return nil + } + } + + return fmt.Errorf("not authenticated with Hugging Face. Please run: huggingface-cli login") +} + +// GetToken retrieves the Hugging Face token from environment or token file +func GetToken() (string, error) { + // First check environment variable + token := os.Getenv("HF_TOKEN") + if token != "" { + return token, nil + } + + // Then check the token file + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, ".huggingface", "token") + data, err := os.ReadFile(tokenPath) + if err != nil { + return "", fmt.Errorf("failed to read token file: %w", err) + } + + return strings.TrimSpace(string(data)), nil +} + +// DownloadFile downloads a single file from Hugging Face +func DownloadFile(ctx context.Context, owner, repo, filename, destPath string) error { + token, err := GetToken() + if err != nil { + return fmt.Errorf("failed to get Hugging Face token: %w", err) + } + + // Construct the download URL + // Format: https://huggingface.co/{owner}/{repo}/resolve/main/{filename} + fileURL := fmt.Sprintf("%s/%s/%s/resolve/main/%s", HuggingFaceBaseURL, owner, repo, filename) + + req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + // Add authorization header + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to download file: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download file, status code: %d", resp.StatusCode) + } + + // Create destination directory + destDir := filepath.Dir(destPath) + if err := os.MkdirAll(destDir, 0755); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + + // Create the destination file + outFile, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer outFile.Close() + + // Copy the content + _, err = io.Copy(outFile, resp.Body) + if err != nil { + return fmt.Errorf("failed to write file: %w", err) + } + + return nil +} diff --git a/pkg/hfhub/download_test.go b/pkg/hfhub/download_test.go new file mode 100644 index 00000000..bafd8432 --- /dev/null +++ b/pkg/hfhub/download_test.go @@ -0,0 +1,131 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package hfhub + +import ( + "testing" +) + +func TestParseModelURL(t *testing.T) { + tests := []struct { + name string + modelURL string + wantOwner string + wantRepo string + wantErr bool + errContains string + }{ + { + name: "full URL", + modelURL: "https://huggingface.co/meta-llama/Llama-2-7b-hf", + wantOwner: "meta-llama", + wantRepo: "Llama-2-7b-hf", + wantErr: false, + }, + { + name: "full URL with trailing slash", + modelURL: "https://huggingface.co/meta-llama/Llama-2-7b-hf/", + wantOwner: "meta-llama", + wantRepo: "Llama-2-7b-hf", + wantErr: false, + }, + { + name: "short form", + modelURL: "meta-llama/Llama-2-7b-hf", + wantOwner: "meta-llama", + wantRepo: "Llama-2-7b-hf", + wantErr: false, + }, + { + name: "http URL", + modelURL: "http://huggingface.co/openai/gpt-2", + wantOwner: "openai", + wantRepo: "gpt-2", + wantErr: false, + }, + { + name: "invalid format - missing repo", + modelURL: "https://huggingface.co/meta-llama", + wantErr: true, + errContains: "invalid Hugging Face URL format", + }, + { + name: "invalid format - only owner", + modelURL: "meta-llama", + wantErr: true, + errContains: "invalid model identifier", + }, + { + name: "empty URL", + modelURL: "", + wantErr: true, + errContains: "invalid model identifier", + }, + { + name: "URL with spaces (trimmed)", + modelURL: " meta-llama/Llama-2-7b-hf ", + wantOwner: "meta-llama", + wantRepo: "Llama-2-7b-hf", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + owner, repo, err := ParseModelURL(tt.modelURL) + + if tt.wantErr { + if err == nil { + t.Errorf("ParseModelURL() expected error but got nil") + return + } + if tt.errContains != "" && err.Error() != tt.errContains && !contains(err.Error(), tt.errContains) { + t.Errorf("ParseModelURL() error = %v, want error containing %v", err, tt.errContains) + } + return + } + + if err != nil { + t.Errorf("ParseModelURL() unexpected error = %v", err) + return + } + + if owner != tt.wantOwner { + t.Errorf("ParseModelURL() owner = %v, want %v", owner, tt.wantOwner) + } + + if repo != tt.wantRepo { + t.Errorf("ParseModelURL() repo = %v, want %v", repo, tt.wantRepo) + } + }) + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && + (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || + findInString(s, substr))) +} + +func findInString(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} From 643b5a19fa7d59db9e9fac5f86d301c36c2b8b8c Mon Sep 17 00:00:00 2001 From: Avinash Singh Date: Tue, 11 Nov 2025 16:13:53 +0530 Subject: [PATCH 2/5] optimise as per gemini's review Signed-off-by: Avinash Singh --- cmd/modelfile/generate.go | 12 ++++++++---- pkg/hfhub/download.go | 6 ++---- pkg/hfhub/download_test.go | 18 ++---------------- 3 files changed, 12 insertions(+), 24 deletions(-) diff --git a/cmd/modelfile/generate.go b/cmd/modelfile/generate.go index 152e7b7e..b6454363 100644 --- a/cmd/modelfile/generate.go +++ b/cmd/modelfile/generate.go @@ -20,7 +20,6 @@ import ( "context" "fmt" "os" - "path/filepath" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -66,8 +65,11 @@ Alternatively, use --model_url to download a model from Hugging Face Hub.`, } // Validate that either path or model_url is provided + if generateConfig.ModelURL != "" && len(args) > 0 { + return fmt.Errorf("the argument and the --model_url flag are mutually exclusive") + } if generateConfig.ModelURL == "" && len(args) == 0 { - return fmt.Errorf("either argument or --model_url flag must be provided") + return fmt.Errorf("either a argument or the --model_url flag must be provided") } if err := generateConfig.Convert(workspace); err != nil { @@ -118,10 +120,12 @@ func runGenerate(ctx context.Context) error { } // Create a temporary directory for downloading the model - tmpDir := filepath.Join(os.TempDir(), "modctl-hf-downloads") - if err := os.MkdirAll(tmpDir, 0755); err != nil { + // Clean up the temporary directory after the function returns + tmpDir, err := os.MkdirTemp("", "modctl-hf-downloads-*") + if err != nil { return fmt.Errorf("failed to create temporary directory: %w", err) } + defer os.RemoveAll(tmpDir) // Download the model downloadPath, err := hfhub.DownloadModel(ctx, generateConfig.ModelURL, tmpDir) diff --git a/pkg/hfhub/download.go b/pkg/hfhub/download.go index 87da7d65..5418c623 100644 --- a/pkg/hfhub/download.go +++ b/pkg/hfhub/download.go @@ -103,14 +103,10 @@ func DownloadModel(ctx context.Context, modelURL, destDir string) (string, error cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - fmt.Printf("Downloading model %s to %s...\n", repoID, downloadPath) - if err := cmd.Run(); err != nil { return "", fmt.Errorf("failed to download model using huggingface-cli: %w", err) } - fmt.Printf("Successfully downloaded model to %s\n", downloadPath) - return downloadPath, nil } @@ -136,6 +132,8 @@ func CheckHuggingFaceAuth() error { // Try using whoami command if _, err := exec.LookPath("huggingface-cli"); err == nil { cmd := exec.Command("huggingface-cli", "whoami") + cmd.Stdout = io.Discard + cmd.Stderr = io.Discard if err := cmd.Run(); err == nil { return nil } diff --git a/pkg/hfhub/download_test.go b/pkg/hfhub/download_test.go index bafd8432..ba993a0a 100644 --- a/pkg/hfhub/download_test.go +++ b/pkg/hfhub/download_test.go @@ -17,6 +17,7 @@ package hfhub import ( + "strings" "testing" ) @@ -93,7 +94,7 @@ func TestParseModelURL(t *testing.T) { t.Errorf("ParseModelURL() expected error but got nil") return } - if tt.errContains != "" && err.Error() != tt.errContains && !contains(err.Error(), tt.errContains) { + if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { t.Errorf("ParseModelURL() error = %v, want error containing %v", err, tt.errContains) } return @@ -114,18 +115,3 @@ func TestParseModelURL(t *testing.T) { }) } } - -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && - (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || - findInString(s, substr))) -} - -func findInString(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} From 6d7fb9003371af7af96fbf6d6895afe4a584f851 Mon Sep 17 00:00:00 2001 From: Avinash Singh Date: Mon, 24 Nov 2025 15:25:51 +0530 Subject: [PATCH 3/5] add modelprovider interface and providers Signed-off-by: Avinash Singh --- cmd/modelfile/generate.go | 34 ++- .../huggingface/downloader.go} | 67 ++---- .../huggingface/downloader_test.go} | 67 +++++- pkg/modelprovider/huggingface/provider.go | 99 +++++++++ pkg/modelprovider/modelscope/downloader.go | 195 ++++++++++++++++++ .../modelscope/downloader_test.go | 180 ++++++++++++++++ pkg/modelprovider/modelscope/provider.go | 96 +++++++++ pkg/modelprovider/provider.go | 46 +++++ pkg/modelprovider/registry.go | 74 +++++++ pkg/modelprovider/registry_test.go | 164 +++++++++++++++ 10 files changed, 951 insertions(+), 71 deletions(-) rename pkg/{hfhub/download.go => modelprovider/huggingface/downloader.go} (64%) rename pkg/{hfhub/download_test.go => modelprovider/huggingface/downloader_test.go} (63%) create mode 100644 pkg/modelprovider/huggingface/provider.go create mode 100644 pkg/modelprovider/modelscope/downloader.go create mode 100644 pkg/modelprovider/modelscope/downloader_test.go create mode 100644 pkg/modelprovider/modelscope/provider.go create mode 100644 pkg/modelprovider/provider.go create mode 100644 pkg/modelprovider/registry.go create mode 100644 pkg/modelprovider/registry_test.go diff --git a/cmd/modelfile/generate.go b/cmd/modelfile/generate.go index b6454363..0c098808 100644 --- a/cmd/modelfile/generate.go +++ b/cmd/modelfile/generate.go @@ -25,8 +25,8 @@ import ( "github.com/spf13/viper" configmodelfile "github.com/modelpack/modctl/pkg/config/modelfile" - "github.com/modelpack/modctl/pkg/hfhub" "github.com/modelpack/modctl/pkg/modelfile" + "github.com/modelpack/modctl/pkg/modelprovider" ) var generateConfig = configmodelfile.NewGenerateConfig() @@ -34,11 +34,11 @@ var generateConfig = configmodelfile.NewGenerateConfig() // generateCmd represents the modelfile tools command for generating modelfile. var generateCmd = &cobra.Command{ Use: "generate [flags] []", - Short: "Generate a modelfile from a local workspace or Hugging Face model", - Long: `Generate a modelfile from either a local directory containing model files or by downloading a model from Hugging Face. + Short: "Generate a modelfile from a local workspace or remote model provider", + Long: `Generate a modelfile from either a local directory containing model files or by downloading a model from a supported provider. The workspace must be a directory including model files and model configuration files. -Alternatively, use --model_url to download a model from Hugging Face Hub.`, +Alternatively, use --model_url to download a model from a supported provider (e.g., HuggingFace, ModelScope).`, Example: ` # Generate from local directory modctl modelfile generate ./my-model-dir @@ -48,6 +48,9 @@ Alternatively, use --model_url to download a model from Hugging Face Hub.`, # Generate from Hugging Face using short form modctl modelfile generate --model_url meta-llama/Llama-2-7b-hf + # Generate from ModelScope + modctl modelfile generate --model_url https://modelscope.cn/models/qwen/Qwen-7B + # Generate with custom output path modctl modelfile generate ./my-model-dir --output ./output/modelfile.yaml @@ -97,7 +100,7 @@ func init() { flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory") flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace") flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile") - flags.StringVar(&generateConfig.ModelURL, "model_url", "", "download model from Hugging Face (format: owner/repo or full URL)") + flags.StringVar(&generateConfig.ModelURL, "model_url", "", "download model from a supported provider (HuggingFace: owner/repo or full URL, ModelScope: full URL)") // Mark the ignore-unrecognized-file-types flag as deprecated and hidden flags.MarkDeprecated("ignore-unrecognized-file-types", "this flag will be removed in the next release") @@ -114,23 +117,32 @@ func runGenerate(ctx context.Context) error { if generateConfig.ModelURL != "" { fmt.Printf("Model URL provided: %s\n", generateConfig.ModelURL) - // Check if user is authenticated with Hugging Face - if err := hfhub.CheckHuggingFaceAuth(); err != nil { - return fmt.Errorf("authentication check failed: %w", err) + // Get the appropriate provider for this URL + registry := modelprovider.NewRegistry() + provider, err := registry.GetProvider(generateConfig.ModelURL) + if err != nil { + return fmt.Errorf("unsupported model URL: %w", err) + } + + fmt.Printf("Using provider: %s\n", provider.Name()) + + // Check if user is authenticated with the provider + if err := provider.CheckAuth(); err != nil { + return fmt.Errorf("%s authentication check failed: %w", provider.Name(), err) } // Create a temporary directory for downloading the model // Clean up the temporary directory after the function returns - tmpDir, err := os.MkdirTemp("", "modctl-hf-downloads-*") + tmpDir, err := os.MkdirTemp("", "modctl-model-downloads-*") if err != nil { return fmt.Errorf("failed to create temporary directory: %w", err) } defer os.RemoveAll(tmpDir) // Download the model - downloadPath, err := hfhub.DownloadModel(ctx, generateConfig.ModelURL, tmpDir) + downloadPath, err := provider.DownloadModel(ctx, generateConfig.ModelURL, tmpDir) if err != nil { - return fmt.Errorf("failed to download model: %w", err) + return fmt.Errorf("failed to download model from %s: %w", provider.Name(), err) } // Update workspace to the downloaded model path diff --git a/pkg/hfhub/download.go b/pkg/modelprovider/huggingface/downloader.go similarity index 64% rename from pkg/hfhub/download.go rename to pkg/modelprovider/huggingface/downloader.go index 5418c623..910418a0 100644 --- a/pkg/hfhub/download.go +++ b/pkg/modelprovider/huggingface/downloader.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package hfhub +package huggingface import ( "context" @@ -29,11 +29,11 @@ import ( ) const ( - HuggingFaceBaseURL = "https://huggingface.co" + huggingFaceBaseURL = "https://huggingface.co" ) -// ParseModelURL parses a Hugging Face model URL and extracts the owner and repository name -func ParseModelURL(modelURL string) (owner, repo string, err error) { +// parseModelURL parses a HuggingFace model URL and extracts the owner and repository name +func parseModelURL(modelURL string) (owner, repo string, err error) { // Handle both full URLs and short-form repo names modelURL = strings.TrimSpace(modelURL) @@ -50,7 +50,7 @@ func ParseModelURL(modelURL string) (owner, repo string, err error) { // Expected format: https://huggingface.co/owner/repo parts := strings.Split(strings.Trim(u.Path, "/"), "/") if len(parts) < 2 { - return "", "", fmt.Errorf("invalid Hugging Face URL format, expected https://huggingface.co/owner/repo") + return "", "", fmt.Errorf("invalid HuggingFace URL format, expected https://huggingface.co/owner/repo") } owner = parts[0] @@ -73,45 +73,8 @@ func ParseModelURL(modelURL string) (owner, repo string, err error) { return owner, repo, nil } -// DownloadModel downloads a model from Hugging Face using the huggingface-cli -// It assumes the user is already logged in via `huggingface-cli login` -func DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) { - owner, repo, err := ParseModelURL(modelURL) - if err != nil { - return "", err - } - - repoID := fmt.Sprintf("%s/%s", owner, repo) - - // Check if huggingface-cli is available - if _, err := exec.LookPath("huggingface-cli"); err != nil { - return "", fmt.Errorf("huggingface-cli not found in PATH. Please install it using: pip install huggingface_hub[cli]") - } - - // Create destination directory if it doesn't exist - if err := os.MkdirAll(destDir, 0755); err != nil { - return "", fmt.Errorf("failed to create destination directory: %w", err) - } - - // Construct the download path - downloadPath := filepath.Join(destDir, repo) - - // Use huggingface-cli to download the model - // The --local-dir-use-symlinks=False flag ensures files are copied, not symlinked - cmd := exec.CommandContext(ctx, "huggingface-cli", "download", repoID, "--local-dir", downloadPath, "--local-dir-use-symlinks", "False") - - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - if err := cmd.Run(); err != nil { - return "", fmt.Errorf("failed to download model using huggingface-cli: %w", err) - } - - return downloadPath, nil -} - -// CheckHuggingFaceAuth checks if the user is authenticated with Hugging Face -func CheckHuggingFaceAuth() error { +// checkHuggingFaceAuth checks if the user is authenticated with HuggingFace +func checkHuggingFaceAuth() error { // Try to find the HF token token := os.Getenv("HF_TOKEN") if token != "" { @@ -139,11 +102,11 @@ func CheckHuggingFaceAuth() error { } } - return fmt.Errorf("not authenticated with Hugging Face. Please run: huggingface-cli login") + return fmt.Errorf("not authenticated with HuggingFace. Please run: huggingface-cli login") } -// GetToken retrieves the Hugging Face token from environment or token file -func GetToken() (string, error) { +// getToken retrieves the HuggingFace token from environment or token file +func getToken() (string, error) { // First check environment variable token := os.Getenv("HF_TOKEN") if token != "" { @@ -165,16 +128,16 @@ func GetToken() (string, error) { return strings.TrimSpace(string(data)), nil } -// DownloadFile downloads a single file from Hugging Face -func DownloadFile(ctx context.Context, owner, repo, filename, destPath string) error { - token, err := GetToken() +// downloadFile downloads a single file from HuggingFace +func downloadFile(ctx context.Context, owner, repo, filename, destPath string) error { + token, err := getToken() if err != nil { - return fmt.Errorf("failed to get Hugging Face token: %w", err) + return fmt.Errorf("failed to get HuggingFace token: %w", err) } // Construct the download URL // Format: https://huggingface.co/{owner}/{repo}/resolve/main/{filename} - fileURL := fmt.Sprintf("%s/%s/%s/resolve/main/%s", HuggingFaceBaseURL, owner, repo, filename) + fileURL := fmt.Sprintf("%s/%s/%s/resolve/main/%s", huggingFaceBaseURL, owner, repo, filename) req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil) if err != nil { diff --git a/pkg/hfhub/download_test.go b/pkg/modelprovider/huggingface/downloader_test.go similarity index 63% rename from pkg/hfhub/download_test.go rename to pkg/modelprovider/huggingface/downloader_test.go index ba993a0a..d4bb4b04 100644 --- a/pkg/hfhub/download_test.go +++ b/pkg/modelprovider/huggingface/downloader_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package hfhub +package huggingface import ( "strings" @@ -62,7 +62,7 @@ func TestParseModelURL(t *testing.T) { name: "invalid format - missing repo", modelURL: "https://huggingface.co/meta-llama", wantErr: true, - errContains: "invalid Hugging Face URL format", + errContains: "invalid HuggingFace URL format", }, { name: "invalid format - only owner", @@ -87,31 +87,82 @@ func TestParseModelURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - owner, repo, err := ParseModelURL(tt.modelURL) + owner, repo, err := parseModelURL(tt.modelURL) if tt.wantErr { if err == nil { - t.Errorf("ParseModelURL() expected error but got nil") + t.Errorf("parseModelURL() expected error but got nil") return } if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { - t.Errorf("ParseModelURL() error = %v, want error containing %v", err, tt.errContains) + t.Errorf("parseModelURL() error = %v, want error containing %v", err, tt.errContains) } return } if err != nil { - t.Errorf("ParseModelURL() unexpected error = %v", err) + t.Errorf("parseModelURL() unexpected error = %v", err) return } if owner != tt.wantOwner { - t.Errorf("ParseModelURL() owner = %v, want %v", owner, tt.wantOwner) + t.Errorf("parseModelURL() owner = %v, want %v", owner, tt.wantOwner) } if repo != tt.wantRepo { - t.Errorf("ParseModelURL() repo = %v, want %v", repo, tt.wantRepo) + t.Errorf("parseModelURL() repo = %v, want %v", repo, tt.wantRepo) } }) } } + +func TestProvider_SupportsURL(t *testing.T) { + provider := New() + + tests := []struct { + name string + url string + want bool + }{ + { + name: "full HuggingFace URL", + url: "https://huggingface.co/meta-llama/Llama-2-7b-hf", + want: true, + }, + { + name: "short form repo", + url: "meta-llama/Llama-2-7b-hf", + want: true, + }, + { + name: "ModelScope URL", + url: "https://modelscope.cn/models/owner/repo", + want: false, + }, + { + name: "invalid format", + url: "just-a-string", + want: false, + }, + { + name: "HTTP URL", + url: "http://example.com/owner/repo", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := provider.SupportsURL(tt.url); got != tt.want { + t.Errorf("Provider.SupportsURL() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestProvider_Name(t *testing.T) { + provider := New() + if got := provider.Name(); got != "huggingface" { + t.Errorf("Provider.Name() = %v, want %v", got, "huggingface") + } +} diff --git a/pkg/modelprovider/huggingface/provider.go b/pkg/modelprovider/huggingface/provider.go new file mode 100644 index 00000000..fb84f352 --- /dev/null +++ b/pkg/modelprovider/huggingface/provider.go @@ -0,0 +1,99 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package huggingface + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// Provider implements the modelprovider.Provider interface for HuggingFace +type Provider struct{} + +// New creates a new HuggingFace provider instance +func New() *Provider { + return &Provider{} +} + +// Name returns the name of this provider +func (p *Provider) Name() string { + return "huggingface" +} + +// SupportsURL checks if this provider can handle the given URL +// It supports both full HuggingFace URLs and short-form repo identifiers +func (p *Provider) SupportsURL(url string) bool { + url = strings.TrimSpace(url) + + // Check for full HuggingFace URLs + if strings.Contains(url, "huggingface.co") { + return true + } + + // Check for short-form repo identifiers (owner/repo) + // Must have exactly one slash and no protocol + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + return strings.Count(url, "/") == 1 + } + + return false +} + +// DownloadModel downloads a model from HuggingFace using the huggingface-cli +func (p *Provider) DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) { + owner, repo, err := parseModelURL(modelURL) + if err != nil { + return "", err + } + + repoID := fmt.Sprintf("%s/%s", owner, repo) + + // Check if huggingface-cli is available + if _, err := exec.LookPath("huggingface-cli"); err != nil { + return "", fmt.Errorf("huggingface-cli not found in PATH. Please install it using: pip install huggingface_hub[cli]") + } + + // Create destination directory if it doesn't exist + if err := os.MkdirAll(destDir, 0755); err != nil { + return "", fmt.Errorf("failed to create destination directory: %w", err) + } + + // Construct the download path + downloadPath := filepath.Join(destDir, repo) + + // Use huggingface-cli to download the model + // The --local-dir-use-symlinks=False flag ensures files are copied, not symlinked + cmd := exec.CommandContext(ctx, "huggingface-cli", "download", repoID, "--local-dir", downloadPath, "--local-dir-use-symlinks", "False") + + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("failed to download model using huggingface-cli: %w", err) + } + + return downloadPath, nil +} + +// CheckAuth verifies that the user is authenticated with HuggingFace +func (p *Provider) CheckAuth() error { + return checkHuggingFaceAuth() +} diff --git a/pkg/modelprovider/modelscope/downloader.go b/pkg/modelprovider/modelscope/downloader.go new file mode 100644 index 00000000..eb5c489b --- /dev/null +++ b/pkg/modelprovider/modelscope/downloader.go @@ -0,0 +1,195 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelscope + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "strings" +) + +const ( + modelScopeBaseURL = "https://modelscope.cn" +) + +// parseModelURL parses a ModelScope model URL and extracts the owner and repository name +func parseModelURL(modelURL string) (owner, repo string, err error) { + // Handle both full URLs and short-form repo names + modelURL = strings.TrimSpace(modelURL) + + // Remove trailing slashes + modelURL = strings.TrimSuffix(modelURL, "/") + + // If it's a full URL, parse it + if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") { + u, err := url.Parse(modelURL) + if err != nil { + return "", "", fmt.Errorf("invalid URL: %w", err) + } + + // Expected format: https://modelscope.cn/models/owner/repo + parts := strings.Split(strings.Trim(u.Path, "/"), "/") + + // Handle both formats: + // 1. https://modelscope.cn/models/owner/repo + // 2. https://modelscope.cn/owner/repo + if len(parts) >= 1 && parts[0] == "models" { + // Must have exactly 3 parts: models/owner/repo + if len(parts) < 3 { + return "", "", fmt.Errorf("invalid ModelScope URL format, expected https://modelscope.cn/models/owner/repo") + } + owner = parts[1] + repo = parts[2] + } else if len(parts) >= 2 { + // Direct format: owner/repo + owner = parts[0] + repo = parts[1] + } else { + return "", "", fmt.Errorf("invalid ModelScope URL format, expected https://modelscope.cn/models/owner/repo") + } + } else { + // Handle short-form like "owner/repo" + parts := strings.Split(modelURL, "/") + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid model identifier, expected format: owner/repo") + } + + owner = parts[0] + repo = parts[1] + } + + if owner == "" || repo == "" { + return "", "", fmt.Errorf("owner and repository name cannot be empty") + } + + return owner, repo, nil +} + +// checkModelScopeAuth checks if the user is authenticated with ModelScope +func checkModelScopeAuth() error { + // Try to find the ModelScope SDK token + token := os.Getenv("MODELSCOPE_SDK_TOKEN") + if token != "" { + return nil + } + + // Check if the token file exists in the default location + homeDir, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("failed to get user home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, ".modelscope", "token") + if _, err := os.Stat(tokenPath); err == nil { + return nil + } + + // Try using modelscope CLI to check auth + if _, err := exec.LookPath("modelscope"); err == nil { + // ModelScope CLI doesn't have a direct whoami equivalent, but we can check if login exists + // For now, we'll allow the download to proceed and let it fail if auth is needed + return nil + } + + // Warning: ModelScope authentication is optional for public models + // We'll return nil here and let the download command handle auth errors + return nil +} + +// getToken retrieves the ModelScope token from environment or token file +func getToken() (string, error) { + // First check environment variable + token := os.Getenv("MODELSCOPE_SDK_TOKEN") + if token != "" { + return token, nil + } + + // Then check the token file + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, ".modelscope", "token") + data, err := os.ReadFile(tokenPath) + if err != nil { + return "", fmt.Errorf("failed to read token file: %w", err) + } + + return strings.TrimSpace(string(data)), nil +} + +// downloadFile downloads a single file from ModelScope +func downloadFile(ctx context.Context, owner, repo, filename, destPath string) error { + token, err := getToken() + if err != nil { + // Token is optional for public models, continue without it + token = "" + } + + // Construct the download URL + // Format: https://modelscope.cn/api/v1/models/{owner}/{repo}/repo?Revision=master&FilePath={filename} + fileURL := fmt.Sprintf("%s/api/v1/models/%s/%s/repo?Revision=master&FilePath=%s", modelScopeBaseURL, owner, repo, filename) + + req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + // Add authorization header if token is available + if token != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to download file: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download file, status code: %d", resp.StatusCode) + } + + // Create destination directory + destDir := filepath.Dir(destPath) + if err := os.MkdirAll(destDir, 0755); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + + // Create the destination file + outFile, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer outFile.Close() + + // Copy the content + _, err = io.Copy(outFile, resp.Body) + if err != nil { + return fmt.Errorf("failed to write file: %w", err) + } + + return nil +} diff --git a/pkg/modelprovider/modelscope/downloader_test.go b/pkg/modelprovider/modelscope/downloader_test.go new file mode 100644 index 00000000..3e70e023 --- /dev/null +++ b/pkg/modelprovider/modelscope/downloader_test.go @@ -0,0 +1,180 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelscope + +import ( + "strings" + "testing" +) + +func TestParseModelURL(t *testing.T) { + tests := []struct { + name string + modelURL string + wantOwner string + wantRepo string + wantErr bool + errContains string + }{ + { + name: "full URL with models prefix", + modelURL: "https://modelscope.cn/models/qwen/Qwen-7B", + wantOwner: "qwen", + wantRepo: "Qwen-7B", + wantErr: false, + }, + { + name: "full URL without models prefix", + modelURL: "https://modelscope.cn/damo/nlp_structbert_backbone_base_std", + wantOwner: "damo", + wantRepo: "nlp_structbert_backbone_base_std", + wantErr: false, + }, + { + name: "full URL with trailing slash", + modelURL: "https://modelscope.cn/models/qwen/Qwen-7B/", + wantOwner: "qwen", + wantRepo: "Qwen-7B", + wantErr: false, + }, + { + name: "short form", + modelURL: "qwen/Qwen-7B", + wantOwner: "qwen", + wantRepo: "Qwen-7B", + wantErr: false, + }, + { + name: "http URL", + modelURL: "http://modelscope.cn/models/damo/nlp_structbert_backbone_base_std", + wantOwner: "damo", + wantRepo: "nlp_structbert_backbone_base_std", + wantErr: false, + }, + { + name: "invalid format - missing repo", + modelURL: "https://modelscope.cn/models/qwen", + wantErr: true, + errContains: "invalid ModelScope URL format", + }, + { + name: "invalid format - only owner", + modelURL: "qwen", + wantErr: true, + errContains: "invalid model identifier", + }, + { + name: "empty URL", + modelURL: "", + wantErr: true, + errContains: "invalid model identifier", + }, + { + name: "URL with spaces (trimmed)", + modelURL: " qwen/Qwen-7B ", + wantOwner: "qwen", + wantRepo: "Qwen-7B", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + owner, repo, err := parseModelURL(tt.modelURL) + + if tt.wantErr { + if err == nil { + t.Errorf("parseModelURL() expected error but got nil") + return + } + if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("parseModelURL() error = %v, want error containing %v", err, tt.errContains) + } + return + } + + if err != nil { + t.Errorf("parseModelURL() unexpected error = %v", err) + return + } + + if owner != tt.wantOwner { + t.Errorf("parseModelURL() owner = %v, want %v", owner, tt.wantOwner) + } + + if repo != tt.wantRepo { + t.Errorf("parseModelURL() repo = %v, want %v", repo, tt.wantRepo) + } + }) + } +} + +func TestProvider_SupportsURL(t *testing.T) { + provider := New() + + tests := []struct { + name string + url string + want bool + }{ + { + name: "full ModelScope URL", + url: "https://modelscope.cn/models/qwen/Qwen-7B", + want: true, + }, + { + name: "ModelScope URL without models prefix", + url: "https://modelscope.cn/damo/nlp_structbert_backbone_base_std", + want: true, + }, + { + name: "HuggingFace URL", + url: "https://huggingface.co/meta-llama/Llama-2-7b-hf", + want: false, + }, + { + name: "short form repo (ambiguous, returns false)", + url: "qwen/Qwen-7B", + want: false, + }, + { + name: "invalid format", + url: "just-a-string", + want: false, + }, + { + name: "HTTP URL", + url: "http://example.com/owner/repo", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := provider.SupportsURL(tt.url); got != tt.want { + t.Errorf("Provider.SupportsURL() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestProvider_Name(t *testing.T) { + provider := New() + if got := provider.Name(); got != "modelscope" { + t.Errorf("Provider.Name() = %v, want %v", got, "modelscope") + } +} diff --git a/pkg/modelprovider/modelscope/provider.go b/pkg/modelprovider/modelscope/provider.go new file mode 100644 index 00000000..284d33a5 --- /dev/null +++ b/pkg/modelprovider/modelscope/provider.go @@ -0,0 +1,96 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelscope + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// Provider implements the modelprovider.Provider interface for ModelScope +type Provider struct{} + +// New creates a new ModelScope provider instance +func New() *Provider { + return &Provider{} +} + +// Name returns the name of this provider +func (p *Provider) Name() string { + return "modelscope" +} + +// SupportsURL checks if this provider can handle the given URL +// It supports both full ModelScope URLs and short-form repo identifiers +func (p *Provider) SupportsURL(url string) bool { + url = strings.TrimSpace(url) + + // Check for full ModelScope URLs + if strings.Contains(url, "modelscope.cn") { + return true + } + + // Note: We don't auto-detect short-form for ModelScope to avoid conflicts with HuggingFace + // Users should use full URLs or explicitly specify the provider + + return false +} + +// DownloadModel downloads a model from ModelScope using the modelscope CLI +func (p *Provider) DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) { + owner, repo, err := parseModelURL(modelURL) + if err != nil { + return "", err + } + + repoID := fmt.Sprintf("%s/%s", owner, repo) + + // Check if modelscope CLI is available + if _, err := exec.LookPath("modelscope"); err != nil { + return "", fmt.Errorf("modelscope CLI not found in PATH. Please install it using: pip install modelscope") + } + + // Create destination directory if it doesn't exist + if err := os.MkdirAll(destDir, 0755); err != nil { + return "", fmt.Errorf("failed to create destination directory: %w", err) + } + + // Construct the download path + downloadPath := filepath.Join(destDir, repo) + + // Use modelscope download command + // The modelscope CLI uses: modelscope download --model --local_dir + cmd := exec.CommandContext(ctx, "modelscope", "download", "--model", repoID, "--local_dir", downloadPath) + + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("failed to download model using modelscope CLI: %w", err) + } + + return downloadPath, nil +} + +// CheckAuth verifies that the user is authenticated with ModelScope +func (p *Provider) CheckAuth() error { + return checkModelScopeAuth() +} diff --git a/pkg/modelprovider/provider.go b/pkg/modelprovider/provider.go new file mode 100644 index 00000000..331628f4 --- /dev/null +++ b/pkg/modelprovider/provider.go @@ -0,0 +1,46 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelprovider + +import "context" + +// Provider defines the interface that all model providers must implement. +// A provider is responsible for downloading models from a specific source +// (e.g., HuggingFace, ModelScope, Civitai, etc.) +type Provider interface { + // Name returns the human-readable name of the provider + // Example: "huggingface", "modelscope", "civitai" + Name() string + + // SupportsURL checks if this provider can handle the given model URL + // This enables automatic provider detection based on URL patterns + SupportsURL(url string) bool + + // DownloadModel downloads a model from the provider and returns the local path + // Parameters: + // - ctx: context for cancellation and timeout + // - modelURL: the URL or identifier of the model to download + // - destDir: the destination directory where the model should be downloaded + // Returns: + // - string: the local path where the model was downloaded + // - error: any error that occurred during download + DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) + + // CheckAuth verifies that the user is authenticated with the provider + // Returns an error if authentication is missing or invalid + CheckAuth() error +} diff --git a/pkg/modelprovider/registry.go b/pkg/modelprovider/registry.go new file mode 100644 index 00000000..8708a62c --- /dev/null +++ b/pkg/modelprovider/registry.go @@ -0,0 +1,74 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelprovider + +import ( + "fmt" + + "github.com/modelpack/modctl/pkg/modelprovider/huggingface" + "github.com/modelpack/modctl/pkg/modelprovider/modelscope" +) + +// Registry manages all available model providers and provides +// functionality to select the appropriate provider for a given URL +type Registry struct { + providers []Provider +} + +// NewRegistry creates a new provider registry with all available providers +func NewRegistry() *Registry { + return &Registry{ + providers: []Provider{ + huggingface.New(), + modelscope.New(), + // Future providers can be added here: + // civitai.New(), + }, + } +} + +// GetProvider returns the appropriate provider for the given model URL +// It iterates through all registered providers and returns the first one +// that supports the URL +func (r *Registry) GetProvider(modelURL string) (Provider, error) { + for _, p := range r.providers { + if p.SupportsURL(modelURL) { + return p, nil + } + } + return nil, fmt.Errorf("no provider found for URL: %s", modelURL) +} + +// GetProviderByName returns a specific provider by its name +// This is useful when you want to explicitly select a provider +func (r *Registry) GetProviderByName(name string) (Provider, error) { + for _, p := range r.providers { + if p.Name() == name { + return p, nil + } + } + return nil, fmt.Errorf("provider not found: %s", name) +} + +// ListProviders returns the names of all registered providers +func (r *Registry) ListProviders() []string { + names := make([]string, len(r.providers)) + for i, p := range r.providers { + names[i] = p.Name() + } + return names +} diff --git a/pkg/modelprovider/registry_test.go b/pkg/modelprovider/registry_test.go new file mode 100644 index 00000000..f2e9a95a --- /dev/null +++ b/pkg/modelprovider/registry_test.go @@ -0,0 +1,164 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelprovider + +import ( + "testing" +) + +func TestRegistry_GetProvider(t *testing.T) { + registry := NewRegistry() + + tests := []struct { + name string + modelURL string + wantProvider string + wantErr bool + }{ + { + name: "HuggingFace full URL", + modelURL: "https://huggingface.co/meta-llama/Llama-2-7b-hf", + wantProvider: "huggingface", + wantErr: false, + }, + { + name: "HuggingFace short form", + modelURL: "meta-llama/Llama-2-7b-hf", + wantProvider: "huggingface", + wantErr: false, + }, + { + name: "ModelScope full URL", + modelURL: "https://modelscope.cn/models/qwen/Qwen-7B", + wantProvider: "modelscope", + wantErr: false, + }, + { + name: "ModelScope URL without models prefix", + modelURL: "https://modelscope.cn/damo/nlp_structbert_backbone_base_std", + wantProvider: "modelscope", + wantErr: false, + }, + { + name: "Unsupported URL", + modelURL: "https://example.com/model/repo", + wantErr: true, + }, + { + name: "Invalid format", + modelURL: "just-a-string", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := registry.GetProvider(tt.modelURL) + + if tt.wantErr { + if err == nil { + t.Errorf("GetProvider() expected error but got nil") + } + return + } + + if err != nil { + t.Errorf("GetProvider() unexpected error = %v", err) + return + } + + if provider.Name() != tt.wantProvider { + t.Errorf("GetProvider() provider name = %v, want %v", provider.Name(), tt.wantProvider) + } + }) + } +} + +func TestRegistry_GetProviderByName(t *testing.T) { + registry := NewRegistry() + + tests := []struct { + name string + providerName string + wantErr bool + }{ + { + name: "Get HuggingFace provider", + providerName: "huggingface", + wantErr: false, + }, + { + name: "Get ModelScope provider", + providerName: "modelscope", + wantErr: false, + }, + { + name: "Get non-existent provider", + providerName: "civitai", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := registry.GetProviderByName(tt.providerName) + + if tt.wantErr { + if err == nil { + t.Errorf("GetProviderByName() expected error but got nil") + } + return + } + + if err != nil { + t.Errorf("GetProviderByName() unexpected error = %v", err) + return + } + + if provider.Name() != tt.providerName { + t.Errorf("GetProviderByName() provider name = %v, want %v", provider.Name(), tt.providerName) + } + }) + } +} + +func TestRegistry_ListProviders(t *testing.T) { + registry := NewRegistry() + providers := registry.ListProviders() + + if len(providers) != 2 { + t.Errorf("ListProviders() returned %d providers, want 2", len(providers)) + } + + expectedProviders := map[string]bool{ + "huggingface": false, + "modelscope": false, + } + + for _, name := range providers { + if _, ok := expectedProviders[name]; !ok { + t.Errorf("ListProviders() returned unexpected provider: %s", name) + } + expectedProviders[name] = true + } + + for name, found := range expectedProviders { + if !found { + t.Errorf("ListProviders() missing expected provider: %s", name) + } + } +} From f455aed313b3808da7dfed136b4b7ace1160f707 Mon Sep 17 00:00:00 2001 From: Avinash Singh Date: Wed, 26 Nov 2025 18:33:15 +0530 Subject: [PATCH 4/5] optimise code Signed-off-by: Avinash Singh --- cmd/modelfile/generate.go | 37 +++++--- pkg/config/modelfile/modelfile.go | 2 + pkg/modelprovider/huggingface/downloader.go | 54 ----------- .../huggingface/downloader_test.go | 4 +- pkg/modelprovider/huggingface/provider.go | 17 +--- pkg/modelprovider/modelscope/downloader.go | 58 ------------ pkg/modelprovider/modelscope/provider.go | 14 +-- pkg/modelprovider/provider.go | 3 +- pkg/modelprovider/registry.go | 54 ++++++++--- pkg/modelprovider/registry_test.go | 92 +++++++++++++++++-- 10 files changed, 163 insertions(+), 172 deletions(-) diff --git a/cmd/modelfile/generate.go b/cmd/modelfile/generate.go index 0c098808..0f5a8ec4 100644 --- a/cmd/modelfile/generate.go +++ b/cmd/modelfile/generate.go @@ -38,18 +38,24 @@ var generateCmd = &cobra.Command{ Long: `Generate a modelfile from either a local directory containing model files or by downloading a model from a supported provider. The workspace must be a directory including model files and model configuration files. -Alternatively, use --model_url to download a model from a supported provider (e.g., HuggingFace, ModelScope).`, +Alternatively, use --model-url to download a model from a supported provider (e.g., HuggingFace, ModelScope). + +For short-form URLs (owner/repo), you must explicitly specify the provider using --provider flag. +Full URLs with domain names will auto-detect the provider.`, Example: ` # Generate from local directory modctl modelfile generate ./my-model-dir - # Generate from Hugging Face model URL - modctl modelfile generate --model_url https://huggingface.co/meta-llama/Llama-2-7b-hf + # Generate from Hugging Face using full URL (auto-detects provider) + modctl modelfile generate --model-url https://huggingface.co/meta-llama/Llama-2-7b-hf + + # Generate from Hugging Face using short form (requires --provider) + modctl modelfile generate --model-url meta-llama/Llama-2-7b-hf --provider huggingface - # Generate from Hugging Face using short form - modctl modelfile generate --model_url meta-llama/Llama-2-7b-hf + # Generate from ModelScope using full URL (auto-detects provider) + modctl modelfile generate --model-url https://modelscope.cn/models/qwen/Qwen-7B - # Generate from ModelScope - modctl modelfile generate --model_url https://modelscope.cn/models/qwen/Qwen-7B + # Generate from ModelScope using short form (requires --provider) + modctl modelfile generate --model-url qwen/Qwen-7B --provider modelscope # Generate with custom output path modctl modelfile generate ./my-model-dir --output ./output/modelfile.yaml @@ -61,18 +67,18 @@ Alternatively, use --model_url to download a model from a supported provider (e. SilenceUsage: true, FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true}, RunE: func(cmd *cobra.Command, args []string) error { - // If model_url is provided, path is optional + // If model-url is provided, path is optional workspace := "." if len(args) > 0 { workspace = args[0] } - // Validate that either path or model_url is provided + // Validate that either path or model-url is provided if generateConfig.ModelURL != "" && len(args) > 0 { - return fmt.Errorf("the argument and the --model_url flag are mutually exclusive") + return fmt.Errorf("the argument and the --model-url flag are mutually exclusive") } if generateConfig.ModelURL == "" && len(args) == 0 { - return fmt.Errorf("either a argument or the --model_url flag must be provided") + return fmt.Errorf("either a argument or the --model-url flag must be provided") } if err := generateConfig.Convert(workspace); err != nil { @@ -100,7 +106,8 @@ func init() { flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory") flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace") flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile") - flags.StringVar(&generateConfig.ModelURL, "model_url", "", "download model from a supported provider (HuggingFace: owner/repo or full URL, ModelScope: full URL)") + flags.StringVar(&generateConfig.ModelURL, "model-url", "", "download model from a supported provider (full URL or short-form with --provider)") + flags.StringVarP(&generateConfig.Provider, "provider", "p", "", "explicitly specify the provider for short-form URLs (huggingface, modelscope)") // Mark the ignore-unrecognized-file-types flag as deprecated and hidden flags.MarkDeprecated("ignore-unrecognized-file-types", "this flag will be removed in the next release") @@ -118,10 +125,10 @@ func runGenerate(ctx context.Context) error { fmt.Printf("Model URL provided: %s\n", generateConfig.ModelURL) // Get the appropriate provider for this URL - registry := modelprovider.NewRegistry() - provider, err := registry.GetProvider(generateConfig.ModelURL) + registry := modelprovider.GetRegistry() + provider, err := registry.SelectProvider(generateConfig.ModelURL, generateConfig.Provider) if err != nil { - return fmt.Errorf("unsupported model URL: %w", err) + return fmt.Errorf("failed to select provider: %w", err) } fmt.Printf("Using provider: %s\n", provider.Name()) diff --git a/pkg/config/modelfile/modelfile.go b/pkg/config/modelfile/modelfile.go index 5723193a..ac785b97 100644 --- a/pkg/config/modelfile/modelfile.go +++ b/pkg/config/modelfile/modelfile.go @@ -40,6 +40,7 @@ type GenerateConfig struct { Precision string Quantization string ModelURL string + Provider string // Explicit provider for short-form URLs (e.g., "huggingface", "modelscope") } func NewGenerateConfig() *GenerateConfig { @@ -57,6 +58,7 @@ func NewGenerateConfig() *GenerateConfig { Precision: "", Quantization: "", ModelURL: "", + Provider: "", } } diff --git a/pkg/modelprovider/huggingface/downloader.go b/pkg/modelprovider/huggingface/downloader.go index 910418a0..3f5616e2 100644 --- a/pkg/modelprovider/huggingface/downloader.go +++ b/pkg/modelprovider/huggingface/downloader.go @@ -17,10 +17,8 @@ package huggingface import ( - "context" "fmt" "io" - "net/http" "net/url" "os" "os/exec" @@ -127,55 +125,3 @@ func getToken() (string, error) { return strings.TrimSpace(string(data)), nil } - -// downloadFile downloads a single file from HuggingFace -func downloadFile(ctx context.Context, owner, repo, filename, destPath string) error { - token, err := getToken() - if err != nil { - return fmt.Errorf("failed to get HuggingFace token: %w", err) - } - - // Construct the download URL - // Format: https://huggingface.co/{owner}/{repo}/resolve/main/{filename} - fileURL := fmt.Sprintf("%s/%s/%s/resolve/main/%s", huggingFaceBaseURL, owner, repo, filename) - - req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - // Add authorization header - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to download file: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to download file, status code: %d", resp.StatusCode) - } - - // Create destination directory - destDir := filepath.Dir(destPath) - if err := os.MkdirAll(destDir, 0755); err != nil { - return fmt.Errorf("failed to create destination directory: %w", err) - } - - // Create the destination file - outFile, err := os.Create(destPath) - if err != nil { - return fmt.Errorf("failed to create destination file: %w", err) - } - defer outFile.Close() - - // Copy the content - _, err = io.Copy(outFile, resp.Body) - if err != nil { - return fmt.Errorf("failed to write file: %w", err) - } - - return nil -} diff --git a/pkg/modelprovider/huggingface/downloader_test.go b/pkg/modelprovider/huggingface/downloader_test.go index d4bb4b04..f7c385c2 100644 --- a/pkg/modelprovider/huggingface/downloader_test.go +++ b/pkg/modelprovider/huggingface/downloader_test.go @@ -130,9 +130,9 @@ func TestProvider_SupportsURL(t *testing.T) { want: true, }, { - name: "short form repo", + name: "short form repo (requires explicit --provider)", url: "meta-llama/Llama-2-7b-hf", - want: true, + want: false, }, { name: "ModelScope URL", diff --git a/pkg/modelprovider/huggingface/provider.go b/pkg/modelprovider/huggingface/provider.go index fb84f352..a4cc7698 100644 --- a/pkg/modelprovider/huggingface/provider.go +++ b/pkg/modelprovider/huggingface/provider.go @@ -39,22 +39,13 @@ func (p *Provider) Name() string { } // SupportsURL checks if this provider can handle the given URL -// It supports both full HuggingFace URLs and short-form repo identifiers +// It only supports full HuggingFace URLs with the huggingface.co domain +// For short-form repo identifiers (owner/repo), users must explicitly specify --provider huggingface func (p *Provider) SupportsURL(url string) bool { url = strings.TrimSpace(url) - // Check for full HuggingFace URLs - if strings.Contains(url, "huggingface.co") { - return true - } - - // Check for short-form repo identifiers (owner/repo) - // Must have exactly one slash and no protocol - if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { - return strings.Count(url, "/") == 1 - } - - return false + // Only support full HuggingFace URLs + return strings.Contains(url, "huggingface.co") } // DownloadModel downloads a model from HuggingFace using the huggingface-cli diff --git a/pkg/modelprovider/modelscope/downloader.go b/pkg/modelprovider/modelscope/downloader.go index eb5c489b..b7a85f48 100644 --- a/pkg/modelprovider/modelscope/downloader.go +++ b/pkg/modelprovider/modelscope/downloader.go @@ -17,10 +17,7 @@ package modelscope import ( - "context" "fmt" - "io" - "net/http" "net/url" "os" "os/exec" @@ -138,58 +135,3 @@ func getToken() (string, error) { return strings.TrimSpace(string(data)), nil } - -// downloadFile downloads a single file from ModelScope -func downloadFile(ctx context.Context, owner, repo, filename, destPath string) error { - token, err := getToken() - if err != nil { - // Token is optional for public models, continue without it - token = "" - } - - // Construct the download URL - // Format: https://modelscope.cn/api/v1/models/{owner}/{repo}/repo?Revision=master&FilePath={filename} - fileURL := fmt.Sprintf("%s/api/v1/models/%s/%s/repo?Revision=master&FilePath=%s", modelScopeBaseURL, owner, repo, filename) - - req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - // Add authorization header if token is available - if token != "" { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - } - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to download file: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to download file, status code: %d", resp.StatusCode) - } - - // Create destination directory - destDir := filepath.Dir(destPath) - if err := os.MkdirAll(destDir, 0755); err != nil { - return fmt.Errorf("failed to create destination directory: %w", err) - } - - // Create the destination file - outFile, err := os.Create(destPath) - if err != nil { - return fmt.Errorf("failed to create destination file: %w", err) - } - defer outFile.Close() - - // Copy the content - _, err = io.Copy(outFile, resp.Body) - if err != nil { - return fmt.Errorf("failed to write file: %w", err) - } - - return nil -} diff --git a/pkg/modelprovider/modelscope/provider.go b/pkg/modelprovider/modelscope/provider.go index 284d33a5..8b16912d 100644 --- a/pkg/modelprovider/modelscope/provider.go +++ b/pkg/modelprovider/modelscope/provider.go @@ -39,19 +39,13 @@ func (p *Provider) Name() string { } // SupportsURL checks if this provider can handle the given URL -// It supports both full ModelScope URLs and short-form repo identifiers +// It only supports full ModelScope URLs with the modelscope.cn domain +// For short-form repo identifiers (owner/repo), users must explicitly specify --provider modelscope func (p *Provider) SupportsURL(url string) bool { url = strings.TrimSpace(url) - // Check for full ModelScope URLs - if strings.Contains(url, "modelscope.cn") { - return true - } - - // Note: We don't auto-detect short-form for ModelScope to avoid conflicts with HuggingFace - // Users should use full URLs or explicitly specify the provider - - return false + // Only support full ModelScope URLs + return strings.Contains(url, "modelscope.cn") } // DownloadModel downloads a model from ModelScope using the modelscope CLI diff --git a/pkg/modelprovider/provider.go b/pkg/modelprovider/provider.go index 331628f4..974dbfd9 100644 --- a/pkg/modelprovider/provider.go +++ b/pkg/modelprovider/provider.go @@ -27,7 +27,8 @@ type Provider interface { Name() string // SupportsURL checks if this provider can handle the given model URL - // This enables automatic provider detection based on URL patterns + // This enables automatic provider detection based on full URL patterns (with domain) + // Short-form URLs (owner/repo) require explicit provider specification via GetProviderByName SupportsURL(url string) bool // DownloadModel downloads a model from the provider and returns the local path diff --git a/pkg/modelprovider/registry.go b/pkg/modelprovider/registry.go index 8708a62c..55de2b69 100644 --- a/pkg/modelprovider/registry.go +++ b/pkg/modelprovider/registry.go @@ -18,6 +18,7 @@ package modelprovider import ( "fmt" + "sync" "github.com/modelpack/modctl/pkg/modelprovider/huggingface" "github.com/modelpack/modctl/pkg/modelprovider/modelscope" @@ -29,28 +30,57 @@ type Registry struct { providers []Provider } -// NewRegistry creates a new provider registry with all available providers -func NewRegistry() *Registry { - return &Registry{ - providers: []Provider{ - huggingface.New(), - modelscope.New(), - // Future providers can be added here: - // civitai.New(), - }, - } +var ( + instance *Registry + once sync.Once +) + +// GetRegistry returns the singleton instance of the registry +// This is thread-safe and will only create the instance once +func GetRegistry() *Registry { + once.Do(func() { + instance = &Registry{ + providers: []Provider{ + huggingface.New(), + modelscope.New(), + // Future providers can be added here: + // civitai.New(), + }, + } + }) + return instance +} + +// ResetRegistry resets the singleton instance +// This should only be used in tests to ensure isolation between test cases +func ResetRegistry() { + once = sync.Once{} + instance = nil } // GetProvider returns the appropriate provider for the given model URL // It iterates through all registered providers and returns the first one -// that supports the URL +// that supports the URL. This only works for full URLs with domain names. +// For short-form URLs (owner/repo), use GetProviderByName with an explicit provider func (r *Registry) GetProvider(modelURL string) (Provider, error) { for _, p := range r.providers { if p.SupportsURL(modelURL) { return p, nil } } - return nil, fmt.Errorf("no provider found for URL: %s", modelURL) + return nil, fmt.Errorf("no provider found for URL: %s. For short-form URLs (owner/repo), use --provider flag to specify the provider explicitly", modelURL) +} + +// SelectProvider returns the appropriate provider based on the URL and explicit provider name +// If providerName is specified, it uses GetProviderByName for short-form URLs +// Otherwise, it uses GetProvider for auto-detection with full URLs +func (r *Registry) SelectProvider(modelURL, providerName string) (Provider, error) { + if providerName != "" { + // Explicit provider specified, use it + return r.GetProviderByName(providerName) + } + // No explicit provider, try auto-detection + return r.GetProvider(modelURL) } // GetProviderByName returns a specific provider by its name diff --git a/pkg/modelprovider/registry_test.go b/pkg/modelprovider/registry_test.go index f2e9a95a..23d5ad46 100644 --- a/pkg/modelprovider/registry_test.go +++ b/pkg/modelprovider/registry_test.go @@ -21,7 +21,8 @@ import ( ) func TestRegistry_GetProvider(t *testing.T) { - registry := NewRegistry() + ResetRegistry() // Ensure clean state for test + registry := GetRegistry() tests := []struct { name string @@ -36,10 +37,9 @@ func TestRegistry_GetProvider(t *testing.T) { wantErr: false, }, { - name: "HuggingFace short form", - modelURL: "meta-llama/Llama-2-7b-hf", - wantProvider: "huggingface", - wantErr: false, + name: "HuggingFace short form (requires explicit provider)", + modelURL: "meta-llama/Llama-2-7b-hf", + wantErr: true, }, { name: "ModelScope full URL", @@ -89,7 +89,8 @@ func TestRegistry_GetProvider(t *testing.T) { } func TestRegistry_GetProviderByName(t *testing.T) { - registry := NewRegistry() + ResetRegistry() // Ensure clean state for test + registry := GetRegistry() tests := []struct { name string @@ -137,7 +138,8 @@ func TestRegistry_GetProviderByName(t *testing.T) { } func TestRegistry_ListProviders(t *testing.T) { - registry := NewRegistry() + ResetRegistry() // Ensure clean state for test + registry := GetRegistry() providers := registry.ListProviders() if len(providers) != 2 { @@ -162,3 +164,79 @@ func TestRegistry_ListProviders(t *testing.T) { } } } + +func TestRegistry_SelectProvider(t *testing.T) { + ResetRegistry() // Ensure clean state for test + registry := GetRegistry() + + tests := []struct { + name string + modelURL string + providerName string + wantProvider string + wantErr bool + }{ + { + name: "Full URL with auto-detection (HuggingFace)", + modelURL: "https://huggingface.co/meta-llama/Llama-2-7b-hf", + providerName: "", + wantProvider: "huggingface", + wantErr: false, + }, + { + name: "Full URL with auto-detection (ModelScope)", + modelURL: "https://modelscope.cn/models/qwen/Qwen-7B", + providerName: "", + wantProvider: "modelscope", + wantErr: false, + }, + { + name: "Short-form with explicit provider (HuggingFace)", + modelURL: "meta-llama/Llama-2-7b-hf", + providerName: "huggingface", + wantProvider: "huggingface", + wantErr: false, + }, + { + name: "Short-form with explicit provider (ModelScope)", + modelURL: "qwen/Qwen-7B", + providerName: "modelscope", + wantProvider: "modelscope", + wantErr: false, + }, + { + name: "Short-form without explicit provider (should fail)", + modelURL: "owner/repo", + providerName: "", + wantErr: true, + }, + { + name: "Invalid provider name", + modelURL: "owner/repo", + providerName: "invalid-provider", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := registry.SelectProvider(tt.modelURL, tt.providerName) + + if tt.wantErr { + if err == nil { + t.Errorf("SelectProvider() expected error but got nil") + } + return + } + + if err != nil { + t.Errorf("SelectProvider() unexpected error = %v", err) + return + } + + if provider.Name() != tt.wantProvider { + t.Errorf("SelectProvider() provider name = %v, want %v", provider.Name(), tt.wantProvider) + } + }) + } +} From 76743d410b4b6afa9798507812d22b3d0ac12d3c Mon Sep 17 00:00:00 2001 From: Avinash Singh Date: Mon, 15 Dec 2025 19:10:58 +0530 Subject: [PATCH 5/5] add optional param for download-dir Signed-off-by: Avinash Singh --- cmd/modelfile/generate.go | 39 +++++++++++++++++++++++++------ pkg/config/modelfile/modelfile.go | 2 ++ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/cmd/modelfile/generate.go b/cmd/modelfile/generate.go index 85699dae..4561b2b3 100644 --- a/cmd/modelfile/generate.go +++ b/cmd/modelfile/generate.go @@ -57,6 +57,9 @@ Full URLs with domain names will auto-detect the provider.`, # Generate from ModelScope using short form (requires --provider) modctl modelfile generate --model-url qwen/Qwen-7B --provider modelscope + # Generate with custom download directory + modctl modelfile generate --model-url meta-llama/Llama-2-7b-hf --provider huggingface --download-dir $HOME/models + # Generate with custom output path modctl modelfile generate ./my-model-dir --output ./output/modelfile.yaml @@ -108,6 +111,7 @@ func init() { flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile") flags.StringVar(&generateConfig.ModelURL, "model-url", "", "download model from a supported provider (full URL or short-form with --provider)") flags.StringVarP(&generateConfig.Provider, "provider", "p", "", "explicitly specify the provider for short-form URLs (huggingface, modelscope)") + flags.StringVar(&generateConfig.DownloadDir, "download-dir", "", "custom directory for downloading models (default: system temp directory)") flags.StringArrayVar(&generateConfig.ExcludePatterns, "exclude", []string{}, "specify glob patterns to exclude files/directories (e.g. *.log, checkpoints/*)") // Mark the ignore-unrecognized-file-types flag as deprecated and hidden @@ -139,16 +143,37 @@ func runGenerate(ctx context.Context) error { return fmt.Errorf("%s authentication check failed: %w", provider.Name(), err) } - // Create a temporary directory for downloading the model - // Clean up the temporary directory after the function returns - tmpDir, err := os.MkdirTemp("", "modctl-model-downloads-*") - if err != nil { - return fmt.Errorf("failed to create temporary directory: %w", err) + // Determine the download directory + var downloadDir string + var cleanupDir bool + + if generateConfig.DownloadDir != "" { + // Use user-specified directory + downloadDir = generateConfig.DownloadDir + cleanupDir = false + + // Create the directory if it doesn't exist + if err := os.MkdirAll(downloadDir, 0755); err != nil { + return fmt.Errorf("failed to create download directory: %w", err) + } + fmt.Printf("Using custom download directory: %s\n", downloadDir) + } else { + // Create a temporary directory for downloading the model + tmpDir, err := os.MkdirTemp("", "modctl-model-downloads-*") + if err != nil { + return fmt.Errorf("failed to create temporary directory: %w", err) + } + downloadDir = tmpDir + cleanupDir = true + } + + // Clean up the directory only if it was a temporary directory + if cleanupDir { + defer os.RemoveAll(downloadDir) } - defer os.RemoveAll(tmpDir) // Download the model - downloadPath, err := provider.DownloadModel(ctx, generateConfig.ModelURL, tmpDir) + downloadPath, err := provider.DownloadModel(ctx, generateConfig.ModelURL, downloadDir) if err != nil { return fmt.Errorf("failed to download model from %s: %w", provider.Name(), err) } diff --git a/pkg/config/modelfile/modelfile.go b/pkg/config/modelfile/modelfile.go index cf3cc652..a51789f1 100644 --- a/pkg/config/modelfile/modelfile.go +++ b/pkg/config/modelfile/modelfile.go @@ -41,6 +41,7 @@ type GenerateConfig struct { Quantization string ModelURL string Provider string // Explicit provider for short-form URLs (e.g., "huggingface", "modelscope") + DownloadDir string // Custom directory for downloading models (optional) ExcludePatterns []string } @@ -60,6 +61,7 @@ func NewGenerateConfig() *GenerateConfig { Quantization: "", ModelURL: "", Provider: "", + DownloadDir: "", ExcludePatterns: []string{}, } }