diff --git a/CHANGELOG.md b/CHANGELOG.md
index b2d56b7..e34b8e4 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,73 @@
+## 1.12.1 - 2024-02-28
+### Added
+- Added querying keys by `keyIds`
+- Increased default postgres DB read timeout to `15s` and write timeout to `5s`
+
+## 1.12.0 - 2024-02-28
+### Added
+- Added setting rotation feature to key
+
+## 1.11.0 - 2024-02-28
+### Added
+- Added cost tracking for OpenAI audio endpoints
+- Added inference cost tracking for OpenAI finetune models
+
+## 1.10.0 - 2024-02-21
+### Added
+- Added `userId` as a new filter option for get events API endpoint
+- Added option to store request and response using keys
+
+## 1.9.6 - 2024-02-18
+### Added
+- Added support for updating key cost limit and rate limit
+
+### Changed
+- Removed validation to updating revoked key field
+
+## 1.9.5 - 2024-02-18
+### Added
+- Added new model "gpt-4-turbo-preview" and "gpt-4-vision-preview" to the cost map
+
+## 1.9.4 - 2024-02-16
+### Added
+- Added support for calculating cost for the cheaper 3.5 turbo model
+- Added validation to updating revoked key field
+
+## 1.9.3 - 2024-02-13
+### Added
+- Added CORS support in the proxy
+
+## 1.9.2 - 2024-02-06
+### Fixed
+- Fixed custom route tokens recording issue incurred by the new architecture
+
+## 1.9.1 - 2024-02-06
+### Fixed
+- Fixed OpenAI chat completion endpoint being slow
+
+## 1.9.0 - 2024-02-06
+### Changed
+- Drastically improved performance through event driven architecture
+
+### Fixed
+- Fixed API calls that exceeds cost limit not being blocked bug
+
+## 1.8.2 - 2024-01-31
+### Added
+- Added support for new chat completion models
+- Added new querying options for metrics and events API
+
+## 1.8.1 - 2024-01-31
+### Changed
+- Extended default proxy request timeout to 10m
+
+### Fixed
+- Fixed streaming response stuck at context deadline exceeded error
+
+## 1.8.0 - 2024-01-26
+### Added
+- Added key authentication for admin endpoints
+
## 1.7.6 - 2024-01-17
### Fixed
- Changed code to string in OpenAI error response
diff --git a/README.md b/README.md
index 895c8ba..de3073f 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,9 @@
+> [!TIP]
+> A [managed version of BricksLLM](https://www.trybricks.ai?utm_source=github&utm_medium=repo&utm_campaign=bricksllm) is also available! It is production ready, and comes with a dashboard to make interacting with BricksLLM easier. Try us out for free today!
+
**BricksLLM** is a cloud native AI gateway written in Go. Currently, it provide native support for OpenAI, Anthropic and Azure OpenAI. We let you create API keys that have rate limits, cost limits and TTLs. The API keys can be used in both development and production to achieve fine-grained access control that is not provided by any of the foundational model providers. The proxy is designed to be 100% compatible with existing SDKs.
## Features
@@ -145,10 +148,18 @@ docker pull luyuanxin1995/bricksllm:1.4.0
> | `REDIS_WRITE_TIME_OUT` | optional | Timeout for Redis write operations | `500ms`
> | `IN_MEMORY_DB_UPDATE_INTERVAL` | optional | The interval BricksLLM API gateway polls Postgresql DB for latest key configurations | `1s`
> | `STATS_PROVIDER` | optional | This value can only be datadog. Required for integration with Datadog. |
-> | `PROXY_TIMEOUT` | optional | This value can only be datadog. Required for integration with Datadog. |
+> | `PROXY_TIMEOUT` | optional | Timeout for proxy HTTP requests. |
+> | `ADMIN_PASS` | optional | Simple password authentication for admin endpoints. |
## Configuration Endpoints
The configuration server runs on Port `8001`.
+
+##### Headers
+> | name | type | data type | description |
+> |--------|------------|----------------|------------------------------------------------------|
+> | `X-API-KEY` | optional | `string` | Key authentication header.
+
+
Get keys: GET
/api/key-management/keys
@@ -161,7 +172,8 @@ This endpoint is set up for retrieving key configurations using a query param ca
> |--------|------------|----------------|------------------------------------------------------|
> | `tag` | optional | `string` | Identifier attached to a key configuration |
> | `tags` | optional | `[]string` | Identifiers attached to a key configuration |
-> | `provider` | optional | `string` | Provider attached to a key provider configuration. Its value can only be `openai`. |
+> | `provider` | optional | `string` | Provider attached to a key provider configuration. Its value can only be `openai`.
+> | `keyIds` | optional | `[]string` | Unique identifiers for keys.
##### Error Response
@@ -202,6 +214,9 @@ Fields of KeyConfiguration
> | allowedPaths | `[]PathConfig` | `[{ "path": "/api/providers/openai/v1/chat/completion", "method": "POST"}]` | Allowed paths that can be accessed using the key. |
> | settingId | `string` | `98daa3ae-961d-4253-bf6a-322a32fdca3d` | This field is DEPERCATED. Use `settingIds` field instead. |
> | settingIds | `string` | `[98daa3ae-961d-4253-bf6a-322a32fdca3d]` | Setting ids associated with the key. |
+> | shouldLogRequest | `bool` | `false` | Should request be stored. |
+> | shouldLogResponse | `bool` | `true` | Should response be stored. |
+> | rotationEnabled | `bool` | `false` | Should key rotate setting used to access third party endpoints in order to circumvent rate limits. |
@@ -236,6 +251,9 @@ PathConfig
> | rateLimitUnit | optional | `enum` | m | Time unit for rateLimitOverTime. Possible values are [`h`, `m`, `s`, `d`] |
> | ttl | optional | `string` | 2d | time to live. Available units are [`s`, `m`, `h`]. |
> | allowedPaths | optional | `[]PathConfig` | 2d | Pathes allowed for access. |
+> | shouldLogRequest | optional | `bool` | `false` | Should request be stored. |
+> | shouldLogResponse | optional | `bool` | `true` | Should response be stored. |
+> | rotationEnabled | optional | `bool` | `false` | Should key rotate setting used to access third party endpoints in order to circumvent rate limits. |
##### Error Response
@@ -272,6 +290,9 @@ PathConfig
> | allowedPaths | `[]PathConfig` | `[{ "path": "/api/providers/openai/v1/chat/completion", method: "POST"}]` | Allowed paths that can be accessed using the key. |
> | settingId | `string` | `98daa3ae-961d-4253-bf6a-322a32fdca3d` | This field is DEPERCATED. Use `settingIds` field instead. |
> | settingIds | `string` | `[98daa3ae-961d-4253-bf6a-322a32fdca3d]` | Setting ids associated with the key. |
+> | shouldLogRequest | `bool` | `false` | Should request be stored. |
+> | shouldLogResponse | `bool` | `true` | Should response be stored. |
+> | rotationEnabled | `bool` | `false` | Should key rotate setting used to access third party endpoints in order to circumvent rate limits. |
@@ -302,8 +323,16 @@ PathConfig
> | name | optional | `string` | spike's developer key | Name of the API key. |
> | tags | optional | `[]string` | `["org-tag-12345"]` | Identifiers associated with the key. |
> | revoked | optional | `boolean` | `true` | Indicator for whether the key is revoked. |
-> | revokedReason| optional | `string` | The key has expired | Reason for why the key is revoked. |
-> | allowedPaths | optional | `[]PathConfig` | 2d | Pathes allowed for access. |
+> | revokedReason | optional | `string` | The key has expired | Reason for why the key is revoked. |
+> | costLimitInUsd | optional | `float64` | `5.5` | Total spend limit of the API key.
+> | costLimitInUsdOverTime | optional | `float64` | `2` | Total spend within period of time. This field is required if costLimitInUsdUnit is specified. |
+> | costLimitInUsdUnit | optional | `enum` | `d` | Time unit for costLimitInUsdOverTime. Possible values are [`m`, `h`, `d`, `mo`]. |
+> | rateLimitOverTime | optional | `int` | `2` | rate limit over period of time. This field is required if rateLimitUnit is specified. |
+> | rateLimitUnit | optional | `string` | `m` | Time unit for rateLimitOverTime. Possible values are [`h`, `m`, `s`, `d`] |
+> | allowedPaths | optional | `[{ "path": "/api/providers/openai/v1/chat/completions", "method": "POST"}]` | `` | Pathes allowed for access. |
+> | shouldLogRequest | optional | `bool` | `false` | Should request be stored. |
+> | shouldLogResponse | optional | `bool` | `true` | Should response be stored. |
+> | rotationEnabled | optional | `bool` | `false` | Should key rotate setting used to access third party endpoints in order to circumvent rate limits. |
##### Error Response
@@ -338,6 +367,9 @@ PathConfig
> | allowedPaths | `[]PathConfig` | `[{ "path": "/api/providers/openai/v1/chat/completion", method: "POST"}]` | Allowed paths that can be accessed using the key. |
> | settingId | `string` | `98daa3ae-961d-4253-bf6a-322a32fdca3d` | This field is DEPERCATED. Use `settingIds` field instead. |
> | settingIds | `string` | `[98daa3ae-961d-4253-bf6a-322a32fdca3d]` | Setting ids associated with the key. |
+> | shouldLogRequest | `bool` | `false` | Should request be stored. |
+> | shouldLogResponse | `bool` | `true` | Should response be stored. |
+> | rotationEnabled | `bool` | `false` | Should key rotate setting used to access third party endpoints in order to circumvent rate limits. |
@@ -489,7 +521,9 @@ This endpoint is retrieving aggregated metrics given an array of key ids and tag
> | Field | required | type | example | description |
> |---------------|-----------------------------------|-|-|-|
> | keyIds | required | `[]string` | `["key-1", "key-2", "key-3" ]` | Array of ids that specicify the keys that you want to aggregate stats from. |
-> | tags | required | `[]string` | `["tag-1", "tag-2"]` | Array of tags that specicify the keys that you want to aggregate stats from. |
+> | tags | required | `[]string` | `["tag-1", "tag-2"]` | Array of tags that specicify the key tags that you want to aggregate stats from. |
+> | customIds | required | `[]string` | `["customId-1", "customId-2"]` | A list of custom IDs that you want to aggregate stats from. |
+> | filters | required | `[]string` | `["model", "keyId"]` | Group by data points through different filters(`model`,`keyId` or `customId`). |
> | start | required | `int64` | `1699933571` | Start timestamp for the requested timeseries data. |
> | end | required | `int64` | `1699933571` | End timestamp for the requested timeseries data. |
> | increment | required | `int` | `60` | This field is the increment in seconds for the requested timeseries data. |
@@ -526,6 +560,7 @@ Datapoint
> | successCount | `int` | `555` | Aggregated number of successful http requests over the given time increment. |
> | keyId | `int` | `555.7` | key Id associated with the event. |
> | model | `string` | `gpt-3.5-turbo` | model associated with the event. |
+> | customId | `string` | `customId` | customId associated with the event. |
@@ -538,7 +573,10 @@ This endpoint is for getting events.
##### Query Parameters
> | name | type | data type | description |
> |--------|------------|----------------|------------------------------------------------------|
-> | `customId` | optional | string | Custom identifier attached to an event |
+> | `customId` | optional | `string` | Custom identifier attached to an event. |
+> | `keyIds` | optional | `[]string` | A list of key IDs. |
+> | `start` | required if `keyIds` is specified | `int64` | Start timestamp. |
+> | `end` | required if `keyIds` is specified | `int64` | End timestamp. |
##### Error Response
> | http code | content-type |
@@ -575,6 +613,8 @@ Event
> | path | `string` | `/api/v1/chat/completion` | Provider setting name. |
> | method | `string` | `POST` | Http method for the assoicated proxu request. |
> | custom_id | `string` | `YOUR_CUSTOM_ID` | Custom Id passed by the user in the headers of proxy requests. |
+> | request | `[]byte` | `{}` | Custom Id passed by the user in the headers of proxy requests. |
+> | custom_id | `string` | `YOUR_CUSTOM_ID` | Custom Id passed by the user in the headers of proxy requests. |
diff --git a/cmd/bricksllm/main.go b/cmd/bricksllm/main.go
index 965a550..85746b4 100644
--- a/cmd/bricksllm/main.go
+++ b/cmd/bricksllm/main.go
@@ -14,6 +14,7 @@ import (
"github.com/bricks-cloud/bricksllm/internal/config"
"github.com/bricks-cloud/bricksllm/internal/logger/zap"
"github.com/bricks-cloud/bricksllm/internal/manager"
+ "github.com/bricks-cloud/bricksllm/internal/message"
"github.com/bricks-cloud/bricksllm/internal/provider/anthropic"
"github.com/bricks-cloud/bricksllm/internal/provider/azure"
"github.com/bricks-cloud/bricksllm/internal/provider/custom"
@@ -176,18 +177,31 @@ func main() {
log.Sugar().Fatalf("error connecting to api redis cache: %v", err)
}
+ accessRedisCache := redis.NewClient(&redis.Options{
+ Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort),
+ Password: cfg.RedisPassword,
+ DB: 4,
+ })
+
+ ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ if err := apiRedisCache.Ping(ctx).Err(); err != nil {
+ log.Sugar().Fatalf("error connecting to api redis cache: %v", err)
+ }
+
rateLimitCache := redisStorage.NewCache(rateLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
costLimitCache := redisStorage.NewCache(costLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
costStorage := redisStorage.NewStore(costRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
apiCache := redisStorage.NewCache(apiRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
+ accessCache := redisStorage.NewAccessCache(accessRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
- m := manager.NewManager(store)
+ m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache)
krm := manager.NewReportingManager(costStorage, store, store)
psm := manager.NewProviderSettingsManager(store, psMemStore)
cpm := manager.NewCustomProvidersManager(store, cpMemStore)
rm := manager.NewRouteManager(store, store, rMemStore, psMemStore)
- as, err := admin.NewAdminServer(log, *modePtr, m, krm, psm, cpm, rm)
+ as, err := admin.NewAdminServer(log, *modePtr, m, krm, psm, cpm, rm, cfg.AdminPass)
if err != nil {
log.Sugar().Fatalf("error creating admin http server: %v", err)
}
@@ -214,7 +228,16 @@ func main() {
c := cache.NewCache(apiCache)
- ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ace, aoe, v, rec, rlm, cfg.ProxyTimeout)
+ messageBus := message.NewMessageBus()
+ eventMessageChan := make(chan message.Message)
+ messageBus.Subscribe("event", eventMessageChan)
+
+ handler := message.NewHandler(rec, log, ace, ce, aoe, v, m, rlm, accessCache)
+
+ eventConsumer := message.NewConsumer(eventMessageChan, log, 4, handler.HandleEventWithRequestAndResponse)
+ eventConsumer.StartEventMessageConsumers()
+
+ ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ace, aoe, v, rec, messageBus, rlm, cfg.ProxyTimeout, accessCache)
if err != nil {
log.Sugar().Fatalf("error creating proxy http server: %v", err)
}
@@ -225,6 +248,7 @@ func main() {
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
+ eventConsumer.Stop()
memStore.Stop()
psMemStore.Stop()
cpMemStore.Stop()
diff --git a/cmd/tool/main.go b/cmd/tool/main.go
deleted file mode 100644
index eade6ef..0000000
--- a/cmd/tool/main.go
+++ /dev/null
@@ -1,274 +0,0 @@
-package main
-
-import (
- "context"
- "flag"
- "fmt"
- "os"
- "os/signal"
- "syscall"
- "time"
-
- auth "github.com/bricks-cloud/bricksllm/internal/authenticator"
- "github.com/bricks-cloud/bricksllm/internal/cache"
- "github.com/bricks-cloud/bricksllm/internal/config"
- logger "github.com/bricks-cloud/bricksllm/internal/logger/zap"
- "github.com/bricks-cloud/bricksllm/internal/manager"
- "github.com/bricks-cloud/bricksllm/internal/provider/anthropic"
- "github.com/bricks-cloud/bricksllm/internal/provider/azure"
- "github.com/bricks-cloud/bricksllm/internal/provider/custom"
- "github.com/bricks-cloud/bricksllm/internal/provider/openai"
- "github.com/bricks-cloud/bricksllm/internal/recorder"
- "github.com/bricks-cloud/bricksllm/internal/server/web/admin"
- "github.com/bricks-cloud/bricksllm/internal/server/web/proxy"
- "github.com/bricks-cloud/bricksllm/internal/stats"
- "github.com/bricks-cloud/bricksllm/internal/storage/memdb"
- "github.com/bricks-cloud/bricksllm/internal/storage/postgresql"
- redisStorage "github.com/bricks-cloud/bricksllm/internal/storage/redis"
- "github.com/bricks-cloud/bricksllm/internal/validator"
- "github.com/gin-gonic/gin"
- "github.com/redis/go-redis/v9"
-)
-
-func main() {
- modePtr := flag.String("m", "dev", "select the mode that bricksllm runs in")
- privacyPtr := flag.String("p", "strict", "select the privacy mode that bricksllm runs in")
- flag.Parse()
-
- log := logger.NewZapLogger(*modePtr)
-
- gin.SetMode(gin.ReleaseMode)
-
- cfg, err := config.ParseEnvVariables()
- if err != nil {
- log.Sugar().Fatalf("cannot parse environment variables: %v", err)
- }
-
- err = stats.InitializeClient(cfg.StatsProvider)
- if err != nil {
- log.Sugar().Fatalf("cannot connect to telemetry provider: %v", err)
- }
-
- store, err := postgresql.NewStore(
- fmt.Sprintf("postgresql:///%s?sslmode=%s&user=%s&password=%s&host=%s&port=%s", cfg.PostgresqlDbName, cfg.PostgresqlSslMode, cfg.PostgresqlUsername, cfg.PostgresqlPassword, cfg.PostgresqlHosts, cfg.PostgresqlPort),
- cfg.PostgresqlWriteTimeout,
- cfg.PostgresqlReadTimeout,
- )
-
- if err != nil {
- log.Sugar().Fatalf("cannot connect to postgresql: %v", err)
- }
-
- err = store.CreateCustomProvidersTable()
- if err != nil {
- log.Sugar().Fatalf("error creating custom providers table: %v", err)
- }
-
- err = store.CreateRoutesTable()
- if err != nil {
- log.Sugar().Fatalf("error creating routes table: %v", err)
- }
-
- err = store.AlterRoutesTable()
- if err != nil {
- log.Sugar().Fatalf("error altering routes table: %v", err)
- }
-
- err = store.CreateKeysTable()
- if err != nil {
- log.Sugar().Fatalf("error creating keys table: %v", err)
- }
-
- err = store.AlterKeysTable()
- if err != nil {
- log.Sugar().Fatalf("error altering keys table: %v", err)
- }
-
- err = store.CreateEventsTable()
- if err != nil {
- log.Sugar().Fatalf("error creating events table: %v", err)
- }
-
- err = store.AlterEventsTable()
- if err != nil {
- log.Sugar().Fatalf("error altering events table: %v", err)
- }
-
- err = store.CreateProviderSettingsTable()
- if err != nil {
- log.Sugar().Fatalf("error creating provider settings table: %v", err)
- }
-
- err = store.AlterProviderSettingsTable()
- if err != nil {
- log.Sugar().Fatalf("error altering provider settings table: %v", err)
- }
-
- memStore, err := memdb.NewMemDb(store, log, cfg.InMemoryDbUpdateInterval)
- if err != nil {
- log.Sugar().Fatalf("cannot initialize memdb: %v", err)
- }
- memStore.Listen()
-
- psMemStore, err := memdb.NewProviderSettingsMemDb(store, log, cfg.InMemoryDbUpdateInterval)
- if err != nil {
- log.Sugar().Fatalf("cannot initialize provider settings memdb: %v", err)
- }
- psMemStore.Listen()
-
- cpMemStore, err := memdb.NewCustomProvidersMemDb(store, log, cfg.InMemoryDbUpdateInterval)
- if err != nil {
- log.Sugar().Fatalf("cannot initialize custom providers memdb: %v", err)
- }
- cpMemStore.Listen()
-
- rMemStore, err := memdb.NewRoutesMemDb(store, log, cfg.InMemoryDbUpdateInterval)
- if err != nil {
- log.Sugar().Fatalf("cannot initialize routes memdb: %v", err)
- }
- rMemStore.Listen()
-
- rateLimitRedisCache := redis.NewClient(&redis.Options{
- Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort),
- Password: cfg.RedisPassword,
- DB: 0,
- })
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- if err := rateLimitRedisCache.Ping(ctx).Err(); err != nil {
- log.Sugar().Fatalf("error connecting to rate limit redis cache: %v", err)
- }
-
- costLimitRedisCache := redis.NewClient(&redis.Options{
- Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort),
- Password: cfg.RedisPassword,
- DB: 1,
- })
-
- ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- if err := costLimitRedisCache.Ping(ctx).Err(); err != nil {
- log.Sugar().Fatalf("error connecting to cost limit redis cache: %v", err)
- }
-
- costRedisStorage := redis.NewClient(&redis.Options{
- Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort),
- Password: cfg.RedisPassword,
- DB: 2,
- })
-
- ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- if err := costRedisStorage.Ping(ctx).Err(); err != nil {
- log.Sugar().Fatalf("error connecting to cost limit redis storage: %v", err)
- }
-
- apiRedisCache := redis.NewClient(&redis.Options{
- Addr: fmt.Sprintf("%s:%s", cfg.RedisHosts, cfg.RedisPort),
- Password: cfg.RedisPassword,
- DB: 3,
- })
-
- ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- if err := apiRedisCache.Ping(ctx).Err(); err != nil {
- log.Sugar().Fatalf("error connecting to api redis cache: %v", err)
- }
-
- rateLimitCache := redisStorage.NewCache(rateLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
- costLimitCache := redisStorage.NewCache(costLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
- costStorage := redisStorage.NewStore(costRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
- apiCache := redisStorage.NewCache(apiRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
-
- m := manager.NewManager(store)
- krm := manager.NewReportingManager(costStorage, store, store)
- psm := manager.NewProviderSettingsManager(store, psMemStore)
- cpm := manager.NewCustomProvidersManager(store, cpMemStore)
- rm := manager.NewRouteManager(store, store, rMemStore, psMemStore)
-
- as, err := admin.NewAdminServer(log, *modePtr, m, krm, psm, cpm, rm)
- if err != nil {
- log.Sugar().Fatalf("error creating admin http server: %v", err)
- }
- as.Run()
-
- tc := openai.NewTokenCounter()
- custom.NewTokenCounter()
- atc, err := anthropic.NewTokenCounter()
- if err != nil {
- log.Sugar().Fatalf("error creating anthropic token counter: %v", err)
- }
-
- ae := anthropic.NewCostEstimator(atc)
-
- ce := openai.NewCostEstimator(openai.OpenAiPerThousandTokenCost, tc)
- v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage)
- rec := recorder.NewRecorder(costStorage, costLimitCache, ce, store)
- rlm := manager.NewRateLimitManager(rateLimitCache)
- a := auth.NewAuthenticator(psm, memStore, rm)
-
- c := cache.NewCache(apiCache)
-
- aoe := azure.NewCostEstimator()
-
- ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ae, aoe, v, rec, rlm, cfg.ProxyTimeout)
- if err != nil {
- log.Sugar().Fatalf("error creating proxy http server: %v", err)
- }
-
- ps.Run()
-
- quit := make(chan os.Signal)
- signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
- <-quit
-
- memStore.Stop()
- psMemStore.Stop()
- cpMemStore.Stop()
-
- log.Sugar().Info("shutting down server...")
-
- ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- if err := as.Shutdown(ctx); err != nil {
- log.Sugar().Debugf("admin server shutdown: %v", err)
- }
-
- ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- if err := ps.Shutdown(ctx); err != nil {
- log.Sugar().Debugf("proxy server shutdown: %v", err)
- }
-
- select {
- case <-ctx.Done():
- log.Sugar().Infof("timeout of 5 seconds")
- }
-
- err = store.DropKeysTable()
- if err != nil {
- log.Sugar().Fatalf("error dropping keys table: %v", err)
- }
-
- err = store.DropEventsTable()
- if err != nil {
- log.Sugar().Fatalf("error dropping events table: %v", err)
- }
-
- err = store.DropCustomProvidersTable()
- if err != nil {
- log.Sugar().Fatalf("error dropping custom providers table: %v", err)
- }
-
- err = store.DropProviderSettingsTable()
- if err != nil {
- log.Sugar().Fatalf("error dropping provider settings table: %v", err)
- }
-
- err = store.DropRoutesTable()
- if err != nil {
- log.Sugar().Fatalf("error dropping routes table: %v", err)
- }
-
- log.Sugar().Infof("server exited")
-}
diff --git a/docker-compose.yml b/docker-compose.yml
index 911b598..07f9717 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -18,22 +18,22 @@ services:
- '5432:5432'
volumes:
- postgresql:/var/lib/postgresql/data
- bricksllm:
- depends_on:
- - redis
- - postgresql
- image: luyuanxin1995/bricksllm
- environment:
- - POSTGRESQL_USERNAME=postgres
- - POSTGRESQL_PASSWORD=postgres
- - REDIS_PASSWORD=eYVX7EwVmmxKPCDmwMtyKVge8oLd2t81
- - POSTGRESQL_HOSTS=postgresql
- - REDIS_HOSTS=redis
- ports:
- - '8001:8001'
- - '8002:8002'
- command:
- - '-m=dev'
+ # bricksllm:
+ # depends_on:
+ # - redis
+ # - postgresql
+ # image: luyuanxin1995/bricksllm
+ # environment:
+ # - POSTGRESQL_USERNAME=postgres
+ # - POSTGRESQL_PASSWORD=postgres
+ # - REDIS_PASSWORD=eYVX7EwVmmxKPCDmwMtyKVge8oLd2t81
+ # - POSTGRESQL_HOSTS=postgresql
+ # - REDIS_HOSTS=redis
+ # ports:
+ # - '8001:8001'
+ # - '8002:8002'
+ # command:
+ # - '-m=dev'
volumes:
redis:
driver: local
diff --git a/go.mod b/go.mod
index 1b4a791..ac6d969 100644
--- a/go.mod
+++ b/go.mod
@@ -12,13 +12,16 @@ require (
github.com/mattn/go-colorable v0.1.13
github.com/pkoukk/tiktoken-go-loader v0.0.1
github.com/redis/go-redis/v9 v9.0.5
- github.com/sashabaranov/go-openai v1.17.7
+ github.com/sashabaranov/go-openai v1.20.1
github.com/stretchr/testify v1.8.4
go.uber.org/zap v1.24.0
)
require (
github.com/Microsoft/go-winio v0.5.0 // indirect
+ github.com/asticode/go-astikit v0.20.0 // indirect
+ github.com/asticode/go-astisub v0.26.2 // indirect
+ github.com/asticode/go-astits v1.8.0 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
diff --git a/go.sum b/go.sum
index fc57456..49b7bf3 100644
--- a/go.sum
+++ b/go.sum
@@ -2,6 +2,12 @@ github.com/DataDog/datadog-go/v5 v5.3.0 h1:2q2qjFOb3RwAZNU+ez27ZVDwErJv5/VpbBPpr
github.com/DataDog/datadog-go/v5 v5.3.0/go.mod h1:XRDJk1pTc00gm+ZDiBKsjh7oOOtJfYfglVCmFb8C2+Q=
github.com/Microsoft/go-winio v0.5.0 h1:Elr9Wn+sGKPlkaBvwu4mTrxtmOp3F3yV9qhaHbXGjwU=
github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84=
+github.com/asticode/go-astikit v0.20.0 h1:+7N+J4E4lWx2QOkRdOf6DafWJMv6O4RRfgClwQokrH8=
+github.com/asticode/go-astikit v0.20.0/go.mod h1:h4ly7idim1tNhaVkdVBeXQZEE3L0xblP7fCWbgwipF0=
+github.com/asticode/go-astisub v0.26.2 h1:cdEXcm+SUSmYCEPTQYbbfCECnmQoIFfH6pF8wDJhfVo=
+github.com/asticode/go-astisub v0.26.2/go.mod h1:WTkuSzFB+Bp7wezuSf2Oxulj5A8zu2zLRVFf6bIFQK8=
+github.com/asticode/go-astits v1.8.0 h1:rf6aiiGn/QhlFjNON1n5plqF3Fs025XLUwiQ0NB6oZg=
+github.com/asticode/go-astits v1.8.0/go.mod h1:DkOWmBNQpnr9mv24KfZjq4JawCFX1FCqjLVGvO0DygQ=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao=
github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y=
@@ -71,6 +77,7 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE=
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
@@ -85,6 +92,10 @@ github.com/sashabaranov/go-openai v1.17.1 h1:tapFKbKE8ep0/qGkKp5Q3TtxWUD7m9VIFe9
github.com/sashabaranov/go-openai v1.17.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx6+/ndloM=
github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
+github.com/sashabaranov/go-openai v1.19.2 h1:+dkuCADSnwXV02YVJkdphY8XD9AyHLUWwk6V7LB6EL8=
+github.com/sashabaranov/go-openai v1.19.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
+github.com/sashabaranov/go-openai v1.20.1 h1:cFnTixAtc0I0cCBFr8gkvEbGCm6Rjf2JyoVWCjXwy9g=
+github.com/sashabaranov/go-openai v1.20.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
@@ -92,6 +103,7 @@ github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
@@ -124,11 +136,13 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
@@ -137,6 +151,7 @@ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -148,6 +163,7 @@ golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
@@ -164,6 +180,7 @@ google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cn
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/internal/authenticator/authenticator.go b/internal/authenticator/authenticator.go
index f95fca9..94aa3c9 100644
--- a/internal/authenticator/authenticator.go
+++ b/internal/authenticator/authenticator.go
@@ -3,6 +3,7 @@ package auth
import (
"errors"
"fmt"
+ "math/rand"
"net/http"
"strings"
@@ -196,8 +197,12 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.Respons
}
if len(selected) != 0 {
- err := rewriteHttpAuthHeader(req, selected[0])
+ used := selected[0]
+ if key.RotationEnabled {
+ used = selected[rand.Intn(len(selected))]
+ }
+ err := rewriteHttpAuthHeader(req, used)
if err != nil {
return nil, nil, err
}
diff --git a/internal/config/config.go b/internal/config/config.go
index 9b675ea..0ebb895 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -7,24 +7,26 @@ import (
)
type Config struct {
- PostgresqlHosts string `env:"POSTGRESQL_HOSTS" envSeparator:":" envDefault:"localhost"`
- PostgresqlDbName string `env:"POSTGRESQL_DB_NAME"`
- PostgresqlUsername string `env:"POSTGRESQL_USERNAME"`
- PostgresqlPassword string `env:"POSTGRESQL_PASSWORD"`
- PostgresqlSslMode string `env:"POSTGRESQL_SSL_MODE" envDefault:"disable"`
- PostgresqlPort string `env:"POSTGRESQL_PORT" envDefault:"5432"`
- RedisHosts string `env:"REDIS_HOSTS" envSeparator:":" envDefault:"localhost"`
- RedisPort string `env:"REDIS_PORT" envDefault:"6379"`
- RedisUsername string `env:"REDIS_USERNAME"`
- RedisPassword string `env:"REDIS_PASSWORD"`
- RedisReadTimeout time.Duration `env:"REDIS_READ_TIME_OUT" envDefault:"1s"`
- RedisWriteTimeout time.Duration `env:"REDIS_WRITE_TIME_OUT" envDefault:"500ms"`
- PostgresqlReadTimeout time.Duration `env:"POSTGRESQL_READ_TIME_OUT" envDefault:"2s"`
- PostgresqlWriteTimeout time.Duration `env:"POSTGRESQL_WRITE_TIME_OUT" envDefault:"1s"`
- InMemoryDbUpdateInterval time.Duration `env:"IN_MEMORY_DB_UPDATE_INTERVAL" envDefault:"5s"`
- OpenAiKey string `env:"OPENAI_API_KEY"`
- StatsProvider string `env:"STATS_PROVIDER"`
- ProxyTimeout time.Duration `env:"PROXY_TIMEOUT" envDefault:"180s"`
+ PostgresqlHosts string `env:"POSTGRESQL_HOSTS" envSeparator:":" envDefault:"localhost"`
+ PostgresqlDbName string `env:"POSTGRESQL_DB_NAME"`
+ PostgresqlUsername string `env:"POSTGRESQL_USERNAME"`
+ PostgresqlPassword string `env:"POSTGRESQL_PASSWORD"`
+ PostgresqlSslMode string `env:"POSTGRESQL_SSL_MODE" envDefault:"disable"`
+ PostgresqlPort string `env:"POSTGRESQL_PORT" envDefault:"5432"`
+ RedisHosts string `env:"REDIS_HOSTS" envSeparator:":" envDefault:"localhost"`
+ RedisPort string `env:"REDIS_PORT" envDefault:"6379"`
+ RedisUsername string `env:"REDIS_USERNAME"`
+ RedisPassword string `env:"REDIS_PASSWORD"`
+ RedisReadTimeout time.Duration `env:"REDIS_READ_TIME_OUT" envDefault:"1s"`
+ RedisWriteTimeout time.Duration `env:"REDIS_WRITE_TIME_OUT" envDefault:"500ms"`
+ PostgresqlReadTimeout time.Duration `env:"POSTGRESQL_READ_TIME_OUT" envDefault:"15s"`
+ PostgresqlWriteTimeout time.Duration `env:"POSTGRESQL_WRITE_TIME_OUT" envDefault:"5s"`
+ InMemoryDbUpdateInterval time.Duration `env:"IN_MEMORY_DB_UPDATE_INTERVAL" envDefault:"5s"`
+ OpenAiKey string `env:"OPENAI_API_KEY"`
+ StatsProvider string `env:"STATS_PROVIDER"`
+ AdminPass string `env:"ADMIN_PASS"`
+ ProxyTimeout time.Duration `env:"PROXY_TIMEOUT" envDefault:"600s"`
+ NumberOfEventMessageConsumers int `env:"NUMBER_OF_EVENT_MESSAGE_CONSUMERS" envDefault:"3"`
}
func ParseEnvVariables() (*Config, error) {
diff --git a/internal/errors/cost_limit_err.go b/internal/errors/cost_limit_err.go
new file mode 100644
index 0000000..15c92de
--- /dev/null
+++ b/internal/errors/cost_limit_err.go
@@ -0,0 +1,17 @@
+package errors
+
+type CostLimitError struct {
+ message string
+}
+
+func NewCostLimitError(msg string) *CostLimitError {
+ return &CostLimitError{
+ message: msg,
+ }
+}
+
+func (cle *CostLimitError) Error() string {
+ return cle.message
+}
+
+func (rle *CostLimitError) CostLimit() {}
diff --git a/internal/event/event.go b/internal/event/event.go
index c0f0152..7f299ab 100644
--- a/internal/event/event.go
+++ b/internal/event/event.go
@@ -15,4 +15,7 @@ type Event struct {
Path string `json:"path"`
Method string `json:"method"`
CustomId string `json:"custom_id"`
+ Request []byte `json:"request"`
+ Response []byte `json:"response"`
+ UserId string `json:"userId"`
}
diff --git a/internal/event/event_with_request_and_response.go b/internal/event/event_with_request_and_response.go
new file mode 100644
index 0000000..8e99f0b
--- /dev/null
+++ b/internal/event/event_with_request_and_response.go
@@ -0,0 +1,16 @@
+package event
+
+import (
+ "github.com/bricks-cloud/bricksllm/internal/key"
+ "github.com/bricks-cloud/bricksllm/internal/provider/custom"
+)
+
+type EventWithRequestAndContent struct {
+ Event *Event
+ IsEmbeddingsRequest bool
+ RouteConfig *custom.RouteConfig
+ Request interface{}
+ Content string
+ Response interface{}
+ Key *key.ResponseKey
+}
diff --git a/internal/event/reporting.go b/internal/event/reporting.go
index 7ccb61c..91dbf6c 100644
--- a/internal/event/reporting.go
+++ b/internal/event/reporting.go
@@ -10,6 +10,8 @@ type DataPoint struct {
SuccessCount int `json:"successCount"`
Model string `json:"model"`
KeyId string `json:"keyId"`
+ CustomId string `json:"customId"`
+ UserId string `json:"userId"`
}
type ReportingResponse struct {
@@ -21,6 +23,8 @@ type ReportingResponse struct {
type ReportingRequest struct {
KeyIds []string `json:"keyIds"`
Tags []string `json:"tags"`
+ CustomIds []string `json:"customIds"`
+ UserIds []string `json:"userIds"`
Start int64 `json:"start"`
End int64 `json:"end"`
Increment int64 `json:"increment"`
diff --git a/internal/key/key.go b/internal/key/key.go
index 7405e30..a8b0be5 100644
--- a/internal/key/key.go
+++ b/internal/key/key.go
@@ -9,15 +9,25 @@ import (
internal_errors "github.com/bricks-cloud/bricksllm/internal/errors"
)
+const RevokedReasonExpired string = "expired"
+
type UpdateKey struct {
- Name string `json:"name"`
- UpdatedAt int64 `json:"updatedAt"`
- Tags []string `json:"tags"`
- Revoked *bool `json:"revoked"`
- RevokedReason string `json:"revokedReason"`
- SettingId string `json:"settingId"`
- SettingIds []string `json:"settingIds"`
- AllowedPaths *[]PathConfig `json:"allowedPaths,omitempty"`
+ Name string `json:"name"`
+ UpdatedAt int64 `json:"updatedAt"`
+ Tags []string `json:"tags"`
+ Revoked *bool `json:"revoked"`
+ RevokedReason string `json:"revokedReason"`
+ SettingId string `json:"settingId"`
+ SettingIds []string `json:"settingIds"`
+ CostLimitInUsd float64 `json:"costLimitInUsd"`
+ CostLimitInUsdOverTime float64 `json:"costLimitInUsdOverTime"`
+ CostLimitInUsdUnit TimeUnit `json:"costLimitInUsdUnit"`
+ RateLimitOverTime int `json:"rateLimitOverTime"`
+ RateLimitUnit TimeUnit `json:"rateLimitUnit"`
+ AllowedPaths *[]PathConfig `json:"allowedPaths,omitempty"`
+ ShouldLogRequest *bool `json:"shouldLogRequest"`
+ ShouldLogResponse *bool `json:"shouldLogResponse"`
+ RotationEnabled *bool `json:"rotationEnabled"`
}
func (uk *UpdateKey) Validate() error {
@@ -34,6 +44,10 @@ func (uk *UpdateKey) Validate() error {
}
}
+ if uk.CostLimitInUsd < 0 {
+ invalid = append(invalid, "costLimitInUsd")
+ }
+
if uk.UpdatedAt <= 0 {
invalid = append(invalid, "updatedAt")
}
@@ -64,6 +78,34 @@ func (uk *UpdateKey) Validate() error {
return internal_errors.NewValidationError(fmt.Sprintf("fields [%s] are invalid", strings.Join(invalid, ", ")))
}
+ if len(uk.RateLimitUnit) != 0 && uk.RateLimitOverTime == 0 {
+ return internal_errors.NewValidationError("rate limit over time can not be empty if rate limit unit is specified")
+ }
+
+ if len(uk.CostLimitInUsdUnit) != 0 && uk.CostLimitInUsdOverTime == 0 {
+ return internal_errors.NewValidationError("cost limit over time can not be empty if cost limit unit is specified")
+ }
+
+ if uk.RateLimitOverTime != 0 {
+ if len(uk.RateLimitUnit) == 0 {
+ return internal_errors.NewValidationError("rate limit unit can not be empty if rate limit over time is specified")
+ }
+
+ if uk.RateLimitUnit != HourTimeUnit && uk.RateLimitUnit != MinuteTimeUnit && uk.RateLimitUnit != SecondTimeUnit && uk.RateLimitUnit != DayTimeUnit {
+ return internal_errors.NewValidationError("rate limit unit can not be identified")
+ }
+ }
+
+ if uk.CostLimitInUsdOverTime != 0 {
+ if len(uk.CostLimitInUsdUnit) == 0 {
+ return internal_errors.NewValidationError("cost limit unit can not be empty if cost limit over time is specified")
+ }
+
+ if uk.CostLimitInUsdUnit != DayTimeUnit && uk.CostLimitInUsdUnit != HourTimeUnit && uk.CostLimitInUsdUnit != MonthTimeUnit && uk.CostLimitInUsdUnit != MinuteTimeUnit {
+ return internal_errors.NewValidationError("cost limit unit can not be identified")
+ }
+ }
+
return nil
}
@@ -88,6 +130,9 @@ type RequestKey struct {
SettingId string `json:"settingId"`
AllowedPaths []PathConfig `json:"allowedPaths"`
SettingIds []string `json:"settingIds"`
+ ShouldLogRequest bool `json:"shouldLogRequest"`
+ ShouldLogResponse bool `json:"shouldLogResponse"`
+ RotationEnabled bool `json:"rotationEnabled"`
}
func (rk *RequestKey) Validate() error {
@@ -232,6 +277,9 @@ type ResponseKey struct {
SettingId string `json:"settingId"`
AllowedPaths []PathConfig `json:"allowedPaths"`
SettingIds []string `json:"settingIds"`
+ ShouldLogRequest bool `json:"shouldLogRequest"`
+ ShouldLogResponse bool `json:"shouldLogResponse"`
+ RotationEnabled bool `json:"rotationEnabled"`
}
func (rk *ResponseKey) GetSettingIds() []string {
diff --git a/internal/manager/key.go b/internal/manager/key.go
index 185fec1..59292aa 100644
--- a/internal/manager/key.go
+++ b/internal/manager/key.go
@@ -8,8 +8,6 @@ import (
"github.com/bricks-cloud/bricksllm/internal/key"
"github.com/bricks-cloud/bricksllm/internal/provider"
"github.com/bricks-cloud/bricksllm/internal/util"
-
- internal_errors "github.com/bricks-cloud/bricksllm/internal/errors"
)
type Storage interface {
@@ -19,6 +17,19 @@ type Storage interface {
DeleteKey(id string) error
GetProviderSetting(id string) (*provider.Setting, error)
GetProviderSettings(withSecret bool, ids []string) ([]*provider.Setting, error)
+ GetKey(keyId string) (*key.ResponseKey, error)
+}
+
+type costLimitCache interface {
+ Delete(keyId string) error
+}
+
+type rateLimitCache interface {
+ Delete(keyId string) error
+}
+
+type accessCache interface {
+ Delete(keyId string) error
}
type Encrypter interface {
@@ -26,12 +37,18 @@ type Encrypter interface {
}
type Manager struct {
- s Storage
+ s Storage
+ clc costLimitCache
+ rlc rateLimitCache
+ ac accessCache
}
-func NewManager(s Storage) *Manager {
+func NewManager(s Storage, clc costLimitCache, rlc rateLimitCache, ac accessCache) *Manager {
return &Manager{
- s: s,
+ s: s,
+ clc: clc,
+ rlc: rlc,
+ ac: ac,
}
}
@@ -39,20 +56,6 @@ func (m *Manager) GetKeys(tags, keyIds []string, provider string) ([]*key.Respon
return m.s.GetKeys(tags, keyIds, provider)
}
-func (m *Manager) areProviderSettingsUniqueness(settings []*provider.Setting) bool {
- providerMap := map[string]bool{}
-
- for _, setting := range settings {
- if providerMap[setting.Provider] {
- return false
- }
-
- providerMap[setting.Provider] = true
- }
-
- return true
-}
-
func (m *Manager) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
rk.CreatedAt = time.Now().Unix()
rk.UpdatedAt = time.Now().Unix()
@@ -78,11 +81,6 @@ func (m *Manager) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
if len(existing) == 0 {
return nil, errors.New("provider settings not found")
}
-
- if !m.areProviderSettingsUniqueness(existing) {
- return nil, internal_errors.NewValidationError("key can only be assoicated with one setting per provider")
- }
-
}
return m.s.CreateKey(rk)
@@ -110,9 +108,26 @@ func (m *Manager) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, err
if len(existing) == 0 {
return nil, errors.New("provider settings not found")
}
+ }
+
+ if len(uk.CostLimitInUsdUnit) != 0 {
+ err := m.clc.Delete(id)
+ if err != nil {
+ return nil, err
+ }
+ }
- if !m.areProviderSettingsUniqueness(existing) {
- return nil, internal_errors.NewValidationError("key can only be assoicated with one setting per provider")
+ if len(uk.RateLimitUnit) != 0 {
+ err := m.rlc.Delete(id)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if len(uk.CostLimitInUsdUnit) != 0 || len(uk.RateLimitUnit) != 0 {
+ err := m.ac.Delete(id)
+ if err != nil {
+ return nil, err
}
}
diff --git a/internal/manager/provider_setting.go b/internal/manager/provider_setting.go
index 5ba2a7c..76299b6 100644
--- a/internal/manager/provider_setting.go
+++ b/internal/manager/provider_setting.go
@@ -44,7 +44,7 @@ func findMissingAuthParams(providerName string, params map[string]string) string
missingFields := []string{}
if providerName == "openai" || providerName == "anthropic" {
- val, _ := params["apikey"]
+ val := params["apikey"]
if len(val) == 0 {
missingFields = append(missingFields, "apikey")
}
@@ -53,12 +53,12 @@ func findMissingAuthParams(providerName string, params map[string]string) string
}
if providerName == "azure" {
- val, _ := params["resourceName"]
+ val := params["resourceName"]
if len(val) == 0 {
missingFields = append(missingFields, "resourceName")
}
- val, _ = params["apikey"]
+ val = params["apikey"]
if len(val) == 0 {
missingFields = append(missingFields, "apikey")
}
@@ -76,7 +76,7 @@ func (m *ProviderSettingsManager) validateSettings(providerName string, setting
}
if len(provider.AuthenticationParam) != 0 {
- val, _ := setting[provider.AuthenticationParam]
+ val := setting[provider.AuthenticationParam]
if len(val) == 0 {
return internal_errors.NewValidationError(fmt.Sprintf("provider %s is missing value for field %s", providerName, provider.AuthenticationParam))
}
diff --git a/internal/manager/reporting.go b/internal/manager/reporting.go
index 57a7ae6..ef9f4d6 100644
--- a/internal/manager/reporting.go
+++ b/internal/manager/reporting.go
@@ -1,9 +1,6 @@
package manager
import (
- "errors"
- "fmt"
-
internal_errors "github.com/bricks-cloud/bricksllm/internal/errors"
"github.com/bricks-cloud/bricksllm/internal/event"
"github.com/bricks-cloud/bricksllm/internal/key"
@@ -18,8 +15,8 @@ type keyStorage interface {
}
type eventStorage interface {
- GetEvents(customId string) ([]*event.Event, error)
- GetEventDataPoints(start, end, increment int64, tags, keyIds []string, filters []string) ([]*event.DataPoint, error)
+ GetEvents(userId, customId string, keyIds []string, start, end int64) ([]*event.Event, error)
+ GetEventDataPoints(start, end, increment int64, tags, keyIds, customIds, userIds []string, filters []string) ([]*event.DataPoint, error)
GetLatencyPercentiles(start, end int64, tags, keyIds []string) ([]float64, error)
}
@@ -38,7 +35,7 @@ func NewReportingManager(cs costStorage, ks keyStorage, es eventStorage) *Report
}
func (rm *ReportingManager) GetEventReporting(e *event.ReportingRequest) (*event.ReportingResponse, error) {
- dataPoints, err := rm.es.GetEventDataPoints(e.Start, e.End, e.Increment, e.Tags, e.KeyIds, e.Filters)
+ dataPoints, err := rm.es.GetEventDataPoints(e.Start, e.End, e.Increment, e.Tags, e.KeyIds, e.CustomIds, e.UserIds, e.Filters)
if err != nil {
return nil, err
}
@@ -80,25 +77,8 @@ func (rm *ReportingManager) GetKeyReporting(keyId string) (*key.KeyReporting, er
}, err
}
-func (rm *ReportingManager) GetEvent(customId string) (*event.Event, error) {
- if len(customId) == 0 {
- return nil, errors.New("customId cannot be empty")
- }
-
- events, err := rm.es.GetEvents(customId)
- if err != nil {
- return nil, err
- }
-
- if len(events) >= 1 {
- return events[0], nil
- }
-
- return nil, internal_errors.NewNotFoundError(fmt.Sprintf("event is not found for customId: %s", customId))
-}
-
-func (rm *ReportingManager) GetEvents(customId string) ([]*event.Event, error) {
- events, err := rm.es.GetEvents(customId)
+func (rm *ReportingManager) GetEvents(userId, customId string, keyIds []string, start, end int64) ([]*event.Event, error) {
+ events, err := rm.es.GetEvents(userId, customId, keyIds, start, end)
if err != nil {
return nil, err
}
diff --git a/internal/manager/route.go b/internal/manager/route.go
index b3c8d8b..eb3c4ed 100644
--- a/internal/manager/route.go
+++ b/internal/manager/route.go
@@ -228,16 +228,16 @@ func (m *RouteManager) validateRoute(r *route.Route) error {
}
if !contains(step.Provider, supportedProviders) {
- return errors.New(fmt.Sprintf("steps.[%d].provider is not supported. Only azure and openai are supported", index))
+ return fmt.Errorf("steps.[%d].provider is not supported. Only azure and openai are supported", index)
}
if step.Provider == "azure" {
- apiVersion, _ := step.Params["apiVersion"]
+ apiVersion := step.Params["apiVersion"]
if len(apiVersion) == 0 {
fields = append(fields, fmt.Sprintf("steps.[%d].params.apiVersion", index))
}
- deploymentId, _ := step.Params["deploymentId"]
+ deploymentId := step.Params["deploymentId"]
if len(deploymentId) == 0 {
fields = append(fields, fmt.Sprintf("steps.[%d].params.deploymentId", index))
}
@@ -248,11 +248,11 @@ func (m *RouteManager) validateRoute(r *route.Route) error {
}
if !contains(step.Model, supportedModels) {
- return errors.New(fmt.Sprintf("steps.[%d].model is not supported. Only chat completion and embeddings model are supported.", index))
+ return fmt.Errorf("steps.[%d].model is not supported. Only chat completion and embeddings model are supported", index)
}
if !checkModelValidity(step.Provider, step.Model) {
- return errors.New(fmt.Sprintf("model: %s is not supported for provider: %s.", step.Model, step.Provider))
+ return fmt.Errorf("model: %s is not supported for provider: %s", step.Model, step.Provider)
}
if !containAda && contains(step.Model, adaModels) {
@@ -262,11 +262,11 @@ func (m *RouteManager) validateRoute(r *route.Route) error {
for _, step := range r.Steps {
if containAda && !contains(step.Model, adaModels) {
- return errors.New("steps must have congruent models. Chat completion and embedding models cannot be in the same route config.")
+ return errors.New("steps must have congruent models. Chat completion and embedding models cannot be in the same route config")
}
if !containAda && !contains(step.Model, chatCompletionModels) {
- return errors.New("steps must have congruent models. Chat completion and embedding models cannot be in the same route config.")
+ return errors.New("steps must have congruent models. Chat completion and embedding models cannot be in the same route config")
}
}
diff --git a/internal/message/bus.go b/internal/message/bus.go
new file mode 100644
index 0000000..7fe0749
--- /dev/null
+++ b/internal/message/bus.go
@@ -0,0 +1,23 @@
+package message
+
+type MessageBus struct {
+ Subscribers map[string][]chan<- Message
+}
+
+func NewMessageBus() *MessageBus {
+ return &MessageBus{
+ Subscribers: make(map[string][]chan<- Message),
+ }
+}
+
+func (mb *MessageBus) Subscribe(messageType string, subscriber chan<- Message) {
+ mb.Subscribers[messageType] = append(mb.Subscribers[messageType], subscriber)
+}
+
+func (mb *MessageBus) Publish(ms Message) {
+ subscribers := mb.Subscribers[ms.Type]
+
+ for _, subscriber := range subscribers {
+ subscriber <- ms
+ }
+}
diff --git a/internal/message/consumer.go b/internal/message/consumer.go
new file mode 100644
index 0000000..e8e4e02
--- /dev/null
+++ b/internal/message/consumer.go
@@ -0,0 +1,58 @@
+package message
+
+import (
+ "github.com/bricks-cloud/bricksllm/internal/event"
+ "github.com/bricks-cloud/bricksllm/internal/key"
+ "go.uber.org/zap"
+)
+
+type Consumer struct {
+ messageChan <-chan Message
+ done chan bool
+ log *zap.Logger
+ numOfEventConsumers int
+ handle func(Message) error
+}
+
+type recorder interface {
+ RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error
+ RecordEvent(e *event.Event) error
+}
+
+func NewConsumer(mc <-chan Message, log *zap.Logger, num int, handle func(Message) error) *Consumer {
+ return &Consumer{
+ messageChan: mc,
+ done: make(chan bool),
+ log: log,
+ numOfEventConsumers: num,
+ handle: handle,
+ }
+}
+
+func (c *Consumer) StartEventMessageConsumers() {
+ for i := 0; i < c.numOfEventConsumers; i++ {
+ go func() {
+ for {
+ select {
+ case <-c.done:
+ c.log.Info("event message consumer stoped...")
+ return
+
+ case m := <-c.messageChan:
+ err := c.handle(m)
+ if err != nil {
+ continue
+ }
+
+ continue
+ }
+ }
+ }()
+ }
+}
+
+func (c *Consumer) Stop() {
+ c.log.Info("shutting down consumer...")
+
+ c.done <- true
+}
diff --git a/internal/message/handler.go b/internal/message/handler.go
new file mode 100644
index 0000000..5ceea1e
--- /dev/null
+++ b/internal/message/handler.go
@@ -0,0 +1,435 @@
+package message
+
+import (
+ "errors"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/bricks-cloud/bricksllm/internal/event"
+ "github.com/bricks-cloud/bricksllm/internal/key"
+ "github.com/bricks-cloud/bricksllm/internal/provider/anthropic"
+ "github.com/bricks-cloud/bricksllm/internal/provider/custom"
+ "github.com/bricks-cloud/bricksllm/internal/stats"
+ "github.com/tidwall/gjson"
+ "go.uber.org/zap"
+
+ goopenai "github.com/sashabaranov/go-openai"
+)
+
+type anthropicEstimator interface {
+ EstimateTotalCost(model string, promptTks, completionTks int) (float64, error)
+ EstimateCompletionCost(model string, tks int) (float64, error)
+ EstimatePromptCost(model string, tks int) (float64, error)
+ Count(input string) int
+}
+
+type estimator interface {
+ EstimateSpeechCost(input string, model string) (float64, error)
+ EstimateChatCompletionPromptCostWithTokenCounts(r *goopenai.ChatCompletionRequest) (int, float64, error)
+ EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error)
+ EstimateChatCompletionStreamCostWithTokenCounts(model, content string) (int, float64, error)
+ EstimateCompletionCost(model string, tks int) (float64, error)
+ EstimateTotalCost(model string, promptTks, completionTks int) (float64, error)
+ EstimateEmbeddingsInputCost(model string, tks int) (float64, error)
+ EstimateChatCompletionPromptTokenCounts(model string, r *goopenai.ChatCompletionRequest) (int, error)
+}
+
+type azureEstimator interface {
+ EstimateChatCompletionStreamCostWithTokenCounts(model, content string) (int, float64, error)
+ EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error)
+ EstimateCompletionCost(model string, tks int) (float64, error)
+ EstimatePromptCost(model string, tks int) (float64, error)
+ EstimateTotalCost(model string, promptTks, completionTks int) (float64, error)
+ EstimateEmbeddingsInputCost(model string, tks int) (float64, error)
+}
+
+type validator interface {
+ Validate(k *key.ResponseKey, promptCost float64) error
+}
+
+type keyManager interface {
+ UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error)
+}
+
+type rateLimitManager interface {
+ Increment(keyId string, timeUnit key.TimeUnit) error
+}
+
+type accessCache interface {
+ Set(key string, timeUnit key.TimeUnit) error
+}
+
+type Handler struct {
+ recorder recorder
+ log *zap.Logger
+ ae anthropicEstimator
+ e estimator
+ aze azureEstimator
+ v validator
+ km keyManager
+ rlm rateLimitManager
+ ac accessCache
+}
+
+func NewHandler(r recorder, log *zap.Logger, ae anthropicEstimator, e estimator, aze azureEstimator, v validator, km keyManager, rlm rateLimitManager, ac accessCache) *Handler {
+ return &Handler{
+ recorder: r,
+ log: log,
+ ae: ae,
+ e: e,
+ aze: aze,
+ v: v,
+ km: km,
+ rlm: rlm,
+ ac: ac,
+ }
+}
+
+func (h *Handler) HandleEvent(m Message) error {
+ stats.Incr("bricksllm.message.handler.handle_event.requests", nil, 1)
+
+ e, ok := m.Data.(*event.Event)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.handle_event.event_parsing_error", nil, 1)
+ h.log.Info("message contains data that cannot be converted to event format", zap.Any("data", m.Data))
+ return errors.New("message data cannot be parsed as event")
+ }
+
+ start := time.Now()
+
+ err := h.recorder.RecordEvent(e)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_event.record_event_error", nil, 1)
+ h.log.Sugar().Debugf("error when publish in event: %v", err)
+ return err
+ }
+
+ stats.Timing("bricksllm.message.handler.handle_event.record_event_latency", time.Since(start), nil, 1)
+ stats.Incr("bricksllm.message.handler.handle_event.success", nil, 1)
+
+ return nil
+}
+
+const (
+ anthropicPromptMagicNum int = 1
+ anthropicCompletionMagicNum int = 4
+)
+
+func countTokensFromJson(bytes []byte, contentLoc string) (int, error) {
+ content := getContentFromJson(bytes, contentLoc)
+ return custom.Count(content)
+}
+
+func getContentFromJson(bytes []byte, contentLoc string) string {
+ result := gjson.Get(string(bytes), contentLoc)
+ content := ""
+
+ if len(result.Str) != 0 {
+ content += result.Str
+ }
+
+ if result.IsArray() {
+ for _, val := range result.Array() {
+ if len(val.Str) != 0 {
+ content += val.Str
+ }
+ }
+ }
+
+ return content
+}
+
+type costLimitError interface {
+ Error() string
+ CostLimit()
+}
+
+type rateLimitError interface {
+ Error() string
+ RateLimit()
+}
+
+type expirationError interface {
+ Error() string
+ Reason() string
+}
+
+func (h *Handler) handleValidationResult(kc *key.ResponseKey, cost float64) error {
+ err := h.v.Validate(kc, cost)
+
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.handle_validation_result", nil, 1)
+
+ if _, ok := err.(expirationError); ok {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.expiraton_error", nil, 1)
+
+ truePtr := true
+ _, err = h.km.UpdateKey(kc.KeyId, &key.UpdateKey{
+ Revoked: &truePtr,
+ RevokedReason: key.RevokedReasonExpired,
+ })
+
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.update_key_error", nil, 1)
+ return err
+ }
+
+ return nil
+ }
+
+ if _, ok := err.(rateLimitError); ok {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.rate_limit_error", nil, 1)
+
+ err = h.ac.Set(kc.KeyId, kc.RateLimitUnit)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.set_rate_limit_error", nil, 1)
+ return err
+ }
+
+ return nil
+ }
+
+ if _, ok := err.(costLimitError); ok {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.cost_limit_error", nil, 1)
+
+ err = h.ac.Set(kc.KeyId, kc.CostLimitInUsdUnit)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_validation_result.set_cost_limit_error", nil, 1)
+ return err
+ }
+
+ return nil
+ }
+
+ return err
+ }
+
+ return nil
+}
+
+func (h *Handler) HandleEventWithRequestAndResponse(m Message) error {
+ e, ok := m.Data.(*event.EventWithRequestAndContent)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.message_data_parsing_error", nil, 1)
+ h.log.Debug("message contains data that cannot be converted to event with request and response format", zap.Any("data", m.Data))
+ return errors.New("message data cannot be parsed as event with request and response")
+ }
+
+ if e.Key != nil && !e.Key.Revoked && e.Event != nil {
+ err := h.decorateEvent(m)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.decorate_event_error", nil, 1)
+ h.log.Debug("error when decorating event", zap.Error(err))
+ }
+
+ if e.Event.CostInUsd != 0 {
+ micros := int64(e.Event.CostInUsd * 1000000)
+ err = h.recorder.RecordKeySpend(e.Event.KeyId, micros, e.Key.CostLimitInUsdUnit)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.record_key_spend_error", nil, 1)
+ h.log.Debug("error when recording key spend", zap.Error(err))
+ }
+ }
+
+ if len(e.Key.RateLimitUnit) != 0 {
+ if err := h.rlm.Increment(e.Key.KeyId, e.Key.RateLimitUnit); err != nil {
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.rate_limit_increment_error", nil, 1)
+
+ h.log.Debug("error when incrementing rate limit", zap.Error(err))
+ }
+ }
+
+ err = h.handleValidationResult(e.Key, e.Event.CostInUsd)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.handle_validation_result_error", nil, 1)
+ h.log.Debug("error when handling validation result", zap.Error(err))
+ }
+ }
+
+ start := time.Now()
+ err := h.recorder.RecordEvent(e.Event)
+ if err != nil {
+ h.log.Debug("error when recording an event", zap.Error(err))
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.record_event_error", nil, 1)
+ return err
+ }
+
+ stats.Timing("bricksllm.message.handler.handle_event_with_request_and_response.latency", time.Since(start), nil, 1)
+ stats.Incr("bricksllm.message.handler.handle_event_with_request_and_response.success", nil, 1)
+
+ return nil
+}
+
+func (h *Handler) decorateEvent(m Message) error {
+ stats.Incr("bricksllm.message.handler.decorate_event.request", nil, 1)
+
+ e, ok := m.Data.(*event.EventWithRequestAndContent)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.message_data_parsing_error", nil, 1)
+ h.log.Debug("message contains data that cannot be converted to event with request and response format", zap.Any("data", m.Data))
+ return errors.New("message data cannot be parsed as event with request and response")
+ }
+
+ if e.Event.Path == "/api/providers/openai/v1/audio/speech" {
+ csr, ok := e.Request.(*goopenai.CreateSpeechRequest)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
+ h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Any("data", m.Data))
+ return errors.New("event request data cannot be parsed as anthropic completon request")
+ }
+
+ if e.Event.Status == http.StatusOK {
+ cost, err := h.e.EstimateSpeechCost(csr.Input, string(csr.Model))
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.estimate_prompt_cost", nil, 1)
+ h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Error(err))
+ return err
+ }
+
+ e.Event.CostInUsd = cost
+ }
+ }
+
+ if e.Event.Path == "/api/providers/anthropic/v1/complete" {
+ cr, ok := e.Request.(*anthropic.CompletionRequest)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
+ h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Any("data", m.Data))
+ return errors.New("event request data cannot be parsed as anthropic completon request")
+ }
+
+ tks := h.ae.Count(cr.Prompt)
+ tks += anthropicPromptMagicNum
+
+ model := cr.Model
+ cost, err := h.ae.EstimatePromptCost(model, tks)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.estimate_prompt_cost", nil, 1)
+ h.log.Debug("event contains request that cannot be converted to anthropic completion request", zap.Error(err))
+ return err
+ }
+
+ completiontks := h.ae.Count(e.Content)
+ completiontks += anthropicCompletionMagicNum
+
+ completionCost, err := h.ae.EstimateCompletionCost(model, completiontks)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.estimate_completion_cost_error", nil, 1)
+ return err
+ }
+
+ e.Event.PromptTokenCount = tks
+
+ e.Event.CompletionTokenCount = completiontks
+ e.Event.CostInUsd = completionCost + cost
+ }
+
+ if e.Event.Path == "/api/providers/azure/openai/deployments/:deployment_id/chat/completions" {
+ ccr, ok := e.Request.(*goopenai.ChatCompletionRequest)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
+ h.log.Debug("event contains data that cannot be converted to azure openai completion request", zap.Any("data", m.Data))
+ return errors.New("event request data cannot be parsed as azure openai completon request")
+ }
+
+ if ccr.Stream {
+ tks, err := h.e.EstimateChatCompletionPromptTokenCounts("gpt-3.5-turbo", ccr)
+ if err != nil {
+ stats.Incr("bricksllm.message.decorate_event.estimate_chat_completion_prompt_token_counts_error", nil, 1)
+ return err
+ }
+
+ cost, err := h.aze.EstimatePromptCost(e.Event.Model, tks)
+ if err != nil {
+ stats.Incr("bricksllm.message.decorate_event.estimate_prompt_cost_error", nil, 1)
+ return err
+ }
+
+ completiontks, completionCost, err := h.aze.EstimateChatCompletionStreamCostWithTokenCounts(e.Event.Model, e.Content)
+ if err != nil {
+ stats.Incr("bricksllm.message.decorate_event.estimate_chat_completion_stream_cost_with_token_counts_error", nil, 1)
+ return err
+ }
+
+ e.Event.PromptTokenCount = tks
+ e.Event.CompletionTokenCount = completiontks
+ e.Event.CostInUsd = cost + completionCost
+ }
+ }
+
+ if e.Event.Path == "/api/providers/openai/v1/chat/completions" {
+ ccr, ok := e.Request.(*goopenai.ChatCompletionRequest)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_request_parsing_error", nil, 1)
+ h.log.Debug("event contains data that cannot be converted to openai completion request", zap.Any("data", m.Data))
+ return errors.New("event request data cannot be parsed as openai completon request")
+ }
+
+ if ccr.Stream {
+ tks, cost, err := h.e.EstimateChatCompletionPromptCostWithTokenCounts(ccr)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.estimate_chat_completion_prompt_cost_with_token_counts", nil, 1)
+ return err
+ }
+
+ completiontks, completionCost, err := h.e.EstimateChatCompletionStreamCostWithTokenCounts(e.Event.Model, e.Content)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.estimate_chat_completion_stream_cost_with_token_counts", nil, 1)
+ return err
+ }
+
+ e.Event.PromptTokenCount = tks
+ e.Event.CompletionTokenCount = completiontks
+ e.Event.CostInUsd = cost + completionCost
+ }
+ }
+
+ if strings.HasPrefix(e.Event.Path, "/api/custom/providers/:provider") && e.RouteConfig != nil {
+ body, ok := e.Request.([]byte)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_request_custom_provider_parsing_error", nil, 1)
+ h.log.Debug("event contains request that cannot be converted to bytes", zap.Any("data", m.Data))
+ return errors.New("event request data cannot be parsed as anthropic completon request")
+ }
+
+ tks, err := countTokensFromJson(body, e.RouteConfig.RequestPromptLocation)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.count_tokens_from_json_error", nil, 1)
+
+ return err
+ }
+
+ e.Event.PromptTokenCount = tks
+
+ result := gjson.Get(string(body), e.RouteConfig.StreamLocation)
+ if result.IsBool() {
+ completiontks, err := custom.Count(e.Content)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.custom_count_error", nil, 1)
+ return err
+ }
+
+ e.Event.CompletionTokenCount = completiontks
+ }
+
+ if !result.IsBool() {
+ content, ok := e.Response.([]byte)
+ if !ok {
+ stats.Incr("bricksllm.message.handler.decorate_event.event_response_custom_provider_parsing_error", nil, 1)
+ h.log.Debug("event contains response that cannot be converted to bytes", zap.Any("data", m.Data))
+ return errors.New("event response data cannot be converted to bytes")
+ }
+
+ completiontks, err := countTokensFromJson(content, e.RouteConfig.ResponseCompletionLocation)
+ if err != nil {
+ stats.Incr("bricksllm.message.handler.decorate_event.count_tokens_from_json_error", nil, 1)
+ return err
+ }
+
+ e.Event.CompletionTokenCount = completiontks
+ }
+ }
+
+ return nil
+}
diff --git a/internal/message/message.go b/internal/message/message.go
new file mode 100644
index 0000000..a2e97a5
--- /dev/null
+++ b/internal/message/message.go
@@ -0,0 +1,6 @@
+package message
+
+type Message struct {
+ Type string
+ Data interface{}
+}
diff --git a/internal/policy/policy.go b/internal/policy/policy.go
index b119383..7119693 100644
--- a/internal/policy/policy.go
+++ b/internal/policy/policy.go
@@ -131,7 +131,7 @@ func (p *Policy) Filter(client http.Client, input any) error {
updatedContents, err := p.inspect(client, contents)
if err != nil {
- return nil
+ return err
}
if len(updatedContents) != len(converted.Messages) {
diff --git a/internal/provider/openai/cost.go b/internal/provider/openai/cost.go
index 96e35c4..299c525 100644
--- a/internal/provider/openai/cost.go
+++ b/internal/provider/openai/cost.go
@@ -4,86 +4,115 @@ import (
"encoding/json"
"errors"
"fmt"
+ "math"
"strings"
goopenai "github.com/sashabaranov/go-openai"
)
+func useFinetuneModel(model string) string {
+ if isFinetuneModel(model) {
+ return parseFinetuneModel(model)
+ }
+
+ return model
+}
+
+func isFinetuneModel(model string) bool {
+ return strings.HasPrefix(model, "ft:")
+}
+
+func parseFinetuneModel(model string) string {
+ parts := strings.Split(model, ":")
+ if len(parts) > 2 {
+ return "finetune-" + parts[1]
+ }
+
+ return model
+}
+
var OpenAiPerThousandTokenCost = map[string]map[string]float64{
"prompt": {
- "gpt-4-1106-preview": 0.01,
- "gpt-4-1106-vision-preview": 0.01,
- "gpt-4": 0.03,
- "gpt-4-0314": 0.03,
- "gpt-4-0613": 0.03,
- "gpt-4-32k": 0.06,
- "gpt-4-32k-0613": 0.06,
- "gpt-4-32k-0314": 0.06,
- "gpt-3.5-turbo": 0.0015,
- "gpt-3.5-turbo-1106": 0.001,
- "gpt-3.5-turbo-0301": 0.0015,
- "gpt-3.5-turbo-instruct": 0.0015,
- "gpt-3.5-turbo-0613": 0.0015,
- "gpt-3.5-turbo-16k": 0.0015,
- "gpt-3.5-turbo-16k-0613": 0.0015,
- "text-davinci-003": 0.12,
- "text-davinci-002": 0.12,
- "code-davinci-002": 0.12,
- "text-curie-001": 0.012,
- "text-babbage-001": 0.0024,
- "text-ada-001": 0.0016,
- "davinci": 0.12,
- "curie": 0.012,
- "babbage": 0.0024,
- "ada": 0.0016,
+ "gpt-4-1106-preview": 0.01,
+ "gpt-4-turbo-preview": 0.01,
+ "gpt-4-0125-preview": 0.01,
+ "gpt-4-1106-vision-preview": 0.01,
+ "gpt-4-vision-preview": 0.01,
+ "gpt-4": 0.03,
+ "gpt-4-0314": 0.03,
+ "gpt-4-0613": 0.03,
+ "gpt-4-32k": 0.06,
+ "gpt-4-32k-0613": 0.06,
+ "gpt-4-32k-0314": 0.06,
+ "gpt-3.5-turbo": 0.0015,
+ "gpt-3.5-turbo-1106": 0.001,
+ "gpt-3.5-turbo-0125": 0.0005,
+ "gpt-3.5-turbo-0301": 0.0015,
+ "gpt-3.5-turbo-instruct": 0.0015,
+ "gpt-3.5-turbo-0613": 0.0015,
+ "gpt-3.5-turbo-16k": 0.0015,
+ "gpt-3.5-turbo-16k-0613": 0.0015,
+ "text-davinci-003": 0.12,
+ "text-davinci-002": 0.12,
+ "code-davinci-002": 0.12,
+ "text-curie-001": 0.012,
+ "text-babbage-001": 0.0024,
+ "text-ada-001": 0.0016,
+ "davinci": 0.12,
+ "curie": 0.012,
+ "babbage": 0.0024,
+ "ada": 0.0016,
+ "finetune-gpt-4-0613": 0.045,
+ "finetune-gpt-3.5-turbo-0125": 0.003,
+ "finetune-gpt-3.5-turbo-1106": 0.003,
+ "finetune-gpt-3.5-turbo-0613": 0.003,
+ "finetune-babbage-002": 0.0016,
+ "finetune-davinci-002": 0.012,
},
- "fine_tune": {
- "text-davinci-003": 0.03,
- "text-davinci-002": 0.03,
- "code-davinci-002": 0.03,
- "text-curie-001": 0.03,
- "text-babbage-001": 0.0006,
- "text-ada-001": 0.0004,
- "davinci": 0.03,
- "curie": 0.03,
- "babbage": 0.0006,
- "ada": 0.0004,
+ "finetune": {
+ "gpt-4-0613": 0.09,
+ "gpt-3.5-turbo-0125": 0.008,
+ "gpt-3.5-turbo-1106": 0.008,
+ "gpt-3.5-turbo-0613": 0.008,
+ "babbage-002": 0.0004,
+ "davinci-002": 0.006,
},
"embeddings": {
- "text-embedding-ada-002": 0.0001,
- "text-similarity-ada-001": 0.004,
- "text-search-ada-doc-001": 0.004,
- "text-search-ada-query-001": 0.004,
- "code-search-ada-code-001": 0.004,
- "code-search-ada-text-001": 0.004,
- "code-search-babbage-code-001": 0.005,
- "code-search-babbage-text-001": 0.005,
- "text-similarity-babbage-001": 0.005,
- "text-search-babbage-doc-001": 0.005,
- "text-search-babbage-query-001": 0.005,
- "text-similarity-curie-001": 0.02,
- "text-search-curie-doc-001": 0.02,
- "text-search-curie-query-001": 0.02,
- "text-search-davinci-doc-001": 0.2,
- "text-search-davinci-query-001": 0.2,
- "text-similarity-davinci-001": 0.2,
+ "text-embedding-ada-002": 0.0001,
+ "text-embedding-3-small": 0.00002,
+ "text-embedding-3-large": 0.00013,
+ },
+ "audio": {
+ "whisper-1": 0.006,
+ "tts-1": 0.015,
+ "tts-1-hd": 0.03,
},
"completion": {
- "gpt-3.5-turbo-1106": 0.002,
- "gpt-4-1106-preview": 0.03,
- "gpt-4-1106-vision-preview": 0.03,
- "gpt-4": 0.06,
- "gpt-4-0314": 0.06,
- "gpt-4-0613": 0.06,
- "gpt-4-32k": 0.12,
- "gpt-4-32k-0613": 0.12,
- "gpt-4-32k-0314": 0.12,
- "gpt-3.5-turbo": 0.002,
- "gpt-3.5-turbo-0301": 0.002,
- "gpt-3.5-turbo-0613": 0.002,
- "gpt-3.5-turbo-instruct": 0.002,
- "gpt-3.5-turbo-16k": 0.004,
- "gpt-3.5-turbo-16k-0613": 0.004,
+ "gpt-3.5-turbo-1106": 0.002,
+ "gpt-4-turbo-preview": 0.03,
+ "gpt-4-1106-preview": 0.03,
+ "gpt-4-0125-preview": 0.03,
+ "gpt-4-1106-vision-preview": 0.03,
+ "gpt-4-vision-preview": 0.03,
+ "gpt-4": 0.06,
+ "gpt-4-0314": 0.06,
+ "gpt-4-0613": 0.06,
+ "gpt-4-32k": 0.12,
+ "gpt-4-32k-0613": 0.12,
+ "gpt-4-32k-0314": 0.12,
+ "gpt-3.5-turbo": 0.002,
+ "gpt-3.5-turbo-0125": 0.0015,
+ "gpt-3.5-turbo-0301": 0.002,
+ "gpt-3.5-turbo-0613": 0.002,
+ "gpt-3.5-turbo-instruct": 0.002,
+ "gpt-3.5-turbo-16k": 0.004,
+ "gpt-3.5-turbo-16k-0613": 0.004,
+ "finetune-gpt-4-0613": 0.09,
+ "finetune-gpt-3.5-turbo-0125": 0.006,
+ "finetune-gpt-3.5-turbo-1106": 0.006,
+ "finetune-gpt-3.5-turbo-0613": 0.006,
+ "finetune-babbage-002": 0.0016,
+ "finetune-davinci-002": 0.012,
},
}
@@ -124,7 +153,7 @@ func (ce *CostEstimator) EstimatePromptCost(model string, tks int) (float64, err
}
- cost, ok := costMap[model]
+ cost, ok := costMap[useFinetuneModel(model)]
if !ok {
return 0, fmt.Errorf("%s is not present in the cost map provided", model)
}
@@ -155,7 +184,7 @@ func (ce *CostEstimator) EstimateCompletionCost(model string, tks int) (float64,
return 0, errors.New("prompt token cost is not provided")
}
- cost, ok := costMap[model]
+ cost, ok := costMap[useFinetuneModel(model)]
if !ok {
return 0, errors.New("model is not present in the cost map provided")
}
@@ -209,8 +238,50 @@ func (ce *CostEstimator) EstimateChatCompletionStreamCostWithTokenCounts(model s
return tks, cost, nil
}
+func (ce *CostEstimator) EstimateTranscriptionCost(secs float64, model string) (float64, error) {
+ costMap, ok := ce.tokenCostMap["audio"]
+ if !ok {
+ return 0, errors.New("audio cost map is not provided")
+ }
+
+ cost, ok := costMap[model]
+ if !ok {
+ return 0, errors.New("model is not present in the audio cost map")
+ }
+
+ return math.Trunc(secs) / 60 * cost, nil
+}
+
+func (ce *CostEstimator) EstimateSpeechCost(input string, model string) (float64, error) {
+ costMap, ok := ce.tokenCostMap["audio"]
+ if !ok {
+ return 0, errors.New("audio cost map is not provided")
+ }
+
+ cost, ok := costMap[model]
+ if !ok {
+ return 0, errors.New("model is not present in the audio cost map")
+ }
+
+ return float64(len(input)) / 1000 * cost, nil
+}
+
+func (ce *CostEstimator) EstimateFinetuningCost(num int, model string) (float64, error) {
+ costMap, ok := ce.tokenCostMap["finetune"]
+ if !ok {
+ return 0, errors.New("audio cost map is not provided")
+ }
+
+ cost, ok := costMap[model]
+ if !ok {
+ return 0, errors.New("model is not present in the audio cost map")
+ }
+
+ return cost * float64(num), nil
+}
+
func (ce *CostEstimator) EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error) {
- if len(r.Model.String()) == 0 {
+ if len(string(r.Model)) == 0 {
return 0, errors.New("model is not provided")
}
@@ -222,7 +293,7 @@ func (ce *CostEstimator) EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (f
return 0, errors.New("input is not string")
}
- tks, err := ce.tc.Count(r.Model.String(), converted)
+ tks, err := ce.tc.Count(string(r.Model), converted)
if err != nil {
return 0, err
}
@@ -230,14 +301,14 @@ func (ce *CostEstimator) EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (f
total += tks
}
- return ce.EstimateEmbeddingsInputCost(r.Model.String(), total)
+ return ce.EstimateEmbeddingsInputCost(string(r.Model), total)
} else if input, ok := r.Input.(string); ok {
- tks, err := ce.tc.Count(r.Model.String(), input)
+ tks, err := ce.tc.Count(string(r.Model), input)
if err != nil {
return 0, err
}
- return ce.EstimateEmbeddingsInputCost(r.Model.String(), tks)
+ return ce.EstimateEmbeddingsInputCost(string(r.Model), tks)
}
return 0, errors.New("input format is not recognized")
diff --git a/internal/provider/openai/fetcher.go b/internal/provider/openai/fetcher.go
new file mode 100644
index 0000000..0aac709
--- /dev/null
+++ b/internal/provider/openai/fetcher.go
@@ -0,0 +1 @@
+package openai
diff --git a/internal/server/web/admin/admin.go b/internal/server/web/admin/admin.go
index dc67778..eee9e82 100644
--- a/internal/server/web/admin/admin.go
+++ b/internal/server/web/admin/admin.go
@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
+ "strconv"
"time"
"github.com/bricks-cloud/bricksllm/internal/event"
@@ -37,7 +38,7 @@ type KeyManager interface {
type KeyReportingManager interface {
GetKeyReporting(keyId string) (*key.KeyReporting, error)
- GetEvent(customId string) (*event.Event, error)
+ GetEvents(userId, customId string, keyIds []string, start int64, end int64) ([]*event.Event, error)
GetEventReporting(e *event.ReportingRequest) (*event.ReportingResponse, error)
}
@@ -55,11 +56,11 @@ type AdminServer struct {
m KeyManager
}
-func NewAdminServer(log *zap.Logger, mode string, m KeyManager, krm KeyReportingManager, psm ProviderSettingsManager, cpm CustomProvidersManager, rm RouteManager) (*AdminServer, error) {
+func NewAdminServer(log *zap.Logger, mode string, m KeyManager, krm KeyReportingManager, psm ProviderSettingsManager, cpm CustomProvidersManager, rm RouteManager, adminPass string) (*AdminServer, error) {
router := gin.New()
prod := mode == "production"
- router.Use(getAdminLoggerMiddleware(log, "admin", prod))
+ router.Use(getAdminLoggerMiddleware(log, "admin", prod, adminPass))
router.GET("/api/health", getGetHealthCheckHandler())
@@ -142,17 +143,17 @@ func getGetKeysHandler(m KeyManager, log *zap.Logger, prod bool) gin.HandlerFunc
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_keys_handler.latency", dur, nil, 1)
}()
tag := c.Query("tag")
tags := c.QueryArray("tags")
+ keyIds := c.QueryArray("keyIds")
provider := c.Query("provider")
path := "/api/key-management/keys"
-
- if len(tags) == 0 && len(tag) == 0 && len(provider) == 0 {
+ if len(tags) == 0 && len(tag) == 0 && len(provider) == 0 && len(keyIds) == 0 {
c.JSON(http.StatusBadRequest, &ErrorResponse{
Type: "/errors/missing-filteres",
Title: "filters are not found",
@@ -176,7 +177,7 @@ func getGetKeysHandler(m KeyManager, log *zap.Logger, prod bool) gin.HandlerFunc
}
cid := c.GetString(correlationId)
- keys, err := m.GetKeys(selected, nil, provider)
+ keys, err := m.GetKeys(selected, keyIds, provider)
if err != nil {
stats.Incr("bricksllm.admin.get_get_keys_handler.get_keys_by_tag_err", nil, 1)
@@ -207,7 +208,7 @@ func getGetProviderSettingsHandler(m ProviderSettingsManager, log *zap.Logger, p
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_provider_settings.latency", dur, nil, 1)
}()
@@ -257,7 +258,7 @@ func getCreateProviderSettingHandler(m ProviderSettingsManager, log *zap.Logger,
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_create_provider_setting_handler.latency", dur, nil, 1)
}()
@@ -347,7 +348,7 @@ func getCreateKeyHandler(m KeyManager, log *zap.Logger, prod bool) gin.HandlerFu
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_create_key_handler.latency", dur, nil, 1)
}()
@@ -437,7 +438,7 @@ func getUpdateProviderSettingHandler(m ProviderSettingsManager, log *zap.Logger,
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_update_provider_setting_handler.latency", dur, nil, 1)
}()
@@ -539,7 +540,7 @@ func getUpdateKeyHandler(m KeyManager, log *zap.Logger, prod bool) gin.HandlerFu
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_update_key_handler.latency", dur, nil, 1)
}()
@@ -714,7 +715,7 @@ func getGetEventMetricsHandler(m KeyReportingManager, log *zap.Logger, prod bool
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_event_metrics.latency", dur, nil, 1)
}()
@@ -801,7 +802,7 @@ func getGetEventsHandler(m KeyReportingManager, log *zap.Logger, prod bool) gin.
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_events_handler.latency", dur, nil, 1)
}()
@@ -819,27 +820,90 @@ func getGetEventsHandler(m KeyReportingManager, log *zap.Logger, prod bool) gin.
}
cid := c.GetString(correlationId)
- customId, ok := c.GetQuery("customId")
- if !ok {
+ customId, ciok := c.GetQuery("customId")
+ userId, uiok := c.GetQuery("userId")
+ keyIds, kiok := c.GetQueryArray("keyIds")
+ if !ciok && !kiok && !uiok {
c.JSON(http.StatusBadRequest, &ErrorResponse{
- Type: "/errors/custom-id-empty",
- Title: "custom id is empty",
+ Type: "/errors/no-filters-empty",
+ Title: "none of customId, keyIds and userId is specified",
Status: http.StatusBadRequest,
- Detail: "query param customId is empty. it is required for retrieving an event.",
+ Detail: "customId, userId and keyIds are empty. one of them is required for retrieving events.",
Instance: path,
})
return
}
- ev, err := m.GetEvent(customId)
+ var qstart int64 = 0
+ var qend int64 = 0
+
+ if kiok {
+ startstr, sok := c.GetQuery("start")
+ if !sok {
+ c.JSON(http.StatusBadRequest, &ErrorResponse{
+ Type: "/errors/query-param-start-missing",
+ Title: "query param start is missing",
+ Status: http.StatusBadRequest,
+ Detail: "start query param is not provided",
+ Instance: path,
+ })
+
+ return
+ }
+
+ parsedStart, err := strconv.ParseInt(startstr, 10, 64)
+ if err != nil {
+ c.JSON(http.StatusBadRequest, &ErrorResponse{
+ Type: "/errors/bad-start-query-param",
+ Title: "start query cannot be parsed",
+ Status: http.StatusBadRequest,
+ Detail: "start query param must be int64",
+ Instance: path,
+ })
+
+ return
+ }
+
+ qstart = parsedStart
+
+ endstr, eoi := c.GetQuery("end")
+ if !eoi {
+ c.JSON(http.StatusBadRequest, &ErrorResponse{
+ Type: "/errors/query-param-end-missing",
+ Title: "query param end is missing",
+ Status: http.StatusBadRequest,
+ Detail: "end query param is not provided",
+ Instance: path,
+ })
+
+ return
+ }
+
+ parsedEnd, err := strconv.ParseInt(endstr, 10, 64)
+ if err != nil {
+ c.JSON(http.StatusBadRequest, &ErrorResponse{
+ Type: "/errors/bad-end-query-param",
+ Title: "end query cannot be parsed",
+ Status: http.StatusBadRequest,
+ Detail: "end query param must be int64",
+ Instance: path,
+ })
+
+ return
+ }
+
+ qend = parsedEnd
+ }
+
+ evs, err := m.GetEvents(userId, customId, keyIds, qstart, qend)
if err != nil {
- stats.Incr("bricksllm.admin.get_get_events_handler.get_event_error", nil, 1)
+ stats.Incr("bricksllm.admin.get_get_events_handler.get_events_error", nil, 1)
- logError(log, "error when getting an event", prod, cid, err)
+ logError(log, "error when getting events", prod, cid, err)
c.JSON(http.StatusInternalServerError, &ErrorResponse{
Type: "/errors/event-manager",
- Title: "getting an event error",
+ Title: "getting events error",
Status: http.StatusInternalServerError,
Detail: err.Error(),
Instance: path,
@@ -849,7 +913,7 @@ func getGetEventsHandler(m KeyReportingManager, log *zap.Logger, prod bool) gin.
stats.Incr("bricksllm.admin.get_get_events_handler.success", nil, 1)
- c.JSON(http.StatusOK, []*event.Event{ev})
+ c.JSON(http.StatusOK, evs)
}
}
@@ -859,7 +923,7 @@ func getGetKeyReportingHandler(m KeyReportingManager, log *zap.Logger, prod bool
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_key_reporting_hanlder.latency", dur, nil, 1)
}()
@@ -944,7 +1008,7 @@ func getCreateCustomProviderHandler(m CustomProvidersManager, log *zap.Logger, p
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_create_custom_provider_handler.latency", dur, nil, 1)
}()
@@ -1032,7 +1096,7 @@ func getGetCustomProvidersHandler(m CustomProvidersManager, log *zap.Logger, pro
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_custom_providers_handler.latency", dur, nil, 1)
}()
@@ -1080,7 +1144,7 @@ func getUpdateCustomProvidersHandler(m CustomProvidersManager, log *zap.Logger,
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_update_custom_providers_handler.latency", dur, nil, 1)
}()
diff --git a/internal/server/web/admin/middleware.go b/internal/server/web/admin/middleware.go
index ad6ad85..8704d44 100644
--- a/internal/server/web/admin/middleware.go
+++ b/internal/server/web/admin/middleware.go
@@ -8,12 +8,18 @@ import (
"go.uber.org/zap"
)
-func getAdminLoggerMiddleware(log *zap.Logger, prefix string, prod bool) gin.HandlerFunc {
+func getAdminLoggerMiddleware(log *zap.Logger, prefix string, prod bool, adminPass string) gin.HandlerFunc {
return func(c *gin.Context) {
+ if len(adminPass) != 0 && c.Request.Header.Get("X-API-KEY") != adminPass {
+ c.Status(200)
+ c.Abort()
+ return
+ }
+
c.Set(correlationId, util.NewUuid())
start := time.Now()
c.Next()
- latency := time.Now().Sub(start).Milliseconds()
+ latency := time.Since(start).Milliseconds()
if !prod {
log.Sugar().Infof("%s | %d | %s | %s | %dms", prefix, c.Writer.Status(), c.Request.Method, c.FullPath(), latency)
}
diff --git a/internal/server/web/admin/route.go b/internal/server/web/admin/route.go
index 6699ea8..99d3bca 100644
--- a/internal/server/web/admin/route.go
+++ b/internal/server/web/admin/route.go
@@ -24,7 +24,7 @@ func getCreateRouteHandler(m RouteManager, log *zap.Logger, prod bool) gin.Handl
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_create_route_handler.latency", dur, nil, 1)
}()
@@ -112,7 +112,7 @@ func getGetRouteHandler(m RouteManager, log *zap.Logger, prod bool) gin.HandlerF
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_route_handler.latency", dur, nil, 1)
}()
@@ -174,7 +174,7 @@ func getGetRoutesHandler(m RouteManager, log *zap.Logger, prod bool) gin.Handler
start := time.Now()
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.admin.get_get_routes_handler.latency", dur, nil, 1)
}()
diff --git a/internal/server/web/proxy/anthropic.go b/internal/server/web/proxy/anthropic.go
index fa96eb2..0123275 100644
--- a/internal/server/web/proxy/anthropic.go
+++ b/internal/server/web/proxy/anthropic.go
@@ -5,12 +5,12 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"io"
"net/http"
"strings"
"time"
- "github.com/bricks-cloud/bricksllm/internal/key"
"github.com/bricks-cloud/bricksllm/internal/provider/anthropic"
"github.com/bricks-cloud/bricksllm/internal/stats"
"github.com/gin-gonic/gin"
@@ -18,11 +18,6 @@ import (
"go.uber.org/zap/zapcore"
)
-const (
- anthropicPromptMagicNum int = 1
- anthropicCompletionMagicNum int = 4
-)
-
type anthropicEstimator interface {
EstimateTotalCost(model string, promptTks, completionTks int) (float64, error)
EstimateCompletionCost(model string, tks int) (float64, error)
@@ -40,7 +35,7 @@ func copyHttpHeaders(source *http.Request, dest *http.Request) {
dest.Header.Set("Accept-Encoding", "*")
}
-func getCompletionHandler(r recorder, prod, private bool, client http.Client, kms keyMemStorage, log *zap.Logger, e anthropicEstimator, timeOut time.Duration) gin.HandlerFunc {
+func getCompletionHandler(prod, private bool, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_completion_handler.requests", nil, 1)
@@ -60,13 +55,13 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
return
}
- raw, exists := c.Get("key")
- kc, ok := raw.(*key.ResponseKey)
- if !exists || !ok {
- stats.Incr("bricksllm.proxy.get_completion_handler.api_key_not_registered", nil, 1)
- JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
- return
- }
+ // raw, exists := c.Get("key")
+ // kc, ok := raw.(*key.ResponseKey)
+ // if !exists || !ok {
+ // stats.Incr("bricksllm.proxy.get_completion_handler.api_key_not_registered", nil, 1)
+ // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
+ // return
+ // }
copyHttpHeaders(c.Request, req)
@@ -95,10 +90,10 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
}
}
- model := c.GetString("model")
+ // model := c.GetString("model")
if !isStreaming && res.StatusCode == http.StatusOK {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_completion_handler.latency", dur, nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -108,8 +103,8 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
return
}
- var cost float64 = 0
- completionTokens := 0
+ // var cost float64 = 0
+ // completionTokens := 0
completionRes := &anthropic.CompletionResponse{}
stats.Incr("bricksllm.proxy.get_completion_handler.success", nil, 1)
stats.Timing("bricksllm.proxy.get_completion_handler.success_latency", dur, nil, 1)
@@ -119,34 +114,38 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
logError(log, "error when unmarshalling anthropic http completion response body", prod, cid, err)
}
- if err == nil {
- logCompletionResponse(log, bytes, prod, private, cid)
- completionTokens = e.Count(completionRes.Completion)
- completionTokens += anthropicCompletionMagicNum
- promptTokens := c.GetInt("promptTokenCount")
- cost, err = e.EstimateTotalCost(model, promptTokens, completionTokens)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_completion_handler.estimate_total_cost_error", nil, 1)
- logError(log, "error when estimating anthropic cost", prod, cid, err)
- }
+ logCompletionResponse(log, bytes, prod, private, cid)
- micros := int64(cost * 1000000)
- err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_completion_handler.record_key_spend_error", nil, 1)
- logError(log, "error when recording anthropic spend", prod, cid, err)
- }
- }
+ c.Set("content", completionRes.Completion)
- c.Set("costInUsd", cost)
- c.Set("completionTokenCount", completionTokens)
+ // if err == nil {
+ // logCompletionResponse(log, bytes, prod, private, cid)
+ // completionTokens = e.Count(completionRes.Completion)
+ // completionTokens += anthropicCompletionMagicNum
+ // promptTokens := c.GetInt("promptTokenCount")
+ // cost, err = e.EstimateTotalCost(model, promptTokens, completionTokens)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_completion_handler.estimate_total_cost_error", nil, 1)
+ // logError(log, "error when estimating anthropic cost", prod, cid, err)
+ // }
+
+ // micros := int64(cost * 1000000)
+ // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_completion_handler.record_key_spend_error", nil, 1)
+ // logError(log, "error when recording anthropic spend", prod, cid, err)
+ // }
+ // }
+
+ // c.Set("costInUsd", cost)
+ // c.Set("completionTokenCount", completionTokens)
c.Data(res.StatusCode, "application/json", bytes)
return
}
if res.StatusCode != http.StatusOK {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_completion_handler.error_latency", dur, nil, 1)
stats.Incr("bricksllm.proxy.get_completion_handler.error_response", nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -162,24 +161,24 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
}
buffer := bufio.NewReader(res.Body)
- var totalCost float64 = 0
+ // var totalCost float64 = 0
content := ""
- defer func() {
- tks := e.Count(content)
- model := c.GetString("model")
- cost, err := e.EstimateCompletionCost(model, tks)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_completion_handler.estimate_completion_cost_error", nil, 1)
- logError(log, "error when estimating anthropic completion stream cost", prod, cid, err)
- }
-
- estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd")
- totalCost = cost + estimatedPromptCost
-
- c.Set("costInUsd", totalCost)
- c.Set("completionTokenCount", tks+anthropicCompletionMagicNum)
- }()
+ // defer func() {
+ // tks := e.Count(content)
+ // model := c.GetString("model")
+ // cost, err := e.EstimateCompletionCost(model, tks)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_completion_handler.estimate_completion_cost_error", nil, 1)
+ // logError(log, "error when estimating anthropic completion stream cost", prod, cid, err)
+ // }
+
+ // estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd")
+ // totalCost = cost + estimatedPromptCost
+
+ // c.Set("costInUsd", totalCost)
+ // c.Set("completionTokenCount", tks+anthropicCompletionMagicNum)
+ // }()
stats.Incr("bricksllm.proxy.get_completion_handler.streaming_requests", nil, 1)
@@ -191,6 +190,13 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
return false
}
+ if errors.Is(err, context.DeadlineExceeded) {
+ stats.Incr("bricksllm.proxy.get_completion_handler.context_deadline_exceeded_error", nil, 1)
+ logError(log, "context deadline exceeded when reading bytes from anthropic completion response", prod, cid, err)
+
+ return false
+ }
+
stats.Incr("bricksllm.proxy.get_completion_handler.read_bytes_error", nil, 1)
logError(log, "error when reading bytes from anthropic streaming response", prod, cid, err)
@@ -249,7 +255,7 @@ func getCompletionHandler(r recorder, prod, private bool, client http.Client, km
return true
})
- stats.Timing("bricksllm.proxy.get_completion_handler.streaming_latency", time.Now().Sub(start), nil, 1)
+ stats.Timing("bricksllm.proxy.get_completion_handler.streaming_latency", time.Since(start), nil, 1)
}
}
diff --git a/internal/server/web/proxy/assistant.go b/internal/server/web/proxy/assistant.go
index 0b20368..1a7b61f 100644
--- a/internal/server/web/proxy/assistant.go
+++ b/internal/server/web/proxy/assistant.go
@@ -63,7 +63,7 @@ func logAssistantResponse(log *zap.Logger, data []byte, prod, private bool, cid
}
}
-func logRetrieveAssistantRequest(log *zap.Logger, data []byte, prod bool, cid, assistantId string) {
+func logRetrieveAssistantRequest(log *zap.Logger, prod bool, cid, assistantId string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -102,7 +102,7 @@ func logModifyAssistantRequest(log *zap.Logger, data []byte, prod, private bool,
}
}
-func logDeleteAssistantRequest(log *zap.Logger, data []byte, prod bool, cid, assistantId string) {
+func logDeleteAssistantRequest(log *zap.Logger, prod bool, cid, assistantId string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/assistant_file.go b/internal/server/web/proxy/assistant_file.go
index 82a45cf..32dca5e 100644
--- a/internal/server/web/proxy/assistant_file.go
+++ b/internal/server/web/proxy/assistant_file.go
@@ -60,7 +60,7 @@ func logRetrieveAssistantFileRequest(log *zap.Logger, prod bool, cid, fid, aid s
}
}
-func logDeleteAssistantFileRequest(log *zap.Logger, data []byte, prod bool, cid, fid, aid string) {
+func logDeleteAssistantFileRequest(log *zap.Logger, prod bool, cid, fid, aid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/audio.go b/internal/server/web/proxy/audio.go
index 33e4534..d513829 100644
--- a/internal/server/web/proxy/audio.go
+++ b/internal/server/web/proxy/audio.go
@@ -1,10 +1,493 @@
package proxy
import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/asticode/go-astisub"
+ "github.com/bricks-cloud/bricksllm/internal/stats"
+ "github.com/gin-gonic/gin"
+ goopenai "github.com/sashabaranov/go-openai"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
+func getSpeechHandler(prod bool, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ stats.Incr("bricksllm.proxy.get_speech_handler.requests", nil, 1)
+
+ if c == nil || c.Request == nil {
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] context is empty")
+ return
+ }
+
+ cid := c.GetString(correlationId)
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeOut)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, c.Request.Method, "https://api.openai.com/v1/audio/speech", c.Request.Body)
+ if err != nil {
+ logError(log, "error when creating openai http request", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create openai http request")
+ return
+ }
+
+ copyHttpHeaders(c.Request, req)
+
+ start := time.Now()
+
+ res, err := client.Do(req)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_speech_handler.http_client_error", nil, 1)
+
+ logError(log, "error when sending create speech request to openai", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to send create speech request to openai")
+ return
+ }
+ defer res.Body.Close()
+
+ dur := time.Since(start)
+ stats.Timing("bricksllm.proxy.get_speech_handler.latency", dur, nil, 1)
+
+ bytes, err := io.ReadAll(res.Body)
+ if err != nil {
+ logError(log, "error when reading openai create speech response body", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to read openai create speech response body")
+ return
+ }
+
+ if res.StatusCode == http.StatusOK {
+ stats.Incr("bricksllm.proxy.get_speech_handler.success", nil, 1)
+ stats.Timing("bricksllm.proxy.get_pass_through_handler.success_latency", dur, nil, 1)
+ }
+
+ if res.StatusCode != http.StatusOK {
+ stats.Timing("bricksllm.proxy.get_speech_handler.error_latency", dur, nil, 1)
+ stats.Incr("bricksllm.proxy.get_speech_handler.error_response", nil, 1)
+
+ errorRes := &goopenai.ErrorResponse{}
+ err = json.Unmarshal(bytes, errorRes)
+ if err != nil {
+ logError(log, "error when unmarshalling openai create speech error response body", prod, cid, err)
+ }
+
+ logOpenAiError(log, prod, cid, errorRes)
+ }
+
+ for name, values := range res.Header {
+ for _, value := range values {
+ c.Header(name, value)
+ }
+ }
+
+ c.Data(res.StatusCode, res.Header.Get("Content-Type"), bytes)
+ }
+}
+
+func convertVerboseJson(resp *goopenai.AudioResponse, format string) ([]byte, error) {
+ if format == "verbose_json" || format == "json" {
+ selected := resp
+ if format == "json" {
+ selected = &goopenai.AudioResponse{
+ Text: resp.Text,
+ }
+ }
+
+ data, err := json.Marshal(selected)
+ if err != nil {
+ return nil, err
+ }
+
+ return data, nil
+ }
+
+ if format == "text" {
+ return []byte(resp.Text + "\n"), nil
+ }
+
+ if format == "srt" || format == "vtt" {
+ sub := astisub.NewSubtitles()
+ items := []*astisub.Item{}
+
+ for _, seg := range resp.Segments {
+ item := &astisub.Item{
+ StartAt: time.Duration(seg.Start * float64(time.Second)),
+ EndAt: time.Duration(seg.End * float64(time.Second)),
+ Lines: []astisub.Line{
+ {
+ Items: []astisub.LineItem{
+ {Text: seg.Text},
+ },
+ },
+ },
+ }
+
+ items = append(items, item)
+ }
+
+ sub.Items = items
+
+ buf := bytes.NewBuffer([]byte{})
+
+ if format == "srt" {
+ err := sub.WriteToSRT(buf)
+ if err != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+ }
+
+ if format == "vtt" {
+ err := sub.WriteToWebVTT(buf)
+ if err != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+ }
+ }
+
+ return nil, nil
+}
+
+func getContentType(format string) string {
+ if format == "verbose_json" || format == "json" {
+ return "application/json"
+ }
+
+ return "text/plain; charset=utf-8"
+}
+
+func getTranscriptionsHandler(prod bool, client http.Client, log *zap.Logger, timeOut time.Duration, e estimator) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.requests", nil, 1)
+
+ if c == nil || c.Request == nil {
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] context is empty")
+ return
+ }
+
+ cid := c.GetString(correlationId)
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeOut)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, c.Request.Method, "https://api.openai.com/v1/audio/transcriptions", c.Request.Body)
+ if err != nil {
+ logError(log, "error when creating transcriptions openai http request", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create openai transcriptions http request")
+ return
+ }
+
+ copyHttpHeaders(c.Request, req)
+
+ var b bytes.Buffer
+ writer := multipart.NewWriter(&b)
+
+ err = writeFieldToBuffer([]string{
+ "model",
+ "language",
+ "prompt",
+ "response_format",
+ "temperature",
+ }, c, writer, map[string]string{
+ "response_format": "verbose_json",
+ })
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.write_field_to_buffer_error", nil, 1)
+ logError(log, "error when writing field to buffer", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot write field to buffer")
+ return
+ }
+
+ var form TransriptionForm
+ c.ShouldBind(&form)
+
+ if form.File != nil {
+ fieldWriter, err := writer.CreateFormFile("file", form.File.Filename)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.create_transcription_file_error", nil, 1)
+ logError(log, "error when creating transcription file", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot create transcription file")
+ return
+ }
+
+ opened, err := form.File.Open()
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.open_transcription_file_error", nil, 1)
+ logError(log, "error when openning transcription file", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot open transcription file")
+ return
+ }
+
+ _, err = io.Copy(fieldWriter, opened)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.copy_transcription_file_error", nil, 1)
+ logError(log, "error when copying transcription file", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot copy transcription file")
+ return
+ }
+ }
+
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ writer.Close()
+ req.Body = io.NopCloser(&b)
+
+ start := time.Now()
+
+ res, err := client.Do(req)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.http_client_error", nil, 1)
+
+ logError(log, "error when sending transcriptions request to openai", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to send transcriptions request to openai")
+ return
+ }
+ defer res.Body.Close()
+
+ dur := time.Since(start)
+ stats.Timing("bricksllm.proxy.get_transcriptions_handler.latency", dur, nil, 1)
+
+ bytes, err := io.ReadAll(res.Body)
+ if err != nil {
+ logError(log, "error when reading openai transcriptions response body", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to read openai transcriptions response body")
+ return
+ }
+
+ format := c.PostForm("response_format")
+ for name, values := range res.Header {
+ for _, value := range values {
+ if strings.ToLower(name) == "content-type" {
+ c.Header(name, getContentType(format))
+ continue
+ }
+
+ c.Header(name, value)
+ }
+ }
+
+ if res.StatusCode == http.StatusOK {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.success", nil, 1)
+ stats.Timing("bricksllm.proxy.get_transcriptions_handler.success_latency", dur, nil, 1)
+
+ ar := &goopenai.AudioResponse{}
+ err = json.Unmarshal(bytes, ar)
+ if err != nil {
+ logError(log, "error when unmarshalling openai http audio response body", prod, cid, err)
+ }
+
+ if err == nil {
+ cost, err := e.EstimateTranscriptionCost(ar.Duration, c.GetString("model"))
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.estimate_total_cost_error", nil, 1)
+ logError(log, "error when estimating openai cost", prod, cid, err)
+ }
+
+ c.Set("costInUsd", cost)
+ }
+
+ data, err := convertVerboseJson(ar, format)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.convert_verbose_json_error", nil, 1)
+ logError(log, "error when converting verbose json", prod, cid, err)
+ }
+
+ c.Header("Content-Length", strconv.Itoa(len(data)))
+
+ c.Data(res.StatusCode, getContentType(format), data)
+ return
+ }
+
+ if res.StatusCode != http.StatusOK {
+ stats.Timing("bricksllm.proxy.get_transcriptions_handler.error_latency", dur, nil, 1)
+ stats.Incr("bricksllm.proxy.get_transcriptions_handler.error_response", nil, 1)
+
+ errorRes := &goopenai.ErrorResponse{}
+ err = json.Unmarshal(bytes, errorRes)
+ if err != nil {
+ logError(log, "error when unmarshalling openai transcriptions error response body", prod, cid, err)
+ }
+
+ logOpenAiError(log, prod, cid, errorRes)
+
+ c.Data(res.StatusCode, res.Header.Get("Content-Type"), bytes)
+
+ return
+ }
+ }
+}
+
+func getTranslationsHandler(prod bool, client http.Client, log *zap.Logger, timeOut time.Duration, e estimator) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ stats.Incr("bricksllm.proxy.get_translations_handler.requests", nil, 1)
+
+ if c == nil || c.Request == nil {
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] context is empty")
+ return
+ }
+
+ cid := c.GetString(correlationId)
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeOut)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, c.Request.Method, "https://api.openai.com/v1/audio/translations", c.Request.Body)
+ if err != nil {
+ logError(log, "error when creating translations openai http request", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to create openai translations http request")
+ return
+ }
+
+ copyHttpHeaders(c.Request, req)
+
+ var b bytes.Buffer
+ writer := multipart.NewWriter(&b)
+
+ err = writeFieldToBuffer([]string{
+ "model",
+ "prompt",
+ "response_format",
+ "temperature",
+ }, c, writer, map[string]string{
+ "response_format": "verbose_json",
+ })
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_pass_through_handler.write_field_to_buffer_error", nil, 1)
+ logError(log, "error when writing field to buffer", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot write field to buffer")
+ return
+ }
+
+ var form TranslationForm
+ c.ShouldBind(&form)
+
+ if form.File != nil {
+ fieldWriter, err := writer.CreateFormFile("file", form.File.Filename)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_pass_through_handler.create_translation_file_error", nil, 1)
+ logError(log, "error when creating translation file", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot create translation file")
+ return
+ }
+
+ opened, err := form.File.Open()
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_pass_through_handler.open_translation_file_error", nil, 1)
+ logError(log, "error when openning translation file", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot open translation file")
+ return
+ }
+
+ _, err = io.Copy(fieldWriter, opened)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_pass_through_handler.copy_translation_file_error", nil, 1)
+ logError(log, "error when copying translation file", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot copy translation file")
+ return
+ }
+ }
+
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+
+ writer.Close()
+
+ req.Body = io.NopCloser(&b)
+
+ start := time.Now()
+
+ res, err := client.Do(req)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_translations_handler.http_client_error", nil, 1)
+
+ logError(log, "error when sending translations request to openai", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to send translations request to openai")
+ return
+ }
+ defer res.Body.Close()
+
+ dur := time.Since(start)
+ stats.Timing("bricksllm.proxy.get_translations_handler.latency", dur, nil, 1)
+
+ bytes, err := io.ReadAll(res.Body)
+ if err != nil {
+ logError(log, "error when reading openai translations response body", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to read openai translations response body")
+ return
+ }
+
+ format := c.PostForm("response_format")
+ for name, values := range res.Header {
+ for _, value := range values {
+ if strings.ToLower(name) == "content-type" {
+ c.Header(name, getContentType(format))
+ continue
+ }
+
+ c.Header(name, value)
+ }
+ }
+
+ if res.StatusCode == http.StatusOK {
+ stats.Incr("bricksllm.proxy.get_translations_handler.success", nil, 1)
+ stats.Timing("bricksllm.proxy.get_translations_handler.success_latency", dur, nil, 1)
+
+ ar := &goopenai.AudioResponse{}
+ err = json.Unmarshal(bytes, ar)
+ if err != nil {
+ logError(log, "error when unmarshalling openai http audio response body", prod, cid, err)
+ }
+
+ if err == nil {
+ cost, err := e.EstimateTranscriptionCost(ar.Duration, c.GetString("model"))
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_translations_handler.estimate_total_cost_error", nil, 1)
+ logError(log, "error when estimating openai cost", prod, cid, err)
+ }
+
+ c.Set("costInUsd", cost)
+ }
+
+ data, err := convertVerboseJson(ar, format)
+ if err != nil {
+ stats.Incr("bricksllm.proxy.get_translations_handler.convert_verbose_json_error", nil, 1)
+ logError(log, "error when converting verbose json", prod, cid, err)
+ }
+
+ c.Header("Content-Length", strconv.Itoa(len(data)))
+
+ c.Data(res.StatusCode, getContentType(format), data)
+ return
+ }
+
+ if res.StatusCode != http.StatusOK {
+ stats.Timing("bricksllm.proxy.get_translations_handler.error_latency", dur, nil, 1)
+ stats.Incr("bricksllm.proxy.get_translations_handler.error_response", nil, 1)
+
+ errorRes := &goopenai.ErrorResponse{}
+ err = json.Unmarshal(bytes, errorRes)
+ if err != nil {
+ logError(log, "error when unmarshalling openai translations error response body", prod, cid, err)
+ }
+
+ logOpenAiError(log, prod, cid, errorRes)
+
+ c.Data(res.StatusCode, res.Header.Get("Content-Type"), bytes)
+
+ return
+ }
+ }
+}
+
type SpeechRequest struct {
Model string `json:"model"`
Input string `json:"input"`
@@ -13,24 +496,24 @@ type SpeechRequest struct {
Speed float64 `json:"speed"`
}
-func logCreateSpeechRequest(log *zap.Logger, sr *SpeechRequest, prod, private bool, cid string) {
+func logCreateSpeechRequest(log *zap.Logger, csr *goopenai.CreateSpeechRequest, prod, private bool, cid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
- zap.String("model", sr.Model),
- zap.String("voice", sr.Voice),
+ zap.String("model", string(csr.Model)),
+ zap.String("voice", string(csr.Voice)),
}
if !private {
- fields = append(fields, zap.String("input", sr.Input))
+ fields = append(fields, zap.String("input", csr.Input))
}
- if len(sr.ResponseFormat) != 0 {
- fields = append(fields, zap.String("response_format", sr.ResponseFormat))
+ if len(csr.ResponseFormat) != 0 {
+ fields = append(fields, zap.String("response_format", string(csr.ResponseFormat)))
}
- if sr.Speed != 0 {
- fields = append(fields, zap.Float64("speed", sr.Speed))
+ if csr.Speed != 0 {
+ fields = append(fields, zap.Float64("speed", csr.Speed))
}
log.Info("openai create speech request", fields...)
@@ -69,6 +552,7 @@ func logCreateTranslationRequest(log *zap.Logger, model, prompt, responseFormat
fields := []zapcore.Field{
zap.String(correlationId, cid),
zap.String("model", model),
+ zap.Float64("temperature", temperature),
}
if !private && len(prompt) == 0 {
diff --git a/internal/server/web/proxy/azure_chat_completion.go b/internal/server/web/proxy/azure_chat_completion.go
index d0b09aa..b5de639 100644
--- a/internal/server/web/proxy/azure_chat_completion.go
+++ b/internal/server/web/proxy/azure_chat_completion.go
@@ -5,12 +5,12 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"fmt"
"io"
"net/http"
"time"
- "github.com/bricks-cloud/bricksllm/internal/key"
"github.com/bricks-cloud/bricksllm/internal/stats"
"github.com/gin-gonic/gin"
goopenai "github.com/sashabaranov/go-openai"
@@ -25,7 +25,7 @@ func buildAzureUrl(path, deploymentId, apiVersion, resourceName string) string {
return fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/embeddings?api-version=%s", resourceName, deploymentId, apiVersion)
}
-func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettingsManager, client http.Client, kms keyMemStorage, log *zap.Logger, aoe azureEstimator, timeOut time.Duration) gin.HandlerFunc {
+func getAzureChatCompletionHandler(prod, private bool, client http.Client, log *zap.Logger, aoe azureEstimator, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.requests", nil, 1)
@@ -35,13 +35,6 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
}
cid := c.GetString(correlationId)
- raw, exists := c.Get("key")
- kc, ok := raw.(*key.ResponseKey)
- if !exists || !ok {
- stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.api_key_not_registered", nil, 1)
- JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
- return
- }
ctx, cancel := context.WithTimeout(context.Background(), timeOut)
defer cancel()
@@ -80,7 +73,7 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
}
if res.StatusCode == http.StatusOK && !isStreaming {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_azure_chat_completion_handler.latency", dur, nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -110,12 +103,12 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
logError(log, "error when estimating azure openai cost", prod, cid, err)
}
- micros := int64(cost * 1000000)
- err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.record_key_spend_error", nil, 1)
- logError(log, "error when recording azure openai spend", prod, cid, err)
- }
+ // micros := int64(cost * 1000000)
+ // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.record_key_spend_error", nil, 1)
+ // logError(log, "error when recording azure openai spend", prod, cid, err)
+ // }
}
c.Set("costInUsd", cost)
@@ -127,7 +120,7 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
}
if res.StatusCode != http.StatusOK {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_azure_chat_completion_handler.error_latency", dur, nil, 1)
stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.error_response", nil, 1)
@@ -144,8 +137,8 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
}
buffer := bufio.NewReader(res.Body)
- var totalCost float64 = 0
- var totalTokens int = 0
+ // var totalCost float64 = 0
+ // var totalTokens int = 0
content := ""
model := ""
@@ -154,24 +147,26 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
c.Set("model", model)
}
- tks, cost, err := aoe.EstimateChatCompletionStreamCostWithTokenCounts(model, content)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1)
- logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err)
- }
+ c.Set("content", content)
- estimatedPromptTokenCounts := c.GetInt("promptTokenCount")
- promptCost, err := aoe.EstimatePromptCost(model, estimatedPromptTokenCounts)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1)
- logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err)
- }
+ // tks, cost, err := aoe.EstimateChatCompletionStreamCostWithTokenCounts(model, content)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1)
+ // logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err)
+ // }
+
+ // estimatedPromptTokenCounts := c.GetInt("promptTokenCount")
+ // promptCost, err := aoe.EstimatePromptCost(model, estimatedPromptTokenCounts)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1)
+ // logError(log, "error when estimating azure openai chat completion stream cost with token counts", prod, cid, err)
+ // }
- totalCost = cost + promptCost
- totalTokens += tks
+ // totalCost = cost + promptCost
+ // totalTokens += tks
- c.Set("costInUsd", totalCost)
- c.Set("completionTokenCount", totalTokens)
+ // c.Set("costInUsd", totalCost)
+ // c.Set("completionTokenCount", totalTokens)
}()
stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.streaming_requests", nil, 1)
@@ -183,6 +178,13 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
return false
}
+ if errors.Is(err, context.DeadlineExceeded) {
+ stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.context_deadline_exceeded_error", nil, 1)
+ logError(log, "context deadline exceeded when reading bytes from azure openai chat completion response", prod, cid, err)
+
+ return false
+ }
+
stats.Incr("bricksllm.proxy.get_azure_chat_completion_handler.read_bytes_error", nil, 1)
logError(log, "error when reading bytes from azure openai chat completion response", prod, cid, err)
@@ -236,6 +238,6 @@ func getAzureChatCompletionHandler(r recorder, prod, private bool, psm ProviderS
return true
})
- stats.Timing("bricksllm.proxy.get_azure_chat_completion_handler.streaming_latency", time.Now().Sub(start), nil, 1)
+ stats.Timing("bricksllm.proxy.get_azure_chat_completion_handler.streaming_latency", time.Since(start), nil, 1)
}
}
diff --git a/internal/server/web/proxy/azure_embedding.go b/internal/server/web/proxy/azure_embedding.go
index c3c2830..ba963e5 100644
--- a/internal/server/web/proxy/azure_embedding.go
+++ b/internal/server/web/proxy/azure_embedding.go
@@ -7,14 +7,13 @@ import (
"net/http"
"time"
- "github.com/bricks-cloud/bricksllm/internal/key"
"github.com/bricks-cloud/bricksllm/internal/stats"
"github.com/gin-gonic/gin"
goopenai "github.com/sashabaranov/go-openai"
"go.uber.org/zap"
)
-func getAzureEmbeddingsHandler(r recorder, prod, private bool, psm ProviderSettingsManager, client http.Client, kms keyMemStorage, log *zap.Logger, aoe azureEstimator, timeOut time.Duration) gin.HandlerFunc {
+func getAzureEmbeddingsHandler(prod, private bool, client http.Client, log *zap.Logger, aoe azureEstimator, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.requests", nil, 1)
if c == nil || c.Request == nil {
@@ -23,13 +22,13 @@ func getAzureEmbeddingsHandler(r recorder, prod, private bool, psm ProviderSetti
}
cid := c.GetString(correlationId)
- raw, exists := c.Get("key")
- kc, ok := raw.(*key.ResponseKey)
- if !exists || !ok {
- stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.api_key_not_registered", nil, 1)
- JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
- return
- }
+ // raw, exists := c.Get("key")
+ // kc, ok := raw.(*key.ResponseKey)
+ // if !exists || !ok {
+ // stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.api_key_not_registered", nil, 1)
+ // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
+ // return
+ // }
ctx, cancel := context.WithTimeout(context.Background(), timeOut)
defer cancel()
@@ -55,7 +54,7 @@ func getAzureEmbeddingsHandler(r recorder, prod, private bool, psm ProviderSetti
}
defer res.Body.Close()
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_azure_embeddings_handler.latency", dur, nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -111,12 +110,12 @@ func getAzureEmbeddingsHandler(r recorder, prod, private bool, psm ProviderSetti
logError(log, "error when estimating azure openai cost for embedding", prod, cid, err)
}
- micros := int64(cost * 1000000)
- err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.record_key_spend_error", nil, 1)
- logError(log, "error when recording azure openai spend for embedding", prod, cid, err)
- }
+ // micros := int64(cost * 1000000)
+ // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_azure_embeddings_handler.record_key_spend_error", nil, 1)
+ // logError(log, "error when recording azure openai spend for embedding", prod, cid, err)
+ // }
}
}
diff --git a/internal/server/web/proxy/custom_provider.go b/internal/server/web/proxy/custom_provider.go
index 8ebcb6e..ee09c25 100644
--- a/internal/server/web/proxy/custom_provider.go
+++ b/internal/server/web/proxy/custom_provider.go
@@ -18,11 +18,6 @@ import (
"go.uber.org/zap"
)
-func countTokensFromJson(bytes []byte, contentLoc string) (int, error) {
- content := getContentFromJson(bytes, contentLoc)
- return custom.Count(content)
-}
-
func getContentFromJson(bytes []byte, contentLoc string) string {
result := gjson.Get(string(bytes), contentLoc)
content := ""
@@ -51,7 +46,7 @@ type ErrorResponse struct {
Error *Error `json:"error"`
}
-func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, cpm CustomProvidersManager, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
+func getCustomProviderHandler(prod bool, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
tags := []string{
fmt.Sprintf("path:%s", c.FullPath()),
@@ -110,7 +105,7 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c
defer res.Body.Close()
if res.StatusCode == http.StatusOK && !isStreaming {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_custom_provider_handler.latency", dur, tags, 1)
bytes, err := io.ReadAll(res.Body)
@@ -120,18 +115,20 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c
return
}
- tks, err := countTokensFromJson(bytes, rc.ResponseCompletionLocation)
- if err != nil {
- logError(log, "error when counting tokens for custom provider completion response", prod, cid, err)
- }
+ c.Set("response", bytes)
+
+ // tks, err := countTokensFromJson(bytes, rc.ResponseCompletionLocation)
+ // if err != nil {
+ // logError(log, "error when counting tokens for custom provider completion response", prod, cid, err)
+ // }
- c.Set("completionTokenCount", tks)
+ // c.Set("completionTokenCount", tks)
c.Data(res.StatusCode, "application/json", bytes)
return
}
if res.StatusCode != http.StatusOK {
- stats.Timing("bricksllm.proxy.get_custom_provider_handler.error_latency", time.Now().Sub(start), nil, 1)
+ stats.Timing("bricksllm.proxy.get_custom_provider_handler.error_latency", time.Since(start), nil, 1)
stats.Incr("bricksllm.proxy.get_custom_provider_handler.error_response", nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -149,13 +146,15 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c
buffer := bufio.NewReader(res.Body)
aggregated := ""
defer func() {
- tks, err := custom.Count(aggregated)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_custom_provider_handler.count_error", nil, 1)
- logError(log, "error when counting tokens for custom provider streaming response", prod, cid, err)
- }
+ c.Set("content", aggregated)
+
+ // tks, err := custom.Count(aggregated)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_custom_provider_handler.count_error", nil, 1)
+ // logError(log, "error when counting tokens for custom provider streaming response", prod, cid, err)
+ // }
- c.Set("completionTokenCount", tks)
+ // c.Set("completionTokenCount", tks)
}()
stats.Incr("bricksllm.proxy.get_custom_provider_handler.streaming_requests", nil, 1)
@@ -167,6 +166,13 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c
return false
}
+ if errors.Is(err, context.DeadlineExceeded) {
+ stats.Incr("bricksllm.proxy.get_custom_provider_handler.context_deadline_exceeded_error", nil, 1)
+ logError(log, "context deadline exceeded when reading bytes from custom provider response", prod, cid, err)
+
+ return false
+ }
+
stats.Incr("bricksllm.proxy.get_custom_provider_handler.read_bytes_error", nil, 1)
logError(log, "error when reading bytes from custom provider response", prod, cid, err)
@@ -206,6 +212,6 @@ func getCustomProviderHandler(prod, private bool, psm ProviderSettingsManager, c
return true
})
- stats.Timing("bricksllm.proxy.get_custom_provider_handler.streaming_latency", time.Now().Sub(start), nil, 1)
+ stats.Timing("bricksllm.proxy.get_custom_provider_handler.streaming_latency", time.Since(start), nil, 1)
}
}
diff --git a/internal/server/web/proxy/file.go b/internal/server/web/proxy/file.go
index cd2667d..55a1377 100644
--- a/internal/server/web/proxy/file.go
+++ b/internal/server/web/proxy/file.go
@@ -108,7 +108,7 @@ func logDeleteFileResponse(log *zap.Logger, data []byte, prod bool, cid string)
}
}
-func logRetrieveFileContentRequest(log *zap.Logger, data []byte, prod bool, cid, fid string) {
+func logRetrieveFileContentRequest(log *zap.Logger, prod bool, cid, fid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -119,7 +119,7 @@ func logRetrieveFileContentRequest(log *zap.Logger, data []byte, prod bool, cid,
}
}
-func logRetrieveFileContentResponse(log *zap.Logger, data []byte, prod bool, cid string) {
+func logRetrieveFileContentResponse(log *zap.Logger, prod bool, cid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/image.go b/internal/server/web/proxy/image.go
index edba584..7b8461d 100644
--- a/internal/server/web/proxy/image.go
+++ b/internal/server/web/proxy/image.go
@@ -63,7 +63,7 @@ func logEditImageRequest(log *zap.Logger, prompt, model string, n int, size, res
}
}
-func logImageVariationsRequest(log *zap.Logger, model string, n int, size, responseFormat, user string, prod, private bool, cid string) {
+func logImageVariationsRequest(log *zap.Logger, model string, n int, size, responseFormat, user string, prod bool, cid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/message.go b/internal/server/web/proxy/message.go
index 737eed3..3f44113 100644
--- a/internal/server/web/proxy/message.go
+++ b/internal/server/web/proxy/message.go
@@ -100,7 +100,7 @@ func logModifyMessageRequest(log *zap.Logger, data []byte, prod, private bool, c
}
}
-func logListMessagesRequest(log *zap.Logger, data []byte, prod bool, cid, tid string) {
+func logListMessagesRequest(log *zap.Logger, prod bool, cid, tid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/message_file.go b/internal/server/web/proxy/message_file.go
index 23518e6..8e85b7c 100644
--- a/internal/server/web/proxy/message_file.go
+++ b/internal/server/web/proxy/message_file.go
@@ -42,7 +42,7 @@ func logRetrieveMessageFileResponse(log *zap.Logger, data []byte, prod bool, cid
}
}
-func logListMessageFilesRequest(log *zap.Logger, data []byte, prod bool, cid, tid, mid string, params map[string]string) {
+func logListMessageFilesRequest(log *zap.Logger, prod bool, cid, tid, mid string, params map[string]string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go
index 79f7ded..77aaac6 100644
--- a/internal/server/web/proxy/middleware.go
+++ b/internal/server/web/proxy/middleware.go
@@ -13,6 +13,7 @@ import (
"github.com/bricks-cloud/bricksllm/internal/event"
"github.com/bricks-cloud/bricksllm/internal/key"
+ "github.com/bricks-cloud/bricksllm/internal/message"
"github.com/bricks-cloud/bricksllm/internal/provider"
"github.com/bricks-cloud/bricksllm/internal/provider/anthropic"
"github.com/bricks-cloud/bricksllm/internal/route"
@@ -25,16 +26,6 @@ import (
goopenai "github.com/sashabaranov/go-openai"
)
-type rateLimitError interface {
- Error() string
- RateLimit()
-}
-
-type expirationError interface {
- Error() string
- Reason() string
-}
-
type keyMemStorage interface {
GetKey(hash string) *key.ResponseKey
}
@@ -44,6 +35,8 @@ type keyStorage interface {
}
type estimator interface {
+ EstimateTranscriptionCost(secs float64, model string) (float64, error)
+ EstimateSpeechCost(input string, model string) (float64, error)
EstimateChatCompletionPromptCostWithTokenCounts(r *goopenai.ChatCompletionRequest) (int, float64, error)
EstimateEmbeddingsCost(r *goopenai.EmbeddingRequest) (float64, error)
EstimateChatCompletionStreamCostWithTokenCounts(model, content string) (int, float64, error)
@@ -74,8 +67,8 @@ type rateLimitManager interface {
Increment(keyId string, timeUnit key.TimeUnit) error
}
-type encrypter interface {
- Encrypt(secret string) string
+type accessCache interface {
+ GetAccessStatus(key string) bool
}
func JSON(c *gin.Context, code int, message string) {
@@ -137,7 +130,50 @@ func sendAlertsToEmail(client http.Client, message string) error {
return nil
}
-func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManager, a authenticator, prod, private bool, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, ks keyStorage, log *zap.Logger, rlm rateLimitManager, r recorder, prefix string, client http.Client) gin.HandlerFunc {
+type publisher interface {
+ Publish(message.Message)
+}
+
+func getProvider(c *gin.Context) string {
+ existing := c.GetString("provider")
+ if len(existing) != 0 {
+ return existing
+ }
+
+ parts := strings.Split(c.FullPath(), "/")
+
+ spaceRemoved := []string{}
+
+ for _, part := range parts {
+ if len(part) != 0 {
+ spaceRemoved = append(spaceRemoved, part)
+ }
+ }
+
+ if strings.HasPrefix(c.FullPath(), "/api/providers/") {
+ if len(spaceRemoved) >= 3 {
+ return spaceRemoved[2]
+ }
+ }
+
+ if strings.HasPrefix(c.FullPath(), "/api/custom/providers/") {
+ return c.Param("provider")
+ }
+
+ return ""
+}
+
+type responseWriter struct {
+ gin.ResponseWriter
+ body *bytes.Buffer
+}
+
+func (w responseWriter) Write(b []byte) (int, error) {
+ w.body.Write(b)
+ return w.ResponseWriter.Write(b)
+}
+
+func getMiddleware(cpm CustomProvidersManager, rm routeManager, a authenticator, prod, private bool, log *zap.Logger, pub publisher, prefix string, ac accessCache, client http.Client) gin.HandlerFunc {
return func(c *gin.Context) {
if c == nil || c.Request == nil {
JSON(c, http.StatusInternalServerError, "[BricksLLM] request is empty")
@@ -150,21 +186,22 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ blw := &responseWriter{body: bytes.NewBufferString(""), ResponseWriter: c.Writer}
+ c.Writer = blw
+
cid := util.NewUuid()
c.Set(correlationId, cid)
start := time.Now()
- selectedProvider := "openai"
+ enrichedEvent := &event.EventWithRequestAndContent{}
+ requestBytes := []byte(`{}`)
+ responseBytes := []byte(`{}`)
+ userId := ""
customId := c.Request.Header.Get("X-CUSTOM-EVENT-ID")
defer func() {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
latency := int(dur.Milliseconds())
- raw, exists := c.Get("key")
- var kc *key.ResponseKey
- if exists {
- kc = raw.(*key.ResponseKey)
- }
if !prod {
log.Sugar().Infof("%s | %d | %s | %s | %dms", prefix, c.Writer.Status(), c.Request.Method, c.FullPath(), latency)
@@ -173,16 +210,14 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
keyId := ""
tags := []string{}
- if kc != nil {
- keyId = kc.KeyId
- tags = kc.Tags
+ if enrichedEvent.Key != nil {
+ keyId = enrichedEvent.Key.KeyId
+ tags = enrichedEvent.Key.Tags
}
stats.Timing("bricksllm.proxy.get_middleware.proxy_latency_in_ms", dur, nil, 1)
- if len(c.GetString("provider")) != 0 {
- selectedProvider = c.GetString("provider")
- }
+ selectedProvider := getProvider(c)
if prod {
log.Info("response to proxy",
@@ -215,14 +250,26 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
Path: c.FullPath(),
Method: c.Request.Method,
CustomId: customId,
+ Request: requestBytes,
+ Response: responseBytes,
+ UserId: userId,
}
- err := r.RecordEvent(evt)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.record_event_error", nil, 1)
+ enrichedEvent.Event = evt
+ content := c.GetString("content")
+ if len(content) != 0 {
+ enrichedEvent.Content = content
+ }
- logError(log, "error when recording openai event", prod, cid, err)
+ resp, ok := c.Get("response")
+ if ok {
+ enrichedEvent.Response = resp
}
+
+ pub.Publish(message.Message{
+ Type: "event",
+ Data: enrichedEvent,
+ })
}()
if len(c.FullPath()) == 0 {
@@ -233,6 +280,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
kc, settings, err := a.AuthenticateHttpRequest(c.Request)
+ enrichedEvent.Key = kc
_, ok := err.(notAuthorizedError)
if ok {
stats.Incr("bricksllm.proxy.get_middleware.authentication_error", nil, 1)
@@ -276,16 +324,19 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ if kc.ShouldLogRequest {
+ requestBytes = body
+ }
+
if c.Request.Method != http.MethodGet {
c.Request.Body = io.NopCloser(bytes.NewReader(body))
}
- var cost float64 = 0
+ // var cost float64 = 0
if c.FullPath() == "/api/providers/anthropic/v1/complete" {
logCompletionRequest(log, body, prod, private, cid)
- selectedProvider = "anthropic"
cr := &anthropic.CompletionRequest{}
err = json.Unmarshal(body, cr)
if err != nil {
@@ -293,19 +344,25 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
- tks := ae.Count(cr.Prompt)
- tks += anthropicPromptMagicNum
- c.Set("promptTokenCount", tks)
-
- model := cr.Model
- cost, err = ae.EstimatePromptCost(model, tks)
- if err != nil {
- logError(log, "error when estimating anthropic completion prompt cost", prod, cid, err)
+ if cr.Metadata != nil {
+ userId = cr.Metadata.UserId
}
+ enrichedEvent.Request = cr
+
+ // tks := ae.Count(cr.Prompt)
+ // tks += anthropicPromptMagicNum
+ // c.Set("promptTokenCount", tks)
+
+ // model := cr.Model
+ // cost, err = ae.EstimatePromptCost(model, tks)
+ // if err != nil {
+ // logError(log, "error when estimating anthropic completion prompt cost", prod, cid, err)
+ // }
+
if cr.Stream {
c.Set("stream", cr.Stream)
- c.Set("estimatedPromptCostInUsd", cost)
+ // c.Set("estimatedPromptCostInUsd", cost)
}
if len(cr.Model) != 0 {
@@ -317,6 +374,8 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
providerName := c.Param("provider")
rc := cpm.GetRouteConfigFromMem(providerName, c.Param("wildcard"))
+ enrichedEvent.RouteConfig = rc
+
cp := cpm.GetCustomProviderFromMem(providerName)
if cp == nil {
stats.Incr("bricksllm.proxy.get_middleware.provider_not_found", nil, 1)
@@ -332,17 +391,22 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
- selectedProvider = cp.Provider
-
c.Set("provider", cp)
c.Set("route_config", rc)
- tks, err := countTokensFromJson(body, rc.RequestPromptLocation)
- if err != nil {
- logError(log, "error when counting tokens for custom provider request", prod, cid, err)
+ enrichedEvent.Request = body
+
+ customResponse, ok := c.Get("response")
+ if ok {
+ enrichedEvent.Response = customResponse
}
- c.Set("promptTokenCount", tks)
+ // tks, err := countTokensFromJson(body, rc.RequestPromptLocation)
+ // if err != nil {
+ // logError(log, "error when counting tokens for custom provider request", prod, cid, err)
+ // }
+
+ // c.Set("promptTokenCount", tks)
result := gjson.Get(string(body), rc.StreamLocation)
@@ -377,6 +441,8 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ userId = er.User
+
if rc.CacheConfig != nil && rc.CacheConfig.Enabled {
c.Set("cache_key", route.ComputeCacheKeyForEmbeddingsRequest(r, er))
}
@@ -418,12 +484,16 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
if !rc.ShouldRunEmbeddings() {
ccr := &goopenai.ChatCompletionRequest{}
+
err = json.Unmarshal(body, ccr)
if err != nil {
logError(log, "error when unmarshalling route chat completion request", prod, cid, err)
return
}
+ userId = ccr.User
+ enrichedEvent.Request = ccr
+
logRequest(log, prod, private, cid, ccr)
if ccr.Stream {
@@ -472,8 +542,6 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/azure/openai/deployments/:deployment_id/chat/completions" {
- selectedProvider = "azure"
-
ccr := &goopenai.ChatCompletionRequest{}
err = json.Unmarshal(body, ccr)
if err != nil {
@@ -481,23 +549,25 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ userId = ccr.User
+
+ enrichedEvent.Request = ccr
+
logRequest(log, prod, private, cid, ccr)
- tks, err := e.EstimateChatCompletionPromptTokenCounts("gpt-3.5-turbo", ccr)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.estimate_chat_completion_prompt_token_counts_error", nil, 1)
- logError(log, "error when estimating prompt cost", prod, cid, err)
- }
+ // tks, err := e.EstimateChatCompletionPromptTokenCounts("gpt-3.5-turbo", ccr)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_middleware.estimate_chat_completion_prompt_token_counts_error", nil, 1)
+ // logError(log, "error when estimating prompt cost", prod, cid, err)
+ // }
if ccr.Stream {
c.Set("stream", true)
- c.Set("promptTokenCount", tks)
+ // c.Set("promptTokenCount", tks)
}
}
if c.FullPath() == "/api/providers/azure/openai/deployments/:deployment_id/embeddings" {
- selectedProvider = "azure"
-
er := &goopenai.EmbeddingRequest{}
err = json.Unmarshal(body, er)
if err != nil {
@@ -505,16 +575,18 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ userId = er.User
+
c.Set("model", "ada")
c.Set("encoding_format", string(er.EncodingFormat))
logEmbeddingRequest(log, prod, private, cid, er)
- cost, err = aoe.EstimateEmbeddingsCost(er)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.estimate_azure_openai_embeddings_cost_error", nil, 1)
- logError(log, "error when estimating azure openai embeddings cost", prod, cid, err)
- }
+ // cost, err = aoe.EstimateEmbeddingsCost(er)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_middleware.estimate_azure_openai_embeddings_cost_error", nil, 1)
+ // logError(log, "error when estimating azure openai embeddings cost", prod, cid, err)
+ // }
}
if c.FullPath() == "/api/providers/openai/v1/chat/completions" {
@@ -525,21 +597,25 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
+ userId = ccr.User
+
+ enrichedEvent.Request = ccr
+
c.Set("model", ccr.Model)
logRequest(log, prod, private, cid, ccr)
- tks, cost, err := e.EstimateChatCompletionPromptCostWithTokenCounts(ccr)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.estimate_chat_completion_prompt_cost_with_token_counts_error", nil, 1)
+ // tks, cost, err := e.EstimateChatCompletionPromptCostWithTokenCounts(ccr)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_middleware.estimate_chat_completion_prompt_cost_with_token_counts_error", nil, 1)
- logError(log, "error when estimating prompt cost", prod, cid, err)
- }
+ // logError(log, "error when estimating prompt cost", prod, cid, err)
+ // }
if ccr.Stream {
c.Set("stream", true)
- c.Set("estimatedPromptCostInUsd", cost)
- c.Set("promptTokenCount", tks)
+ // c.Set("estimatedPromptCostInUsd", cost)
+ // c.Set("promptTokenCount", tks)
}
}
@@ -551,16 +627,18 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
return
}
- c.Set("model", er.Model.String())
+ userId = er.User
+
+ c.Set("model", string(er.Model))
c.Set("encoding_format", string(er.EncodingFormat))
logEmbeddingRequest(log, prod, private, cid, er)
- cost, err = e.EstimateEmbeddingsCost(er)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.estimate_embeddings_cost_error", nil, 1)
- logError(log, "error when estimating embeddings cost", prod, cid, err)
- }
+ // cost, err = e.EstimateEmbeddingsCost(er)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_middleware.estimate_embeddings_cost_error", nil, 1)
+ // logError(log, "error when estimating embeddings cost", prod, cid, err)
+ // }
}
if c.FullPath() == "/api/providers/openai/v1/images/generations" && c.Request.Method == http.MethodPost {
@@ -586,6 +664,9 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
model := c.PostForm("model")
size := c.PostForm("size")
user := c.PostForm("user")
+
+ userId = user
+
responseFormat := c.PostForm("response_format")
n, _ := strconv.Atoi(c.PostForm("n"))
@@ -602,6 +683,9 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
model := c.PostForm("model")
size := c.PostForm("size")
user := c.PostForm("user")
+
+ userId = user
+
responseFormat := c.PostForm("response_format")
n, _ := strconv.Atoi(c.PostForm("n"))
@@ -611,18 +695,20 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
c.Set("model", "dall-e-2")
}
- logImageVariationsRequest(log, model, n, size, responseFormat, user, prod, private, cid)
+ logImageVariationsRequest(log, model, n, size, responseFormat, user, prod, cid)
}
if c.FullPath() == "/api/providers/openai/v1/audio/speech" && c.Request.Method == http.MethodPost {
- sr := &SpeechRequest{}
+ sr := &goopenai.CreateSpeechRequest{}
err := json.Unmarshal(body, sr)
if err != nil {
logError(log, "error when unmarshalling create speech request", prod, cid, err)
return
}
- c.Set("model", sr.Model)
+ enrichedEvent.Request = sr
+
+ c.Set("model", string(sr.Model))
logCreateSpeechRequest(log, sr, prod, private, cid)
}
@@ -697,7 +783,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodGet {
- logRetrieveAssistantRequest(log, body, prod, cid, aid)
+ logRetrieveAssistantRequest(log, prod, cid, aid)
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodPost {
@@ -705,7 +791,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id" && c.Request.Method == http.MethodDelete {
- logDeleteAssistantRequest(log, body, prod, cid, aid)
+ logDeleteAssistantRequest(log, prod, cid, aid)
}
if c.FullPath() == "/api/providers/openai/v1/assistants" && c.Request.Method == http.MethodGet {
@@ -721,7 +807,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files/:file_id" && c.Request.Method == http.MethodDelete {
- logRetrieveAssistantFileRequest(log, prod, cid, fid, aid)
+ logDeleteAssistantFileRequest(log, prod, cid, fid, aid)
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files" && c.Request.Method == http.MethodGet {
@@ -733,7 +819,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodGet {
- logCreateThreadRequest(log, body, prod, private, cid)
+ logRetrieveThreadRequest(log, prod, cid, tid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id" && c.Request.Method == http.MethodPost {
@@ -757,7 +843,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages" && c.Request.Method == http.MethodGet {
- logListMessagesRequest(log, body, prod, cid, aid)
+ logListMessagesRequest(log, prod, cid, aid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files/:file_id" && c.Request.Method == http.MethodGet {
@@ -765,15 +851,15 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files" && c.Request.Method == http.MethodGet {
- logListAssistantFilesRequest(log, prod, cid, aid, qm)
+ logListMessageFilesRequest(log, prod, cid, tid, mid, qm)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs" && c.Request.Method == http.MethodPost {
- logCreateRunRequest(log, body, prod, cid)
+ logCreateRunRequest(log, body, prod, private, cid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id" && c.Request.Method == http.MethodGet {
- logRetrieveRunRequest(log, body, prod, cid, tid, rid)
+ logRetrieveRunRequest(log, prod, cid, tid, rid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id" && c.Request.Method == http.MethodPost {
@@ -781,7 +867,7 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs" && c.Request.Method == http.MethodGet {
- logListRunsRequest(log, body, prod, cid, tid, qm)
+ logListRunsRequest(log, prod, cid, tid, qm)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs" && c.Request.Method == http.MethodPost {
@@ -789,19 +875,19 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/cancel" && c.Request.Method == http.MethodPost {
- logCancelARunRequest(log, body, prod, cid, tid, rid)
+ logCancelARunRequest(log, prod, cid, tid, rid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/runs" && c.Request.Method == http.MethodPost {
- logCreateThreadAndRunRequest(log, body, prod, private, cid, tid, rid)
+ logCreateThreadAndRunRequest(log, body, prod, private, cid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps/:step_id" && c.Request.Method == http.MethodGet {
- logRetrieveRunStepRequest(log, body, prod, cid, tid, rid, sid)
+ logRetrieveRunStepRequest(log, prod, cid, tid, rid, sid)
}
if c.FullPath() == "/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps" && c.Request.Method == http.MethodGet {
- logListRunStepsRequest(log, body, prod, cid, tid, rid, qm)
+ logListRunStepsRequest(log, prod, cid, tid, rid, qm)
}
if c.FullPath() == "/api/providers/openai/v1/moderations" && c.Request.Method == http.MethodPost {
@@ -813,11 +899,11 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/models/:model" && c.Request.Method == http.MethodGet {
- logRetrieveModelRequest(log, body, prod, cid, md)
+ logRetrieveModelRequest(log, prod, cid, md)
}
if c.FullPath() == "/api/providers/openai/v1/models/:model" && c.Request.Method == http.MethodDelete {
- logDeleteModelRequest(log, body, prod, cid, md)
+ logDeleteModelRequest(log, prod, cid, md)
}
if c.FullPath() == "/api/providers/openai/v1/files" && c.Request.Method == http.MethodGet {
@@ -838,52 +924,21 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage
}
if c.FullPath() == "/api/providers/openai/v1/files/:file_id/content" && c.Request.Method == http.MethodGet {
- logRetrieveFileContentRequest(log, body, prod, cid, fid)
+ logRetrieveFileContentRequest(log, prod, cid, fid)
}
- err = v.Validate(kc, cost)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.validation_error", nil, 1)
-
- if _, ok := err.(expirationError); ok {
- stats.Incr("bricksllm.proxy.get_middleware.key_expired", nil, 1)
-
- truePtr := true
- _, err = ks.UpdateKey(kc.KeyId, &key.UpdateKey{
- Revoked: &truePtr,
- RevokedReason: "Key has expired or exceeded set spend limit",
- })
-
- if err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.update_key_error", nil, 1)
- log.Sugar().Debugf("error when updating revoking the api key %s: %v", kc.KeyId, err)
- }
-
- JSON(c, http.StatusUnauthorized, "[BricksLLM] key has expired")
- c.Abort()
- return
- }
-
- if _, ok := err.(rateLimitError); ok {
- stats.Incr("bricksllm.proxy.get_middleware.rate_limited", nil, 1)
- JSON(c, http.StatusTooManyRequests, "[BricksLLM] too many requests")
- c.Abort()
- return
- }
-
- logError(log, "error when validating api key", prod, cid, err)
+ if ac.GetAccessStatus(kc.KeyId) {
+ stats.Incr("bricksllm.proxy.get_middleware.rate_limited", nil, 1)
+ JSON(c, http.StatusTooManyRequests, "[BricksLLM] too many requests")
+ c.Abort()
return
}
- if len(kc.RateLimitUnit) != 0 {
- if err := rlm.Increment(kc.KeyId, kc.RateLimitUnit); err != nil {
- stats.Incr("bricksllm.proxy.get_middleware.rate_limit_increment_error", nil, 1)
+ c.Next()
- logError(log, "error when incrementing rate limit counter", prod, cid, err)
- }
+ if kc.ShouldLogResponse {
+ responseBytes = blw.body.Bytes()
}
-
- c.Next()
}
}
@@ -924,21 +979,3 @@ func containsPath(arr []key.PathConfig, path, method string) bool {
return false
}
-
-func getAuthTokenFromHeader(c *gin.Context) string {
- if strings.HasPrefix(c.FullPath(), "/api/providers/anthropic") {
- return c.GetHeader("x-api-key")
- }
-
- if strings.HasPrefix(c.FullPath(), "/api/providers/azure") {
- return c.GetHeader("api-key")
- }
-
- split := strings.Split(c.Request.Header.Get("Authorization"), "Bearer ")
- if len(split) < 2 || len(split[1]) == 0 {
- return ""
- }
-
- return split[1]
-
-}
diff --git a/internal/server/web/proxy/models.go b/internal/server/web/proxy/models.go
index 640a068..ba6db48 100644
--- a/internal/server/web/proxy/models.go
+++ b/internal/server/web/proxy/models.go
@@ -26,7 +26,7 @@ func logListModelsResponse(log *zap.Logger, data []byte, prod bool, cid string)
}
}
-func logRetrieveModelRequest(log *zap.Logger, data []byte, prod bool, cid, model string) {
+func logRetrieveModelRequest(log *zap.Logger, prod bool, cid, model string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -58,7 +58,7 @@ func logRetrieveModelResponse(log *zap.Logger, data []byte, prod bool, cid strin
}
}
-func logDeleteModelRequest(log *zap.Logger, data []byte, prod bool, cid, model string) {
+func logDeleteModelRequest(log *zap.Logger, prod bool, cid, model string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/server/web/proxy/proxy.go b/internal/server/web/proxy/proxy.go
index b2321b7..53bd9cc 100644
--- a/internal/server/web/proxy/proxy.go
+++ b/internal/server/web/proxy/proxy.go
@@ -39,7 +39,7 @@ type ProxyServer struct {
}
type recorder interface {
- RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error
+ // RecordKeySpend(keyId string, micros int64, costLimitUnit key.TimeUnit) error
RecordEvent(e *event.Event) error
}
@@ -55,12 +55,32 @@ type CustomProvidersManager interface {
GetCustomProviderFromMem(name string) *custom.Provider
}
-func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, kms keyMemStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, rlm rateLimitManager, timeOut time.Duration) (*ProxyServer, error) {
+func CorsMiddleware() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ a_or_b := func(a, b string) string {
+ if a != "" {
+ return a
+ } else {
+ return b
+ }
+ }
+ c.Header("Access-Control-Allow-Origin", a_or_b(c.GetHeader("Origin"), "*"))
+ if c.Request.Method == "OPTIONS" {
+ c.Header("Access-Control-Allow-Methods", a_or_b(c.GetHeader("Access-Control-Request-Method"), "*"))
+ c.Header("Access-Control-Allow-Headers", a_or_b(c.GetHeader("Access-Control-Request-Headers"), "*"))
+ c.Header("Access-Control-Max-Age", "3600")
+ c.AbortWithStatus(204)
+ }
+ }
+}
+
+func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, kms keyMemStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, pub publisher, rlm rateLimitManager, timeOut time.Duration, ac accessCache) (*ProxyServer, error) {
router := gin.New()
prod := mode == "production"
private := privacyMode == "strict"
- router.Use(getMiddleware(kms, cpm, rm, a, prod, private, e, ae, aoe, v, ks, log, rlm, r, "proxy", http.Client{}))
+ router.Use(CorsMiddleware())
+ router.Use(getMiddleware(cpm, rm, a, prod, private, log, pub, "proxy", ac, http.Client{}))
client := http.Client{}
@@ -68,88 +88,88 @@ func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyMan
router.POST("/api/health", getGetHealthCheckHandler())
// audios
- router.POST("/api/providers/openai/v1/audio/speech", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/audio/transcriptions", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/audio/translations", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/audio/speech", getSpeechHandler(prod, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/audio/transcriptions", getTranscriptionsHandler(prod, client, log, timeOut, e))
+ router.POST("/api/providers/openai/v1/audio/translations", getTranslationsHandler(prod, client, log, timeOut, e))
// completions
- router.POST("/api/providers/openai/v1/chat/completions", getChatCompletionHandler(r, prod, private, psm, client, kms, log, e, timeOut))
+ router.POST("/api/providers/openai/v1/chat/completions", getChatCompletionHandler(prod, private, client, log, e, timeOut))
// embeddings
- router.POST("/api/providers/openai/v1/embeddings", getEmbeddingHandler(r, prod, private, psm, client, kms, log, e, timeOut))
+ router.POST("/api/providers/openai/v1/embeddings", getEmbeddingHandler(prod, private, client, log, e, timeOut))
// moderations
- router.POST("/api/providers/openai/v1/moderations", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/moderations", getPassThroughHandler(prod, private, client, log, timeOut))
// models
- router.GET("/api/providers/openai/v1/models", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/models/:model", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.DELETE("/api/providers/openai/v1/models/:model", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/models", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/models/:model", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.DELETE("/api/providers/openai/v1/models/:model", getPassThroughHandler(prod, private, client, log, timeOut))
// assistants
- router.POST("/api/providers/openai/v1/assistants", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.DELETE("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/assistants", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/assistants", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.DELETE("/api/providers/openai/v1/assistants/:assistant_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/assistants", getPassThroughHandler(prod, private, client, log, timeOut))
// assistant files
- router.POST("/api/providers/openai/v1/assistants/:assistant_id/files", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/assistants/:assistant_id/files/:file_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.DELETE("/api/providers/openai/v1/assistants/:assistant_id/files/:file_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/assistants/:assistant_id/files", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/assistants/:assistant_id/files", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/assistants/:assistant_id/files/:file_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.DELETE("/api/providers/openai/v1/assistants/:assistant_id/files/:file_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/assistants/:assistant_id/files", getPassThroughHandler(prod, private, client, log, timeOut))
// threads
- router.POST("/api/providers/openai/v1/threads", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.DELETE("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.DELETE("/api/providers/openai/v1/threads/:thread_id", getPassThroughHandler(prod, private, client, log, timeOut))
// messages
- router.POST("/api/providers/openai/v1/threads/:thread_id/messages", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/threads/:thread_id/messages/:message_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/messages", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id/messages", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id/messages/:message_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/messages", getPassThroughHandler(prod, private, client, log, timeOut))
// message files
- router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files/:file_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files/:file_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/messages/:message_id/files", getPassThroughHandler(prod, private, client, log, timeOut))
// runs
- router.POST("/api/providers/openai/v1/threads/:thread_id/runs", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/runs", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/cancel", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/threads/runs", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps/:step_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id/runs", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/runs", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/submit_tool_outputs", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/cancel", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/threads/runs", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps/:step_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/threads/:thread_id/runs/:run_id/steps", getPassThroughHandler(prod, private, client, log, timeOut))
// files
- router.GET("/api/providers/openai/v1/files", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/files", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.DELETE("/api/providers/openai/v1/files/:file_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/files/:file_id", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.GET("/api/providers/openai/v1/files/:file_id/content", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/files", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/files", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.DELETE("/api/providers/openai/v1/files/:file_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/files/:file_id", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.GET("/api/providers/openai/v1/files/:file_id/content", getPassThroughHandler(prod, private, client, log, timeOut))
// images
- router.POST("/api/providers/openai/v1/images/generations", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/images/edits", getPassThroughHandler(r, prod, private, client, log, timeOut))
- router.POST("/api/providers/openai/v1/images/variations", getPassThroughHandler(r, prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/images/generations", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/images/edits", getPassThroughHandler(prod, private, client, log, timeOut))
+ router.POST("/api/providers/openai/v1/images/variations", getPassThroughHandler(prod, private, client, log, timeOut))
// azure
- router.POST("/api/providers/azure/openai/deployments/:deployment_id/chat/completions", getAzureChatCompletionHandler(r, prod, private, psm, client, kms, log, aoe, timeOut))
- router.POST("/api/providers/azure/openai/deployments/:deployment_id/embeddings", getAzureEmbeddingsHandler(r, prod, private, psm, client, kms, log, aoe, timeOut))
+ router.POST("/api/providers/azure/openai/deployments/:deployment_id/chat/completions", getAzureChatCompletionHandler(prod, private, client, log, aoe, timeOut))
+ router.POST("/api/providers/azure/openai/deployments/:deployment_id/embeddings", getAzureEmbeddingsHandler(prod, private, client, log, aoe, timeOut))
// anthropic
- router.POST("/api/providers/anthropic/v1/complete", getCompletionHandler(r, prod, private, client, kms, log, ae, timeOut))
+ router.POST("/api/providers/anthropic/v1/complete", getCompletionHandler(prod, private, client, log, timeOut))
// custom provider
- router.POST("/api/custom/providers/:provider/*wildcard", getCustomProviderHandler(prod, private, psm, cpm, client, log, timeOut))
+ router.POST("/api/custom/providers/:provider/*wildcard", getCustomProviderHandler(prod, client, log, timeOut))
// custom route
- router.POST("/api/routes/*route", getRouteHandler(prod, private, rm, c, aoe, e, r, client, log, timeOut))
+ router.POST("/api/routes/*route", getRouteHandler(prod, c, aoe, e, client, log))
srv := &http.Server{
Addr: ":8002",
@@ -189,9 +209,16 @@ type TranslationForm struct {
File *multipart.FileHeader `form:"file" binding:"required"`
}
-func writeFieldToBuffer(fields []string, c *gin.Context, writer *multipart.Writer) error {
+func writeFieldToBuffer(fields []string, c *gin.Context, writer *multipart.Writer, overWrites map[string]string) error {
for _, field := range fields {
val := c.PostForm(field)
+
+ if len(overWrites) != 0 {
+ if ow := overWrites[field]; len(ow) != 0 {
+ val = ow
+ }
+ }
+
if len(val) != 0 {
err := writer.WriteField(field, val)
if err != nil {
@@ -203,7 +230,7 @@ func writeFieldToBuffer(fields []string, c *gin.Context, writer *multipart.Write
return nil
}
-func getPassThroughHandler(r recorder, prod, private bool, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
+func getPassThroughHandler(prod, private bool, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
tags := []string{
fmt.Sprintf("path:%s", c.FullPath()),
@@ -296,7 +323,7 @@ func getPassThroughHandler(r recorder, prod, private bool, client http.Client, l
"size",
"response_format",
"user",
- }, c, writer)
+ }, c, writer, nil)
if err != nil {
stats.Incr("bricksllm.proxy.get_pass_through_handler.write_field_to_buffer_error", tags, 1)
logError(log, "error when writing field to buffer", prod, cid, err)
@@ -376,7 +403,7 @@ func getPassThroughHandler(r recorder, prod, private bool, client http.Client, l
"size",
"response_format",
"user",
- }, c, writer)
+ }, c, writer, nil)
if err != nil {
stats.Incr("bricksllm.proxy.get_pass_through_handler.write_field_to_buffer_error", tags, 1)
logError(log, "error when writing field to buffer", prod, cid, err)
@@ -420,132 +447,25 @@ func getPassThroughHandler(r recorder, prod, private bool, client http.Client, l
req.Body = io.NopCloser(&b)
}
- if c.FullPath() == "/api/providers/openai/v1/audio/transcriptions" && c.Request.Method == http.MethodPost {
- var b bytes.Buffer
- writer := multipart.NewWriter(&b)
-
- err := writeFieldToBuffer([]string{
- "model",
- "language",
- "prompt",
- "response_format",
- "temperature",
- }, c, writer)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.write_field_to_buffer_error", tags, 1)
- logError(log, "error when writing field to buffer", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot write field to buffer")
- return
- }
-
- var form TransriptionForm
- c.ShouldBind(&form)
-
- if form.File != nil {
- fieldWriter, err := writer.CreateFormFile("file", form.File.Filename)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.create_transcription_file_error", tags, 1)
- logError(log, "error when creating transcription file", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot create transcription file")
- return
- }
-
- opened, err := form.File.Open()
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.open_transcription_file_error", tags, 1)
- logError(log, "error when openning transcription file", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot open transcription file")
- return
- }
-
- _, err = io.Copy(fieldWriter, opened)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.copy_transcription_file_error", tags, 1)
- logError(log, "error when copying transcription file", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot copy transcription file")
- return
- }
- }
-
- req.Header.Set("Content-Type", writer.FormDataContentType())
-
- writer.Close()
-
- req.Body = io.NopCloser(&b)
- }
-
- if c.FullPath() == "/api/providers/openai/v1/audio/translations" && c.Request.Method == http.MethodPost {
- var b bytes.Buffer
- writer := multipart.NewWriter(&b)
-
- err := writeFieldToBuffer([]string{
- "model",
- "prompt",
- "response_format",
- "temperature",
- }, c, writer)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.write_field_to_buffer_error", tags, 1)
- logError(log, "error when writing field to buffer", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot write field to buffer")
- return
- }
-
- var form TranslationForm
- c.ShouldBind(&form)
-
- if form.File != nil {
- fieldWriter, err := writer.CreateFormFile("file", form.File.Filename)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.create_translation_file_error", tags, 1)
- logError(log, "error when creating translation file", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot create translation file")
- return
- }
-
- opened, err := form.File.Open()
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.open_translation_file_error", tags, 1)
- logError(log, "error when openning translation file", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot open translation file")
- return
- }
-
- _, err = io.Copy(fieldWriter, opened)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_pass_through_handler.copy_translation_file_error", tags, 1)
- logError(log, "error when copying translation file", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] cannot copy translation file")
- return
- }
- }
-
- req.Header.Set("Content-Type", writer.FormDataContentType())
-
- writer.Close()
-
- req.Body = io.NopCloser(&b)
- }
-
start := time.Now()
res, err := client.Do(req)
if err != nil {
stats.Incr("bricksllm.proxy.get_pass_through_handler.http_client_error", tags, 1)
- logError(log, "error when sending embedding request to openai", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to send embedding request to openai")
+ logError(log, "error when sending pass through request to openai", prod, cid, err)
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to send pass through request to openai")
return
}
defer res.Body.Close()
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_pass_through_handler.latency", dur, tags, 1)
bytes, err := io.ReadAll(res.Body)
if err != nil {
logError(log, "error when reading openai embedding response body", prod, cid, err)
- JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to read openai embedding response body")
+ JSON(c, http.StatusInternalServerError, "[BricksLLM] failed to read openai pass through response body")
return
}
@@ -570,7 +490,7 @@ func getPassThroughHandler(r recorder, prod, private bool, client http.Client, l
}
if c.FullPath() == "/api/providers/openai/v1/assistants" && c.Request.Method == http.MethodGet {
- logListAssistantFilesResponse(log, bytes, prod, cid)
+ logListAssistantsResponse(log, bytes, prod, private, cid)
}
if c.FullPath() == "/api/providers/openai/v1/assistants/:assistant_id/files" && c.Request.Method == http.MethodPost {
@@ -698,7 +618,7 @@ func getPassThroughHandler(r recorder, prod, private bool, client http.Client, l
}
if c.FullPath() == "/api/providers/openai/v1/files/:file_id/content" && c.Request.Method == http.MethodGet {
- logRetrieveFileContentResponse(log, bytes, prod, cid)
+ logRetrieveFileContentResponse(log, prod, cid)
}
if c.FullPath() == "/api/providers/openai/v1/images/generations" && c.Request.Method == http.MethodPost {
@@ -934,7 +854,7 @@ type EmbeddingResponseBase64 struct {
Usage goopenai.Usage `json:"usage"`
}
-func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsManager, client http.Client, kms keyMemStorage, log *zap.Logger, e estimator, timeOut time.Duration) gin.HandlerFunc {
+func getEmbeddingHandler(prod, private bool, client http.Client, log *zap.Logger, e estimator, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_embedding_handler.requests", nil, 1)
if c == nil || c.Request == nil {
@@ -942,13 +862,13 @@ func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsMan
return
}
- raw, exists := c.Get("key")
- kc, ok := raw.(*key.ResponseKey)
- if !exists || !ok {
- stats.Incr("bricksllm.proxy.get_embedding_handler.api_key_not_registered", nil, 1)
- JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
- return
- }
+ // raw, exists := c.Get("key")
+ // kc, ok := raw.(*key.ResponseKey)
+ // if !exists || !ok {
+ // stats.Incr("bricksllm.proxy.get_embedding_handler.api_key_not_registered", nil, 1)
+ // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
+ // return
+ // }
id := c.GetString(correlationId)
@@ -976,7 +896,7 @@ func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsMan
}
defer res.Body.Close()
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_embedding_handler.latency", dur, nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -1032,12 +952,12 @@ func getEmbeddingHandler(r recorder, prod, private bool, psm ProviderSettingsMan
logError(log, "error when estimating openai cost for embedding", prod, id, err)
}
- micros := int64(cost * 1000000)
- err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_embedding_handler.record_key_spend_error", nil, 1)
- logError(log, "error when recording openai spend for embedding", prod, id, err)
- }
+ // micros := int64(cost * 1000000)
+ // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_embedding_handler.record_key_spend_error", nil, 1)
+ // logError(log, "error when recording openai spend for embedding", prod, id, err)
+ // }
}
}
@@ -1072,10 +992,9 @@ var (
eventCompletionPrefix = []byte("event: completion")
eventPingPrefix = []byte("event: ping")
eventErrorPrefix = []byte("event: error")
- errorPrefix = []byte(`data: {"error":`)
)
-func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettingsManager, client http.Client, kms keyMemStorage, log *zap.Logger, e estimator, timeOut time.Duration) gin.HandlerFunc {
+func getChatCompletionHandler(prod, private bool, client http.Client, log *zap.Logger, e estimator, timeOut time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
stats.Incr("bricksllm.proxy.get_chat_completion_handler.requests", nil, 1)
@@ -1085,13 +1004,13 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
}
cid := c.GetString(correlationId)
- raw, exists := c.Get("key")
- kc, ok := raw.(*key.ResponseKey)
- if !exists || !ok {
- stats.Incr("bricksllm.proxy.get_chat_completion_handler.api_key_not_registered", nil, 1)
- JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
- return
- }
+ // raw, exists := c.Get("key")
+ // kc, ok := raw.(*key.ResponseKey)
+ // if !exists || !ok {
+ // stats.Incr("bricksllm.proxy.get_chat_completion_handler.api_key_not_registered", nil, 1)
+ // JSON(c, http.StatusUnauthorized, "[BricksLLM] api key is not registered")
+ // return
+ // }
ctx, cancel := context.WithTimeout(context.Background(), timeOut)
defer cancel()
@@ -1133,7 +1052,7 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
model := c.GetString("model")
if res.StatusCode == http.StatusOK && !isStreaming {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_chat_completion_handler.latency", dur, nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -1161,12 +1080,12 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
logError(log, "error when estimating openai cost", prod, cid, err)
}
- micros := int64(cost * 1000000)
- err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_chat_completion_handler.record_key_spend_error", nil, 1)
- logError(log, "error when recording openai spend", prod, cid, err)
- }
+ // micros := int64(cost * 1000000)
+ // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_chat_completion_handler.record_key_spend_error", nil, 1)
+ // logError(log, "error when recording openai spend", prod, cid, err)
+ // }
}
c.Set("costInUsd", cost)
@@ -1178,7 +1097,7 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
}
if res.StatusCode != http.StatusOK {
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_chat_completion_handler.error_latency", dur, nil, 1)
stats.Incr("bricksllm.proxy.get_chat_completion_handler.error_response", nil, 1)
@@ -1195,22 +1114,24 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
}
buffer := bufio.NewReader(res.Body)
- var totalCost float64 = 0
- var totalTokens int = 0
+ // var totalCost float64 = 0
+ // var totalTokens int = 0
content := ""
defer func() {
- tks, cost, err := e.EstimateChatCompletionStreamCostWithTokenCounts(model, content)
- if err != nil {
- stats.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1)
- logError(log, "error when estimating chat completion stream cost with token counts", prod, cid, err)
- }
+ c.Set("content", content)
+
+ // tks, cost, err := e.EstimateChatCompletionStreamCostWithTokenCounts(model, content)
+ // if err != nil {
+ // stats.Incr("bricksllm.proxy.get_chat_completion_handler.estimate_chat_completion_cost_and_tokens_error", nil, 1)
+ // logError(log, "error when estimating chat completion stream cost with token counts", prod, cid, err)
+ // }
- estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd")
- totalCost = cost + estimatedPromptCost
- totalTokens += tks
+ // estimatedPromptCost := c.GetFloat64("estimatedPromptCostInUsd")
+ // totalCost = cost + estimatedPromptCost
+ // totalTokens += tks
- c.Set("costInUsd", totalCost)
- c.Set("completionTokenCount", totalTokens)
+ // c.Set("costInUsd", totalCost)
+ // c.Set("completionTokenCount", totalTokens)
}()
stats.Incr("bricksllm.proxy.get_chat_completion_handler.streaming_requests", nil, 1)
@@ -1222,6 +1143,13 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
return false
}
+ if errors.Is(err, context.DeadlineExceeded) {
+ stats.Incr("bricksllm.proxy.get_chat_completion_handler.context_deadline_exceeded_error", nil, 1)
+ logError(log, "context deadline exceeded when reading bytes from openai chat completion response", prod, cid, err)
+
+ return false
+ }
+
stats.Incr("bricksllm.proxy.get_chat_completion_handler.read_bytes_error", nil, 1)
logError(log, "error when reading bytes from openai chat completion response", prod, cid, err)
@@ -1271,7 +1199,7 @@ func getChatCompletionHandler(r recorder, prod, private bool, psm ProviderSettin
return true
})
- stats.Timing("bricksllm.proxy.get_chat_completion_handler.streaming_latency", time.Now().Sub(start), nil, 1)
+ stats.Timing("bricksllm.proxy.get_chat_completion_handler.streaming_latency", time.Since(start), nil, 1)
}
}
@@ -1515,7 +1443,7 @@ func logEmbeddingRequest(log *zap.Logger, prod, private bool, id string, r *goop
if prod {
fields := []zapcore.Field{
zap.String(correlationId, id),
- zap.String("model", r.Model.String()),
+ zap.String("model", string(r.Model)),
zap.String("encoding_format", string(r.EncodingFormat)),
zap.String("user", r.User),
}
diff --git a/internal/server/web/proxy/route.go b/internal/server/web/proxy/route.go
index 13357a9..2b56741 100644
--- a/internal/server/web/proxy/route.go
+++ b/internal/server/web/proxy/route.go
@@ -27,7 +27,7 @@ type cache interface {
GetBytes(key string) ([]byte, error)
}
-func getRouteHandler(prod, private bool, rm routeManager, ca cache, aoe azureEstimator, e estimator, r recorder, client http.Client, log *zap.Logger, timeOut time.Duration) gin.HandlerFunc {
+func getRouteHandler(prod bool, ca cache, aoe azureEstimator, e estimator, client http.Client, log *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
trueStart := time.Now()
@@ -64,7 +64,7 @@ func getRouteHandler(prod, private bool, rm routeManager, ca cache, aoe azureEst
bytes, err := ca.GetBytes(cacheKey)
if err == nil && len(bytes) != 0 {
stats.Incr("bricksllm.proxy.get_route_handeler.success", nil, 1)
- stats.Timing("bricksllm.proxy.get_route_handeler.success_latency", time.Now().Sub(trueStart), nil, 1)
+ stats.Timing("bricksllm.proxy.get_route_handeler.success_latency", time.Since(trueStart), nil, 1)
c.Set("provider", "cached")
c.Data(http.StatusOK, "application/json", bytes)
@@ -115,7 +115,7 @@ func getRouteHandler(prod, private bool, rm routeManager, ca cache, aoe azureEst
res := runRes.Response
defer res.Body.Close()
- dur := time.Now().Sub(start)
+ dur := time.Since(start)
stats.Timing("bricksllm.proxy.get_route_handeler.latency", dur, nil, 1)
bytes, err := io.ReadAll(res.Body)
@@ -144,7 +144,7 @@ func getRouteHandler(prod, private bool, rm routeManager, ca cache, aoe azureEst
}
- err := parseResult(c, ca, kc, rc.ShouldRunEmbeddings(), bytes, e, aoe, r, runRes.Model, runRes.Provider)
+ err := parseResult(c, rc.ShouldRunEmbeddings(), bytes, e, aoe, runRes.Model, runRes.Provider)
if err != nil {
logError(log, "error when parsing run steps result", prod, cid, err)
}
@@ -173,7 +173,7 @@ func getRouteHandler(prod, private bool, rm routeManager, ca cache, aoe azureEst
}
}
-func parseResult(c *gin.Context, ca cache, kc *key.ResponseKey, runEmbeddings bool, bytes []byte, e estimator, aoe azureEstimator, r recorder, model, provider string) error {
+func parseResult(c *gin.Context, runEmbeddings bool, bytes []byte, e estimator, aoe azureEstimator, model, provider string) error {
base64ChatRes := &EmbeddingResponseBase64{}
chatRes := &EmbeddingResponse{}
@@ -231,12 +231,12 @@ func parseResult(c *gin.Context, ca cache, kc *key.ResponseKey, runEmbeddings bo
cost = ecost
}
- micros := int64(cost * 1000000)
+ // micros := int64(cost * 1000000)
- err := r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- return err
- }
+ // err := r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // return err
+ // }
}
if !runEmbeddings {
@@ -262,11 +262,11 @@ func parseResult(c *gin.Context, ca cache, kc *key.ResponseKey, runEmbeddings bo
}
}
- micros := int64(cost * 1000000)
- err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
- if err != nil {
- return err
- }
+ // micros := int64(cost * 1000000)
+ // err = r.RecordKeySpend(kc.KeyId, micros, kc.CostLimitInUsdUnit)
+ // if err != nil {
+ // return err
+ // }
}
return nil
diff --git a/internal/server/web/proxy/run.go b/internal/server/web/proxy/run.go
index de4cfa5..a8558ff 100644
--- a/internal/server/web/proxy/run.go
+++ b/internal/server/web/proxy/run.go
@@ -8,7 +8,7 @@ import (
"go.uber.org/zap/zapcore"
)
-func logCreateRunRequest(log *zap.Logger, data []byte, prod bool, cid string) {
+func logCreateRunRequest(log *zap.Logger, data []byte, prod, private bool, cid string) {
rr := &goopenai.RunRequest{}
err := json.Unmarshal(data, rr)
if err != nil {
@@ -20,12 +20,15 @@ func logCreateRunRequest(log *zap.Logger, data []byte, prod bool, cid string) {
fields := []zapcore.Field{
zap.String(correlationId, cid),
zap.String("assistant_id", rr.AssistantID),
- zap.Stringp("instruction", rr.Instructions),
- zap.Stringp("model", rr.Model),
+ zap.String("model", rr.Model),
zap.Any("tools", rr.Tools),
zap.Any("metadata", rr.Metadata),
}
+ if !private {
+ fields = append(fields, zap.String("instruction", rr.Instructions))
+ }
+
log.Info("openai create run request", fields...)
}
}
@@ -68,7 +71,7 @@ func logRunResponse(log *zap.Logger, data []byte, prod, private bool, cid string
}
}
-func logRetrieveRunRequest(log *zap.Logger, data []byte, prod bool, cid, tid, rid string) {
+func logRetrieveRunRequest(log *zap.Logger, prod bool, cid, tid, rid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -100,7 +103,7 @@ func logModifyRunRequest(log *zap.Logger, data []byte, prod bool, cid, tid, rid
}
}
-func logListRunsRequest(log *zap.Logger, data []byte, prod bool, cid, tid string, params map[string]string) {
+func logListRunsRequest(log *zap.Logger, prod bool, cid, tid string, params map[string]string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -171,7 +174,7 @@ func logSubmitToolOutputsRequest(log *zap.Logger, data []byte, prod bool, cid, t
}
}
-func logCancelARunRequest(log *zap.Logger, data []byte, prod bool, cid, tid, rid string) {
+func logCancelARunRequest(log *zap.Logger, prod bool, cid, tid, rid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -183,7 +186,7 @@ func logCancelARunRequest(log *zap.Logger, data []byte, prod bool, cid, tid, rid
}
}
-func logCreateThreadAndRunRequest(log *zap.Logger, data []byte, prod, private bool, cid, tid, rid string) {
+func logCreateThreadAndRunRequest(log *zap.Logger, data []byte, prod, private bool, cid string) {
r := &goopenai.CreateThreadAndRunRequest{}
err := json.Unmarshal(data, r)
if err != nil {
@@ -196,20 +199,20 @@ func logCreateThreadAndRunRequest(log *zap.Logger, data []byte, prod, private bo
zap.String(correlationId, cid),
zap.String("assistant_id", r.AssistantID),
zap.Any("thread", r.Thread),
- zap.Stringp("model", r.Model),
+ zap.String("model", r.Model),
zap.Any("tools", r.Tools),
zap.Any("metadata", r.Metadata),
}
if !private {
- fields = append(fields, zap.Stringp("instructions", r.Instructions))
+ fields = append(fields, zap.String("instructions", r.Instructions))
}
log.Info("openai create thread and run request", fields...)
}
}
-func logRetrieveRunStepRequest(log *zap.Logger, data []byte, prod bool, cid, tid, rid, sid string) {
+func logRetrieveRunStepRequest(log *zap.Logger, prod bool, cid, tid, rid, sid string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
@@ -254,7 +257,7 @@ func logRetrieveRunStepResponse(log *zap.Logger, data []byte, prod bool, cid str
}
}
-func logListRunStepsRequest(log *zap.Logger, data []byte, prod bool, cid, tid, rid string, params map[string]string) {
+func logListRunStepsRequest(log *zap.Logger, prod bool, cid, tid, rid string, params map[string]string) {
if prod {
fields := []zapcore.Field{
zap.String(correlationId, cid),
diff --git a/internal/storage/postgresql/postgresql.go b/internal/storage/postgresql/postgresql.go
index 8de0710..f54a4a0 100644
--- a/internal/storage/postgresql/postgresql.go
+++ b/internal/storage/postgresql/postgresql.go
@@ -101,7 +101,7 @@ func (s *Store) AlterKeysTable() error {
END IF;
END
$$;
- ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[];
+ ALTER TABLE keys ADD COLUMN IF NOT EXISTS setting_id VARCHAR(255), ADD COLUMN IF NOT EXISTS allowed_paths JSONB, ADD COLUMN IF NOT EXISTS setting_ids VARCHAR(255)[] NOT NULL DEFAULT ARRAY[]::VARCHAR(255)[], ADD COLUMN IF NOT EXISTS should_log_request BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS should_log_response BOOLEAN NOT NULL DEFAULT FALSE, ADD COLUMN IF NOT EXISTS rotation_enabled BOOLEAN NOT NULL DEFAULT FALSE;
`
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt)
@@ -142,7 +142,7 @@ func (s *Store) CreateEventsTable() error {
func (s *Store) AlterEventsTable() error {
alterTableQuery := `
- ALTER TABLE events ADD COLUMN IF NOT EXISTS path VARCHAR(255), ADD COLUMN IF NOT EXISTS method VARCHAR(255), ADD COLUMN IF NOT EXISTS custom_id VARCHAR(255)
+ ALTER TABLE events ADD COLUMN IF NOT EXISTS path VARCHAR(255), ADD COLUMN IF NOT EXISTS method VARCHAR(255), ADD COLUMN IF NOT EXISTS custom_id VARCHAR(255), ADD COLUMN IF NOT EXISTS request JSONB, ADD COLUMN IF NOT EXISTS response JSONB, ADD COLUMN IF NOT EXISTS user_id VARCHAR(255) NOT NULL DEFAULT '';
`
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt)
@@ -211,8 +211,8 @@ func (s *Store) DropKeysTable() error {
func (s *Store) InsertEvent(e *event.Event) error {
query := `
- INSERT INTO events (event_id, created_at, tags, key_id, cost_in_usd, provider, model, status_code, prompt_token_count, completion_token_count, latency_in_ms, path, method, custom_id)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
+ INSERT INTO events (event_id, created_at, tags, key_id, cost_in_usd, provider, model, status_code, prompt_token_count, completion_token_count, latency_in_ms, path, method, custom_id, request, response, user_id)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
`
values := []any{
@@ -230,6 +230,9 @@ func (s *Store) InsertEvent(e *event.Event) error {
e.Path,
e.Method,
e.CustomId,
+ e.Request,
+ e.Response,
+ e.UserId,
}
ctx, cancel := context.WithTimeout(context.Background(), s.wt)
@@ -241,16 +244,62 @@ func (s *Store) InsertEvent(e *event.Event) error {
return nil
}
-func (s *Store) GetEvents(customId string) ([]*event.Event, error) {
+func shouldAddAnd(userId, customId string, keyIds []string) bool {
+ num := 0
+
+ if len(userId) == 0 {
+ num++
+ }
+
+ if len(customId) == 0 {
+ num++
+ }
+
+ if len(keyIds) == 0 {
+ num++
+ }
+
+ return num >= 2
+}
+
+func (s *Store) GetEvents(userId string, customId string, keyIds []string, start int64, end int64) ([]*event.Event, error) {
+ if len(customId) == 0 && len(keyIds) == 0 && len(userId) == 0 {
+ return nil, errors.New("none of customId, keyIds and userId is specified")
+ }
+
+ if len(keyIds) != 0 && (start == 0 || end == 0) {
+ return nil, errors.New("keyIds are provided but either start or end is not specified")
+ }
+
query := `
- SELECT * FROM events WHERE $1 = custom_id
+ SELECT * FROM events WHERE
`
+ if len(customId) != 0 {
+ query += fmt.Sprintf(" custom_id = '%s'", customId)
+ }
+
+ if len(customId) > 0 && len(userId) > 0 {
+ query += " AND"
+ }
+
+ if len(userId) != 0 {
+ query += fmt.Sprintf(" user_id = '%s'", userId)
+ }
+
+ if (len(customId) > 0 || len(userId) > 0) && len(keyIds) > 0 {
+ query += " AND"
+ }
+
+ if len(keyIds) != 0 {
+ query += fmt.Sprintf(" key_id = ANY('%s') AND created_at >= %d AND created_at <= %d", sliceToSqlStringArray(keyIds), start, end)
+ }
+
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.rt)
defer cancel()
events := []*event.Event{}
- rows, err := s.db.QueryContext(ctxTimeout, query, customId)
+ rows, err := s.db.QueryContext(ctxTimeout, query)
if err != nil {
if err == sql.ErrNoRows {
return events, nil
@@ -280,6 +329,9 @@ func (s *Store) GetEvents(customId string) ([]*event.Event, error) {
&path,
&method,
&customId,
+ &e.Request,
+ &e.Response,
+ &e.UserId,
); err != nil {
return nil, err
}
@@ -356,7 +408,7 @@ func (s *Store) GetLatencyPercentiles(start, end int64, tags, keyIds []string) (
return data, nil
}
-func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds []string, filters []string) ([]*event.DataPoint, error) {
+func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds, customIds, userIds []string, filters []string) ([]*event.DataPoint, error) {
groupByQuery := "GROUP BY time_series_table.series"
selectQuery := "SELECT series AS time_stamp, COALESCE(COUNT(events_table.event_id),0) AS num_of_requests, COALESCE(SUM(events_table.cost_in_usd),0) AS cost_in_usd, COALESCE(SUM(events_table.latency_in_ms),0) AS latency_in_ms, COALESCE(SUM(events_table.prompt_token_count),0) AS prompt_token_count, COALESCE(SUM(events_table.completion_token_count),0) AS completion_token_count, COALESCE(SUM(CASE WHEN status_code = 200 THEN 1 END),0) AS success_count"
@@ -371,6 +423,16 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds []s
groupByQuery += ",events_table.key_id"
selectQuery += ",events_table.key_id as keyId"
}
+
+ if filter == "customId" {
+ groupByQuery += ",events_table.custom_id"
+ selectQuery += ",events_table.custom_id as customId"
+ }
+
+ if filter == "userId" {
+ groupByQuery += ",events_table.user_id"
+ selectQuery += ",events_table.user_id as userId"
+ }
}
}
@@ -406,6 +468,14 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds []s
conditionBlock += fmt.Sprintf("AND key_id = ANY('%s')", sliceToSqlStringArray(keyIds))
}
+ if len(customIds) != 0 {
+ conditionBlock += fmt.Sprintf("AND custom_id = ANY('%s')", sliceToSqlStringArray(customIds))
+ }
+
+ if len(userIds) != 0 {
+ conditionBlock += fmt.Sprintf("AND user_id = ANY('%s')", sliceToSqlStringArray(userIds))
+ }
+
eventSelectionBlock += conditionBlock
eventSelectionBlock += ")"
@@ -425,6 +495,8 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds []s
var e event.DataPoint
var model sql.NullString
var keyId sql.NullString
+ var customId sql.NullString
+ var userId sql.NullString
additional := []any{
&e.TimeStamp,
@@ -445,6 +517,14 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds []s
if filter == "keyId" {
additional = append(additional, &keyId)
}
+
+ if filter == "customId" {
+ additional = append(additional, &customId)
+ }
+
+ if filter == "userId" {
+ additional = append(additional, &userId)
+ }
}
}
@@ -457,6 +537,8 @@ func (s *Store) GetEventDataPoints(start, end, increment int64, tags, keyIds []s
pe := &e
pe.Model = model.String
pe.KeyId = keyId.String
+ pe.CustomId = customId.String
+ pe.UserId = userId.String
data = append(data, pe)
}
@@ -548,6 +630,9 @@ func (s *Store) GetKeys(tags, keyIds []string, provider string) ([]*key.Response
&settingId,
&data,
pq.Array(&k.SettingIds),
+ &k.ShouldLogRequest,
+ &k.ShouldLogResponse,
+ &k.RotationEnabled,
); err != nil {
return nil, err
}
@@ -604,6 +689,9 @@ func (s *Store) GetKey(keyId string) (*key.ResponseKey, error) {
&settingId,
&data,
pq.Array(&k.SettingIds),
+ &k.ShouldLogRequest,
+ &k.ShouldLogResponse,
+ &k.RotationEnabled,
); err != nil {
return nil, err
}
@@ -745,6 +833,9 @@ func (s *Store) GetAllKeys() ([]*key.ResponseKey, error) {
&settingId,
&data,
pq.Array(&k.SettingIds),
+ &k.ShouldLogRequest,
+ &k.ShouldLogResponse,
+ &k.RotationEnabled,
); err != nil {
return nil, err
}
@@ -839,6 +930,9 @@ func (s *Store) GetUpdatedKeys(updatedAt int64) ([]*key.ResponseKey, error) {
&settingId,
&data,
pq.Array(&k.SettingIds),
+ &k.ShouldLogRequest,
+ &k.ShouldLogResponse,
+ &k.RotationEnabled,
); err != nil {
return nil, err
}
@@ -927,14 +1021,50 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error
}
if uk.Revoked != nil {
+ if *uk.Revoked && len(uk.RevokedReason) != 0 {
+ values = append(values, uk.RevokedReason)
+ fields = append(fields, fmt.Sprintf("revoked_reason = $%d", counter))
+ counter++
+ }
+
+ if !*uk.Revoked {
+ values = append(values, "")
+ fields = append(fields, fmt.Sprintf("revoked_reason = $%d", counter))
+ counter++
+ }
+
values = append(values, uk.Revoked)
fields = append(fields, fmt.Sprintf("revoked = $%d", counter))
counter++
}
- if len(uk.RevokedReason) != 0 {
- values = append(values, uk.RevokedReason)
- fields = append(fields, fmt.Sprintf("revoked_reason = $%d", counter))
+ if uk.CostLimitInUsd != 0 {
+ values = append(values, uk.CostLimitInUsd)
+ fields = append(fields, fmt.Sprintf("cost_limit_in_usd = $%d", counter))
+ counter++
+ }
+
+ if uk.CostLimitInUsdOverTime != 0 {
+ values = append(values, uk.CostLimitInUsdOverTime)
+ fields = append(fields, fmt.Sprintf("cost_limit_in_usd_over_time = $%d", counter))
+ counter++
+ }
+
+ if len(uk.CostLimitInUsdUnit) != 0 {
+ values = append(values, uk.CostLimitInUsdUnit)
+ fields = append(fields, fmt.Sprintf("cost_limit_in_usd_unit = $%d", counter))
+ counter++
+ }
+
+ if uk.RateLimitOverTime != 0 {
+ values = append(values, uk.RateLimitOverTime)
+ fields = append(fields, fmt.Sprintf("rate_limit_over_time = $%d", counter))
+ counter++
+ }
+
+ if len(uk.RateLimitUnit) != 0 {
+ values = append(values, uk.RateLimitUnit)
+ fields = append(fields, fmt.Sprintf("rate_limit_unit = $%d", counter))
counter++
}
@@ -950,6 +1080,24 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error
counter++
}
+ if uk.ShouldLogRequest != nil {
+ values = append(values, *uk.ShouldLogRequest)
+ fields = append(fields, fmt.Sprintf("should_log_request = $%d", counter))
+ counter++
+ }
+
+ if uk.ShouldLogResponse != nil {
+ values = append(values, *uk.ShouldLogResponse)
+ fields = append(fields, fmt.Sprintf("should_log_response = $%d", counter))
+ counter++
+ }
+
+ if uk.RotationEnabled != nil {
+ values = append(values, *uk.RotationEnabled)
+ fields = append(fields, fmt.Sprintf("rotation_enabled = $%d", counter))
+ counter++
+ }
+
if uk.AllowedPaths != nil {
data, err := json.Marshal(uk.AllowedPaths)
if err != nil {
@@ -986,6 +1134,9 @@ func (s *Store) UpdateKey(id string, uk *key.UpdateKey) (*key.ResponseKey, error
&settingId,
&data,
pq.Array(&k.SettingIds),
+ &k.ShouldLogRequest,
+ &k.ShouldLogResponse,
+ &k.RotationEnabled,
); err != nil {
if err == sql.ErrNoRows {
return nil, internal_errors.NewNotFoundError(fmt.Sprintf("key not found for id: %s", id))
@@ -1125,8 +1276,8 @@ func (s *Store) CreateProviderSetting(setting *provider.Setting) (*provider.Sett
func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
query := `
- INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, setting_id, allowed_paths, setting_ids)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
+ INSERT INTO keys (name, created_at, updated_at, tags, revoked, key_id, key, revoked_reason, cost_limit_in_usd, cost_limit_in_usd_over_time, cost_limit_in_usd_unit, rate_limit_over_time, rate_limit_unit, ttl, setting_id, allowed_paths, setting_ids, should_log_request, should_log_response, rotation_enabled)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20)
RETURNING *;
`
@@ -1153,6 +1304,9 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
rk.SettingId,
rdata,
sliceToSqlStringArray(rk.SettingIds),
+ rk.ShouldLogRequest,
+ rk.ShouldLogResponse,
+ rk.RotationEnabled,
}
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.wt)
@@ -1180,6 +1334,9 @@ func (s *Store) CreateKey(rk *key.RequestKey) (*key.ResponseKey, error) {
&settingId,
&data,
pq.Array(&k.SettingIds),
+ &k.ShouldLogRequest,
+ &k.ShouldLogResponse,
+ &k.RotationEnabled,
); err != nil {
return nil, err
}
diff --git a/internal/storage/redis/access-cache.go b/internal/storage/redis/access-cache.go
new file mode 100644
index 0000000..6faa660
--- /dev/null
+++ b/internal/storage/redis/access-cache.go
@@ -0,0 +1,59 @@
+package redis
+
+import (
+ "context"
+ "time"
+
+ "github.com/bricks-cloud/bricksllm/internal/key"
+ "github.com/redis/go-redis/v9"
+)
+
+type AccessCache struct {
+ client *redis.Client
+ wt time.Duration
+ rt time.Duration
+}
+
+func NewAccessCache(c *redis.Client, wt time.Duration, rt time.Duration) *AccessCache {
+ return &AccessCache{
+ client: c,
+ wt: wt,
+ rt: rt,
+ }
+}
+
+func (ac *AccessCache) Delete(key string) error {
+ ctx, cancel := context.WithTimeout(context.Background(), ac.wt)
+ defer cancel()
+ err := ac.client.Del(ctx, key).Err()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (ac *AccessCache) Set(key string, timeUnit key.TimeUnit) error {
+ ttl, err := getCounterTtl(timeUnit)
+ if err != nil {
+ return err
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), ac.wt)
+ defer cancel()
+ err = ac.client.Set(ctx, key, true, ttl.Sub(time.Now())).Err()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (ac *AccessCache) GetAccessStatus(key string) bool {
+ ctx, cancel := context.WithTimeout(context.Background(), ac.rt)
+ defer cancel()
+
+ result := ac.client.Get(ctx, key)
+
+ return result.Err() != redis.Nil
+}
diff --git a/internal/storage/redis/cache.go b/internal/storage/redis/cache.go
index 73e887d..a60f13b 100644
--- a/internal/storage/redis/cache.go
+++ b/internal/storage/redis/cache.go
@@ -35,6 +35,17 @@ func (c *Cache) Set(key string, value interface{}, ttl time.Duration) error {
return nil
}
+func (c *Cache) Delete(key string) error {
+ ctx, cancel := context.WithTimeout(context.Background(), c.wt)
+ defer cancel()
+ err := c.client.Del(ctx, key).Err()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
func (c *Cache) GetBytes(key string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), c.rt)
defer cancel()
diff --git a/internal/storage/redis/thread-cache.go b/internal/storage/redis/thread-cache.go
new file mode 100644
index 0000000..eb780cf
--- /dev/null
+++ b/internal/storage/redis/thread-cache.go
@@ -0,0 +1,53 @@
+package redis
+
+import (
+ "context"
+ "time"
+
+ "github.com/redis/go-redis/v9"
+)
+
+type ThreadCache struct {
+ client *redis.Client
+ wt time.Duration
+ rt time.Duration
+}
+
+func NewThreadCache(c *redis.Client, wt time.Duration, rt time.Duration) *ThreadCache {
+ return &ThreadCache{
+ client: c,
+ wt: wt,
+ rt: rt,
+ }
+}
+
+func (ac *ThreadCache) Delete(key string) error {
+ ctx, cancel := context.WithTimeout(context.Background(), ac.wt)
+ defer cancel()
+ err := ac.client.Del(ctx, key).Err()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (ac *ThreadCache) Set(key string, dur time.Duration) error {
+ ctx, cancel := context.WithTimeout(context.Background(), ac.wt)
+ defer cancel()
+ err := ac.client.Set(ctx, key, true, dur).Err()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (ac *ThreadCache) GetThreadStatus(key string) bool {
+ ctx, cancel := context.WithTimeout(context.Background(), ac.rt)
+ defer cancel()
+
+ result := ac.client.Get(ctx, key)
+
+ return result.Err() != redis.Nil
+}
diff --git a/internal/validator/validator.go b/internal/validator/validator.go
index 9551261..d2fd4ba 100644
--- a/internal/validator/validator.go
+++ b/internal/validator/validator.go
@@ -58,12 +58,12 @@ func (v *Validator) Validate(k *key.ResponseKey, promptCost float64) error {
return err
}
- err = v.validateCostLimitOverTime(k.KeyId, k.CostLimitInUsdOverTime, k.CostLimitInUsdUnit, promptCost)
+ err = v.validateCostLimitOverTime(k.KeyId, k.CostLimitInUsdOverTime, k.CostLimitInUsdUnit)
if err != nil {
return err
}
- err = v.validateCostLimit(k.KeyId, k.CostLimitInUsd, promptCost)
+ err = v.validateCostLimit(k.KeyId, k.CostLimitInUsd)
if err != nil {
return err
}
@@ -96,14 +96,14 @@ func (v *Validator) validateRateLimitOverTime(keyId string, rateLimitOverTime in
return errors.New("failed to get rate limit counter")
}
- if c+1 > int64(rateLimitOverTime) {
+ if c >= int64(rateLimitOverTime) {
return internal_errors.NewRateLimitError(fmt.Sprintf("key exceeded rate limit %d requests per %s", rateLimitOverTime, rateLimitUnit))
}
return nil
}
-func (v *Validator) validateCostLimitOverTime(keyId string, costLimitOverTime float64, costLimitUnit key.TimeUnit, promptCost float64) error {
+func (v *Validator) validateCostLimitOverTime(keyId string, costLimitOverTime float64, costLimitUnit key.TimeUnit) error {
if costLimitOverTime == 0 {
return nil
}
@@ -113,8 +113,8 @@ func (v *Validator) validateCostLimitOverTime(keyId string, costLimitOverTime fl
return errors.New("failed to get cached token cost")
}
- if convertDollarToMicroDollars(promptCost)+cachedCost > convertDollarToMicroDollars(costLimitOverTime) {
- return internal_errors.NewExpirationError(fmt.Sprintf("cost limit: %f has been reached for the current time period: %s", costLimitOverTime, costLimitUnit), internal_errors.CostLimitExpiration)
+ if cachedCost >= convertDollarToMicroDollars(costLimitOverTime) {
+ return internal_errors.NewCostLimitError(fmt.Sprintf("cost limit: %f has been reached for the current time period: %s", costLimitOverTime, costLimitUnit))
}
return nil
@@ -124,7 +124,7 @@ func convertDollarToMicroDollars(dollar float64) int64 {
return int64(dollar * 1000000)
}
-func (v *Validator) validateCostLimit(keyId string, costLimit float64, promptCost float64) error {
+func (v *Validator) validateCostLimit(keyId string, costLimit float64) error {
if costLimit == 0 {
return nil
}
@@ -134,7 +134,7 @@ func (v *Validator) validateCostLimit(keyId string, costLimit float64, promptCos
return errors.New("failed to get total token cost")
}
- if convertDollarToMicroDollars(promptCost)+existingTotalCost > convertDollarToMicroDollars(costLimit) {
+ if existingTotalCost >= convertDollarToMicroDollars(costLimit) {
return internal_errors.NewExpirationError(fmt.Sprintf("total cost limit: %f has been reached", costLimit), internal_errors.CostLimitExpiration)
}