From e8fea3791fa914ad8b5db45a1fb03f3c6f69a368 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Thu, 12 Sep 2024 17:02:09 +0800 Subject: [PATCH] Refactor BackendRuntime structure Signed-off-by: kerthcet --- .github/workflows/publish-helm-chart.yaml | 12 ------- Makefile | 3 +- api/core/v1alpha1/model_types.go | 8 +++-- api/core/v1alpha1/zz_generated.deepcopy.go | 10 +++--- .../v1alpha1/backendruntime_types.go | 24 ++++++------- chart/Chart.yaml | 4 +-- chart/crds/backendruntime-crd.yaml | 20 +++++++---- chart/crds/playground-crd.yaml | 9 +++-- chart/crds/service-crd.yaml | 9 +++-- chart/templates/backends/llamacpp.yaml | 6 ++-- chart/templates/backends/sglang.yaml | 4 ++- chart/templates/backends/vllm.yaml | 6 ++-- .../core/v1alpha1/modelclaims.go | 6 ++-- .../{modelrepresentative.go => modelrefer.go} | 14 ++++---- client-go/applyconfiguration/utils.go | 4 +-- .../inference.llmaz.io_backendruntimes.yaml | 20 +++++++---- .../bases/inference.llmaz.io_playgrounds.yaml | 9 +++-- .../bases/inference.llmaz.io_services.yaml | 9 +++-- docs/installation.md | 6 ++-- index.yaml | 12 ++++++- .../inference/playground_controller.go | 10 +++--- pkg/controller_helper/backendruntime.go | 13 +++---- pkg/controller_helper/helper.go | 24 +++++++++---- pkg/webhook/backendruntime_webhook.go | 15 +++----- pkg/webhook/playground_webhook.go | 11 ++---- test/config/backends/llamacpp.yaml | 4 +-- test/config/backends/sglang.yaml | 2 +- test/config/backends/vllm.yaml | 4 +-- test/e2e/playground_test.go | 2 +- .../webhook/backendruntime_test.go | 36 +++++++++---------- test/integration/webhook/playground_test.go | 8 +---- test/util/mock.go | 2 +- test/util/wrapper/backend.go | 4 +-- test/util/wrapper/playground.go | 4 +-- test/util/wrapper/service.go | 4 +-- 35 files changed, 183 insertions(+), 155 deletions(-) delete mode 100644 .github/workflows/publish-helm-chart.yaml rename client-go/applyconfiguration/core/v1alpha1/{modelrepresentative.go => modelrefer.go} (68%) diff --git a/.github/workflows/publish-helm-chart.yaml b/.github/workflows/publish-helm-chart.yaml deleted file mode 100644 index 870634f..0000000 --- a/.github/workflows/publish-helm-chart.yaml +++ /dev/null @@ -1,12 +0,0 @@ -name: Publish Helm Chart - -on: - workflow_dispatch: - -jobs: - publish: - uses: kerthcet/github-workflow-as-kube/.github/workflows/workflow-helm-chart.yaml@main - secrets: - AGENT_TOKEN: ${{ secrets.AGENT_TOKEN }} - with: - repo_url: "https://inftyai.github.io/llmaz" diff --git a/Makefile b/Makefile index 2ffb79e..50ec9ac 100644 --- a/Makefile +++ b/Makefile @@ -295,7 +295,6 @@ $(HELMIFY): $(LOCALBIN) .PHONY: helm helm: manifests kustomize helmify - $(KUBECTL) create namespace llmaz-system --dry-run=client -o yaml | $(KUBECTL) apply -f - $(KUSTOMIZE) build config/default | $(HELMIFY) -crd-dir .PHONY: helm-install @@ -303,7 +302,7 @@ helm-install: helm helm upgrade --install llmaz ./chart --namespace llmaz-system --create-namespace -f ./chart/values.global.yaml .PHONY: helm-package -helm-package: +helm-package: helm # Make sure will alwasy start with a new line. printf "\n" >> ./chart/values.yaml cat ./chart/values.global.yaml >> ./chart/values.yaml diff --git a/api/core/v1alpha1/model_types.go b/api/core/v1alpha1/model_types.go index 199510a..976a56e 100644 --- a/api/core/v1alpha1/model_types.go +++ b/api/core/v1alpha1/model_types.go @@ -131,10 +131,14 @@ const ( DraftRole ModelRole = "draft" ) -type ModelRepresentative struct { +// ModelRefer refers to a created Model with it's role. +type ModelRefer struct { // Name represents the model name. Name ModelName `json:"name"` // Role represents the model role once more than one model is required. + // Such as a draft role, which means running with SpeculativeDecoding, + // and default arguments for backend will be searched in backendRuntime + // with the name of speculative-decoding. // +kubebuilder:validation:Enum={main,draft} // +kubebuilder:default=main // +optional @@ -148,7 +152,7 @@ type ModelClaims struct { // speculative decoding, then one model is main(target) model, another one // is draft model. // +kubebuilder:validation:MinItems=1 - Models []ModelRepresentative `json:"models,omitempty"` + Models []ModelRefer `json:"models,omitempty"` // InferenceFlavors represents a list of flavors with fungibility supported // to serve the model. // - If not set, always apply with the 0-index model by default. diff --git a/api/core/v1alpha1/zz_generated.deepcopy.go b/api/core/v1alpha1/zz_generated.deepcopy.go index 241c4c5..d4da3b4 100644 --- a/api/core/v1alpha1/zz_generated.deepcopy.go +++ b/api/core/v1alpha1/zz_generated.deepcopy.go @@ -87,7 +87,7 @@ func (in *ModelClaims) DeepCopyInto(out *ModelClaims) { *out = *in if in.Models != nil { in, out := &in.Models, &out.Models - *out = make([]ModelRepresentative, len(*in)) + *out = make([]ModelRefer, len(*in)) for i := range *in { (*in)[i].DeepCopyInto(&(*out)[i]) } @@ -140,7 +140,7 @@ func (in *ModelHub) DeepCopy() *ModelHub { } // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. -func (in *ModelRepresentative) DeepCopyInto(out *ModelRepresentative) { +func (in *ModelRefer) DeepCopyInto(out *ModelRefer) { *out = *in if in.Role != nil { in, out := &in.Role, &out.Role @@ -149,12 +149,12 @@ func (in *ModelRepresentative) DeepCopyInto(out *ModelRepresentative) { } } -// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelRepresentative. -func (in *ModelRepresentative) DeepCopy() *ModelRepresentative { +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelRefer. +func (in *ModelRefer) DeepCopy() *ModelRefer { if in == nil { return nil } - out := new(ModelRepresentative) + out := new(ModelRefer) in.DeepCopyInto(out) return out } diff --git a/api/inference/v1alpha1/backendruntime_types.go b/api/inference/v1alpha1/backendruntime_types.go index db1bb81..4167284 100644 --- a/api/inference/v1alpha1/backendruntime_types.go +++ b/api/inference/v1alpha1/backendruntime_types.go @@ -21,16 +21,15 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) -type InferenceMode string - -const ( - DefaultInferenceMode InferenceMode = "Default" - SpeculativeDecodingInferenceMode InferenceMode = "SpeculativeDecoding" -) - +// BackendRuntimeArg is preset arguments for easy to use. +// Do not edit the preset names unless set the argument name explicitly +// in Playground backendRuntimeConfig. type BackendRuntimeArg struct { - Mode InferenceMode `json:"mode"` - Flags []string `json:"flags,omitempty"` + // Name represents the identifier of the backendRuntime argument. + Name string `json:"name"` + // Flags represents all the preset configurations. + // Flag around with {{ .CONFIG }} is a configuration waiting for render. + Flags []string `json:"flags,omitempty"` } // BackendRuntimeSpec defines the desired state of BackendRuntime @@ -43,11 +42,8 @@ type BackendRuntimeSpec struct { // Version represents the default version of the backendRuntime. // It will be appended to the image as a tag. Version string `json:"version"` - // Args represents the args of the backendRuntime. - // They can be appended or overwritten by the Playground args. - // The key is the inference option, like default one or advanced - // speculativeDecoding, the values are the corresponding args. - // Flag around with {{ .XXX }} is a flag waiting for render. + // Args represents the preset arguments of the backendRuntime. + // They can be appended or overwritten by the Playground backendRuntimeConfig. Args []BackendRuntimeArg `json:"args,omitempty"` // Envs represents the environments set to the container. // +optional diff --git a/chart/Chart.yaml b/chart/Chart.yaml index 176f869..6cc3844 100644 --- a/chart/Chart.yaml +++ b/chart/Chart.yaml @@ -13,9 +13,9 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.0.2 +version: 0.0.3 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to # follow Semantic Versioning. They should reflect the version the application is using. # It is recommended to use it with quotes. -appVersion: "0.0.6" +appVersion: 0.0.7 diff --git a/chart/crds/backendruntime-crd.yaml b/chart/crds/backendruntime-crd.yaml index 212d8b6..6d4e967 100644 --- a/chart/crds/backendruntime-crd.yaml +++ b/chart/crds/backendruntime-crd.yaml @@ -42,21 +42,27 @@ spec: properties: args: description: |- - Args represents the args of the backendRuntime. - They can be appended or overwritten by the Playground args. - The key is the inference option, like default one or advanced - speculativeDecoding, the values are the corresponding args. - Flag around with {{ .XXX }} is a flag waiting for render. + Args represents the preset arguments of the backendRuntime. + They can be appended or overwritten by the Playground backendRuntimeConfig. items: + description: |- + BackendRuntimeArg is preset arguments for easy to use. + Do not edit the preset names unless set the argument name explicitly + in Playground backendRuntimeConfig. properties: flags: + description: |- + Flags represents all the preset configurations. + Flag around with {{ .CONFIG }} is a configuration waiting for render. items: type: string type: array - mode: + name: + description: Name represents the identifier of the backendRuntime + argument. type: string required: - - mode + - name type: object type: array commands: diff --git a/chart/crds/playground-crd.yaml b/chart/crds/playground-crd.yaml index 7984ce9..1432085 100644 --- a/chart/crds/playground-crd.yaml +++ b/chart/crds/playground-crd.yaml @@ -259,14 +259,19 @@ spec: speculative decoding, then one model is main(target) model, another one is draft model. items: + description: ModelRefer refers to a created Model with it's + role. properties: name: description: Name represents the model name. type: string role: default: main - description: Role represents the model role once more than - one model is required. + description: |- + Role represents the model role once more than one model is required. + Such as a draft role, which means running with SpeculativeDecoding, + and default arguments for backend will be searched in backendRuntime + with the name of speculative-decoding. enum: - main - draft diff --git a/chart/crds/service-crd.yaml b/chart/crds/service-crd.yaml index 18cc558..d8169c8 100644 --- a/chart/crds/service-crd.yaml +++ b/chart/crds/service-crd.yaml @@ -84,14 +84,19 @@ spec: speculative decoding, then one model is main(target) model, another one is draft model. items: + description: ModelRefer refers to a created Model with it's + role. properties: name: description: Name represents the model name. type: string role: default: main - description: Role represents the model role once more than - one model is required. + description: |- + Role represents the model role once more than one model is required. + Such as a draft role, which means running with SpeculativeDecoding, + and default arguments for backend will be searched in backendRuntime + with the name of speculative-decoding. enum: - main - draft diff --git a/chart/templates/backends/llamacpp.yaml b/chart/templates/backends/llamacpp.yaml index ea54d92..1fc0402 100644 --- a/chart/templates/backends/llamacpp.yaml +++ b/chart/templates/backends/llamacpp.yaml @@ -12,8 +12,10 @@ spec: - ./llama-server image: ghcr.io/ggerganov/llama.cpp version: server + # Do not edit the preset argument name unless you know what you're doing. + # Free to add more arguments with your requirements. args: - - mode: Default + - name: default flags: - -m - "{{`{{ .ModelPath }}`}}" @@ -21,7 +23,7 @@ spec: - "0.0.0.0" - --port - "8080" - - mode: SpeculativeDecoding + - name: speculative-decoding flags: - -m - "{{`{{ .ModelPath }}`}}" diff --git a/chart/templates/backends/sglang.yaml b/chart/templates/backends/sglang.yaml index fdc51a2..811981f 100644 --- a/chart/templates/backends/sglang.yaml +++ b/chart/templates/backends/sglang.yaml @@ -14,8 +14,10 @@ spec: - sglang.launch_server image: lmsysorg/sglang version: v0.2.10-cu121 + # Do not edit the preset argument name unless you know what you're doing. + # Free to add more arguments with your requirements. args: - - mode: Default + - name: default flags: - --model-path - "{{`{{ .ModelPath }}`}}" diff --git a/chart/templates/backends/vllm.yaml b/chart/templates/backends/vllm.yaml index 1126b74..ea0f239 100644 --- a/chart/templates/backends/vllm.yaml +++ b/chart/templates/backends/vllm.yaml @@ -14,8 +14,10 @@ spec: - vllm.entrypoints.openai.api_server image: vllm/vllm-openai version: v0.6.0 + # Do not edit the preset argument name unless you know what you're doing. + # Free to add more arguments with your requirements. args: - - mode: Default + - name: default flags: - --model - "{{`{{ .ModelPath }}`}}" @@ -25,7 +27,7 @@ spec: - "0.0.0.0" - --port - "8080" - - mode: SpeculativeDecoding + - name: speculative-decoding flags: - --model - "{{`{{ .ModelPath }}`}}" diff --git a/client-go/applyconfiguration/core/v1alpha1/modelclaims.go b/client-go/applyconfiguration/core/v1alpha1/modelclaims.go index 52760ef..6840022 100644 --- a/client-go/applyconfiguration/core/v1alpha1/modelclaims.go +++ b/client-go/applyconfiguration/core/v1alpha1/modelclaims.go @@ -24,8 +24,8 @@ import ( // ModelClaimsApplyConfiguration represents an declarative configuration of the ModelClaims type for use // with apply. type ModelClaimsApplyConfiguration struct { - Models []ModelRepresentativeApplyConfiguration `json:"models,omitempty"` - InferenceFlavors []corev1alpha1.FlavorName `json:"inferenceFlavors,omitempty"` + Models []ModelReferApplyConfiguration `json:"models,omitempty"` + InferenceFlavors []corev1alpha1.FlavorName `json:"inferenceFlavors,omitempty"` } // ModelClaimsApplyConfiguration constructs an declarative configuration of the ModelClaims type for use with @@ -37,7 +37,7 @@ func ModelClaims() *ModelClaimsApplyConfiguration { // WithModels adds the given value to the Models field in the declarative configuration // and returns the receiver, so that objects can be build by chaining "With" function invocations. // If called multiple times, values provided by each call will be appended to the Models field. -func (b *ModelClaimsApplyConfiguration) WithModels(values ...*ModelRepresentativeApplyConfiguration) *ModelClaimsApplyConfiguration { +func (b *ModelClaimsApplyConfiguration) WithModels(values ...*ModelReferApplyConfiguration) *ModelClaimsApplyConfiguration { for i := range values { if values[i] == nil { panic("nil value passed to WithModels") diff --git a/client-go/applyconfiguration/core/v1alpha1/modelrepresentative.go b/client-go/applyconfiguration/core/v1alpha1/modelrefer.go similarity index 68% rename from client-go/applyconfiguration/core/v1alpha1/modelrepresentative.go rename to client-go/applyconfiguration/core/v1alpha1/modelrefer.go index 83477b2..9acc944 100644 --- a/client-go/applyconfiguration/core/v1alpha1/modelrepresentative.go +++ b/client-go/applyconfiguration/core/v1alpha1/modelrefer.go @@ -21,23 +21,23 @@ import ( v1alpha1 "github.com/inftyai/llmaz/api/core/v1alpha1" ) -// ModelRepresentativeApplyConfiguration represents an declarative configuration of the ModelRepresentative type for use +// ModelReferApplyConfiguration represents an declarative configuration of the ModelRefer type for use // with apply. -type ModelRepresentativeApplyConfiguration struct { +type ModelReferApplyConfiguration struct { Name *v1alpha1.ModelName `json:"name,omitempty"` Role *v1alpha1.ModelRole `json:"role,omitempty"` } -// ModelRepresentativeApplyConfiguration constructs an declarative configuration of the ModelRepresentative type for use with +// ModelReferApplyConfiguration constructs an declarative configuration of the ModelRefer type for use with // apply. -func ModelRepresentative() *ModelRepresentativeApplyConfiguration { - return &ModelRepresentativeApplyConfiguration{} +func ModelRefer() *ModelReferApplyConfiguration { + return &ModelReferApplyConfiguration{} } // WithName sets the Name field in the declarative configuration to the given value // and returns the receiver, so that objects can be built by chaining "With" function invocations. // If called multiple times, the Name field is set to the value of the last call. -func (b *ModelRepresentativeApplyConfiguration) WithName(value v1alpha1.ModelName) *ModelRepresentativeApplyConfiguration { +func (b *ModelReferApplyConfiguration) WithName(value v1alpha1.ModelName) *ModelReferApplyConfiguration { b.Name = &value return b } @@ -45,7 +45,7 @@ func (b *ModelRepresentativeApplyConfiguration) WithName(value v1alpha1.ModelNam // WithRole sets the Role field in the declarative configuration to the given value // and returns the receiver, so that objects can be built by chaining "With" function invocations. // If called multiple times, the Role field is set to the value of the last call. -func (b *ModelRepresentativeApplyConfiguration) WithRole(value v1alpha1.ModelRole) *ModelRepresentativeApplyConfiguration { +func (b *ModelReferApplyConfiguration) WithRole(value v1alpha1.ModelRole) *ModelReferApplyConfiguration { b.Role = &value return b } diff --git a/client-go/applyconfiguration/utils.go b/client-go/applyconfiguration/utils.go index 2245365..6a9c4d0 100644 --- a/client-go/applyconfiguration/utils.go +++ b/client-go/applyconfiguration/utils.go @@ -58,8 +58,8 @@ func ForKind(kind schema.GroupVersionKind) interface{} { return &applyconfigurationcorev1alpha1.ModelClaimsApplyConfiguration{} case corev1alpha1.SchemeGroupVersion.WithKind("ModelHub"): return &applyconfigurationcorev1alpha1.ModelHubApplyConfiguration{} - case corev1alpha1.SchemeGroupVersion.WithKind("ModelRepresentative"): - return &applyconfigurationcorev1alpha1.ModelRepresentativeApplyConfiguration{} + case corev1alpha1.SchemeGroupVersion.WithKind("ModelRefer"): + return &applyconfigurationcorev1alpha1.ModelReferApplyConfiguration{} case corev1alpha1.SchemeGroupVersion.WithKind("ModelSource"): return &applyconfigurationcorev1alpha1.ModelSourceApplyConfiguration{} case corev1alpha1.SchemeGroupVersion.WithKind("ModelSpec"): diff --git a/config/crd/bases/inference.llmaz.io_backendruntimes.yaml b/config/crd/bases/inference.llmaz.io_backendruntimes.yaml index a9654fb..1b2d645 100644 --- a/config/crd/bases/inference.llmaz.io_backendruntimes.yaml +++ b/config/crd/bases/inference.llmaz.io_backendruntimes.yaml @@ -43,21 +43,27 @@ spec: properties: args: description: |- - Args represents the args of the backendRuntime. - They can be appended or overwritten by the Playground args. - The key is the inference option, like default one or advanced - speculativeDecoding, the values are the corresponding args. - Flag around with {{ .XXX }} is a flag waiting for render. + Args represents the preset arguments of the backendRuntime. + They can be appended or overwritten by the Playground backendRuntimeConfig. items: + description: |- + BackendRuntimeArg is preset arguments for easy to use. + Do not edit the preset names unless set the argument name explicitly + in Playground backendRuntimeConfig. properties: flags: + description: |- + Flags represents all the preset configurations. + Flag around with {{ .CONFIG }} is a configuration waiting for render. items: type: string type: array - mode: + name: + description: Name represents the identifier of the backendRuntime + argument. type: string required: - - mode + - name type: object type: array commands: diff --git a/config/crd/bases/inference.llmaz.io_playgrounds.yaml b/config/crd/bases/inference.llmaz.io_playgrounds.yaml index ee9d86e..db2ddb0 100644 --- a/config/crd/bases/inference.llmaz.io_playgrounds.yaml +++ b/config/crd/bases/inference.llmaz.io_playgrounds.yaml @@ -260,14 +260,19 @@ spec: speculative decoding, then one model is main(target) model, another one is draft model. items: + description: ModelRefer refers to a created Model with it's + role. properties: name: description: Name represents the model name. type: string role: default: main - description: Role represents the model role once more than - one model is required. + description: |- + Role represents the model role once more than one model is required. + Such as a draft role, which means running with SpeculativeDecoding, + and default arguments for backend will be searched in backendRuntime + with the name of speculative-decoding. enum: - main - draft diff --git a/config/crd/bases/inference.llmaz.io_services.yaml b/config/crd/bases/inference.llmaz.io_services.yaml index f00ce46..116d54d 100644 --- a/config/crd/bases/inference.llmaz.io_services.yaml +++ b/config/crd/bases/inference.llmaz.io_services.yaml @@ -85,14 +85,19 @@ spec: speculative decoding, then one model is main(target) model, another one is draft model. items: + description: ModelRefer refers to a created Model with it's + role. properties: name: description: Name represents the model name. type: string role: default: main - description: Role represents the model role once more than - one model is required. + description: |- + Role represents the model role once more than one model is required. + Such as a draft role, which means running with SpeculativeDecoding, + and default arguments for backend will be searched in backendRuntime + with the name of speculative-decoding. enum: - main - draft diff --git a/docs/installation.md b/docs/installation.md index 5559d23..f301eec 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -22,7 +22,8 @@ helm uninstall llmaz kubectl delete ns llmaz-system ``` -If you want to delete the CRDs as well, run (ignore the error) +If you want to delete the CRDs as well, run + ```cmd kubectl delete crd \ openmodels.llmaz.io \ @@ -47,7 +48,8 @@ helm uninstall llmaz kubectl delete ns llmaz-system ``` -If you want to delete the CRDs as well, run (ignore the error) +If you want to delete the CRDs as well, run + ```cmd kubectl delete crd \ openmodels.llmaz.io \ diff --git a/index.yaml b/index.yaml index 054dccd..260a81f 100644 --- a/index.yaml +++ b/index.yaml @@ -1,6 +1,16 @@ apiVersion: v1 entries: llmaz: + - apiVersion: v2 + appVersion: 0.0.7 + created: "2024-09-12T16:49:31.224669+08:00" + description: A Helm chart for llmaz + digest: 2f4f376184d7e8971bcfc10a8d110307a989df4a9bd3aaf31f05fc738aa8c5cf + name: llmaz + type: application + urls: + - https://inftyai.github.io/llmaz/llmaz-0.0.3.tgz + version: 0.0.3 - apiVersion: v2 appVersion: 0.0.6 created: "2024-09-11T21:18:24.980219+08:00" @@ -21,4 +31,4 @@ entries: urls: - https://inftyai.github.io/llmaz/llmaz-0.0.1.tgz version: 0.0.1 -generated: "2024-09-11T21:18:24.967532+08:00" +generated: "2024-09-12T16:49:31.210833+08:00" diff --git a/pkg/controller/inference/playground_controller.go b/pkg/controller/inference/playground_controller.go index 61ed10b..20a846d 100644 --- a/pkg/controller/inference/playground_controller.go +++ b/pkg/controller/inference/playground_controller.go @@ -114,7 +114,7 @@ func (r *PlaygroundReconciler) Reconcile(ctx context.Context, req ctrl.Request) serviceApplyConfiguration, err := buildServiceApplyConfiguration(models, playground, backendRuntime) if err != nil { - logger.Error(err, "failed to get build inference Service") + logger.Error(err, "failed to build inference Service") return ctrl.Result{}, err } @@ -195,16 +195,16 @@ func buildServiceApplyConfiguration(models []*coreapi.OpenModel, playground *inf var claim *coreclientgo.ModelClaimsApplyConfiguration if playground.Spec.ModelClaim != nil { claim = coreclientgo.ModelClaims(). - WithModels(coreclientgo.ModelRepresentative().WithName(playground.Spec.ModelClaim.ModelName).WithRole(coreapi.MainRole)). + WithModels(coreclientgo.ModelRefer().WithName(playground.Spec.ModelClaim.ModelName).WithRole(coreapi.MainRole)). WithInferenceFlavors(playground.Spec.ModelClaim.InferenceFlavors...) } else { - mrs := []*coreclientgo.ModelRepresentativeApplyConfiguration{} + mrs := []*coreclientgo.ModelReferApplyConfiguration{} for _, model := range playground.Spec.ModelClaims.Models { role := coreapi.MainRole if model.Role != nil { role = *model.Role } - mr := coreclientgo.ModelRepresentative().WithName(model.Name).WithRole(role) + mr := coreclientgo.ModelRefer().WithName(model.Name).WithRole(role) mrs = append(mrs, mr) } @@ -257,7 +257,7 @@ func buildWorkloadTemplate(models []*coreapi.OpenModel, playground *inferenceapi func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.Playground, backendRuntime *inferenceapi.BackendRuntime) (corev1.PodTemplateSpec, error) { parser := helper.NewBackendRuntimeParser(backendRuntime) - args, err := parser.Args(helper.InferenceMode(playground), models) + args, err := parser.Args(helper.PlaygroundInferenceMode(playground), models) if err != nil { return corev1.PodTemplateSpec{}, err } diff --git a/pkg/controller_helper/backendruntime.go b/pkg/controller_helper/backendruntime.go index 61f6463..ded7c24 100644 --- a/pkg/controller_helper/backendruntime.go +++ b/pkg/controller_helper/backendruntime.go @@ -44,14 +44,14 @@ func (p *BackendRuntimeParser) Envs() []corev1.EnvVar { return p.backendRuntime.Spec.Envs } -func (p *BackendRuntimeParser) Args(mode inferenceapi.InferenceMode, models []*coreapi.OpenModel) ([]string, error) { - if mode == inferenceapi.SpeculativeDecodingInferenceMode && len(models) != 2 { +func (p *BackendRuntimeParser) Args(mode InferenceMode, models []*coreapi.OpenModel) ([]string, error) { + if mode == SpeculativeDecodingInferenceMode && len(models) != 2 { return nil, fmt.Errorf("models number not right, want 2, got %d", len(models)) } modelInfo := map[string]string{} - if mode == inferenceapi.DefaultInferenceMode { + if mode == DefaultInferenceMode { source := modelSource.NewModelSourceProvider(models[0]) modelInfo = map[string]string{ "ModelPath": source.ModelPath(), @@ -59,7 +59,7 @@ func (p *BackendRuntimeParser) Args(mode inferenceapi.InferenceMode, models []*c } } - if mode == inferenceapi.SpeculativeDecodingInferenceMode { + if mode == SpeculativeDecodingInferenceMode { targetSource := modelSource.NewModelSourceProvider(models[0]) draftSource := modelSource.NewModelSourceProvider(models[1]) modelInfo = map[string]string{ @@ -70,12 +70,13 @@ func (p *BackendRuntimeParser) Args(mode inferenceapi.InferenceMode, models []*c } for _, arg := range p.backendRuntime.Spec.Args { - if arg.Mode == mode { + if InferenceMode(arg.Name) == mode { return renderFlags(arg.Flags, modelInfo) } } + // We should not reach here. - return nil, fmt.Errorf("backendRuntime %s not supported", p.backendRuntime.Name) + return nil, fmt.Errorf("failed to parse backendRuntime %s", p.backendRuntime.Name) } func (p *BackendRuntimeParser) Image(version string) string { diff --git a/pkg/controller_helper/helper.go b/pkg/controller_helper/helper.go index 5e353c0..7405e2c 100644 --- a/pkg/controller_helper/helper.go +++ b/pkg/controller_helper/helper.go @@ -25,21 +25,31 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) -func InferenceMode(playground *inferenceapi.Playground) inferenceapi.InferenceMode { +type InferenceMode string + +// These two modes are preset. +const ( + DefaultInferenceMode InferenceMode = "default" + SpeculativeDecodingInferenceMode InferenceMode = "speculative-decoding" +) + +// PlaygroundInferenceMode gets the mode of inference process, supports default +// or speculative-decoding for now, which is aligned with backendRuntime. +func PlaygroundInferenceMode(playground *inferenceapi.Playground) InferenceMode { if playground.Spec.ModelClaim != nil { - return inferenceapi.DefaultInferenceMode + return DefaultInferenceMode } if playground.Spec.ModelClaims != nil { for _, mr := range playground.Spec.ModelClaims.Models { if *mr.Role == coreapi.DraftRole { - return inferenceapi.SpeculativeDecodingInferenceMode + return SpeculativeDecodingInferenceMode } } } // We should not reach here. - return inferenceapi.DefaultInferenceMode + return DefaultInferenceMode } func FetchModelsByService(ctx context.Context, k8sClient client.Client, service *inferenceapi.Service) (models []*coreapi.OpenModel, err error) { @@ -48,10 +58,10 @@ func FetchModelsByService(ctx context.Context, k8sClient client.Client, service func FetchModelsByPlayground(ctx context.Context, k8sClient client.Client, playground *inferenceapi.Playground) (models []*coreapi.OpenModel, err error) { mainRole := coreapi.MainRole - mrs := []coreapi.ModelRepresentative{} + mrs := []coreapi.ModelRefer{} if playground.Spec.ModelClaim != nil { - mrs = append(mrs, coreapi.ModelRepresentative{Name: playground.Spec.ModelClaim.ModelName, Role: &mainRole}) + mrs = append(mrs, coreapi.ModelRefer{Name: playground.Spec.ModelClaim.ModelName, Role: &mainRole}) } else { mrs = playground.Spec.ModelClaims.Models } @@ -59,7 +69,7 @@ func FetchModelsByPlayground(ctx context.Context, k8sClient client.Client, playg return fetchModels(ctx, k8sClient, mrs) } -func fetchModels(ctx context.Context, k8sClient client.Client, mrs []coreapi.ModelRepresentative) (models []*coreapi.OpenModel, err error) { +func fetchModels(ctx context.Context, k8sClient client.Client, mrs []coreapi.ModelRefer) (models []*coreapi.OpenModel, err error) { for _, mr := range mrs { model := &coreapi.OpenModel{} if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(mr.Name)}, model); err != nil { diff --git a/pkg/webhook/backendruntime_webhook.go b/pkg/webhook/backendruntime_webhook.go index 6f35354..ad55b4a 100644 --- a/pkg/webhook/backendruntime_webhook.go +++ b/pkg/webhook/backendruntime_webhook.go @@ -71,7 +71,6 @@ func (w *BackendRuntimeWebhook) ValidateDelete(ctx context.Context, obj runtime. return nil, nil } -// TODO: the mode name should not be duplicated. func (w *BackendRuntimeWebhook) generateValidate(obj runtime.Object) field.ErrorList { backend := obj.(*inferenceapi.BackendRuntime) specPath := field.NewPath("spec") @@ -87,18 +86,12 @@ func (w *BackendRuntimeWebhook) generateValidate(obj runtime.Object) field.Error } } - modes := []string{} - + names := []string{} for _, arg := range backend.Spec.Args { - if util.In(modes, string(arg.Mode)) { - allErrs = append(allErrs, field.Forbidden(specPath.Child("args", "mode"), fmt.Sprintf("duplicated mode %s", arg.Mode))) - } - // TODO: this may change in the future if user wants to customized there flags for easy usage. - // See https://github.com/InftyAI/llmaz/issues/140 - if !(arg.Mode == inferenceapi.DefaultInferenceMode || arg.Mode == inferenceapi.SpeculativeDecodingInferenceMode) { - allErrs = append(allErrs, field.Forbidden(specPath.Child("args", "mode"), fmt.Sprintf("inferenceMode of %s is forbidden", arg.Mode))) + if util.In(names, arg.Name) { + allErrs = append(allErrs, field.Forbidden(specPath.Child("args", "name"), fmt.Sprintf("duplicated name %s", arg.Name))) } - modes = append(modes, string(arg.Mode)) + names = append(names, arg.Name) } return allErrs } diff --git a/pkg/webhook/playground_webhook.go b/pkg/webhook/playground_webhook.go index 0cacc23..f8af447 100644 --- a/pkg/webhook/playground_webhook.go +++ b/pkg/webhook/playground_webhook.go @@ -28,6 +28,7 @@ import ( coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" + helper "github.com/inftyai/llmaz/pkg/controller_helper" ) type PlaygroundWebhook struct{} @@ -103,24 +104,18 @@ func (w *PlaygroundWebhook) generateValidate(obj runtime.Object) field.ErrorList } if playground.Spec.ModelClaims != nil { mainModelCount := 0 - var speculativeDecoding bool for _, model := range playground.Spec.ModelClaims.Models { if model.Name == coreapi.ModelName(coreapi.MainRole) { mainModelCount += 1 } - if *model.Role == coreapi.DraftRole { - speculativeDecoding = true - } } - if speculativeDecoding { + mode := helper.PlaygroundInferenceMode(playground) + if mode == helper.SpeculativeDecodingInferenceMode { if len(playground.Spec.ModelClaims.Models) != 2 { allErrs = append(allErrs, field.Forbidden(specPath.Child("modelClaims", "models"), "only two models are allowed in speculativeDecoding mode")) } - if playground.Spec.BackendRuntimeConfig != nil && *playground.Spec.BackendRuntimeConfig.Name != inferenceapi.VLLM { - allErrs = append(allErrs, field.Forbidden(specPath.Child("backendRuntimeConfig", "name"), "only vLLM supports speculativeDecoding mode")) - } } if mainModelCount > 1 { diff --git a/test/config/backends/llamacpp.yaml b/test/config/backends/llamacpp.yaml index 2360973..da57e3d 100644 --- a/test/config/backends/llamacpp.yaml +++ b/test/config/backends/llamacpp.yaml @@ -12,7 +12,7 @@ spec: image: ghcr.io/ggerganov/llama.cpp version: server args: - - mode: Default + - name: default flags: - -m - "{{ .ModelPath }}" @@ -20,7 +20,7 @@ spec: - "0.0.0.0" - --port - "8080" - - mode: SpeculativeDecoding + - name: speculative-decoding flags: - -m - "{{ .ModelPath }}" diff --git a/test/config/backends/sglang.yaml b/test/config/backends/sglang.yaml index 7716d95..8d5b4ea 100644 --- a/test/config/backends/sglang.yaml +++ b/test/config/backends/sglang.yaml @@ -14,7 +14,7 @@ spec: image: lmsysorg/sglang version: v0.2.10-cu121 args: - - mode: Default + - name: default flags: - --model-path - "{{ .ModelPath }}" diff --git a/test/config/backends/vllm.yaml b/test/config/backends/vllm.yaml index 14ca8b7..7a6b564 100644 --- a/test/config/backends/vllm.yaml +++ b/test/config/backends/vllm.yaml @@ -14,7 +14,7 @@ spec: image: vllm/vllm-openai version: v0.6.0 args: - - mode: Default + - name: default flags: - --model - "{{ .ModelPath }}" @@ -24,7 +24,7 @@ spec: - "0.0.0.0" - --port - "8080" - - mode: SpeculativeDecoding + - name: speculative-decoding flags: - --model - "{{ .ModelPath }}" diff --git a/test/e2e/playground_test.go b/test/e2e/playground_test.go index 344fe2d..2b2f2e3 100644 --- a/test/e2e/playground_test.go +++ b/test/e2e/playground_test.go @@ -69,7 +69,7 @@ var _ = ginkgo.Describe("playground e2e tests", func() { backendRuntime := wrapper.MakeBackendRuntime("llmaz-llamacpp"). Image("ghcr.io/ggerganov/llama.cpp").Version("server"). Command([]string{"./llama-server"}). - Arg("Default", []string{"-m", "{{.ModelPath}}", "--host", "0.0.0.0", "--port", "8080"}). + Arg("default", []string{"-m", "{{.ModelPath}}", "--host", "0.0.0.0", "--port", "8080"}). Request("cpu", "2").Request("memory", "4Gi").Limit("cpu", "4").Limit("memory", "4Gi").Obj() gomega.Expect(k8sClient.Create(ctx, backendRuntime)).To(gomega.Succeed()) diff --git a/test/integration/webhook/backendruntime_test.go b/test/integration/webhook/backendruntime_test.go index a9f55d5..9e62b6f 100644 --- a/test/integration/webhook/backendruntime_test.go +++ b/test/integration/webhook/backendruntime_test.go @@ -20,10 +20,22 @@ import ( "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" + inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" "github.com/inftyai/llmaz/test/util" ) var _ = ginkgo.Describe("BackendRuntime default and validation", func() { + + // Delete all backendRuntimes for each case. + ginkgo.AfterEach(func() { + var runtimes inferenceapi.BackendRuntimeList + gomega.Expect(k8sClient.List(ctx, &runtimes)).To(gomega.Succeed()) + + for _, runtime := range runtimes.Items { + gomega.Expect(k8sClient.Delete(ctx, &runtime)).To(gomega.Succeed()) + } + }) + type testValidatingCase struct { creationFunc func() error failed bool @@ -43,13 +55,6 @@ var _ = ginkgo.Describe("BackendRuntime default and validation", func() { }, failed: false, }), - ginkgo.Entry("BackendRuntime creation with no image", &testValidatingCase{ - creationFunc: func() error { - runtime := util.MockASampleBackendRuntime().Image("").Obj() - return k8sClient.Create(ctx, runtime) - }, - failed: true, - }), ginkgo.Entry("BackendRuntime creation with limits less than requests", &testValidatingCase{ creationFunc: func() error { runtime := util.MockASampleBackendRuntime().Limit("cpu", "1").Obj() @@ -57,24 +62,17 @@ var _ = ginkgo.Describe("BackendRuntime default and validation", func() { }, failed: true, }), - ginkgo.Entry("BackendRuntime creation with unsupported inferenceMode", &testValidatingCase{ + ginkgo.Entry("BackendRuntime creation with unknown argument name", &testValidatingCase{ creationFunc: func() error { runtime := util.MockASampleBackendRuntime().Arg("unknown", []string{"foo", "bar"}).Obj() return k8sClient.Create(ctx, runtime) }, - failed: true, + failed: false, }), - ginkgo.Entry("BackendRuntime creation with duplicated inferenceMode", &testValidatingCase{ + ginkgo.Entry("BackendRuntime creation with duplicated argument name", &testValidatingCase{ creationFunc: func() error { - runtime := util.MockASampleBackendRuntime().Obj() - if err := k8sClient.Create(ctx, runtime); err != nil { - return err - } - anotherRuntime := util.MockASampleBackendRuntime().Name("another-vllm").Obj() - if err := k8sClient.Create(ctx, anotherRuntime); err != nil { - return err - } - return nil + runtime := util.MockASampleBackendRuntime().Arg("default", []string{"foo", "bar"}).Obj() + return k8sClient.Create(ctx, runtime) }, failed: true, }), diff --git a/test/integration/webhook/playground_test.go b/test/integration/webhook/playground_test.go index 8df0f8f..aa2f8aa 100644 --- a/test/integration/webhook/playground_test.go +++ b/test/integration/webhook/playground_test.go @@ -87,12 +87,6 @@ var _ = ginkgo.Describe("Playground default and validation", func() { }, failed: false, }), - ginkgo.Entry("speculativeDecoding with SGLang is not allowed", &testValidatingCase{ - playground: func() *inferenceapi.Playground { - return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaims([]string{"llama3-405b", "llama3-8b"}, []string{"main", "draft"}).BackendRuntime(string(inferenceapi.SGLANG)).Obj() - }, - failed: true, - }), ginkgo.Entry("speculativeDecoding with three models is not allowed", &testValidatingCase{ playground: func() *inferenceapi.Playground { return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaims([]string{"llama3-405b", "llama3-8b", "llama3-2b"}, []string{"main", "draft", "draft"}).Obj() @@ -138,7 +132,7 @@ var _ = ginkgo.Describe("Playground default and validation", func() { playground := wrapper.MakePlayground("playground", ns.Name).Replicas(1).Obj() draftRole := coreapi.DraftRole playground.Spec.ModelClaims = &coreapi.ModelClaims{ - Models: []coreapi.ModelRepresentative{ + Models: []coreapi.ModelRefer{ { Name: "llama3-405b", }, diff --git a/test/util/mock.go b/test/util/mock.go index 6642b9b..54b1432 100644 --- a/test/util/mock.go +++ b/test/util/mock.go @@ -44,6 +44,6 @@ func MockASampleBackendRuntime() *wrapper.BackendRuntimeWrapper { return wrapper.MakeBackendRuntime("vllm"). Image("vllm/vllm-openai").Version("v0.6.0"). Command([]string{"python3", "-m", "vllm.entrypoints.openai.api_server"}). - Arg("Default", []string{"--model", "{{.ModelPath}}", "--served-model-name", "{{.ModelName}}", "--host", "0.0.0.0", "--port", "8080"}). + Arg("default", []string{"--model", "{{.ModelPath}}", "--served-model-name", "{{.ModelName}}", "--host", "0.0.0.0", "--port", "8080"}). Request("cpu", "4").Limit("cpu", "4") } diff --git a/test/util/wrapper/backend.go b/test/util/wrapper/backend.go index 66c4fab..4deb1b5 100644 --- a/test/util/wrapper/backend.go +++ b/test/util/wrapper/backend.go @@ -62,9 +62,9 @@ func (w *BackendRuntimeWrapper) Command(commands []string) *BackendRuntimeWrappe return w } -func (w *BackendRuntimeWrapper) Arg(mode string, flags []string) *BackendRuntimeWrapper { +func (w *BackendRuntimeWrapper) Arg(name string, flags []string) *BackendRuntimeWrapper { w.Spec.Args = append(w.Spec.Args, inferenceapi.BackendRuntimeArg{ - Mode: inferenceapi.InferenceMode(mode), + Name: name, Flags: flags, }) return w diff --git a/test/util/wrapper/playground.go b/test/util/wrapper/playground.go index f5ec5e2..54aabab 100644 --- a/test/util/wrapper/playground.go +++ b/test/util/wrapper/playground.go @@ -72,9 +72,9 @@ func (w *PlaygroundWrapper) ModelClaim(modelName string, flavorNames ...string) } func (w *PlaygroundWrapper) ModelClaims(modelNames []string, roles []string, flavorNames ...string) *PlaygroundWrapper { - models := []coreapi.ModelRepresentative{} + models := []coreapi.ModelRefer{} for i, name := range modelNames { - models = append(models, coreapi.ModelRepresentative{Name: coreapi.ModelName(name), Role: (*coreapi.ModelRole)(&roles[i])}) + models = append(models, coreapi.ModelRefer{Name: coreapi.ModelName(name), Role: (*coreapi.ModelRole)(&roles[i])}) } w.Spec.ModelClaims = &coreapi.ModelClaims{ Models: models, diff --git a/test/util/wrapper/service.go b/test/util/wrapper/service.go index e3d4dc5..37b6268 100644 --- a/test/util/wrapper/service.go +++ b/test/util/wrapper/service.go @@ -46,9 +46,9 @@ func (w *ServiceWrapper) Obj() *inferenceapi.Service { } func (w *ServiceWrapper) ModelClaims(modelNames []string, roles []string, flavorNames ...string) *ServiceWrapper { - models := []coreapi.ModelRepresentative{} + models := []coreapi.ModelRefer{} for i, name := range modelNames { - models = append(models, coreapi.ModelRepresentative{Name: coreapi.ModelName(name), Role: (*coreapi.ModelRole)(&roles[i])}) + models = append(models, coreapi.ModelRefer{Name: coreapi.ModelName(name), Role: (*coreapi.ModelRole)(&roles[i])}) } w.Spec.ModelClaims = coreapi.ModelClaims{ Models: models,