diff --git a/.gitignore b/.gitignore index 4bd61cf0..775cea89 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,5 @@ cache __pycache__ *.pyc .pytest_cache -*.tgz \ No newline at end of file +*.tgz +.huggingface \ No newline at end of file diff --git a/api/core/v1alpha1/model_types.go b/api/core/v1alpha1/model_types.go index fe92f753..dd62484e 100644 --- a/api/core/v1alpha1/model_types.go +++ b/api/core/v1alpha1/model_types.go @@ -44,11 +44,18 @@ type ModelHub struct { // the whole repo which includes all kinds of quantized models. // TODO: this is only supported with Huggingface, add support for ModelScope // in the near future. + // Note: once filename is set, allowPatterns and ignorePatterns should be left unset. Filename *string `json:"filename,omitempty"` // Revision refers to a Git revision id which can be a branch name, a tag, or a commit hash. // +kubebuilder:default=main // +optional Revision *string `json:"revision,omitempty"` + // AllowPatterns refers to files matched with at least one pattern will be downloaded. + // +optional + AllowPatterns []string `json:"allowPatterns,omitempty"` + // IgnorePatterns refers to files matched with any of the patterns will not be downloaded. + // +optional + IgnorePatterns []string `json:"ignorePatterns,omitempty"` } // URIProtocol represents the protocol of the URI. diff --git a/api/core/v1alpha1/zz_generated.deepcopy.go b/api/core/v1alpha1/zz_generated.deepcopy.go index d4da3b46..7c94dbca 100644 --- a/api/core/v1alpha1/zz_generated.deepcopy.go +++ b/api/core/v1alpha1/zz_generated.deepcopy.go @@ -127,6 +127,16 @@ func (in *ModelHub) DeepCopyInto(out *ModelHub) { *out = new(string) **out = **in } + if in.AllowPatterns != nil { + in, out := &in.AllowPatterns, &out.AllowPatterns + *out = make([]string, len(*in)) + copy(*out, *in) + } + if in.IgnorePatterns != nil { + in, out := &in.IgnorePatterns, &out.IgnorePatterns + *out = make([]string, len(*in)) + copy(*out, *in) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelHub. diff --git a/client-go/applyconfiguration/core/v1alpha1/modelhub.go b/client-go/applyconfiguration/core/v1alpha1/modelhub.go index bff87476..a11f133a 100644 --- a/client-go/applyconfiguration/core/v1alpha1/modelhub.go +++ b/client-go/applyconfiguration/core/v1alpha1/modelhub.go @@ -20,10 +20,12 @@ package v1alpha1 // ModelHubApplyConfiguration represents a declarative configuration of the ModelHub type for use // with apply. type ModelHubApplyConfiguration struct { - Name *string `json:"name,omitempty"` - ModelID *string `json:"modelID,omitempty"` - Filename *string `json:"filename,omitempty"` - Revision *string `json:"revision,omitempty"` + Name *string `json:"name,omitempty"` + ModelID *string `json:"modelID,omitempty"` + Filename *string `json:"filename,omitempty"` + Revision *string `json:"revision,omitempty"` + AllowPatterns []string `json:"allowPatterns,omitempty"` + IgnorePatterns []string `json:"ignorePatterns,omitempty"` } // ModelHubApplyConfiguration constructs a declarative configuration of the ModelHub type for use with @@ -63,3 +65,23 @@ func (b *ModelHubApplyConfiguration) WithRevision(value string) *ModelHubApplyCo b.Revision = &value return b } + +// WithAllowPatterns adds the given value to the AllowPatterns field in the declarative configuration +// and returns the receiver, so that objects can be build by chaining "With" function invocations. +// If called multiple times, values provided by each call will be appended to the AllowPatterns field. +func (b *ModelHubApplyConfiguration) WithAllowPatterns(values ...string) *ModelHubApplyConfiguration { + for i := range values { + b.AllowPatterns = append(b.AllowPatterns, values[i]) + } + return b +} + +// WithIgnorePatterns adds the given value to the IgnorePatterns field in the declarative configuration +// and returns the receiver, so that objects can be build by chaining "With" function invocations. +// If called multiple times, values provided by each call will be appended to the IgnorePatterns field. +func (b *ModelHubApplyConfiguration) WithIgnorePatterns(values ...string) *ModelHubApplyConfiguration { + for i := range values { + b.IgnorePatterns = append(b.IgnorePatterns, values[i]) + } + return b +} diff --git a/config/crd/bases/llmaz.io_openmodels.yaml b/config/crd/bases/llmaz.io_openmodels.yaml index 1d87e4b2..27a661fa 100644 --- a/config/crd/bases/llmaz.io_openmodels.yaml +++ b/config/crd/bases/llmaz.io_openmodels.yaml @@ -110,6 +110,12 @@ spec: description: ModelHub represents the model registry for model downloads. properties: + allowPatterns: + description: AllowPatterns refers to only files matching at + least one pattern are downloaded. + items: + type: string + type: array filename: description: |- Filename refers to a specified model file rather than the whole repo. @@ -117,6 +123,12 @@ spec: the whole repo which includes all kinds of quantized models. in the near future. type: string + ignorePatterns: + description: IgnorePatterns refers to files matching any of + the patterns are not downloaded. + items: + type: string + type: array modelID: description: |- ModelID refers to the model identifier on model hub, diff --git a/llmaz/main.py b/llmaz/main.py index f3e1a208..5a0d08c8 100644 --- a/llmaz/main.py +++ b/llmaz/main.py @@ -17,32 +17,39 @@ import os from datetime import datetime +from llmaz.model_loader.constant import * + from llmaz.model_loader.objstore.objstore import model_download from llmaz.model_loader.model_hub.hub_factory import HubFactory -from llmaz.model_loader.model_hub.huggingface import HUGGING_FACE +from llmaz.model_loader.model_hub.huggingface import HUB_HUGGING_FACE from llmaz.util.logger import Logger - if __name__ == "__main__": - model_source_type = os.getenv("MODEL_SOURCE_TYPE") + model_source_type = os.getenv(ENV_HUB_MODEL_SOURCE_TYPE) start_time = datetime.now() if model_source_type == "modelhub": - hub_name = os.getenv("MODEL_HUB_NAME", HUGGING_FACE) - revision = os.getenv("REVISION") - model_id = os.getenv("MODEL_ID") - model_file_name = os.getenv("MODEL_FILENAME") + hub_name = os.getenv(ENV_HUB_MODEL_HUB_NAME, HUB_HUGGING_FACE) + revision = os.getenv(ENV_HUB_REVISION) + model_id = os.getenv(ENV_HUB_MODEL_ID) + model_file_name = os.getenv(ENV_HUB_MODEL_FILENAME) + model_allow_patterns = os.getenv(ENV_HUB_MODEL_ALLOW_PATTERNS) + model_ignore_patterns = os.getenv(ENV_HUB_MODEL_IGNORE_PATTERNS) if not model_id: raise EnvironmentError(f"Environment variable '{model_id}' not found.") - hub = HubFactory.new(hub_name) - hub.load_model(model_id, model_file_name, revision) + model_allow_patterns_list, model_ignore_patterns_list = [], [] + if model_allow_patterns: + model_allow_patterns_list = model_allow_patterns.split(',') + if model_ignore_patterns: + model_ignore_patterns_list = model_ignore_patterns.split(',') + hub.load_model(model_id, model_file_name, revision, model_allow_patterns_list, model_ignore_patterns_list) elif model_source_type == "objstore": - provider = os.getenv("PROVIDER") - endpoint = os.getenv("ENDPOINT") - bucket = os.getenv("BUCKET") - src = os.getenv("MODEL_PATH") + provider = os.getenv(ENV_OBJ_PROVIDER) + endpoint = os.getenv(ENV_OBJ_ENDPOINT) + bucket = os.getenv(ENV_OBJ_BUCKET) + src = os.getenv(ENV_OBJ_MODEL_PATH) model_download(provider=provider, endpoint=endpoint, bucket=bucket, src=src) else: diff --git a/llmaz/model_loader/constant.py b/llmaz/model_loader/constant.py new file mode 100644 index 00000000..08674898 --- /dev/null +++ b/llmaz/model_loader/constant.py @@ -0,0 +1,16 @@ +MODEL_LOCAL_DIR = "/workspace/models/" +HUB_HUGGING_FACE = "Huggingface" +HUB_MODEL_SCOPE = "ModelScope" + +ENV_HUB_MODEL_SOURCE_TYPE = "MODEL_SOURCE_TYPE" +ENV_HUB_MODEL_HUB_NAME = "MODEL_HUB_NAME" +ENV_HUB_REVISION = "REVISION" +ENV_HUB_MODEL_ID = "MODEL_ID" +ENV_HUB_MODEL_FILENAME = "MODEL_FILENAME" +ENV_HUB_MODEL_ALLOW_PATTERNS = "MODEL_ALLOW_PATTERNS" +ENV_HUB_MODEL_IGNORE_PATTERNS = "MODEL_IGNORE_PATTERNS" + +ENV_OBJ_PROVIDER = "PROVIDER" +ENV_OBJ_ENDPOINT = "ENDPOINT" +ENV_OBJ_BUCKET = "BUCKET" +ENV_OBJ_MODEL_PATH = "MODEL_PATH" diff --git a/llmaz/model_loader/defaults.py b/llmaz/model_loader/defaults.py deleted file mode 100644 index 1aab3d4f..00000000 --- a/llmaz/model_loader/defaults.py +++ /dev/null @@ -1 +0,0 @@ -MODEL_LOCAL_DIR = "/workspace/models/" diff --git a/llmaz/model_loader/model_hub/hub_factory.py b/llmaz/model_loader/model_hub/hub_factory.py index a70f8ec8..c078dded 100644 --- a/llmaz/model_loader/model_hub/hub_factory.py +++ b/llmaz/model_loader/model_hub/hub_factory.py @@ -13,19 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. """ - +from llmaz.model_loader.constant import HUB_HUGGING_FACE, HUB_MODEL_SCOPE from llmaz.model_loader.model_hub.model_hub import ModelHub -from llmaz.model_loader.model_hub.huggingface import HUGGING_FACE, Huggingface -from llmaz.model_loader.model_hub.modelscope import MODEL_SCOPE, ModelScope - +from llmaz.model_loader.model_hub.huggingface import Huggingface +from llmaz.model_loader.model_hub.modelscope import ModelScope SUPPORT_MODEL_HUBS = { - HUGGING_FACE: Huggingface, - MODEL_SCOPE: ModelScope, + HUB_HUGGING_FACE: Huggingface, + HUB_MODEL_SCOPE: ModelScope, } class HubFactory: + @classmethod def new(cls, hub_name: str) -> ModelHub: if hub_name not in SUPPORT_MODEL_HUBS.keys(): diff --git a/llmaz/model_loader/model_hub/huggingface.py b/llmaz/model_loader/model_hub/huggingface.py index 6b329365..fafd14d2 100644 --- a/llmaz/model_loader/model_hub/huggingface.py +++ b/llmaz/model_loader/model_hub/huggingface.py @@ -17,69 +17,51 @@ import concurrent.futures import os -from huggingface_hub import hf_hub_download, list_repo_files +from huggingface_hub import snapshot_download -from llmaz.model_loader.defaults import MODEL_LOCAL_DIR +from llmaz.model_loader.constant import MODEL_LOCAL_DIR, HUB_HUGGING_FACE from llmaz.model_loader.model_hub.model_hub import ( - HUGGING_FACE, - MAX_WORKERS, ModelHub, ) from llmaz.util.logger import Logger from llmaz.model_loader.model_hub.util import get_folder_total_size -from typing import Optional +from typing import Optional, List class Huggingface(ModelHub): @classmethod def name(cls) -> str: - return HUGGING_FACE + return HUB_HUGGING_FACE @classmethod def load_model( - cls, model_id: str, filename: Optional[str], revision: Optional[str] + cls, + model_id: str, + filename: Optional[str], + revision: Optional[str], + allow_patterns: Optional[List[str]], + ignore_patterns: Optional[List[str]], ) -> None: Logger.info( f"Start to download, model_id: {model_id}, filename: {filename}, revision: {revision}" ) - if filename: - hf_hub_download( - repo_id=model_id, - filename=filename, - local_dir=MODEL_LOCAL_DIR, - revision=revision, - ) - file_size = os.path.getsize(MODEL_LOCAL_DIR + filename) / (1024**3) - Logger.info( - f"The total size of {MODEL_LOCAL_DIR + filename} is {file_size: .2f} GB" - ) - return - local_dir = os.path.join( - MODEL_LOCAL_DIR, f"models--{model_id.replace('/','--')}" + MODEL_LOCAL_DIR, f"models--{model_id.replace('/', '--')}" ) - # # TODO: Should we verify the download is finished? - with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: - futures = [] - for file in list_repo_files(repo_id=model_id): - # TODO: support version management, right now we didn't distinguish with them. - futures.append( - executor.submit( - hf_hub_download, - repo_id=model_id, - filename=file, - local_dir=local_dir, - revision=revision, - ).add_done_callback(handle_completion) - ) + if filename: + allow_patterns.append(filename) + local_dir = MODEL_LOCAL_DIR + + snapshot_download( + repo_id=model_id, + revision=revision, + local_dir=local_dir, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) total_size = get_folder_total_size(local_dir) - Logger.info(f"The total size of {local_dir} is {total_size: .2f} GB") - - -def handle_completion(future): - filename = future.result() - Logger.info(f"Download completed for {filename}") + Logger.info(f"The total size of {local_dir} is {total_size: .2f} GB") \ No newline at end of file diff --git a/llmaz/model_loader/model_hub/model_hub.py b/llmaz/model_loader/model_hub/model_hub.py index a28a799c..d9eb97de 100644 --- a/llmaz/model_loader/model_hub/model_hub.py +++ b/llmaz/model_loader/model_hub/model_hub.py @@ -15,11 +15,7 @@ """ from abc import ABC, abstractmethod -from typing import Optional - -MAX_WORKERS = 4 -HUGGING_FACE = "Huggingface" -MODEL_SCOPE = "ModelScope" +from typing import Optional, List class ModelHub(ABC): @@ -31,6 +27,11 @@ def name(cls) -> str: @classmethod @abstractmethod def load_model( - cls, model_id: str, filename: Optional[str], revision: Optional[str] + cls, + model_id: str, + filename: Optional[str], + revision: Optional[str], + allow_patterns: Optional[List[str]], + ignore_patterns: Optional[List[str]], ) -> None: pass diff --git a/llmaz/model_loader/model_hub/modelscope.py b/llmaz/model_loader/model_hub/modelscope.py index 66c55532..f2ed1a9f 100644 --- a/llmaz/model_loader/model_hub/modelscope.py +++ b/llmaz/model_loader/model_hub/modelscope.py @@ -15,15 +15,12 @@ """ import os -import concurrent.futures -from typing import Optional +from typing import Optional, List from modelscope import snapshot_download -from llmaz.model_loader.defaults import MODEL_LOCAL_DIR +from llmaz.model_loader.constant import * from llmaz.model_loader.model_hub.model_hub import ( - MAX_WORKERS, - MODEL_SCOPE, ModelHub, ) from llmaz.util.logger import Logger @@ -33,36 +30,35 @@ class ModelScope(ModelHub): @classmethod def name(cls) -> str: - return MODEL_SCOPE + return HUB_MODEL_SCOPE - # TODO: support filename @classmethod def load_model( - cls, model_id: str, filename: Optional[str], revision: Optional[str] + cls, + model_id: str, + filename: Optional[str], + revision: Optional[str], + allow_patterns: Optional[List[str]], + ignore_patterns: Optional[List[str]], ) -> None: Logger.info( f"Start to download, model_id: {model_id}, filename: {filename}, revision: {revision}" ) local_dir = os.path.join( - MODEL_LOCAL_DIR, f"models--{model_id.replace('/','--')}" + MODEL_LOCAL_DIR, f"models--{model_id.replace('/', '--')}" ) - with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: - futures = [] - futures.append( - executor.submit( - snapshot_download, - model_id=model_id, - local_dir=local_dir, - revision=revision, - ).add_done_callback(handle_completion) - ) + if filename: + allow_patterns.append(filename) + local_dir = MODEL_LOCAL_DIR + snapshot_download( + model_id=model_id, + revision=revision, + local_dir=local_dir, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) total_size = get_folder_total_size(local_dir) Logger.info(f"The total size of {local_dir} is {total_size:.2f} GB") - - -def handle_completion(future): - filename = future.result() - Logger.info(f"Download completed for {filename}") diff --git a/llmaz/model_loader/model_hub/util.py b/llmaz/model_loader/model_hub/util.py index eafc0dd4..85aabfeb 100644 --- a/llmaz/model_loader/model_hub/util.py +++ b/llmaz/model_loader/model_hub/util.py @@ -17,18 +17,18 @@ import os from llmaz.util.logger import Logger + def get_folder_total_size(folder_path: str) -> float: total_size = 0 - for dirpath, _, filenames in os.walk(folder_path): + for dirPath, _, filenames in os.walk(folder_path): for filename in filenames: - file_path = os.path.join(dirpath, filename) + file_path = os.path.join(dirPath, filename) try: if os.path.exists(file_path): total_size += os.path.getsize(file_path) except OSError as e: Logger.error(f"Failed to get file {file_path} size, err is {e}") - - total_size_gb = total_size / (1024**3) + total_size_gb = total_size / (1024 ** 3) return total_size_gb diff --git a/llmaz/model_loader/objstore/objstore.py b/llmaz/model_loader/objstore/objstore.py index eed54479..af8d03ef 100644 --- a/llmaz/model_loader/objstore/objstore.py +++ b/llmaz/model_loader/objstore/objstore.py @@ -1,6 +1,6 @@ from omnistore.objstore import StoreFactory -from llmaz.model_loader.defaults import MODEL_LOCAL_DIR +from llmaz.model_loader.constant import MODEL_LOCAL_DIR def model_download(provider: str, endpoint: str, bucket: str, src: str): diff --git a/pkg/controller_helper/model_source/modelhub.go b/pkg/controller_helper/model_source/modelhub.go index 573c6a0a..cee1631e 100644 --- a/pkg/controller_helper/model_source/modelhub.go +++ b/pkg/controller_helper/model_source/modelhub.go @@ -29,18 +29,20 @@ import ( var _ ModelSourceProvider = &ModelHubProvider{} type ModelHubProvider struct { - modelName string - modelID string - modelHub string - fileName *string - modelRevision *string + modelName string + modelID string + modelHub string + fileName *string + modelRevision *string + modelAllowPatterns []string + modelIgnorePatterns []string } func (p *ModelHubProvider) ModelName() string { return p.modelName } -// Example 1: +// ModelPath Example 1: // - modelID: facebook/opt-125m // modelPath: /workspace/models/models--facebook--opt-125m // @@ -89,6 +91,16 @@ func (p *ModelHubProvider) InjectModelLoader(template *corev1.PodTemplateSpec, i corev1.EnvVar{Name: "REVISION", Value: *p.modelRevision}, ) } + if p.modelAllowPatterns != nil { + initContainer.Env = append(initContainer.Env, + corev1.EnvVar{Name: "MODEL_ALLOW_PATTERNS", Value: strings.Join(p.modelAllowPatterns, ",")}, + ) + } + if p.modelIgnorePatterns != nil { + initContainer.Env = append(initContainer.Env, + corev1.EnvVar{Name: "MODEL_IGNORE_PATTERNS", Value: strings.Join(p.modelIgnorePatterns, ",")}, + ) + } initContainer.Env = append(initContainer.Env, corev1.EnvVar{ Name: "HUGGING_FACE_HUB_TOKEN", // vllm diff --git a/pkg/controller_helper/model_source/modelsource.go b/pkg/controller_helper/model_source/modelsource.go index 75676e65..281779da 100644 --- a/pkg/controller_helper/model_source/modelsource.go +++ b/pkg/controller_helper/model_source/modelsource.go @@ -17,10 +17,9 @@ limitations under the License. package modelSource import ( - corev1 "k8s.io/api/core/v1" - coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" "github.com/inftyai/llmaz/pkg/util" + corev1 "k8s.io/api/core/v1" ) const ( @@ -60,11 +59,13 @@ type ModelSourceProvider interface { func NewModelSourceProvider(model *coreapi.OpenModel) ModelSourceProvider { if model.Spec.Source.ModelHub != nil { return &ModelHubProvider{ - modelName: model.Name, - modelID: model.Spec.Source.ModelHub.ModelID, - modelHub: *model.Spec.Source.ModelHub.Name, - fileName: model.Spec.Source.ModelHub.Filename, - modelRevision: model.Spec.Source.ModelHub.Revision, + modelName: model.Name, + modelID: model.Spec.Source.ModelHub.ModelID, + modelHub: *model.Spec.Source.ModelHub.Name, + fileName: model.Spec.Source.ModelHub.Filename, + modelRevision: model.Spec.Source.ModelHub.Revision, + modelAllowPatterns: model.Spec.Source.ModelHub.AllowPatterns, + modelIgnorePatterns: model.Spec.Source.ModelHub.IgnorePatterns, } } diff --git a/pkg/controller_helper/model_source/modelsource_test.go b/pkg/controller_helper/model_source/modelsource_test.go index f4bbb8a4..1fb7c8b9 100644 --- a/pkg/controller_helper/model_source/modelsource_test.go +++ b/pkg/controller_helper/model_source/modelsource_test.go @@ -39,7 +39,7 @@ func Test_ModelSourceProvider(t *testing.T) { }, { name: "modelhub with GGUF file", - model: wrapper.MakeModel("test-7b").FamilyName("test").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf", "").Obj(), + model: wrapper.MakeModel("test-7b").FamilyName("test").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf", "", nil, nil).Obj(), wantModelName: "test-7b", wantModelPath: "/workspace/models/qwen2-0_5b-instruct-q5_k_m.gguf", }, diff --git a/pkg/webhook/openmodel_webhook.go b/pkg/webhook/openmodel_webhook.go index e6b5d821..2d97c7ce 100644 --- a/pkg/webhook/openmodel_webhook.go +++ b/pkg/webhook/openmodel_webhook.go @@ -118,5 +118,15 @@ func (w *OpenModelWebhook) generateValidate(obj runtime.Object) field.ErrorList allErrs = append(allErrs, field.Invalid(sourcePath.Child("modelHub.filename"), *model.Spec.Source.ModelHub.Filename, "Filename can only set once modeHub is Huggingface")) } } + + if model.Spec.Source.ModelHub != nil && model.Spec.Source.ModelHub.Filename != nil { + if model.Spec.Source.ModelHub.AllowPatterns != nil { + allErrs = append(allErrs, field.Invalid(sourcePath.Child("modelHub.allowPatterns"), model.Spec.Source.ModelHub.AllowPatterns, "Once Filename is set, allowPatterns should be nil")) + } + if model.Spec.Source.ModelHub.IgnorePatterns != nil { + allErrs = append(allErrs, field.Invalid(sourcePath.Child("modelHub.ignorePatterns"), model.Spec.Source.ModelHub.IgnorePatterns, "Once Filename is set, ignorePatterns should be nil")) + } + } + return allErrs } diff --git a/test/e2e/playground_test.go b/test/e2e/playground_test.go index 5376bda9..215da365 100644 --- a/test/e2e/playground_test.go +++ b/test/e2e/playground_test.go @@ -48,7 +48,7 @@ var _ = ginkgo.Describe("playground e2e tests", func() { }) ginkgo.It("Deploy a huggingface model with llama.cpp", func() { - model := wrapper.MakeModel("qwen2-0-5b-gguf").FamilyName("qwen2").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf", "").Obj() + model := wrapper.MakeModel("qwen2-0-5b-gguf").FamilyName("qwen2").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf", "", nil, nil).Obj() gomega.Expect(k8sClient.Create(ctx, model)).To(gomega.Succeed()) defer func() { gomega.Expect(k8sClient.Delete(ctx, model)).To(gomega.Succeed()) @@ -73,7 +73,7 @@ var _ = ginkgo.Describe("playground e2e tests", func() { Request("cpu", "2").Request("memory", "4Gi").Limit("cpu", "4").Limit("memory", "4Gi").Obj() gomega.Expect(k8sClient.Create(ctx, backendRuntime)).To(gomega.Succeed()) - model := wrapper.MakeModel("qwen2-0-5b-gguf").FamilyName("qwen2").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf", "").Obj() + model := wrapper.MakeModel("qwen2-0-5b-gguf").FamilyName("qwen2").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf", "", nil, nil).Obj() gomega.Expect(k8sClient.Create(ctx, model)).To(gomega.Succeed()) defer func() { gomega.Expect(k8sClient.Delete(ctx, model)).To(gomega.Succeed()) @@ -92,12 +92,12 @@ var _ = ginkgo.Describe("playground e2e tests", func() { }) // TODO: add e2e tests. // ginkgo.It("SpeculativeDecoding with llama.cpp", func() { - // targetModel := wrapper.MakeModel("llama2-7b-q8-gguf").FamilyName("llama2").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("TheBloke/Llama-2-7B-GGUF", "llama-2-7b.Q8_0.gguf", "").Obj() + // targetModel := wrapper.MakeModel("llama2-7b-q8-gguf").FamilyName("llama2").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("TheBloke/Llama-2-7B-GGUF", "llama-2-7b.Q8_0.gguf", "", nil, nil).Obj() // gomega.Expect(k8sClient.Create(ctx, targetModel)).To(gomega.Succeed()) // defer func() { // gomega.Expect(k8sClient.Delete(ctx, targetModel)).To(gomega.Succeed()) // }() - // draftModel := wrapper.MakeModel("llama2-7b-q2-k-gguf").FamilyName("llama2").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("TheBloke/Llama-2-7B-GGUF", "llama-2-7b.Q2_K.gguf", "").Obj() + // draftModel := wrapper.MakeModel("llama2-7b-q2-k-gguf").FamilyName("llama2").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("TheBloke/Llama-2-7B-GGUF", "llama-2-7b.Q2_K.gguf", "", nil, nil).Obj() // gomega.Expect(k8sClient.Create(ctx, draftModel)).To(gomega.Succeed()) // defer func() { // gomega.Expect(k8sClient.Delete(ctx, draftModel)).To(gomega.Succeed()) diff --git a/test/integration/controller/inference/playground_test.go b/test/integration/controller/inference/playground_test.go index 3814aa52..15a6d87c 100644 --- a/test/integration/controller/inference/playground_test.go +++ b/test/integration/controller/inference/playground_test.go @@ -57,7 +57,7 @@ var _ = ginkgo.Describe("playground controller test", func() { gomega.Expect(k8sClient.Create(ctx, ns)).To(gomega.Succeed()) model = util.MockASampleModel() gomega.Expect(k8sClient.Create(ctx, model)).To(gomega.Succeed()) - draftModel = wrapper.MakeModel("llama3-2b").FamilyName("llama3").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("meta-llama/Meta-Llama-3-2B", "", "").Obj() + draftModel = wrapper.MakeModel("llama3-2b").FamilyName("llama3").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("meta-llama/Meta-Llama-3-2B", "", "", nil, nil).Obj() gomega.Expect(k8sClient.Create(ctx, draftModel)).To(gomega.Succeed()) }) ginkgo.AfterEach(func() { diff --git a/test/integration/webhook/model_test.go b/test/integration/webhook/model_test.go index 3e3c8b5a..193fe6e4 100644 --- a/test/integration/webhook/model_test.go +++ b/test/integration/webhook/model_test.go @@ -52,18 +52,18 @@ var _ = ginkgo.Describe("model default and validation", func() { }, ginkgo.Entry("apply model family name", &testDefaultingCase{ model: func() *coreapi.OpenModel { - return wrapper.MakeModel("llama3-8b").ModelSourceWithModelID("meta-llama/Meta-Llama-3-8B", "", "").FamilyName("llama3").Obj() + return wrapper.MakeModel("llama3-8b").ModelSourceWithModelID("meta-llama/Meta-Llama-3-8B", "", "", nil, nil).FamilyName("llama3").Obj() }, wantModel: func() *coreapi.OpenModel { - return wrapper.MakeModel("llama3-8b").ModelSourceWithModelID("meta-llama/Meta-Llama-3-8B", "", "main").ModelSourceWithModelHub("Huggingface").FamilyName("llama3").Label(coreapi.ModelFamilyNameLabelKey, "llama3").Obj() + return wrapper.MakeModel("llama3-8b").ModelSourceWithModelID("meta-llama/Meta-Llama-3-8B", "", "main", nil, nil).ModelSourceWithModelHub("Huggingface").FamilyName("llama3").Label(coreapi.ModelFamilyNameLabelKey, "llama3").Obj() }, }), ginkgo.Entry("apply modelscope model hub name", &testDefaultingCase{ model: func() *coreapi.OpenModel { - return wrapper.MakeModel("llama3-8b").FamilyName("llama3").ModelSourceWithModelHub("ModelScope").ModelSourceWithModelID("LLM-Research/Meta-Llama-3-8B", "", "").Obj() + return wrapper.MakeModel("llama3-8b").FamilyName("llama3").ModelSourceWithModelHub("ModelScope").ModelSourceWithModelID("LLM-Research/Meta-Llama-3-8B", "", "", nil, nil).Obj() }, wantModel: func() *coreapi.OpenModel { - return wrapper.MakeModel("llama3-8b").ModelSourceWithModelID("LLM-Research/Meta-Llama-3-8B", "", "main").ModelSourceWithModelHub("ModelScope").FamilyName("llama3").Label(coreapi.ModelFamilyNameLabelKey, "llama3").Obj() + return wrapper.MakeModel("llama3-8b").ModelSourceWithModelID("LLM-Research/Meta-Llama-3-8B", "", "main", nil, nil).ModelSourceWithModelHub("ModelScope").FamilyName("llama3").Label(coreapi.ModelFamilyNameLabelKey, "llama3").Obj() }, }), ) @@ -83,19 +83,19 @@ var _ = ginkgo.Describe("model default and validation", func() { }, ginkgo.Entry("default normal huggingface model creation", &testValidatingCase{ model: func() *coreapi.OpenModel { - return wrapper.MakeModel("llama3-8b").FamilyName("llama3").ModelSourceWithModelID("meta-llama/Meta-Llama-3-8B", "", "").Obj() + return wrapper.MakeModel("llama3-8b").FamilyName("llama3").ModelSourceWithModelID("meta-llama/Meta-Llama-3-8B", "", "", nil, nil).Obj() }, failed: false, }), ginkgo.Entry("normal modelScope model creation", &testValidatingCase{ model: func() *coreapi.OpenModel { - return wrapper.MakeModel("llama3-8b").FamilyName("llama3").ModelSourceWithModelHub("ModelScope").ModelSourceWithModelID("LLM-Research/Meta-Llama-3-8B", "", "").Obj() + return wrapper.MakeModel("llama3-8b").FamilyName("llama3").ModelSourceWithModelHub("ModelScope").ModelSourceWithModelID("LLM-Research/Meta-Llama-3-8B", "", "", nil, nil).Obj() }, failed: false, }), ginkgo.Entry("invalid model name", &testValidatingCase{ model: func() *coreapi.OpenModel { - return wrapper.MakeModel("qwen-2-0.5b").FamilyName("qwen2").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct", "", "").Obj() + return wrapper.MakeModel("qwen-2-0.5b").FamilyName("qwen2").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct", "", "", nil, nil).Obj() }, failed: true, }), @@ -131,13 +131,31 @@ var _ = ginkgo.Describe("model default and validation", func() { }), ginkgo.Entry("set filename when modelHub is Huggingface", &testValidatingCase{ model: func() *coreapi.OpenModel { - return wrapper.MakeModel("llama3-8b").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf", "").FamilyName("llama3").Obj() + return wrapper.MakeModel("llama3-8b").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf", "", nil, nil).FamilyName("llama3").Obj() + }, + failed: false, + }), + ginkgo.Entry("set filename and allowPatterns when modelHub is Huggingface", &testValidatingCase{ + model: func() *coreapi.OpenModel { + return wrapper.MakeModel("llama3-8b").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf", "", []string{"*"}, nil).FamilyName("llama3").Obj() + }, + failed: true, + }), + ginkgo.Entry("set filename and ignorePatterns when modelHub is Huggingface", &testValidatingCase{ + model: func() *coreapi.OpenModel { + return wrapper.MakeModel("llama3-8b").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf", "", nil, []string{"*"}).FamilyName("llama3").Obj() + }, + failed: true, + }), + ginkgo.Entry("set allowPatterns and ignorePatterns when modelHub is Huggingface", &testValidatingCase{ + model: func() *coreapi.OpenModel { + return wrapper.MakeModel("llama3-8b").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "", "", []string{"*"}, []string{"*.gguf"}).FamilyName("llama3").Obj() }, failed: false, }), ginkgo.Entry("set filename when modelHub is ModelScope", &testValidatingCase{ model: func() *coreapi.OpenModel { - return wrapper.MakeModel("llama3-8b").ModelSourceWithModelHub("ModelScope").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf", "").FamilyName("llama3").Obj() + return wrapper.MakeModel("llama3-8b").ModelSourceWithModelHub("ModelScope").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf", "", nil, nil).FamilyName("llama3").Obj() }, failed: true, }), diff --git a/test/util/mock.go b/test/util/mock.go index e3dd2f97..d0ebe7c0 100644 --- a/test/util/mock.go +++ b/test/util/mock.go @@ -26,7 +26,7 @@ const ( ) func MockASampleModel() *coreapi.OpenModel { - return wrapper.MakeModel(sampleModelName).FamilyName("llama3").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("meta-llama/Meta-Llama-3-8B", "", "").Obj() + return wrapper.MakeModel(sampleModelName).FamilyName("llama3").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("meta-llama/Meta-Llama-3-8B", "", "", nil, nil).Obj() } func MockASamplePlayground(ns string) *inferenceapi.Playground { diff --git a/test/util/validation/validate_service.go b/test/util/validation/validate_service.go index 3a834553..99eb6690 100644 --- a/test/util/validation/validate_service.go +++ b/test/util/validation/validate_service.go @@ -111,6 +111,12 @@ func ValidateModelLoader(model *coreapi.OpenModel, index int, workload *lws.Lead if model.Spec.Source.ModelHub.Revision != nil { envStrings = append(envStrings, "REVISION") } + if model.Spec.Source.ModelHub.AllowPatterns != nil { + envStrings = append(envStrings, "MODEL_ALLOW_PATTERNS") + } + if model.Spec.Source.ModelHub.IgnorePatterns != nil { + envStrings = append(envStrings, "MODEL_IGNORE_PATTERNS") + } } if model.Spec.Source.URI != nil { envStrings = []string{"MODEL_SOURCE_TYPE", "PROVIDER", "ENDPOINT", "BUCKET", "MODEL_PATH", "OSS_ACCESS_KEY_ID", "OSS_ACCESS_KEY_SECRET"} diff --git a/test/util/wrapper/model.go b/test/util/wrapper/model.go index 0e7b1f89..077ef00e 100644 --- a/test/util/wrapper/model.go +++ b/test/util/wrapper/model.go @@ -47,7 +47,7 @@ func (w *ModelWrapper) FamilyName(name string) *ModelWrapper { return w } -func (w *ModelWrapper) ModelSourceWithModelID(modelID string, filename string, revision string) *ModelWrapper { +func (w *ModelWrapper) ModelSourceWithModelID(modelID string, filename string, revision string, allowPatterns, ignorePatterns []string) *ModelWrapper { if modelID != "" { if w.Spec.Source.ModelHub == nil { w.Spec.Source.ModelHub = &coreapi.ModelHub{} @@ -61,6 +61,14 @@ func (w *ModelWrapper) ModelSourceWithModelID(modelID string, filename string, r if revision != "" { w.Spec.Source.ModelHub.Revision = &revision } + + if allowPatterns != nil { + w.Spec.Source.ModelHub.AllowPatterns = allowPatterns + } + + if ignorePatterns != nil { + w.Spec.Source.ModelHub.IgnorePatterns = ignorePatterns + } } return w }