Skip to content

Commit

Permalink
add a flag to allow disabling the models processor
Browse files Browse the repository at this point in the history
Signed-off-by: Ignasi Barrera <[email protected]>
  • Loading branch information
nacx committed Feb 11, 2025
1 parent 6e2fc02 commit 2fb8c08
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 61 deletions.
85 changes: 58 additions & 27 deletions cmd/extproc/mainlib/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mainlib

import (
"context"
"errors"
"flag"
"fmt"
"log"
Expand All @@ -21,59 +22,88 @@ import (
"github.com/envoyproxy/ai-gateway/internal/version"
)

// parseAndValidateFlags parses and validates the flags passed to the external processor.
func parseAndValidateFlags(args []string) (configPath, addr string, logLevel slog.Level, err error) {
fs := flag.NewFlagSet("AI Gateway External Processor", flag.ContinueOnError)
configPathPtr := fs.String(
// extProcFlags is the struct that holds the flags passed to the external processor.
type extProcFlags struct {
configPath string // path to the configuration file
extProcAddr string // gRPC address for the external processor
logLevel slog.Level // log level for the external processor
processors []string // list of processors to enable
}

// builtinProcessors is the map of built-in processors that can be enabled.
var builtinProcessors = map[string]struct {
path string
factory extproc.ProcessorFactory
}{
"chatcompletions": {"/v1/chat/completions", extproc.NewChatCompletionProcessor},
"models": {"/v1/models", extproc.NewModelsProcessor},
}

// parseAndValidateFlags parses and validates the flas passed to the external processor.
func parseAndValidateFlags(args []string) (extProcFlags, error) {
var (
flags extProcFlags
errs []error
fs = flag.NewFlagSet("AI Gateway External Processor", flag.ContinueOnError)
)

fs.StringVar(&flags.configPath,
"configPath",
"",
"path to the configuration file. The file must be in YAML format specified in filterapi.Config type. "+
"The configuration file is watched for changes.",
)
extProcAddrPtr := fs.String(
fs.StringVar(&flags.extProcAddr,
"extProcAddr",
":1063",
"gRPC address for the external processor. For example, :1063 or unix:///tmp/ext_proc.sock",
)
logLevelPtr := fs.String(
"logLevel",
"info", "log level for the external processor. One of 'debug', 'info', 'warn', or 'error'.",
"info",
"log level for the external processor. One of 'debug', 'info', 'warn', or 'error'.",
)
processorsPtr := fs.String(
"processors",
"chatcompletions,models",
"comma-separated list of processors to enable. Available processors: chatcompletions, models",
)

if err = fs.Parse(args); err != nil {
err = fmt.Errorf("failed to parse flags: %w", err)
return
if err := fs.Parse(args); err != nil {
return extProcFlags{}, fmt.Errorf("failed to parse extProcFlags: %w", err)
}

if *configPathPtr == "" {
err = fmt.Errorf("configPath must be provided")
return
if flags.configPath == "" {
errs = append(errs, fmt.Errorf("configPath must be provided"))
}
if err := flags.logLevel.UnmarshalText([]byte(*logLevelPtr)); err != nil {
errs = append(errs, fmt.Errorf("failed to unmarshal log level: %w", err))
}

if err = logLevel.UnmarshalText([]byte(*logLevelPtr)); err != nil {
err = fmt.Errorf("failed to unmarshal log level: %w", err)
return
flags.processors = strings.Split(*processorsPtr, ",")
for _, p := range flags.processors {
if _, ok := builtinProcessors[p]; !ok {
errs = append(errs, fmt.Errorf("invalid processor: %s", p))
}
}

configPath = *configPathPtr
addr = *extProcAddrPtr
return
return flags, errors.Join(errs...)
}

// Main is a main function for the external processor exposed
// for allowing users to build their own external processor.
func Main() {
configPath, extProcAddr, level, err := parseAndValidateFlags(os.Args[1:])
flags, err := parseAndValidateFlags(os.Args[1:])
if err != nil {
log.Fatalf("failed to parse and validate flags: %v", err)
log.Fatalf("failed to parse and validate extProcFlags: %v", err)
}

l := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))
l := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: flags.logLevel}))

l.Info("starting external processor",
slog.String("version", version.Version),
slog.String("address", extProcAddr),
slog.String("configPath", configPath),
slog.String("address", flags.extProcAddr),
slog.String("configPath", flags.configPath),
)

ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -84,7 +114,7 @@ func Main() {
cancel()
}()

lis, err := net.Listen(listenAddress(extProcAddr))
lis, err := net.Listen(listenAddress(flags.extProcAddr))
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
Expand All @@ -93,10 +123,11 @@ func Main() {
if err != nil {
log.Fatalf("failed to create external processor server: %v", err)
}
server.Register("/v1/chat/completions", extproc.NewChatCompletionProcessor)
server.Register("/v1/models", extproc.NewModelsProcessor)
for _, p := range flags.processors {
server.Register(builtinProcessors[p].path, builtinProcessors[p].factory)
}

if err := extproc.StartConfigWatcher(ctx, configPath, server, l, time.Second*5); err != nil {
if err := extproc.StartConfigWatcher(ctx, flags.configPath, server, l, time.Second*5); err != nil {
log.Fatalf("failed to start config watcher: %v", err)
}

Expand Down
68 changes: 37 additions & 31 deletions cmd/extproc/mainlib/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,91 +5,97 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_parseAndValidateFlags(t *testing.T) {
t.Run("ok flags", func(t *testing.T) {
t.Run("ok extProcFlags", func(t *testing.T) {
for _, tc := range []struct {
name string
args []string
configPath string
addr string
logLevel slog.Level
processors []string
}{
{
name: "minimal flags",
name: "minimal extProcFlags",
args: []string{"-configPath", "/path/to/config.yaml"},
configPath: "/path/to/config.yaml",
addr: ":1063",
logLevel: slog.LevelInfo,
processors: []string{"chatcompletions", "models"},
},
{
name: "custom addr",
args: []string{"-configPath", "/path/to/config.yaml", "-extProcAddr", "unix:///tmp/ext_proc.sock"},
configPath: "/path/to/config.yaml",
addr: "unix:///tmp/ext_proc.sock",
logLevel: slog.LevelInfo,
processors: []string{"chatcompletions", "models"},
},
{
name: "log level debug",
args: []string{"-configPath", "/path/to/config.yaml", "-logLevel", "debug"},
configPath: "/path/to/config.yaml",
addr: ":1063",
logLevel: slog.LevelDebug,
processors: []string{"chatcompletions", "models"},
},
{
name: "log level warn",
args: []string{"-configPath", "/path/to/config.yaml", "-logLevel", "warn"},
configPath: "/path/to/config.yaml",
addr: ":1063",
logLevel: slog.LevelWarn,
processors: []string{"chatcompletions", "models"},
},
{
name: "log level error",
args: []string{"-configPath", "/path/to/config.yaml", "-logLevel", "error"},
configPath: "/path/to/config.yaml",
addr: ":1063",
logLevel: slog.LevelError,
processors: []string{"chatcompletions", "models"},
},
{
name: "all flags",
args: []string{"-configPath", "/path/to/config.yaml", "-extProcAddr", "unix:///tmp/ext_proc.sock", "-logLevel", "debug"},
name: "custom processors",
args: []string{"-configPath", "/path/to/config.yaml", "-processors", "chatcompletions"},
configPath: "/path/to/config.yaml",
addr: ":1063",
logLevel: slog.LevelInfo,
processors: []string{"chatcompletions"},
},
{
name: "all extProcFlags",
args: []string{
"-configPath", "/path/to/config.yaml",
"-extProcAddr", "unix:///tmp/ext_proc.sock",
"-logLevel", "debug",
"-processors", "models",
},
configPath: "/path/to/config.yaml",
addr: "unix:///tmp/ext_proc.sock",
logLevel: slog.LevelDebug,
processors: []string{"models"},
},
} {
t.Run(tc.name, func(t *testing.T) {
configPath, addr, logLevel, err := parseAndValidateFlags(tc.args)
assert.Equal(t, tc.configPath, configPath)
assert.Equal(t, tc.addr, addr)
assert.Equal(t, tc.logLevel, logLevel)
assert.NoError(t, err)
flags, err := parseAndValidateFlags(tc.args)
require.NoError(t, err)
assert.Equal(t, tc.configPath, flags.configPath)
assert.Equal(t, tc.addr, flags.extProcAddr)
assert.Equal(t, tc.logLevel, flags.logLevel)
assert.Equal(t, tc.processors, flags.processors)
})
}
})
t.Run("invalid flags", func(t *testing.T) {
for _, tc := range []struct {
name string
flags []string
expErr string
}{
{
name: "missing configPath",
flags: []string{"-extProcAddr", ":1063"},
expErr: "configPath must be provided",
},
{
name: "invalid logLevel",
flags: []string{"-configPath", "/path/to/config.yaml", "-logLevel", "invalid"},
expErr: `failed to unmarshal log level: slog: level string "invalid": unknown name`,
},
} {
t.Run(tc.name, func(t *testing.T) {
_, _, _, err := parseAndValidateFlags(tc.flags)
assert.EqualError(t, err, tc.expErr)
})
}

t.Run("invalid extProcFlags", func(t *testing.T) {
_, err := parseAndValidateFlags([]string{"-logLevel", "invalid", "-processors", "undefined"})
assert.EqualError(t, err, `configPath must be provided
failed to unmarshal log level: slog: level string "invalid": unknown name
invalid processor: undefined`)
})
}

Expand Down
5 changes: 5 additions & 0 deletions internal/extproc/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package extproc

import (
"context"
"log/slog"

corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
Expand All @@ -28,11 +29,15 @@ type processorConfig struct {
declaredModels []string
}

// processorConfigRequestCost is the configuration for the request cost.
type processorConfigRequestCost struct {
*filterapi.LLMRequestCost
celProg cel.Program
}

// ProcessorFactory is the factory function used to create new instances of a processor.
type ProcessorFactory func(*processorConfig, *slog.Logger) ProcessorIface

// ProcessorIface is the interface for the processor.
// This decouples the processor implementation detail from the server implementation.
type ProcessorIface interface {
Expand Down
6 changes: 3 additions & 3 deletions internal/extproc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ var sensitiveHeaderKeys = []string{"authorization"}
type Server struct {
logger *slog.Logger
config *processorConfig
processors map[string]func(*processorConfig, *slog.Logger) ProcessorIface
processors map[string]ProcessorFactory
}

// NewServer creates a new external processor server.
func NewServer(logger *slog.Logger) (*Server, error) {
srv := &Server{
logger: logger,
processors: make(map[string]func(*processorConfig, *slog.Logger) ProcessorIface),
processors: make(map[string]ProcessorFactory),
}
return srv, nil
}
Expand Down Expand Up @@ -122,7 +122,7 @@ func (s *Server) LoadConfig(ctx context.Context, config *filterapi.Config) error
}

// Register a new processor for the given request path.
func (s *Server) Register(path string, newProcessor func(*processorConfig, *slog.Logger) ProcessorIface) {
func (s *Server) Register(path string, newProcessor ProcessorFactory) {
s.processors[path] = newProcessor
}

Expand Down

0 comments on commit 2fb8c08

Please sign in to comment.