From 876fa3c162a5540fee143ec5bc7102f06dc36a7f Mon Sep 17 00:00:00 2001 From: Spike Lu Date: Sun, 8 Sep 2024 22:23:30 -0700 Subject: [PATCH] add amazon bedrock integrations for claude --- CHANGELOG.md | 4 + go.mod | 10 +- go.sum | 12 + internal/authenticator/authenticator.go | 4 + internal/manager/provider_setting.go | 19 +- internal/message/handler.go | 41 ++ internal/provider/anthropic/bedrock.go | 44 ++ internal/server/web/proxy/anthropic.go | 7 +- internal/server/web/proxy/bedrock.go | 513 ++++++++++++++++++++++++ internal/server/web/proxy/middleware.go | 62 +++ internal/server/web/proxy/proxy.go | 8 + internal/util/util.go | 29 ++ 12 files changed, 742 insertions(+), 11 deletions(-) create mode 100644 internal/provider/anthropic/bedrock.go create mode 100644 internal/server/web/proxy/bedrock.go diff --git a/CHANGELOG.md b/CHANGELOG.md index c7802f8..802f541 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 1.36.0 - 2024-08-10 +### Added +- Added Amazon Bedrock integration for Claude models + ## 1.35.2 - 2024-08-10 ### Changed - Changed aggregated table column data types from `INT` to `BIGINT` diff --git a/go.mod b/go.mod index 40b6f34..b099bce 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,8 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 // indirect + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.16.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 // indirect @@ -44,18 +46,18 @@ require ( github.com/Microsoft/go-winio v0.5.0 // indirect github.com/asticode/go-astikit v0.20.0 // indirect github.com/asticode/go-astits v1.8.0 // indirect - github.com/aws/aws-sdk-go-v2 v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2 v1.30.5 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.7 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.3 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.5 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.20.2 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.2 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.28.4 // indirect - github.com/aws/smithy-go v1.20.1 // indirect + github.com/aws/smithy-go v1.20.4 // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect diff --git a/go.sum b/go.sum index 85cb81a..aaf0e9b 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,10 @@ github.com/asticode/go-astits v1.8.0 h1:rf6aiiGn/QhlFjNON1n5plqF3Fs025XLUwiQ0NB6 github.com/asticode/go-astits v1.8.0/go.mod h1:DkOWmBNQpnr9mv24KfZjq4JawCFX1FCqjLVGvO0DygQ= github.com/aws/aws-sdk-go-v2 v1.25.3 h1:xYiLpZTQs1mzvz5PaI6uR0Wh57ippuEthxS4iK5v0n0= github.com/aws/aws-sdk-go-v2 v1.25.3/go.mod h1:35hUlJVYd+M++iLI3ALmVwMOyRYMmRqUXpTtRGW+K9I= +github.com/aws/aws-sdk-go-v2 v1.30.5 h1:mWSRTwQAb0aLE17dSzztCVJWI9+cRMgqebndjwDyK0g= +github.com/aws/aws-sdk-go-v2 v1.30.5/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 h1:70PVAiL15/aBMh5LThwgXdSQorVr91L127ttckI9QQU= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4/go.mod h1:/MQxMqci8tlqDH+pjmoLu1i0tbWCUP1hhyMRuFxpQCw= github.com/aws/aws-sdk-go-v2/config v1.27.7 h1:JSfb5nOQF01iOgxFI5OIKWwDiEXWTyTgg1Mm1mHi0A4= github.com/aws/aws-sdk-go-v2/config v1.27.7/go.mod h1:PH0/cNpoMO+B04qET699o5W92Ca79fVtbUnvMIZro4I= github.com/aws/aws-sdk-go-v2/credentials v1.17.7 h1:WJd+ubWKoBeRh7A5iNMnxEOs982SyVKOJD+K8HIezu4= @@ -18,10 +22,16 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.3 h1:p+y7FvkK2dxS+FEwRIDHDe/ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.3/go.mod h1:/fYB+FZbDlwlAiynK9KDXlzZl3ANI9JkD0Uhz5FjNT4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3 h1:ifbIbHZyGl1alsAhPIYsHOg5MuApgqOvVeI8wIugXfs= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3/go.mod h1:oQZXg3c6SNeY6OZrDY+xHcF4VGIEoNotX2B4PrDeoJI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17 h1:pI7Bzt0BJtYA0N/JEC6B8fJ4RBrEMi1LBrkMdFYNSnQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17/go.mod h1:Dh5zzJYMtxfIjYW+/evjQ8uj2OyR/ve2KROHGHlSFqE= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3 h1:Qvodo9gHG9F3E8SfYOspPeBt0bjSbsevK8WhRAUHcoY= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3/go.mod h1:vCKrdLXtybdf/uQd/YfVR2r5pcbNuEYKzMQpcxmeSJw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17 h1:Mqr/V5gvrhA2gvgnF42Zh5iMiQNcOYthFYwCyrnuWlc= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17/go.mod h1:aLJpZlCmjE+V+KtN1q1uyZkfnUWpQGpbsn89XPKyzfU= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.16.2 h1:hmzsX43PIJ8x+dwJwruqMjE2F8tZuCQMxVz9Vn0EZkc= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.16.2/go.mod h1:emMKL0OTFG+l9pW11RMgfvJRxZ5e093OS1o102YEGoA= github.com/aws/aws-sdk-go-v2/service/comprehend v1.31.2 h1:iAnydKItgi2m2rOPFfyolvjXuZimVZgRPxGlYg6Vt5U= github.com/aws/aws-sdk-go-v2/service/comprehend v1.31.2/go.mod h1:4jJr/hungAbvS0vQqkZQvxBqxJ4oUSEpvezYM75q2e4= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1 h1:EyBZibRTVAs6ECHZOw5/wlylS9OcTzwyjeQMudmREjE= @@ -36,6 +46,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.28.4 h1:Ppup1nVNAOWbBOrcoOxaxPeEnSFB github.com/aws/aws-sdk-go-v2/service/sts v1.28.4/go.mod h1:+K1rNPVyGxkRuv9NNiaZ4YhBFuyw2MMA9SlIJ1Zlpz8= github.com/aws/smithy-go v1.20.1 h1:4SZlSlMr36UEqC7XOyRVb27XMeZubNcBNN+9IgEPIQw= github.com/aws/smithy-go v1.20.1/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= +github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go index 25c4bb4..005b67c 100644 --- a/internal/authenticator/authenticator.go +++ b/internal/authenticator/authenticator.go @@ -87,6 +87,10 @@ func rewriteHttpAuthHeader(req *http.Request, setting *provider.Setting) error { } if len(apiKey) == 0 { + if setting.Provider == "bedrock" { + return nil + } + return errors.New("api key is empty in provider setting") } diff --git a/internal/manager/provider_setting.go b/internal/manager/provider_setting.go index 29b0e6c..5b246d1 100644 --- a/internal/manager/provider_setting.go +++ b/internal/manager/provider_setting.go @@ -40,7 +40,7 @@ func NewProviderSettingsManager(s ProviderSettingsStorage, cache ProviderSetting } func isProviderNativelySupported(provider string) bool { - return provider == "openai" || provider == "anthropic" || provider == "azure" || provider == "vllm" || provider == "deepinfra" + return provider == "openai" || provider == "anthropic" || provider == "azure" || provider == "vllm" || provider == "deepinfra" || provider == "bedrock" } func findMissingAuthParams(providerName string, params map[string]string) string { @@ -55,6 +55,23 @@ func findMissingAuthParams(providerName string, params map[string]string) string return strings.Join(missingFields, " ,") } + if providerName == "bedrock" { + val := params["awsAccessKeyId"] + if len(val) == 0 { + missingFields = append(missingFields, "awsAccessKeyId") + } + + val = params["awsSecretAccessKey"] + if len(val) == 0 { + missingFields = append(missingFields, "awsSecretAccessKey") + } + + val = params["awsRegion"] + if len(val) == 0 { + missingFields = append(missingFields, "awsRegion") + } + } + if providerName == "azure" { val := params["resourceName"] if len(val) == 0 { diff --git a/internal/message/handler.go b/internal/message/handler.go index a7f0ee5..e216820 100644 --- a/internal/message/handler.go +++ b/internal/message/handler.go @@ -14,6 +14,7 @@ import ( "github.com/bricks-cloud/bricksllm/internal/provider/vllm" "github.com/bricks-cloud/bricksllm/internal/telemetry" "github.com/bricks-cloud/bricksllm/internal/user" + "github.com/bricks-cloud/bricksllm/internal/util" "github.com/tidwall/gjson" "go.uber.org/zap" @@ -468,6 +469,46 @@ func (h *Handler) decorateEvent(m Message) error { } } + if e.Event.Path == "/api/providers/bedrock/anthropic/v1/complete" { + cr, ok := e.Request.(*anthropic.CompletionRequest) + if !ok { + telemetry.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1) + h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Any("data", m.Data)) + return errors.New("event request data cannot be parsed as anthropic completion request") + } + + if !cr.Stream { + tks := h.ae.Count(cr.Prompt) + tks += anthropicPromptMagicNum + + model := cr.Model + + translatedModel := util.TranslateBedrockModelToAnthropicModel(model) + cost, err := h.ae.EstimatePromptCost(translatedModel, tks) + if err != nil { + telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_prompt_cost", nil, 1) + h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Error(err)) + return err + } + + completiontks := h.ae.Count(e.Content) + completiontks += anthropicCompletionMagicNum + + completionCost, err := h.ae.EstimateCompletionCost(translatedModel, completiontks) + if err != nil { + telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_completion_cost_error", nil, 1) + return err + } + + e.Event.PromptTokenCount = tks + + e.Event.CompletionTokenCount = completiontks + if e.Event.Status == http.StatusOK { + e.Event.CostInUsd = completionCost + cost + } + } + } + if strings.HasPrefix(e.Event.Path, "/api/providers/azure/openai/deployments") && strings.HasSuffix(e.Event.Path, "/chat/completions") { ccr, ok := e.Request.(*goopenai.ChatCompletionRequest) if !ok { diff --git a/internal/provider/anthropic/bedrock.go b/internal/provider/anthropic/bedrock.go new file mode 100644 index 0000000..d92929e --- /dev/null +++ b/internal/provider/anthropic/bedrock.go @@ -0,0 +1,44 @@ +package anthropic + +type BedrockCompletionRequest struct { + Prompt string `json:"prompt"` + MaxTokensToSample int `json:"max_tokens_to_sample"` + StopSequences []string `json:"stop_sequences,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP int `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` +} + +type BedrockCompletionResponse struct { + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` + Model string `json:"model"` + Metrics *BedrockMetrics `json:"amazon-bedrock-invocationMetrics"` +} + +type BedrockMessageRequest struct { + AnthropicVersion string `json:"anthropic_version"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens"` + StopSequences []string `json:"stop_sequences,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP int `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Metadata *Metadata `json:"metadata,omitempty"` +} + +type BedrockMessagesStopResponse struct { + Type string `json:"type"` + Metrics *BedrockMetrics `json:"amazon-bedrock-invocationMetrics"` +} + +type BedrockMetrics struct { + InputTokenCount int `json:"inputTokenCount"` + OutputTokenCount int `json:"outputTokenCount"` + InvocationLatency int `json:"invocationLatency"` + FirstByteLatency int `json:"firstByteLatency"` +} + +type BedrockMessageType struct { + Type string `json:"type"` +} diff --git a/internal/server/web/proxy/anthropic.go b/internal/server/web/proxy/anthropic.go index 97e79ed..9c4c902 100644 --- a/internal/server/web/proxy/anthropic.go +++ b/internal/server/web/proxy/anthropic.go @@ -386,11 +386,6 @@ func getMessagesHandler(prod, private bool, client http.Client, e anthropicEstim telemetry.Incr("bricksllm.proxy.get_messages_handler.estimate_total_cost_error", nil, 1) logError(log, "error when estimating anthropic cost", prod, err) } - - if err != nil { - telemetry.Incr("bricksllm.proxy.get_messages_handler.record_key_spend_error", nil, 1) - logError(log, "error when recording anthropic spend", prod, err) - } } c.Set("costInUsd", cost) @@ -402,7 +397,7 @@ func getMessagesHandler(prod, private bool, client http.Client, e anthropicEstim } if res.StatusCode != http.StatusOK { - dur := time.Now().Sub(start) + dur := time.Since(start) telemetry.Timing("bricksllm.proxy.get_messages_handler.error_latency", dur, nil, 1) telemetry.Incr("bricksllm.proxy.get_messages_handler.error_response", nil, 1) bytes, err := io.ReadAll(res.Body) diff --git a/internal/server/web/proxy/bedrock.go b/internal/server/web/proxy/bedrock.go new file mode 100644 index 0000000..37d9669 --- /dev/null +++ b/internal/server/web/proxy/bedrock.go @@ -0,0 +1,513 @@ +package proxy + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/bricks-cloud/bricksllm/internal/provider/anthropic" + "github.com/bricks-cloud/bricksllm/internal/telemetry" + "github.com/bricks-cloud/bricksllm/internal/util" + "github.com/gin-gonic/gin" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func getBedrockCompletionHandler(prod bool, e anthropicEstimator, timeOut time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + log := util.GetLogFromCtx(c) + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.requests", nil, 1) + + if c == nil || c.Request == nil { + JSON(c, http.StatusInternalServerError, "[BricksLLM] context is empty") + return + } + + body, err := io.ReadAll(c.Request.Body) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.read_all_error", nil, 1) + log.Error("error when reading claude req data from body", []zapcore.Field{zap.Error(err)}...) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to read claude req data from body") + return + } + + anthropicReq := &anthropic.CompletionRequest{} + err = json.Unmarshal(body, anthropicReq) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.unmarshal_anthropic_completion_request_error", nil, 1) + log.Error("error when unmarshalling anthropic completion request", []zapcore.Field{zap.Error(err)}...) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to unmarshal anthropic completion request") + return + } + + req := &anthropic.BedrockCompletionRequest{} + err = json.Unmarshal(body, req) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.unmarshal_bedrock_completion_request_error", nil, 1) + log.Error("error when unmarshalling bedrock completion request", []zapcore.Field{zap.Error(err)}...) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to unmarshal bedrock completion request") + return + } + + bs, err := json.Marshal(req) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.marshal_bedrock_completion_request_error", nil, 1) + log.Error("error when marshalling bedrock completion request", []zapcore.Field{zap.Error(err)}...) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to marshal bedrock completion request") + return + } + + keyId := c.GetString("awsAccessKeyId") + secretKey := c.GetString("awsSecretAccessKey") + region := c.GetString("awsRegion") + + if len(keyId) == 0 || len(secretKey) == 0 || len(region) == 0 { + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.auth_error", nil, 1) + log.Error("key id, secret key or region is missing", []zapcore.Field{zap.Error(err)}...) + JSON(c, http.StatusUnauthorized, "[BricksLLM] auth credentials are missing") + return + } + + ctx, cancel := context.WithTimeout(context.Background(), timeOut) + defer cancel() + cfg, err := config.LoadDefaultConfig(ctx, + config.WithCredentialsProvider(credentials.StaticCredentialsProvider{ + Value: aws.Credentials{ + AccessKeyID: keyId, SecretAccessKey: secretKey, + Source: "BricksLLM Credentials", + }, + }), + config.WithRegion(region)) + + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.aws_config_creation_error", nil, 1) + log.Error("error when creating aws config", []zapcore.Field{zap.Error(err)}...) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create aws config") + return + } + + client := bedrockruntime.NewFromConfig(cfg) + stream := c.GetBool("stream") + + ctx, cancel = context.WithTimeout(context.Background(), timeOut) + defer cancel() + + start := time.Now() + + if !stream { + output, err := client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ + ModelId: &anthropicReq.Model, + ContentType: aws.String("application/json"), + Body: bs, + }) + + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.error_response", nil, 1) + telemetry.Timing("bricksllm.proxy.get_bedrock_completion_handler.error_latency", time.Since(start), nil, 1) + + log.Error("error when invoking bedrock model", []zapcore.Field{zap.Error(err)}...) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to invoke bedrock model") + return + } + + completionRes := &anthropic.BedrockCompletionResponse{} + err = json.Unmarshal(output.Body, completionRes) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.unmarshal_bedrock_completion_response_error", nil, 1) + logError(log, "error when unmarshalling bedrock anthropic completion response body", prod, err) + } + + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.success", nil, 1) + telemetry.Timing("bricksllm.proxy.get_bedrock_completion_handler.success_latency", time.Since(start), nil, 1) + + c.Set("content", completionRes.Completion) + + c.Data(http.StatusOK, "application/json", output.Body) + return + } + + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.streaming_requests", nil, 1) + + streamOutput, err := client.InvokeModelWithResponseStream(ctx, &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: &anthropicReq.Model, + ContentType: aws.String("application/json"), + Body: bs, + }) + + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.invoking_model_with_streaming_response_error", nil, 1) + log.Error("error when invoking bedrock model with streaming responses", []zapcore.Field{zap.Error(err)}...) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to invoke bedrock model with stream response") + return + } + + streamingResponse := [][]byte{} + promptTokenCount := 0 + completionTokenCount := 0 + + defer func() { + model := c.GetString("model") + translatedModel := util.TranslateBedrockModelToAnthropicModel(model) + compeltionCost, err := e.EstimateCompletionCost(translatedModel, completionTokenCount) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.estimate_completion_cost_error", nil, 1) + logError(log, "error when estimating bedrock completion cost", prod, err) + } + + promptCost, err := e.EstimatePromptCost(translatedModel, promptTokenCount) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.estimate_prompt_cost_error", nil, 1) + logError(log, "error when estimating bedrock prompt cost", prod, err) + } + + c.Set("costInUsd", compeltionCost+promptCost) + c.Set("promptTokenCount", promptTokenCount) + c.Set("completionTokenCount", completionTokenCount) + c.Set("streaming_response", bytes.Join(streamingResponse, []byte{'\n'})) + }() + + eventName := "" + c.Stream(func(w io.Writer) bool { + for event := range streamOutput.GetStream().Events() { + switch v := event.(type) { + case *types.ResponseStreamMemberChunk: + raw := v.Value.Bytes + noSpaceLine := bytes.TrimSpace(raw) + if len(noSpaceLine) == 0 { + return true + } + + eventName = getEventNameFromLine(noSpaceLine) + if len(eventName) == 0 { + return true + } + + chatCompletionResp := &anthropic.BedrockCompletionResponse{} + if eventName == " completion" { + err := json.NewDecoder(bytes.NewReader(noSpaceLine)).Decode(&chatCompletionResp) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.bedrock_completion_stream_response_unmarshall_error", nil, 1) + log.Error("error when unmarshalling bedrock streaming response chunks", []zapcore.Field{zap.Error(err)}...) + return false + } + + if chatCompletionResp.Metrics != nil { + promptTokenCount = chatCompletionResp.Metrics.InputTokenCount + completionTokenCount = chatCompletionResp.Metrics.OutputTokenCount + } + } + + noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) + c.SSEvent(eventName, " "+string(noPrefixLine)) + + streamingResponse = append(streamingResponse, raw) + if len(chatCompletionResp.StopReason) != 0 { + return false + } + default: + telemetry.Incr("bricksllm.proxy.get_bedrock_completion_handler.bedrock_completion_stream_response_unkown_error", nil, 1) + return false + } + } + + telemetry.Timing("bricksllm.proxy.get_bedrock_completion_handler.streaming_latency", time.Since(start), nil, 1) + return false + }) + } +} + +var ( + bedrockEventMessageStart = []byte(`{"type":"message_start"`) + bedrockEventMessageDelta = []byte(`{"type":"message_delta"`) + bedrockEventMessageStop = []byte(`{"type":"message_stop"`) + bedrockEventContentBlockStart = []byte(`{"type":"content_block_start"`) + bedrockEventContentBlockDelta = []byte(`{"type":"content_block_delta"`) + bedrockEventContentBlockStop = []byte(`{"type":"content_block_stop"`) + bedrockEventPing = []byte(`{"type":"ping"`) + bedrockEventError = []byte(`{"type":"error"`) + bedrockEventCompletion = []byte(`{"type":"completion"`) +) + +func getEventNameFromLine(line []byte) string { + if bytes.HasPrefix(line, bedrockEventMessageStart) { + return " message_start" + } + + if bytes.HasPrefix(line, bedrockEventMessageDelta) { + return " message_delta" + } + + if bytes.HasPrefix(line, bedrockEventMessageStop) { + return " message_stop" + } + + if bytes.HasPrefix(line, bedrockEventContentBlockStart) { + return " content_block_start" + } + + if bytes.HasPrefix(line, bedrockEventContentBlockDelta) { + return " content_block_delta" + } + + if bytes.HasPrefix(line, bedrockEventContentBlockStop) { + return " content_block_stop" + } + + if bytes.HasPrefix(line, bedrockEventPing) { + return " ping" + } + + if bytes.HasPrefix(line, bedrockEventError) { + return " error" + } + + if bytes.HasPrefix(line, bedrockEventCompletion) { + return " completion" + } + + return "" +} + +func getBedrockMessagesHandler(prod bool, e anthropicEstimator, timeOut time.Duration) gin.HandlerFunc { + return func(c *gin.Context) { + log := util.GetLogFromCtx(c) + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.requests", nil, 1) + + if c == nil || c.Request == nil { + JSON(c, http.StatusInternalServerError, "[BricksLLM] context is empty") + return + } + + body, err := io.ReadAll(c.Request.Body) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.read_all_error", nil, 1) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to read claude req data from body") + return + } + + anthropicReq := &anthropic.MessagesRequest{} + err = json.Unmarshal(body, anthropicReq) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.unmarshal_anthropic_messages_request_error", nil, 1) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to unmarshal anthropic messages request") + return + } + + req := &anthropic.BedrockMessageRequest{} + err = json.Unmarshal(body, req) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.unmarshal_bedrock_messages_request_error", nil, 1) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to unmarshal bedrock messages request") + return + } + + bs, err := json.Marshal(req) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.marshal_error", nil, 1) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to marshal bedrock messages request") + return + } + + keyId := c.GetString("awsAccessKeyId") + secretKey := c.GetString("awsSecretAccessKey") + region := c.GetString("awsRegion") + + if len(keyId) == 0 || len(secretKey) == 0 || len(region) == 0 { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.auth_error", nil, 1) + log.Error("key id, secret key or region is missing", []zapcore.Field{zap.Error(err)}...) + JSON(c, http.StatusUnauthorized, "[BricksLLM] auth credentials are missing") + return + } + + ctx, cancel := context.WithTimeout(context.Background(), timeOut) + defer cancel() + cfg, err := config.LoadDefaultConfig(ctx, + config.WithCredentialsProvider(credentials.StaticCredentialsProvider{ + Value: aws.Credentials{ + AccessKeyID: keyId, SecretAccessKey: secretKey, + Source: "BricksLLM Credentials", + }, + }), + config.WithRegion(region)) + + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.aws_config_creation_error", nil, 1) + log.Error("error when creating aws config", []zapcore.Field{zap.Error(err)}...) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create aws config") + return + } + + client := bedrockruntime.NewFromConfig(cfg) + stream := c.GetBool("stream") + + ctx, cancel = context.WithTimeout(context.Background(), timeOut) + defer cancel() + + start := time.Now() + + if !stream { + output, err := client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ + ModelId: &anthropicReq.Model, + ContentType: aws.String("application/json"), + Body: bs, + }) + + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.error_response", nil, 1) + telemetry.Timing("bricksllm.proxy.get_bedrock_messages_handler.error_latency", time.Since(start), nil, 1) + + log.Error("error when invoking bedrock model", []zapcore.Field{zap.Error(err)}...) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to invoke bedrock model") + return + } + + var cost float64 = 0 + completionTokens := 0 + promptTokens := 0 + + messagesRes := &anthropic.MessagesResponse{} + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.success", nil, 1) + telemetry.Timing("bricksllm.proxy.get_bedrock_messages_handler.success_latency", time.Since(start), nil, 1) + + err = json.Unmarshal(output.Body, messagesRes) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.unmarshal_bedrock_messages_response_error", nil, 1) + logError(log, "error when unmarshalling bedrock messages response body", prod, err) + } + + if err == nil { + completionTokens = messagesRes.Usage.OutputTokens + promptTokens = messagesRes.Usage.InputTokens + + model := c.GetString("model") + translated := util.TranslateBedrockModelToAnthropicModel(model) + + cost, err = e.EstimateTotalCost(translated, promptTokens, completionTokens) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.estimate_total_cost_error", nil, 1) + logError(log, "error when estimating anthropic cost", prod, err) + } + } + + c.Set("costInUsd", cost) + c.Set("promptTokenCount", promptTokens) + c.Set("completionTokenCount", completionTokens) + + c.Data(http.StatusOK, "application/json", output.Body) + return + } + + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.streaming_requests", nil, 1) + + streamOutput, err := client.InvokeModelWithResponseStream(ctx, &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: &anthropicReq.Model, + ContentType: aws.String("application/json"), + Body: bs, + }) + + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.invoking_model_with_streaming_response_error", nil, 1) + + log.Error("error when invoking bedrock model with streaming responses", []zapcore.Field{zap.Error(err)}...) + JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to invoke model request with stream response") + return + } + + streamingResponse := [][]byte{} + promptTokenCount := 0 + completionTokenCount := 0 + + defer func() { + model := c.GetString("model") + translatedModel := util.TranslateBedrockModelToAnthropicModel(model) + compeltionCost, err := e.EstimateCompletionCost(translatedModel, completionTokenCount) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.estimate_completion_cost_error", nil, 1) + logError(log, "error when estimating bedrock completion cost", prod, err) + } + + promptCost, err := e.EstimatePromptCost(translatedModel, promptTokenCount) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.estimate_prompt_cost_error", nil, 1) + logError(log, "error when estimating bedrock prompt cost", prod, err) + } + + c.Set("costInUsd", compeltionCost+promptCost) + c.Set("promptTokenCount", promptTokenCount) + c.Set("completionTokenCount", completionTokenCount) + c.Set("streaming_response", bytes.Join(streamingResponse, []byte{'\n'})) + }() + + eventName := "" + c.Stream(func(w io.Writer) bool { + content := "" + for event := range streamOutput.GetStream().Events() { + switch v := event.(type) { + case *types.ResponseStreamMemberChunk: + raw := v.Value.Bytes + streamingResponse = append(streamingResponse, raw) + + noSpaceLine := bytes.TrimSpace(raw) + if len(noSpaceLine) == 0 { + return true + } + + eventName = getEventNameFromLine(noSpaceLine) + if len(eventName) == 0 { + return true + } + + if eventName == " message_stop" { + stopResp := &anthropic.BedrockMessagesStopResponse{} + err := json.NewDecoder(bytes.NewReader(raw)).Decode(&stopResp) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.bedrock_messages_stop_response_unmarshall_error", nil, 1) + log.Error("error when unmarshalling bedrock messages stop response response chunks", []zapcore.Field{zap.Error(err)}...) + + return false + } + + if stopResp.Metrics != nil { + promptTokenCount = stopResp.Metrics.InputTokenCount + completionTokenCount = stopResp.Metrics.OutputTokenCount + } + } + + if eventName == " content_block_delta" { + chatCompletionResp := &anthropic.MessagesStreamBlockDelta{} + err := json.NewDecoder(bytes.NewReader(raw)).Decode(&chatCompletionResp) + if err != nil { + telemetry.Incr("bricksllm.proxy.get_bedrock_messages_handler.bedrock_messages_content_block_response_unmarshall_error", nil, 1) + log.Error("error when unmarshalling bedrock messages content block response chunks", []zapcore.Field{zap.Error(err)}...) + + return false + } + + content += chatCompletionResp.Delta.Text + } + + c.SSEvent(eventName, " "+string(noSpaceLine)) + + if eventName == " message_stop" { + return false + } + default: + + telemetry.Timing("bricksllm.proxy.get_bedrock_messages_handler.streaming_latency", time.Since(start), nil, 1) + return false + } + } + + telemetry.Timing("bricksllm.proxy.get_bedrock_messages_handler.streaming_latency", time.Since(start), nil, 1) + return false + }) + } +} diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go index a389b7e..30b9eb0 100644 --- a/internal/server/web/proxy/middleware.go +++ b/internal/server/web/proxy/middleware.go @@ -348,6 +348,20 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag } } + if strings.HasPrefix(c.FullPath(), "/api/providers/bedrock/anthropic") { + if selected != nil && len(selected.Setting["awsAccessKeyId"]) != 0 { + c.Set("awsAccessKeyId", selected.Setting["awsAccessKeyId"]) + } + + if selected != nil && len(selected.Setting["awsSecretAccessKey"]) != 0 { + c.Set("awsSecretAccessKey", selected.Setting["awsSecretAccessKey"]) + } + + if selected != nil && len(selected.Setting["awsRegion"]) != 0 { + c.Set("awsRegion", selected.Setting["awsRegion"]) + } + } + if strings.HasPrefix(c.FullPath(), "/api/providers/vllm") { if selected != nil && len(selected.Setting["url"]) != 0 { c.Set("vllmUrl", selected.Setting["url"]) @@ -402,6 +416,54 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag policyInput = cr } + if c.FullPath() == "/api/providers/bedrock/anthropic/v1/complete" { + logCompletionRequest(logWithCid, body, prod, private) + + cr := &anthropic.CompletionRequest{} + err = json.Unmarshal(body, cr) + if err != nil { + logError(logWithCid, "error when unmarshalling bedrock anthropic completion request", prod, err) + return + } + + if cr.Metadata != nil { + userId = cr.Metadata.UserId + } + + enrichedEvent.Request = cr + + if cr.Stream { + c.Set("stream", cr.Stream) + } + + c.Set("model", cr.Model) + + policyInput = cr + } + + if c.FullPath() == "/api/providers/bedrock/anthropic/v1/messages" { + logCreateMessageRequest(logWithCid, body, prod, private) + + mr := &anthropic.MessagesRequest{} + err = json.Unmarshal(body, mr) + if err != nil { + logError(logWithCid, "error when unmarshalling anthropic messages request", prod, err) + return + } + + if mr.Metadata != nil { + userId = mr.Metadata.UserId + } + + if mr.Stream { + c.Set("stream", mr.Stream) + } + + c.Set("model", mr.Model) + + policyInput = mr + } + if c.FullPath() == "/api/providers/anthropic/v1/messages" { logCreateMessageRequest(logWithCid, body, prod, private) diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go index 5a92dc4..c7d9641 100644 --- a/internal/server/web/proxy/proxy.go +++ b/internal/server/web/proxy/proxy.go @@ -181,6 +181,10 @@ func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyMan router.POST("/api/providers/anthropic/v1/complete", getCompletionHandler(prod, private, client, timeOut)) router.POST("/api/providers/anthropic/v1/messages", getMessagesHandler(prod, private, client, ae, timeOut)) + // bedrock anthropic + router.POST("/api/providers/bedrock/anthropic/v1/complete", getBedrockCompletionHandler(prod, ae, timeOut)) + router.POST("/api/providers/bedrock/anthropic/v1/messages", getBedrockMessagesHandler(prod, ae, timeOut)) + // vllm router.POST("/api/providers/vllm/v1/chat/completions", getVllmChatCompletionsHandler(prod, private, client, timeOut)) router.POST("/api/providers/vllm/v1/completions", getVllmCompletionsHandler(prod, private, client, timeOut)) @@ -992,6 +996,10 @@ func (ps *ProxyServer) Run() { ps.log.Info("PORT 8002 | POST | /api/providers/anthropic/v1/complete is ready for forwarding completion requests to anthropic") ps.log.Info("PORT 8002 | POST | /api/providers/anthropic/v1/messages is ready for forwarding message requests to anthropic") + // bedrock anthropic + ps.log.Info("PORT 8002 | POST | /api/providers/bedrock/anthropic/v1/complete is ready for forwarding completion requests to bedrock anthropic") + ps.log.Info("PORT 8002 | POST | /api/providers/bedrock/anthropic/v1/messages is ready for forwarding message requests to bedrock anthropic") + // vllm ps.log.Info("PORT 8002 | POST | /api/providers/vllm/v1/chat/completions is ready for forwarding vllm chat completions requests") ps.log.Info("PORT 8002 | POST | /api/providers/vllm/v1/completions is ready for forwarding vllm completions requests") diff --git a/internal/util/util.go b/internal/util/util.go index ae5d353..2ee40d5 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -3,6 +3,7 @@ package util import ( "context" "errors" + "strings" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -50,3 +51,31 @@ func ConvertAnyToStr(input any) (string, error) { return converted, nil } + +func TranslateBedrockModelToAnthropicModel(model string) string { + if strings.HasPrefix(model, "anthropic.claude-v2") { + return "claude" + } + + if strings.HasPrefix(model, "anthropic.claude-3-haiku") { + return "claude-3-haiku" + } + + if strings.HasPrefix(model, "anthropic.claude-3-sonnet") { + return "claude-3-sonnet" + } + + if strings.HasPrefix(model, "anthropic.claude-3-opus") { + return "claude-3-opus" + } + + if strings.HasPrefix(model, "anthropic.claude-3-5-sonnet") { + return "claude-3.5-sonnet" + } + + if strings.HasPrefix(model, "anthropic.claude-instant") { + return "claude-instant" + } + + return model +}