Skip to content

Commit

Permalink
add amazon bedrock integrations for claude
Browse files Browse the repository at this point in the history
  • Loading branch information
spikelu2016 committed Sep 9, 2024
1 parent e34416f commit 876fa3c
Show file tree
Hide file tree
Showing 12 changed files with 742 additions and 11 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 1.36.0 - 2024-08-10
### Added
- Added Amazon Bedrock integration for Claude models

## 1.35.2 - 2024-08-10
### Changed
- Changed aggregated table column data types from `INT` to `BIGINT`
Expand Down
10 changes: 6 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ require (
)

require (
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 // indirect
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.16.2 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 // indirect
Expand All @@ -44,18 +46,18 @@ require (
github.com/Microsoft/go-winio v0.5.0 // indirect
github.com/asticode/go-astikit v0.20.0 // indirect
github.com/asticode/go-astits v1.8.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.25.3 // indirect
github.com/aws/aws-sdk-go-v2 v1.30.5 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.7 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.3 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.5 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.20.2 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.2 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.28.4 // indirect
github.com/aws/smithy-go v1.20.1 // indirect
github.com/aws/smithy-go v1.20.4 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
Expand Down
12 changes: 12 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ github.com/asticode/go-astits v1.8.0 h1:rf6aiiGn/QhlFjNON1n5plqF3Fs025XLUwiQ0NB6
github.com/asticode/go-astits v1.8.0/go.mod h1:DkOWmBNQpnr9mv24KfZjq4JawCFX1FCqjLVGvO0DygQ=
github.com/aws/aws-sdk-go-v2 v1.25.3 h1:xYiLpZTQs1mzvz5PaI6uR0Wh57ippuEthxS4iK5v0n0=
github.com/aws/aws-sdk-go-v2 v1.25.3/go.mod h1:35hUlJVYd+M++iLI3ALmVwMOyRYMmRqUXpTtRGW+K9I=
github.com/aws/aws-sdk-go-v2 v1.30.5 h1:mWSRTwQAb0aLE17dSzztCVJWI9+cRMgqebndjwDyK0g=
github.com/aws/aws-sdk-go-v2 v1.30.5/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 h1:70PVAiL15/aBMh5LThwgXdSQorVr91L127ttckI9QQU=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4/go.mod h1:/MQxMqci8tlqDH+pjmoLu1i0tbWCUP1hhyMRuFxpQCw=
github.com/aws/aws-sdk-go-v2/config v1.27.7 h1:JSfb5nOQF01iOgxFI5OIKWwDiEXWTyTgg1Mm1mHi0A4=
github.com/aws/aws-sdk-go-v2/config v1.27.7/go.mod h1:PH0/cNpoMO+B04qET699o5W92Ca79fVtbUnvMIZro4I=
github.com/aws/aws-sdk-go-v2/credentials v1.17.7 h1:WJd+ubWKoBeRh7A5iNMnxEOs982SyVKOJD+K8HIezu4=
Expand All @@ -18,10 +22,16 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.3 h1:p+y7FvkK2dxS+FEwRIDHDe/
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.3/go.mod h1:/fYB+FZbDlwlAiynK9KDXlzZl3ANI9JkD0Uhz5FjNT4=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3 h1:ifbIbHZyGl1alsAhPIYsHOg5MuApgqOvVeI8wIugXfs=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3/go.mod h1:oQZXg3c6SNeY6OZrDY+xHcF4VGIEoNotX2B4PrDeoJI=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17 h1:pI7Bzt0BJtYA0N/JEC6B8fJ4RBrEMi1LBrkMdFYNSnQ=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.17/go.mod h1:Dh5zzJYMtxfIjYW+/evjQ8uj2OyR/ve2KROHGHlSFqE=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3 h1:Qvodo9gHG9F3E8SfYOspPeBt0bjSbsevK8WhRAUHcoY=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3/go.mod h1:vCKrdLXtybdf/uQd/YfVR2r5pcbNuEYKzMQpcxmeSJw=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17 h1:Mqr/V5gvrhA2gvgnF42Zh5iMiQNcOYthFYwCyrnuWlc=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.17/go.mod h1:aLJpZlCmjE+V+KtN1q1uyZkfnUWpQGpbsn89XPKyzfU=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.16.2 h1:hmzsX43PIJ8x+dwJwruqMjE2F8tZuCQMxVz9Vn0EZkc=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.16.2/go.mod h1:emMKL0OTFG+l9pW11RMgfvJRxZ5e093OS1o102YEGoA=
github.com/aws/aws-sdk-go-v2/service/comprehend v1.31.2 h1:iAnydKItgi2m2rOPFfyolvjXuZimVZgRPxGlYg6Vt5U=
github.com/aws/aws-sdk-go-v2/service/comprehend v1.31.2/go.mod h1:4jJr/hungAbvS0vQqkZQvxBqxJ4oUSEpvezYM75q2e4=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1 h1:EyBZibRTVAs6ECHZOw5/wlylS9OcTzwyjeQMudmREjE=
Expand All @@ -36,6 +46,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.28.4 h1:Ppup1nVNAOWbBOrcoOxaxPeEnSFB
github.com/aws/aws-sdk-go-v2/service/sts v1.28.4/go.mod h1:+K1rNPVyGxkRuv9NNiaZ4YhBFuyw2MMA9SlIJ1Zlpz8=
github.com/aws/smithy-go v1.20.1 h1:4SZlSlMr36UEqC7XOyRVb27XMeZubNcBNN+9IgEPIQw=
github.com/aws/smithy-go v1.20.1/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4=
github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
Expand Down
4 changes: 4 additions & 0 deletions internal/authenticator/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ func rewriteHttpAuthHeader(req *http.Request, setting *provider.Setting) error {
}

if len(apiKey) == 0 {
if setting.Provider == "bedrock" {
return nil
}

return errors.New("api key is empty in provider setting")
}

Expand Down
19 changes: 18 additions & 1 deletion internal/manager/provider_setting.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewProviderSettingsManager(s ProviderSettingsStorage, cache ProviderSetting
}

func isProviderNativelySupported(provider string) bool {
return provider == "openai" || provider == "anthropic" || provider == "azure" || provider == "vllm" || provider == "deepinfra"
return provider == "openai" || provider == "anthropic" || provider == "azure" || provider == "vllm" || provider == "deepinfra" || provider == "bedrock"
}

func findMissingAuthParams(providerName string, params map[string]string) string {
Expand All @@ -55,6 +55,23 @@ func findMissingAuthParams(providerName string, params map[string]string) string
return strings.Join(missingFields, " ,")
}

if providerName == "bedrock" {
val := params["awsAccessKeyId"]
if len(val) == 0 {
missingFields = append(missingFields, "awsAccessKeyId")
}

val = params["awsSecretAccessKey"]
if len(val) == 0 {
missingFields = append(missingFields, "awsSecretAccessKey")
}

val = params["awsRegion"]
if len(val) == 0 {
missingFields = append(missingFields, "awsRegion")
}
}

if providerName == "azure" {
val := params["resourceName"]
if len(val) == 0 {
Expand Down
41 changes: 41 additions & 0 deletions internal/message/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/bricks-cloud/bricksllm/internal/provider/vllm"
"github.com/bricks-cloud/bricksllm/internal/telemetry"
"github.com/bricks-cloud/bricksllm/internal/user"
"github.com/bricks-cloud/bricksllm/internal/util"
"github.com/tidwall/gjson"
"go.uber.org/zap"

Expand Down Expand Up @@ -468,6 +469,46 @@ func (h *Handler) decorateEvent(m Message) error {
}
}

if e.Event.Path == "/api/providers/bedrock/anthropic/v1/complete" {
cr, ok := e.Request.(*anthropic.CompletionRequest)
if !ok {
telemetry.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Any("data", m.Data))
return errors.New("event request data cannot be parsed as anthropic completion request")
}

if !cr.Stream {
tks := h.ae.Count(cr.Prompt)
tks += anthropicPromptMagicNum

model := cr.Model

translatedModel := util.TranslateBedrockModelToAnthropicModel(model)
cost, err := h.ae.EstimatePromptCost(translatedModel, tks)
if err != nil {
telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_prompt_cost", nil, 1)
h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Error(err))
return err
}

completiontks := h.ae.Count(e.Content)
completiontks += anthropicCompletionMagicNum

completionCost, err := h.ae.EstimateCompletionCost(translatedModel, completiontks)
if err != nil {
telemetry.Incr("bricksllm.message.handler.decorate_event.estimate_completion_cost_error", nil, 1)
return err
}

e.Event.PromptTokenCount = tks

e.Event.CompletionTokenCount = completiontks
if e.Event.Status == http.StatusOK {
e.Event.CostInUsd = completionCost + cost
}
}
}

if strings.HasPrefix(e.Event.Path, "/api/providers/azure/openai/deployments") && strings.HasSuffix(e.Event.Path, "/chat/completions") {
ccr, ok := e.Request.(*goopenai.ChatCompletionRequest)
if !ok {
Expand Down
44 changes: 44 additions & 0 deletions internal/provider/anthropic/bedrock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package anthropic

type BedrockCompletionRequest struct {
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP int `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
}

type BedrockCompletionResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Metrics *BedrockMetrics `json:"amazon-bedrock-invocationMetrics"`
}

type BedrockMessageRequest struct {
AnthropicVersion string `json:"anthropic_version"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP int `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
}

type BedrockMessagesStopResponse struct {
Type string `json:"type"`
Metrics *BedrockMetrics `json:"amazon-bedrock-invocationMetrics"`
}

type BedrockMetrics struct {
InputTokenCount int `json:"inputTokenCount"`
OutputTokenCount int `json:"outputTokenCount"`
InvocationLatency int `json:"invocationLatency"`
FirstByteLatency int `json:"firstByteLatency"`
}

type BedrockMessageType struct {
Type string `json:"type"`
}
7 changes: 1 addition & 6 deletions internal/server/web/proxy/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,11 +386,6 @@ func getMessagesHandler(prod, private bool, client http.Client, e anthropicEstim
telemetry.Incr("bricksllm.proxy.get_messages_handler.estimate_total_cost_error", nil, 1)
logError(log, "error when estimating anthropic cost", prod, err)
}

if err != nil {
telemetry.Incr("bricksllm.proxy.get_messages_handler.record_key_spend_error", nil, 1)
logError(log, "error when recording anthropic spend", prod, err)
}
}

c.Set("costInUsd", cost)
Expand All @@ -402,7 +397,7 @@ func getMessagesHandler(prod, private bool, client http.Client, e anthropicEstim
}

if res.StatusCode != http.StatusOK {
dur := time.Now().Sub(start)
dur := time.Since(start)
telemetry.Timing("bricksllm.proxy.get_messages_handler.error_latency", dur, nil, 1)
telemetry.Incr("bricksllm.proxy.get_messages_handler.error_response", nil, 1)
bytes, err := io.ReadAll(res.Body)
Expand Down
Loading

0 comments on commit 876fa3c

Please sign in to comment.