diff --git a/README.md b/README.md index b9f0b4d5..6365cc01 100644 --- a/README.md +++ b/README.md @@ -27,11 +27,11 @@ Easy, advanced inference platform for large language models on Kubernetes ## Feature Overview - **Easy of Use**: People can quick deploy a LLM service with minimal configurations. -- **Broad Backend Support**: llmaz supports a wide range of advanced inference backends for high performance, like [vLLM](https://github.com/vllm-project/vllm), [SGLang](https://github.com/sgl-project/sglang), [llama.cpp](https://github.com/ggerganov/llama.cpp). Find the full list of supported backends [here](./docs/support-backends.md). +- **Broad Backend Support**: llmaz supports a wide range of advanced inference backends for different scenarios, like [vLLM](https://github.com/vllm-project/vllm), [SGLang](https://github.com/sgl-project/sglang), [llama.cpp](https://github.com/ggerganov/llama.cpp). Find the full list of supported backends [here](./docs/support-backends.md). - **Scaling Efficiency (WIP)**: llmaz works smoothly with autoscaling components like [Cluster-Autoscaler](https://github.com/kubernetes/autoscaler/tree/master/cluster-autoscaler) or [Karpenter](https://github.com/kubernetes-sigs/karpenter) to support elastic scenarios. - **Accelerator Fungibility (WIP)**: llmaz supports serving the same LLM with various accelerators to optimize cost and performance. -- **SOTA Inference (WIP)**: llmaz supports the latest cutting-edge researches like [Speculative Decoding](https://arxiv.org/abs/2211.17192) or [Splitwise](https://arxiv.org/abs/2311.18677) to run on Kubernetes. -- **Various Model Providers**: llmaz automatically loads models from various providers, such as [HuggingFace](https://huggingface.co/), [ModelScope](https://www.modelscope.cn), ObjectStores(aliyun OSS, more on the way). +- **SOTA Inference**: llmaz supports the latest cutting-edge researches like [Speculative Decoding](https://arxiv.org/abs/2211.17192) or [Splitwise](https://arxiv.org/abs/2311.18677)(WIP) to run on Kubernetes. +- **Various Model Providers**: llmaz supports a wide range of model providers, such as [HuggingFace](https://huggingface.co/), [ModelScope](https://www.modelscope.cn), ObjectStores(aliyun OSS, more on the way). llmaz automatically handles the model loading requiring no effort from users. - **Multi-hosts Support**: llmaz supports both single-host and multi-hosts scenarios with [LWS](https://github.com/kubernetes-sigs/lws) from day 1. ## Quick Start @@ -110,10 +110,19 @@ curl http://localhost:8080/v1/completions \ ## Roadmap - Gateway support for traffic routing +- Metrics support - Serverless support for cloud-agnostic users - CLI tool support - Model training, fine tuning in the long-term +## Project Structures + +```structure +llmaz # root +├── llmaz # where the model loader logic locates +├── pkg # where the main logic for Kubernetes controllers locates +``` + ## Contributions 🚀 All kinds of contributions are welcomed ! Please follow [Contributing](./CONTRIBUTING.md). Thanks to all these contributors. diff --git a/api/core/v1alpha1/model_types.go b/api/core/v1alpha1/model_types.go index 7fbad86d..e0ca3d62 100644 --- a/api/core/v1alpha1/model_types.go +++ b/api/core/v1alpha1/model_types.go @@ -92,9 +92,9 @@ type Flavor struct { // the requests here will be covered. // +optional Requests v1.ResourceList `json:"requests,omitempty"` - // NodeSelector defines the labels to filter specified nodes, like - // cloud-provider.com/accelerator: nvidia-a100. - // NodeSelector will be auto injected to the Pods as scheduling primitives. + // NodeSelector represents the node candidates for Pod placements, if a node doesn't + // meet the nodeSelector, it will be filtered out in the resourceFungibility scheduler plugin. + // If nodeSelector is empty, it means every node is a candidate. // +optional NodeSelector map[string]string `json:"nodeSelector,omitempty"` // Params stores other useful parameters and will be consumed by the autoscaling components @@ -107,39 +107,47 @@ type Flavor struct { type ModelName string -// ModelClaim represents the references to one model. -// It's a simple config for most of the cases compared to multiModelsClaim. +// ModelClaim represents claiming for one model, it's the standard claimMode +// of multiModelsClaim compared to other modes like SpeculativeDecoding. type ModelClaim struct { - // ModelName represents a list of models, there maybe multiple models here - // to support state-of-the-art technologies like speculative decoding. + // ModelName represents the name of the Model. ModelName ModelName `json:"modelName,omitempty"` - // InferenceFlavors represents a list of flavors with fungibility supports - // to serve the model. The flavor names should be a subset of the model - // configured flavors. If not set, will use the model configured flavors. + // InferenceFlavors represents a list of flavors with fungibility support + // to serve the model. + // If set, The flavor names should be a subset of the model configured flavors. + // If not set, Model configured flavors will be used by default. // +optional InferenceFlavors []FlavorName `json:"inferenceFlavors,omitempty"` } -// MultiModelsClaim represents the references to multiple models. -// It's an advanced and more complicated config comparing to modelClaim. +type InferenceMode string + +const ( + Standard InferenceMode = "Standard" + SpeculativeDecoding InferenceMode = "SpeculativeDecoding" +) + +// MultiModelsClaim represents claiming for multiple models with different claimModes, +// like standard or speculative-decoding to support different inference scenarios. type MultiModelsClaim struct { // ModelNames represents a list of models, there maybe multiple models here // to support state-of-the-art technologies like speculative decoding. + // If the composedMode is SpeculativeDecoding, the first model is the target model, + // and the second model is the draft model. // +kubebuilder:validation:MinItems=1 ModelNames []ModelName `json:"modelNames,omitempty"` + // Mode represents the paradigm to serve the model, whether via a standard way + // or via an advanced technique like SpeculativeDecoding. + // +kubebuilder:default=Standard + // +kubebuilder:validation:Enum={Standard,SpeculativeDecoding} + // +optional + InferenceMode InferenceMode `json:"inferenceMode,omitempty"` // InferenceFlavors represents a list of flavors with fungibility supported // to serve the model. // - If not set, always apply with the 0-index model by default. // - If set, will lookup the flavor names following the model orders. // +optional InferenceFlavors []FlavorName `json:"inferenceFlavors,omitempty"` - // Rate works only when multiple claims declared, it represents the replicas rates of - // the sub-workload, like when claim1.rate:claim2.rate = 1:2 and 3 replicas defined in - // workload, then sub-workload1 will have 1 replica, and sub-workload2 will have 2 replicas. - // This is mostly designed for state-of-the-art technology called splitwise, the prefill - // and decode phase will be separated and requires different accelerators. - // The sum of the rates should be divisible by replicas. - Rate *int32 `json:"rate,omitempty"` } // ModelSpec defines the desired state of Model @@ -151,7 +159,8 @@ type ModelSpec struct { // the model such as loading from huggingface, OCI registry, s3, host path and so on. Source ModelSource `json:"source"` // InferenceFlavors represents the accelerator requirements to serve the model. - // Flavors are fungible following the priority of slice order. + // Flavors are fungible following the priority represented by the slice order. + // +kubebuilder:validation:MaxItems=8 // +optional InferenceFlavors []Flavor `json:"inferenceFlavors,omitempty"` } diff --git a/api/core/v1alpha1/zz_generated.deepcopy.go b/api/core/v1alpha1/zz_generated.deepcopy.go index 00dfcf4e..8ad44d3e 100644 --- a/api/core/v1alpha1/zz_generated.deepcopy.go +++ b/api/core/v1alpha1/zz_generated.deepcopy.go @@ -195,11 +195,6 @@ func (in *MultiModelsClaim) DeepCopyInto(out *MultiModelsClaim) { *out = make([]FlavorName, len(*in)) copy(*out, *in) } - if in.Rate != nil { - in, out := &in.Rate, &out.Rate - *out = new(int32) - **out = **in - } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MultiModelsClaim. diff --git a/api/inference/v1alpha1/config_types.go b/api/inference/v1alpha1/config_types.go index d3078e75..ebe466e1 100644 --- a/api/inference/v1alpha1/config_types.go +++ b/api/inference/v1alpha1/config_types.go @@ -39,6 +39,7 @@ type BackendConfig struct { // +optional Version *string `json:"version,omitempty"` // Args represents the arguments passed to the backend. + // You can add new args or overwrite the default args. // +optional Args []string `json:"args,omitempty"` // Envs represents the environments set to the container. diff --git a/api/inference/v1alpha1/playground_types.go b/api/inference/v1alpha1/playground_types.go index 1d18f270..792fe421 100644 --- a/api/inference/v1alpha1/playground_types.go +++ b/api/inference/v1alpha1/playground_types.go @@ -28,19 +28,17 @@ type PlaygroundSpec struct { // +kubebuilder:default=1 // +optional Replicas *int32 `json:"replicas,omitempty"` - // ModelClaim represents one modelClaim, it's a simple configuration - // compared to multiModelsClaims only work for one model and one claim. - // ModelClaim and multiModelsClaims are exclusive configured. - // Note: properties (nodeSelectors, resources, e.g.) of the model flavors - // will be applied to the workload if not exist. + // ModelClaim represents claiming for one model, it's the standard claimMode + // of multiModelsClaim compared to other modes like SpeculativeDecoding. + // Most of the time, modelClaim is enough. + // ModelClaim and multiModelsClaim are exclusive configured. // +optional ModelClaim *coreapi.ModelClaim `json:"modelClaim,omitempty"` - // MultiModelsClaims represents multiple modelClaim, which is useful when different - // sub-workload has different accelerator requirements, like the state-of-the-art - // technology called splitwise, the workload template is shared by both. - // ModelClaim and multiModelsClaims are exclusive configured. + // MultiModelsClaim represents claiming for multiple models with different claimModes, + // like standard or speculative-decoding to support different inference scenarios. + // ModelClaim and multiModelsClaim are exclusive configured. // +optional - MultiModelsClaims []coreapi.MultiModelsClaim `json:"multiModelsClaims,omitempty"` + MultiModelsClaim *coreapi.MultiModelsClaim `json:"multiModelsClaim,omitempty"` // BackendConfig represents the inference backend configuration // under the hood, e.g. vLLM, which is the default backend. // +optional diff --git a/api/inference/v1alpha1/service_types.go b/api/inference/v1alpha1/service_types.go index 7af3f22a..9ab675b9 100644 --- a/api/inference/v1alpha1/service_types.go +++ b/api/inference/v1alpha1/service_types.go @@ -27,14 +27,9 @@ import ( // Service controller will maintain multi-flavor of workloads with // different accelerators for cost or performance considerations. type ServiceSpec struct { - // MultiModelsClaims represents multiple modelClaim, which is useful when different - // sub-workload has different accelerator requirements, like the state-of-the-art - // technology called splitwise, the workload template is shared by both. - // Most of the time, one modelClaim is enough. - // Note: properties (nodeSelectors, resources, e.g.) of the model flavors - // will be applied to the workload if not exist. - // +kubebuilder:validation:MinItems=1 - MultiModelsClaims []coreapi.MultiModelsClaim `json:"multiModelsClaims,omitempty"` + // MultiModelsClaim represents claiming for multiple models with different claimModes, + // like standard or speculative-decoding to support different inference scenarios. + MultiModelsClaim coreapi.MultiModelsClaim `json:"multiModelsClaim,omitempty"` // WorkloadTemplate defines the underlying workload layout and configuration. // Note: the LWS spec might be twisted with various LWS instances to support // accelerator fungibility or other cutting-edge researches. diff --git a/api/inference/v1alpha1/zz_generated.deepcopy.go b/api/inference/v1alpha1/zz_generated.deepcopy.go index d9d94588..cfdad843 100644 --- a/api/inference/v1alpha1/zz_generated.deepcopy.go +++ b/api/inference/v1alpha1/zz_generated.deepcopy.go @@ -166,12 +166,10 @@ func (in *PlaygroundSpec) DeepCopyInto(out *PlaygroundSpec) { *out = new(corev1alpha1.ModelClaim) (*in).DeepCopyInto(*out) } - if in.MultiModelsClaims != nil { - in, out := &in.MultiModelsClaims, &out.MultiModelsClaims - *out = make([]corev1alpha1.MultiModelsClaim, len(*in)) - for i := range *in { - (*in)[i].DeepCopyInto(&(*out)[i]) - } + if in.MultiModelsClaim != nil { + in, out := &in.MultiModelsClaim, &out.MultiModelsClaim + *out = new(corev1alpha1.MultiModelsClaim) + (*in).DeepCopyInto(*out) } if in.BackendConfig != nil { in, out := &in.BackendConfig, &out.BackendConfig @@ -303,13 +301,7 @@ func (in *ServiceList) DeepCopyObject() runtime.Object { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ServiceSpec) DeepCopyInto(out *ServiceSpec) { *out = *in - if in.MultiModelsClaims != nil { - in, out := &in.MultiModelsClaims, &out.MultiModelsClaims - *out = make([]corev1alpha1.MultiModelsClaim, len(*in)) - for i := range *in { - (*in)[i].DeepCopyInto(&(*out)[i]) - } - } + in.MultiModelsClaim.DeepCopyInto(&out.MultiModelsClaim) in.WorkloadTemplate.DeepCopyInto(&out.WorkloadTemplate) if in.ElasticConfig != nil { in, out := &in.ElasticConfig, &out.ElasticConfig diff --git a/client-go/applyconfiguration/core/v1alpha1/multimodelsclaim.go b/client-go/applyconfiguration/core/v1alpha1/multimodelsclaim.go index f086f03e..3c6a8bc3 100644 --- a/client-go/applyconfiguration/core/v1alpha1/multimodelsclaim.go +++ b/client-go/applyconfiguration/core/v1alpha1/multimodelsclaim.go @@ -24,9 +24,9 @@ import ( // MultiModelsClaimApplyConfiguration represents an declarative configuration of the MultiModelsClaim type for use // with apply. type MultiModelsClaimApplyConfiguration struct { - ModelNames []v1alpha1.ModelName `json:"modelNames,omitempty"` - InferenceFlavors []v1alpha1.FlavorName `json:"inferenceFlavors,omitempty"` - Rate *int32 `json:"rate,omitempty"` + ModelNames []v1alpha1.ModelName `json:"modelNames,omitempty"` + InferenceMode *v1alpha1.InferenceMode `json:"inferenceMode,omitempty"` + InferenceFlavors []v1alpha1.FlavorName `json:"inferenceFlavors,omitempty"` } // MultiModelsClaimApplyConfiguration constructs an declarative configuration of the MultiModelsClaim type for use with @@ -45,6 +45,14 @@ func (b *MultiModelsClaimApplyConfiguration) WithModelNames(values ...v1alpha1.M return b } +// WithInferenceMode sets the InferenceMode field in the declarative configuration to the given value +// and returns the receiver, so that objects can be built by chaining "With" function invocations. +// If called multiple times, the InferenceMode field is set to the value of the last call. +func (b *MultiModelsClaimApplyConfiguration) WithInferenceMode(value v1alpha1.InferenceMode) *MultiModelsClaimApplyConfiguration { + b.InferenceMode = &value + return b +} + // WithInferenceFlavors adds the given value to the InferenceFlavors field in the declarative configuration // and returns the receiver, so that objects can be build by chaining "With" function invocations. // If called multiple times, values provided by each call will be appended to the InferenceFlavors field. @@ -54,11 +62,3 @@ func (b *MultiModelsClaimApplyConfiguration) WithInferenceFlavors(values ...v1al } return b } - -// WithRate sets the Rate field in the declarative configuration to the given value -// and returns the receiver, so that objects can be built by chaining "With" function invocations. -// If called multiple times, the Rate field is set to the value of the last call. -func (b *MultiModelsClaimApplyConfiguration) WithRate(value int32) *MultiModelsClaimApplyConfiguration { - b.Rate = &value - return b -} diff --git a/client-go/applyconfiguration/inference/v1alpha1/playgroundspec.go b/client-go/applyconfiguration/inference/v1alpha1/playgroundspec.go index 02604da3..6c39c925 100644 --- a/client-go/applyconfiguration/inference/v1alpha1/playgroundspec.go +++ b/client-go/applyconfiguration/inference/v1alpha1/playgroundspec.go @@ -24,10 +24,10 @@ import ( // PlaygroundSpecApplyConfiguration represents an declarative configuration of the PlaygroundSpec type for use // with apply. type PlaygroundSpecApplyConfiguration struct { - Replicas *int32 `json:"replicas,omitempty"` - ModelClaim *v1alpha1.ModelClaimApplyConfiguration `json:"modelClaim,omitempty"` - MultiModelsClaims []v1alpha1.MultiModelsClaimApplyConfiguration `json:"multiModelsClaims,omitempty"` - BackendConfig *BackendConfigApplyConfiguration `json:"backendConfig,omitempty"` + Replicas *int32 `json:"replicas,omitempty"` + ModelClaim *v1alpha1.ModelClaimApplyConfiguration `json:"modelClaim,omitempty"` + MultiModelsClaim *v1alpha1.MultiModelsClaimApplyConfiguration `json:"multiModelsClaim,omitempty"` + BackendConfig *BackendConfigApplyConfiguration `json:"backendConfig,omitempty"` } // PlaygroundSpecApplyConfiguration constructs an declarative configuration of the PlaygroundSpec type for use with @@ -52,16 +52,11 @@ func (b *PlaygroundSpecApplyConfiguration) WithModelClaim(value *v1alpha1.ModelC return b } -// WithMultiModelsClaims adds the given value to the MultiModelsClaims field in the declarative configuration -// and returns the receiver, so that objects can be build by chaining "With" function invocations. -// If called multiple times, values provided by each call will be appended to the MultiModelsClaims field. -func (b *PlaygroundSpecApplyConfiguration) WithMultiModelsClaims(values ...*v1alpha1.MultiModelsClaimApplyConfiguration) *PlaygroundSpecApplyConfiguration { - for i := range values { - if values[i] == nil { - panic("nil value passed to WithMultiModelsClaims") - } - b.MultiModelsClaims = append(b.MultiModelsClaims, *values[i]) - } +// WithMultiModelsClaim sets the MultiModelsClaim field in the declarative configuration to the given value +// and returns the receiver, so that objects can be built by chaining "With" function invocations. +// If called multiple times, the MultiModelsClaim field is set to the value of the last call. +func (b *PlaygroundSpecApplyConfiguration) WithMultiModelsClaim(value *v1alpha1.MultiModelsClaimApplyConfiguration) *PlaygroundSpecApplyConfiguration { + b.MultiModelsClaim = value return b } diff --git a/client-go/applyconfiguration/inference/v1alpha1/servicespec.go b/client-go/applyconfiguration/inference/v1alpha1/servicespec.go index 4095eb91..f31e425f 100644 --- a/client-go/applyconfiguration/inference/v1alpha1/servicespec.go +++ b/client-go/applyconfiguration/inference/v1alpha1/servicespec.go @@ -25,9 +25,9 @@ import ( // ServiceSpecApplyConfiguration represents an declarative configuration of the ServiceSpec type for use // with apply. type ServiceSpecApplyConfiguration struct { - MultiModelsClaims []v1alpha1.MultiModelsClaimApplyConfiguration `json:"multiModelsClaims,omitempty"` - WorkloadTemplate *v1.LeaderWorkerSetSpec `json:"workloadTemplate,omitempty"` - ElasticConfig *ElasticConfigApplyConfiguration `json:"elasticConfig,omitempty"` + MultiModelsClaim *v1alpha1.MultiModelsClaimApplyConfiguration `json:"multiModelsClaim,omitempty"` + WorkloadTemplate *v1.LeaderWorkerSetSpec `json:"workloadTemplate,omitempty"` + ElasticConfig *ElasticConfigApplyConfiguration `json:"elasticConfig,omitempty"` } // ServiceSpecApplyConfiguration constructs an declarative configuration of the ServiceSpec type for use with @@ -36,16 +36,11 @@ func ServiceSpec() *ServiceSpecApplyConfiguration { return &ServiceSpecApplyConfiguration{} } -// WithMultiModelsClaims adds the given value to the MultiModelsClaims field in the declarative configuration -// and returns the receiver, so that objects can be build by chaining "With" function invocations. -// If called multiple times, values provided by each call will be appended to the MultiModelsClaims field. -func (b *ServiceSpecApplyConfiguration) WithMultiModelsClaims(values ...*v1alpha1.MultiModelsClaimApplyConfiguration) *ServiceSpecApplyConfiguration { - for i := range values { - if values[i] == nil { - panic("nil value passed to WithMultiModelsClaims") - } - b.MultiModelsClaims = append(b.MultiModelsClaims, *values[i]) - } +// WithMultiModelsClaim sets the MultiModelsClaim field in the declarative configuration to the given value +// and returns the receiver, so that objects can be built by chaining "With" function invocations. +// If called multiple times, the MultiModelsClaim field is set to the value of the last call. +func (b *ServiceSpecApplyConfiguration) WithMultiModelsClaim(value *v1alpha1.MultiModelsClaimApplyConfiguration) *ServiceSpecApplyConfiguration { + b.MultiModelsClaim = value return b } diff --git a/config/crd/bases/inference.llmaz.io_playgrounds.yaml b/config/crd/bases/inference.llmaz.io_playgrounds.yaml index 69991078..766444dd 100644 --- a/config/crd/bases/inference.llmaz.io_playgrounds.yaml +++ b/config/crd/bases/inference.llmaz.io_playgrounds.yaml @@ -45,7 +45,9 @@ spec: under the hood, e.g. vLLM, which is the default backend. properties: args: - description: Args represents the arguments passed to the backend. + description: |- + Args represents the arguments passed to the backend. + You can add new args or overwrite the default args. items: type: string type: array @@ -222,66 +224,59 @@ spec: type: object modelClaim: description: |- - ModelClaim represents one modelClaim, it's a simple configuration - compared to multiModelsClaims only work for one model and one claim. - ModelClaim and multiModelsClaims are exclusive configured. - Note: properties (nodeSelectors, resources, e.g.) of the model flavors - will be applied to the workload if not exist. + ModelClaim represents claiming for one model, it's the standard claimMode + of multiModelsClaim compared to other modes like SpeculativeDecoding. + Most of the time, modelClaim is enough. + ModelClaim and multiModelsClaim are exclusive configured. properties: inferenceFlavors: description: |- - InferenceFlavors represents a list of flavors with fungibility supports - to serve the model. The flavor names should be a subset of the model - configured flavors. If not set, will use the model configured flavors. + InferenceFlavors represents a list of flavors with fungibility support + to serve the model. + If set, The flavor names should be a subset of the model configured flavors. + If not set, Model configured flavors will be used by default. items: type: string type: array modelName: - description: |- - ModelName represents a list of models, there maybe multiple models here - to support state-of-the-art technologies like speculative decoding. + description: ModelName represents the name of the Model. type: string type: object - multiModelsClaims: + multiModelsClaim: description: |- - MultiModelsClaims represents multiple modelClaim, which is useful when different - sub-workload has different accelerator requirements, like the state-of-the-art - technology called splitwise, the workload template is shared by both. - ModelClaim and multiModelsClaims are exclusive configured. - items: - description: |- - MultiModelsClaim represents the references to multiple models. - It's an advanced and more complicated config comparing to modelClaim. - properties: - inferenceFlavors: - description: |- - InferenceFlavors represents a list of flavors with fungibility supported - to serve the model. - - If not set, always apply with the 0-index model by default. - - If set, will lookup the flavor names following the model orders. - items: - type: string - type: array - modelNames: - description: |- - ModelNames represents a list of models, there maybe multiple models here - to support state-of-the-art technologies like speculative decoding. - items: - type: string - minItems: 1 - type: array - rate: - description: |- - Rate works only when multiple claims declared, it represents the replicas rates of - the sub-workload, like when claim1.rate:claim2.rate = 1:2 and 3 replicas defined in - workload, then sub-workload1 will have 1 replica, and sub-workload2 will have 2 replicas. - This is mostly designed for state-of-the-art technology called splitwise, the prefill - and decode phase will be separated and requires different accelerators. - The sum of the rates should be divisible by replicas. - format: int32 - type: integer - type: object - type: array + MultiModelsClaim represents claiming for multiple models with different claimModes, + like standard or speculative-decoding to support different inference scenarios. + ModelClaim and multiModelsClaim are exclusive configured. + properties: + inferenceFlavors: + description: |- + InferenceFlavors represents a list of flavors with fungibility supported + to serve the model. + - If not set, always apply with the 0-index model by default. + - If set, will lookup the flavor names following the model orders. + items: + type: string + type: array + inferenceMode: + default: Standard + description: |- + Mode represents the paradigm to serve the model, whether via a standard way + or via an advanced technique like SpeculativeDecoding. + enum: + - Standard + - SpeculativeDecoding + type: string + modelNames: + description: |- + ModelNames represents a list of models, there maybe multiple models here + to support state-of-the-art technologies like speculative decoding. + If the composedMode is SpeculativeDecoding, the first model is the target model, + and the second model is the draft model. + items: + type: string + minItems: 1 + type: array + type: object replicas: default: 1 description: Replicas represents the replica number of inference workloads. diff --git a/config/crd/bases/inference.llmaz.io_services.yaml b/config/crd/bases/inference.llmaz.io_services.yaml index 17e7ab57..e6bc503d 100644 --- a/config/crd/bases/inference.llmaz.io_services.yaml +++ b/config/crd/bases/inference.llmaz.io_services.yaml @@ -65,49 +65,40 @@ spec: format: int32 type: integer type: object - multiModelsClaims: + multiModelsClaim: description: |- - MultiModelsClaims represents multiple modelClaim, which is useful when different - sub-workload has different accelerator requirements, like the state-of-the-art - technology called splitwise, the workload template is shared by both. - Most of the time, one modelClaim is enough. - Note: properties (nodeSelectors, resources, e.g.) of the model flavors - will be applied to the workload if not exist. - items: - description: |- - MultiModelsClaim represents the references to multiple models. - It's an advanced and more complicated config comparing to modelClaim. - properties: - inferenceFlavors: - description: |- - InferenceFlavors represents a list of flavors with fungibility supported - to serve the model. - - If not set, always apply with the 0-index model by default. - - If set, will lookup the flavor names following the model orders. - items: - type: string - type: array - modelNames: - description: |- - ModelNames represents a list of models, there maybe multiple models here - to support state-of-the-art technologies like speculative decoding. - items: - type: string - minItems: 1 - type: array - rate: - description: |- - Rate works only when multiple claims declared, it represents the replicas rates of - the sub-workload, like when claim1.rate:claim2.rate = 1:2 and 3 replicas defined in - workload, then sub-workload1 will have 1 replica, and sub-workload2 will have 2 replicas. - This is mostly designed for state-of-the-art technology called splitwise, the prefill - and decode phase will be separated and requires different accelerators. - The sum of the rates should be divisible by replicas. - format: int32 - type: integer - type: object - minItems: 1 - type: array + MultiModelsClaim represents claiming for multiple models with different claimModes, + like standard or speculative-decoding to support different inference scenarios. + properties: + inferenceFlavors: + description: |- + InferenceFlavors represents a list of flavors with fungibility supported + to serve the model. + - If not set, always apply with the 0-index model by default. + - If set, will lookup the flavor names following the model orders. + items: + type: string + type: array + inferenceMode: + default: Standard + description: |- + Mode represents the paradigm to serve the model, whether via a standard way + or via an advanced technique like SpeculativeDecoding. + enum: + - Standard + - SpeculativeDecoding + type: string + modelNames: + description: |- + ModelNames represents a list of models, there maybe multiple models here + to support state-of-the-art technologies like speculative decoding. + If the composedMode is SpeculativeDecoding, the first model is the target model, + and the second model is the draft model. + items: + type: string + minItems: 1 + type: array + type: object workloadTemplate: description: |- WorkloadTemplate defines the underlying workload layout and configuration. diff --git a/config/crd/bases/llmaz.io_openmodels.yaml b/config/crd/bases/llmaz.io_openmodels.yaml index e89cfd6a..7b3f0734 100644 --- a/config/crd/bases/llmaz.io_openmodels.yaml +++ b/config/crd/bases/llmaz.io_openmodels.yaml @@ -47,7 +47,7 @@ spec: inferenceFlavors: description: |- InferenceFlavors represents the accelerator requirements to serve the model. - Flavors are fungible following the priority of slice order. + Flavors are fungible following the priority represented by the slice order. items: description: |- Flavor defines the accelerator requirements for a model and the necessary parameters @@ -63,9 +63,9 @@ spec: additionalProperties: type: string description: |- - NodeSelector defines the labels to filter specified nodes, like - cloud-provider.com/accelerator: nvidia-a100. - NodeSelector will be auto injected to the Pods as scheduling primitives. + NodeSelector represents the node candidates for Pod placements, if a node doesn't + meet the nodeSelector, it will be filtered out in the resourceFungibility scheduler plugin. + If nodeSelector is empty, it means every node is a candidate. type: object params: additionalProperties: @@ -97,6 +97,7 @@ spec: required: - name type: object + maxItems: 8 type: array source: description: |- diff --git a/config/manager/kustomization.yaml b/config/manager/kustomization.yaml index beaa32a8..9714fe0d 100644 --- a/config/manager/kustomization.yaml +++ b/config/manager/kustomization.yaml @@ -4,5 +4,5 @@ apiVersion: kustomize.config.k8s.io/v1beta1 kind: Kustomization images: - name: controller - newName: inftyai/llmaz - newTag: main + newName: inftyai/llmaz-test + newTag: 0901-04 diff --git a/docs/assets/.DS_Store b/docs/assets/.DS_Store deleted file mode 100644 index e848b39a..00000000 Binary files a/docs/assets/.DS_Store and /dev/null differ diff --git a/docs/examples/README.md b/docs/examples/README.md index a61f425a..b690c225 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -9,6 +9,7 @@ We provide a set of examples to help you serve large language models, by default - [Deploy models from ObjectStore](#deploy-models-from-objectstore) - [Deploy models via SGLang](#deploy-models-via-sglang) - [Deploy models via llama.cpp](#deploy-models-via-llamacpp) +- [Speculative Decoding with vLLM](#speculative-decoding-with-vllm) ### Deploy models from Huggingface @@ -34,8 +35,12 @@ In theory, if we want to load the `Qwen2-7B` model, which occupies about 14.2 GB ### Deploy models via SGLang -By default, we use [vLLM](https://github.com/vllm-project/vllm) as the inference backend, however, if you want to use other backends like [SGLang](https://github.com/sgl-project/sglang), see [examples](./sglang/) here. +By default, we use [vLLM](https://github.com/vllm-project/vllm) as the inference backend, however, if you want to use other backends like [SGLang](https://github.com/sgl-project/sglang), see [example](./sglang/) here. ### Deploy models via llama.cpp -[llama.cpp](https://github.com/ggerganov/llama.cpp) can serve models on a wide variety of hardwares, such as CPU, see [examples](./llamacpp/) here. +[llama.cpp](https://github.com/ggerganov/llama.cpp) can serve models on a wide variety of hardwares, such as CPU, see [example](./llamacpp/) here. + +### Speculative Decoding with vLLM + +[Speculative Decoding](https://arxiv.org/abs/2211.17192) can improve inference performance efficiently, see [example](./speculative-decoding/vllm/) here. diff --git a/docs/examples/llamacpp/model.yaml b/docs/examples/llamacpp/model.yaml index 40e659ad..236f0f44 100644 --- a/docs/examples/llamacpp/model.yaml +++ b/docs/examples/llamacpp/model.yaml @@ -1,7 +1,7 @@ apiVersion: llmaz.io/v1alpha1 kind: OpenModel metadata: - name: qwen2-0-5b-gguf + name: qwen2-0--5b-gguf spec: familyName: qwen2 source: diff --git a/docs/examples/sglang/model.yaml b/docs/examples/sglang/model.yaml index 0658832f..fe0ef7c1 100644 --- a/docs/examples/sglang/model.yaml +++ b/docs/examples/sglang/model.yaml @@ -1,7 +1,7 @@ apiVersion: llmaz.io/v1alpha1 kind: OpenModel metadata: - name: qwen2-05b + name: qwen2-0--5b spec: familyName: qwen2 source: diff --git a/docs/examples/speculative-decoding/vllm/model.yaml b/docs/examples/speculative-decoding/vllm/model.yaml new file mode 100644 index 00000000..35b1e757 --- /dev/null +++ b/docs/examples/speculative-decoding/vllm/model.yaml @@ -0,0 +1,25 @@ +apiVersion: llmaz.io/v1alpha1 +kind: OpenModel +metadata: + name: opt-6--7b +spec: + familyName: opt + source: + modelHub: + modelID: facebook/opt-6.7b + inferenceFlavors: + - name: a10 # gpu type + requests: + nvidia.com/gpu: 1 +--- +apiVersion: llmaz.io/v1alpha1 +kind: OpenModel +metadata: + name: opt-125m +spec: + familyName: opt + source: + modelHub: + modelID: facebook/opt-125m + # Draft model's inferenceFlavors will not impact the speculative-decoding, + # only target model will be considered, so we ignore the flavor configurations here. diff --git a/docs/examples/speculative-decoding/vllm/playground.yaml b/docs/examples/speculative-decoding/vllm/playground.yaml new file mode 100644 index 00000000..25bf50d6 --- /dev/null +++ b/docs/examples/speculative-decoding/vllm/playground.yaml @@ -0,0 +1,18 @@ +apiVersion: inference.llmaz.io/v1alpha1 +kind: Playground +metadata: + name: speculator +spec: + replicas: 1 + multiModelsClaim: + inferenceMode: SpeculativeDecoding + modelNames: + - opt-6--7b # the target model, should be the first one + - opt-125m # the draft model + backendConfig: + args: + - --use-v2-block-manager + - -tp + - 1 + - --num_speculative_tokens + - 5 diff --git a/llmaz/README.md b/llmaz/README.md index ab1db060..6369077a 100644 --- a/llmaz/README.md +++ b/llmaz/README.md @@ -1,6 +1,6 @@ -# ModelLoader +# llmaz -ModelLoader maintains the codes to load model weights with various ways, such as from huggingface or from s3. +ModelLoader maintains the codes to load model weights with various ways, such as from Huggingface or from object stores. ## Load Models From ModelHub diff --git a/pkg/controller/inference/playground_controller.go b/pkg/controller/inference/playground_controller.go index 33dea035..bcdc8538 100644 --- a/pkg/controller/inference/playground_controller.go +++ b/pkg/controller/inference/playground_controller.go @@ -102,17 +102,25 @@ func (r *PlaygroundReconciler) Reconcile(ctx context.Context, req ctrl.Request) var serviceApplyConfiguration *inferenceclientgo.ServiceApplyConfiguration - model := &coreapi.OpenModel{} + models := []*coreapi.OpenModel{} if playground.Spec.ModelClaim != nil { - modelName := playground.Spec.ModelClaim.ModelName - - if err := r.Get(ctx, types.NamespacedName{Name: string(modelName)}, model); err != nil { + model := &coreapi.OpenModel{} + if err := r.Get(ctx, types.NamespacedName{Name: string(playground.Spec.ModelClaim.ModelName)}, model); err != nil { return ctrl.Result{}, err } - serviceApplyConfiguration = buildServiceApplyConfiguration(model, playground) + models = append(models, model) + } else if playground.Spec.MultiModelsClaim != nil { + for _, modelName := range playground.Spec.MultiModelsClaim.ModelNames { + model := &coreapi.OpenModel{} + if err := r.Get(ctx, types.NamespacedName{Name: string(modelName)}, model); err != nil { + return ctrl.Result{}, err + } + models = append(models, model) + } } - // TODO: handle MultiModelsClaims in the future. + serviceApplyConfiguration = buildServiceApplyConfiguration(models, playground) + if err := setControllerReferenceForService(playground, serviceApplyConfiguration, r.Scheme); err != nil { return ctrl.Result{}, err } @@ -182,21 +190,28 @@ func (r *PlaygroundReconciler) SetupWithManager(mgr ctrl.Manager) error { Complete(r) } -func buildServiceApplyConfiguration(model *coreapi.OpenModel, playground *inferenceapi.Playground) *inferenceclientgo.ServiceApplyConfiguration { +func buildServiceApplyConfiguration(models []*coreapi.OpenModel, playground *inferenceapi.Playground) *inferenceclientgo.ServiceApplyConfiguration { // Build metadata serviceApplyConfiguration := inferenceclientgo.Service(playground.Name, playground.Namespace) // Build spec. spec := inferenceclientgo.ServiceSpec() + claim := &coreclientgo.MultiModelsClaimApplyConfiguration{} if playground.Spec.ModelClaim != nil { - claim := coreclientgo.MultiModelsClaim(). + claim = coreclientgo.MultiModelsClaim(). WithModelNames(playground.Spec.ModelClaim.ModelName). - WithInferenceFlavors(playground.Spec.ModelClaim.InferenceFlavors...) - spec.WithMultiModelsClaims(claim) + WithInferenceFlavors(playground.Spec.ModelClaim.InferenceFlavors...). + WithInferenceMode(coreapi.Standard) + } else if playground.Spec.MultiModelsClaim != nil { + claim = coreclientgo.MultiModelsClaim(). + WithModelNames(playground.Spec.MultiModelsClaim.ModelNames...). + WithInferenceFlavors(playground.Spec.MultiModelsClaim.InferenceFlavors...). + WithInferenceMode(playground.Spec.MultiModelsClaim.InferenceMode) } - spec.WithWorkloadTemplate(buildWorkloadTemplate(model, playground)) + spec.WithMultiModelsClaim(claim) + spec.WithWorkloadTemplate(buildWorkloadTemplate(models, playground)) serviceApplyConfiguration.WithSpec(spec) return serviceApplyConfiguration @@ -208,7 +223,7 @@ func buildServiceApplyConfiguration(model *coreapi.OpenModel, playground *infere // to cover both single-host and multi-host cases. There're some shortages for lws like can not force rolling // update when one replica failed, we'll fix this in the kubernetes upstream. // Model flavors will not be considered but in inferenceService controller to support accelerator fungibility. -func buildWorkloadTemplate(model *coreapi.OpenModel, playground *inferenceapi.Playground) lws.LeaderWorkerSetSpec { +func buildWorkloadTemplate(models []*coreapi.OpenModel, playground *inferenceapi.Playground) lws.LeaderWorkerSetSpec { // TODO: this should be leaderWorkerSetTemplateSpec, we should support in the lws upstream. workload := lws.LeaderWorkerSetSpec{ // Use the default policy defined in lws. @@ -222,12 +237,12 @@ func buildWorkloadTemplate(model *coreapi.OpenModel, playground *inferenceapi.Pl // TODO: handle multi-host scenarios, e.g. nvidia.com/gpu: 32, means we'll split into 4 hosts. // Do we need another configuration for playground for multi-host use case? I guess no currently. - workload.LeaderWorkerTemplate.WorkerTemplate = buildWorkerTemplate(model, playground) + workload.LeaderWorkerTemplate.WorkerTemplate = buildWorkerTemplate(models, playground) return workload } -func buildWorkerTemplate(model *coreapi.OpenModel, playground *inferenceapi.Playground) corev1.PodTemplateSpec { +func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.Playground) corev1.PodTemplateSpec { backendName := inferenceapi.DefaultBackend if playground.Spec.BackendConfig != nil && playground.Spec.BackendConfig.Name != nil { backendName = *playground.Spec.BackendConfig.Name @@ -239,7 +254,12 @@ func buildWorkerTemplate(model *coreapi.OpenModel, playground *inferenceapi.Play version = *playground.Spec.BackendConfig.Version } - args := bkd.DefaultArgs(model) + mode := coreapi.Standard + if playground.Spec.MultiModelsClaim != nil { + mode = playground.Spec.MultiModelsClaim.InferenceMode + } + + args := bkd.Args(models, mode) var envs []corev1.EnvVar if playground.Spec.BackendConfig != nil { diff --git a/pkg/controller/inference/service_controller.go b/pkg/controller/inference/service_controller.go index 3febdff2..1fc0b6d3 100644 --- a/pkg/controller/inference/service_controller.go +++ b/pkg/controller/inference/service_controller.go @@ -80,14 +80,16 @@ func (r *ServiceReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct logger.V(10).Info("reconcile Service", "Playground", klog.KObj(service)) - model := &coreapi.OpenModel{} - // TODO: multiModelsClaim - modelName := service.Spec.MultiModelsClaims[0].ModelNames[0] - if err := r.Get(ctx, types.NamespacedName{Name: string(modelName)}, model); err != nil { - return ctrl.Result{}, err + models := []*coreapi.OpenModel{} + for _, modelName := range service.Spec.MultiModelsClaim.ModelNames { + model := &coreapi.OpenModel{} + if err := r.Get(ctx, types.NamespacedName{Name: string(modelName)}, model); err != nil { + return ctrl.Result{}, err + } + models = append(models, model) } - workloadApplyConfiguration := buildWorkloadApplyConfiguration(service, model) + workloadApplyConfiguration := buildWorkloadApplyConfiguration(service, models) if err := setControllerReferenceForLWS(service, workloadApplyConfiguration, r.Scheme); err != nil { return ctrl.Result{}, err } @@ -127,14 +129,14 @@ func (r *ServiceReconciler) SetupWithManager(mgr ctrl.Manager) error { Complete(r) } -func buildWorkloadApplyConfiguration(service *inferenceapi.Service, model *coreapi.OpenModel) *applyconfigurationv1.LeaderWorkerSetApplyConfiguration { +func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*coreapi.OpenModel) *applyconfigurationv1.LeaderWorkerSetApplyConfiguration { workload := applyconfigurationv1.LeaderWorkerSet(service.Name, service.Namespace) leaderWorkerTemplate := applyconfigurationv1.LeaderWorkerTemplate() leaderWorkerTemplate.WithWorkerTemplate(service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate) // The core logic to inject additional configurations. - injectModelProperties(leaderWorkerTemplate, model) + injectModelProperties(leaderWorkerTemplate, models) spec := applyconfigurationv1.LeaderWorkerSetSpec() spec.WithLeaderWorkerTemplate(leaderWorkerTemplate) @@ -144,17 +146,16 @@ func buildWorkloadApplyConfiguration(service *inferenceapi.Service, model *corea return workload } -func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateApplyConfiguration, model *coreapi.OpenModel) { - source := modelSource.NewModelSourceProvider(model) - - template.WorkerTemplate.Labels = util.MergeKVs(template.WorkerTemplate.Labels, modelLabels(model)) - - injectModelLoader(template, source) - injectModelFlavor(template, model) -} +func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateApplyConfiguration, models []*coreapi.OpenModel) { + for i, model := range models { + source := modelSource.NewModelSourceProvider(model) + source.InjectModelLoader(template.WorkerTemplate, i) + } -func injectModelLoader(template *applyconfigurationv1.LeaderWorkerTemplateApplyConfiguration, source modelSource.ModelSourceProvider) { - source.InjectModelLoader(template.WorkerTemplate) + // We treat the 0-index model as the main model, we only consider the main model's requirements, + // like label, flavor. + template.WorkerTemplate.Labels = util.MergeKVs(template.WorkerTemplate.Labels, modelLabels(models[0])) + injectModelFlavor(template, models[0]) } func injectModelFlavor(template *applyconfigurationv1.LeaderWorkerTemplateApplyConfiguration, model *coreapi.OpenModel) { @@ -203,7 +204,6 @@ func injectModelFlavor(template *applyconfigurationv1.LeaderWorkerTemplateApplyC } template.WorkerTemplate.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms = []corev1.NodeSelectorTerm{term} } - } func modelLabels(model *coreapi.OpenModel) map[string]string { diff --git a/pkg/controller_helper/backend/backend.go b/pkg/controller_helper/backend/backend.go index 5b3c7188..36d0db9e 100644 --- a/pkg/controller_helper/backend/backend.go +++ b/pkg/controller_helper/backend/backend.go @@ -38,8 +38,16 @@ type Backend interface { DefaultResources() inferenceapi.ResourceRequirements // DefaultCommands returns the default command to start the inference backend. DefaultCommands() []string - // DefaultArgs returns the default bootstrap arguments to start the backend. - DefaultArgs(*coreapi.OpenModel) []string + // Args returns the bootstrap arguments to start the backend. + Args([]*coreapi.OpenModel, coreapi.InferenceMode) []string + // defaultArgs returns the bootstrap arguments when inferenceMode is standard. + defaultArgs(*coreapi.OpenModel) []string +} + +// SpeculativeBackend represents backend supports speculativeDecoding inferenceMode. +type SpeculativeBackend interface { + // speculativeArgs returns the bootstrap arguments when inferenceMode is speculativeDecoding. + speculativeArgs([]*coreapi.OpenModel) []string } func SwitchBackend(name inferenceapi.BackendName) Backend { diff --git a/pkg/controller_helper/backend/llamacpp.go b/pkg/controller_helper/backend/llamacpp.go index 9e821ad3..dd65ef40 100644 --- a/pkg/controller_helper/backend/llamacpp.go +++ b/pkg/controller_helper/backend/llamacpp.go @@ -64,7 +64,15 @@ func (l *LLAMACPP) DefaultCommands() []string { return []string{"./llama-server"} } -func (l *LLAMACPP) DefaultArgs(model *coreapi.OpenModel) []string { +func (l *LLAMACPP) Args(models []*coreapi.OpenModel, mode coreapi.InferenceMode) []string { + if mode == coreapi.Standard { + return l.defaultArgs(models[0]) + } + // We should not reach here. + return nil +} + +func (l *LLAMACPP) defaultArgs(model *coreapi.OpenModel) []string { source := modelSource.NewModelSourceProvider(model) return []string{ "-m", source.ModelPath(), diff --git a/pkg/controller_helper/backend/sglang.go b/pkg/controller_helper/backend/sglang.go index 21e7bab0..a34fbfb6 100644 --- a/pkg/controller_helper/backend/sglang.go +++ b/pkg/controller_helper/backend/sglang.go @@ -64,7 +64,15 @@ func (s *SGLANG) DefaultCommands() []string { return []string{"python3", "-m", "sglang.launch_server"} } -func (s *SGLANG) DefaultArgs(model *coreapi.OpenModel) []string { +func (s *SGLANG) Args(models []*coreapi.OpenModel, mode coreapi.InferenceMode) []string { + if mode == coreapi.Standard { + return s.defaultArgs(models[0]) + } + // We should not reach here. + return nil +} + +func (s *SGLANG) defaultArgs(model *coreapi.OpenModel) []string { source := modelSource.NewModelSourceProvider(model) return []string{ "--model-path", source.ModelPath(), diff --git a/pkg/controller_helper/backend/vllm.go b/pkg/controller_helper/backend/vllm.go index dda12269..a41ac79a 100644 --- a/pkg/controller_helper/backend/vllm.go +++ b/pkg/controller_helper/backend/vllm.go @@ -28,6 +28,7 @@ import ( ) var _ Backend = (*VLLM)(nil) +var _ SpeculativeBackend = (*VLLM)(nil) type VLLM struct{} @@ -64,11 +65,35 @@ func (v *VLLM) DefaultCommands() []string { return []string{"python3", "-m", "vllm.entrypoints.openai.api_server"} } -func (v *VLLM) DefaultArgs(model *coreapi.OpenModel) []string { +func (v *VLLM) Args(models []*coreapi.OpenModel, mode coreapi.InferenceMode) []string { + if mode == coreapi.Standard { + return v.defaultArgs(models[0]) + } + if mode == coreapi.SpeculativeDecoding { + return v.speculativeArgs(models) + } + // We should not reach here. + return nil +} + +func (v *VLLM) defaultArgs(model *coreapi.OpenModel) []string { source := modelSource.NewModelSourceProvider(model) return []string{ "--model", source.ModelPath(), "--served-model-name", source.ModelName(), + "--host", "0.0.0.0", + "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), + } +} + +func (v *VLLM) speculativeArgs(models []*coreapi.OpenModel) []string { + targetModelSource := modelSource.NewModelSourceProvider(models[0]) + draftModelSource := modelSource.NewModelSourceProvider(models[1]) + return []string{ + "--model", targetModelSource.ModelPath(), + "--speculative_model", draftModelSource.ModelPath(), + "--served-model-name", targetModelSource.ModelName(), + "--host", "0.0.0.0", "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), } } diff --git a/pkg/controller_helper/model_source/modelhub.go b/pkg/controller_helper/model_source/modelhub.go index 3c5ceb05..573c6a0a 100644 --- a/pkg/controller_helper/model_source/modelhub.go +++ b/pkg/controller_helper/model_source/modelhub.go @@ -17,6 +17,7 @@ limitations under the License. package modelSource import ( + "strconv" "strings" corev1 "k8s.io/api/core/v1" @@ -54,10 +55,15 @@ func (p *ModelHubProvider) ModelPath() string { return CONTAINER_MODEL_PATH + "models--" + strings.ReplaceAll(p.modelID, "/", "--") } -func (p *ModelHubProvider) InjectModelLoader(template *corev1.PodTemplateSpec) { +func (p *ModelHubProvider) InjectModelLoader(template *corev1.PodTemplateSpec, index int) { + initContainerName := MODEL_LOADER_CONTAINER_NAME + if index != 0 { + initContainerName += "-" + strconv.Itoa(index) + } + // Handle initContainer. initContainer := &corev1.Container{ - Name: MODEL_LOADER_CONTAINER_NAME, + Name: initContainerName, Image: pkg.LOADER_IMAGE, VolumeMounts: []corev1.VolumeMount{ { @@ -110,6 +116,11 @@ func (p *ModelHubProvider) InjectModelLoader(template *corev1.PodTemplateSpec) { ) template.Spec.InitContainers = append(template.Spec.InitContainers, *initContainer) + // Return once not the main model, because all the below has already been injected. + if index != 0 { + return + } + // Handle container. for i := range template.Spec.Containers { diff --git a/pkg/controller_helper/model_source/modelsource.go b/pkg/controller_helper/model_source/modelsource.go index 08d9e937..a32573a9 100644 --- a/pkg/controller_helper/model_source/modelsource.go +++ b/pkg/controller_helper/model_source/modelsource.go @@ -52,7 +52,11 @@ const ( type ModelSourceProvider interface { ModelName() string ModelPath() string - InjectModelLoader(*corev1.PodTemplateSpec) + // InjectModelLoader will inject the model loader to the spec, + // initContainerOnly means whether to inject specs other than initContainers, + // just in case of rewriting the specs, + // index refers to the suffix of the initContainer name, like model-loader, model-loader-1. + InjectModelLoader(spec *corev1.PodTemplateSpec, index int) } func NewModelSourceProvider(model *coreapi.OpenModel) ModelSourceProvider { diff --git a/pkg/controller_helper/model_source/uri.go b/pkg/controller_helper/model_source/uri.go index 36cf7354..092fd1ac 100644 --- a/pkg/controller_helper/model_source/uri.go +++ b/pkg/controller_helper/model_source/uri.go @@ -14,6 +14,7 @@ limitations under the License. package modelSource import ( + "strconv" "strings" corev1 "k8s.io/api/core/v1" @@ -56,10 +57,14 @@ func (p *URIProvider) ModelPath() string { return CONTAINER_MODEL_PATH + "models--" + splits[len(splits)-1] } -func (p *URIProvider) InjectModelLoader(template *corev1.PodTemplateSpec) { +func (p *URIProvider) InjectModelLoader(template *corev1.PodTemplateSpec, index int) { + initContainerName := MODEL_LOADER_CONTAINER_NAME + if index != 0 { + initContainerName += "-" + strconv.Itoa(index) + } // Handle initContainer. initContainer := &corev1.Container{ - Name: MODEL_LOADER_CONTAINER_NAME, + Name: initContainerName, Image: pkg.LOADER_IMAGE, VolumeMounts: []corev1.VolumeMount{ { @@ -107,6 +112,11 @@ func (p *URIProvider) InjectModelLoader(template *corev1.PodTemplateSpec) { template.Spec.InitContainers = append(template.Spec.InitContainers, *initContainer) + // Return once not the main model, because all the below has already been injected. + if index != 0 { + return + } + // Handle container. for i, container := range template.Spec.Containers { diff --git a/pkg/webhook/playground_webhook.go b/pkg/webhook/playground_webhook.go index fb52c9b6..723685ad 100644 --- a/pkg/webhook/playground_webhook.go +++ b/pkg/webhook/playground_webhook.go @@ -52,8 +52,10 @@ func (w *PlaygroundWebhook) Default(ctx context.Context, obj runtime.Object) err var modelName string if playground.Spec.ModelClaim != nil { modelName = string(playground.Spec.ModelClaim.ModelName) + } else if playground.Spec.MultiModelsClaim != nil { + // We choose the first model as the main model. + modelName = string(playground.Spec.MultiModelsClaim.ModelNames[0]) } - // TODO: handle MultiModelsClaims in the future. if playground.Labels == nil { playground.Labels = map[string]string{} @@ -93,8 +95,19 @@ func (w *PlaygroundWebhook) generateValidate(obj runtime.Object) field.ErrorList specPath := field.NewPath("spec") var allErrs field.ErrorList - if playground.Spec.ModelClaim == nil && len(playground.Spec.MultiModelsClaims) == 0 { - allErrs = append(allErrs, field.Forbidden(specPath, "modelClaim and multiModelsClaims couldn't be both empty")) + if playground.Spec.ModelClaim == nil && playground.Spec.MultiModelsClaim == nil { + allErrs = append(allErrs, field.Forbidden(specPath, "modelClaim and multiModelsClaim couldn't be both nil")) + } + if playground.Spec.MultiModelsClaim != nil { + if playground.Spec.MultiModelsClaim.InferenceMode == coreapi.SpeculativeDecoding { + if playground.Spec.BackendConfig != nil && *playground.Spec.BackendConfig.Name != inferenceapi.VLLM { + allErrs = append(allErrs, field.Forbidden(specPath.Child("multiModelsClaim", "inferenceMode"), "only vLLM supports speculativeDecoding mode")) + } + if len(playground.Spec.MultiModelsClaim.ModelNames) != 2 { + allErrs = append(allErrs, field.Forbidden(specPath.Child("multiModelsClaim", "modelNames"), "only two models are allowed in speculativeDecoding mode")) + } + } + } return allErrs } diff --git a/test/integration/controller/inference/playground_test.go b/test/integration/controller/inference/playground_test.go index 0381d15a..50e0d9ec 100644 --- a/test/integration/controller/inference/playground_test.go +++ b/test/integration/controller/inference/playground_test.go @@ -40,6 +40,7 @@ var _ = ginkgo.Describe("playground controller test", func() { // Each test runs in a separate namespace. var ns *corev1.Namespace var model *coreapi.OpenModel + var draftModel *coreapi.OpenModel type update struct { playgroundUpdateFn func(*inferenceapi.Playground) @@ -56,10 +57,13 @@ var _ = ginkgo.Describe("playground controller test", func() { gomega.Expect(k8sClient.Create(ctx, ns)).To(gomega.Succeed()) model = util.MockASampleModel() gomega.Expect(k8sClient.Create(ctx, model)).To(gomega.Succeed()) + draftModel = wrapper.MakeModel("llama3-2b").FamilyName("llama3").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("meta-llama/Meta-Llama-3-2B", "").Obj() + gomega.Expect(k8sClient.Create(ctx, draftModel)).To(gomega.Succeed()) }) ginkgo.AfterEach(func() { gomega.Expect(k8sClient.Delete(ctx, ns)).To(gomega.Succeed()) gomega.Expect(k8sClient.Delete(ctx, model)).To(gomega.Succeed()) + gomega.Expect(k8sClient.Delete(ctx, draftModel)).To(gomega.Succeed()) }) type testValidatingCase struct { @@ -139,6 +143,23 @@ var _ = ginkgo.Describe("playground controller test", func() { }, }, }), + ginkgo.Entry("Playground with speculativeDecoding", &testValidatingCase{ + makePlayground: func() *inferenceapi.Playground { + return wrapper.MakePlayground("playground", ns.Name).MultiModelsClaim([]string{model.Name, draftModel.Name}, coreapi.SpeculativeDecoding).Label(coreapi.ModelNameLabelKey, model.Name). + Obj() + }, + updates: []*update{ + { + playgroundUpdateFn: func(playground *inferenceapi.Playground) { + gomega.Expect(k8sClient.Create(ctx, playground)).To(gomega.Succeed()) + }, + checkPlayground: func(ctx context.Context, k8sClient client.Client, playground *inferenceapi.Playground) { + validation.ValidatePlayground(ctx, k8sClient, playground) + validation.ValidatePlaygroundStatusEqualTo(ctx, k8sClient, playground, inferenceapi.PlaygroundProgressing, "Pending", metav1.ConditionTrue) + }, + }, + }, + }), ginkgo.Entry("advance configured Playground with llamacpp", &testValidatingCase{ makePlayground: func() *inferenceapi.Playground { return wrapper.MakePlayground("playground", ns.Name).ModelClaim(model.Name).Label(coreapi.ModelNameLabelKey, model.Name). @@ -166,7 +187,7 @@ var _ = ginkgo.Describe("playground controller test", func() { playgroundUpdateFn: func(playground *inferenceapi.Playground) { // Create a service with the same name as the playground. service := wrapper.MakeService(playground.Name, playground.Namespace). - ModelsClaim([]string{"llama3-8b"}, []string{}, nil). + ModelsClaim([]string{"llama3-8b"}, coreapi.Standard, nil). WorkerTemplate(). Obj() gomega.Expect(k8sClient.Create(ctx, service)).To(gomega.Succeed()) @@ -180,7 +201,7 @@ var _ = ginkgo.Describe("playground controller test", func() { // Delete the service, playground should be updated to Pending. playgroundUpdateFn: func(playground *inferenceapi.Playground) { service := wrapper.MakeService(playground.Name, playground.Namespace). - ModelsClaim([]string{"llama3-8b"}, []string{}, nil). + ModelsClaim([]string{"llama3-8b"}, coreapi.Standard, nil). WorkerTemplate(). Obj() gomega.Expect(k8sClient.Delete(ctx, service)).To(gomega.Succeed()) diff --git a/test/integration/controller/inference/service_test.go b/test/integration/controller/inference/service_test.go index cb04d413..f8617aef 100644 --- a/test/integration/controller/inference/service_test.go +++ b/test/integration/controller/inference/service_test.go @@ -65,11 +65,12 @@ var _ = ginkgo.Describe("inferenceService controller test", func() { } }) + // TODO: Add more testCases to cover status update. + type testValidatingCase struct { makeService func() *inferenceapi.Service updates []*update } - // TODO: Add more testCases to cover updating. ginkgo.DescribeTable("test playground creation and update", func(tc *testValidatingCase) { service := tc.makeService() @@ -127,7 +128,26 @@ var _ = ginkgo.Describe("inferenceService controller test", func() { ginkgo.Entry("service created with URI configured Model", &testValidatingCase{ makeService: func() *inferenceapi.Service { return wrapper.MakeService("service-llama3-8b", ns.Name). - ModelsClaim([]string{"model-with-uri"}, []string{}, nil). + ModelsClaim([]string{"model-with-uri"}, coreapi.Standard, nil). + WorkerTemplate(). + Obj() + }, + updates: []*update{ + { + serviceUpdateFn: func(service *inferenceapi.Service) { + gomega.Expect(k8sClient.Create(ctx, service)).To(gomega.Succeed()) + }, + checkService: func(ctx context.Context, k8sClient client.Client, service *inferenceapi.Service) { + validation.ValidateService(ctx, k8sClient, service) + validation.ValidateServiceStatusEqualTo(ctx, k8sClient, service, inferenceapi.ServiceProgressing, "ServiceInProgress", metav1.ConditionTrue) + }, + }, + }, + }), + ginkgo.Entry("service created with speculativeDecoding mode", &testValidatingCase{ + makeService: func() *inferenceapi.Service { + return wrapper.MakeService("service-llama3-8b", ns.Name). + ModelsClaim([]string{"llama3-8b", "model-with-uri"}, coreapi.SpeculativeDecoding, nil). WorkerTemplate(). Obj() }, diff --git a/test/integration/webhook/playground_test.go b/test/integration/webhook/playground_test.go index 1bcb1c25..e61f2f62 100644 --- a/test/integration/webhook/playground_test.go +++ b/test/integration/webhook/playground_test.go @@ -87,12 +87,30 @@ var _ = ginkgo.Describe("playground default and validation", func() { }, failed: false, }), + ginkgo.Entry("speculativeDecoding with SGLang is not allowed", &testValidatingCase{ + playground: func() *inferenceapi.Playground { + return wrapper.MakePlayground("playground", ns.Name).Replicas(1).MultiModelsClaim([]string{"llama3-405b", "llama3-8b"}, coreapi.SpeculativeDecoding).Backend(string(inferenceapi.SGLANG)).Obj() + }, + failed: true, + }), + ginkgo.Entry("speculativeDecoding with three models claimed", &testValidatingCase{ + playground: func() *inferenceapi.Playground { + return wrapper.MakePlayground("playground", ns.Name).Replicas(1).MultiModelsClaim([]string{"llama3-405b", "llama3-8b", "llama3-2b"}, coreapi.SpeculativeDecoding).Obj() + }, + failed: true, + }), ginkgo.Entry("unknown backend configured", &testValidatingCase{ playground: func() *inferenceapi.Playground { return wrapper.MakePlayground("playground", ns.Name).Replicas(1).Backend("unknown").Obj() }, failed: true, }), + ginkgo.Entry("unknown inference mode", &testValidatingCase{ + playground: func() *inferenceapi.Playground { + return wrapper.MakePlayground("playground", ns.Name).Replicas(1).MultiModelsClaim([]string{"llama3-405b", "llama3-8b"}, coreapi.InferenceMode("unknown")).Obj() + }, + failed: true, + }), ) type testDefaultingCase struct { @@ -115,5 +133,17 @@ var _ = ginkgo.Describe("playground default and validation", func() { return wrapper.MakePlayground("playground", ns.Name).ModelClaim("llama3-8b").Replicas(1).Label(coreapi.ModelNameLabelKey, "llama3-8b").Obj() }, }), + ginkgo.Entry("defaulting inferenceMode with multiModelsClaim", &testDefaultingCase{ + playground: func() *inferenceapi.Playground { + playground := wrapper.MakePlayground("playground", ns.Name).Replicas(1).Obj() + playground.Spec.MultiModelsClaim = &coreapi.MultiModelsClaim{ + ModelNames: []coreapi.ModelName{"llama3-405b", "llama3-8b"}, + } + return playground + }, + wantPlayground: func() *inferenceapi.Playground { + return wrapper.MakePlayground("playground", ns.Name).MultiModelsClaim([]string{"llama3-405b", "llama3-8b"}, coreapi.Standard).Replicas(1).Label(coreapi.ModelNameLabelKey, "llama3-405b").Obj() + }, + }), ) }) diff --git a/test/integration/webhook/service_test.go b/test/integration/webhook/service_test.go index 2fd04e6f..6ed4c476 100644 --- a/test/integration/webhook/service_test.go +++ b/test/integration/webhook/service_test.go @@ -22,6 +22,7 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" "github.com/inftyai/llmaz/test/util" "github.com/inftyai/llmaz/test/util/wrapper" @@ -72,7 +73,7 @@ var _ = ginkgo.Describe("service default and validation", func() { ginkgo.Entry("model-runner container doesn't exist", &testValidatingCase{ service: func() *inferenceapi.Service { return wrapper.MakeService("service-llama3-8b", ns.Name). - ModelsClaim([]string{"llama3-8b"}, []string{}, nil). + ModelsClaim([]string{"llama3-8b"}, coreapi.Standard, nil). WorkerTemplate(). ContainerName("model-runner-fake"). Obj() diff --git a/test/util/mock.go b/test/util/mock.go index d96c7391..91febb94 100644 --- a/test/util/mock.go +++ b/test/util/mock.go @@ -35,7 +35,7 @@ func MockASamplePlayground(ns string) *inferenceapi.Playground { func MockASampleService(ns string) *inferenceapi.Service { return wrapper.MakeService("service-llama3-8b", ns). - ModelsClaim([]string{"llama3-8b"}, []string{}, nil). + ModelsClaim([]string{"llama3-8b"}, coreapi.Standard, nil). WorkerTemplate(). Obj() } diff --git a/test/util/validation/validate_playground.go b/test/util/validation/validate_playground.go index 1d159dc5..fbd3c30c 100644 --- a/test/util/validation/validate_playground.go +++ b/test/util/validation/validate_playground.go @@ -37,37 +37,53 @@ import ( "github.com/inftyai/llmaz/test/util/format" ) -func ValidatePlayground(ctx context.Context, k8sClient client.Client, playground *inferenceapi.Playground) { - gomega.Eventually(func() error { - service := inferenceapi.Service{} - if err := k8sClient.Get(ctx, types.NamespacedName{Name: playground.Name, Namespace: playground.Namespace}, &service); err != nil { - return errors.New("failed to get inferenceService") +func validateModelClaim(ctx context.Context, k8sClient client.Client, playground *inferenceapi.Playground, service inferenceapi.Service) error { + model := coreapi.OpenModel{} + + if playground.Spec.ModelClaim != nil { + if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(playground.Spec.ModelClaim.ModelName), Namespace: playground.Namespace}, &model); err != nil { + return errors.New("failed to get model") } - if *playground.Spec.Replicas != *service.Spec.WorkloadTemplate.Replicas { - return fmt.Errorf("expected replicas: %d, got %d", *playground.Spec.Replicas, *service.Spec.WorkloadTemplate.Replicas) + if playground.Spec.ModelClaim.ModelName != service.Spec.MultiModelsClaim.ModelNames[0] { + return fmt.Errorf("expected modelName %s, got %s", playground.Spec.ModelClaim.ModelName, service.Spec.MultiModelsClaim.ModelNames[0]) + } + if diff := cmp.Diff(playground.Spec.ModelClaim.InferenceFlavors, service.Spec.MultiModelsClaim.InferenceFlavors); diff != "" { + return fmt.Errorf("unexpected flavors, want %v, got %v", playground.Spec.ModelClaim.InferenceFlavors, service.Spec.MultiModelsClaim.InferenceFlavors) + } + } else if playground.Spec.MultiModelsClaim != nil { + if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(playground.Spec.MultiModelsClaim.ModelNames[0]), Namespace: playground.Namespace}, &model); err != nil { + return errors.New("failed to get model") } + if diff := cmp.Diff(playground.Spec.MultiModelsClaim.ModelNames, service.Spec.MultiModelsClaim.ModelNames); diff != "" { + return fmt.Errorf("expected modelNames, want %s, got %s", playground.Spec.MultiModelsClaim.ModelNames, service.Spec.MultiModelsClaim.ModelNames) + } + if diff := cmp.Diff(playground.Spec.MultiModelsClaim.InferenceFlavors, service.Spec.MultiModelsClaim.InferenceFlavors); diff != "" { + return fmt.Errorf("unexpected flavors, want %v, got %v", playground.Spec.MultiModelsClaim.InferenceFlavors, service.Spec.MultiModelsClaim.InferenceFlavors) + } + } - model := coreapi.OpenModel{} + if playground.Labels[coreapi.ModelNameLabelKey] != model.Name { + return fmt.Errorf("unexpected Playground label value, want %v, got %v", model.Name, playground.Labels[coreapi.ModelNameLabelKey]) + } - if playground.Spec.ModelClaim != nil { - if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(playground.Spec.ModelClaim.ModelName), Namespace: playground.Namespace}, &model); err != nil { - return errors.New("failed to get model") - } + return nil +} - if playground.Spec.ModelClaim.ModelName != service.Spec.MultiModelsClaims[0].ModelNames[0] { - return fmt.Errorf("expected modelName %s, got %s", playground.Spec.ModelClaim.ModelName, service.Spec.MultiModelsClaims[0].ModelNames[0]) - } - if diff := cmp.Diff(playground.Spec.ModelClaim.InferenceFlavors, service.Spec.MultiModelsClaims[0].InferenceFlavors); diff != "" { - return fmt.Errorf("unexpected flavors, want %v, got %v", playground.Spec.ModelClaim.InferenceFlavors, service.Spec.MultiModelsClaims[0].InferenceFlavors) - } +func ValidatePlayground(ctx context.Context, k8sClient client.Client, playground *inferenceapi.Playground) { + gomega.Eventually(func() error { + service := inferenceapi.Service{} + if err := k8sClient.Get(ctx, types.NamespacedName{Name: playground.Name, Namespace: playground.Namespace}, &service); err != nil { + return errors.New("failed to get inferenceService") } - if playground.Labels[coreapi.ModelNameLabelKey] != model.Name { - return fmt.Errorf("unexpected Playground label value, want %v, got %v", model.Name, playground.Labels[coreapi.ModelNameLabelKey]) + if err := validateModelClaim(ctx, k8sClient, playground, service); err != nil { + return err } - // TODO: MultiModelsClaim + if *playground.Spec.Replicas != *service.Spec.WorkloadTemplate.Replicas { + return fmt.Errorf("expected replicas: %d, got %d", *playground.Spec.Replicas, *service.Spec.WorkloadTemplate.Replicas) + } backendName := inferenceapi.DefaultBackend if playground.Spec.BackendConfig != nil && playground.Spec.BackendConfig.Name != nil { diff --git a/test/util/validation/validate_service.go b/test/util/validation/validate_service.go index 315e87c4..0dd08fb4 100644 --- a/test/util/validation/validate_service.go +++ b/test/util/validation/validate_service.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "strconv" "github.com/google/go-cmp/cmp" "github.com/onsi/gomega" @@ -47,30 +48,36 @@ func ValidateService(ctx context.Context, k8sClient client.Client, service *infe return fmt.Errorf("unexpected replicas %d, got %d", *service.Spec.WorkloadTemplate.Replicas, *workload.Spec.Replicas) } - // TODO: multiModelsClaim // TODO: multi-host - modelName := string(service.Spec.MultiModelsClaims[0].ModelNames[0]) - model := coreapi.OpenModel{} - if err := k8sClient.Get(ctx, types.NamespacedName{Name: modelName}, &model); err != nil { - return errors.New("failed to get model") + models := []*coreapi.OpenModel{} + modelNames := service.Spec.MultiModelsClaim.ModelNames + for _, modelName := range modelNames { + model := &coreapi.OpenModel{} + if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(modelName)}, model); err != nil { + return errors.New("failed to get model") + } + models = append(models, model) } - if workload.Spec.LeaderWorkerTemplate.WorkerTemplate.Labels[coreapi.ModelNameLabelKey] != model.Name { - return fmt.Errorf("unexpected model name %s in template, want %s", workload.Labels[coreapi.ModelNameLabelKey], model.Name) - } - if workload.Spec.LeaderWorkerTemplate.WorkerTemplate.Labels[coreapi.ModelFamilyNameLabelKey] != string(model.Spec.FamilyName) { - return fmt.Errorf("unexpected model family name %s in template, want %s", workload.Spec.LeaderWorkerTemplate.WorkerTemplate.Labels[coreapi.ModelFamilyNameLabelKey], model.Spec.FamilyName) + for index, model := range models { + // Validate injecting modelLoaders + if err := ValidateModelLoader(model, index, &workload, service); err != nil { + return err + } } - // Validate injecting modelLoaders - if err := ValidateModelLoader(&model, &workload, service); err != nil { - return err + mainModel := models[0] + if workload.Spec.LeaderWorkerTemplate.WorkerTemplate.Labels[coreapi.ModelNameLabelKey] != mainModel.Name { + return fmt.Errorf("unexpected model name %s in template, want %s", workload.Labels[coreapi.ModelNameLabelKey], mainModel.Name) + } + if workload.Spec.LeaderWorkerTemplate.WorkerTemplate.Labels[coreapi.ModelFamilyNameLabelKey] != string(mainModel.Spec.FamilyName) { + return fmt.Errorf("unexpected model family name %s in template, want %s", workload.Spec.LeaderWorkerTemplate.WorkerTemplate.Labels[coreapi.ModelFamilyNameLabelKey], mainModel.Spec.FamilyName) } // Validate injecting flavors. - if len(model.Spec.InferenceFlavors) != 0 { - if err := ValidateModelFlavor(&model, &workload); err != nil { + if len(mainModel.Spec.InferenceFlavors) != 0 { + if err := ValidateModelFlavor(mainModel, &workload); err != nil { return err } } @@ -79,15 +86,19 @@ func ValidateService(ctx context.Context, k8sClient client.Client, service *infe }, util.IntegrationTimeout, util.Interval).Should(gomega.Succeed()) } -func ValidateModelLoader(model *coreapi.OpenModel, workload *lws.LeaderWorkerSet, service *inferenceapi.Service) error { +func ValidateModelLoader(model *coreapi.OpenModel, index int, workload *lws.LeaderWorkerSet, service *inferenceapi.Service) error { if model.Spec.Source.ModelHub != nil || model.Spec.Source.URI != nil { if len(workload.Spec.LeaderWorkerTemplate.WorkerTemplate.Spec.InitContainers) == 0 { return errors.New("no initContainer configured") } - initContainer := workload.Spec.LeaderWorkerTemplate.WorkerTemplate.Spec.InitContainers[0] + initContainer := workload.Spec.LeaderWorkerTemplate.WorkerTemplate.Spec.InitContainers[index] - if initContainer.Name != modelSource.MODEL_LOADER_CONTAINER_NAME { + containerName := modelSource.MODEL_LOADER_CONTAINER_NAME + if index != 0 { + containerName += "-" + strconv.Itoa(index) + } + if initContainer.Name != containerName { return fmt.Errorf("unexpected initContainer name, want %s, got %s", modelSource.MODEL_LOADER_CONTAINER_NAME, initContainer.Name) } if initContainer.Image != pkg.LOADER_IMAGE { diff --git a/test/util/wrapper/playground.go b/test/util/wrapper/playground.go index 81123625..5160a0cb 100644 --- a/test/util/wrapper/playground.go +++ b/test/util/wrapper/playground.go @@ -71,6 +71,27 @@ func (w *PlaygroundWrapper) ModelClaim(modelName string, flavorNames ...string) return w } +func (w *PlaygroundWrapper) MultiModelsClaim(modelNames []string, mode coreapi.InferenceMode, flavorNames ...string) *PlaygroundWrapper { + mNames := []coreapi.ModelName{} + for _, name := range modelNames { + mNames = append(mNames, coreapi.ModelName(name)) + } + + fNames := []coreapi.FlavorName{} + for _, name := range flavorNames { + fNames = append(fNames, coreapi.FlavorName(name)) + } + w.Spec.MultiModelsClaim = &coreapi.MultiModelsClaim{ + InferenceMode: mode, + ModelNames: mNames, + } + + if len(fNames) > 0 { + w.Spec.ModelClaim.InferenceFlavors = fNames + } + return w +} + func (w *PlaygroundWrapper) Backend(name string) *PlaygroundWrapper { if w.Spec.BackendConfig == nil { w.Spec.BackendConfig = &inferenceapi.BackendConfig{} diff --git a/test/util/wrapper/service.go b/test/util/wrapper/service.go index 16f898ed..512f074d 100644 --- a/test/util/wrapper/service.go +++ b/test/util/wrapper/service.go @@ -45,7 +45,7 @@ func (w *ServiceWrapper) Obj() *inferenceapi.Service { return &w.Service } -func (w *ServiceWrapper) ModelsClaim(modelNames []string, flavorNames []string, rate *int32) *ServiceWrapper { +func (w *ServiceWrapper) ModelsClaim(modelNames []string, mode coreapi.InferenceMode, flavorNames []string) *ServiceWrapper { names := []coreapi.ModelName{} for i := range modelNames { names = append(names, coreapi.ModelName(modelNames[i])) @@ -54,11 +54,11 @@ func (w *ServiceWrapper) ModelsClaim(modelNames []string, flavorNames []string, for i := range flavorNames { flavors = append(flavors, coreapi.FlavorName(flavorNames[i])) } - w.Spec.MultiModelsClaims = append(w.Spec.MultiModelsClaims, coreapi.MultiModelsClaim{ + w.Spec.MultiModelsClaim = coreapi.MultiModelsClaim{ ModelNames: names, + InferenceMode: mode, InferenceFlavors: flavors, - Rate: rate, - }) + } return w }