Skip to content

Commit

Permalink
add new openai endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
spikelu2016 committed Jul 17, 2024
1 parent 92a9a79 commit 60b64b3
Show file tree
Hide file tree
Showing 19 changed files with 1,166 additions and 22 deletions.
2 changes: 1 addition & 1 deletion cmd/bricksllm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ func main() {
scanner := pii.NewScanner(detector)
cd := custompolicy.NewOpenAiDetector(cfg.CustomPolicyDetectionTimeout, cfg.OpenAiApiKey)

ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ace, aoe, v, rec, messageBus, rlm, cfg.ProxyTimeout, accessCache, userAccessCache, pm, scanner, cd, die, um)
ps, err := proxy.NewProxyServer(log, *modePtr, *privacyPtr, c, m, rm, a, psm, cpm, store, memStore, ce, ace, aoe, v, rec, messageBus, rlm, cfg.ProxyTimeout, accessCache, userAccessCache, pm, scanner, cd, die, um, cfg.RemoveUserAgent)
if err != nil {
log.Sugar().Fatalf("error creating proxy http server: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ require (
github.com/mattn/go-colorable v0.1.13
github.com/pkoukk/tiktoken-go v0.1.7
github.com/redis/go-redis/v9 v9.0.5
github.com/sashabaranov/go-openai v1.24.0
github.com/sashabaranov/go-openai v1.26.3
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.17.0
go.uber.org/zap v1.24.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/sashabaranov/go-openai v1.24.0 h1:4H4Pg8Bl2RH/YSnU8DYumZbuHnnkfioor/dtNlB20D4=
github.com/sashabaranov/go-openai v1.24.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.26.3 h1:Tjnh4rcvsSU68f66r05mys+Zou4vo4qyvkne6AIRJPI=
github.com/sashabaranov/go-openai v1.26.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type Config struct {
AmazonRegion string `koanf:"amazon_region" env:"AMAZON_REGION" envDefault:"us-west-2"`
AmazonRequestTimeout time.Duration `koanf:"amazon_request_timeout" env:"AMAZON_REQUEST_TIMEOUT" envDefault:"5s"`
AmazonConnectionTimeout time.Duration `koanf:"amazon_connection_timeout" env:"AMAZON_CONNECTION_TIMEOUT" envDefault:"10s"`
RemoveUserAgent bool `koanf:"remove_user_agent" env:"REMOVE_USER_AGENT" envDefault:"false"`
}

func prepareDotEnv(envFilePath string) error {
Expand Down
10 changes: 7 additions & 3 deletions internal/server/web/proxy/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ type anthropicEstimator interface {
CountMessagesTokens(messages []anthropic.Message) int
}

func copyHttpHeaders(source *http.Request, dest *http.Request) {
func copyHttpHeaders(source *http.Request, dest *http.Request, removeUseAgent bool) {
for k := range source.Header {
if strings.ToLower(k) != "X-CUSTOM-EVENT-ID" {
dest.Header.Set(k, source.Header.Get(k))
}
}

if removeUseAgent {
dest.Header.Del("User-Agent")
}

dest.Header.Set("Accept-Encoding", "*")
}

Expand Down Expand Up @@ -66,7 +70,7 @@ func getCompletionHandler(prod, private bool, client http.Client, timeOut time.D
// return
// }

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

isStreaming := c.GetBool("stream")
if isStreaming {
Expand Down Expand Up @@ -320,7 +324,7 @@ func getMessagesHandler(prod, private bool, client http.Client, e anthropicEstim
return
}

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

isStreaming := c.GetBool("stream")
if isStreaming {
Expand Down
6 changes: 3 additions & 3 deletions internal/server/web/proxy/audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func getSpeechHandler(prod bool, client http.Client, timeOut time.Duration) gin.
return
}

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

start := time.Now()

Expand Down Expand Up @@ -187,7 +187,7 @@ func getTranscriptionsHandler(prod bool, client http.Client, timeOut time.Durati
return
}

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

var b bytes.Buffer
writer := multipart.NewWriter(&b)
Expand Down Expand Up @@ -351,7 +351,7 @@ func getTranslationsHandler(prod bool, client http.Client, timeOut time.Duration
return
}

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

var b bytes.Buffer
writer := multipart.NewWriter(&b)
Expand Down
2 changes: 1 addition & 1 deletion internal/server/web/proxy/azure_chat_completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func getAzureChatCompletionHandler(prod, private bool, client http.Client, aoe a
return
}

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

isStreaming := c.GetBool("stream")
if isStreaming {
Expand Down
2 changes: 1 addition & 1 deletion internal/server/web/proxy/azure_completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func getAzureCompletionsHandler(prod, private bool, client http.Client, aoe azur
return
}

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

isStreaming := c.GetBool("stream")
if isStreaming {
Expand Down
2 changes: 1 addition & 1 deletion internal/server/web/proxy/azure_embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func getAzureEmbeddingsHandler(prod, private bool, client http.Client, aoe azure
return
}

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

start := time.Now()

Expand Down
2 changes: 1 addition & 1 deletion internal/server/web/proxy/chat_completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func getChatCompletionHandler(prod, private bool, client http.Client, e estimato
return
}

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

isStreaming := c.GetBool("stream")
if isStreaming {
Expand Down
2 changes: 1 addition & 1 deletion internal/server/web/proxy/custom_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func getCustomProviderHandler(prod bool, client http.Client, timeOut time.Durati
return
}

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

isStreaming := c.GetBool("stream")
if isStreaming {
Expand Down
4 changes: 2 additions & 2 deletions internal/server/web/proxy/deepinfra.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func getDeepinfraCompletionsHandler(prod, private bool, client http.Client, time
return
}

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

isStreaming := c.GetBool("stream")
if isStreaming {
Expand Down Expand Up @@ -240,7 +240,7 @@ func getDeepinfraChatCompletionsHandler(prod, private bool, client http.Client,
return
}

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

isStreaming := c.GetBool("stream")
if isStreaming {
Expand Down
2 changes: 1 addition & 1 deletion internal/server/web/proxy/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func getEmbeddingHandler(prod, private bool, client http.Client, e estimator, ti
return
}

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

start := time.Now()

Expand Down
6 changes: 5 additions & 1 deletion internal/server/web/proxy/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ type CustomPolicyDetector interface {
Detect(input []string, requirements []string) (bool, error)
}

func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManager, a authenticator, prod, private bool, log *zap.Logger, pub publisher, prefix string, ac accessCache, uac userAccessCache, client http.Client, scanner Scanner, cd CustomPolicyDetector, um userManager) gin.HandlerFunc {
func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManager, a authenticator, prod, private bool, log *zap.Logger, pub publisher, prefix string, ac accessCache, uac userAccessCache, client http.Client, scanner Scanner, cd CustomPolicyDetector, um userManager, removeUserAgent bool) gin.HandlerFunc {
return func(c *gin.Context) {
if c == nil || c.Request == nil {
JSON(c, http.StatusInternalServerError, "[BricksLLM] request is empty")
Expand All @@ -181,6 +181,10 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
return
}

if removeUserAgent {
c.Set("removeUserAgent", removeUserAgent)
}

blw := &responseWriter{body: bytes.NewBufferString(""), ResponseWriter: c.Writer}
c.Writer = blw

Expand Down
44 changes: 41 additions & 3 deletions internal/server/web/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ func CorsMiddleware() gin.HandlerFunc {
}
}

func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, kms keyMemStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, pub publisher, rlm rateLimitManager, timeOut time.Duration, ac accessCache, uac userAccessCache, pm PoliciesManager, scanner Scanner, cd CustomPolicyDetector, die deepinfraEstimator, um userManager) (*ProxyServer, error) {
func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyManager, rm routeManager, a authenticator, psm ProviderSettingsManager, cpm CustomProvidersManager, ks keyStorage, kms keyMemStorage, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, r recorder, pub publisher, rlm rateLimitManager, timeOut time.Duration, ac accessCache, uac userAccessCache, pm PoliciesManager, scanner Scanner, cd CustomPolicyDetector, die deepinfraEstimator, um userManager, removeAgentHeaders bool) (*ProxyServer, error) {
router := gin.New()
prod := mode == "production"
private := privacyMode == "strict"

router.Use(CorsMiddleware())
router.Use(getMiddleware(cpm, rm, pm, a, prod, private, log, pub, "proxy", ac, uac, http.Client{}, scanner, cd, um))
router.Use(getMiddleware(cpm, rm, pm, a, prod, private, log, pub, "proxy", ac, uac, http.Client{}, scanner, cd, um, removeAgentHeaders))

client := http.Client{}

Expand Down Expand Up @@ -196,6 +196,25 @@ func NewProxyServer(log *zap.Logger, mode, privacyMode string, c cache, m KeyMan
// custom route
router.POST("/api/routes/*route", getRouteHandler(prod, c, aoe, e, client, r))

// vector store
router.POST("/api/providers/openai/v1/vector_stores", getCreateVectorStoreHandler(prod, client, timeOut))
router.GET("/api/providers/openai/v1/vector_stores", getListVectorStoresHandler(prod, client, timeOut))
router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id", getGetVectorStoreHandler(prod, client, timeOut))
router.POST("/api/providers/openai/v1/vector_stores/:vector_store_id", getModifyVectorStoreHandler(prod, client, timeOut))
router.DELETE("/api/providers/openai/v1/vector_stores/:vector_store_id", getDeleteVectorStoreHandler(prod, client, timeOut))

// vector store files
router.POST("/api/providers/openai/v1/vector_stores/:vector_store_id/files", getCreateVectorStoreFileHandler(prod, client, timeOut))
router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id/files", getListVectorStoreFilesHandler(prod, client, timeOut))
router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id/files/:file_id", getGetVectorStoreFileHandler(prod, client, timeOut))
router.DELETE("/api/providers/openai/v1/vector_stores/:vector_store_id/files/:file_id", getDeleteVectorStoreFileHandler(prod, client, timeOut))

// vector store file batches
router.POST("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches", getCreateVectorStoreFileBatchHandler(prod, client, timeOut))
router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id", getGetVectorStoreFileBatchHandler(prod, client, timeOut))
router.POST("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/cancel", getCancelVectorStoreFileBatchHandler(prod, client, timeOut))
router.GET("/api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/files", getListVectorStoreFileBatchFilesHandler(prod, client, timeOut))

srv := &http.Server{
Addr: ":8002",
Handler: router,
Expand Down Expand Up @@ -291,7 +310,7 @@ func getPassThroughHandler(prod, private bool, client http.Client, timeOut time.
// copy query params
req.URL.RawQuery = c.Request.URL.RawQuery

copyHttpHeaders(c.Request, req)
copyHttpHeaders(c.Request, req, c.GetBool("removeUserAgent"))

if c.FullPath() == "/api/providers/openai/v1/files" && c.Request.Method == http.MethodPost {
purpose := c.PostForm("purpose")
Expand Down Expand Up @@ -988,6 +1007,25 @@ func (ps *ProxyServer) Run() {
// custom route
ps.log.Info("PORT 8002 | POST | /api/routes/*route is ready for forwarding requests to a custom route")

// vector store
ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/vector_stores is ready for creating an openai vector store")
ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/vector_stores is ready for listing openai vector stores")
ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/vector_stores/:vector_store_id is ready for getting an openai vector store")
ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/vector_stores/:vector_store_id is ready for modifying an openai vector store")
ps.log.Info("PORT 8002 | DELETE | /api/providers/openai/v1/vector_stores/:vector_store_id is ready for deleting an openai vector store")

// vector store files
ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/vector_stores/:vector_store_id/files is ready for creating an openai vector store file")
ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/vector_stores/:vector_store_id/files is ready for listing openai vector store files")
ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/vector_stores/:vector_store_id/files/:file_id is ready for getting an openai vector store file")
ps.log.Info("PORT 8002 | DELETE | /api/providers/openai/v1/vector_stores/:vector_store_id/files/:file_id is ready for deleting an openai vector store file")

// vector store file batches
ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/vector_stores/:vector_store_id/file_batches is ready for creating an openai vector store file batch")
ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id is ready for getting an openai vector store file batch")
ps.log.Info("PORT 8002 | POST | /api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/cancel is ready for cancelling an openai vector store file batch")
ps.log.Info("PORT 8002 | GET | /api/providers/openai/v1/vector_stores/:vector_store_id/file_batches/:batch_id/files is ready for listing openai vector store file batch files")

if err := ps.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
ps.log.Sugar().Fatalf("error proxy server listening: %v", err)
return
Expand Down
Loading

0 comments on commit 60b64b3

Please sign in to comment.