diff --git a/cmd/modelfile/generate.go b/cmd/modelfile/generate.go index 2ae2ad5f..4561b2b3 100644 --- a/cmd/modelfile/generate.go +++ b/cmd/modelfile/generate.go @@ -26,20 +26,65 @@ import ( configmodelfile "github.com/modelpack/modctl/pkg/config/modelfile" "github.com/modelpack/modctl/pkg/modelfile" + "github.com/modelpack/modctl/pkg/modelprovider" ) var generateConfig = configmodelfile.NewGenerateConfig() // generateCmd represents the modelfile tools command for generating modelfile. var generateCmd = &cobra.Command{ - Use: "generate [flags] ", - Short: "A command line tool for generating modelfile in the workspace, the workspace must be a directory including model files and model configuration files", - Args: cobra.ExactArgs(1), + Use: "generate [flags] []", + Short: "Generate a modelfile from a local workspace or remote model provider", + Long: `Generate a modelfile from either a local directory containing model files or by downloading a model from a supported provider. + +The workspace must be a directory including model files and model configuration files. +Alternatively, use --model-url to download a model from a supported provider (e.g., HuggingFace, ModelScope). + +For short-form URLs (owner/repo), you must explicitly specify the provider using --provider flag. +Full URLs with domain names will auto-detect the provider.`, + Example: ` # Generate from local directory + modctl modelfile generate ./my-model-dir + + # Generate from Hugging Face using full URL (auto-detects provider) + modctl modelfile generate --model-url https://huggingface.co/meta-llama/Llama-2-7b-hf + + # Generate from Hugging Face using short form (requires --provider) + modctl modelfile generate --model-url meta-llama/Llama-2-7b-hf --provider huggingface + + # Generate from ModelScope using full URL (auto-detects provider) + modctl modelfile generate --model-url https://modelscope.cn/models/qwen/Qwen-7B + + # Generate from ModelScope using short form (requires --provider) + modctl modelfile generate --model-url qwen/Qwen-7B --provider modelscope + + # Generate with custom download directory + modctl modelfile generate --model-url meta-llama/Llama-2-7b-hf --provider huggingface --download-dir $HOME/models + + # Generate with custom output path + modctl modelfile generate ./my-model-dir --output ./output/modelfile.yaml + + # Generate with metadata overrides + modctl modelfile generate ./my-model-dir --name my-custom-model --family llama3`, + Args: cobra.MaximumNArgs(1), DisableAutoGenTag: true, SilenceUsage: true, FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true}, RunE: func(cmd *cobra.Command, args []string) error { - if err := generateConfig.Convert(args[0]); err != nil { + // If model-url is provided, path is optional + workspace := "." + if len(args) > 0 { + workspace = args[0] + } + + // Validate that either path or model-url is provided + if generateConfig.ModelURL != "" && len(args) > 0 { + return fmt.Errorf("the argument and the --model-url flag are mutually exclusive") + } + if generateConfig.ModelURL == "" && len(args) == 0 { + return fmt.Errorf("either a argument or the --model-url flag must be provided") + } + + if err := generateConfig.Convert(workspace); err != nil { return err } @@ -64,6 +109,9 @@ func init() { flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory") flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace") flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile") + flags.StringVar(&generateConfig.ModelURL, "model-url", "", "download model from a supported provider (full URL or short-form with --provider)") + flags.StringVarP(&generateConfig.Provider, "provider", "p", "", "explicitly specify the provider for short-form URLs (huggingface, modelscope)") + flags.StringVar(&generateConfig.DownloadDir, "download-dir", "", "custom directory for downloading models (default: system temp directory)") flags.StringArrayVar(&generateConfig.ExcludePatterns, "exclude", []string{}, "specify glob patterns to exclude files/directories (e.g. *.log, checkpoints/*)") // Mark the ignore-unrecognized-file-types flag as deprecated and hidden @@ -76,7 +124,65 @@ func init() { } // runGenerate runs the generate modelfile. -func runGenerate(_ context.Context) error { +func runGenerate(ctx context.Context) error { + // If model URL is provided, download the model first + if generateConfig.ModelURL != "" { + fmt.Printf("Model URL provided: %s\n", generateConfig.ModelURL) + + // Get the appropriate provider for this URL + registry := modelprovider.GetRegistry() + provider, err := registry.SelectProvider(generateConfig.ModelURL, generateConfig.Provider) + if err != nil { + return fmt.Errorf("failed to select provider: %w", err) + } + + fmt.Printf("Using provider: %s\n", provider.Name()) + + // Check if user is authenticated with the provider + if err := provider.CheckAuth(); err != nil { + return fmt.Errorf("%s authentication check failed: %w", provider.Name(), err) + } + + // Determine the download directory + var downloadDir string + var cleanupDir bool + + if generateConfig.DownloadDir != "" { + // Use user-specified directory + downloadDir = generateConfig.DownloadDir + cleanupDir = false + + // Create the directory if it doesn't exist + if err := os.MkdirAll(downloadDir, 0755); err != nil { + return fmt.Errorf("failed to create download directory: %w", err) + } + fmt.Printf("Using custom download directory: %s\n", downloadDir) + } else { + // Create a temporary directory for downloading the model + tmpDir, err := os.MkdirTemp("", "modctl-model-downloads-*") + if err != nil { + return fmt.Errorf("failed to create temporary directory: %w", err) + } + downloadDir = tmpDir + cleanupDir = true + } + + // Clean up the directory only if it was a temporary directory + if cleanupDir { + defer os.RemoveAll(downloadDir) + } + + // Download the model + downloadPath, err := provider.DownloadModel(ctx, generateConfig.ModelURL, downloadDir) + if err != nil { + return fmt.Errorf("failed to download model from %s: %w", provider.Name(), err) + } + + // Update workspace to the downloaded model path + generateConfig.Workspace = downloadPath + fmt.Printf("Using downloaded model at: %s\n", downloadPath) + } + fmt.Printf("Generating modelfile for %s\n", generateConfig.Workspace) modelfile, err := modelfile.NewModelfileByWorkspace(generateConfig.Workspace, generateConfig) if err != nil { diff --git a/pkg/config/modelfile/modelfile.go b/pkg/config/modelfile/modelfile.go index 2b0fdcb5..a51789f1 100644 --- a/pkg/config/modelfile/modelfile.go +++ b/pkg/config/modelfile/modelfile.go @@ -39,6 +39,9 @@ type GenerateConfig struct { ParamSize string Precision string Quantization string + ModelURL string + Provider string // Explicit provider for short-form URLs (e.g., "huggingface", "modelscope") + DownloadDir string // Custom directory for downloading models (optional) ExcludePatterns []string } @@ -56,6 +59,9 @@ func NewGenerateConfig() *GenerateConfig { ParamSize: "", Precision: "", Quantization: "", + ModelURL: "", + Provider: "", + DownloadDir: "", ExcludePatterns: []string{}, } } diff --git a/pkg/modelprovider/huggingface/downloader.go b/pkg/modelprovider/huggingface/downloader.go new file mode 100644 index 00000000..3f5616e2 --- /dev/null +++ b/pkg/modelprovider/huggingface/downloader.go @@ -0,0 +1,127 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package huggingface + +import ( + "fmt" + "io" + "net/url" + "os" + "os/exec" + "path/filepath" + "strings" +) + +const ( + huggingFaceBaseURL = "https://huggingface.co" +) + +// parseModelURL parses a HuggingFace model URL and extracts the owner and repository name +func parseModelURL(modelURL string) (owner, repo string, err error) { + // Handle both full URLs and short-form repo names + modelURL = strings.TrimSpace(modelURL) + + // Remove trailing slashes + modelURL = strings.TrimSuffix(modelURL, "/") + + // If it's a full URL, parse it + if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") { + u, err := url.Parse(modelURL) + if err != nil { + return "", "", fmt.Errorf("invalid URL: %w", err) + } + + // Expected format: https://huggingface.co/owner/repo + parts := strings.Split(strings.Trim(u.Path, "/"), "/") + if len(parts) < 2 { + return "", "", fmt.Errorf("invalid HuggingFace URL format, expected https://huggingface.co/owner/repo") + } + + owner = parts[0] + repo = parts[1] + } else { + // Handle short-form like "owner/repo" + parts := strings.Split(modelURL, "/") + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid model identifier, expected format: owner/repo") + } + + owner = parts[0] + repo = parts[1] + } + + if owner == "" || repo == "" { + return "", "", fmt.Errorf("owner and repository name cannot be empty") + } + + return owner, repo, nil +} + +// checkHuggingFaceAuth checks if the user is authenticated with HuggingFace +func checkHuggingFaceAuth() error { + // Try to find the HF token + token := os.Getenv("HF_TOKEN") + if token != "" { + return nil + } + + // Check if the token file exists + homeDir, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("failed to get user home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, ".huggingface", "token") + if _, err := os.Stat(tokenPath); err == nil { + return nil + } + + // Try using whoami command + if _, err := exec.LookPath("huggingface-cli"); err == nil { + cmd := exec.Command("huggingface-cli", "whoami") + cmd.Stdout = io.Discard + cmd.Stderr = io.Discard + if err := cmd.Run(); err == nil { + return nil + } + } + + return fmt.Errorf("not authenticated with HuggingFace. Please run: huggingface-cli login") +} + +// getToken retrieves the HuggingFace token from environment or token file +func getToken() (string, error) { + // First check environment variable + token := os.Getenv("HF_TOKEN") + if token != "" { + return token, nil + } + + // Then check the token file + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, ".huggingface", "token") + data, err := os.ReadFile(tokenPath) + if err != nil { + return "", fmt.Errorf("failed to read token file: %w", err) + } + + return strings.TrimSpace(string(data)), nil +} diff --git a/pkg/modelprovider/huggingface/downloader_test.go b/pkg/modelprovider/huggingface/downloader_test.go new file mode 100644 index 00000000..f7c385c2 --- /dev/null +++ b/pkg/modelprovider/huggingface/downloader_test.go @@ -0,0 +1,168 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package huggingface + +import ( + "strings" + "testing" +) + +func TestParseModelURL(t *testing.T) { + tests := []struct { + name string + modelURL string + wantOwner string + wantRepo string + wantErr bool + errContains string + }{ + { + name: "full URL", + modelURL: "https://huggingface.co/meta-llama/Llama-2-7b-hf", + wantOwner: "meta-llama", + wantRepo: "Llama-2-7b-hf", + wantErr: false, + }, + { + name: "full URL with trailing slash", + modelURL: "https://huggingface.co/meta-llama/Llama-2-7b-hf/", + wantOwner: "meta-llama", + wantRepo: "Llama-2-7b-hf", + wantErr: false, + }, + { + name: "short form", + modelURL: "meta-llama/Llama-2-7b-hf", + wantOwner: "meta-llama", + wantRepo: "Llama-2-7b-hf", + wantErr: false, + }, + { + name: "http URL", + modelURL: "http://huggingface.co/openai/gpt-2", + wantOwner: "openai", + wantRepo: "gpt-2", + wantErr: false, + }, + { + name: "invalid format - missing repo", + modelURL: "https://huggingface.co/meta-llama", + wantErr: true, + errContains: "invalid HuggingFace URL format", + }, + { + name: "invalid format - only owner", + modelURL: "meta-llama", + wantErr: true, + errContains: "invalid model identifier", + }, + { + name: "empty URL", + modelURL: "", + wantErr: true, + errContains: "invalid model identifier", + }, + { + name: "URL with spaces (trimmed)", + modelURL: " meta-llama/Llama-2-7b-hf ", + wantOwner: "meta-llama", + wantRepo: "Llama-2-7b-hf", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + owner, repo, err := parseModelURL(tt.modelURL) + + if tt.wantErr { + if err == nil { + t.Errorf("parseModelURL() expected error but got nil") + return + } + if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("parseModelURL() error = %v, want error containing %v", err, tt.errContains) + } + return + } + + if err != nil { + t.Errorf("parseModelURL() unexpected error = %v", err) + return + } + + if owner != tt.wantOwner { + t.Errorf("parseModelURL() owner = %v, want %v", owner, tt.wantOwner) + } + + if repo != tt.wantRepo { + t.Errorf("parseModelURL() repo = %v, want %v", repo, tt.wantRepo) + } + }) + } +} + +func TestProvider_SupportsURL(t *testing.T) { + provider := New() + + tests := []struct { + name string + url string + want bool + }{ + { + name: "full HuggingFace URL", + url: "https://huggingface.co/meta-llama/Llama-2-7b-hf", + want: true, + }, + { + name: "short form repo (requires explicit --provider)", + url: "meta-llama/Llama-2-7b-hf", + want: false, + }, + { + name: "ModelScope URL", + url: "https://modelscope.cn/models/owner/repo", + want: false, + }, + { + name: "invalid format", + url: "just-a-string", + want: false, + }, + { + name: "HTTP URL", + url: "http://example.com/owner/repo", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := provider.SupportsURL(tt.url); got != tt.want { + t.Errorf("Provider.SupportsURL() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestProvider_Name(t *testing.T) { + provider := New() + if got := provider.Name(); got != "huggingface" { + t.Errorf("Provider.Name() = %v, want %v", got, "huggingface") + } +} diff --git a/pkg/modelprovider/huggingface/provider.go b/pkg/modelprovider/huggingface/provider.go new file mode 100644 index 00000000..a4cc7698 --- /dev/null +++ b/pkg/modelprovider/huggingface/provider.go @@ -0,0 +1,90 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package huggingface + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// Provider implements the modelprovider.Provider interface for HuggingFace +type Provider struct{} + +// New creates a new HuggingFace provider instance +func New() *Provider { + return &Provider{} +} + +// Name returns the name of this provider +func (p *Provider) Name() string { + return "huggingface" +} + +// SupportsURL checks if this provider can handle the given URL +// It only supports full HuggingFace URLs with the huggingface.co domain +// For short-form repo identifiers (owner/repo), users must explicitly specify --provider huggingface +func (p *Provider) SupportsURL(url string) bool { + url = strings.TrimSpace(url) + + // Only support full HuggingFace URLs + return strings.Contains(url, "huggingface.co") +} + +// DownloadModel downloads a model from HuggingFace using the huggingface-cli +func (p *Provider) DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) { + owner, repo, err := parseModelURL(modelURL) + if err != nil { + return "", err + } + + repoID := fmt.Sprintf("%s/%s", owner, repo) + + // Check if huggingface-cli is available + if _, err := exec.LookPath("huggingface-cli"); err != nil { + return "", fmt.Errorf("huggingface-cli not found in PATH. Please install it using: pip install huggingface_hub[cli]") + } + + // Create destination directory if it doesn't exist + if err := os.MkdirAll(destDir, 0755); err != nil { + return "", fmt.Errorf("failed to create destination directory: %w", err) + } + + // Construct the download path + downloadPath := filepath.Join(destDir, repo) + + // Use huggingface-cli to download the model + // The --local-dir-use-symlinks=False flag ensures files are copied, not symlinked + cmd := exec.CommandContext(ctx, "huggingface-cli", "download", repoID, "--local-dir", downloadPath, "--local-dir-use-symlinks", "False") + + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("failed to download model using huggingface-cli: %w", err) + } + + return downloadPath, nil +} + +// CheckAuth verifies that the user is authenticated with HuggingFace +func (p *Provider) CheckAuth() error { + return checkHuggingFaceAuth() +} diff --git a/pkg/modelprovider/modelscope/downloader.go b/pkg/modelprovider/modelscope/downloader.go new file mode 100644 index 00000000..b7a85f48 --- /dev/null +++ b/pkg/modelprovider/modelscope/downloader.go @@ -0,0 +1,137 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelscope + +import ( + "fmt" + "net/url" + "os" + "os/exec" + "path/filepath" + "strings" +) + +const ( + modelScopeBaseURL = "https://modelscope.cn" +) + +// parseModelURL parses a ModelScope model URL and extracts the owner and repository name +func parseModelURL(modelURL string) (owner, repo string, err error) { + // Handle both full URLs and short-form repo names + modelURL = strings.TrimSpace(modelURL) + + // Remove trailing slashes + modelURL = strings.TrimSuffix(modelURL, "/") + + // If it's a full URL, parse it + if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") { + u, err := url.Parse(modelURL) + if err != nil { + return "", "", fmt.Errorf("invalid URL: %w", err) + } + + // Expected format: https://modelscope.cn/models/owner/repo + parts := strings.Split(strings.Trim(u.Path, "/"), "/") + + // Handle both formats: + // 1. https://modelscope.cn/models/owner/repo + // 2. https://modelscope.cn/owner/repo + if len(parts) >= 1 && parts[0] == "models" { + // Must have exactly 3 parts: models/owner/repo + if len(parts) < 3 { + return "", "", fmt.Errorf("invalid ModelScope URL format, expected https://modelscope.cn/models/owner/repo") + } + owner = parts[1] + repo = parts[2] + } else if len(parts) >= 2 { + // Direct format: owner/repo + owner = parts[0] + repo = parts[1] + } else { + return "", "", fmt.Errorf("invalid ModelScope URL format, expected https://modelscope.cn/models/owner/repo") + } + } else { + // Handle short-form like "owner/repo" + parts := strings.Split(modelURL, "/") + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid model identifier, expected format: owner/repo") + } + + owner = parts[0] + repo = parts[1] + } + + if owner == "" || repo == "" { + return "", "", fmt.Errorf("owner and repository name cannot be empty") + } + + return owner, repo, nil +} + +// checkModelScopeAuth checks if the user is authenticated with ModelScope +func checkModelScopeAuth() error { + // Try to find the ModelScope SDK token + token := os.Getenv("MODELSCOPE_SDK_TOKEN") + if token != "" { + return nil + } + + // Check if the token file exists in the default location + homeDir, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("failed to get user home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, ".modelscope", "token") + if _, err := os.Stat(tokenPath); err == nil { + return nil + } + + // Try using modelscope CLI to check auth + if _, err := exec.LookPath("modelscope"); err == nil { + // ModelScope CLI doesn't have a direct whoami equivalent, but we can check if login exists + // For now, we'll allow the download to proceed and let it fail if auth is needed + return nil + } + + // Warning: ModelScope authentication is optional for public models + // We'll return nil here and let the download command handle auth errors + return nil +} + +// getToken retrieves the ModelScope token from environment or token file +func getToken() (string, error) { + // First check environment variable + token := os.Getenv("MODELSCOPE_SDK_TOKEN") + if token != "" { + return token, nil + } + + // Then check the token file + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, ".modelscope", "token") + data, err := os.ReadFile(tokenPath) + if err != nil { + return "", fmt.Errorf("failed to read token file: %w", err) + } + + return strings.TrimSpace(string(data)), nil +} diff --git a/pkg/modelprovider/modelscope/downloader_test.go b/pkg/modelprovider/modelscope/downloader_test.go new file mode 100644 index 00000000..3e70e023 --- /dev/null +++ b/pkg/modelprovider/modelscope/downloader_test.go @@ -0,0 +1,180 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelscope + +import ( + "strings" + "testing" +) + +func TestParseModelURL(t *testing.T) { + tests := []struct { + name string + modelURL string + wantOwner string + wantRepo string + wantErr bool + errContains string + }{ + { + name: "full URL with models prefix", + modelURL: "https://modelscope.cn/models/qwen/Qwen-7B", + wantOwner: "qwen", + wantRepo: "Qwen-7B", + wantErr: false, + }, + { + name: "full URL without models prefix", + modelURL: "https://modelscope.cn/damo/nlp_structbert_backbone_base_std", + wantOwner: "damo", + wantRepo: "nlp_structbert_backbone_base_std", + wantErr: false, + }, + { + name: "full URL with trailing slash", + modelURL: "https://modelscope.cn/models/qwen/Qwen-7B/", + wantOwner: "qwen", + wantRepo: "Qwen-7B", + wantErr: false, + }, + { + name: "short form", + modelURL: "qwen/Qwen-7B", + wantOwner: "qwen", + wantRepo: "Qwen-7B", + wantErr: false, + }, + { + name: "http URL", + modelURL: "http://modelscope.cn/models/damo/nlp_structbert_backbone_base_std", + wantOwner: "damo", + wantRepo: "nlp_structbert_backbone_base_std", + wantErr: false, + }, + { + name: "invalid format - missing repo", + modelURL: "https://modelscope.cn/models/qwen", + wantErr: true, + errContains: "invalid ModelScope URL format", + }, + { + name: "invalid format - only owner", + modelURL: "qwen", + wantErr: true, + errContains: "invalid model identifier", + }, + { + name: "empty URL", + modelURL: "", + wantErr: true, + errContains: "invalid model identifier", + }, + { + name: "URL with spaces (trimmed)", + modelURL: " qwen/Qwen-7B ", + wantOwner: "qwen", + wantRepo: "Qwen-7B", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + owner, repo, err := parseModelURL(tt.modelURL) + + if tt.wantErr { + if err == nil { + t.Errorf("parseModelURL() expected error but got nil") + return + } + if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("parseModelURL() error = %v, want error containing %v", err, tt.errContains) + } + return + } + + if err != nil { + t.Errorf("parseModelURL() unexpected error = %v", err) + return + } + + if owner != tt.wantOwner { + t.Errorf("parseModelURL() owner = %v, want %v", owner, tt.wantOwner) + } + + if repo != tt.wantRepo { + t.Errorf("parseModelURL() repo = %v, want %v", repo, tt.wantRepo) + } + }) + } +} + +func TestProvider_SupportsURL(t *testing.T) { + provider := New() + + tests := []struct { + name string + url string + want bool + }{ + { + name: "full ModelScope URL", + url: "https://modelscope.cn/models/qwen/Qwen-7B", + want: true, + }, + { + name: "ModelScope URL without models prefix", + url: "https://modelscope.cn/damo/nlp_structbert_backbone_base_std", + want: true, + }, + { + name: "HuggingFace URL", + url: "https://huggingface.co/meta-llama/Llama-2-7b-hf", + want: false, + }, + { + name: "short form repo (ambiguous, returns false)", + url: "qwen/Qwen-7B", + want: false, + }, + { + name: "invalid format", + url: "just-a-string", + want: false, + }, + { + name: "HTTP URL", + url: "http://example.com/owner/repo", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := provider.SupportsURL(tt.url); got != tt.want { + t.Errorf("Provider.SupportsURL() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestProvider_Name(t *testing.T) { + provider := New() + if got := provider.Name(); got != "modelscope" { + t.Errorf("Provider.Name() = %v, want %v", got, "modelscope") + } +} diff --git a/pkg/modelprovider/modelscope/provider.go b/pkg/modelprovider/modelscope/provider.go new file mode 100644 index 00000000..8b16912d --- /dev/null +++ b/pkg/modelprovider/modelscope/provider.go @@ -0,0 +1,90 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelscope + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// Provider implements the modelprovider.Provider interface for ModelScope +type Provider struct{} + +// New creates a new ModelScope provider instance +func New() *Provider { + return &Provider{} +} + +// Name returns the name of this provider +func (p *Provider) Name() string { + return "modelscope" +} + +// SupportsURL checks if this provider can handle the given URL +// It only supports full ModelScope URLs with the modelscope.cn domain +// For short-form repo identifiers (owner/repo), users must explicitly specify --provider modelscope +func (p *Provider) SupportsURL(url string) bool { + url = strings.TrimSpace(url) + + // Only support full ModelScope URLs + return strings.Contains(url, "modelscope.cn") +} + +// DownloadModel downloads a model from ModelScope using the modelscope CLI +func (p *Provider) DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) { + owner, repo, err := parseModelURL(modelURL) + if err != nil { + return "", err + } + + repoID := fmt.Sprintf("%s/%s", owner, repo) + + // Check if modelscope CLI is available + if _, err := exec.LookPath("modelscope"); err != nil { + return "", fmt.Errorf("modelscope CLI not found in PATH. Please install it using: pip install modelscope") + } + + // Create destination directory if it doesn't exist + if err := os.MkdirAll(destDir, 0755); err != nil { + return "", fmt.Errorf("failed to create destination directory: %w", err) + } + + // Construct the download path + downloadPath := filepath.Join(destDir, repo) + + // Use modelscope download command + // The modelscope CLI uses: modelscope download --model --local_dir + cmd := exec.CommandContext(ctx, "modelscope", "download", "--model", repoID, "--local_dir", downloadPath) + + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("failed to download model using modelscope CLI: %w", err) + } + + return downloadPath, nil +} + +// CheckAuth verifies that the user is authenticated with ModelScope +func (p *Provider) CheckAuth() error { + return checkModelScopeAuth() +} diff --git a/pkg/modelprovider/provider.go b/pkg/modelprovider/provider.go new file mode 100644 index 00000000..974dbfd9 --- /dev/null +++ b/pkg/modelprovider/provider.go @@ -0,0 +1,47 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelprovider + +import "context" + +// Provider defines the interface that all model providers must implement. +// A provider is responsible for downloading models from a specific source +// (e.g., HuggingFace, ModelScope, Civitai, etc.) +type Provider interface { + // Name returns the human-readable name of the provider + // Example: "huggingface", "modelscope", "civitai" + Name() string + + // SupportsURL checks if this provider can handle the given model URL + // This enables automatic provider detection based on full URL patterns (with domain) + // Short-form URLs (owner/repo) require explicit provider specification via GetProviderByName + SupportsURL(url string) bool + + // DownloadModel downloads a model from the provider and returns the local path + // Parameters: + // - ctx: context for cancellation and timeout + // - modelURL: the URL or identifier of the model to download + // - destDir: the destination directory where the model should be downloaded + // Returns: + // - string: the local path where the model was downloaded + // - error: any error that occurred during download + DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) + + // CheckAuth verifies that the user is authenticated with the provider + // Returns an error if authentication is missing or invalid + CheckAuth() error +} diff --git a/pkg/modelprovider/registry.go b/pkg/modelprovider/registry.go new file mode 100644 index 00000000..55de2b69 --- /dev/null +++ b/pkg/modelprovider/registry.go @@ -0,0 +1,104 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelprovider + +import ( + "fmt" + "sync" + + "github.com/modelpack/modctl/pkg/modelprovider/huggingface" + "github.com/modelpack/modctl/pkg/modelprovider/modelscope" +) + +// Registry manages all available model providers and provides +// functionality to select the appropriate provider for a given URL +type Registry struct { + providers []Provider +} + +var ( + instance *Registry + once sync.Once +) + +// GetRegistry returns the singleton instance of the registry +// This is thread-safe and will only create the instance once +func GetRegistry() *Registry { + once.Do(func() { + instance = &Registry{ + providers: []Provider{ + huggingface.New(), + modelscope.New(), + // Future providers can be added here: + // civitai.New(), + }, + } + }) + return instance +} + +// ResetRegistry resets the singleton instance +// This should only be used in tests to ensure isolation between test cases +func ResetRegistry() { + once = sync.Once{} + instance = nil +} + +// GetProvider returns the appropriate provider for the given model URL +// It iterates through all registered providers and returns the first one +// that supports the URL. This only works for full URLs with domain names. +// For short-form URLs (owner/repo), use GetProviderByName with an explicit provider +func (r *Registry) GetProvider(modelURL string) (Provider, error) { + for _, p := range r.providers { + if p.SupportsURL(modelURL) { + return p, nil + } + } + return nil, fmt.Errorf("no provider found for URL: %s. For short-form URLs (owner/repo), use --provider flag to specify the provider explicitly", modelURL) +} + +// SelectProvider returns the appropriate provider based on the URL and explicit provider name +// If providerName is specified, it uses GetProviderByName for short-form URLs +// Otherwise, it uses GetProvider for auto-detection with full URLs +func (r *Registry) SelectProvider(modelURL, providerName string) (Provider, error) { + if providerName != "" { + // Explicit provider specified, use it + return r.GetProviderByName(providerName) + } + // No explicit provider, try auto-detection + return r.GetProvider(modelURL) +} + +// GetProviderByName returns a specific provider by its name +// This is useful when you want to explicitly select a provider +func (r *Registry) GetProviderByName(name string) (Provider, error) { + for _, p := range r.providers { + if p.Name() == name { + return p, nil + } + } + return nil, fmt.Errorf("provider not found: %s", name) +} + +// ListProviders returns the names of all registered providers +func (r *Registry) ListProviders() []string { + names := make([]string, len(r.providers)) + for i, p := range r.providers { + names[i] = p.Name() + } + return names +} diff --git a/pkg/modelprovider/registry_test.go b/pkg/modelprovider/registry_test.go new file mode 100644 index 00000000..23d5ad46 --- /dev/null +++ b/pkg/modelprovider/registry_test.go @@ -0,0 +1,242 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelprovider + +import ( + "testing" +) + +func TestRegistry_GetProvider(t *testing.T) { + ResetRegistry() // Ensure clean state for test + registry := GetRegistry() + + tests := []struct { + name string + modelURL string + wantProvider string + wantErr bool + }{ + { + name: "HuggingFace full URL", + modelURL: "https://huggingface.co/meta-llama/Llama-2-7b-hf", + wantProvider: "huggingface", + wantErr: false, + }, + { + name: "HuggingFace short form (requires explicit provider)", + modelURL: "meta-llama/Llama-2-7b-hf", + wantErr: true, + }, + { + name: "ModelScope full URL", + modelURL: "https://modelscope.cn/models/qwen/Qwen-7B", + wantProvider: "modelscope", + wantErr: false, + }, + { + name: "ModelScope URL without models prefix", + modelURL: "https://modelscope.cn/damo/nlp_structbert_backbone_base_std", + wantProvider: "modelscope", + wantErr: false, + }, + { + name: "Unsupported URL", + modelURL: "https://example.com/model/repo", + wantErr: true, + }, + { + name: "Invalid format", + modelURL: "just-a-string", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := registry.GetProvider(tt.modelURL) + + if tt.wantErr { + if err == nil { + t.Errorf("GetProvider() expected error but got nil") + } + return + } + + if err != nil { + t.Errorf("GetProvider() unexpected error = %v", err) + return + } + + if provider.Name() != tt.wantProvider { + t.Errorf("GetProvider() provider name = %v, want %v", provider.Name(), tt.wantProvider) + } + }) + } +} + +func TestRegistry_GetProviderByName(t *testing.T) { + ResetRegistry() // Ensure clean state for test + registry := GetRegistry() + + tests := []struct { + name string + providerName string + wantErr bool + }{ + { + name: "Get HuggingFace provider", + providerName: "huggingface", + wantErr: false, + }, + { + name: "Get ModelScope provider", + providerName: "modelscope", + wantErr: false, + }, + { + name: "Get non-existent provider", + providerName: "civitai", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := registry.GetProviderByName(tt.providerName) + + if tt.wantErr { + if err == nil { + t.Errorf("GetProviderByName() expected error but got nil") + } + return + } + + if err != nil { + t.Errorf("GetProviderByName() unexpected error = %v", err) + return + } + + if provider.Name() != tt.providerName { + t.Errorf("GetProviderByName() provider name = %v, want %v", provider.Name(), tt.providerName) + } + }) + } +} + +func TestRegistry_ListProviders(t *testing.T) { + ResetRegistry() // Ensure clean state for test + registry := GetRegistry() + providers := registry.ListProviders() + + if len(providers) != 2 { + t.Errorf("ListProviders() returned %d providers, want 2", len(providers)) + } + + expectedProviders := map[string]bool{ + "huggingface": false, + "modelscope": false, + } + + for _, name := range providers { + if _, ok := expectedProviders[name]; !ok { + t.Errorf("ListProviders() returned unexpected provider: %s", name) + } + expectedProviders[name] = true + } + + for name, found := range expectedProviders { + if !found { + t.Errorf("ListProviders() missing expected provider: %s", name) + } + } +} + +func TestRegistry_SelectProvider(t *testing.T) { + ResetRegistry() // Ensure clean state for test + registry := GetRegistry() + + tests := []struct { + name string + modelURL string + providerName string + wantProvider string + wantErr bool + }{ + { + name: "Full URL with auto-detection (HuggingFace)", + modelURL: "https://huggingface.co/meta-llama/Llama-2-7b-hf", + providerName: "", + wantProvider: "huggingface", + wantErr: false, + }, + { + name: "Full URL with auto-detection (ModelScope)", + modelURL: "https://modelscope.cn/models/qwen/Qwen-7B", + providerName: "", + wantProvider: "modelscope", + wantErr: false, + }, + { + name: "Short-form with explicit provider (HuggingFace)", + modelURL: "meta-llama/Llama-2-7b-hf", + providerName: "huggingface", + wantProvider: "huggingface", + wantErr: false, + }, + { + name: "Short-form with explicit provider (ModelScope)", + modelURL: "qwen/Qwen-7B", + providerName: "modelscope", + wantProvider: "modelscope", + wantErr: false, + }, + { + name: "Short-form without explicit provider (should fail)", + modelURL: "owner/repo", + providerName: "", + wantErr: true, + }, + { + name: "Invalid provider name", + modelURL: "owner/repo", + providerName: "invalid-provider", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := registry.SelectProvider(tt.modelURL, tt.providerName) + + if tt.wantErr { + if err == nil { + t.Errorf("SelectProvider() expected error but got nil") + } + return + } + + if err != nil { + t.Errorf("SelectProvider() unexpected error = %v", err) + return + } + + if provider.Name() != tt.wantProvider { + t.Errorf("SelectProvider() provider name = %v, want %v", provider.Name(), tt.wantProvider) + } + }) + } +}