From f45e673799ba7e93977209536674ae837a05f6b9 Mon Sep 17 00:00:00 2001 From: Anders Swanson Date: Thu, 8 Aug 2024 13:52:12 -0700 Subject: [PATCH] feat: oci genai multi-model, serving mode Signed-off-by: Anders Swanson --- go.mod | 2 +- go.sum | 2 + pkg/ai/ocigenai.go | 125 ++++++++++++++++++++++++++++++++++++--------- 3 files changed, 104 insertions(+), 25 deletions(-) diff --git a/go.mod b/go.mod index 74a9039548..9958684499 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,7 @@ require ( github.com/hupe1980/go-huggingface v0.0.15 github.com/kyverno/policy-reporter-kyverno-plugin v1.6.3 github.com/olekukonko/tablewriter v0.0.5 - github.com/oracle/oci-go-sdk/v65 v65.65.1 + github.com/oracle/oci-go-sdk/v65 v65.71.0 github.com/prometheus/prometheus v0.53.1 github.com/pterm/pterm v0.12.79 google.golang.org/api v0.187.0 diff --git a/go.sum b/go.sum index ea81b062a8..693ef98896 100644 --- a/go.sum +++ b/go.sum @@ -2118,6 +2118,8 @@ github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQ github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= github.com/oracle/oci-go-sdk/v65 v65.65.1 h1:sv7uD844tJGa2Vc+2KaByoXQ0FllZDGV/2+9MdxN6nA= github.com/oracle/oci-go-sdk/v65 v65.65.1/go.mod h1:IBEV9l1qBzUpo7zgGaRUhbB05BVfcDGYRFBCPlTcPp0= +github.com/oracle/oci-go-sdk/v65 v65.71.0 h1:eEnFD/CzcoqdAA0xu+EmK32kJL3jfV0oLYNWVzoKNyo= +github.com/oracle/oci-go-sdk/v65 v65.71.0/go.mod h1:IBEV9l1qBzUpo7zgGaRUhbB05BVfcDGYRFBCPlTcPp0= github.com/ovh/go-ovh v1.5.1 h1:P8O+7H+NQuFK9P/j4sFW5C0fvSS2DnHYGPwdVCp45wI= github.com/ovh/go-ovh v1.5.1/go.mod h1:cTVDnl94z4tl8pP1uZ/8jlVxntjSIf09bNcQ5TJSC7c= github.com/owenrumney/squealer v1.2.1 h1:4ryMMT59aaz8VMsqsD+FDkarADJz0F1dcq2fd0DRR+c= diff --git a/pkg/ai/ocigenai.go b/pkg/ai/ocigenai.go index 53c7076289..5fa3f4a2cb 100644 --- a/pkg/ai/ocigenai.go +++ b/pkg/ai/ocigenai.go @@ -16,21 +16,33 @@ package ai import ( "context" "errors" + "fmt" "github.com/oracle/oci-go-sdk/v65/common" + "github.com/oracle/oci-go-sdk/v65/generativeai" "github.com/oracle/oci-go-sdk/v65/generativeaiinference" + "reflect" "strings" ) const ociClientName = "oci" +type ociModelVendor string + +const ( + vendorCohere = "cohere" + vendorMeta = "meta" +) + type OCIGenAIClient struct { nopCloser client *generativeaiinference.GenerativeAiInferenceClient - model string + model *generativeai.Model + modelId string compartmentId string temperature float32 topP float32 + topK int32 maxTokens int } @@ -40,9 +52,10 @@ func (c *OCIGenAIClient) GetName() string { func (c *OCIGenAIClient) Configure(config IAIConfig) error { config.GetEndpointName() - c.model = config.GetModel() + c.modelId = config.GetModel() c.temperature = config.GetTemperature() c.topP = config.GetTopP() + c.topK = config.GetTopK() c.maxTokens = config.GetMaxTokens() c.compartmentId = config.GetCompartmentId() provider := common.DefaultConfigProvider() @@ -51,6 +64,12 @@ func (c *OCIGenAIClient) Configure(config IAIConfig) error { return err } c.client = &client + model, err := c.getModel(provider) + if err != nil { + return err + } + c.model = model + return nil } @@ -60,38 +79,96 @@ func (c *OCIGenAIClient) GetCompletion(ctx context.Context, prompt string) (stri if err != nil { return "", err } - return extractGeneratedText(generateTextResponse.InferenceResponse) + return c.extractGeneratedText(generateTextResponse.InferenceResponse) } func (c *OCIGenAIClient) newGenerateTextRequest(prompt string) generativeaiinference.GenerateTextRequest { - temperatureF64 := float64(c.temperature) - topPF64 := float64(c.topP) return generativeaiinference.GenerateTextRequest{ GenerateTextDetails: generativeaiinference.GenerateTextDetails{ - CompartmentId: &c.compartmentId, - ServingMode: generativeaiinference.OnDemandServingMode{ - ModelId: &c.model, - }, - InferenceRequest: generativeaiinference.CohereLlmInferenceRequest{ - Prompt: &prompt, - MaxTokens: &c.maxTokens, - Temperature: &temperatureF64, - TopP: &topPF64, - }, + CompartmentId: &c.compartmentId, + ServingMode: c.getServingMode(), + InferenceRequest: c.getInferenceRequest(prompt), }, } } -func extractGeneratedText(llmInferenceResponse generativeaiinference.LlmInferenceResponse) (string, error) { - response, ok := llmInferenceResponse.(generativeaiinference.CohereLlmInferenceResponse) - if !ok { - return "", errors.New("failed to extract generated text from backed response") +func (c *OCIGenAIClient) getServingMode() generativeaiinference.ServingMode { + if c.isBaseModel() { + return generativeaiinference.OnDemandServingMode{ + ModelId: &c.modelId, + } + } + return generativeaiinference.DedicatedServingMode{ + EndpointId: &c.modelId, } - sb := strings.Builder{} - for _, text := range response.GeneratedTexts { - if text.Text != nil { - sb.WriteString(*text.Text) +} + +func (c *OCIGenAIClient) getInferenceRequest(prompt string) generativeaiinference.LlmInferenceRequest { + temperatureF64 := float64(c.temperature) + topPF64 := float64(c.topP) + topK := int(c.topP) + + switch c.getVendor() { + case vendorMeta: + return generativeaiinference.LlamaLlmInferenceRequest{ + Prompt: &prompt, + MaxTokens: &c.maxTokens, + Temperature: &temperatureF64, + TopK: &topK, + TopP: &topPF64, + } + default: // Default to cohere + return generativeaiinference.CohereLlmInferenceRequest{ + Prompt: &prompt, + MaxTokens: &c.maxTokens, + Temperature: &temperatureF64, + TopK: &topK, + TopP: &topPF64, + } + } +} + +func (c *OCIGenAIClient) getModel(provider common.ConfigurationProvider) (*generativeai.Model, error) { + client, err := generativeai.NewGenerativeAiClientWithConfigurationProvider(provider) + if err != nil { + return nil, err + } + response, err := client.GetModel(context.Background(), generativeai.GetModelRequest{ + ModelId: &c.modelId, + }) + if err != nil { + return nil, err + } + return &response.Model, nil +} + +func (c *OCIGenAIClient) isBaseModel() bool { + return c.model != nil && c.model.Type == generativeai.ModelTypeBase +} + +func (c *OCIGenAIClient) getVendor() ociModelVendor { + if c.model == nil || c.model.Vendor == nil { + return "" + } + return ociModelVendor(*c.model.Vendor) +} + +func (c *OCIGenAIClient) extractGeneratedText(llmInferenceResponse generativeaiinference.LlmInferenceResponse) (string, error) { + switch response := llmInferenceResponse.(type) { + case generativeaiinference.LlamaLlmInferenceResponse: + if len(response.Choices) > 0 && response.Choices[0].Text != nil { + return *response.Choices[0].Text, nil + } + return "", errors.New("no text found in oci response") + case generativeaiinference.CohereLlmInferenceResponse: + sb := strings.Builder{} + for _, text := range response.GeneratedTexts { + if text.Text != nil { + sb.WriteString(*text.Text) + } } + return sb.String(), nil + default: + return "", fmt.Errorf("unknown oci response type: %s", reflect.TypeOf(llmInferenceResponse).Name()) } - return sb.String(), nil }