Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extproc: custom processors per path and serve /v1/models #325

Merged
merged 2 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions cmd/extproc/mainlib/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mainlib

import (
"context"
"errors"
"flag"
"fmt"
"log"
Expand All @@ -21,59 +22,66 @@ import (
"github.com/envoyproxy/ai-gateway/internal/version"
)

// parseAndValidateFlags parses and validates the flags passed to the external processor.
func parseAndValidateFlags(args []string) (configPath, addr string, logLevel slog.Level, err error) {
fs := flag.NewFlagSet("AI Gateway External Processor", flag.ContinueOnError)
configPathPtr := fs.String(
// extProcFlags is the struct that holds the flags passed to the external processor.
type extProcFlags struct {
configPath string // path to the configuration file.
extProcAddr string // gRPC address for the external processor.
logLevel slog.Level // log level for the external processor.
}

// parseAndValidateFlags parses and validates the flas passed to the external processor.
func parseAndValidateFlags(args []string) (extProcFlags, error) {
var (
flags extProcFlags
errs []error
fs = flag.NewFlagSet("AI Gateway External Processor", flag.ContinueOnError)
)

fs.StringVar(&flags.configPath,
"configPath",
"",
"path to the configuration file. The file must be in YAML format specified in filterapi.Config type. "+
"The configuration file is watched for changes.",
)
extProcAddrPtr := fs.String(
fs.StringVar(&flags.extProcAddr,
"extProcAddr",
":1063",
"gRPC address for the external processor. For example, :1063 or unix:///tmp/ext_proc.sock",
)
logLevelPtr := fs.String(
"logLevel",
"info", "log level for the external processor. One of 'debug', 'info', 'warn', or 'error'.",
"info",
"log level for the external processor. One of 'debug', 'info', 'warn', or 'error'.",
)

if err = fs.Parse(args); err != nil {
err = fmt.Errorf("failed to parse flags: %w", err)
return
if err := fs.Parse(args); err != nil {
return extProcFlags{}, fmt.Errorf("failed to parse extProcFlags: %w", err)
}

if *configPathPtr == "" {
err = fmt.Errorf("configPath must be provided")
return
if flags.configPath == "" {
errs = append(errs, fmt.Errorf("configPath must be provided"))
}

if err = logLevel.UnmarshalText([]byte(*logLevelPtr)); err != nil {
err = fmt.Errorf("failed to unmarshal log level: %w", err)
return
if err := flags.logLevel.UnmarshalText([]byte(*logLevelPtr)); err != nil {
errs = append(errs, fmt.Errorf("failed to unmarshal log level: %w", err))
}

configPath = *configPathPtr
addr = *extProcAddrPtr
return
return flags, errors.Join(errs...)
}

// Main is a main function for the external processor exposed
// for allowing users to build their own external processor.
func Main() {
configPath, extProcAddr, level, err := parseAndValidateFlags(os.Args[1:])
flags, err := parseAndValidateFlags(os.Args[1:])
if err != nil {
log.Fatalf("failed to parse and validate flags: %v", err)
log.Fatalf("failed to parse and validate extProcFlags: %v", err)
}

l := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
l := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: flags.logLevel}))

l.Info("starting external processor",
slog.String("version", version.Version),
slog.String("address", extProcAddr),
slog.String("configPath", configPath),
slog.String("address", flags.extProcAddr),
slog.String("configPath", flags.configPath),
)

ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -84,17 +92,19 @@ func Main() {
cancel()
}()

lis, err := net.Listen(listenAddress(extProcAddr))
lis, err := net.Listen(listenAddress(flags.extProcAddr))
if err != nil {
log.Fatalf("failed to listen: %v", err)
}

server, err := extproc.NewServer[*extproc.Processor](l, extproc.NewProcessor)
server, err := extproc.NewServer(l)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generics were tying the server to a concrete processor type, but now it allows defining custom processors per request path, so the generics made no sense anymore. I've removed them and kept the code using the ProcessorIface

if err != nil {
log.Fatalf("failed to create external processor server: %v", err)
}
server.Register("/v1/chat/completions", extproc.NewChatCompletionProcessor)
server.Register("/v1/models", extproc.NewModelsProcessor)

if err := extproc.StartConfigWatcher(ctx, configPath, server, l, time.Second*5); err != nil {
if err := extproc.StartConfigWatcher(ctx, flags.configPath, server, l, time.Second*5); err != nil {
log.Fatalf("failed to start config watcher: %v", err)
}

Expand Down
50 changes: 19 additions & 31 deletions cmd/extproc/mainlib/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_parseAndValidateFlags(t *testing.T) {
t.Run("ok flags", func(t *testing.T) {
t.Run("ok extProcFlags", func(t *testing.T) {
for _, tc := range []struct {
name string
args []string
Expand All @@ -17,7 +18,7 @@ func Test_parseAndValidateFlags(t *testing.T) {
logLevel slog.Level
}{
{
name: "minimal flags",
name: "minimal extProcFlags",
args: []string{"-configPath", "/path/to/config.yaml"},
configPath: "/path/to/config.yaml",
addr: ":1063",
Expand Down Expand Up @@ -52,44 +53,31 @@ func Test_parseAndValidateFlags(t *testing.T) {
logLevel: slog.LevelError,
},
{
name: "all flags",
args: []string{"-configPath", "/path/to/config.yaml", "-extProcAddr", "unix:///tmp/ext_proc.sock", "-logLevel", "debug"},
name: "all extProcFlags",
args: []string{
"-configPath", "/path/to/config.yaml",
"-extProcAddr", "unix:///tmp/ext_proc.sock",
"-logLevel", "debug",
},
configPath: "/path/to/config.yaml",
addr: "unix:///tmp/ext_proc.sock",
logLevel: slog.LevelDebug,
},
} {
t.Run(tc.name, func(t *testing.T) {
configPath, addr, logLevel, err := parseAndValidateFlags(tc.args)
assert.Equal(t, tc.configPath, configPath)
assert.Equal(t, tc.addr, addr)
assert.Equal(t, tc.logLevel, logLevel)
assert.NoError(t, err)
flags, err := parseAndValidateFlags(tc.args)
require.NoError(t, err)
assert.Equal(t, tc.configPath, flags.configPath)
assert.Equal(t, tc.addr, flags.extProcAddr)
assert.Equal(t, tc.logLevel, flags.logLevel)
})
}
})
t.Run("invalid flags", func(t *testing.T) {
for _, tc := range []struct {
name string
flags []string
expErr string
}{
{
name: "missing configPath",
flags: []string{"-extProcAddr", ":1063"},
expErr: "configPath must be provided",
},
{
name: "invalid logLevel",
flags: []string{"-configPath", "/path/to/config.yaml", "-logLevel", "invalid"},
expErr: `failed to unmarshal log level: slog: level string "invalid": unknown name`,
},
} {
t.Run(tc.name, func(t *testing.T) {
_, _, _, err := parseAndValidateFlags(tc.flags)
assert.EqualError(t, err, tc.expErr)
})
}

t.Run("invalid extProcFlags", func(t *testing.T) {
_, err := parseAndValidateFlags([]string{"-logLevel", "invalid"})
assert.EqualError(t, err, `configPath must be provided
failed to unmarshal log level: slog: level string "invalid": unknown name`)
})
}

Expand Down
2 changes: 1 addition & 1 deletion filterapi/filterconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
)

func TestDefaultConfig(t *testing.T) {
server, err := extproc.NewServer(slog.Default(), extproc.NewProcessor)
server, err := extproc.NewServer(slog.Default())
require.NoError(t, err)
require.NotNil(t, server)

Expand Down
42 changes: 42 additions & 0 deletions internal/apischema/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ package openai
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
)

// Chat message role defined by the OpenAI API.
Expand Down Expand Up @@ -723,3 +725,43 @@ type ErrorType struct {
// The event_id of the client event that caused the error, if applicable.
EventID *string `json:"event_id,omitempty"`
}

// ModelList is described in the OpenAI API documentation
// https://platform.openai.com/docs/api-reference/models/list
type ModelList struct {
// Data is a list of models.
Data []Model `json:"data"`
// Object is the object type, which is always "list".
Object string `json:"object"`
}

// Model is described in the OpenAI API documentation
// https://platform.openai.com/docs/api-reference/models/object
type Model struct {
// ID is the model identifier, which can be referenced in the API endpoints.
ID string `json:"id"`
// Created is the Unix timestamp (in seconds) when the model was created.
Created JSONUNIXTime `json:"created"`
// Object is the object type, which is always "model".
Object string `json:"object"`
// OwnedBy is the organization that owns the model.
OwnedBy string `json:"owned_by"`
}

// JSONUNIXTime is a helper type to marshal/unmarshal time.Time UNIX timestamps.
type JSONUNIXTime time.Time

// MarshalJSON implements [json.Marshaler].
func (t JSONUNIXTime) MarshalJSON() ([]byte, error) {
return []byte(strconv.FormatInt(time.Time(t).Unix(), 10)), nil
}

// UnmarshalJSON implements [json.Unmarshaler].
func (t *JSONUNIXTime) UnmarshalJSON(s []byte) error {
q, err := strconv.ParseInt(string(s), 10, 64)
if err != nil {
return err
}
*(*time.Time)(t) = time.Unix(q, 0)
return nil
}
28 changes: 28 additions & 0 deletions internal/apischema/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai
import (
"encoding/json"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/openai/openai-go"
Expand Down Expand Up @@ -234,3 +235,30 @@ func TestOpenAIChatCompletionMessageUnmarshal(t *testing.T) {
})
}
}

func TestModelListMarshal(t *testing.T) {
var (
model = Model{
ID: "gpt-3.5-turbo",
Object: "model",
OwnedBy: "tetrate",
Created: JSONUNIXTime(time.Date(2025, 0o1, 0o1, 0, 0, 0, 0, time.UTC)),
}
list = ModelList{Object: "list", Data: []Model{model}}
raw = `{"object":"list","data":[{"id":"gpt-3.5-turbo","object":"model","owned_by":"tetrate","created":1735689600}]}`
)

b, err := json.Marshal(list)
require.NoError(t, err)
require.JSONEq(t, raw, string(b))

var out ModelList
require.NoError(t, json.Unmarshal([]byte(raw), &out))
require.Len(t, out.Data, 1)
require.Equal(t, "list", out.Object)
require.Equal(t, model.ID, out.Data[0].ID)
require.Equal(t, model.Object, out.Data[0].Object)
require.Equal(t, model.OwnedBy, out.Data[0].OwnedBy)
// Unmarshalling initializes other fields in time.Time we're not interested with. Just compare the actual time.
require.Equal(t, time.Time(model.Created).Unix(), time.Time(out.Data[0].Created).Unix())
}
Loading