diff --git a/core/http/app.go b/core/http/app.go index d3c42b317fc1..328a9d8e9a18 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -205,6 +205,7 @@ func API(application *application.Application) (*echo.Echo, error) { routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application) routes.RegisterOpenAIRoutes(e, requestExtractor, application) + routes.RegisterAnthropicRoutes(e, requestExtractor, application) if !application.ApplicationConfig().DisableWebUI { routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application) routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService()) diff --git a/core/http/endpoints/anthropic/messages.go b/core/http/endpoints/anthropic/messages.go new file mode 100644 index 000000000000..b82ed3f3b6d1 --- /dev/null +++ b/core/http/endpoints/anthropic/messages.go @@ -0,0 +1,310 @@ +package anthropic + +import ( + "encoding/json" + "fmt" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/templates" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/xlog" +) + +// MessagesEndpoint is the Anthropic Messages API endpoint +// https://docs.anthropic.com/claude/reference/messages_post +// @Summary Generate a message response for the given messages and model. +// @Param request body schema.AnthropicRequest true "query params" +// @Success 200 {object} schema.AnthropicResponse "Response" +// @Router /v1/messages [post] +func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + id := uuid.New().String() + + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.AnthropicRequest) + if !ok || input.Model == "" { + return sendAnthropicError(c, 400, "invalid_request_error", "model is required") + } + + cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + return sendAnthropicError(c, 400, "invalid_request_error", "model configuration not found") + } + + if input.MaxTokens <= 0 { + return sendAnthropicError(c, 400, "invalid_request_error", "max_tokens is required and must be greater than 0") + } + + xlog.Debug("Anthropic Messages endpoint configuration read", "config", cfg) + + // Convert Anthropic messages to OpenAI format for internal processing + openAIMessages := convertAnthropicToOpenAIMessages(input) + + // Create an OpenAI-compatible request for internal processing + openAIReq := &schema.OpenAIRequest{ + PredictionOptions: schema.PredictionOptions{ + BasicModelRequest: schema.BasicModelRequest{Model: input.Model}, + Temperature: input.Temperature, + TopK: input.TopK, + TopP: input.TopP, + Maxtokens: &input.MaxTokens, + }, + Messages: openAIMessages, + Stream: input.Stream, + Context: input.Context, + Cancel: input.Cancel, + } + + // Set stop sequences + if len(input.StopSequences) > 0 { + openAIReq.Stop = input.StopSequences + } + + // Merge config settings + if input.Temperature != nil { + cfg.Temperature = input.Temperature + } + if input.TopK != nil { + cfg.TopK = input.TopK + } + if input.TopP != nil { + cfg.TopP = input.TopP + } + cfg.Maxtokens = &input.MaxTokens + if len(input.StopSequences) > 0 { + cfg.StopWords = append(cfg.StopWords, input.StopSequences...) + } + + // Template the prompt + predInput := evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, nil, false) + xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput) + + if input.Stream { + return handleAnthropicStream(c, id, input, cfg, ml, predInput) + } + + return handleAnthropicNonStream(c, id, input, cfg, ml, predInput, openAIReq) + } +} + +func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, predInput string, openAIReq *schema.OpenAIRequest) error { + images := []string{} + for _, m := range openAIReq.Messages { + images = append(images, m.StringImages...) + } + + predFunc, err := backend.ModelInference( + input.Context, predInput, openAIReq.Messages, images, nil, nil, ml, cfg, nil, nil, nil, "", "", nil, nil, nil) + if err != nil { + xlog.Error("Anthropic model inference failed", "error", err) + return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err)) + } + + prediction, err := predFunc() + if err != nil { + xlog.Error("Anthropic prediction failed", "error", err) + return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err)) + } + + result := backend.Finetune(*cfg, predInput, prediction.Response) + stopReason := "end_turn" + + resp := &schema.AnthropicResponse{ + ID: fmt.Sprintf("msg_%s", id), + Type: "message", + Role: "assistant", + Model: input.Model, + StopReason: &stopReason, + Content: []schema.AnthropicContentBlock{ + {Type: "text", Text: result}, + }, + Usage: schema.AnthropicUsage{ + InputTokens: prediction.Usage.Prompt, + OutputTokens: prediction.Usage.Completion, + }, + } + + if respData, err := json.Marshal(resp); err == nil { + xlog.Debug("Anthropic Response", "response", string(respData)) + } + + return c.JSON(200, resp) +} + +func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, predInput string) error { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + + // Create OpenAI messages for inference + openAIMessages := convertAnthropicToOpenAIMessages(input) + + images := []string{} + for _, m := range openAIMessages { + images = append(images, m.StringImages...) + } + + // Send message_start event + messageStart := schema.AnthropicStreamEvent{ + Type: "message_start", + Message: &schema.AnthropicStreamMessage{ + ID: fmt.Sprintf("msg_%s", id), + Type: "message", + Role: "assistant", + Content: []schema.AnthropicContentBlock{}, + Model: input.Model, + Usage: schema.AnthropicUsage{InputTokens: 0, OutputTokens: 0}, + }, + } + sendAnthropicSSE(c, messageStart) + + // Send content_block_start event + contentBlockStart := schema.AnthropicStreamEvent{ + Type: "content_block_start", + Index: 0, + ContentBlock: &schema.AnthropicContentBlock{Type: "text", Text: ""}, + } + sendAnthropicSSE(c, contentBlockStart) + + // Stream content deltas + tokenCallback := func(token string, usage backend.TokenUsage) bool { + delta := schema.AnthropicStreamEvent{ + Type: "content_block_delta", + Index: 0, + Delta: &schema.AnthropicStreamDelta{ + Type: "text_delta", + Text: token, + }, + } + sendAnthropicSSE(c, delta) + return true + } + + predFunc, err := backend.ModelInference( + input.Context, predInput, openAIMessages, images, nil, nil, ml, cfg, nil, nil, tokenCallback, "", "", nil, nil, nil) + if err != nil { + xlog.Error("Anthropic stream model inference failed", "error", err) + return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err)) + } + + prediction, err := predFunc() + if err != nil { + xlog.Error("Anthropic stream prediction failed", "error", err) + return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("prediction failed: %v", err)) + } + + // Send content_block_stop event + contentBlockStop := schema.AnthropicStreamEvent{ + Type: "content_block_stop", + Index: 0, + } + sendAnthropicSSE(c, contentBlockStop) + + // Send message_delta event with stop_reason + stopReason := "end_turn" + messageDelta := schema.AnthropicStreamEvent{ + Type: "message_delta", + Delta: &schema.AnthropicStreamDelta{ + StopReason: &stopReason, + }, + Usage: &schema.AnthropicUsage{ + OutputTokens: prediction.Usage.Completion, + }, + } + sendAnthropicSSE(c, messageDelta) + + // Send message_stop event + messageStop := schema.AnthropicStreamEvent{ + Type: "message_stop", + } + sendAnthropicSSE(c, messageStop) + + return nil +} + +func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) { + data, err := json.Marshal(event) + if err != nil { + xlog.Error("Failed to marshal SSE event", "error", err) + return + } + fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.Type, string(data)) + c.Response().Flush() +} + +func sendAnthropicError(c echo.Context, statusCode int, errorType, message string) error { + resp := schema.AnthropicErrorResponse{ + Type: "error", + Error: schema.AnthropicError{ + Type: errorType, + Message: message, + }, + } + return c.JSON(statusCode, resp) +} + +func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.Message { + var messages []schema.Message + + // Add system message if present + if input.System != "" { + messages = append(messages, schema.Message{ + Role: "system", + StringContent: input.System, + Content: input.System, + }) + } + + // Convert Anthropic messages to OpenAI format + for _, msg := range input.Messages { + openAIMsg := schema.Message{ + Role: msg.Role, + } + + // Handle content (can be string or array of content blocks) + switch content := msg.Content.(type) { + case string: + openAIMsg.StringContent = content + openAIMsg.Content = content + case []interface{}: + // Handle array of content blocks + var textContent string + var stringImages []string + + for _, block := range content { + if blockMap, ok := block.(map[string]interface{}); ok { + blockType, _ := blockMap["type"].(string) + switch blockType { + case "text": + if text, ok := blockMap["text"].(string); ok { + textContent += text + } + case "image": + // Handle image content + if source, ok := blockMap["source"].(map[string]interface{}); ok { + if sourceType, ok := source["type"].(string); ok && sourceType == "base64" { + if data, ok := source["data"].(string); ok { + mediaType, _ := source["media_type"].(string) + // Format as data URI + dataURI := fmt.Sprintf("data:%s;base64,%s", mediaType, data) + stringImages = append(stringImages, dataURI) + } + } + } + } + } + } + openAIMsg.StringContent = textContent + openAIMsg.Content = textContent + openAIMsg.StringImages = stringImages + } + + messages = append(messages, openAIMsg) + } + + return messages +} diff --git a/core/http/routes/anthropic.go b/core/http/routes/anthropic.go new file mode 100644 index 000000000000..9f7050c27edf --- /dev/null +++ b/core/http/routes/anthropic.go @@ -0,0 +1,108 @@ +package routes + +import ( + "context" + "fmt" + "net/http" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/endpoints/anthropic" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/xlog" +) + +func RegisterAnthropicRoutes(app *echo.Echo, + re *middleware.RequestExtractor, + application *application.Application) { + + // Anthropic Messages API endpoint + messagesHandler := anthropic.MessagesEndpoint( + application.ModelConfigLoader(), + application.ModelLoader(), + application.TemplatesEvaluator(), + application.ApplicationConfig(), + ) + + messagesMiddleware := []echo.MiddlewareFunc{ + middleware.TraceMiddleware(application), + re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.AnthropicRequest) }), + setAnthropicRequestContext(application.ApplicationConfig()), + } + + // Main Anthropic endpoint + app.POST("/v1/messages", messagesHandler, messagesMiddleware...) + + // Also support without version prefix for compatibility + app.POST("/messages", messagesHandler, messagesMiddleware...) +} + +// setAnthropicRequestContext sets up the context and cancel function for Anthropic requests +func setAnthropicRequestContext(appConfig *config.ApplicationConfig) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.AnthropicRequest) + if !ok || input.Model == "" { + return echo.NewHTTPError(http.StatusBadRequest, "model is required") + } + + cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || cfg == nil { + return echo.NewHTTPError(http.StatusBadRequest, "model configuration not found") + } + + // Extract or generate the correlation ID + // Anthropic uses x-request-id header + correlationID := c.Request().Header.Get("x-request-id") + if correlationID == "" { + correlationID = uuid.New().String() + } + c.Response().Header().Set("x-request-id", correlationID) + + // Set up context with cancellation + reqCtx := c.Request().Context() + c1, cancel := context.WithCancel(appConfig.Context) + + // Cancel when request context is cancelled (client disconnects) + go func() { + select { + case <-reqCtx.Done(): + cancel() + case <-c1.Done(): + // Already cancelled + } + }() + + // Add the correlation ID to the new context + ctxWithCorrelationID := context.WithValue(c1, middleware.CorrelationIDKey, correlationID) + + input.Context = ctxWithCorrelationID + input.Cancel = cancel + + if cfg.Model == "" { + xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) + cfg.Model = input.Model + } + + c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + // Log the Anthropic API version if provided + anthropicVersion := c.Request().Header.Get("anthropic-version") + if anthropicVersion != "" { + xlog.Debug("Anthropic API version", "version", anthropicVersion) + } + + // Validate max_tokens is provided + if input.MaxTokens <= 0 { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("max_tokens is required and must be greater than 0")) + } + + return next(c) + } + } +} diff --git a/core/schema/anthropic.go b/core/schema/anthropic.go new file mode 100644 index 000000000000..02398faa4af9 --- /dev/null +++ b/core/schema/anthropic.go @@ -0,0 +1,163 @@ +package schema + +import ( + "context" + "encoding/json" +) + +// AnthropicRequest represents a request to the Anthropic Messages API +// https://docs.anthropic.com/claude/reference/messages_post +type AnthropicRequest struct { + Model string `json:"model"` + Messages []AnthropicMessage `json:"messages"` + MaxTokens int `json:"max_tokens"` + Metadata map[string]string `json:"metadata,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK *int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + + // Internal fields for request handling + Context context.Context `json:"-"` + Cancel context.CancelFunc `json:"-"` +} + +// ModelName implements the LocalAIRequest interface +func (ar *AnthropicRequest) ModelName(s *string) string { + if s != nil { + ar.Model = *s + } + return ar.Model +} + +// AnthropicMessage represents a message in the Anthropic format +type AnthropicMessage struct { + Role string `json:"role"` + Content interface{} `json:"content"` +} + +// AnthropicContentBlock represents a content block in an Anthropic message +type AnthropicContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Source *AnthropicImageSource `json:"source,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input map[string]interface{} `json:"input,omitempty"` +} + +// AnthropicImageSource represents an image source in Anthropic format +type AnthropicImageSource struct { + Type string `json:"type"` + MediaType string `json:"media_type"` + Data string `json:"data"` +} + +// AnthropicResponse represents a response from the Anthropic Messages API +type AnthropicResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []AnthropicContentBlock `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence,omitempty"` + Usage AnthropicUsage `json:"usage"` +} + +// AnthropicUsage represents token usage in Anthropic format +type AnthropicUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// AnthropicStreamEvent represents a streaming event from the Anthropic API +type AnthropicStreamEvent struct { + Type string `json:"type"` + Index int `json:"index,omitempty"` + ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"` + Delta *AnthropicStreamDelta `json:"delta,omitempty"` + Message *AnthropicStreamMessage `json:"message,omitempty"` + Usage *AnthropicUsage `json:"usage,omitempty"` +} + +// AnthropicStreamDelta represents the delta in a streaming response +type AnthropicStreamDelta struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` +} + +// AnthropicStreamMessage represents the message object in streaming events +type AnthropicStreamMessage struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []AnthropicContentBlock `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence,omitempty"` + Usage AnthropicUsage `json:"usage"` +} + +// AnthropicErrorResponse represents an error response from the Anthropic API +type AnthropicErrorResponse struct { + Type string `json:"type"` + Error AnthropicError `json:"error"` +} + +// AnthropicError represents an error in the Anthropic format +type AnthropicError struct { + Type string `json:"type"` + Message string `json:"message"` +} + +// GetStringContent extracts the string content from an AnthropicMessage +// Content can be either a string or an array of content blocks +func (m *AnthropicMessage) GetStringContent() string { + switch content := m.Content.(type) { + case string: + return content + case []interface{}: + var result string + for _, block := range content { + if blockMap, ok := block.(map[string]interface{}); ok { + if blockMap["type"] == "text" { + if text, ok := blockMap["text"].(string); ok { + result += text + } + } + } + } + return result + } + return "" +} + +// GetContentBlocks extracts content blocks from an AnthropicMessage +func (m *AnthropicMessage) GetContentBlocks() []AnthropicContentBlock { + switch content := m.Content.(type) { + case string: + return []AnthropicContentBlock{{Type: "text", Text: content}} + case []interface{}: + var blocks []AnthropicContentBlock + for _, block := range content { + if blockMap, ok := block.(map[string]interface{}); ok { + cb := AnthropicContentBlock{} + data, err := json.Marshal(blockMap) + if err != nil { + continue + } + if err := json.Unmarshal(data, &cb); err != nil { + continue + } + blocks = append(blocks, cb) + } + } + return blocks + } + return nil +} diff --git a/core/schema/anthropic_test.go b/core/schema/anthropic_test.go new file mode 100644 index 000000000000..f4dbde16df77 --- /dev/null +++ b/core/schema/anthropic_test.go @@ -0,0 +1,145 @@ +package schema_test + +import ( + "encoding/json" + + "github.com/mudler/LocalAI/core/schema" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Anthropic Schema", func() { + Describe("AnthropicRequest", func() { + It("should unmarshal a valid request", func() { + jsonData := `{ + "model": "claude-3-sonnet-20240229", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, world!"} + ], + "system": "You are a helpful assistant.", + "temperature": 0.7 + }` + + var req schema.AnthropicRequest + err := json.Unmarshal([]byte(jsonData), &req) + Expect(err).ToNot(HaveOccurred()) + Expect(req.Model).To(Equal("claude-3-sonnet-20240229")) + Expect(req.MaxTokens).To(Equal(1024)) + Expect(len(req.Messages)).To(Equal(1)) + Expect(req.System).To(Equal("You are a helpful assistant.")) + Expect(*req.Temperature).To(Equal(0.7)) + }) + + It("should implement LocalAIRequest interface", func() { + req := &schema.AnthropicRequest{Model: "test-model"} + Expect(req.ModelName(nil)).To(Equal("test-model")) + + newModel := "new-model" + Expect(req.ModelName(&newModel)).To(Equal("new-model")) + Expect(req.Model).To(Equal("new-model")) + }) + }) + + Describe("AnthropicMessage", func() { + It("should get string content from string content", func() { + msg := schema.AnthropicMessage{ + Role: "user", + Content: "Hello, world!", + } + Expect(msg.GetStringContent()).To(Equal("Hello, world!")) + }) + + It("should get string content from array content", func() { + msg := schema.AnthropicMessage{ + Role: "user", + Content: []interface{}{ + map[string]interface{}{"type": "text", "text": "Hello, "}, + map[string]interface{}{"type": "text", "text": "world!"}, + }, + } + Expect(msg.GetStringContent()).To(Equal("Hello, world!")) + }) + + It("should get content blocks from string content", func() { + msg := schema.AnthropicMessage{ + Role: "user", + Content: "Hello, world!", + } + blocks := msg.GetContentBlocks() + Expect(len(blocks)).To(Equal(1)) + Expect(blocks[0].Type).To(Equal("text")) + Expect(blocks[0].Text).To(Equal("Hello, world!")) + }) + + It("should get content blocks from array content", func() { + msg := schema.AnthropicMessage{ + Role: "user", + Content: []interface{}{ + map[string]interface{}{"type": "text", "text": "Hello"}, + map[string]interface{}{"type": "image", "source": map[string]interface{}{"type": "base64", "data": "abc123"}}, + }, + } + blocks := msg.GetContentBlocks() + Expect(len(blocks)).To(Equal(2)) + Expect(blocks[0].Type).To(Equal("text")) + Expect(blocks[0].Text).To(Equal("Hello")) + }) + }) + + Describe("AnthropicResponse", func() { + It("should marshal a valid response", func() { + stopReason := "end_turn" + resp := schema.AnthropicResponse{ + ID: "msg_123", + Type: "message", + Role: "assistant", + Model: "claude-3-sonnet-20240229", + StopReason: &stopReason, + Content: []schema.AnthropicContentBlock{ + {Type: "text", Text: "Hello!"}, + }, + Usage: schema.AnthropicUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + } + + data, err := json.Marshal(resp) + Expect(err).ToNot(HaveOccurred()) + + var result map[string]interface{} + err = json.Unmarshal(data, &result) + Expect(err).ToNot(HaveOccurred()) + + Expect(result["id"]).To(Equal("msg_123")) + Expect(result["type"]).To(Equal("message")) + Expect(result["role"]).To(Equal("assistant")) + Expect(result["stop_reason"]).To(Equal("end_turn")) + }) + }) + + Describe("AnthropicErrorResponse", func() { + It("should marshal an error response", func() { + resp := schema.AnthropicErrorResponse{ + Type: "error", + Error: schema.AnthropicError{ + Type: "invalid_request_error", + Message: "max_tokens is required", + }, + } + + data, err := json.Marshal(resp) + Expect(err).ToNot(HaveOccurred()) + + var result map[string]interface{} + err = json.Unmarshal(data, &result) + Expect(err).ToNot(HaveOccurred()) + + Expect(result["type"]).To(Equal("error")) + errorObj := result["error"].(map[string]interface{}) + Expect(errorObj["type"]).To(Equal("invalid_request_error")) + Expect(errorObj["message"]).To(Equal("max_tokens is required")) + }) + }) +}) diff --git a/go.mod b/go.mod index e634cd9cee41..d5ee68973502 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( fyne.io/fyne/v2 v2.7.1 github.com/Masterminds/sprig/v3 v3.3.0 github.com/alecthomas/kong v1.13.0 + github.com/anthropics/anthropic-sdk-go v1.19.0 github.com/charmbracelet/glamour v0.10.0 github.com/containerd/containerd v1.7.30 github.com/ebitengine/purego v0.9.1 @@ -58,6 +59,7 @@ require ( go.opentelemetry.io/otel/metric v1.39.0 go.opentelemetry.io/otel/sdk/metric v1.39.0 google.golang.org/grpc v1.78.0 + google.golang.org/protobuf v1.36.10 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 oras.land/oras-go/v2 v2.6.0 @@ -67,8 +69,11 @@ require ( github.com/ghodss/yaml v1.0.0 // indirect github.com/labstack/gommon v0.4.2 // indirect github.com/swaggo/files/v2 v2.0.2 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect - google.golang.org/protobuf v1.36.10 // indirect ) require ( diff --git a/go.sum b/go.sum index 17b8c6341150..d6928e0309d1 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/andybalholm/brotli v1.0.1/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/anthropics/anthropic-sdk-go v1.19.0 h1:mO6E+ffSzLRvR/YUH9KJC0uGw0uV8GjISIuzem//3KE= +github.com/anthropics/anthropic-sdk-go v1.19.0/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8= @@ -762,10 +764,12 @@ github.com/swaggo/swag v1.16.6/go.mod h1:ngP2etMK5a0P3QBizic5MEwpRmluJZPHjXcMoj4 github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU= github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= diff --git a/tests/e2e/e2e_anthropic_test.go b/tests/e2e/e2e_anthropic_test.go new file mode 100644 index 000000000000..8980ec355784 --- /dev/null +++ b/tests/e2e/e2e_anthropic_test.go @@ -0,0 +1,150 @@ +package e2e_test + +import ( + "context" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Anthropic API E2E test", func() { + var client anthropic.Client + + Context("API with Anthropic SDK", func() { + BeforeEach(func() { + // Create Anthropic client pointing to LocalAI + client = anthropic.NewClient( + option.WithBaseURL(localAIURL), + option.WithAPIKey("test-api-key"), // LocalAI doesn't require a real API key + ) + + // Wait for API to be ready by attempting a simple request + Eventually(func() error { + _, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ + Model: "gpt-4", + MaxTokens: 10, + Messages: []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock("Hi")), + }, + }) + return err + }, "2m").ShouldNot(HaveOccurred()) + }) + + Context("Non-streaming responses", func() { + It("generates a response for a simple message", func() { + message, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ + Model: "gpt-4", + MaxTokens: 1024, + Messages: []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock("How much is 2+2? Reply with just the number.")), + }, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(message.Content).ToNot(BeEmpty()) + // Role is a constant type that defaults to "assistant" + Expect(string(message.Role)).To(Equal("assistant")) + Expect(message.StopReason).To(Equal(anthropic.MessageStopReasonEndTurn)) + Expect(string(message.Type)).To(Equal("message")) + + // Check that content contains text block with expected answer + Expect(len(message.Content)).To(BeNumerically(">=", 1)) + textBlock := message.Content[0] + Expect(string(textBlock.Type)).To(Equal("text")) + Expect(textBlock.Text).To(Or(ContainSubstring("4"), ContainSubstring("four"))) + }) + + It("handles system prompts", func() { + message, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ + Model: "gpt-4", + MaxTokens: 1024, + System: []anthropic.TextBlockParam{ + {Text: "You are a helpful assistant. Always respond in uppercase letters."}, + }, + Messages: []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock("Say hello")), + }, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(message.Content).ToNot(BeEmpty()) + Expect(len(message.Content)).To(BeNumerically(">=", 1)) + }) + + It("returns usage information", func() { + message, err := client.Messages.New(context.TODO(), anthropic.MessageNewParams{ + Model: "gpt-4", + MaxTokens: 100, + Messages: []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock("Hello")), + }, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(message.Usage.InputTokens).To(BeNumerically(">", 0)) + Expect(message.Usage.OutputTokens).To(BeNumerically(">", 0)) + }) + }) + + Context("Streaming responses", func() { + It("streams tokens for a simple message", func() { + stream := client.Messages.NewStreaming(context.TODO(), anthropic.MessageNewParams{ + Model: "gpt-4", + MaxTokens: 1024, + Messages: []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock("Count from 1 to 5")), + }, + }) + + message := anthropic.Message{} + eventCount := 0 + hasContentDelta := false + + for stream.Next() { + event := stream.Current() + err := message.Accumulate(event) + Expect(err).ToNot(HaveOccurred()) + eventCount++ + + // Check for content block delta events + switch event.AsAny().(type) { + case anthropic.ContentBlockDeltaEvent: + hasContentDelta = true + } + } + + Expect(stream.Err()).ToNot(HaveOccurred()) + Expect(eventCount).To(BeNumerically(">", 0)) + Expect(hasContentDelta).To(BeTrue()) + + // Check accumulated message + Expect(message.Content).ToNot(BeEmpty()) + // Role is a constant type that defaults to "assistant" + Expect(string(message.Role)).To(Equal("assistant")) + }) + + It("streams with system prompt", func() { + stream := client.Messages.NewStreaming(context.TODO(), anthropic.MessageNewParams{ + Model: "gpt-4", + MaxTokens: 1024, + System: []anthropic.TextBlockParam{ + {Text: "You are a helpful assistant."}, + }, + Messages: []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock("Say hello")), + }, + }) + + message := anthropic.Message{} + for stream.Next() { + event := stream.Current() + err := message.Accumulate(event) + Expect(err).ToNot(HaveOccurred()) + } + + Expect(stream.Err()).ToNot(HaveOccurred()) + Expect(message.Content).ToNot(BeEmpty()) + }) + }) + }) +})