diff --git a/api/core/v1alpha1/model_types.go b/api/core/v1alpha1/model_types.go
index e0ca3d62..ec16682c 100644
--- a/api/core/v1alpha1/model_types.go
+++ b/api/core/v1alpha1/model_types.go
@@ -120,28 +120,35 @@ type ModelClaim struct {
InferenceFlavors []FlavorName `json:"inferenceFlavors,omitempty"`
}
-type InferenceMode string
+type ModelRole string
const (
- Standard InferenceMode = "Standard"
- SpeculativeDecoding InferenceMode = "SpeculativeDecoding"
+ // Main represents the main model, if only one model is required,
+ // it must be the main model. Only one main model is allowed.
+ MainRole ModelRole = "main"
+ // Draft represents the draft model in speculative decoding,
+ // the main model is the target model then.
+ DraftRole ModelRole = "draft"
)
-// MultiModelsClaim represents claiming for multiple models with different claimModes,
-// like standard or speculative-decoding to support different inference scenarios.
-type MultiModelsClaim struct {
- // ModelNames represents a list of models, there maybe multiple models here
- // to support state-of-the-art technologies like speculative decoding.
- // If the composedMode is SpeculativeDecoding, the first model is the target model,
- // and the second model is the draft model.
- // +kubebuilder:validation:MinItems=1
- ModelNames []ModelName `json:"modelNames,omitempty"`
- // Mode represents the paradigm to serve the model, whether via a standard way
- // or via an advanced technique like SpeculativeDecoding.
- // +kubebuilder:default=Standard
- // +kubebuilder:validation:Enum={Standard,SpeculativeDecoding}
+type ModelRepresentative struct {
+ // Name represents the model name.
+ Name ModelName `json:"name"`
+ // Role represents the model role once more than one model is required.
+ // +kubebuilder:validation:Enum={main,draft}
+ // +kubebuilder:default=main
// +optional
- InferenceMode InferenceMode `json:"inferenceMode,omitempty"`
+ Role *ModelRole `json:"role,omitempty"`
+}
+
+// ModelClaims represents multiple claims for different models.
+type ModelClaims struct {
+ // Models represents a list of models with roles specified, there maybe
+ // multiple models here to support state-of-the-art technologies like
+ // speculative decoding, then one model is main(target) model, another one
+ // is draft model.
+ // +kubebuilder:validation:MinItems=1
+ Models []ModelRepresentative `json:"models,omitempty"`
// InferenceFlavors represents a list of flavors with fungibility supported
// to serve the model.
// - If not set, always apply with the 0-index model by default.
diff --git a/api/core/v1alpha1/zz_generated.deepcopy.go b/api/core/v1alpha1/zz_generated.deepcopy.go
index 8ad44d3e..241c4c50 100644
--- a/api/core/v1alpha1/zz_generated.deepcopy.go
+++ b/api/core/v1alpha1/zz_generated.deepcopy.go
@@ -82,6 +82,33 @@ func (in *ModelClaim) DeepCopy() *ModelClaim {
return out
}
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *ModelClaims) DeepCopyInto(out *ModelClaims) {
+ *out = *in
+ if in.Models != nil {
+ in, out := &in.Models, &out.Models
+ *out = make([]ModelRepresentative, len(*in))
+ for i := range *in {
+ (*in)[i].DeepCopyInto(&(*out)[i])
+ }
+ }
+ if in.InferenceFlavors != nil {
+ in, out := &in.InferenceFlavors, &out.InferenceFlavors
+ *out = make([]FlavorName, len(*in))
+ copy(*out, *in)
+ }
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelClaims.
+func (in *ModelClaims) DeepCopy() *ModelClaims {
+ if in == nil {
+ return nil
+ }
+ out := new(ModelClaims)
+ in.DeepCopyInto(out)
+ return out
+}
+
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *ModelHub) DeepCopyInto(out *ModelHub) {
*out = *in
@@ -112,6 +139,26 @@ func (in *ModelHub) DeepCopy() *ModelHub {
return out
}
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *ModelRepresentative) DeepCopyInto(out *ModelRepresentative) {
+ *out = *in
+ if in.Role != nil {
+ in, out := &in.Role, &out.Role
+ *out = new(ModelRole)
+ **out = **in
+ }
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelRepresentative.
+func (in *ModelRepresentative) DeepCopy() *ModelRepresentative {
+ if in == nil {
+ return nil
+ }
+ out := new(ModelRepresentative)
+ in.DeepCopyInto(out)
+ return out
+}
+
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *ModelSource) DeepCopyInto(out *ModelSource) {
*out = *in
@@ -182,31 +229,6 @@ func (in *ModelStatus) DeepCopy() *ModelStatus {
return out
}
-// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
-func (in *MultiModelsClaim) DeepCopyInto(out *MultiModelsClaim) {
- *out = *in
- if in.ModelNames != nil {
- in, out := &in.ModelNames, &out.ModelNames
- *out = make([]ModelName, len(*in))
- copy(*out, *in)
- }
- if in.InferenceFlavors != nil {
- in, out := &in.InferenceFlavors, &out.InferenceFlavors
- *out = make([]FlavorName, len(*in))
- copy(*out, *in)
- }
-}
-
-// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MultiModelsClaim.
-func (in *MultiModelsClaim) DeepCopy() *MultiModelsClaim {
- if in == nil {
- return nil
- }
- out := new(MultiModelsClaim)
- in.DeepCopyInto(out)
- return out
-}
-
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *OpenModel) DeepCopyInto(out *OpenModel) {
*out = *in
diff --git a/api/inference/v1alpha1/playground_types.go b/api/inference/v1alpha1/playground_types.go
index 792fe421..2d0dd436 100644
--- a/api/inference/v1alpha1/playground_types.go
+++ b/api/inference/v1alpha1/playground_types.go
@@ -28,17 +28,16 @@ type PlaygroundSpec struct {
// +kubebuilder:default=1
// +optional
Replicas *int32 `json:"replicas,omitempty"`
- // ModelClaim represents claiming for one model, it's the standard claimMode
- // of multiModelsClaim compared to other modes like SpeculativeDecoding.
- // Most of the time, modelClaim is enough.
- // ModelClaim and multiModelsClaim are exclusive configured.
+ // ModelClaim represents claiming for one model, it's a simplified use case
+ // of modelClaims. Most of the time, modelClaim is enough.
+ // ModelClaim and modelClaims are exclusive configured.
// +optional
ModelClaim *coreapi.ModelClaim `json:"modelClaim,omitempty"`
- // MultiModelsClaim represents claiming for multiple models with different claimModes,
- // like standard or speculative-decoding to support different inference scenarios.
- // ModelClaim and multiModelsClaim are exclusive configured.
+ // ModelClaims represents claiming for multiple models for more complicated
+ // use cases like speculative-decoding.
+ // ModelClaims and modelClaim are exclusive configured.
// +optional
- MultiModelsClaim *coreapi.MultiModelsClaim `json:"multiModelsClaim,omitempty"`
+ ModelClaims *coreapi.ModelClaims `json:"modelClaims,omitempty"`
// BackendConfig represents the inference backend configuration
// under the hood, e.g. vLLM, which is the default backend.
// +optional
diff --git a/api/inference/v1alpha1/service_types.go b/api/inference/v1alpha1/service_types.go
index 9ab675b9..7de6087d 100644
--- a/api/inference/v1alpha1/service_types.go
+++ b/api/inference/v1alpha1/service_types.go
@@ -27,9 +27,8 @@ import (
// Service controller will maintain multi-flavor of workloads with
// different accelerators for cost or performance considerations.
type ServiceSpec struct {
- // MultiModelsClaim represents claiming for multiple models with different claimModes,
- // like standard or speculative-decoding to support different inference scenarios.
- MultiModelsClaim coreapi.MultiModelsClaim `json:"multiModelsClaim,omitempty"`
+ // ModelClaims represents multiple claims for different models.
+ ModelClaims coreapi.ModelClaims `json:"modelClaims,omitempty"`
// WorkloadTemplate defines the underlying workload layout and configuration.
// Note: the LWS spec might be twisted with various LWS instances to support
// accelerator fungibility or other cutting-edge researches.
diff --git a/api/inference/v1alpha1/zz_generated.deepcopy.go b/api/inference/v1alpha1/zz_generated.deepcopy.go
index cfdad843..dd373e47 100644
--- a/api/inference/v1alpha1/zz_generated.deepcopy.go
+++ b/api/inference/v1alpha1/zz_generated.deepcopy.go
@@ -166,9 +166,9 @@ func (in *PlaygroundSpec) DeepCopyInto(out *PlaygroundSpec) {
*out = new(corev1alpha1.ModelClaim)
(*in).DeepCopyInto(*out)
}
- if in.MultiModelsClaim != nil {
- in, out := &in.MultiModelsClaim, &out.MultiModelsClaim
- *out = new(corev1alpha1.MultiModelsClaim)
+ if in.ModelClaims != nil {
+ in, out := &in.ModelClaims, &out.ModelClaims
+ *out = new(corev1alpha1.ModelClaims)
(*in).DeepCopyInto(*out)
}
if in.BackendConfig != nil {
@@ -301,7 +301,7 @@ func (in *ServiceList) DeepCopyObject() runtime.Object {
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *ServiceSpec) DeepCopyInto(out *ServiceSpec) {
*out = *in
- in.MultiModelsClaim.DeepCopyInto(&out.MultiModelsClaim)
+ in.ModelClaims.DeepCopyInto(&out.ModelClaims)
in.WorkloadTemplate.DeepCopyInto(&out.WorkloadTemplate)
if in.ElasticConfig != nil {
in, out := &in.ElasticConfig, &out.ElasticConfig
diff --git a/client-go/applyconfiguration/core/v1alpha1/modelclaims.go b/client-go/applyconfiguration/core/v1alpha1/modelclaims.go
new file mode 100644
index 00000000..52760ef2
--- /dev/null
+++ b/client-go/applyconfiguration/core/v1alpha1/modelclaims.go
@@ -0,0 +1,58 @@
+/*
+Copyright 2024.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+// Code generated by applyconfiguration-gen. DO NOT EDIT.
+
+package v1alpha1
+
+import (
+ corev1alpha1 "github.com/inftyai/llmaz/api/core/v1alpha1"
+)
+
+// ModelClaimsApplyConfiguration represents an declarative configuration of the ModelClaims type for use
+// with apply.
+type ModelClaimsApplyConfiguration struct {
+ Models []ModelRepresentativeApplyConfiguration `json:"models,omitempty"`
+ InferenceFlavors []corev1alpha1.FlavorName `json:"inferenceFlavors,omitempty"`
+}
+
+// ModelClaimsApplyConfiguration constructs an declarative configuration of the ModelClaims type for use with
+// apply.
+func ModelClaims() *ModelClaimsApplyConfiguration {
+ return &ModelClaimsApplyConfiguration{}
+}
+
+// WithModels adds the given value to the Models field in the declarative configuration
+// and returns the receiver, so that objects can be build by chaining "With" function invocations.
+// If called multiple times, values provided by each call will be appended to the Models field.
+func (b *ModelClaimsApplyConfiguration) WithModels(values ...*ModelRepresentativeApplyConfiguration) *ModelClaimsApplyConfiguration {
+ for i := range values {
+ if values[i] == nil {
+ panic("nil value passed to WithModels")
+ }
+ b.Models = append(b.Models, *values[i])
+ }
+ return b
+}
+
+// WithInferenceFlavors adds the given value to the InferenceFlavors field in the declarative configuration
+// and returns the receiver, so that objects can be build by chaining "With" function invocations.
+// If called multiple times, values provided by each call will be appended to the InferenceFlavors field.
+func (b *ModelClaimsApplyConfiguration) WithInferenceFlavors(values ...corev1alpha1.FlavorName) *ModelClaimsApplyConfiguration {
+ for i := range values {
+ b.InferenceFlavors = append(b.InferenceFlavors, values[i])
+ }
+ return b
+}
diff --git a/client-go/applyconfiguration/core/v1alpha1/modelrepresentative.go b/client-go/applyconfiguration/core/v1alpha1/modelrepresentative.go
new file mode 100644
index 00000000..83477b22
--- /dev/null
+++ b/client-go/applyconfiguration/core/v1alpha1/modelrepresentative.go
@@ -0,0 +1,51 @@
+/*
+Copyright 2024.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+// Code generated by applyconfiguration-gen. DO NOT EDIT.
+
+package v1alpha1
+
+import (
+ v1alpha1 "github.com/inftyai/llmaz/api/core/v1alpha1"
+)
+
+// ModelRepresentativeApplyConfiguration represents an declarative configuration of the ModelRepresentative type for use
+// with apply.
+type ModelRepresentativeApplyConfiguration struct {
+ Name *v1alpha1.ModelName `json:"name,omitempty"`
+ Role *v1alpha1.ModelRole `json:"role,omitempty"`
+}
+
+// ModelRepresentativeApplyConfiguration constructs an declarative configuration of the ModelRepresentative type for use with
+// apply.
+func ModelRepresentative() *ModelRepresentativeApplyConfiguration {
+ return &ModelRepresentativeApplyConfiguration{}
+}
+
+// WithName sets the Name field in the declarative configuration to the given value
+// and returns the receiver, so that objects can be built by chaining "With" function invocations.
+// If called multiple times, the Name field is set to the value of the last call.
+func (b *ModelRepresentativeApplyConfiguration) WithName(value v1alpha1.ModelName) *ModelRepresentativeApplyConfiguration {
+ b.Name = &value
+ return b
+}
+
+// WithRole sets the Role field in the declarative configuration to the given value
+// and returns the receiver, so that objects can be built by chaining "With" function invocations.
+// If called multiple times, the Role field is set to the value of the last call.
+func (b *ModelRepresentativeApplyConfiguration) WithRole(value v1alpha1.ModelRole) *ModelRepresentativeApplyConfiguration {
+ b.Role = &value
+ return b
+}
diff --git a/client-go/applyconfiguration/core/v1alpha1/multimodelsclaim.go b/client-go/applyconfiguration/core/v1alpha1/multimodelsclaim.go
deleted file mode 100644
index 3c6a8bc3..00000000
--- a/client-go/applyconfiguration/core/v1alpha1/multimodelsclaim.go
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
-Copyright 2024.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-// Code generated by applyconfiguration-gen. DO NOT EDIT.
-
-package v1alpha1
-
-import (
- v1alpha1 "github.com/inftyai/llmaz/api/core/v1alpha1"
-)
-
-// MultiModelsClaimApplyConfiguration represents an declarative configuration of the MultiModelsClaim type for use
-// with apply.
-type MultiModelsClaimApplyConfiguration struct {
- ModelNames []v1alpha1.ModelName `json:"modelNames,omitempty"`
- InferenceMode *v1alpha1.InferenceMode `json:"inferenceMode,omitempty"`
- InferenceFlavors []v1alpha1.FlavorName `json:"inferenceFlavors,omitempty"`
-}
-
-// MultiModelsClaimApplyConfiguration constructs an declarative configuration of the MultiModelsClaim type for use with
-// apply.
-func MultiModelsClaim() *MultiModelsClaimApplyConfiguration {
- return &MultiModelsClaimApplyConfiguration{}
-}
-
-// WithModelNames adds the given value to the ModelNames field in the declarative configuration
-// and returns the receiver, so that objects can be build by chaining "With" function invocations.
-// If called multiple times, values provided by each call will be appended to the ModelNames field.
-func (b *MultiModelsClaimApplyConfiguration) WithModelNames(values ...v1alpha1.ModelName) *MultiModelsClaimApplyConfiguration {
- for i := range values {
- b.ModelNames = append(b.ModelNames, values[i])
- }
- return b
-}
-
-// WithInferenceMode sets the InferenceMode field in the declarative configuration to the given value
-// and returns the receiver, so that objects can be built by chaining "With" function invocations.
-// If called multiple times, the InferenceMode field is set to the value of the last call.
-func (b *MultiModelsClaimApplyConfiguration) WithInferenceMode(value v1alpha1.InferenceMode) *MultiModelsClaimApplyConfiguration {
- b.InferenceMode = &value
- return b
-}
-
-// WithInferenceFlavors adds the given value to the InferenceFlavors field in the declarative configuration
-// and returns the receiver, so that objects can be build by chaining "With" function invocations.
-// If called multiple times, values provided by each call will be appended to the InferenceFlavors field.
-func (b *MultiModelsClaimApplyConfiguration) WithInferenceFlavors(values ...v1alpha1.FlavorName) *MultiModelsClaimApplyConfiguration {
- for i := range values {
- b.InferenceFlavors = append(b.InferenceFlavors, values[i])
- }
- return b
-}
diff --git a/client-go/applyconfiguration/inference/v1alpha1/playgroundspec.go b/client-go/applyconfiguration/inference/v1alpha1/playgroundspec.go
index 6c39c925..b9692a36 100644
--- a/client-go/applyconfiguration/inference/v1alpha1/playgroundspec.go
+++ b/client-go/applyconfiguration/inference/v1alpha1/playgroundspec.go
@@ -24,10 +24,10 @@ import (
// PlaygroundSpecApplyConfiguration represents an declarative configuration of the PlaygroundSpec type for use
// with apply.
type PlaygroundSpecApplyConfiguration struct {
- Replicas *int32 `json:"replicas,omitempty"`
- ModelClaim *v1alpha1.ModelClaimApplyConfiguration `json:"modelClaim,omitempty"`
- MultiModelsClaim *v1alpha1.MultiModelsClaimApplyConfiguration `json:"multiModelsClaim,omitempty"`
- BackendConfig *BackendConfigApplyConfiguration `json:"backendConfig,omitempty"`
+ Replicas *int32 `json:"replicas,omitempty"`
+ ModelClaim *v1alpha1.ModelClaimApplyConfiguration `json:"modelClaim,omitempty"`
+ ModelClaims *v1alpha1.ModelClaimsApplyConfiguration `json:"modelClaims,omitempty"`
+ BackendConfig *BackendConfigApplyConfiguration `json:"backendConfig,omitempty"`
}
// PlaygroundSpecApplyConfiguration constructs an declarative configuration of the PlaygroundSpec type for use with
@@ -52,11 +52,11 @@ func (b *PlaygroundSpecApplyConfiguration) WithModelClaim(value *v1alpha1.ModelC
return b
}
-// WithMultiModelsClaim sets the MultiModelsClaim field in the declarative configuration to the given value
+// WithModelClaims sets the ModelClaims field in the declarative configuration to the given value
// and returns the receiver, so that objects can be built by chaining "With" function invocations.
-// If called multiple times, the MultiModelsClaim field is set to the value of the last call.
-func (b *PlaygroundSpecApplyConfiguration) WithMultiModelsClaim(value *v1alpha1.MultiModelsClaimApplyConfiguration) *PlaygroundSpecApplyConfiguration {
- b.MultiModelsClaim = value
+// If called multiple times, the ModelClaims field is set to the value of the last call.
+func (b *PlaygroundSpecApplyConfiguration) WithModelClaims(value *v1alpha1.ModelClaimsApplyConfiguration) *PlaygroundSpecApplyConfiguration {
+ b.ModelClaims = value
return b
}
diff --git a/client-go/applyconfiguration/inference/v1alpha1/servicespec.go b/client-go/applyconfiguration/inference/v1alpha1/servicespec.go
index f31e425f..1ba4aa2a 100644
--- a/client-go/applyconfiguration/inference/v1alpha1/servicespec.go
+++ b/client-go/applyconfiguration/inference/v1alpha1/servicespec.go
@@ -25,9 +25,9 @@ import (
// ServiceSpecApplyConfiguration represents an declarative configuration of the ServiceSpec type for use
// with apply.
type ServiceSpecApplyConfiguration struct {
- MultiModelsClaim *v1alpha1.MultiModelsClaimApplyConfiguration `json:"multiModelsClaim,omitempty"`
- WorkloadTemplate *v1.LeaderWorkerSetSpec `json:"workloadTemplate,omitempty"`
- ElasticConfig *ElasticConfigApplyConfiguration `json:"elasticConfig,omitempty"`
+ ModelClaims *v1alpha1.ModelClaimsApplyConfiguration `json:"modelClaims,omitempty"`
+ WorkloadTemplate *v1.LeaderWorkerSetSpec `json:"workloadTemplate,omitempty"`
+ ElasticConfig *ElasticConfigApplyConfiguration `json:"elasticConfig,omitempty"`
}
// ServiceSpecApplyConfiguration constructs an declarative configuration of the ServiceSpec type for use with
@@ -36,11 +36,11 @@ func ServiceSpec() *ServiceSpecApplyConfiguration {
return &ServiceSpecApplyConfiguration{}
}
-// WithMultiModelsClaim sets the MultiModelsClaim field in the declarative configuration to the given value
+// WithModelClaims sets the ModelClaims field in the declarative configuration to the given value
// and returns the receiver, so that objects can be built by chaining "With" function invocations.
-// If called multiple times, the MultiModelsClaim field is set to the value of the last call.
-func (b *ServiceSpecApplyConfiguration) WithMultiModelsClaim(value *v1alpha1.MultiModelsClaimApplyConfiguration) *ServiceSpecApplyConfiguration {
- b.MultiModelsClaim = value
+// If called multiple times, the ModelClaims field is set to the value of the last call.
+func (b *ServiceSpecApplyConfiguration) WithModelClaims(value *v1alpha1.ModelClaimsApplyConfiguration) *ServiceSpecApplyConfiguration {
+ b.ModelClaims = value
return b
}
diff --git a/client-go/applyconfiguration/utils.go b/client-go/applyconfiguration/utils.go
index 0bb10ec7..1ede1792 100644
--- a/client-go/applyconfiguration/utils.go
+++ b/client-go/applyconfiguration/utils.go
@@ -54,16 +54,18 @@ func ForKind(kind schema.GroupVersionKind) interface{} {
return &applyconfigurationcorev1alpha1.FlavorApplyConfiguration{}
case corev1alpha1.SchemeGroupVersion.WithKind("ModelClaim"):
return &applyconfigurationcorev1alpha1.ModelClaimApplyConfiguration{}
+ case corev1alpha1.SchemeGroupVersion.WithKind("ModelClaims"):
+ return &applyconfigurationcorev1alpha1.ModelClaimsApplyConfiguration{}
case corev1alpha1.SchemeGroupVersion.WithKind("ModelHub"):
return &applyconfigurationcorev1alpha1.ModelHubApplyConfiguration{}
+ case corev1alpha1.SchemeGroupVersion.WithKind("ModelRepresentative"):
+ return &applyconfigurationcorev1alpha1.ModelRepresentativeApplyConfiguration{}
case corev1alpha1.SchemeGroupVersion.WithKind("ModelSource"):
return &applyconfigurationcorev1alpha1.ModelSourceApplyConfiguration{}
case corev1alpha1.SchemeGroupVersion.WithKind("ModelSpec"):
return &applyconfigurationcorev1alpha1.ModelSpecApplyConfiguration{}
case corev1alpha1.SchemeGroupVersion.WithKind("ModelStatus"):
return &applyconfigurationcorev1alpha1.ModelStatusApplyConfiguration{}
- case corev1alpha1.SchemeGroupVersion.WithKind("MultiModelsClaim"):
- return &applyconfigurationcorev1alpha1.MultiModelsClaimApplyConfiguration{}
case corev1alpha1.SchemeGroupVersion.WithKind("OpenModel"):
return &applyconfigurationcorev1alpha1.OpenModelApplyConfiguration{}
diff --git a/config/crd/bases/inference.llmaz.io_playgrounds.yaml b/config/crd/bases/inference.llmaz.io_playgrounds.yaml
index 766444dd..ae35cf74 100644
--- a/config/crd/bases/inference.llmaz.io_playgrounds.yaml
+++ b/config/crd/bases/inference.llmaz.io_playgrounds.yaml
@@ -224,10 +224,9 @@ spec:
type: object
modelClaim:
description: |-
- ModelClaim represents claiming for one model, it's the standard claimMode
- of multiModelsClaim compared to other modes like SpeculativeDecoding.
- Most of the time, modelClaim is enough.
- ModelClaim and multiModelsClaim are exclusive configured.
+ ModelClaim represents claiming for one model, it's a simplified use case
+ of modelClaims. Most of the time, modelClaim is enough.
+ ModelClaim and modelClaims are exclusive configured.
properties:
inferenceFlavors:
description: |-
@@ -242,11 +241,11 @@ spec:
description: ModelName represents the name of the Model.
type: string
type: object
- multiModelsClaim:
+ modelClaims:
description: |-
- MultiModelsClaim represents claiming for multiple models with different claimModes,
- like standard or speculative-decoding to support different inference scenarios.
- ModelClaim and multiModelsClaim are exclusive configured.
+ ModelClaims represents claiming for multiple models for more complicated
+ use cases like speculative-decoding.
+ ModelClaims and modelClaim are exclusive configured.
properties:
inferenceFlavors:
description: |-
@@ -257,23 +256,28 @@ spec:
items:
type: string
type: array
- inferenceMode:
- default: Standard
+ models:
description: |-
- Mode represents the paradigm to serve the model, whether via a standard way
- or via an advanced technique like SpeculativeDecoding.
- enum:
- - Standard
- - SpeculativeDecoding
- type: string
- modelNames:
- description: |-
- ModelNames represents a list of models, there maybe multiple models here
- to support state-of-the-art technologies like speculative decoding.
- If the composedMode is SpeculativeDecoding, the first model is the target model,
- and the second model is the draft model.
+ Models represents a list of models with roles specified, there maybe
+ multiple models here to support state-of-the-art technologies like
+ speculative decoding, then one model is main(target) model, another one
+ is draft model.
items:
- type: string
+ properties:
+ name:
+ description: Name represents the model name.
+ type: string
+ role:
+ default: main
+ description: Role represents the model role once more than
+ one model is required.
+ enum:
+ - main
+ - draft
+ type: string
+ required:
+ - name
+ type: object
minItems: 1
type: array
type: object
diff --git a/config/crd/bases/inference.llmaz.io_services.yaml b/config/crd/bases/inference.llmaz.io_services.yaml
index e6bc503d..f00ce468 100644
--- a/config/crd/bases/inference.llmaz.io_services.yaml
+++ b/config/crd/bases/inference.llmaz.io_services.yaml
@@ -65,10 +65,9 @@ spec:
format: int32
type: integer
type: object
- multiModelsClaim:
- description: |-
- MultiModelsClaim represents claiming for multiple models with different claimModes,
- like standard or speculative-decoding to support different inference scenarios.
+ modelClaims:
+ description: ModelClaims represents multiple claims for different
+ models.
properties:
inferenceFlavors:
description: |-
@@ -79,23 +78,28 @@ spec:
items:
type: string
type: array
- inferenceMode:
- default: Standard
+ models:
description: |-
- Mode represents the paradigm to serve the model, whether via a standard way
- or via an advanced technique like SpeculativeDecoding.
- enum:
- - Standard
- - SpeculativeDecoding
- type: string
- modelNames:
- description: |-
- ModelNames represents a list of models, there maybe multiple models here
- to support state-of-the-art technologies like speculative decoding.
- If the composedMode is SpeculativeDecoding, the first model is the target model,
- and the second model is the draft model.
+ Models represents a list of models with roles specified, there maybe
+ multiple models here to support state-of-the-art technologies like
+ speculative decoding, then one model is main(target) model, another one
+ is draft model.
items:
- type: string
+ properties:
+ name:
+ description: Name represents the model name.
+ type: string
+ role:
+ default: main
+ description: Role represents the model role once more than
+ one model is required.
+ enum:
+ - main
+ - draft
+ type: string
+ required:
+ - name
+ type: object
minItems: 1
type: array
type: object
diff --git a/docs/examples/speculative-decoding/llamacpp/playground.yaml b/docs/examples/speculative-decoding/llamacpp/playground.yaml
index 5ab223eb..e237503e 100644
--- a/docs/examples/speculative-decoding/llamacpp/playground.yaml
+++ b/docs/examples/speculative-decoding/llamacpp/playground.yaml
@@ -1,4 +1,4 @@
-# This is just an example, because it doesn't make any sense
+# This is just an toy example, because it doesn't make any sense
# in real world, drafting tokens for the model with similar size.
apiVersion: inference.llmaz.io/v1alpha1
@@ -7,11 +7,12 @@ metadata:
name: llamacpp-speculator
spec:
replicas: 1
- multiModelsClaim:
- inferenceMode: SpeculativeDecoding
- modelNames:
- - llama2-7b-q8-gguf # the target model, should be the first one
- - llama2-7b-q2-k-gguf # the draft model
+ modelClaims:
+ models:
+ - name: llama2-7b-q8-gguf # the target model
+ role: main
+ - name: llama2-7b-q2-k-gguf # the draft model
+ role: draft
backendConfig:
name: llamacpp
args:
diff --git a/docs/examples/speculative-decoding/vllm/playground.yaml b/docs/examples/speculative-decoding/vllm/playground.yaml
index 40f1c431..152f08d7 100644
--- a/docs/examples/speculative-decoding/vllm/playground.yaml
+++ b/docs/examples/speculative-decoding/vllm/playground.yaml
@@ -4,15 +4,14 @@ metadata:
name: vllm-speculator
spec:
replicas: 1
- multiModelsClaim:
- inferenceMode: SpeculativeDecoding
- modelNames:
- - opt-6--7b # the target model, should be the first one
- - opt-125m # the draft model
+ modelClaims:
+ models:
+ - name: opt-6--7b # the target model
+ role: main
+ - name: opt-125m # the draft model
+ role: draft
backendConfig:
args:
- --use-v2-block-manager
- - -tp
- - 1
- - --num_speculative_tokens
- - 5
+ - --num_speculative_tokens 5
+ - -tp 1
diff --git a/pkg/controller/inference/playground_controller.go b/pkg/controller/inference/playground_controller.go
index 48f0de40..2fbadbb0 100644
--- a/pkg/controller/inference/playground_controller.go
+++ b/pkg/controller/inference/playground_controller.go
@@ -106,10 +106,10 @@ func (r *PlaygroundReconciler) Reconcile(ctx context.Context, req ctrl.Request)
return ctrl.Result{}, err
}
models = append(models, model)
- } else if playground.Spec.MultiModelsClaim != nil {
- for _, modelName := range playground.Spec.MultiModelsClaim.ModelNames {
+ } else if playground.Spec.ModelClaims != nil {
+ for _, mr := range playground.Spec.ModelClaims.Models {
model := &coreapi.OpenModel{}
- if err := r.Get(ctx, types.NamespacedName{Name: string(modelName)}, model); err != nil {
+ if err := r.Get(ctx, types.NamespacedName{Name: string(mr.Name)}, model); err != nil {
if apierrors.IsNotFound(err) && handleUnexpectedCondition(playground, false, false) {
return ctrl.Result{}, r.Client.Status().Update(ctx, playground)
}
@@ -192,20 +192,28 @@ func buildServiceApplyConfiguration(models []*coreapi.OpenModel, playground *inf
// Build spec.
spec := inferenceclientgo.ServiceSpec()
- claim := &coreclientgo.MultiModelsClaimApplyConfiguration{}
+ claim := &coreclientgo.ModelClaimsApplyConfiguration{}
if playground.Spec.ModelClaim != nil {
- claim = coreclientgo.MultiModelsClaim().
- WithModelNames(playground.Spec.ModelClaim.ModelName).
- WithInferenceFlavors(playground.Spec.ModelClaim.InferenceFlavors...).
- WithInferenceMode(coreapi.Standard)
- } else if playground.Spec.MultiModelsClaim != nil {
- claim = coreclientgo.MultiModelsClaim().
- WithModelNames(playground.Spec.MultiModelsClaim.ModelNames...).
- WithInferenceFlavors(playground.Spec.MultiModelsClaim.InferenceFlavors...).
- WithInferenceMode(playground.Spec.MultiModelsClaim.InferenceMode)
+ claim = coreclientgo.ModelClaims().
+ WithModels(coreclientgo.ModelRepresentative().WithName(playground.Spec.ModelClaim.ModelName).WithRole(coreapi.MainRole)).
+ WithInferenceFlavors(playground.Spec.ModelClaim.InferenceFlavors...)
+ } else if playground.Spec.ModelClaims != nil {
+ mrs := []*coreclientgo.ModelRepresentativeApplyConfiguration{}
+ for _, model := range playground.Spec.ModelClaims.Models {
+ role := coreapi.MainRole
+ if model.Role != nil {
+ role = *model.Role
+ }
+ mr := coreclientgo.ModelRepresentative().WithName(model.Name).WithRole(role)
+ mrs = append(mrs, mr)
+ }
+
+ claim = coreclientgo.ModelClaims().
+ WithModels(mrs...).
+ WithInferenceFlavors(playground.Spec.ModelClaims.InferenceFlavors...)
}
- spec.WithMultiModelsClaim(claim)
+ spec.WithModelClaims(claim)
spec.WithWorkloadTemplate(buildWorkloadTemplate(models, playground))
serviceApplyConfiguration.WithSpec(spec)
@@ -237,6 +245,20 @@ func buildWorkloadTemplate(models []*coreapi.OpenModel, playground *inferenceapi
return workload
}
+func involveRole(playground *inferenceapi.Playground) coreapi.ModelRole {
+ if playground.Spec.ModelClaim != nil {
+ return coreapi.MainRole
+ } else if playground.Spec.ModelClaims != nil {
+ for _, mr := range playground.Spec.ModelClaims.Models {
+ if *mr.Role != coreapi.MainRole {
+ return *mr.Role
+ }
+ }
+ }
+
+ return coreapi.MainRole
+}
+
func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.Playground) corev1.PodTemplateSpec {
backendName := inferenceapi.DefaultBackend
if playground.Spec.BackendConfig != nil && playground.Spec.BackendConfig.Name != nil {
@@ -249,12 +271,7 @@ func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.P
version = *playground.Spec.BackendConfig.Version
}
- mode := coreapi.Standard
- if playground.Spec.MultiModelsClaim != nil {
- mode = playground.Spec.MultiModelsClaim.InferenceMode
- }
-
- args := bkd.Args(models, mode)
+ args := bkd.Args(models, involveRole(playground))
var envs []corev1.EnvVar
if playground.Spec.BackendConfig != nil {
@@ -285,7 +302,7 @@ func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.P
Name: modelSource.MODEL_RUNNER_CONTAINER_NAME,
Image: bkd.Image(version),
Resources: resources,
- Command: bkd.Command(),
+ Command: bkd.DefaultCommand(),
Args: args,
Env: envs,
Ports: []corev1.ContainerPort{
diff --git a/pkg/controller/inference/service_controller.go b/pkg/controller/inference/service_controller.go
index 0f1f998e..fb95ec93 100644
--- a/pkg/controller/inference/service_controller.go
+++ b/pkg/controller/inference/service_controller.go
@@ -81,9 +81,9 @@ func (r *ServiceReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct
logger.V(10).Info("reconcile Service", "Playground", klog.KObj(service))
models := []*coreapi.OpenModel{}
- for _, modelName := range service.Spec.MultiModelsClaim.ModelNames {
+ for _, mr := range service.Spec.ModelClaims.Models {
model := &coreapi.OpenModel{}
- if err := r.Get(ctx, types.NamespacedName{Name: string(modelName)}, model); err != nil {
+ if err := r.Get(ctx, types.NamespacedName{Name: string(mr.Name)}, model); err != nil {
return ctrl.Result{}, err
}
models = append(models, model)
@@ -153,7 +153,7 @@ func injectModelProperties(template *applyconfigurationv1.LeaderWorkerTemplateAp
}
// We treat the 0-index model as the main model, we only consider the main model's requirements,
- // like label, flavor.
+ // like label, flavor. Note: this may change in the future, let's see.
template.WorkerTemplate.Labels = util.MergeKVs(template.WorkerTemplate.Labels, modelLabels(models[0]))
injectModelFlavor(template, models[0])
}
diff --git a/pkg/controller_helper/backend/backend.go b/pkg/controller_helper/backend/backend.go
index d5b842bf..249ec9ca 100644
--- a/pkg/controller_helper/backend/backend.go
+++ b/pkg/controller_helper/backend/backend.go
@@ -36,13 +36,11 @@ type Backend interface {
DefaultVersion() string
// DefaultResources returns the default resources set for the container.
DefaultResources() inferenceapi.ResourceRequirements
- // Command returns the command to start the inference backend.
- Command() []string
+ // DefaultCommand returns the command to start the inference backend.
+ DefaultCommand() []string
// Args returns the bootstrap arguments to start the backend.
- Args([]*coreapi.OpenModel, coreapi.InferenceMode) []string
-
- // defaultArgs returns the bootstrap arguments when inferenceMode is standard.
- defaultArgs(*coreapi.OpenModel) []string
+ // The second parameter represents which particular modelRole involved, like draft.
+ Args([]*coreapi.OpenModel, coreapi.ModelRole) []string
}
// SpeculativeBackend represents backend supports speculativeDecoding inferenceMode.
diff --git a/pkg/controller_helper/backend/llamacpp.go b/pkg/controller_helper/backend/llamacpp.go
index e4404aac..cc2de38d 100644
--- a/pkg/controller_helper/backend/llamacpp.go
+++ b/pkg/controller_helper/backend/llamacpp.go
@@ -28,7 +28,6 @@ import (
)
var _ Backend = (*LLAMACPP)(nil)
-var _ SpeculativeBackend = (*LLAMACPP)(nil)
type LLAMACPP struct{}
@@ -61,37 +60,26 @@ func (l *LLAMACPP) DefaultResources() inferenceapi.ResourceRequirements {
}
}
-func (l *LLAMACPP) Command() []string {
+func (l *LLAMACPP) DefaultCommand() []string {
return []string{"./llama-server"}
}
-func (l *LLAMACPP) Args(models []*coreapi.OpenModel, mode coreapi.InferenceMode) []string {
- if mode == coreapi.Standard {
- return l.defaultArgs(models[0])
- }
- if mode == coreapi.SpeculativeDecoding {
- return l.speculativeArgs(models)
- }
- // We should not reach here.
- return nil
-}
+func (l *LLAMACPP) Args(models []*coreapi.OpenModel, involvedRole coreapi.ModelRole) []string {
+ targetModelSource := modelSource.NewModelSourceProvider(models[0])
-func (l *LLAMACPP) defaultArgs(model *coreapi.OpenModel) []string {
- source := modelSource.NewModelSourceProvider(model)
- return []string{
- "-m", source.ModelPath(),
- "--port", strconv.Itoa(DEFAULT_BACKEND_PORT),
- "--host", "0.0.0.0",
+ if involvedRole == coreapi.DraftRole {
+ draftModelSource := modelSource.NewModelSourceProvider(models[1])
+ return []string{
+ "-m", targetModelSource.ModelPath(),
+ "-md", draftModelSource.ModelPath(),
+ "--host", "0.0.0.0",
+ "--port", strconv.Itoa(DEFAULT_BACKEND_PORT),
+ }
}
-}
-func (l *LLAMACPP) speculativeArgs(models []*coreapi.OpenModel) []string {
- targetModelSource := modelSource.NewModelSourceProvider(models[0])
- draftModelSource := modelSource.NewModelSourceProvider(models[1])
return []string{
"-m", targetModelSource.ModelPath(),
- "-md", draftModelSource.ModelPath(),
- "--port", strconv.Itoa(DEFAULT_BACKEND_PORT),
"--host", "0.0.0.0",
+ "--port", strconv.Itoa(DEFAULT_BACKEND_PORT),
}
}
diff --git a/pkg/controller_helper/backend/llamacpp_test.go b/pkg/controller_helper/backend/llamacpp_test.go
index f3382653..b7402597 100644
--- a/pkg/controller_helper/backend/llamacpp_test.go
+++ b/pkg/controller_helper/backend/llamacpp_test.go
@@ -57,41 +57,41 @@ func Test_llamacpp(t *testing.T) {
}
testCases := []struct {
- name string
- mode coreapi.InferenceMode
- wantCommand []string
- wantArgs []string
+ name string
+ involvedRole coreapi.ModelRole
+ wantCommand []string
+ wantArgs []string
}{
{
- name: "standard mode",
- mode: coreapi.Standard,
- wantCommand: []string{"./llama-server"},
+ name: "one main model",
+ involvedRole: coreapi.MainRole,
+ wantCommand: []string{"./llama-server"},
wantArgs: []string{
"-m", "/workspace/models/models--hub--model-1",
- "--port", "8080",
"--host", "0.0.0.0",
+ "--port", "8080",
},
},
{
- name: "speculative decoding",
- mode: coreapi.SpeculativeDecoding,
- wantCommand: []string{"./llama-server"},
+ name: "speculative decoding",
+ involvedRole: coreapi.DraftRole,
+ wantCommand: []string{"./llama-server"},
wantArgs: []string{
"-m", "/workspace/models/models--hub--model-1",
"-md", "/workspace/models/models--hub--model-2",
- "--port", "8080",
"--host", "0.0.0.0",
+ "--port", "8080",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- if diff := cmp.Diff(backend.Command(), tc.wantCommand); diff != "" {
- t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.Command())
+ if diff := cmp.Diff(backend.DefaultCommand(), tc.wantCommand); diff != "" {
+ t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.DefaultCommand())
}
- if diff := cmp.Diff(backend.Args(models, tc.mode), tc.wantArgs); diff != "" {
- t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.mode))
+ if diff := cmp.Diff(backend.Args(models, tc.involvedRole), tc.wantArgs); diff != "" {
+ t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.involvedRole))
}
})
}
diff --git a/pkg/controller_helper/backend/sglang.go b/pkg/controller_helper/backend/sglang.go
index f9463f1c..12a7307b 100644
--- a/pkg/controller_helper/backend/sglang.go
+++ b/pkg/controller_helper/backend/sglang.go
@@ -60,23 +60,21 @@ func (s *SGLANG) DefaultResources() inferenceapi.ResourceRequirements {
}
}
-func (s *SGLANG) Command() []string {
+func (s *SGLANG) DefaultCommand() []string {
return []string{"python3", "-m", "sglang.launch_server"}
}
-func (s *SGLANG) Args(models []*coreapi.OpenModel, mode coreapi.InferenceMode) []string {
- if mode == coreapi.Standard {
- return s.defaultArgs(models[0])
+func (s *SGLANG) Args(models []*coreapi.OpenModel, involvedRole coreapi.ModelRole) []string {
+ targetModelSource := modelSource.NewModelSourceProvider(models[0])
+
+ if involvedRole == coreapi.DraftRole {
+ // TODO: support speculative decoding
+ return nil
}
- // We should not reach here.
- return nil
-}
-func (s *SGLANG) defaultArgs(model *coreapi.OpenModel) []string {
- source := modelSource.NewModelSourceProvider(model)
return []string{
- "--model-path", source.ModelPath(),
- "--served-model-name", source.ModelName(),
+ "--model-path", targetModelSource.ModelPath(),
+ "--served-model-name", targetModelSource.ModelName(),
"--host", "0.0.0.0",
"--port", strconv.Itoa(DEFAULT_BACKEND_PORT),
}
diff --git a/pkg/controller_helper/backend/sglang_test.go b/pkg/controller_helper/backend/sglang_test.go
index ecbf1535..ce1bb50a 100644
--- a/pkg/controller_helper/backend/sglang_test.go
+++ b/pkg/controller_helper/backend/sglang_test.go
@@ -57,15 +57,15 @@ func Test_SGLANG(t *testing.T) {
}
testCases := []struct {
- name string
- mode coreapi.InferenceMode
- wantCommand []string
- wantArgs []string
+ name string
+ involvedRole coreapi.ModelRole
+ wantCommand []string
+ wantArgs []string
}{
{
- name: "standard mode",
- mode: coreapi.Standard,
- wantCommand: []string{"python3", "-m", "sglang.launch_server"},
+ name: "one main model",
+ involvedRole: coreapi.MainRole,
+ wantCommand: []string{"python3", "-m", "sglang.launch_server"},
wantArgs: []string{
"--model-path", "/workspace/models/models--hub--model-1",
"--served-model-name", "model-1",
@@ -74,20 +74,20 @@ func Test_SGLANG(t *testing.T) {
},
},
{
- name: "speculative decoding",
- mode: coreapi.SpeculativeDecoding,
- wantCommand: []string{"python3", "-m", "sglang.launch_server"},
- wantArgs: nil,
+ name: "speculative decoding",
+ involvedRole: coreapi.DraftRole,
+ wantCommand: []string{"python3", "-m", "sglang.launch_server"},
+ wantArgs: nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- if diff := cmp.Diff(backend.Command(), tc.wantCommand); diff != "" {
- t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.Command())
+ if diff := cmp.Diff(backend.DefaultCommand(), tc.wantCommand); diff != "" {
+ t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.DefaultCommand())
}
- if diff := cmp.Diff(backend.Args(models, tc.mode), tc.wantArgs); diff != "" {
- t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.mode))
+ if diff := cmp.Diff(backend.Args(models, tc.involvedRole), tc.wantArgs); diff != "" {
+ t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.involvedRole))
}
})
}
diff --git a/pkg/controller_helper/backend/vllm.go b/pkg/controller_helper/backend/vllm.go
index 8334af73..467bfbc4 100644
--- a/pkg/controller_helper/backend/vllm.go
+++ b/pkg/controller_helper/backend/vllm.go
@@ -28,7 +28,6 @@ import (
)
var _ Backend = (*VLLM)(nil)
-var _ SpeculativeBackend = (*VLLM)(nil)
type VLLM struct{}
@@ -61,37 +60,26 @@ func (v *VLLM) DefaultResources() inferenceapi.ResourceRequirements {
}
}
-func (v *VLLM) Command() []string {
+func (v *VLLM) DefaultCommand() []string {
return []string{"python3", "-m", "vllm.entrypoints.openai.api_server"}
}
-func (v *VLLM) Args(models []*coreapi.OpenModel, mode coreapi.InferenceMode) []string {
- if mode == coreapi.Standard {
- return v.defaultArgs(models[0])
- }
- if mode == coreapi.SpeculativeDecoding {
- return v.speculativeArgs(models)
- }
- // We should not reach here.
- return nil
-}
+func (v *VLLM) Args(models []*coreapi.OpenModel, involvedRole coreapi.ModelRole) []string {
+ targetModelSource := modelSource.NewModelSourceProvider(models[0])
-func (v *VLLM) defaultArgs(model *coreapi.OpenModel) []string {
- source := modelSource.NewModelSourceProvider(model)
- return []string{
- "--model", source.ModelPath(),
- "--served-model-name", source.ModelName(),
- "--host", "0.0.0.0",
- "--port", strconv.Itoa(DEFAULT_BACKEND_PORT),
+ if involvedRole == coreapi.DraftRole {
+ draftModelSource := modelSource.NewModelSourceProvider(models[1])
+ return []string{
+ "--model", targetModelSource.ModelPath(),
+ "--speculative_model", draftModelSource.ModelPath(),
+ "--served-model-name", targetModelSource.ModelName(),
+ "--host", "0.0.0.0",
+ "--port", strconv.Itoa(DEFAULT_BACKEND_PORT),
+ }
}
-}
-func (v *VLLM) speculativeArgs(models []*coreapi.OpenModel) []string {
- targetModelSource := modelSource.NewModelSourceProvider(models[0])
- draftModelSource := modelSource.NewModelSourceProvider(models[1])
return []string{
"--model", targetModelSource.ModelPath(),
- "--speculative_model", draftModelSource.ModelPath(),
"--served-model-name", targetModelSource.ModelName(),
"--host", "0.0.0.0",
"--port", strconv.Itoa(DEFAULT_BACKEND_PORT),
diff --git a/pkg/controller_helper/backend/vllm_test.go b/pkg/controller_helper/backend/vllm_test.go
index 7b8d0629..d75fe4e1 100644
--- a/pkg/controller_helper/backend/vllm_test.go
+++ b/pkg/controller_helper/backend/vllm_test.go
@@ -57,15 +57,15 @@ func Test_vllm(t *testing.T) {
}
testCases := []struct {
- name string
- mode coreapi.InferenceMode
- wantCommand []string
- wantArgs []string
+ name string
+ involvedRole coreapi.ModelRole
+ wantCommand []string
+ wantArgs []string
}{
{
- name: "standard mode",
- mode: coreapi.Standard,
- wantCommand: []string{"python3", "-m", "vllm.entrypoints.openai.api_server"},
+ name: "one main model",
+ involvedRole: coreapi.MainRole,
+ wantCommand: []string{"python3", "-m", "vllm.entrypoints.openai.api_server"},
wantArgs: []string{
"--model", "/workspace/models/models--hub--model-1",
"--served-model-name", "model-1",
@@ -74,9 +74,9 @@ func Test_vllm(t *testing.T) {
},
},
{
- name: "speculative decoding",
- mode: coreapi.SpeculativeDecoding,
- wantCommand: []string{"python3", "-m", "vllm.entrypoints.openai.api_server"},
+ name: "speculative decoding",
+ involvedRole: coreapi.DraftRole,
+ wantCommand: []string{"python3", "-m", "vllm.entrypoints.openai.api_server"},
wantArgs: []string{
"--model", "/workspace/models/models--hub--model-1",
"--speculative_model", "/workspace/models/models--hub--model-2",
@@ -89,11 +89,11 @@ func Test_vllm(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- if diff := cmp.Diff(backend.Command(), tc.wantCommand); diff != "" {
- t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.Command())
+ if diff := cmp.Diff(backend.DefaultCommand(), tc.wantCommand); diff != "" {
+ t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.DefaultCommand())
}
- if diff := cmp.Diff(backend.Args(models, tc.mode), tc.wantArgs); diff != "" {
- t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.mode))
+ if diff := cmp.Diff(backend.Args(models, tc.involvedRole), tc.wantArgs); diff != "" {
+ t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.involvedRole))
}
})
}
diff --git a/pkg/controller_helper/model_source/modelsource.go b/pkg/controller_helper/model_source/modelsource.go
index a32573a9..75676e65 100644
--- a/pkg/controller_helper/model_source/modelsource.go
+++ b/pkg/controller_helper/model_source/modelsource.go
@@ -53,8 +53,6 @@ type ModelSourceProvider interface {
ModelName() string
ModelPath() string
// InjectModelLoader will inject the model loader to the spec,
- // initContainerOnly means whether to inject specs other than initContainers,
- // just in case of rewriting the specs,
// index refers to the suffix of the initContainer name, like model-loader, model-loader-1.
InjectModelLoader(spec *corev1.PodTemplateSpec, index int)
}
diff --git a/pkg/webhook/playground_webhook.go b/pkg/webhook/playground_webhook.go
index 2fa8240e..4f42be57 100644
--- a/pkg/webhook/playground_webhook.go
+++ b/pkg/webhook/playground_webhook.go
@@ -52,9 +52,12 @@ func (w *PlaygroundWebhook) Default(ctx context.Context, obj runtime.Object) err
var modelName string
if playground.Spec.ModelClaim != nil {
modelName = string(playground.Spec.ModelClaim.ModelName)
- } else if playground.Spec.MultiModelsClaim != nil {
- // We choose the first model as the main model.
- modelName = string(playground.Spec.MultiModelsClaim.ModelNames[0])
+ } else if playground.Spec.ModelClaims != nil {
+ for _, model := range playground.Spec.ModelClaims.Models {
+ if model.Role == nil || *model.Role == coreapi.MainRole {
+ modelName = string(model.Name)
+ }
+ }
}
if playground.Labels == nil {
@@ -95,22 +98,34 @@ func (w *PlaygroundWebhook) generateValidate(obj runtime.Object) field.ErrorList
specPath := field.NewPath("spec")
var allErrs field.ErrorList
- if playground.Spec.ModelClaim == nil && playground.Spec.MultiModelsClaim == nil {
- allErrs = append(allErrs, field.Forbidden(specPath, "modelClaim and multiModelsClaim couldn't be both nil"))
+ if playground.Spec.ModelClaim == nil && playground.Spec.ModelClaims == nil {
+ allErrs = append(allErrs, field.Forbidden(specPath, "modelClaim and modelClaims couldn't be both nil"))
}
- if playground.Spec.MultiModelsClaim != nil {
- if playground.Spec.MultiModelsClaim.InferenceMode == coreapi.SpeculativeDecoding {
- // if playground.Spec.BackendConfig != nil && !(*playground.Spec.BackendConfig.Name == inferenceapi.VLLM || *playground.Spec.BackendConfig.Name == inferenceapi.LLAMACPP) {
- // allErrs = append(allErrs, field.Forbidden(specPath.Child("multiModelsClaim", "inferenceMode"), "only vLLM and llama.cpp supports speculativeDecoding mode"))
- // }
- if playground.Spec.BackendConfig != nil && *playground.Spec.BackendConfig.Name != inferenceapi.VLLM {
- allErrs = append(allErrs, field.Forbidden(specPath.Child("multiModelsClaim", "inferenceMode"), "only vLLM supports speculativeDecoding mode"))
+ if playground.Spec.ModelClaims != nil {
+ mainModelCount := 0
+ var speculativeDecoding bool
+
+ for _, model := range playground.Spec.ModelClaims.Models {
+ if model.Name == coreapi.ModelName(coreapi.MainRole) {
+ mainModelCount += 1
+ }
+ if *model.Role == coreapi.DraftRole {
+ speculativeDecoding = true
}
- if len(playground.Spec.MultiModelsClaim.ModelNames) != 2 {
- allErrs = append(allErrs, field.Forbidden(specPath.Child("multiModelsClaim", "modelNames"), "only two models are allowed in speculativeDecoding mode"))
+ }
+
+ if speculativeDecoding {
+ if len(playground.Spec.ModelClaims.Models) != 2 {
+ allErrs = append(allErrs, field.Forbidden(specPath.Child("modelClaims", "models"), "only two models are allowed in speculativeDecoding mode"))
+ }
+ if playground.Spec.BackendConfig != nil && *playground.Spec.BackendConfig.Name != inferenceapi.VLLM {
+ allErrs = append(allErrs, field.Forbidden(specPath.Child("backendConfig", "name"), "only vLLM supports speculativeDecoding mode"))
}
}
+ if mainModelCount > 1 {
+ allErrs = append(allErrs, field.Forbidden(specPath.Child("modelClaims", "models"), "only one main model is allowed"))
+ }
}
return allErrs
}
diff --git a/pkg/webhook/service_webhook.go b/pkg/webhook/service_webhook.go
index 3756c96e..fd21b2e5 100644
--- a/pkg/webhook/service_webhook.go
+++ b/pkg/webhook/service_webhook.go
@@ -26,6 +26,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/webhook"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
+ coreapi "github.com/inftyai/llmaz/api/core/v1alpha1"
inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1"
modelSource "github.com/inftyai/llmaz/pkg/controller_helper/model_source"
)
@@ -56,7 +57,7 @@ var _ webhook.CustomValidator = &ServiceWebhook{}
// ValidateCreate implements webhook.Validator so a webhook will be registered for the type
func (w *ServiceWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
- allErrs := field.ErrorList{}
+ allErrs := w.generateValidate(obj)
service := obj.(*inferenceapi.Service)
for _, err := range validation.IsDNS1123Label(service.Name) {
allErrs = append(allErrs, field.Invalid(field.NewPath("metadata.name"), service.Name, err))
@@ -78,10 +79,38 @@ func (w *ServiceWebhook) ValidateCreate(ctx context.Context, obj runtime.Object)
// ValidateUpdate implements webhook.Validator so a webhook will be registered for the type
func (w *ServiceWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) {
- return nil, nil
+ allErrs := w.generateValidate(newObj)
+ return nil, allErrs.ToAggregate()
}
// ValidateDelete implements webhook.Validator so a webhook will be registered for the type
func (w *ServiceWebhook) ValidateDelete(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
return nil, nil
}
+
+func (w *ServiceWebhook) generateValidate(obj runtime.Object) field.ErrorList {
+ service := obj.(*inferenceapi.Service)
+ specPath := field.NewPath("spec")
+ var allErrs field.ErrorList
+
+ mainModelCount := 0
+ var speculativeDecoding bool
+ for _, model := range service.Spec.ModelClaims.Models {
+ if model.Role == nil || *model.Role == coreapi.MainRole {
+ mainModelCount += 1
+ }
+ if model.Role != nil && *model.Role == coreapi.DraftRole {
+ speculativeDecoding = true
+ }
+ }
+
+ if speculativeDecoding {
+ if len(service.Spec.ModelClaims.Models) != 2 {
+ allErrs = append(allErrs, field.Forbidden(specPath.Child("modelClaims", "models"), "only two models are allowed in speculativeDecoding mode"))
+ }
+ if mainModelCount != 1 {
+ allErrs = append(allErrs, field.Forbidden(specPath.Child("modelClaims", "models"), "main model is required"))
+ }
+ }
+ return allErrs
+}
diff --git a/test/integration/controller/inference/playground_test.go b/test/integration/controller/inference/playground_test.go
index 65913390..db6b5662 100644
--- a/test/integration/controller/inference/playground_test.go
+++ b/test/integration/controller/inference/playground_test.go
@@ -182,7 +182,7 @@ var _ = ginkgo.Describe("playground controller test", func() {
}),
ginkgo.Entry("Playground with speculativeDecoding", &testValidatingCase{
makePlayground: func() *inferenceapi.Playground {
- return wrapper.MakePlayground("playground", ns.Name).MultiModelsClaim([]string{model.Name, draftModel.Name}, coreapi.SpeculativeDecoding).Label(coreapi.ModelNameLabelKey, model.Name).
+ return wrapper.MakePlayground("playground", ns.Name).ModelClaims([]string{model.Name, draftModel.Name}, []string{"main", "draft"}).Label(coreapi.ModelNameLabelKey, model.Name).
Obj()
},
updates: []*update{
@@ -242,7 +242,7 @@ var _ = ginkgo.Describe("playground controller test", func() {
updateFunc: func(playground *inferenceapi.Playground) {
// Create a service with the same name as the playground.
service := wrapper.MakeService(playground.Name, playground.Namespace).
- ModelsClaim([]string{"llama3-8b"}, coreapi.Standard, nil).
+ ModelClaims([]string{"llama3-8b"}, []string{"main"}).
WorkerTemplate().
Obj()
gomega.Expect(k8sClient.Create(ctx, service)).To(gomega.Succeed())
@@ -256,7 +256,7 @@ var _ = ginkgo.Describe("playground controller test", func() {
// Delete the service, playground should be updated to Pending.
updateFunc: func(playground *inferenceapi.Playground) {
service := wrapper.MakeService(playground.Name, playground.Namespace).
- ModelsClaim([]string{"llama3-8b"}, coreapi.Standard, nil).
+ ModelClaims([]string{"llama3-8b"}, []string{"main"}).
WorkerTemplate().
Obj()
gomega.Expect(k8sClient.Delete(ctx, service)).To(gomega.Succeed())
diff --git a/test/integration/controller/inference/service_test.go b/test/integration/controller/inference/service_test.go
index 4afaae3d..cb18eb72 100644
--- a/test/integration/controller/inference/service_test.go
+++ b/test/integration/controller/inference/service_test.go
@@ -157,7 +157,7 @@ var _ = ginkgo.Describe("inferenceService controller test", func() {
ginkgo.Entry("service created with URI configured Model", &testValidatingCase{
makeService: func() *inferenceapi.Service {
return wrapper.MakeService("service-llama3-8b", ns.Name).
- ModelsClaim([]string{"model-with-uri"}, coreapi.Standard, nil).
+ ModelClaims([]string{"model-with-uri"}, []string{"main"}).
WorkerTemplate().
Obj()
},
@@ -185,7 +185,7 @@ var _ = ginkgo.Describe("inferenceService controller test", func() {
ginkgo.Entry("service created with speculativeDecoding mode", &testValidatingCase{
makeService: func() *inferenceapi.Service {
return wrapper.MakeService("service-llama3-8b", ns.Name).
- ModelsClaim([]string{"llama3-8b", "model-with-uri"}, coreapi.SpeculativeDecoding, nil).
+ ModelClaims([]string{"llama3-8b", "model-with-uri"}, []string{"main", "draft"}).
WorkerTemplate().
Obj()
},
diff --git a/test/integration/webhook/playground_test.go b/test/integration/webhook/playground_test.go
index e61f2f62..0f244be2 100644
--- a/test/integration/webhook/playground_test.go
+++ b/test/integration/webhook/playground_test.go
@@ -89,13 +89,13 @@ var _ = ginkgo.Describe("playground default and validation", func() {
}),
ginkgo.Entry("speculativeDecoding with SGLang is not allowed", &testValidatingCase{
playground: func() *inferenceapi.Playground {
- return wrapper.MakePlayground("playground", ns.Name).Replicas(1).MultiModelsClaim([]string{"llama3-405b", "llama3-8b"}, coreapi.SpeculativeDecoding).Backend(string(inferenceapi.SGLANG)).Obj()
+ return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaims([]string{"llama3-405b", "llama3-8b"}, []string{"main", "draft"}).Backend(string(inferenceapi.SGLANG)).Obj()
},
failed: true,
}),
- ginkgo.Entry("speculativeDecoding with three models claimed", &testValidatingCase{
+ ginkgo.Entry("speculativeDecoding with three models is not allowed", &testValidatingCase{
playground: func() *inferenceapi.Playground {
- return wrapper.MakePlayground("playground", ns.Name).Replicas(1).MultiModelsClaim([]string{"llama3-405b", "llama3-8b", "llama3-2b"}, coreapi.SpeculativeDecoding).Obj()
+ return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaims([]string{"llama3-405b", "llama3-8b", "llama3-2b"}, []string{"main", "draft", "draft"}).Obj()
},
failed: true,
}),
@@ -105,9 +105,9 @@ var _ = ginkgo.Describe("playground default and validation", func() {
},
failed: true,
}),
- ginkgo.Entry("unknown inference mode", &testValidatingCase{
+ ginkgo.Entry("no main model", &testValidatingCase{
playground: func() *inferenceapi.Playground {
- return wrapper.MakePlayground("playground", ns.Name).Replicas(1).MultiModelsClaim([]string{"llama3-405b", "llama3-8b"}, coreapi.InferenceMode("unknown")).Obj()
+ return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaims([]string{"llama3-8b"}, []string{"draft"}).Obj()
},
failed: true,
}),
@@ -133,16 +133,25 @@ var _ = ginkgo.Describe("playground default and validation", func() {
return wrapper.MakePlayground("playground", ns.Name).ModelClaim("llama3-8b").Replicas(1).Label(coreapi.ModelNameLabelKey, "llama3-8b").Obj()
},
}),
- ginkgo.Entry("defaulting inferenceMode with multiModelsClaim", &testDefaultingCase{
+ ginkgo.Entry("defaulting model role with modelClaims", &testDefaultingCase{
playground: func() *inferenceapi.Playground {
playground := wrapper.MakePlayground("playground", ns.Name).Replicas(1).Obj()
- playground.Spec.MultiModelsClaim = &coreapi.MultiModelsClaim{
- ModelNames: []coreapi.ModelName{"llama3-405b", "llama3-8b"},
+ draftRole := coreapi.DraftRole
+ playground.Spec.ModelClaims = &coreapi.ModelClaims{
+ Models: []coreapi.ModelRepresentative{
+ {
+ Name: "llama3-405b",
+ },
+ {
+ Name: "llama3-8b",
+ Role: &draftRole,
+ },
+ },
}
return playground
},
wantPlayground: func() *inferenceapi.Playground {
- return wrapper.MakePlayground("playground", ns.Name).MultiModelsClaim([]string{"llama3-405b", "llama3-8b"}, coreapi.Standard).Replicas(1).Label(coreapi.ModelNameLabelKey, "llama3-405b").Obj()
+ return wrapper.MakePlayground("playground", ns.Name).ModelClaims([]string{"llama3-405b", "llama3-8b"}, []string{"main", "draft"}).Replicas(1).Label(coreapi.ModelNameLabelKey, "llama3-405b").Obj()
},
}),
)
diff --git a/test/integration/webhook/service_test.go b/test/integration/webhook/service_test.go
index 6ed4c476..abd9ada3 100644
--- a/test/integration/webhook/service_test.go
+++ b/test/integration/webhook/service_test.go
@@ -22,7 +22,6 @@ import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
- coreapi "github.com/inftyai/llmaz/api/core/v1alpha1"
inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1"
"github.com/inftyai/llmaz/test/util"
"github.com/inftyai/llmaz/test/util/wrapper"
@@ -73,12 +72,42 @@ var _ = ginkgo.Describe("service default and validation", func() {
ginkgo.Entry("model-runner container doesn't exist", &testValidatingCase{
service: func() *inferenceapi.Service {
return wrapper.MakeService("service-llama3-8b", ns.Name).
- ModelsClaim([]string{"llama3-8b"}, coreapi.Standard, nil).
+ ModelClaims([]string{"llama3-8b"}, []string{"main"}).
WorkerTemplate().
ContainerName("model-runner-fake").
Obj()
},
failed: true,
}),
+ ginkgo.Entry("speculative-decoding with three models", &testValidatingCase{
+ service: func() *inferenceapi.Service {
+ return wrapper.MakeService("service-llama3-8b", ns.Name).
+ ModelClaims([]string{"llama3-405b", "llama3-8b", "llama3-2b"}, []string{"main", "draft", "draft"}).
+ WorkerTemplate().
+ Obj()
+ },
+ failed: true,
+ }),
+ ginkgo.Entry("modelClaims with nil role", &testValidatingCase{
+ service: func() *inferenceapi.Service {
+ service := wrapper.MakeService("service-llama3-8b", ns.Name).
+ ModelClaims([]string{"llama3-405b", "llama3-8b"}, []string{"main", "draft"}).
+ WorkerTemplate().
+ Obj()
+ // Set the role to nil
+ service.Spec.ModelClaims.Models[0].Role = nil
+ return service
+ },
+ failed: false,
+ }),
+ ginkgo.Entry("no main model", &testValidatingCase{
+ service: func() *inferenceapi.Service {
+ return wrapper.MakeService("service-llama3-8b", ns.Name).
+ ModelClaims([]string{"llama3-8b"}, []string{"draft"}).
+ WorkerTemplate().
+ Obj()
+ },
+ failed: true,
+ }),
)
})
diff --git a/test/util/mock.go b/test/util/mock.go
index 91febb94..7d774b23 100644
--- a/test/util/mock.go
+++ b/test/util/mock.go
@@ -35,7 +35,7 @@ func MockASamplePlayground(ns string) *inferenceapi.Playground {
func MockASampleService(ns string) *inferenceapi.Service {
return wrapper.MakeService("service-llama3-8b", ns).
- ModelsClaim([]string{"llama3-8b"}, coreapi.Standard, nil).
+ ModelClaims([]string{"llama3-8b"}, []string{"main"}).
WorkerTemplate().
Obj()
}
diff --git a/test/util/validation/validate_playground.go b/test/util/validation/validate_playground.go
index f5ec1a48..3a5cea13 100644
--- a/test/util/validation/validate_playground.go
+++ b/test/util/validation/validate_playground.go
@@ -46,21 +46,18 @@ func validateModelClaim(ctx context.Context, k8sClient client.Client, playground
return errors.New("failed to get model")
}
- if playground.Spec.ModelClaim.ModelName != service.Spec.MultiModelsClaim.ModelNames[0] {
- return fmt.Errorf("expected modelName %s, got %s", playground.Spec.ModelClaim.ModelName, service.Spec.MultiModelsClaim.ModelNames[0])
+ if playground.Spec.ModelClaim.ModelName != service.Spec.ModelClaims.Models[0].Name {
+ return fmt.Errorf("expected modelName %s, got %s", playground.Spec.ModelClaim.ModelName, service.Spec.ModelClaims.Models[0].Name)
}
- if diff := cmp.Diff(playground.Spec.ModelClaim.InferenceFlavors, service.Spec.MultiModelsClaim.InferenceFlavors); diff != "" {
- return fmt.Errorf("unexpected flavors, want %v, got %v", playground.Spec.ModelClaim.InferenceFlavors, service.Spec.MultiModelsClaim.InferenceFlavors)
+ if diff := cmp.Diff(playground.Spec.ModelClaim.InferenceFlavors, service.Spec.ModelClaims.InferenceFlavors); diff != "" {
+ return fmt.Errorf("unexpected flavors, want %v, got %v", playground.Spec.ModelClaim.InferenceFlavors, service.Spec.ModelClaims.InferenceFlavors)
}
- } else if playground.Spec.MultiModelsClaim != nil {
- if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(playground.Spec.MultiModelsClaim.ModelNames[0]), Namespace: playground.Namespace}, &model); err != nil {
+ } else if playground.Spec.ModelClaims != nil {
+ if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(playground.Spec.ModelClaims.Models[0].Name), Namespace: playground.Namespace}, &model); err != nil {
return errors.New("failed to get model")
}
- if diff := cmp.Diff(playground.Spec.MultiModelsClaim.ModelNames, service.Spec.MultiModelsClaim.ModelNames); diff != "" {
- return fmt.Errorf("expected modelNames, want %s, got %s", playground.Spec.MultiModelsClaim.ModelNames, service.Spec.MultiModelsClaim.ModelNames)
- }
- if diff := cmp.Diff(playground.Spec.MultiModelsClaim.InferenceFlavors, service.Spec.MultiModelsClaim.InferenceFlavors); diff != "" {
- return fmt.Errorf("unexpected flavors, want %v, got %v", playground.Spec.MultiModelsClaim.InferenceFlavors, service.Spec.MultiModelsClaim.InferenceFlavors)
+ if diff := cmp.Diff(*playground.Spec.ModelClaims, service.Spec.ModelClaims); diff != "" {
+ return fmt.Errorf("expected modelClaims, want %v, got %v", *playground.Spec.ModelClaims, service.Spec.ModelClaims)
}
}
@@ -95,7 +92,7 @@ func ValidatePlayground(ctx context.Context, k8sClient client.Client, playground
if service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Name != modelSource.MODEL_RUNNER_CONTAINER_NAME {
return fmt.Errorf("container name not right, want %s, got %s", modelSource.MODEL_RUNNER_CONTAINER_NAME, service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Name)
}
- if diff := cmp.Diff(bkd.Command(), service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Command); diff != "" {
+ if diff := cmp.Diff(bkd.DefaultCommand(), service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Command); diff != "" {
return errors.New("command not right")
}
if playground.Spec.BackendConfig != nil {
diff --git a/test/util/validation/validate_service.go b/test/util/validation/validate_service.go
index 717465ac..3a834553 100644
--- a/test/util/validation/validate_service.go
+++ b/test/util/validation/validate_service.go
@@ -51,10 +51,9 @@ func ValidateService(ctx context.Context, k8sClient client.Client, service *infe
// TODO: multi-host
models := []*coreapi.OpenModel{}
- modelNames := service.Spec.MultiModelsClaim.ModelNames
- for _, modelName := range modelNames {
+ for _, mr := range service.Spec.ModelClaims.Models {
model := &coreapi.OpenModel{}
- if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(modelName)}, model); err != nil {
+ if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(mr.Name)}, model); err != nil {
return errors.New("failed to get model")
}
models = append(models, model)
diff --git a/test/util/wrapper/playground.go b/test/util/wrapper/playground.go
index 5160a0cb..4f7e7023 100644
--- a/test/util/wrapper/playground.go
+++ b/test/util/wrapper/playground.go
@@ -71,23 +71,22 @@ func (w *PlaygroundWrapper) ModelClaim(modelName string, flavorNames ...string)
return w
}
-func (w *PlaygroundWrapper) MultiModelsClaim(modelNames []string, mode coreapi.InferenceMode, flavorNames ...string) *PlaygroundWrapper {
- mNames := []coreapi.ModelName{}
- for _, name := range modelNames {
- mNames = append(mNames, coreapi.ModelName(name))
+func (w *PlaygroundWrapper) ModelClaims(modelNames []string, roles []string, flavorNames ...string) *PlaygroundWrapper {
+ models := []coreapi.ModelRepresentative{}
+ for i, name := range modelNames {
+ models = append(models, coreapi.ModelRepresentative{Name: coreapi.ModelName(name), Role: (*coreapi.ModelRole)(&roles[i])})
+ }
+ w.Spec.ModelClaims = &coreapi.ModelClaims{
+ Models: models,
}
fNames := []coreapi.FlavorName{}
for _, name := range flavorNames {
fNames = append(fNames, coreapi.FlavorName(name))
}
- w.Spec.MultiModelsClaim = &coreapi.MultiModelsClaim{
- InferenceMode: mode,
- ModelNames: mNames,
- }
if len(fNames) > 0 {
- w.Spec.ModelClaim.InferenceFlavors = fNames
+ w.Spec.ModelClaims.InferenceFlavors = fNames
}
return w
}
diff --git a/test/util/wrapper/service.go b/test/util/wrapper/service.go
index 512f074d..e3d4dc5d 100644
--- a/test/util/wrapper/service.go
+++ b/test/util/wrapper/service.go
@@ -45,19 +45,22 @@ func (w *ServiceWrapper) Obj() *inferenceapi.Service {
return &w.Service
}
-func (w *ServiceWrapper) ModelsClaim(modelNames []string, mode coreapi.InferenceMode, flavorNames []string) *ServiceWrapper {
- names := []coreapi.ModelName{}
- for i := range modelNames {
- names = append(names, coreapi.ModelName(modelNames[i]))
+func (w *ServiceWrapper) ModelClaims(modelNames []string, roles []string, flavorNames ...string) *ServiceWrapper {
+ models := []coreapi.ModelRepresentative{}
+ for i, name := range modelNames {
+ models = append(models, coreapi.ModelRepresentative{Name: coreapi.ModelName(name), Role: (*coreapi.ModelRole)(&roles[i])})
}
- flavors := []coreapi.FlavorName{}
- for i := range flavorNames {
- flavors = append(flavors, coreapi.FlavorName(flavorNames[i]))
+ w.Spec.ModelClaims = coreapi.ModelClaims{
+ Models: models,
}
- w.Spec.MultiModelsClaim = coreapi.MultiModelsClaim{
- ModelNames: names,
- InferenceMode: mode,
- InferenceFlavors: flavors,
+
+ fNames := []coreapi.FlavorName{}
+ for _, name := range flavorNames {
+ fNames = append(fNames, coreapi.FlavorName(name))
+ }
+
+ if len(fNames) > 0 {
+ w.Spec.ModelClaims.InferenceFlavors = fNames
}
return w
}