diff --git a/cmd/server/main.go b/cmd/server/main.go index d63cfbd343..b642e52a01 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -505,6 +505,7 @@ func main() { misc.StartAntigravityVersionUpdater(context.Background()) if !localModel { registry.StartModelsUpdater(context.Background()) + registry.StartOpenRouterEnrichment(context.Background()) } hook := tui.NewLogHook(2000) hook.SetFormatter(&logging.LogFormatter{}) @@ -581,6 +582,7 @@ func main() { misc.StartAntigravityVersionUpdater(context.Background()) if !localModel { registry.StartModelsUpdater(context.Background()) + registry.StartOpenRouterEnrichment(context.Background()) } cmd.StartService(cfg, configFilePath, password) } diff --git a/internal/api/handlers/management/model_definitions.go b/internal/api/handlers/management/model_definitions.go index 85ff314bf4..20146ad1eb 100644 --- a/internal/api/handlers/management/model_definitions.go +++ b/internal/api/handlers/management/model_definitions.go @@ -3,6 +3,7 @@ package management import ( "net/http" "strings" + "time" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" @@ -31,3 +32,93 @@ func (h *Handler) GetStaticModelDefinitions(c *gin.Context) { "models": models, }) } + +// GetModelsHealth returns comprehensive health information for all registered models. +func (h *Handler) GetModelsHealth(c *gin.Context) { + globalRegistry := registry.GetGlobalRegistry() + + // Get ALL registered models including suspended/unhealthy for full operator visibility + models := globalRegistry.GetAllRegisteredModels("openai") + // Build enhanced model health with suspension and provider info + enhancedModels := make([]map[string]any, 0, len(models)) + for _, model := range models { + modelID, _ := model["id"].(string) + if modelID == "" { + continue + } + // Get health details for this model + providers, suspendedClients, count := globalRegistry.GetModelHealthDetails(modelID) + // Build suspension summary + suspensionSummary := make(map[string]any) + if len(suspendedClients) > 0 { + suspensionSummary["count"] = len(suspendedClients) + suspensionSummary["clients"] = suspendedClients + // Extract unique reasons + reasons := make(map[string]bool) + for _, reason := range suspendedClients { + if reason == "" { + reasons["unknown"] = true + } else { + reasons[strings.ToLower(reason)] = true + } + } + reasonList := make([]string, 0, len(reasons)) + for reason := range reasons { + reasonList = append(reasonList, reason) + } + suspensionSummary["reasons"] = reasonList + } else { + suspensionSummary["count"] = 0 + suspensionSummary["clients"] = map[string]string{} + suspensionSummary["reasons"] = []string{} + } + + // Build provider list + providerList := make([]string, 0, len(providers)) + for provider := range providers { + providerList = append(providerList, provider) + } + // Create enhanced model entry + enhancedModel := make(map[string]any) + for k, v := range model { + enhancedModel[k] = v + } + enhancedModel["total_clients"] = count + enhancedModel["providers"] = providerList + enhancedModel["provider_counts"] = providers + enhancedModel["suspension"] = suspensionSummary + enhancedModels = append(enhancedModels, enhancedModel) + } + // Build the sources map indicating where context_length came from + sources := registry.BuildModelSources(globalRegistry) + // Get last refresh timestamp from OpenRouter enrichment + lastRefresh := registry.GetOpenRouterLastRefresh() + lastRefreshStr := "" + if !lastRefresh.IsZero() { + lastRefreshStr = lastRefresh.Format(time.RFC3339) + } + response := gin.H{ + "models": enhancedModels, + "sources": sources, + "last_refresh": lastRefreshStr, + "refresh_interval": "24h", + } + c.JSON(http.StatusOK, response) +} + +// RefreshModels triggers an immediate refresh of model metadata from OpenRouter and the enrichment cache. +// This re-fetches context_length from OpenRouter's public /api/v1/models endpoint and enriches registered models that lack this data. +func (h *Handler) RefreshModels(c *gin.Context) { + count := registry.TriggerOpenRouterRefresh(c.Request.Context()) + if count == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "no models registered"}) + return + } + lastRefresh := registry.GetOpenRouterLastRefresh() + c.JSON(http.StatusOK, gin.H{ + "status": "refreshed", + "enriched_count": count, + "last_refresh": lastRefresh.Format(time.RFC3339), + "total_models": len(registry.GetGlobalRegistry().GetAvailableModels("openai")), + }) +} diff --git a/internal/api/server.go b/internal/api/server.go index 2bdc4ab095..e18bce2639 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -627,6 +627,8 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/auth-files", s.mgmt.ListAuthFiles) mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels) mgmt.GET("/model-definitions/:channel", s.mgmt.GetStaticModelDefinitions) + mgmt.GET("/models/health", s.mgmt.GetModelsHealth) + mgmt.POST("/models/refresh", s.mgmt.RefreshModels) mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile) mgmt.POST("/auth-files", s.mgmt.UploadAuthFile) mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile) @@ -773,8 +775,8 @@ func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, cl return func(c *gin.Context) { userAgent := c.GetHeader("User-Agent") - // Route to Claude handler if User-Agent starts with "claude-cli" - if strings.HasPrefix(userAgent, "claude-cli") { + // Route to Claude handler if User-Agent starts with "claude-cli" or "claude-code" + if strings.HasPrefix(userAgent, "claude-cli") || strings.HasPrefix(userAgent, "claude-code") || strings.HasPrefix(userAgent, "claude/") { // log.Debugf("Routing /v1/models to Claude handler for User-Agent: %s", userAgent) claudeHandler.ClaudeModels(c) } else { diff --git a/internal/config/config.go b/internal/config/config.go index 15847f57e0..5ed2f075a0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,9 +14,10 @@ import ( "syscall" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" - log "github.com/sirupsen/logrus" - "golang.org/x/crypto/bcrypt" - "gopkg.in/yaml.v3" + temporal "github.com/router-for-me/CLIProxyAPI/v6/internal/temporal" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/bcrypt" + "gopkg.in/yaml.v3" ) const ( @@ -126,6 +127,10 @@ type Config struct { // Payload defines default and override rules for provider payload parameters. Payload PayloadConfig `yaml:"payload" json:"payload"` + // Temporal controls temporal anti-drift injection into LLM request payloads. + // nil means use defaults (enabled, every request). + Temporal *temporal.Config `yaml:"temporal" json:"temporal"` + legacyMigrationPending bool `yaml:"-" json:"-"` } diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index 3f3f530d27..2617b7dc98 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -243,6 +243,14 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ if model == nil || model.ID == "" { continue } + if staticInfo := LookupStaticModelInfo(model.ID); staticInfo != nil { + if model.ContextLength == 0 && staticInfo.ContextLength > 0 { + model.ContextLength = staticInfo.ContextLength + } + if model.MaxCompletionTokens == 0 && staticInfo.MaxCompletionTokens > 0 { + model.MaxCompletionTokens = staticInfo.MaxCompletionTokens + } + } rawModelIDs = append(rawModelIDs, model.ID) newCounts[model.ID]++ if _, exists := newModels[model.ID]; exists { @@ -875,6 +883,25 @@ func cloneModelMapValue(value any) any { // // Returns: // - []*ModelInfo: List of available models for the provider +// GetAllRegisteredModels returns metadata for every registered model regardless of +// suspension or quota state. Use this for administrative/health endpoints that need +// visibility into ALL models, including unhealthy ones. +func (r *ModelRegistry) GetAllRegisteredModels(handlerType string) []map[string]any { + r.mutex.RLock() + defer r.mutex.RUnlock() + models := make([]map[string]any, 0, len(r.models)) + for _, registration := range r.models { + if registration.Info == nil { + continue + } + model := r.convertModelToMap(registration.Info, handlerType) + if model != nil { + models = append(models, model) + } + } + return models +} + func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelInfo { provider = strings.ToLower(strings.TrimSpace(provider)) if provider == "" { @@ -1158,6 +1185,14 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) if model.DisplayName != "" { result["display_name"] = model.DisplayName } + if model.ContextLength > 0 { + result["context_length"] = model.ContextLength + result["context_window_size"] = model.ContextLength + } + if model.MaxCompletionTokens > 0 { + result["max_completion_tokens"] = model.MaxCompletionTokens + result["max_output_tokens"] = model.MaxCompletionTokens + } return result case "gemini": @@ -1315,3 +1350,42 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo { } return result } + +// GetModelHealthDetails returns detailed health information for a model including +// providers, suspended clients with reasons, and availability counts. +func (r *ModelRegistry) GetModelHealthDetails(modelID string) (providers map[string]int, suspendedClients map[string]string, count int) { + r.mutex.RLock() + defer r.mutex.RUnlock() + + if registration, exists := r.models[modelID]; exists && registration != nil { + // Clone providers map + if len(registration.Providers) > 0 { + providers = make(map[string]int, len(registration.Providers)) + for k, v := range registration.Providers { + providers[k] = v + } + } + // Clone suspended clients map + if len(registration.SuspendedClients) > 0 { + suspendedClients = make(map[string]string, len(registration.SuspendedClients)) + for k, v := range registration.SuspendedClients { + suspendedClients[k] = v + } + } + count = registration.Count + } + return +} + +// SetModelContextLength updates the context_length on the live model registration. +// This writes through to the actual stored ModelInfo, not a clone. +func (r *ModelRegistry) SetModelContextLength(modelID string, contextLength int) bool { + r.mutex.Lock() + defer r.mutex.Unlock() + reg, ok := r.models[modelID] + if !ok || reg == nil || reg.Info == nil { + return false + } + reg.Info.ContextLength = contextLength + return true +} diff --git a/internal/registry/models/models.json b/internal/registry/models/models.json index acf368ab5f..277668d172 100644 --- a/internal/registry/models/models.json +++ b/internal/registry/models/models.json @@ -2708,6 +2708,7 @@ "display_name": "Gemini 3.1 Flash Image", "name": "gemini-3.1-flash-image", "description": "Gemini 3.1 Flash Image", + "context_length": 1048576, "thinking": { "min": 128, "max": 32768, diff --git a/internal/registry/openrouter_enrichment.go b/internal/registry/openrouter_enrichment.go new file mode 100644 index 0000000000..8971c863ee --- /dev/null +++ b/internal/registry/openrouter_enrichment.go @@ -0,0 +1,295 @@ +// Package registry provides OpenRouter context length enrichment for model metadata. +package registry + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + openRouterFetchTimeout = 30 * time.Second + openRouterRefreshInterval = 24 * time.Hour + openRouterModelsURL = "https://openrouter.ai/api/v1/models" +) + +var enrichmentOnce sync.Once + +// openRouterModel represents a model in OpenRouter's API response +type openRouterModel struct { + ID string `json:"id"` + Name string `json:"name"` + ContextLength int `json:"context_length"` +} + +// openRouterModelsResponse is the response from OpenRouter's models endpoint +type openRouterModelsResponse struct { + Data []openRouterModel `json:"data"` +} + +// openRouterEnrichmentStore tracks enrichment state +type openRouterEnrichmentStore struct { + mu sync.RWMutex + lastRefresh time.Time + contextLength map[string]int // model ID -> context length +} + +var openRouterStore = &openRouterEnrichmentStore{ + contextLength: make(map[string]int), +} + +// StartOpenRouterEnrichment starts a background goroutine that fetches +// context_length metadata from OpenRouter's public models endpoint. +// Runs immediately on startup and then refreshes every 24 hours. +func StartOpenRouterEnrichment(ctx context.Context) { + enrichmentOnce.Do(func() { + go runOpenRouterEnrichment(ctx) + }) +} + +func runOpenRouterEnrichment(ctx context.Context) { + // Initial fetch + fetchAndEnrichOpenRouter(ctx) + + // Periodic refresh + ticker := time.NewTicker(openRouterRefreshInterval) + defer ticker.Stop() + log.Infof("OpenRouter enrichment started (interval=%s)", openRouterRefreshInterval) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + fetchAndEnrichOpenRouter(ctx) + } + } +} + +// fetchAndEnrichOpenRouter fetches models from OpenRouter and enriches +// registered models that lack context_length metadata. +// Returns the number of models actually enriched. +func fetchAndEnrichOpenRouter(ctx context.Context) int { + client := &http.Client{Timeout: openRouterFetchTimeout} + reqCtx, cancel := context.WithTimeout(ctx, openRouterFetchTimeout) + req, err := http.NewRequestWithContext(reqCtx, "GET", openRouterModelsURL, nil) + if err != nil { + cancel() + log.Debugf("OpenRouter enrichment: request creation failed: %v", err) + return 0 + } + + resp, err := client.Do(req) + if err != nil { + cancel() + log.Debugf("OpenRouter enrichment: fetch failed: %v", err) + return 0 + } + + if resp.StatusCode != 200 { + resp.Body.Close() + cancel() + log.Debugf("OpenRouter enrichment: returned status %d", resp.StatusCode) + return 0 + } + + data, err := io.ReadAll(resp.Body) + resp.Body.Close() + cancel() + + if err != nil { + log.Debugf("OpenRouter enrichment: read error: %v", err) + return 0 + } + + var parsed openRouterModelsResponse + if err := json.Unmarshal(data, &parsed); err != nil { + log.Warnf("OpenRouter enrichment: parse failed: %v", err) + return 0 + } + + return enrichModelsFromOpenRouter(parsed.Data) +} + +// enrichModelsFromOpenRouter updates the global registry with context_length +// values from OpenRouter for models that lack this metadata. +// Matches by exact model ID or by checking if the local ID is contained in +// or contains the OpenRouter ID (e.g., "gemini-3.1-pro" matches "google/gemini-3.1-pro-preview"). +// Returns the number of models actually enriched. +func enrichModelsFromOpenRouter(models []openRouterModel) int { + enrichedCount := 0 + registry := GetGlobalRegistry() + + // Build a map of model ID to context length from OpenRouter + openRouterContextLengths := make(map[string]int, len(models)) + for _, m := range models { + if m.ContextLength > 0 { + openRouterContextLengths[m.ID] = m.ContextLength + } + } + + // Get all registered models and enrich those lacking context_length + allModels := registry.GetAvailableModels("openai") + for _, modelMap := range allModels { + modelID, _ := modelMap["id"].(string) + if modelID == "" { + continue + } + + // Skip if already has context_length + if _, hasCL := modelMap["context_length"]; hasCL { + continue + } + + // Try to find a matching OpenRouter model ID + var ctxLen int + var found bool + + // First try exact match + if cl, ok := openRouterContextLengths[modelID]; ok { + ctxLen = cl + found = true + } else { + // Try substring matching: + // - Check if local ID is contained in OpenRouter ID (e.g., "gemini-3.1-pro" in "google/gemini-3.1-pro-preview") + // - Check if OpenRouter ID is contained in local ID + // - Check if OpenRouter ID suffix matches local ID + for orID, cl := range openRouterContextLengths { + if strings.Contains(orID, modelID) || strings.Contains(modelID, orID) { + // Extract the base name from OpenRouter ID (after last slash) + orBase := orID + if slashIdx := strings.LastIndex(orID, "/"); slashIdx >= 0 { + orBase = orID[slashIdx+1:] + } + // Check if local ID matches the base or is a prefix/suffix + if orBase == modelID || strings.HasPrefix(orBase, modelID) || strings.HasPrefix(modelID, orBase) { + ctxLen = cl + found = true + break + } + } + } + } + + if found { + // Update the live model registration (not a clone) + if registry.SetModelContextLength(modelID, ctxLen) { + enrichedCount++ + log.Debugf("enriched model %s with context_length=%d from openrouter", modelID, ctxLen) + } + } + } + + // Update store + openRouterStore.mu.Lock() + for modelID, ctxLen := range openRouterContextLengths { + openRouterStore.contextLength[modelID] = ctxLen + } + openRouterStore.lastRefresh = time.Now() + openRouterStore.mu.Unlock() + + if enrichedCount > 0 { + log.Infof("OpenRouter enrichment: enriched %d models with context_length", enrichedCount) + } + + return enrichedCount +} + +// GetOpenRouterContextLength returns the cached context_length for a model from OpenRouter. +// Returns 0 if not found. +func GetOpenRouterContextLength(modelID string) int { + openRouterStore.mu.RLock() + defer openRouterStore.mu.RUnlock() + return openRouterStore.contextLength[modelID] +} + +// TriggerOpenRouterRefresh forces an immediate refresh of OpenRouter context enrichment. +// Returns the number of models actually enriched (not the cache size). +func TriggerOpenRouterRefresh(ctx context.Context) int { + return fetchAndEnrichOpenRouter(ctx) +} + +// GetOpenRouterLastRefresh returns the last refresh time. +func GetOpenRouterLastRefresh() time.Time { + openRouterStore.mu.RLock() + defer openRouterStore.mu.RUnlock() + return openRouterStore.lastRefresh +} + +// GetOpenRouterContextLengthSource returns "openrouter" if the context_length +// for a model came from OpenRouter enrichment, empty string otherwise. +func GetOpenRouterContextLengthSource(modelID string) string { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + return "" + } + openRouterStore.mu.RLock() + defer openRouterStore.mu.RUnlock() + if _, ok := openRouterStore.contextLength[modelID]; ok { + return "openrouter" + } + return "" +} + +// GetOpenRouterEnrichedModels returns all model IDs that have been enriched +// by OpenRouter with their context_length values. +func GetOpenRouterEnrichedModels() map[string]int { + openRouterStore.mu.RLock() + defer openRouterStore.mu.RUnlock() + result := make(map[string]int, len(openRouterStore.contextLength)) + for k, v := range openRouterStore.contextLength { + result[k] = v + } + return result +} + +// BuildModelSources constructs a sources map indicating where each model's +// context_length originated (static, openrouter, provider, etc.). +func BuildModelSources(registry *ModelRegistry) map[string]string { + if registry == nil { + return nil + } + + sources := make(map[string]string) + registry.mutex.RLock() + defer registry.mutex.RUnlock() + + for modelID, registration := range registry.models { + if registration == nil || registration.Info == nil { + continue + } + + // Determine source based on model ID prefix and enrichment state + switch { + case strings.HasPrefix(modelID, "claude-"): + if GetOpenRouterContextLengthSource(modelID) != "" { + sources[modelID] = "openrouter" + } else { + sources[modelID] = "static" + } + case strings.HasPrefix(modelID, "gemini-"), strings.HasPrefix(modelID, "models/gemini-"): + sources[modelID] = "static" + case strings.HasPrefix(modelID, "gpt-"), strings.HasPrefix(modelID, "chatgpt-"), strings.HasPrefix(modelID, "o1-"): + if GetOpenRouterContextLengthSource(modelID) != "" { + sources[modelID] = "openrouter" + } else { + sources[modelID] = "provider" + } + default: + // For OpenRouter-hosted models, check enrichment + if GetOpenRouterContextLengthSource(modelID) != "" { + sources[modelID] = "openrouter" + } else { + sources[modelID] = "provider" + } + } + } + + return sources +} diff --git a/internal/temporal/temporal.go b/internal/temporal/temporal.go new file mode 100644 index 0000000000..257e22cb43 --- /dev/null +++ b/internal/temporal/temporal.go @@ -0,0 +1,157 @@ +package temporal + +import ( + "bytes" + "fmt" + "strings" + "sync/atomic" + "time" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// Config controls temporal injection behavior. +type Config struct { + Enabled bool `yaml:"enabled" json:"enabled"` + InjectInterval int `yaml:"inject_interval" json:"inject_interval"` // 0 = every request +} + +// DefaultConfig returns sensible defaults: enabled, inject every request. +func DefaultConfig() Config { + return Config{Enabled: true, InjectInterval: 0} +} + +// requestCounter tracks how many requests have been seen for interval-based injection. +var requestCounter uint64 + +// ShouldInject returns true if a temporal tag should be injected for this request. +// When interval is 0. it injects every time. Otherwise, every Nth request. +// Thread-safe: uses sync/atomic for the counter increment and read. +func ShouldInject(cfg Config) bool { + if !cfg.Enabled { + return false + } + if cfg.InjectInterval <= 0 { + return true + } + newVal := atomic.AddUint64(&requestCounter, 1) + return newVal%uint64(cfg.InjectInterval) == 0 +} + +// BuildTemporalTag generates a temporal XML tag with current date/time metadata. +// Ported from pi-rs/crates/pi-agent/src/system_prompt.rs TemporalDriftDetector pattern. +func BuildTemporalTag(now time.Time) string { + local := now.Local() + _, week := now.ISOWeek() + return fmt.Sprintf( + ``, + dayName(now.Weekday()), + now.Format("2006-01-02"), + now.Format("15:04:05"), + now.Format(time.RFC3339), + now.Unix(), + week, + now.YearDay(), + dayName(local.Weekday()), + local.Format("2006-01-02"), + local.Format("15:04:05"), + local.Location().String(), + ) +} + +// IsImageModel returns true if the model name indicates an image generation model +// that uses a different request format where prepending system messages would break +// prompt extraction. +func IsImageModel(model string) bool { + lower := strings.ToLower(model) + return strings.Contains(lower, "imagen") || strings.Contains(lower, "image-gen") +} + +// InjectIntoPayload injects temporal context into the request payload. +// It auto-detects the payload format and injects appropriately: +// - Claude format (top-level "system" field): prepends to the system array +// - OpenAI format ("messages" array, no top-level "system"): prepends system message +// - Fallback: returns payload unchanged +// Skips injection for image generation models (imagen, image-gen). +func InjectIntoPayload(payload []byte, model string) []byte { + if IsImageModel(model) { + return payload + } + // Skip if payload already contains a temporal tag (e.g. from Claude Code's + // built-in injection). Avoids duplicate/conflicting temporal context. + if bytes.Contains(payload, []byte("%s", tag) + + // Claude format: top-level "system" field exists + if system := gjson.GetBytes(payload, "system"); system.Exists() { + return injectIntoClaudeSystem(payload, system, text) + } + + // OpenAI format: messages array exists + if messages := gjson.GetBytes(payload, "messages"); messages.IsArray() { + return injectIntoMessages(payload, messages, text) + } + + return payload +} + +// injectIntoClaudeSystem prepends a temporal text block to the Claude "system" field. +// Handles both array and scalar (string) system values. +func injectIntoClaudeSystem(payload []byte, system gjson.Result, text string) []byte { + temporalBlock := map[string]any{ + "type": "text", + "text": text, + } + var newSystem []any + newSystem = append(newSystem, temporalBlock) + + if system.IsArray() { + // Modern Claude format: system is an array of content blocks + system.ForEach(func(_, v gjson.Result) bool { + newSystem = append(newSystem, v.Value()) + return true + }) + } else { + // Legacy Claude format: system is a plain string + // Convert to structured format so all blocks are homogeneous + newSystem = append(newSystem, map[string]any{ + "type": "text", + "text": system.String(), + }) + } + + result, err := sjson.SetBytes(payload, "system", newSystem) + if err != nil { + return payload + } + return result +} + +func injectIntoMessages(payload []byte, messages gjson.Result, text string) []byte { + // Use plain string content for broad compatibility. + // Some providers reject structured content arrays in system messages. + temporalMsg := map[string]any{ + "role": "system", + "content": text, + } + var newMessages []any + newMessages = append(newMessages, temporalMsg) + messages.ForEach(func(_, v gjson.Result) bool { + newMessages = append(newMessages, v.Value()) + return true + }) + result, err := sjson.SetBytes(payload, "messages", newMessages) + if err != nil { + return payload + } + return result +} + +func dayName(w time.Weekday) string { + names := []string{"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"} + return names[w] +} diff --git a/sdk/api/handlers/claude/code_handlers.go b/sdk/api/handlers/claude/code_handlers.go index 074ffc0d07..0a4f9e672c 100644 --- a/sdk/api/handlers/claude/code_handlers.go +++ b/sdk/api/handlers/claude/code_handlers.go @@ -130,6 +130,14 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) { // - c: The Gin context for the request. func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) { models := h.Models() + for _, m := range models { + if ctxLen, ok := m["context_length"]; ok { + m["context_window_size"] = ctxLen + } + if maxTok, ok := m["max_completion_tokens"]; ok { + m["max_output_tokens"] = maxTok + } + } firstID := "" lastID := "" if len(models) > 0 { diff --git a/sdk/api/handlers/openai/openai_handlers.go b/sdk/api/handlers/openai/openai_handlers.go index 4b4a9833bd..6b79f8d19e 100644 --- a/sdk/api/handlers/openai/openai_handlers.go +++ b/sdk/api/handlers/openai/openai_handlers.go @@ -59,34 +59,16 @@ func (h *OpenAIAPIHandler) Models() []map[string]any { // It returns a list of available AI models with their capabilities // and specifications in OpenAI-compatible format. func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) { - // Get all available models allModels := h.Models() - - // Filter to only include the 4 required fields: id, object, created, owned_by - filteredModels := make([]map[string]any, len(allModels)) - for i, model := range allModels { - filteredModel := map[string]any{ - "id": model["id"], - "object": model["object"], - } - - // Add created field if it exists - if created, exists := model["created"]; exists { - filteredModel["created"] = created + for _, m := range allModels { + if ctxLen, ok := m["context_length"]; ok { + m["context_window_size"] = ctxLen } - - // Add owned_by field if it exists - if ownedBy, exists := model["owned_by"]; exists { - filteredModel["owned_by"] = ownedBy + if maxTok, ok := m["max_completion_tokens"]; ok { + m["max_output_tokens"] = maxTok } - - filteredModels[i] = filteredModel } - - c.JSON(http.StatusOK, gin.H{ - "object": "list", - "data": filteredModels, - }) + c.JSON(http.StatusOK, gin.H{"object": "list", "data": allModels}) } // ChatCompletions handles the /v1/chat/completions endpoint. diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 478c7921ff..b8ceb21165 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -19,6 +19,7 @@ import ( internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/temporal" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" @@ -277,6 +278,16 @@ func (m *Manager) SetConfig(cfg *internalconfig.Config) { m.rebuildAPIKeyModelAliasFromRuntimeConfig() } +// temporalConfig returns the temporal injection settings from the current runtime config. +// Falls back to DefaultConfig when no config is loaded or the temporal block is omitted. +func (m *Manager) temporalConfig() temporal.Config { + cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config) + if cfg != nil && cfg.Temporal != nil { + return *cfg.Temporal + } + return temporal.DefaultConfig() +} + func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) string { if m == nil { return "" @@ -733,6 +744,11 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi resultModel := m.stateModelForExecution(auth, routeModel, execModel, pooled) execReq := req execReq.Model = execModel + // Inject temporal context after model resolution so image-model + // safeguards see the actual upstream model, not a route alias. + if temporal.ShouldInject(m.temporalConfig()) { + execReq.Payload = temporal.InjectIntoPayload(execReq.Payload, execModel) + } streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts) if errStream != nil { if errCtx := ctx.Err(); errCtx != nil { @@ -1213,6 +1229,11 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req resultModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled) execReq := req execReq.Model = upstreamModel + // Inject temporal context after model resolution so image-model + // safeguards see the actual upstream model, not a route alias. + if temporal.ShouldInject(m.temporalConfig()) { + execReq.Payload = temporal.InjectIntoPayload(execReq.Payload, upstreamModel) + } resp, errExec := executor.Execute(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: errExec == nil} if errExec != nil {