-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
set up yaml parser, web server, cors and key auth
- Loading branch information
1 parent
c7c4347
commit a495e30
Showing
12 changed files
with
1,138 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
openai: | ||
api_credential: ${OPENAI_KEY} | ||
|
||
routes: | ||
- path: /travel | ||
provider: openai | ||
key_auth: | ||
key: ${API_KEY} | ||
cors: | ||
allowed_origins: ["*"] | ||
allowed_credentials: true | ||
input: | ||
plan: | ||
type: object | ||
properties: | ||
place: | ||
type: string | ||
openai_config: | ||
model: gpt-3.5-turbo | ||
prompts: | ||
- role: assitant | ||
content: say hi to {{ plan.place }} | ||
|
||
- path: /test | ||
provider: openai | ||
input: | ||
name: | ||
type: string | ||
openai_config: | ||
model: gpt-3.5-turbo | ||
prompts: | ||
- role: assitant | ||
content: say hi to {{ name }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"os" | ||
"os/signal" | ||
"syscall" | ||
"time" | ||
|
||
"github.com/bricks-cloud/atlas/config" | ||
"github.com/bricks-cloud/atlas/internal/server/web" | ||
"github.com/gin-gonic/gin" | ||
"go.uber.org/zap" | ||
) | ||
|
||
func main() { | ||
gin.SetMode(gin.ReleaseMode) | ||
|
||
rawJSON := []byte(`{ | ||
"level": "debug", | ||
"encoding": "json", | ||
"outputPaths": ["stdout", "/tmp/logs"], | ||
"errorOutputPaths": ["stderr"], | ||
"encoderConfig": { | ||
"messageKey": "message", | ||
"levelKey": "level", | ||
"levelEncoder": "lowercase" | ||
} | ||
}`) | ||
|
||
var cfg zap.Config | ||
|
||
if err := json.Unmarshal(rawJSON, &cfg); err != nil { | ||
panic(err) | ||
} | ||
|
||
logger := zap.Must(cfg.Build()) | ||
defer logger.Sync() | ||
|
||
filePath := "atlas.yaml" | ||
|
||
c, err := config.NewConfig(filePath) | ||
if err != nil { | ||
logger.Sugar().Fatalf("error parsing yaml config %s : %w", filePath, err) | ||
} | ||
|
||
logger.Sugar().Infof("successfuly parsed atlas yaml config file from path: %s", filePath) | ||
|
||
ws, err := web.NewWebServer(c, logger.Sugar()) | ||
if err != nil { | ||
logger.Sugar().Fatalf("error creating http server: %w", err) | ||
} | ||
|
||
ws.Run() | ||
|
||
quit := make(chan os.Signal) | ||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) | ||
<-quit | ||
|
||
logger.Sugar().Info("shutting down server...") | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||
defer cancel() | ||
if err := ws.Shutdown(ctx); err != nil { | ||
logger.Sugar().Fatalf("server shutdown: %w", err) | ||
} | ||
|
||
select { | ||
case <-ctx.Done(): | ||
logger.Sugar().Info("timeout of 5 seconds") | ||
} | ||
|
||
logger.Sugar().Info("server exited") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,282 @@ | ||
package config | ||
|
||
import ( | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"io/ioutil" | ||
"os" | ||
"strings" | ||
"time" | ||
|
||
"github.com/bricks-cloud/atlas/internal/util" | ||
"gopkg.in/yaml.v3" | ||
) | ||
|
||
type RateLimitConfig struct { | ||
Count string `yaml:"count"` | ||
Interval string `yaml:"interval"` | ||
} | ||
|
||
type Protocol string | ||
|
||
const ( | ||
Http Protocol = "http" | ||
Https Protocol = "https" | ||
) | ||
|
||
type Provider string | ||
|
||
const ( | ||
openaiProvider Provider = "openai" | ||
) | ||
|
||
type ServerConfig struct { | ||
Port int `yaml:"port"` | ||
} | ||
|
||
type DataType string | ||
|
||
const ( | ||
StringDataType DataType = "string" | ||
NumberDataType DataType = "number" | ||
ArrayDataType DataType = "array" | ||
ObjectDataType DataType = "object" | ||
BooleanDataType DataType = "boolean" | ||
) | ||
|
||
type InputValue struct { | ||
DataType DataType `yaml:"type"` | ||
Properties map[string]interface{} `yaml:"properties"` | ||
Items interface{} `yaml:"items"` | ||
} | ||
|
||
type CorsConfig struct { | ||
AllowedOrgins []string `yaml:"allowed_origins"` | ||
AllowedCredentials bool `yaml:"allowed_credentials"` | ||
} | ||
|
||
func (cc *CorsConfig) Enabled() bool { | ||
return cc != nil | ||
} | ||
|
||
func (cc *CorsConfig) GetAllowedOrigins() []string { | ||
return cc.AllowedOrgins | ||
} | ||
|
||
func (cc *CorsConfig) GetAllowedCredentials() bool { | ||
return cc.AllowedCredentials | ||
} | ||
|
||
type KeyAuthConfig struct { | ||
Key string `yaml:"key"` | ||
} | ||
|
||
func (kac *KeyAuthConfig) Enabled() bool { | ||
return kac != nil | ||
} | ||
|
||
func (kac *KeyAuthConfig) GetKey() string { | ||
return kac.Key | ||
} | ||
|
||
type RouteConfig struct { | ||
Path string `yaml:"path"` | ||
CorsConfig *CorsConfig `yaml:"cors"` | ||
Input map[string]InputValue `yaml:"input"` | ||
Provider Provider `yaml:"provider"` | ||
OpenAiConfig *OpenAiRouteConfig `yaml:"openai_config"` | ||
Description string `yaml:"description"` | ||
Protocol Protocol `yaml:"protocol"` | ||
CertFile string `yaml:"cert_file"` | ||
KeyFile string `yaml:"key_file"` | ||
KeyAuthConfig *KeyAuthConfig `yaml:"key_auth"` | ||
UpstreamSendTimeout time.Duration `yaml:"upstream_send_time"` | ||
} | ||
|
||
type OpenAiMessageRole string | ||
|
||
const ( | ||
system OpenAiMessageRole = "system" | ||
user OpenAiMessageRole = "user" | ||
assitant OpenAiMessageRole = "assitant" | ||
function OpenAiMessageRole = "function" | ||
) | ||
|
||
type OpenAiPrompt struct { | ||
Role string `yaml:"role"` | ||
Content string `yaml:"content"` | ||
} | ||
|
||
type OpenAiModel string | ||
|
||
const ( | ||
gpt35Turbo OpenAiModel = "gpt-3.5-turbo" | ||
) | ||
|
||
type OpenAiRouteConfig struct { | ||
ApiCredential string `yaml:"api_credential"` | ||
Model OpenAiModel `yaml:"model"` | ||
Prompts []*OpenAiPrompt `yaml:"prompts"` | ||
} | ||
|
||
type OpenAiConfig struct { | ||
ApiCredential string `yaml:"api_credential"` | ||
} | ||
|
||
type Config struct { | ||
Routes []*RouteConfig `yaml:"routes"` | ||
Server *ServerConfig `yaml:"server"` | ||
OpenAiConfig *OpenAiConfig `yaml:"openai"` | ||
} | ||
|
||
func NewConfig(filePath string) (*Config, error) { | ||
yamlFile, err := ioutil.ReadFile(filePath) | ||
if err != nil { | ||
return nil, fmt.Errorf("unable to read config file with path %s: %w", filePath, err) | ||
} | ||
|
||
yamlFile = []byte(os.ExpandEnv(string(yamlFile))) | ||
|
||
c := &Config{} | ||
err = yaml.Unmarshal(yamlFile, c) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to unmarshal yaml file with path %s: %w", filePath, err) | ||
} | ||
|
||
// default server port to 8080 | ||
if c.Server == nil { | ||
c.Server = &ServerConfig{ | ||
Port: 8080, | ||
} | ||
} | ||
|
||
if len(c.Routes) == 0 { | ||
return nil, fmt.Errorf("routes are not configured in config file %s", filePath) | ||
} | ||
|
||
apiCredentialConfigured := false | ||
if c.OpenAiConfig != nil && len(c.OpenAiConfig.ApiCredential) != 0 { | ||
apiCredentialConfigured = true | ||
} | ||
|
||
for _, route := range c.Routes { | ||
err = parseRouteConfig(route, apiCredentialConfigured) | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
return c, nil | ||
} | ||
|
||
func parseRouteConfig(rc *RouteConfig, isOpenAiConfigured bool) error { | ||
if len(rc.Path) == 0 { | ||
return errors.New("path is empty") | ||
} | ||
|
||
if len(rc.Provider) == 0 { | ||
return errors.New("provider is empty") | ||
} | ||
|
||
if rc.CorsConfig != nil { | ||
if len(rc.CorsConfig.AllowedOrgins) == 0 { | ||
return fmt.Errorf("cors config is present but allowed_origins is not specified for route: %s", rc.Path) | ||
} | ||
} | ||
|
||
if rc.KeyAuthConfig != nil { | ||
if len(rc.KeyAuthConfig.Key) == 0 { | ||
return fmt.Errorf("key_auth config is present but key is not specified for route: %s", rc.Path) | ||
} | ||
} | ||
|
||
if rc.Provider == openaiProvider { | ||
if rc.OpenAiConfig == nil { | ||
return errors.New("openai config is not provided") | ||
} | ||
|
||
for _, prompt := range rc.OpenAiConfig.Prompts { | ||
if len(prompt.Role) == 0 { | ||
return errors.New("role is not provided in openai prompt") | ||
} | ||
|
||
if !isOpenAiConfigured && len(rc.OpenAiConfig.ApiCredential) == 0 { | ||
return errors.New("openai api credential is not configrued") | ||
} | ||
|
||
if len(prompt.Content) == 0 { | ||
return errors.New("content is not provided in openai prompt") | ||
} | ||
|
||
variableMap := util.GetVariableMap(prompt.Content) | ||
err := validateInput(rc.Input, variableMap) | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
} | ||
|
||
if rc.Protocol == Https { | ||
if len(rc.CertFile) == 0 { | ||
return errors.New("cert file is not provided for https protocol") | ||
} | ||
|
||
if len(rc.KeyFile) == 0 { | ||
return errors.New("key file is not provided for https protocol") | ||
} | ||
} | ||
|
||
// defaut route protocol to http | ||
if len(rc.Protocol) == 0 { | ||
rc.Protocol = Http | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func validateInput(input map[string]InputValue, variableMap map[string]string) error { | ||
if len(variableMap) == 0 { | ||
return nil | ||
} | ||
|
||
for _, reference := range variableMap { | ||
parts := strings.Split(reference, ".") | ||
|
||
if len(parts) == 0 { | ||
return errors.New("no references found inside `{{ }}` syntax") | ||
} | ||
|
||
if len(parts) == 1 { | ||
if _, found := input[parts[0]]; found { | ||
continue | ||
} | ||
|
||
return errors.New("referenced value in prompt is not defined in input") | ||
} | ||
|
||
innerInput := input | ||
for index, part := range parts { | ||
value, found := innerInput[part] | ||
if !found { | ||
return errors.New("referenced value in prompt does not exist") | ||
} | ||
|
||
if index != len(parts)-1 && value.DataType != ObjectDataType { | ||
return errors.New("input value is not represented as object") | ||
} | ||
|
||
js, err := json.Marshal(value.Properties) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
innerInput = map[string]InputValue{} | ||
err = json.Unmarshal(js, &innerInput) | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
} | ||
|
||
return nil | ||
} |
Oops, something went wrong.