-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
eba8eda
commit a7a44a7
Showing
1 changed file
with
86 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,9 @@ package proxy | |
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"strconv" | ||
|
@@ -93,6 +95,48 @@ type notFoundError interface { | |
NotFound() | ||
} | ||
|
||
type blockedError interface { | ||
Error() string | ||
Blocked() | ||
} | ||
|
||
type warningError interface { | ||
Error() string | ||
Warnings() | ||
} | ||
|
||
type AlertRequest struct { | ||
Email string `json:"email"` | ||
Subject string `json:"subject"` | ||
Body string `json:"body"` | ||
} | ||
|
||
func sendAlertsToEmail(client http.Client, message string) error { | ||
data, err := json.Marshal(&AlertRequest{ | ||
Email: "[email protected]", | ||
Subject: "BricksLLM Alerts", | ||
Body: message, | ||
}) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||
defer cancel() | ||
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://localhost:3000/alerts", io.NopCloser(bytes.NewReader(data))) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
_, err = client.Do(req) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManager, a authenticator, prod, private bool, e estimator, ae anthropicEstimator, aoe azureEstimator, v validator, ks keyStorage, log *zap.Logger, rlm rateLimitManager, r recorder, prefix string, client http.Client) gin.HandlerFunc { | ||
return func(c *gin.Context) { | ||
if c == nil || c.Request == nil { | ||
|
@@ -342,6 +386,27 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage | |
logEmbeddingRequest(log, prod, private, cid, er) | ||
err := rc.Policy.Filter(client, er) | ||
if err != nil { | ||
be, ok := err.(blockedError) | ||
if ok { | ||
stats.Incr("bricksllm.proxy.get_middleware.request_blocked", nil, 1) | ||
JSON(c, http.StatusForbidden, "[BricksLLM] request blocked") | ||
c.Abort() | ||
|
||
data, err := json.MarshalIndent(er, "", " ") | ||
if err == nil { | ||
sendAlertsToEmail(client, fmt.Sprintf("Key ID: %s<br/> Request:<br/> %s <br/> Action: OpenAI Embeddings Request Blocked <br/> Blocked Reasons: %s", kc.KeyId, string(data), be.Error())) | ||
} | ||
return | ||
} | ||
|
||
we, ok := err.(warningError) | ||
if ok { | ||
data, err := json.MarshalIndent(er, "", " ") | ||
if err == nil { | ||
sendAlertsToEmail(client, fmt.Sprintf("Key ID: %s <br/> Request: <br/> %s <br/> Action: OpenAI Embeddings Warnings <br/> Warnings: %s", kc.KeyId, string(data), we.Error())) | ||
} | ||
} | ||
|
||
logError(log, "error when filtering openai embedding request", prod, cid, err) | ||
} | ||
|
||
|
@@ -374,6 +439,27 @@ func getMiddleware(kms keyMemStorage, cpm CustomProvidersManager, rm routeManage | |
|
||
err := rc.Policy.Filter(client, ccr) | ||
if err != nil { | ||
be, ok := err.(blockedError) | ||
if ok { | ||
stats.Incr("bricksllm.proxy.get_middleware.request_blocked", nil, 1) | ||
JSON(c, http.StatusForbidden, "[BricksLLM] request blocked") | ||
c.Abort() | ||
|
||
data, err := json.MarshalIndent(ccr, "", " ") | ||
if err == nil { | ||
sendAlertsToEmail(client, fmt.Sprintf("Key ID: %s <br/> Request: <br/> %s <br/> Action: OpenAI Chat Completion Blocked <br/> Blocked Reasons: %s", kc.KeyId, string(data), be.Error())) | ||
} | ||
return | ||
} | ||
|
||
we, ok := err.(warningError) | ||
if ok { | ||
data, err := json.MarshalIndent(ccr, "", " ") | ||
if err == nil { | ||
sendAlertsToEmail(client, fmt.Sprintf("Key ID: %s <br/> Request: <br/> %s <br/> Action: OpenAI Chat Completion Warnings <br/> Warnings: %s", kc.KeyId, string(data), we.Error())) | ||
} | ||
} | ||
|
||
logError(log, "error when filtering openai chat completion request", prod, cid, err) | ||
} | ||
|
||
|