From 9e07b3eb35ddb4a04ad54d0fbdb9dd3621e0a80a Mon Sep 17 00:00:00 2001 From: Akshay Deo Date: Thu, 30 Oct 2025 04:09:32 -0700 Subject: [PATCH] adds support for gemini format (native and openai) --- go.mod | 2 + go.sum | 2 + logging/base.go | 1 + logging/gemini_test.go | 345 +++++++++++++++++++++++++++++++++++++++++ logging/gemni.go | 166 ++++++++++++++++++++ logging/generation.go | 13 ++ logging/parser.go | 2 + 7 files changed, 531 insertions(+) create mode 100644 logging/gemini_test.go create mode 100644 logging/gemni.go diff --git a/go.mod b/go.mod index 39915f4..b74f774 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,5 @@ module github.com/maximhq/maxim-go go 1.21 require github.com/google/uuid v1.6.0 + +require github.com/joho/godotenv v1.5.1 diff --git a/go.sum b/go.sum index 7790d7c..b36acff 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= diff --git a/logging/base.go b/logging/base.go index 528d57d..bbbc4dc 100644 --- a/logging/base.go +++ b/logging/base.go @@ -11,6 +11,7 @@ const ( ProviderAzure = "azure" ProviderAnthropic = "anthropic" ProviderBedrock = "aws" + ProviderGemini = "gemini" ) // baseConfig is the configuration for a base entity. diff --git a/logging/gemini_test.go b/logging/gemini_test.go new file mode 100644 index 0000000..1c86bfd --- /dev/null +++ b/logging/gemini_test.go @@ -0,0 +1,345 @@ +package logging_test + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "testing" + + "github.com/joho/godotenv" + "github.com/maximhq/maxim-go/logging" +) + +// TestParseGeminiNativeResponse tests parsing of native Gemini API response format +func TestParseGeminiNativeResponse(t *testing.T) { + // Sample native Gemini response + geminiResponse := map[string]interface{}{ + "candidates": []map[string]interface{}{ + { + "content": map[string]interface{}{ + "parts": []map[string]interface{}{ + { + "text": "Hello! How can I help you today?", + }, + }, + "role": "model", + }, + "finishReason": "STOP", + "index": 0, + }, + }, + "usageMetadata": map[string]interface{}{ + "promptTokenCount": 10, + "candidatesTokenCount": 20, + "totalTokenCount": 30, + }, + "modelVersion": "gemini-pro", + } + + jsonData, err := json.Marshal(geminiResponse) + if err != nil { + t.Fatalf("Failed to marshal test data: %v", err) + } + + result, err := logging.ParseGeminiResult(jsonData) + if err != nil { + t.Fatalf("ParseGeminiResult failed: %v", err) + } + + // Validate result + if result.Model != "gemini-pro" { + t.Errorf("Expected model 'gemini-pro', got '%s'", result.Model) + } + + if len(result.Choices) != 1 { + t.Fatalf("Expected 1 choice, got %d", len(result.Choices)) + } + + if result.Choices[0].Message.Content != "Hello! How can I help you today?" { + t.Errorf("Expected content 'Hello! How can I help you today?', got '%s'", result.Choices[0].Message.Content) + } + + if result.Choices[0].Message.Role != "assistant" { + t.Errorf("Expected role 'assistant', got '%s'", result.Choices[0].Message.Role) + } + + if result.Choices[0].FinishReason != "stop" { + t.Errorf("Expected finish reason 'stop', got '%s'", result.Choices[0].FinishReason) + } + + if result.Usage.PromptTokens != 10 { + t.Errorf("Expected 10 prompt tokens, got %d", result.Usage.PromptTokens) + } + + if result.Usage.CompletionTokens != 20 { + t.Errorf("Expected 20 completion tokens, got %d", result.Usage.CompletionTokens) + } + + if result.Usage.TotalTokens != 30 { + t.Errorf("Expected 30 total tokens, got %d", result.Usage.TotalTokens) + } +} + +// TestParseGeminiOpenAICompatible tests parsing of OpenAI-compatible Gemini response +func TestParseGeminiOpenAICompatible(t *testing.T) { + // Sample OpenAI-compatible response + openAIResponse := map[string]interface{}{ + "id": "chatcmpl-123", + "model": "gemini-pro", + "created": 1677652288, + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{ + "role": "assistant", + "content": "This is a test response", + }, + "finish_reason": "stop", + }, + }, + "usage": map[string]interface{}{ + "prompt_tokens": 15, + "completion_tokens": 25, + "total_tokens": 40, + }, + } + + jsonData, err := json.Marshal(openAIResponse) + if err != nil { + t.Fatalf("Failed to marshal test data: %v", err) + } + + result, err := logging.ParseGeminiResult(jsonData) + if err != nil { + t.Fatalf("ParseGeminiResult failed: %v", err) + } + + // Validate result + if result.ID != "chatcmpl-123" { + t.Errorf("Expected ID 'chatcmpl-123', got '%s'", result.ID) + } + + if result.Model != "gemini-pro" { + t.Errorf("Expected model 'gemini-pro', got '%s'", result.Model) + } + + if len(result.Choices) != 1 { + t.Fatalf("Expected 1 choice, got %d", len(result.Choices)) + } + + if result.Choices[0].Message.Content != "This is a test response" { + t.Errorf("Expected content 'This is a test response', got '%s'", result.Choices[0].Message.Content) + } +} + +// TestParseGeminiWithFunctionCall tests parsing of Gemini response with function calls +func TestParseGeminiWithFunctionCall(t *testing.T) { + // Sample Gemini response with function call + geminiResponse := map[string]interface{}{ + "candidates": []map[string]interface{}{ + { + "content": map[string]interface{}{ + "parts": []map[string]interface{}{ + { + "text": "Let me check the weather for you.", + }, + { + "functionCall": map[string]interface{}{ + "name": "get_weather", + "args": map[string]interface{}{ + "location": "San Francisco", + "unit": "celsius", + }, + }, + }, + }, + "role": "model", + }, + "finishReason": "STOP", + "index": 0, + }, + }, + "usageMetadata": map[string]interface{}{ + "promptTokenCount": 15, + "candidatesTokenCount": 30, + "totalTokenCount": 45, + }, + } + + jsonData, err := json.Marshal(geminiResponse) + if err != nil { + t.Fatalf("Failed to marshal test data: %v", err) + } + + result, err := logging.ParseGeminiResult(jsonData) + if err != nil { + t.Fatalf("ParseGeminiResult failed: %v", err) + } + + // Validate result + if len(result.Choices) != 1 { + t.Fatalf("Expected 1 choice, got %d", len(result.Choices)) + } + + if result.Choices[0].Message.Content != "Let me check the weather for you." { + t.Errorf("Expected text content, got '%s'", result.Choices[0].Message.Content) + } + + if len(result.Choices[0].Message.ToolCalls) != 1 { + t.Fatalf("Expected 1 tool call, got %d", len(result.Choices[0].Message.ToolCalls)) + } + + toolCall := result.Choices[0].Message.ToolCalls[0] + if toolCall.Function.Name != "get_weather" { + t.Errorf("Expected function name 'get_weather', got '%s'", toolCall.Function.Name) + } + + if toolCall.Type != "function" { + t.Errorf("Expected type 'function', got '%s'", toolCall.Type) + } +} + +// TestGeminiAPIIntegration tests real Gemini API integration +// This test requires GEMINI_API_KEY to be set in .env file +func TestGeminiAPIIntegration(t *testing.T) { + // Load .env file + err := godotenv.Load("../.env") + if err != nil { + t.Skip("Skipping integration test: .env file not found") + } + + apiKey := os.Getenv("GEMINI_API_KEY") + if apiKey == "" { + t.Skip("Skipping integration test: GEMINI_API_KEY not set in .env") + } + + // Make a real API call to Gemini + model := "gemini-pro" + url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent?key=%s", model, apiKey) + + requestBody := map[string]interface{}{ + "contents": []map[string]interface{}{ + { + "parts": []map[string]interface{}{ + { + "text": "Say hello in one sentence", + }, + }, + }, + }, + } + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("API request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("API returned non-200 status: %d, body: %s", resp.StatusCode, string(body)) + } + + // Parse the response using our parser + result, err := logging.ParseGeminiResult(body) + if err != nil { + t.Fatalf("ParseGeminiResult failed: %v, response: %s", err, string(body)) + } + + // Validate the parsed result + if result == nil { + t.Fatal("Result is nil") + } + + if len(result.Choices) == 0 { + t.Fatal("No choices in result") + } + + if result.Choices[0].Message.Content == "" { + t.Error("Content is empty") + } + + if result.Usage.TotalTokens == 0 { + t.Error("Total tokens is 0") + } + + // Log the result for inspection + t.Logf("Successfully parsed Gemini response:") + t.Logf("Model: %s", result.Model) + t.Logf("Content: %s", result.Choices[0].Message.Content) + t.Logf("Prompt tokens: %d", result.Usage.PromptTokens) + t.Logf("Completion tokens: %d", result.Usage.CompletionTokens) + t.Logf("Total tokens: %d", result.Usage.TotalTokens) +} + +// TestGeminiFinishReasonMapping tests the mapping of Gemini finish reasons +func TestGeminiFinishReasonMapping(t *testing.T) { + testCases := []struct { + geminiReason string + expectedReason string + }{ + {"STOP", "stop"}, + {"MAX_TOKENS", "length"}, + {"SAFETY", "content_filter"}, + {"RECITATION", "content_filter"}, + {"OTHER", "stop"}, + } + + for _, tc := range testCases { + t.Run(tc.geminiReason, func(t *testing.T) { + geminiResponse := map[string]interface{}{ + "candidates": []map[string]interface{}{ + { + "content": map[string]interface{}{ + "parts": []map[string]interface{}{ + {"text": "test"}, + }, + "role": "model", + }, + "finishReason": tc.geminiReason, + "index": 0, + }, + }, + "usageMetadata": map[string]interface{}{ + "promptTokenCount": 1, + "candidatesTokenCount": 1, + "totalTokenCount": 2, + }, + } + + jsonData, err := json.Marshal(geminiResponse) + if err != nil { + t.Fatalf("Failed to marshal test data: %v", err) + } + + result, err := logging.ParseGeminiResult(jsonData) + if err != nil { + t.Fatalf("ParseGeminiResult failed: %v", err) + } + + if result.Choices[0].FinishReason != tc.expectedReason { + t.Errorf("Expected finish reason '%s', got '%s'", tc.expectedReason, result.Choices[0].FinishReason) + } + }) + } +} + diff --git a/logging/gemni.go b/logging/gemni.go new file mode 100644 index 0000000..f34168a --- /dev/null +++ b/logging/gemni.go @@ -0,0 +1,166 @@ +// Package logging provides parsers for various LLM provider API responses. +package logging + +import ( + "encoding/json" + "fmt" + "time" +) + +// ParseGeminiResult parses a JSON response from Google's Gemini API and returns a MaximLLMResult. +// It handles both native Gemini format and OpenAI-compatible format. +func ParseGeminiResult(jsonData []byte) (*MaximLLMResult, error) { + // First, try to parse as OpenAI-compatible format + var openAIResp MaximLLMResult + if err := json.Unmarshal(jsonData, &openAIResp); err == nil { + // Check if it looks like an OpenAI format (has choices array) + if len(openAIResp.Choices) > 0 { + return &openAIResp, nil + } + } + + // If not OpenAI format, parse as native Gemini format + return parseNativeGeminiResult(jsonData) +} + +// parseNativeGeminiResult parses the native Gemini API response format +func parseNativeGeminiResult(jsonData []byte) (*MaximLLMResult, error) { + var geminiResp struct { + Candidates []struct { + Content struct { + Parts []struct { + Text string `json:"text"` + FunctionCall *map[string]interface{} `json:"functionCall,omitempty"` + } `json:"parts"` + Role string `json:"role"` + } `json:"content"` + FinishReason string `json:"finishReason"` + Index int `json:"index"` + SafetyRatings []struct { + Category string `json:"category"` + Probability string `json:"probability"` + } `json:"safetyRatings,omitempty"` + } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + } `json:"usageMetadata"` + ModelVersion string `json:"modelVersion,omitempty"` + } + + if err := json.Unmarshal(jsonData, &geminiResp); err != nil { + return nil, fmt.Errorf("failed to unmarshal Gemini completion: %w", err) + } + + // Validate response has candidates + if len(geminiResp.Candidates) == 0 { + return nil, fmt.Errorf("no candidates in Gemini response") + } + + resp := MaximLLMResult{} + + // Generate an ID (Gemini doesn't provide one in native format) + resp.ID = fmt.Sprintf("gemini-%d", time.Now().UnixNano()) + + // Set model from modelVersion or use a default + if geminiResp.ModelVersion != "" { + resp.Model = geminiResp.ModelVersion + } else { + resp.Model = "gemini" + } + + // Set creation timestamp + resp.Created = time.Now().Unix() + + // Initialize choices array + resp.Choices = make([]struct { + Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ChatCompletionToolCall `json:"tool_calls,omitempty"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + }, len(geminiResp.Candidates)) + + // Process each candidate + for i, candidate := range geminiResp.Candidates { + // Map Gemini role to OpenAI role format + role := "assistant" + if candidate.Content.Role == "model" { + role = "assistant" + } else if candidate.Content.Role == "user" { + role = "user" + } + resp.Choices[i].Message.Role = role + + // Concatenate all text parts + var fullContent string + var toolCalls []ChatCompletionToolCall + + for partIdx, part := range candidate.Content.Parts { + if part.Text != "" { + fullContent += part.Text + } + + // Handle function calls as tool calls + if part.FunctionCall != nil { + functionCall := *part.FunctionCall + if name, ok := functionCall["name"].(string); ok { + // Convert args to JSON string + var argsJSON string + if args, ok := functionCall["args"]; ok { + argsBytes, err := json.Marshal(args) + if err == nil { + argsJSON = string(argsBytes) + } + } + + toolCall := ChatCompletionToolCall{ + ID: fmt.Sprintf("call_%d_%d", i, partIdx), + Type: "function", + Function: ToolCallFunction{ + Name: name, + Arguments: argsJSON, + }, + } + toolCalls = append(toolCalls, toolCall) + } + } + } + + resp.Choices[i].Message.Content = fullContent + if len(toolCalls) > 0 { + resp.Choices[i].Message.ToolCalls = toolCalls + } + + // Map finish reason + finishReason := mapGeminiFinishReason(candidate.FinishReason) + resp.Choices[i].FinishReason = finishReason + } + + // Set usage information + resp.Usage.PromptTokens = geminiResp.UsageMetadata.PromptTokenCount + resp.Usage.CompletionTokens = geminiResp.UsageMetadata.CandidatesTokenCount + resp.Usage.TotalTokens = geminiResp.UsageMetadata.TotalTokenCount + + return &resp, nil +} + +// mapGeminiFinishReason maps Gemini finish reasons to OpenAI format +func mapGeminiFinishReason(reason string) string { + switch reason { + case "STOP": + return "stop" + case "MAX_TOKENS": + return "length" + case "SAFETY": + return "content_filter" + case "RECITATION": + return "content_filter" + case "OTHER": + return "stop" + default: + return "stop" + } +} diff --git a/logging/generation.go b/logging/generation.go index 89f228a..d029e81 100644 --- a/logging/generation.go +++ b/logging/generation.go @@ -2,6 +2,7 @@ package logging import ( "encoding/json" + "fmt" "log" "time" ) @@ -165,6 +166,11 @@ type Generation struct { } func newGeneration(c *GenerationConfig, w *writer) *Generation { + // Validating provider + if c.Provider != ProviderOpenAI && c.Provider != ProviderAzure && c.Provider != ProviderBedrock && c.Provider != ProviderAnthropic && c.Provider != ProviderGemini { + fmt.Printf("[MaximSDK] Invalid provider %s, allowed providers are %s, %s, %s, %s, %s\n", c.Provider, ProviderOpenAI, ProviderAzure, ProviderBedrock, ProviderAnthropic, ProviderGemini) + return nil + } return &Generation{ base: newBase(EntityGeneration, c.Id, &baseConfig{ SpanId: c.SpanId, @@ -219,6 +225,11 @@ func (g *Generation) handleAnthropicResult(jsonData []byte) (*MaximLLMResult, er return ParseAnthropicResult(jsonData) } +// handleGeminiResult extracts and logs data from a Gemini completion +func (g *Generation) handleGeminiResult(jsonData []byte) (*MaximLLMResult, error) { + return ParseGeminiResult(jsonData) +} + // handleAzure extracts and logs data from an Azure OpenAI completion func (g *Generation) handleAzure(jsonData []byte, _ time.Duration) (*MaximLLMResult, error) { // Azure OpenAI has the same response format as OpenAI @@ -255,6 +266,8 @@ func (g *Generation) SetResult(r interface{}) { finalResult, err = g.handleBedrockConverseResult(jsonData) case ProviderAnthropic: finalResult, err = g.handleAnthropicResult(jsonData) + case ProviderGemini: + finalResult, err = g.handleGeminiResult(jsonData) } if err != nil { log.Println("[MaximSDK] Failed to parse result", err) diff --git a/logging/parser.go b/logging/parser.go index d42312b..e4a7f5c 100644 --- a/logging/parser.go +++ b/logging/parser.go @@ -24,6 +24,8 @@ func ParseResult(provider, model string, r interface{}) (*MaximLLMResult, error) finalResult, err = ParseBedrockResult(model, jsonData) case ProviderAnthropic: finalResult, err = ParseAnthropicResult(jsonData) + case ProviderGemini: + finalResult, err = ParseGeminiResult(jsonData) } return finalResult, err }