diff --git a/.gitignore b/.gitignore index 48172982cb..96e93332df 100644 --- a/.gitignore +++ b/.gitignore @@ -85,6 +85,7 @@ temp/ .cache/ .dev/ .serena/ +.sisyphus/ # =================== # 构建产物 @@ -129,4 +130,4 @@ deploy/docker-compose.override.yml .gocache/ vite.config.js docs/* -.serena/ \ No newline at end of file +.serena/ diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index d9ff788e21..b8f6577a8c 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -69,6 +69,7 @@ func provideCleanup( opsScheduledReport *service.OpsScheduledReportService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, + copilotModelRefresh *service.CopilotModelRefreshService, accountExpiry *service.AccountExpiryService, subscriptionExpiry *service.SubscriptionExpiryService, usageCleanup *service.UsageCleanupService, @@ -135,6 +136,12 @@ func provideCleanup( tokenRefresh.Stop() return nil }}, + {"CopilotModelRefreshService", func() error { + if copilotModelRefresh != nil { + copilotModelRefresh.Stop() + } + return nil + }}, {"AccountExpiryService", func() error { accountExpiry.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 5ccd797e95..043558f04a 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -127,15 +127,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityCache := repository.NewIdentityCache(redisClient) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) + gitHubCopilotTokenProvider := service.NewGitHubCopilotTokenProvider(geminiTokenCache, httpUpstream) gatewayCache := repository.NewGatewayCache(redisClient) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) - accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) + accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, gitHubCopilotTokenProvider, antigravityGatewayService, httpUpstream, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) + gitHubDeviceSessionStore := repository.NewGitHubDeviceSessionStore(redisClient) + gitHubDeviceAuthService := service.NewGitHubDeviceAuthService(gitHubDeviceSessionStore, httpUpstream) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) - accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator) + accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, gitHubDeviceAuthService, gitHubCopilotTokenProvider, sessionLimitCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) @@ -155,9 +158,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, gitHubCopilotTokenProvider, sessionLimitCache, digestSessionStore) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, gitHubCopilotTokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService) @@ -180,8 +183,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) + openAIMessagesCompatService := service.NewOpenAIMessagesCompatService(openAIGatewayService) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, openAIMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, pricingService, configConfig) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, gatewayService, geminiMessagesCompatService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler) @@ -196,9 +200,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) + copilotModelRefreshService := service.ProvideCopilotModelRefreshService(accountRepository, gitHubCopilotTokenProvider, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, copilotModelRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -230,6 +235,7 @@ func provideCleanup( opsScheduledReport *service.OpsScheduledReportService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, + copilotModelRefresh *service.CopilotModelRefreshService, accountExpiry *service.AccountExpiryService, subscriptionExpiry *service.SubscriptionExpiryService, usageCleanup *service.UsageCleanupService, @@ -295,6 +301,12 @@ func provideCleanup( tokenRefresh.Stop() return nil }}, + {"CopilotModelRefreshService", func() error { + if copilotModelRefresh != nil { + copilotModelRefresh.Stop() + } + return nil + }}, {"AccountExpiryService", func() error { accountExpiry.Stop() return nil diff --git a/backend/go.mod b/backend/go.mod index 08d54b91aa..30a0041cef 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -75,6 +75,7 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect @@ -144,6 +145,7 @@ require ( golang.org/x/mod v0.31.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect + golang.org/x/tools v0.40.0 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 71e8f5048a..f6fdb85172 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -116,6 +116,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 91437ba8d4..932566918b 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -38,31 +38,32 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - Ops OpsConfig `mapstructure:"ops"` - JWT JWTConfig `mapstructure:"jwt"` - Totp TotpConfig `mapstructure:"totp"` - LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` - Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` - DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` - UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + Ops OpsConfig `mapstructure:"ops"` + JWT JWTConfig `mapstructure:"jwt"` + Totp TotpConfig `mapstructure:"totp"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` + DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` + UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + CopilotModelRefresh CopilotModelRefreshConfig `mapstructure:"copilot_model_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } type GeminiConfig struct { @@ -128,6 +129,12 @@ type TokenRefreshConfig struct { RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` } +type CopilotModelRefreshConfig struct { + Enabled bool `mapstructure:"enabled"` + CheckIntervalMinutes int `mapstructure:"check_interval_minutes"` + RequestTimeoutSeconds int `mapstructure:"request_timeout_seconds"` +} + type PricingConfig struct { // 价格数据远程URL(默认使用LiteLLM镜像) RemoteURL string `mapstructure:"remote_url"` @@ -724,6 +731,8 @@ func setDefaults() { viper.SetDefault("security.url_allowlist.upstream_hosts", []string{ "api.openai.com", "api.anthropic.com", + "*.githubcopilot.com", + "api.github.com", "api.kimi.com", "open.bigmodel.cn", "api.minimaxi.com", @@ -923,6 +932,10 @@ func setDefaults() { viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 + viper.SetDefault("copilot_model_refresh.enabled", true) + viper.SetDefault("copilot_model_refresh.check_interval_minutes", 360) + viper.SetDefault("copilot_model_refresh.request_timeout_seconds", 30) + // Gemini OAuth - configure via environment variables or config file // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET // Default: uses Gemini CLI public credentials (set via environment) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 05b5adc104..32b55415c7 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -16,14 +16,56 @@ const ( RoleUser = "user" ) -// Platform constants +// Platform constants (API protocol type) const ( PlatformAnthropic = "anthropic" PlatformOpenAI = "openai" + PlatformCopilot = "copilot" + PlatformAggregator = "aggregator" PlatformGemini = "gemini" PlatformAntigravity = "antigravity" ) +// Provider constants (actual service source) +// Provider identifies the upstream service provider for a model, +// enabling differentiation of context windows and pricing for the same model name. +const ( + ProviderOpenAI = "openai" // OpenAI 官方 API + ProviderAzure = "azure" // Azure OpenAI + ProviderCopilot = "copilot" // GitHub Copilot + ProviderAnthropic = "anthropic" // Anthropic 官方 API + ProviderGemini = "gemini" // Google Gemini 官方 API + ProviderVertexAI = "vertex" // Google Vertex AI + ProviderAntigravity = "antigravity" // Antigravity 服务 + ProviderBedrock = "bedrock" // AWS Bedrock + ProviderOpenRouter = "openrouter" // OpenRouter 聚合 + ProviderAggregator = "aggregator" // 通用聚合器 +) + +// ProviderToPlatform maps provider to the API protocol (platform) it uses. +// This enables automatic platform inference from provider namespace. +var ProviderToPlatform = map[string]string{ + ProviderOpenAI: PlatformOpenAI, + ProviderAzure: PlatformOpenAI, + ProviderCopilot: PlatformCopilot, + ProviderAnthropic: PlatformAnthropic, + ProviderGemini: PlatformGemini, + ProviderVertexAI: PlatformGemini, + ProviderAntigravity: PlatformAntigravity, + ProviderBedrock: PlatformOpenAI, // Bedrock uses OpenAI-compatible format + ProviderOpenRouter: PlatformOpenAI, + ProviderAggregator: PlatformAggregator, +} + +// GetPlatformFromProvider returns the platform (API protocol) for a given provider. +// Returns empty string if provider is unknown. +func GetPlatformFromProvider(provider string) string { + if platform, ok := ProviderToPlatform[provider]; ok { + return platform + } + return "" +} + // Account type constants const ( AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) diff --git a/backend/internal/handler/admin/account_data_handler_test.go b/backend/internal/handler/admin/account_data_handler_test.go index c8b04c2aec..e516b275f2 100644 --- a/backend/internal/handler/admin/account_data_handler_test.go +++ b/backend/internal/handler/admin/account_data_handler_test.go @@ -64,6 +64,8 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) { nil, nil, nil, + nil, + nil, ) router.GET("/api/v1/admin/accounts/data", h.ExportData) diff --git a/backend/internal/handler/admin/account_github_device_auth_test.go b/backend/internal/handler/admin/account_github_device_auth_test.go new file mode 100644 index 0000000000..80b8c400ea --- /dev/null +++ b/backend/internal/handler/admin/account_github_device_auth_test.go @@ -0,0 +1,162 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type githubDeviceAuthAdminService struct { + *stubAdminService + account service.Account + updatedInput *service.UpdateAccountInput +} + +func (s *githubDeviceAuthAdminService) GetAccount(_ context.Context, id int64) (*service.Account, error) { + acc := s.account + acc.ID = id + return &acc, nil +} + +func (s *githubDeviceAuthAdminService) UpdateAccount(_ context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) { + s.updatedInput = input + acc := s.account + acc.ID = id + if input != nil { + if input.Credentials != nil { + acc.Credentials = input.Credentials + } + if strings.TrimSpace(input.Name) != "" { + acc.Name = input.Name + } + } + return &acc, nil +} + +type fakeGitHubHTTPUpstreamForAdminHandler struct{} + +func (f *fakeGitHubHTTPUpstreamForAdminHandler) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + switch req.URL.String() { + case "https://github.com/login/device/code": + body := `{"device_code":"dc1","user_code":"uc1","verification_uri":"https://github.com/login/device","verification_uri_complete":"https://github.com/login/device?user_code=uc1","expires_in":900,"interval":5}` + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(body))}, nil + case "https://github.com/login/oauth/access_token": + body := `{"access_token":"gho_xxx","token_type":"bearer","scope":"read:user"}` + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(body))}, nil + default: + return &http.Response{StatusCode: 404, Body: io.NopCloser(strings.NewReader(`{"error":"not_found"}`))}, nil + } +} + +func (f *fakeGitHubHTTPUpstreamForAdminHandler) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, _ bool) (*http.Response, error) { + return f.Do(req, proxyURL, accountID, accountConcurrency) +} + +func setupAccountGitHubDeviceAuthRouter(t *testing.T) (*gin.Engine, *githubDeviceAuthAdminService) { + gin.SetMode(gin.TestMode) + router := gin.New() + + adminSvc := &githubDeviceAuthAdminService{ + stubAdminService: newStubAdminService(), + account: service.Account{ + ID: 3, + Name: "copilot", + Type: service.AccountTypeAPIKey, + Status: service.StatusActive, + Credentials: map[string]any{ + "base_url": "https://api.githubcopilot.com", + "api_key": "sk-test", + }, + }, + } + + store := service.NewInMemoryGitHubDeviceSessionStore() + upstream := &fakeGitHubHTTPUpstreamForAdminHandler{} + deviceAuth := service.NewGitHubDeviceAuthService(store, upstream) + + accountHandler := NewAccountHandler( + adminSvc, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + deviceAuth, + nil, + nil, + nil, + ) + + router.POST("/api/v1/admin/accounts/:id/github/device/start", accountHandler.StartGitHubDeviceAuth) + router.POST("/api/v1/admin/accounts/:id/github/device/poll", accountHandler.PollGitHubDeviceAuth) + router.POST("/api/v1/admin/accounts/:id/github/device/cancel", accountHandler.CancelGitHubDeviceAuth) + + return router, adminSvc +} + +func TestAccountGitHubDeviceAuth_StartPollStoresToken(t *testing.T) { + router, adminSvc := setupAccountGitHubDeviceAuthRouter(t) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/github/device/start", bytes.NewBufferString(`{}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var startEnv struct { + Code int `json:"code"` + Data struct { + SessionID string `json:"session_id"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &startEnv)) + require.Equal(t, 0, startEnv.Code) + require.NotEmpty(t, startEnv.Data.SessionID) + pollBody, _ := json.Marshal(map[string]any{"session_id": startEnv.Data.SessionID}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/github/device/poll", bytes.NewReader(pollBody)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + require.NotNil(t, adminSvc.updatedInput) + gh, ok := adminSvc.updatedInput.Credentials["github_token"].(string) + require.True(t, ok) + require.Equal(t, "gho_xxx", gh) +} + +func TestAccountGitHubDeviceAuth_Cancel(t *testing.T) { + router, _ := setupAccountGitHubDeviceAuthRouter(t) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/github/device/start", bytes.NewBufferString(`{}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var startEnv struct { + Code int `json:"code"` + Data struct { + SessionID string `json:"session_id"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &startEnv)) + require.NotEmpty(t, startEnv.Data.SessionID) + cancelBody, _ := json.Marshal(map[string]any{"session_id": startEnv.Data.SessionID}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/github/device/cancel", bytes.NewReader(cancelBody)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 85400c6fa3..76a0289ee0 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -43,6 +43,8 @@ type AccountHandler struct { rateLimitService *service.RateLimitService accountUsageService *service.AccountUsageService accountTestService *service.AccountTestService + githubDeviceAuthService *service.GitHubDeviceAuthService + githubCopilotToken *service.GitHubCopilotTokenProvider concurrencyService *service.ConcurrencyService crsSyncService *service.CRSSyncService sessionLimitCache service.SessionLimitCache @@ -61,6 +63,8 @@ func NewAccountHandler( accountTestService *service.AccountTestService, concurrencyService *service.ConcurrencyService, crsSyncService *service.CRSSyncService, + githubDeviceAuthService *service.GitHubDeviceAuthService, + githubCopilotTokenProvider *service.GitHubCopilotTokenProvider, sessionLimitCache service.SessionLimitCache, tokenCacheInvalidator service.TokenCacheInvalidator, ) *AccountHandler { @@ -73,6 +77,8 @@ func NewAccountHandler( rateLimitService: rateLimitService, accountUsageService: accountUsageService, accountTestService: accountTestService, + githubDeviceAuthService: githubDeviceAuthService, + githubCopilotToken: githubCopilotTokenProvider, concurrencyService: concurrencyService, crsSyncService: crsSyncService, sessionLimitCache: sessionLimitCache, @@ -457,6 +463,173 @@ func (h *AccountHandler) Test(c *gin.Context) { } } +type GitHubDeviceAuthStartRequest struct { + ClientID string `json:"client_id"` + Scope string `json:"scope"` +} + +func (h *AccountHandler) StartGitHubDeviceAuth(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + if h.githubDeviceAuthService == nil { + response.InternalError(c, "GitHub device auth service not configured") + return + } + + var req GitHubDeviceAuthStartRequest + _ = c.ShouldBindJSON(&req) + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + if account == nil { + response.BadRequest(c, "Account not found") + return + } + if account.Type != service.AccountTypeAPIKey { + response.BadRequest(c, "Device auth only supports APIKey accounts") + return + } + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + if !service.IsGitHubCopilotBaseURL(baseURL) { + response.BadRequest(c, "Account base_url is not a GitHub Copilot endpoint") + return + } + + result, err := h.githubDeviceAuthService.Start(c.Request.Context(), account, req.ClientID, req.Scope) + if err != nil { + response.InternalError(c, "Start device auth failed: "+err.Error()) + return + } + response.Success(c, result) +} + +type GitHubDeviceAuthPollRequest struct { + SessionID string `json:"session_id" binding:"required"` +} + +func (h *AccountHandler) PollGitHubDeviceAuth(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + if h.githubDeviceAuthService == nil { + response.InternalError(c, "GitHub device auth service not configured") + return + } + + var req GitHubDeviceAuthPollRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + pollResult, err := h.githubDeviceAuthService.Poll(c.Request.Context(), accountID, req.SessionID) + if err != nil { + response.InternalError(c, "Poll device auth failed: "+err.Error()) + return + } + if pollResult == nil { + response.InternalError(c, "Poll device auth failed: empty result") + return + } + if pollResult.Status != "success" { + response.Success(c, pollResult) + return + } + if strings.TrimSpace(pollResult.AccessToken) == "" { + response.InternalError(c, "Poll device auth failed: empty access_token") + return + } + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + if account == nil { + response.BadRequest(c, "Account not found") + return + } + + newCredentials := make(map[string]any) + for k, v := range account.Credentials { + newCredentials[k] = v + } + newCredentials["github_token"] = strings.TrimSpace(pollResult.AccessToken) + + newExtra := make(map[string]any) + for k, v := range account.Extra { + newExtra[k] = v + } + if h.githubCopilotToken != nil { + acc := *account + acc.Credentials = newCredentials + acc.Extra = newExtra + if models, err := h.githubCopilotToken.ListModels(c.Request.Context(), &acc); err == nil && len(models) > 0 { + ids := make([]string, 0, len(models)) + for _, m := range models { + if v := strings.TrimSpace(m.ID); v != "" { + ids = append(ids, v) + } + } + if len(ids) > 0 { + now := time.Now().Format(time.RFC3339) + newExtra[service.AccountExtraKeyAvailableModels] = ids + newExtra[service.AccountExtraKeyAvailableModelsUpdatedAt] = now + newExtra[service.AccountExtraKeyAvailableModelsSource] = "github_copilot" + } + } + } + + updated, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ + Credentials: newCredentials, + Extra: newExtra, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if h.githubCopilotToken != nil { + h.githubCopilotToken.Invalidate(c.Request.Context(), updated) + } + response.Success(c, dto.AccountFromService(updated)) +} + +type GitHubDeviceAuthCancelRequest struct { + SessionID string `json:"session_id" binding:"required"` +} + +func (h *AccountHandler) CancelGitHubDeviceAuth(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + if h.githubDeviceAuthService == nil { + response.InternalError(c, "GitHub device auth service not configured") + return + } + + var req GitHubDeviceAuthCancelRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if ok := h.githubDeviceAuthService.Cancel(c.Request.Context(), accountID, req.SessionID); !ok { + response.BadRequest(c, "Session not found") + return + } + response.Success(c, gin.H{"message": "Device auth session cancelled"}) +} + // SyncFromCRS handles syncing accounts from claude-relay-service (CRS) // POST /api/v1/admin/accounts/sync/crs func (h *AccountHandler) SyncFromCRS(c *gin.Context) { @@ -1212,18 +1385,79 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { return } - // Handle OpenAI accounts - if account.IsOpenAI() { + platformPrefix := strings.TrimSpace(account.Platform) + if service.IsGitHubCopilotAccount(account) { + platformPrefix = service.PlatformCopilot + } + prefix := strings.TrimSpace(platformPrefix) + if prefix != "" { + prefix = prefix + "/" + } + + if account.Platform == service.PlatformOpenAI || account.Platform == service.PlatformCopilot || account.Platform == service.PlatformAggregator || service.IsGitHubCopilotAccount(account) { // For OAuth accounts: return default OpenAI models if account.IsOAuth() { - response.Success(c, openai.DefaultModels) + out := make([]openai.Model, 0, len(openai.DefaultModels)) + for _, m := range openai.DefaultModels { + mm := m + mm.ID = prefix + mm.ID + out = append(out, mm) + } + response.Success(c, out) + return + } + + if account.Platform == service.PlatformCopilot || service.IsGitHubCopilotAccount(account) { + if ids := account.GetAvailableModels(); len(ids) > 0 { + defaults := make(map[string]openai.Model, len(openai.DefaultModels)) + for _, dm := range openai.DefaultModels { + defaults[dm.ID] = dm + } + models := make([]openai.Model, 0, len(ids)) + for _, id := range ids { + if dm, ok := defaults[id]; ok { + m := dm + m.ID = prefix + m.ID + models = append(models, m) + continue + } + models = append(models, openai.Model{ID: prefix + id, Object: "model", Type: "model", DisplayName: id}) + } + response.Success(c, models) + return + } + if h.githubCopilotToken != nil { + if models, err := h.githubCopilotToken.ListModels(c.Request.Context(), account); err == nil && len(models) > 0 { + out := make([]openai.Model, 0, len(models)) + for _, m := range models { + mm := m + mm.ID = prefix + mm.ID + out = append(out, mm) + } + response.Success(c, out) + return + } + } + out := make([]openai.Model, 0, len(openai.DefaultModels)) + for _, m := range openai.DefaultModels { + mm := m + mm.ID = prefix + mm.ID + out = append(out, mm) + } + response.Success(c, out) return } // For API Key accounts: check model_mapping mapping := account.GetModelMapping() if len(mapping) == 0 { - response.Success(c, openai.DefaultModels) + out := make([]openai.Model, 0, len(openai.DefaultModels)) + for _, m := range openai.DefaultModels { + mm := m + mm.ID = prefix + mm.ID + out = append(out, mm) + } + response.Success(c, out) return } @@ -1233,14 +1467,16 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { var found bool for _, dm := range openai.DefaultModels { if dm.ID == requestedModel { - models = append(models, dm) + m := dm + m.ID = prefix + m.ID + models = append(models, m) found = true break } } if !found { models = append(models, openai.Model{ - ID: requestedModel, + ID: prefix + requestedModel, Object: "model", Type: "model", DisplayName: requestedModel, @@ -1255,14 +1491,26 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { if account.IsGemini() { // For OAuth accounts: return default Gemini models if account.IsOAuth() { - response.Success(c, geminicli.DefaultModels) + out := make([]geminicli.Model, 0, len(geminicli.DefaultModels)) + for _, m := range geminicli.DefaultModels { + mm := m + mm.ID = prefix + mm.ID + out = append(out, mm) + } + response.Success(c, out) return } // For API Key accounts: return models based on model_mapping mapping := account.GetModelMapping() if len(mapping) == 0 { - response.Success(c, geminicli.DefaultModels) + out := make([]geminicli.Model, 0, len(geminicli.DefaultModels)) + for _, m := range geminicli.DefaultModels { + mm := m + mm.ID = prefix + mm.ID + out = append(out, mm) + } + response.Success(c, out) return } @@ -1271,14 +1519,16 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { var found bool for _, dm := range geminicli.DefaultModels { if dm.ID == requestedModel { - models = append(models, dm) + m := dm + m.ID = prefix + m.ID + models = append(models, m) found = true break } } if !found { models = append(models, geminicli.Model{ - ID: requestedModel, + ID: prefix + requestedModel, Type: "model", DisplayName: requestedModel, CreatedAt: "", @@ -1303,7 +1553,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { // 添加 Claude 模型 for _, m := range claude.DefaultModels { models = append(models, UnifiedModel{ - ID: m.ID, + ID: prefix + m.ID, Type: m.Type, DisplayName: m.DisplayName, }) @@ -1311,8 +1561,8 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { // 添加 Gemini 3 系列模型用于测试 geminiTestModels := []UnifiedModel{ - {ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash"}, - {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview"}, + {ID: prefix + "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash"}, + {ID: prefix + "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview"}, } models = append(models, geminiTestModels...) @@ -1323,7 +1573,13 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { // Handle Claude/Anthropic accounts // For OAuth and Setup-Token accounts: return default models if account.IsOAuth() { - response.Success(c, claude.DefaultModels) + out := make([]claude.Model, 0, len(claude.DefaultModels)) + for _, m := range claude.DefaultModels { + mm := m + mm.ID = prefix + mm.ID + out = append(out, mm) + } + response.Success(c, out) return } @@ -1331,7 +1587,13 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { mapping := account.GetModelMapping() if len(mapping) == 0 { // No mapping configured, return default models - response.Success(c, claude.DefaultModels) + out := make([]claude.Model, 0, len(claude.DefaultModels)) + for _, m := range claude.DefaultModels { + mm := m + mm.ID = prefix + mm.ID + out = append(out, mm) + } + response.Success(c, out) return } @@ -1342,7 +1604,9 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { var found bool for _, dm := range claude.DefaultModels { if dm.ID == requestedModel { - models = append(models, dm) + m := dm + m.ID = prefix + m.ID + models = append(models, m) found = true break } @@ -1350,7 +1614,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { // If not found in defaults, create a basic entry if !found { models = append(models, claude.Model{ - ID: requestedModel, + ID: prefix + requestedModel, Type: "model", DisplayName: requestedModel, CreatedAt: "", @@ -1361,6 +1625,64 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { response.Success(c, models) } +func (h *AccountHandler) RefreshAvailableModels(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + if h.githubCopilotToken == nil { + response.InternalError(c, "GitHub Copilot token provider not configured") + return + } + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.NotFound(c, "Account not found") + return + } + if account == nil { + response.BadRequest(c, "Account not found") + return + } + if account.Platform != service.PlatformCopilot && !service.IsGitHubCopilotAccount(account) { + response.BadRequest(c, "Only GitHub Copilot accounts support model refresh") + return + } + + models, err := h.githubCopilotToken.ListModels(c.Request.Context(), account) + if err != nil { + response.InternalError(c, "Refresh models failed: "+err.Error()) + return + } + ids := make([]string, 0, len(models)) + for _, m := range models { + if v := strings.TrimSpace(m.ID); v != "" { + ids = append(ids, v) + } + } + if len(ids) > 0 { + now := time.Now().Format(time.RFC3339) + newExtra := make(map[string]any) + for k, v := range account.Extra { + newExtra[k] = v + } + newExtra[service.AccountExtraKeyAvailableModels] = ids + newExtra[service.AccountExtraKeyAvailableModelsUpdatedAt] = now + newExtra[service.AccountExtraKeyAvailableModelsSource] = "github_copilot" + _, _ = h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{Extra: newExtra}) + } + + prefix := service.PlatformCopilot + "/" + out := make([]openai.Model, 0, len(models)) + for _, m := range models { + mm := m + mm.ID = prefix + mm.ID + out = append(out, mm) + } + response.Success(c, out) +} + // RefreshTier handles refreshing Google One tier for a single account // POST /api/v1/admin/accounts/:id/refresh-tier func (h *AccountHandler) RefreshTier(c *gin.Context) { diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 7daaf2811c..4a09e6cdfe 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler { type CreateGroupRequest struct { Name string `json:"name" binding:"required"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai copilot aggregator gemini antigravity"` RateMultiplier float64 `json:"rate_multiplier"` IsExclusive bool `json:"is_exclusive"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` @@ -55,7 +55,7 @@ type CreateGroupRequest struct { type UpdateGroupRequest struct { Name string `json:"name"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai copilot aggregator gemini antigravity"` RateMultiplier *float64 `json:"rate_multiplier"` IsExclusive *bool `json:"is_exclusive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"` diff --git a/backend/internal/handler/capture_writer.go b/backend/internal/handler/capture_writer.go new file mode 100644 index 0000000000..0a22914d57 --- /dev/null +++ b/backend/internal/handler/capture_writer.go @@ -0,0 +1,76 @@ +package handler + +import ( + "bufio" + "bytes" + "errors" + "net" + "net/http" + + "github.com/gin-gonic/gin" +) + +type captureWriter struct { + base gin.ResponseWriter + header http.Header + status int + size int + buf bytes.Buffer +} + +func newCaptureWriter(base gin.ResponseWriter) *captureWriter { + return &captureWriter{ + base: base, + header: make(http.Header), + status: http.StatusOK, + } +} + +func (w *captureWriter) Header() http.Header { return w.header } + +func (w *captureWriter) WriteHeader(code int) { w.status = code } + +func (w *captureWriter) WriteHeaderNow() {} + +func (w *captureWriter) Write(p []byte) (int, error) { + n, err := w.buf.Write(p) + w.size += n + return n, err +} + +func (w *captureWriter) WriteString(s string) (int, error) { + n, err := w.buf.WriteString(s) + w.size += n + return n, err +} + +func (w *captureWriter) Status() int { return w.status } + +func (w *captureWriter) Size() int { return w.size } + +func (w *captureWriter) Written() bool { return w.size > 0 } + +func (w *captureWriter) Flush() {} + +func (w *captureWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hj, ok := w.base.(http.Hijacker); ok { + return hj.Hijack() + } + return nil, nil, errors.New("hijack not supported") +} + +func (w *captureWriter) CloseNotify() <-chan bool { + if cn, ok := w.base.(interface{ CloseNotify() <-chan bool }); ok { + return cn.CloseNotify() + } + ch := make(chan bool) + close(ch) + return ch +} + +func (w *captureWriter) Pusher() http.Pusher { + if p, ok := w.base.(interface{ Pusher() http.Pusher }); ok { + return p.Pusher() + } + return nil +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index c2b6bf0960..a1f213baaa 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -15,7 +15,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" - "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" @@ -30,12 +29,14 @@ import ( type GatewayHandler struct { gatewayService *service.GatewayService geminiCompatService *service.GeminiMessagesCompatService + openaiCompatService *service.OpenAIMessagesCompatService antigravityGatewayService *service.AntigravityGatewayService userService *service.UserService billingCacheService *service.BillingCacheService usageService *service.UsageService apiKeyService *service.APIKeyService errorPassthroughService *service.ErrorPassthroughService + pricingService *service.PricingService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int maxAccountSwitchesGemini int @@ -45,6 +46,7 @@ type GatewayHandler struct { func NewGatewayHandler( gatewayService *service.GatewayService, geminiCompatService *service.GeminiMessagesCompatService, + openaiCompatService *service.OpenAIMessagesCompatService, antigravityGatewayService *service.AntigravityGatewayService, userService *service.UserService, concurrencyService *service.ConcurrencyService, @@ -52,6 +54,7 @@ func NewGatewayHandler( usageService *service.UsageService, apiKeyService *service.APIKeyService, errorPassthroughService *service.ErrorPassthroughService, + pricingService *service.PricingService, cfg *config.Config, ) *GatewayHandler { pingInterval := time.Duration(0) @@ -69,12 +72,14 @@ func NewGatewayHandler( return &GatewayHandler{ gatewayService: gatewayService, geminiCompatService: geminiCompatService, + openaiCompatService: openaiCompatService, antigravityGatewayService: antigravityGatewayService, userService: userService, billingCacheService: billingCacheService, usageService: usageService, apiKeyService: apiKeyService, errorPassthroughService: errorPassthroughService, + pricingService: pricingService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, @@ -145,6 +150,18 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + if ns := service.ParseModelNamespace(reqModel); ns.HasNamespace { + reqModel = ns.Model + parsedReq.Model = reqModel + body = replaceModelFieldRaw(body, reqModel) + parsedReq.Body = body + if ns.Platform != "" && !middleware2.HasForcePlatform(c) { + ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, ns.Platform) + c.Request = c.Request.WithContext(ctx) + c.Set(string(middleware2.ContextKeyForcePlatform), ns.Platform) + } + } + // Track if we've started streaming (for error handling) streamStarted := false @@ -156,6 +173,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 获取订阅信息(可能为nil)- 提前获取用于后续检查 subscription, _ := middleware2.GetSubscriptionFromContext(c) + effectiveAPIKey, err := h.resolveEffectiveAPIKey(c, apiKey, reqModel) + if err != nil { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No accessible groups: "+err.Error()) + return + } + apiKey = effectiveAPIKey + // 0. 检查wait队列是否已满 maxWait := service.CalculateMaxWait(subject.Concurrency) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) @@ -554,6 +578,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) + } else if account.Platform == service.PlatformCopilot { + if service.IsClaudeModelID(reqModel) { + result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) + } else { + result, err = h.openaiCompatService.Forward(requestCtx, c, account, parsedReq) + } + } else if account.Platform == service.PlatformOpenAI || account.Platform == service.PlatformAggregator { + result, err = h.openaiCompatService.Forward(requestCtx, c, account, parsedReq) } else { result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) } @@ -670,53 +702,107 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // Models handles listing available models // GET /v1/models -// Returns models based on account configurations (model_mapping whitelist) -// Falls back to default models if no whitelist is configured func (h *GatewayHandler) Models(c *gin.Context) { apiKey, _ := middleware2.GetAPIKeyFromContext(c) - var groupID *int64 - var platform string + var allowedGroups []int64 + if apiKey != nil && apiKey.User != nil { + allowedGroups = apiKey.User.AllowedGroups + } - if apiKey != nil && apiKey.Group != nil { - groupID = &apiKey.Group.ID - platform = apiKey.Group.Platform + groupIDs, err := h.gatewayService.GetAccessibleGroupIDs(c.Request.Context(), allowedGroups) + if err != nil { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get accessible groups") + return } - // Get available models from account configurations (without platform filter) - availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") + availableModels := h.gatewayService.GetAvailableModelsByGroupIDs(c.Request.Context(), groupIDs, "") if len(availableModels) > 0 { - // Build model list from whitelist - models := make([]claude.Model, 0, len(availableModels)) - for _, modelID := range availableModels { - models = append(models, claude.Model{ - ID: modelID, - Type: "model", - DisplayName: modelID, - CreatedAt: "2024-01-01T00:00:00Z", - }) - } - c.JSON(http.StatusOK, gin.H{ - "object": "list", - "data": models, - }) + models := h.buildModelListWithPricing("", availableModels) + c.JSON(http.StatusOK, gin.H{"object": "list", "data": models}) return } - // Fallback to default models - if platform == "openai" { - c.JSON(http.StatusOK, gin.H{ - "object": "list", - "data": openai.DefaultModels, - }) - return + models := h.buildDefaultModelList(domain.ProviderOpenAI) + c.JSON(http.StatusOK, gin.H{"object": "list", "data": models}) +} + +func (h *GatewayHandler) buildModelListWithPricing(provider string, modelIDs []string) []openai.Model { + models := make([]openai.Model, 0, len(modelIDs)) + defaults := make(map[string]openai.Model, len(openai.DefaultModels)) + for _, dm := range openai.DefaultModels { + defaults[dm.ID] = dm } - c.JSON(http.StatusOK, gin.H{ - "object": "list", - "data": claude.DefaultModels, - }) + for _, modelID := range modelIDs { + raw := strings.TrimSpace(modelID) + if raw == "" { + continue + } + + modelProvider := strings.TrimSpace(provider) + modelName := raw + if idx := strings.Index(raw, "/"); idx > 0 && idx < len(raw)-1 { + modelProvider = strings.ToLower(strings.TrimSpace(raw[:idx])) + modelName = strings.TrimSpace(raw[idx+1:]) + } + + nsID := namespaceModelID(modelProvider, modelName) + m := openai.Model{ID: nsID, Object: "model", Type: "model", DisplayName: modelName} + if dm, ok := defaults[modelName]; ok { + m.DisplayName = dm.DisplayName + } + + if h.pricingService != nil { + if info := h.pricingService.GetModelInfo(modelProvider, modelName); info != nil { + m.ContextWindow = info.ContextWindow + m.MaxOutputTokens = info.MaxOutputTokens + m.Source = info.Source + } + } + + models = append(models, m) + } + return models +} + +func (h *GatewayHandler) buildDefaultModelList(provider string) []openai.Model { + models := make([]openai.Model, 0, len(openai.DefaultModels)) + for _, m := range openai.DefaultModels { + mm := m + mm.ID = namespaceModelID(provider, mm.ID) + if h.pricingService != nil { + if info := h.pricingService.GetModelInfo(provider, m.ID); info != nil { + mm.ContextWindow = info.ContextWindow + mm.MaxOutputTokens = info.MaxOutputTokens + mm.Source = info.Source + } + } + models = append(models, mm) + } + return models +} + +func namespaceModelID(provider string, modelID string) string { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + return "" + } + if strings.Contains(modelID, "/") { + return modelID + } + if service.IsClaudeModelID(modelID) { + return domain.ProviderAnthropic + "/" + modelID + } + if service.IsGeminiModelID(modelID) { + return domain.ProviderGemini + "/" + modelID + } + provider = strings.TrimSpace(provider) + if provider == "" { + return modelID + } + return provider + "/" + modelID } // AntigravityModels 返回 Antigravity 支持的全部模型 @@ -739,6 +825,44 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service return &cloned } +func (h *GatewayHandler) resolveEffectiveAPIKey(c *gin.Context, apiKey *service.APIKey, requestedModel string) (*service.APIKey, error) { + if apiKey.GroupID != nil && apiKey.Group != nil { + return apiKey, nil + } + + allowedGroups := []int64{} + if apiKey.User != nil { + allowedGroups = apiKey.User.AllowedGroups + } + + group, err := h.gatewayService.ResolveGroupFromUserPermission(c.Request.Context(), allowedGroups, requestedModel) + if err != nil { + return nil, err + } + + return cloneAPIKeyWithGroup(apiKey, group), nil +} + +func replaceModelFieldRaw(body []byte, model string) []byte { + if len(body) == 0 { + return body + } + var req map[string]json.RawMessage + if err := json.Unmarshal(body, &req); err != nil { + return body + } + modelBytes, err := json.Marshal(model) + if err != nil { + return body + } + req["model"] = modelBytes + newBody, err := json.Marshal(req) + if err != nil { + return body + } + return newBody +} + // Usage handles getting account balance and usage statistics for CC Switch integration // GET /v1/usage func (h *GatewayHandler) Usage(c *gin.Context) { @@ -1103,8 +1227,26 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } + if ns := service.ParseModelNamespace(parsedReq.Model); ns.HasNamespace { + parsedReq.Model = ns.Model + body = replaceModelFieldRaw(body, ns.Model) + parsedReq.Body = body + if ns.Platform != "" && !middleware2.HasForcePlatform(c) { + ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, ns.Platform) + c.Request = c.Request.WithContext(ctx) + c.Set(string(middleware2.ContextKeyForcePlatform), ns.Platform) + } + } + setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body) + effectiveAPIKey, err := h.resolveEffectiveAPIKey(c, apiKey, parsedReq.Model) + if err != nil { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No accessible groups: "+err.Error()) + return + } + apiKey = effectiveAPIKey + // 获取订阅信息(可能为nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) @@ -1133,6 +1275,14 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { setOpsSelectedAccount(c, account.ID) // 转发请求(不记录使用量) + if account.Platform == service.PlatformOpenAI || account.Platform == service.PlatformCopilot || account.Platform == service.PlatformAggregator { + if err := h.openaiCompatService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil { + log.Printf("Forward count_tokens request failed: %v", err) + // 错误响应已在 ForwardCountTokens 中处理 + return + } + return + } if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil { log.Printf("Forward count_tokens request failed: %v", err) // 错误响应已在 ForwardCountTokens 中处理 diff --git a/backend/internal/handler/gateway_handler_count_tokens_groupless_test.go b/backend/internal/handler/gateway_handler_count_tokens_groupless_test.go new file mode 100644 index 0000000000..4967586f1b --- /dev/null +++ b/backend/internal/handler/gateway_handler_count_tokens_groupless_test.go @@ -0,0 +1,351 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + mw "github.com/Wei-Shaw/sub2api/internal/server/middleware" + svc "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type billingCacheStubForCountTokens struct { + balance float64 + sub *svc.SubscriptionCacheData +} + +func (s *billingCacheStubForCountTokens) GetUserBalance(ctx context.Context, userID int64) (float64, error) { + return s.balance, nil +} + +func (s *billingCacheStubForCountTokens) SetUserBalance(ctx context.Context, userID int64, balance float64) error { + return nil +} + +func (s *billingCacheStubForCountTokens) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { + return nil +} + +func (s *billingCacheStubForCountTokens) InvalidateUserBalance(ctx context.Context, userID int64) error { + return nil +} + +func (s *billingCacheStubForCountTokens) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*svc.SubscriptionCacheData, error) { + return s.sub, nil +} + +func (s *billingCacheStubForCountTokens) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *svc.SubscriptionCacheData) error { + return nil +} + +func (s *billingCacheStubForCountTokens) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { + return nil +} + +func (s *billingCacheStubForCountTokens) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { + return nil +} + +type accountRepoStubForCountTokens struct { + listByGroupAccounts []svc.Account +} + +func (s *accountRepoStubForCountTokens) Create(ctx context.Context, account *svc.Account) error { + panic("unexpected Create call") +} + +func (s *accountRepoStubForCountTokens) GetByID(ctx context.Context, id int64) (*svc.Account, error) { + panic("unexpected GetByID call") +} + +func (s *accountRepoStubForCountTokens) GetByIDs(ctx context.Context, ids []int64) ([]*svc.Account, error) { + panic("unexpected GetByIDs call") +} + +func (s *accountRepoStubForCountTokens) ExistsByID(ctx context.Context, id int64) (bool, error) { + panic("unexpected ExistsByID call") +} + +func (s *accountRepoStubForCountTokens) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*svc.Account, error) { + panic("unexpected GetByCRSAccountID call") +} + +func (s *accountRepoStubForCountTokens) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { + panic("unexpected ListCRSAccountIDs call") +} + +func (s *accountRepoStubForCountTokens) Update(ctx context.Context, account *svc.Account) error { + panic("unexpected Update call") +} + +func (s *accountRepoStubForCountTokens) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (s *accountRepoStubForCountTokens) List(ctx context.Context, params pagination.PaginationParams) ([]svc.Account, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *accountRepoStubForCountTokens) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]svc.Account, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *accountRepoStubForCountTokens) ListByGroup(ctx context.Context, groupID int64) ([]svc.Account, error) { + panic("unexpected ListByGroup call") +} + +func (s *accountRepoStubForCountTokens) ListActive(ctx context.Context) ([]svc.Account, error) { + panic("unexpected ListActive call") +} + +func (s *accountRepoStubForCountTokens) ListByPlatform(ctx context.Context, platform string) ([]svc.Account, error) { + panic("unexpected ListByPlatform call") +} + +func (s *accountRepoStubForCountTokens) UpdateLastUsed(ctx context.Context, id int64) error { + panic("unexpected UpdateLastUsed call") +} + +func (s *accountRepoStubForCountTokens) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { + panic("unexpected BatchUpdateLastUsed call") +} + +func (s *accountRepoStubForCountTokens) SetError(ctx context.Context, id int64, errorMsg string) error { + panic("unexpected SetError call") +} + +func (s *accountRepoStubForCountTokens) ClearError(ctx context.Context, id int64) error { + panic("unexpected ClearError call") +} + +func (s *accountRepoStubForCountTokens) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { + panic("unexpected SetSchedulable call") +} + +func (s *accountRepoStubForCountTokens) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) { + panic("unexpected AutoPauseExpiredAccounts call") +} + +func (s *accountRepoStubForCountTokens) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { + panic("unexpected BindGroups call") +} + +func (s *accountRepoStubForCountTokens) ListSchedulable(ctx context.Context) ([]svc.Account, error) { + panic("unexpected ListSchedulable call") +} + +func (s *accountRepoStubForCountTokens) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]svc.Account, error) { + return append([]svc.Account(nil), s.listByGroupAccounts...), nil +} + +func (s *accountRepoStubForCountTokens) ListSchedulableByGroupIDs(ctx context.Context, groupIDs []int64) ([]svc.Account, error) { + panic("unexpected ListSchedulableByGroupIDs call") +} + +func (s *accountRepoStubForCountTokens) ListSchedulableByPlatform(ctx context.Context, platform string) ([]svc.Account, error) { + panic("unexpected ListSchedulableByPlatform call") +} + +func (s *accountRepoStubForCountTokens) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]svc.Account, error) { + return nil, errors.New("boom") +} + +func (s *accountRepoStubForCountTokens) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]svc.Account, error) { + panic("unexpected ListSchedulableByPlatforms call") +} + +func (s *accountRepoStubForCountTokens) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]svc.Account, error) { + panic("unexpected ListSchedulableByGroupIDAndPlatforms call") +} + +func (s *accountRepoStubForCountTokens) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { + panic("unexpected SetRateLimited call") +} + +func (s *accountRepoStubForCountTokens) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { + panic("unexpected SetModelRateLimit call") +} + +func (s *accountRepoStubForCountTokens) SetOverloaded(ctx context.Context, id int64, until time.Time) error { + panic("unexpected SetOverloaded call") +} + +func (s *accountRepoStubForCountTokens) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + panic("unexpected SetTempUnschedulable call") +} + +func (s *accountRepoStubForCountTokens) ClearTempUnschedulable(ctx context.Context, id int64) error { + panic("unexpected ClearTempUnschedulable call") +} + +func (s *accountRepoStubForCountTokens) ClearRateLimit(ctx context.Context, id int64) error { + panic("unexpected ClearRateLimit call") +} + +func (s *accountRepoStubForCountTokens) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + panic("unexpected ClearAntigravityQuotaScopes call") +} + +func (s *accountRepoStubForCountTokens) ClearModelRateLimits(ctx context.Context, id int64) error { + panic("unexpected ClearModelRateLimits call") +} + +func (s *accountRepoStubForCountTokens) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { + panic("unexpected UpdateSessionWindow call") +} + +func (s *accountRepoStubForCountTokens) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { + panic("unexpected UpdateExtra call") +} + +func (s *accountRepoStubForCountTokens) BulkUpdate(ctx context.Context, ids []int64, updates svc.AccountBulkUpdate) (int64, error) { + panic("unexpected BulkUpdate call") +} + +type groupRepoStubForCountTokens struct { + group *svc.Group +} + +func (s *groupRepoStubForCountTokens) Create(ctx context.Context, group *svc.Group) error { + panic("unexpected Create call") +} + +func (s *groupRepoStubForCountTokens) GetByID(ctx context.Context, id int64) (*svc.Group, error) { + panic("unexpected GetByID call") +} + +func (s *groupRepoStubForCountTokens) GetByIDLite(ctx context.Context, id int64) (*svc.Group, error) { + if s.group != nil && s.group.ID == id { + return s.group, nil + } + return nil, svc.ErrGroupNotFound +} + +func (s *groupRepoStubForCountTokens) Update(ctx context.Context, group *svc.Group) error { + panic("unexpected Update call") +} + +func (s *groupRepoStubForCountTokens) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (s *groupRepoStubForCountTokens) DeleteCascade(ctx context.Context, id int64) ([]int64, error) { + panic("unexpected DeleteCascade call") +} + +func (s *groupRepoStubForCountTokens) List(ctx context.Context, params pagination.PaginationParams) ([]svc.Group, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *groupRepoStubForCountTokens) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]svc.Group, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *groupRepoStubForCountTokens) ListActive(ctx context.Context) ([]svc.Group, error) { + panic("unexpected ListActive call") +} + +func (s *groupRepoStubForCountTokens) ListActiveByPlatform(ctx context.Context, platform string) ([]svc.Group, error) { + panic("unexpected ListActiveByPlatform call") +} + +func (s *groupRepoStubForCountTokens) ListPublicGroupIDs(ctx context.Context) ([]int64, error) { + return nil, nil +} + +func (s *groupRepoStubForCountTokens) ExistsByName(ctx context.Context, name string) (bool, error) { + panic("unexpected ExistsByName call") +} + +func (s *groupRepoStubForCountTokens) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected GetAccountCount call") +} + +func (s *groupRepoStubForCountTokens) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected DeleteAccountGroupsByGroupID call") +} + +func (s *groupRepoStubForCountTokens) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) { + panic("unexpected GetAccountIDsByGroupIDs call") +} + +func (s *groupRepoStubForCountTokens) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error { + panic("unexpected BindAccountsToGroup call") +} + +func (s *groupRepoStubForCountTokens) UpdateSortOrders(ctx context.Context, updates []svc.GroupSortOrderUpdate) error { + panic("unexpected UpdateSortOrders call") +} + +func TestGatewayHandler_CountTokens_GroupLessKeyResolvesGroupBeforeBilling(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(7) + user := &svc.User{ID: 1, AllowedGroups: []int64{groupID}} + apiKey := &svc.APIKey{ID: 10, User: user} + + cache := &billingCacheStubForCountTokens{ + balance: 0, + sub: &svc.SubscriptionCacheData{ + Status: svc.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }, + } + billingSvc := svc.NewBillingCacheService(cache, nil, nil, &config.Config{RunMode: config.RunModeStandard}) + t.Cleanup(billingSvc.Stop) + + group := &svc.Group{ID: groupID, Platform: svc.PlatformOpenAI, SubscriptionType: svc.SubscriptionTypeSubscription, Status: svc.StatusActive} + groupRepo := &groupRepoStubForCountTokens{group: group} + accountRepo := &accountRepoStubForCountTokens{listByGroupAccounts: []svc.Account{{ID: 1, Platform: svc.PlatformOpenAI, Status: svc.StatusActive, Schedulable: true}}} + gatewaySvc := svc.NewGatewayService( + accountRepo, + groupRepo, + nil, + nil, + nil, + nil, + nil, + &config.Config{RunMode: config.RunModeStandard}, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + h := &GatewayHandler{ + gatewayService: gatewaySvc, + billingCacheService: billingSvc, + } + + body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + c.Set(string(mw.ContextKeyAPIKey), apiKey) + c.Set(string(mw.ContextKeyUser), mw.AuthSubject{UserID: user.ID, Concurrency: 1}) + + h.CountTokens(c) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) +} diff --git a/backend/internal/handler/gateway_handler_models_test.go b/backend/internal/handler/gateway_handler_models_test.go new file mode 100644 index 0000000000..bec9a02071 --- /dev/null +++ b/backend/internal/handler/gateway_handler_models_test.go @@ -0,0 +1,176 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + mw "github.com/Wei-Shaw/sub2api/internal/server/middleware" + svc "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type accountRepoStubForModels struct { + accountRepoStubForCountTokens + listByGroupIDsAccounts []svc.Account +} + +func (s *accountRepoStubForModels) ListSchedulableByGroupIDs(ctx context.Context, groupIDs []int64) ([]svc.Account, error) { + return append([]svc.Account(nil), s.listByGroupIDsAccounts...), nil +} + +func TestGatewayHandler_Models_ReturnsNamespacedModelsWithPricing(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(7) + user := &svc.User{ID: 1, AllowedGroups: []int64{groupID}} + apiKey := &svc.APIKey{ID: 10, User: user} + + fallbackFile := filepath.Join("..", "..", "resources", "model-pricing", "model_prices_and_context_window.json") + cfg := &config.Config{ + RunMode: config.RunModeStandard, + Pricing: config.PricingConfig{ + DataDir: t.TempDir(), + FallbackFile: fallbackFile, + UpdateIntervalHours: 24, + HashCheckIntervalMinutes: 10, + }, + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + Enabled: false, + AllowInsecureHTTP: false, + }, + }, + } + + pricingSvc := svc.NewPricingService(cfg, nil) + require.NoError(t, pricingSvc.Initialize()) + t.Cleanup(pricingSvc.Stop) + + group := &svc.Group{ID: groupID, Platform: svc.PlatformOpenAI, Status: svc.StatusActive} + groupRepo := &groupRepoStubForCountTokens{group: group} + + accounts := []svc.Account{ + { + ID: 1, + Platform: svc.PlatformOpenAI, + Type: svc.AccountTypeUpstream, + Status: svc.StatusActive, + Schedulable: true, + Credentials: map[string]any{ + "base_url": "https://example-resource.openai.azure.com", + "model_mapping": map[string]any{ + "gpt-5-chat": "gpt-5-chat", + }, + }, + }, + { + ID: 2, + Platform: svc.PlatformCopilot, + Type: svc.AccountTypeOAuth, + Status: svc.StatusActive, + Schedulable: true, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.2": "gpt-5.2", + }, + }, + }, + { + ID: 3, + Platform: svc.PlatformAggregator, + Type: svc.AccountTypeUpstream, + Status: svc.StatusActive, + Schedulable: true, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.2": "gpt-5.2", + }, + }, + }, + { + ID: 4, + Platform: svc.PlatformOpenAI, + Type: svc.AccountTypeUpstream, + Status: svc.StatusActive, + Schedulable: true, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.2": "gpt-5.2", + }, + }, + }, + } + accountRepo := &accountRepoStubForModels{listByGroupIDsAccounts: accounts} + + gatewaySvc := svc.NewGatewayService( + accountRepo, + groupRepo, + nil, + nil, + nil, + nil, + nil, + cfg, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + h := &GatewayHandler{gatewayService: gatewaySvc, pricingService: pricingSvc} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(mw.ContextKeyAPIKey), apiKey) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Object string `json:"object"` + Data []svc.ModelInfo `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "list", resp.Object) + + byID := make(map[string]svc.ModelInfo, len(resp.Data)) + for _, m := range resp.Data { + byID[m.ID] = m + } + + azure := byID["azure/gpt-5-chat"] + require.Equal(t, 272000, azure.ContextWindow) + require.Equal(t, 128000, azure.MaxOutputTokens) + require.Equal(t, "https://azure.microsoft.com/en-us/blog/gpt-5-in-azure-ai-foundry-the-future-of-ai-apps-and-agents-starts-here/", azure.Source) + + openai := byID["openai/gpt-5.2"] + require.Equal(t, 400000, openai.ContextWindow) + require.Equal(t, 128000, openai.MaxOutputTokens) + + copilot := byID["copilot/gpt-5.2"] + require.Equal(t, 400000, copilot.ContextWindow) + require.Equal(t, 128000, copilot.MaxOutputTokens) + + aggregator := byID["aggregator/gpt-5.2"] + require.Equal(t, 400000, aggregator.ContextWindow) + require.Equal(t, 128000, aggregator.MaxOutputTokens) +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 3d25505ba9..aecdbecf21 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -7,6 +7,7 @@ import ( "encoding/hex" "encoding/json" "errors" + "fmt" "io" "log" "net/http" @@ -144,18 +145,27 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { return } - // 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组 - if !middleware.HasForcePlatform(c) { + modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/")) + if err != nil { + googleError(c, http.StatusNotFound, err.Error()) + return + } + + hadForcePlatform := middleware.HasForcePlatform(c) + if !hadForcePlatform { if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini { googleError(c, http.StatusBadRequest, "API key group platform is not gemini") return } } - modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/")) - if err != nil { - googleError(c, http.StatusNotFound, err.Error()) - return + if ns := service.ParseModelNamespace(modelName); ns.HasNamespace { + modelName = ns.Model + if ns.Platform != "" && !hadForcePlatform { + ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, ns.Platform) + c.Request = c.Request.WithContext(ctx) + c.Set(string(middleware.ContextKeyForcePlatform), ns.Platform) + } } stream := action == "streamGenerateContent" @@ -179,6 +189,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // Get subscription (may be nil) subscription, _ := middleware.GetSubscriptionFromContext(c) + effectiveAPIKey, err := h.resolveEffectiveAPIKey(c, apiKey, modelName) + if err != nil { + googleError(c, http.StatusServiceUnavailable, "No accessible groups: "+err.Error()) + return + } + apiKey = effectiveAPIKey + // For Gemini native API, do not send Claude-style ping frames. geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0) @@ -434,8 +451,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) - } else { + } else if account.Platform == service.PlatformGemini { result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body) + } else { + result, err = h.forwardGeminiNativeCrossPlatform(requestCtx, c, account, modelName, action, stream, body) } if accountReleaseFunc != nil { accountReleaseFunc() @@ -486,23 +505,39 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } } - // 6) record usage async (Gemini 使用长上下文双倍计费) go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: ip, - LongContextThreshold: 200000, // Gemini 200K 阈值 - LongContextMultiplier: 2.0, // 超出部分双倍计费 - ForceCacheBilling: fcb, - APIKeyService: h.apiKeyService, + if usedAccount.Platform == service.PlatformGemini || usedAccount.Platform == service.PlatformAntigravity { + if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: ip, + LongContextThreshold: 200000, + LongContextMultiplier: 2.0, + ForceCacheBilling: fcb, + APIKeyService: h.apiKeyService, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + return + } + + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: ip, + ForceCacheBilling: fcb, + APIKeyService: h.apiKeyService, }); err != nil { log.Printf("Record usage failed: %v", err) } @@ -511,6 +546,120 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } } +func (h *GatewayHandler) forwardGeminiNativeCrossPlatform(ctx context.Context, c *gin.Context, account *service.Account, modelName string, action string, stream bool, body []byte) (*service.ForwardResult, error) { + startTime := time.Now() + + claudeBody, err := service.ConvertGeminiNativeRequestToClaudeMessages(modelName, body) + if err != nil { + googleError(c, http.StatusBadRequest, err.Error()) + return nil, err + } + + parsed, err := service.ParseGatewayRequest(claudeBody, domain.PlatformAnthropic) + if err != nil { + googleError(c, http.StatusBadRequest, "Failed to parse request") + return nil, err + } + + origWriter := c.Writer + cw := newCaptureWriter(origWriter) + c.Writer = cw + defer func() { c.Writer = origWriter }() + + if action == "countTokens" { + var countErr error + if account.Platform == service.PlatformOpenAI || account.Platform == service.PlatformAggregator || (account.Platform == service.PlatformCopilot && !service.IsClaudeModelID(modelName)) { + countErr = h.openaiCompatService.ForwardCountTokens(ctx, c, account, parsed) + } else { + countErr = h.gatewayService.ForwardCountTokens(ctx, c, account, parsed) + } + c.Writer = origWriter + + if countErr != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(countErr, &failoverErr) { + return nil, failoverErr + } + googleError(c, http.StatusBadGateway, "Upstream request failed") + return nil, countErr + } + + var parsedResp map[string]any + if json.Unmarshal(cw.buf.Bytes(), &parsedResp) != nil { + googleError(c, http.StatusBadGateway, "Failed to parse upstream response") + return nil, errors.New("failed to parse countTokens response") + } + inputTokens, _ := parsedResp["input_tokens"].(float64) + c.JSON(http.StatusOK, map[string]any{"totalTokens": int(inputTokens)}) + return &service.ForwardResult{ + RequestID: cw.header.Get("x-request-id"), + Usage: service.ClaudeUsage{}, + Model: modelName, + Stream: false, + Duration: time.Since(startTime), + FirstTokenMs: nil, + }, nil + } + + claudeStream := false + parsed.Stream = claudeStream + claudeReq := map[string]any{} + if json.Unmarshal(claudeBody, &claudeReq) == nil { + claudeReq["stream"] = false + claudeBody, _ = json.Marshal(claudeReq) + parsed.Body = claudeBody + } + + var result *service.ForwardResult + if account.Platform == service.PlatformOpenAI || account.Platform == service.PlatformAggregator || (account.Platform == service.PlatformCopilot && !service.IsClaudeModelID(modelName)) { + result, err = h.openaiCompatService.Forward(ctx, c, account, parsed) + } else { + result, err = h.gatewayService.Forward(ctx, c, account, parsed) + } + c.Writer = origWriter + + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + return nil, failoverErr + } + googleError(c, http.StatusBadGateway, "Upstream request failed") + return nil, err + } + + if rid := strings.TrimSpace(cw.header.Get("x-request-id")); rid != "" { + c.Header("x-request-id", rid) + } + + var claudeResp map[string]any + if err := json.Unmarshal(cw.buf.Bytes(), &claudeResp); err != nil { + googleError(c, http.StatusBadGateway, "Failed to parse upstream response") + return nil, err + } + geminiResp, err := service.ConvertClaudeMessageToGeminiResponse(claudeResp, &result.Usage) + if err != nil { + googleError(c, http.StatusBadGateway, "Failed to convert upstream response") + return nil, err + } + + if !stream { + c.JSON(http.StatusOK, geminiResp) + return result, nil + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if b, err := json.Marshal(geminiResp); err == nil { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", b) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } + } + return result, nil +} + func parseGeminiModelAction(rest string) (model string, action string, err error) { rest = strings.TrimSpace(rest) if rest == "" { diff --git a/backend/internal/handler/gemini_v1beta_handler_test.go b/backend/internal/handler/gemini_v1beta_handler_test.go index 82b30ee46e..1f0d18697a 100644 --- a/backend/internal/handler/gemini_v1beta_handler_test.go +++ b/backend/internal/handler/gemini_v1beta_handler_test.go @@ -30,16 +30,24 @@ func TestGeminiV1BetaHandler_PlatformRoutingInvariant(t *testing.T) { expectedService: "AntigravityGatewayService.ForwardGemini", description: "Antigravity 账户通过 CRS 中转,支持 Gemini 协议", }, + { + name: "非Gemini平台使用跨协议转发", + platform: service.PlatformOpenAI, + expectedService: "GatewayHandler.forwardGeminiNativeCrossPlatform", + description: "Gemini 原生协议可通过 namespace 调用其它平台模型", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 模拟 GeminiV1BetaModels 中的路由决策 (lines 199-205 in gemini_v1beta_handler.go) var routedService string - if tt.platform == service.PlatformAntigravity { + switch tt.platform { + case service.PlatformAntigravity: routedService = "AntigravityGatewayService.ForwardGemini" - } else { + case service.PlatformGemini: routedService = "GeminiMessagesCompatService.ForwardNative" + default: + routedService = "GatewayHandler.forwardGeminiNativeCrossPlatform" } require.Equal(t, tt.expectedService, routedService, diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index c08a8b0ec5..1a00b16571 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -2,6 +2,8 @@ package handler import ( "context" + "crypto/rand" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -12,6 +14,8 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/domain" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -22,7 +26,9 @@ import ( // OpenAIGatewayHandler handles OpenAI API gateway requests type OpenAIGatewayHandler struct { - gatewayService *service.OpenAIGatewayService + openaiGatewayService *service.OpenAIGatewayService + claudeGatewayService *service.GatewayService + geminiCompatService *service.GeminiMessagesCompatService billingCacheService *service.BillingCacheService apiKeyService *service.APIKeyService errorPassthroughService *service.ErrorPassthroughService @@ -32,7 +38,9 @@ type OpenAIGatewayHandler struct { // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler func NewOpenAIGatewayHandler( - gatewayService *service.OpenAIGatewayService, + openaiGatewayService *service.OpenAIGatewayService, + claudeGatewayService *service.GatewayService, + geminiCompatService *service.GeminiMessagesCompatService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, apiKeyService *service.APIKeyService, @@ -48,7 +56,9 @@ func NewOpenAIGatewayHandler( } } return &OpenAIGatewayHandler{ - gatewayService: gatewayService, + openaiGatewayService: openaiGatewayService, + claudeGatewayService: claudeGatewayService, + geminiCompatService: geminiCompatService, billingCacheService: billingCacheService, apiKeyService: apiKeyService, errorPassthroughService: errorPassthroughService, @@ -57,12 +67,34 @@ func NewOpenAIGatewayHandler( } } +func (h *OpenAIGatewayHandler) resolveEffectiveAPIKey(c *gin.Context, apiKey *service.APIKey, requestedModel string) (*service.APIKey, error) { + if apiKey.GroupID != nil && apiKey.Group != nil { + return apiKey, nil + } + + allowedGroups := []int64{} + if apiKey.User != nil { + allowedGroups = apiKey.User.AllowedGroups + } + + group, err := h.claudeGatewayService.ResolveGroupFromUserPermission(c.Request.Context(), allowedGroups, requestedModel) + if err != nil { + return nil, err + } + + cloned := *apiKey + groupID := group.ID + cloned.GroupID = &groupID + cloned.Group = group + return &cloned, nil +} + // Responses handles OpenAI Responses API endpoint // POST /openai/v1/responses func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Get apiKey and user from context (set by ApiKeyAuth middleware) apiKey, ok := middleware2.GetAPIKeyFromContext(c) - if !ok { + if !ok || apiKey == nil { h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") return } @@ -108,6 +140,22 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } + requestedModelForClient := strings.TrimSpace(reqModel) + if ns := service.ParseModelNamespace(reqModel); ns.HasNamespace { + reqModel = ns.Model + reqBody["model"] = reqModel + body, err = json.Marshal(reqBody) + if err != nil { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return + } + if ns.Platform != "" && !middleware2.HasForcePlatform(c) { + ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, ns.Platform) + c.Request = c.Request.WithContext(ctx) + c.Set(string(middleware2.ContextKeyForcePlatform), ns.Platform) + } + } + userAgent := c.GetHeader("User-Agent") if !openai.IsCodexCLIRequest(userAgent) { existingInstructions, _ := reqBody["instructions"].(string) @@ -157,6 +205,13 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Get subscription info (may be nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) + effectiveAPIKey, err := h.resolveEffectiveAPIKey(c, apiKey, reqModel) + if err != nil { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No accessible groups: "+err.Error()) + return + } + apiKey = effectiveAPIKey + // 0. Check if wait queue is full maxWait := service.CalculateMaxWait(subject.Concurrency) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) @@ -204,7 +259,27 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } // Generate session hash (header first; fallback to prompt_cache_key) - sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody) + sessionHash := h.openaiGatewayService.GenerateSessionHash(c, reqBody) + + targetPlatform := "" + if fp, ok := middleware2.GetForcePlatformFromContext(c); ok { + targetPlatform = strings.TrimSpace(fp) + } + if targetPlatform == "" && apiKey.Group != nil { + targetPlatform = strings.TrimSpace(apiKey.Group.Platform) + } + if strings.EqualFold(targetPlatform, "claude") { + targetPlatform = service.PlatformAnthropic + } + if targetPlatform == service.PlatformAnthropic || targetPlatform == service.PlatformGemini { + h.handleCrossPlatformResponses(c, apiKey, subscription, reqBody, body, reqModel, requestedModelForClient, reqStream, sessionHash, &streamStarted) + return + } + + openaiPlatform := service.PlatformOpenAI + if targetPlatform == service.PlatformCopilot || targetPlatform == service.PlatformAggregator { + openaiPlatform = targetPlatform + } maxAccountSwitches := h.maxAccountSwitches switchCount := 0 @@ -214,7 +289,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { for { // Select account supporting the requested model log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + selection, err := h.openaiGatewayService.SelectAccountWithLoadAwarenessForPlatform(c.Request.Context(), apiKey.GroupID, openaiPlatform, sessionHash, reqModel, failedAccountIDs) if err != nil { log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) if len(failedAccountIDs) == 0 { @@ -274,7 +349,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) accountWaitCounted = false } - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { + if err := h.openaiGatewayService.BindStickySessionForPlatform(c.Request.Context(), apiKey.GroupID, openaiPlatform, sessionHash, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } @@ -282,7 +357,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) // Forward request - result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) + result, err := h.openaiGatewayService.Forward(c.Request.Context(), c, account, body) if accountReleaseFunc != nil { accountReleaseFunc() } @@ -312,7 +387,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ + if err := h.openaiGatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, @@ -329,6 +404,360 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } } +func (h *OpenAIGatewayHandler) handleCrossPlatformResponses( + c *gin.Context, + apiKey *service.APIKey, + subscription *service.UserSubscription, + reqBody map[string]any, + body []byte, + reqModel string, + requestedModelForClient string, + reqStream bool, + sessionHash string, + streamStarted *bool, +) { + if h.claudeGatewayService == nil { + h.handleStreamingAwareError(c, http.StatusInternalServerError, "api_error", "Gateway service not configured", derefBool(streamStarted)) + return + } + if apiKey == nil { + h.handleStreamingAwareError(c, http.StatusUnauthorized, "authentication_error", "Invalid API key", derefBool(streamStarted)) + return + } + + platform := "" + if fp, ok := middleware2.GetForcePlatformFromContext(c); ok { + platform = strings.TrimSpace(fp) + } + if platform == "" && apiKey.Group != nil { + platform = strings.TrimSpace(apiKey.Group.Platform) + } + if strings.EqualFold(platform, "claude") { + platform = service.PlatformAnthropic + } + + sessionKey := sessionHash + if platform == service.PlatformGemini && strings.TrimSpace(sessionHash) != "" { + sessionKey = "gemini:" + sessionHash + } + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + var lastFailoverErr *service.UpstreamFailoverError + + for { + selection, err := h.claudeGatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") + if err != nil { + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), derefBool(streamStarted)) + return + } + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, derefBool(streamStarted)) + } else { + h.handleFailoverExhaustedSimple(c, 502, derefBool(streamStarted)) + } + return + } + account := selection.Account + setOpsSelectedAccount(c, account.ID) + + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", derefBool(streamStarted)) + return + } + accountWaitCounted := false + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", derefBool(streamStarted)) + return + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + }() + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + streamStarted, + ) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", derefBool(streamStarted)) + return + } + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } + if err := h.claudeGatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } + } + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + result, err := h.forwardCrossPlatformResponses(c.Request.Context(), c, account, reqBody, body, requestedModelForClient, reqStream, streamStarted) + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, derefBool(streamStarted)) + return + } + switchCount++ + continue + } + log.Printf("Account %d: Forward request failed: %v", account.ID, err) + return + } + + ua := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + go func(result *service.ForwardResult, usedAccount *service.Account, ua, ipAddr string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.claudeGatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: ipAddr, + APIKeyService: h.apiKeyService, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account, ua, clientIP) + return + } +} + +func (h *OpenAIGatewayHandler) forwardCrossPlatformResponses( + ctx context.Context, + c *gin.Context, + account *service.Account, + openaiReq map[string]any, + openaiBody []byte, + requestedModelForClient string, + reqStream bool, + streamStarted *bool, +) (*service.ForwardResult, error) { + claudeReq, convErr := service.ConvertOpenAIResponsesRequestToClaudeMessages(openaiReq) + if convErr != nil { + h.writeOpenAIResponsesError(c, reqStream, streamStarted, "invalid_request_error", convErr.Error()) + return nil, convErr + } + claudeReq["stream"] = false + claudeBody, err := json.Marshal(claudeReq) + if err != nil { + h.writeOpenAIResponsesError(c, reqStream, streamStarted, "api_error", "Failed to process request") + return nil, err + } + if c != nil { + c.Set(service.OpsUpstreamRequestBodyKey, string(openaiBody)) + } + + origWriter := c.Writer + cw := newCaptureWriter(origWriter) + c.Writer = cw + defer func() { c.Writer = origWriter }() + + var result *service.ForwardResult + if account.Platform == service.PlatformGemini { + if h.geminiCompatService == nil { + h.writeOpenAIResponsesError(c, reqStream, streamStarted, "api_error", "Gemini compat service not configured") + return nil, errors.New("gemini compat service not configured") + } + result, err = h.geminiCompatService.Forward(ctx, c, account, claudeBody) + } else { + parsed, perr := service.ParseGatewayRequest(claudeBody, domain.PlatformAnthropic) + if perr != nil { + h.writeOpenAIResponsesError(c, reqStream, streamStarted, "invalid_request_error", "Failed to parse request") + return nil, perr + } + result, err = h.claudeGatewayService.Forward(ctx, c, account, parsed) + } + + c.Writer = origWriter + + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + return nil, failoverErr + } + status := cw.Status() + errType := "upstream_error" + if status == http.StatusBadRequest { + errType = "invalid_request_error" + } + msg := extractClaudeErrorMessage(cw.buf.Bytes()) + if strings.TrimSpace(msg) == "" { + msg = "Upstream request failed" + } + h.writeOpenAIResponsesError(c, reqStream, streamStarted, errType, msg) + return nil, err + } + + var claudeResp map[string]any + if err := json.Unmarshal(cw.buf.Bytes(), &claudeResp); err != nil { + h.writeOpenAIResponsesError(c, reqStream, streamStarted, "upstream_error", "Failed to parse upstream response") + return nil, err + } + + if strings.TrimSpace(requestedModelForClient) == "" { + requestedModelForClient, _ = openaiReq["model"].(string) + } + + openaiResp, err := service.ConvertClaudeMessageToOpenAIResponsesResponse(claudeResp, &result.Usage, requestedModelForClient, "") + if err != nil { + h.writeOpenAIResponsesError(c, reqStream, streamStarted, "upstream_error", "Failed to convert upstream response") + return nil, err + } + + if rid := strings.TrimSpace(cw.header.Get("x-request-id")); rid != "" { + c.Header("x-request-id", rid) + } + + if !reqStream { + c.JSON(http.StatusOK, openaiResp) + return result, nil + } + + if streamStarted != nil { + *streamStarted = true + } + writeOpenAIResponsesSSE(c, openaiResp) + return result, nil +} + +func extractClaudeErrorMessage(body []byte) string { + var parsed map[string]any + if err := json.Unmarshal(body, &parsed); err != nil { + return "" + } + if errObj, ok := parsed["error"].(map[string]any); ok { + if msg, ok := errObj["message"].(string); ok { + return strings.TrimSpace(msg) + } + } + return "" +} + +func writeOpenAIResponsesSSE(c *gin.Context, resp map[string]any) { + if c == nil { + return + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + w := c.Writer + flusher, _ := w.(http.Flusher) + flush := func() { + if flusher != nil { + flusher.Flush() + } + } + + responseID, _ := resp["id"].(string) + if strings.TrimSpace(responseID) == "" { + responseID = "resp_" + randomHex(12) + } + + writeEvent := func(eventType string, payload map[string]any) { + b, err := json.Marshal(payload) + if err != nil { + return + } + _, _ = fmt.Fprintf(w, "event: %s\n", eventType) + _, _ = fmt.Fprintf(w, "data: %s\n\n", b) + flush() + } + + writeEvent("response.created", map[string]any{"type": "response.created", "response": map[string]any{"id": responseID}}) + + if output, ok := resp["output"].([]any); ok { + for _, item := range output { + writeEvent("response.output_item.done", map[string]any{"type": "response.output_item.done", "item": item}) + } + } + + completed := map[string]any{"id": responseID} + if usage := resp["usage"]; usage != nil { + completed["usage"] = usage + } + writeEvent("response.completed", map[string]any{"type": "response.completed", "response": completed}) +} + +func (h *OpenAIGatewayHandler) writeOpenAIResponsesError(c *gin.Context, stream bool, streamStarted *bool, errType string, message string) { + if stream { + if streamStarted != nil { + *streamStarted = true + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + payload := map[string]any{ + "type": "response.failed", + "response": map[string]any{ + "id": "resp_" + randomHex(12), + "error": map[string]any{ + "code": errType, + "message": message, + }, + }, + } + b, _ := json.Marshal(payload) + _, _ = fmt.Fprintf(c.Writer, "event: response.failed\n") + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", b) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } + return + } + h.errorResponse(c, http.StatusBadGateway, errType, message) +} + +func derefBool(v *bool) bool { + if v == nil { + return false + } + return *v +} + +func randomHex(nBytes int) string { + if nBytes <= 0 { + return "" + } + b := make([]byte, nBytes) + if _, err := rand.Read(b); err != nil { + return "" + } + return hex.EncodeToString(b) +} + // handleConcurrencyError handles concurrency-related errors with proper 429 response func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", diff --git a/backend/internal/model/error_passthrough_rule.go b/backend/internal/model/error_passthrough_rule.go index 620736cd87..4d6977c1f4 100644 --- a/backend/internal/model/error_passthrough_rule.go +++ b/backend/internal/model/error_passthrough_rule.go @@ -34,13 +34,15 @@ const MatchModeAll = "all" const ( PlatformAnthropic = "anthropic" PlatformOpenAI = "openai" + PlatformCopilot = "copilot" + PlatformAggregator = "aggregator" PlatformGemini = "gemini" PlatformAntigravity = "antigravity" ) // AllPlatforms 返回所有支持的平台列表 func AllPlatforms() []string { - return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity} + return []string{PlatformAnthropic, PlatformOpenAI, PlatformCopilot, PlatformAggregator, PlatformGemini, PlatformAntigravity} } // Validate 验证规则配置的有效性 diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index ac32fae59b..3f7c1e7f49 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -218,6 +218,7 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (* } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // #nosec G704 -- request targets fixed Google OAuth token endpoint resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("token 交换请求失败: %w", err) @@ -255,6 +256,7 @@ func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenR } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // #nosec G704 -- request targets fixed Google OAuth token endpoint resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("token 刷新请求失败: %w", err) @@ -286,6 +288,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo } req.Header.Set("Authorization", "Bearer "+accessToken) + // #nosec G704 -- request targets fixed Google OAuth userinfo endpoint resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("用户信息请求失败: %w", err) @@ -335,6 +338,7 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", UserAgent) + // #nosec G704 -- request targets fixed Antigravity API host allowlist (BaseURLs) resp, err := c.httpClient.Do(req) if err != nil { lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err) @@ -414,6 +418,7 @@ func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (s req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", UserAgent) + // #nosec G704 -- request targets fixed Antigravity API host allowlist (BaseURLs) resp, err := c.httpClient.Do(req) if err != nil { lastErr = fmt.Errorf("onboardUser 请求失败: %w", err) @@ -534,6 +539,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", UserAgent) + // #nosec G704 -- request targets fixed Antigravity API host allowlist (BaseURLs) resp, err := c.httpClient.Do(req) if err != nil { lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err) diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index eecee11e1c..bd3494aa79 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -57,10 +57,12 @@ var DefaultHeaders = map[string]string{ // Model 表示一个 Claude 模型 type Model struct { - ID string `json:"id"` - Type string `json:"type"` - DisplayName string `json:"display_name"` - CreatedAt string `json:"created_at"` + ID string `json:"id"` + Type string `json:"type"` + DisplayName string `json:"display_name"` + CreatedAt string `json:"created_at"` + ContextWindow int `json:"context_window,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` } // DefaultModels Claude Code 客户端支持的默认模型列表 diff --git a/backend/internal/pkg/geminicli/drive_client.go b/backend/internal/pkg/geminicli/drive_client.go index a6cbc3abab..7a3aca966f 100644 --- a/backend/internal/pkg/geminicli/drive_client.go +++ b/backend/internal/pkg/geminicli/drive_client.go @@ -43,8 +43,9 @@ func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL // Get HTTP client with proxy support client, err := httpclient.GetClient(httpclient.Options{ - ProxyURL: proxyURL, - Timeout: 10 * time.Second, + ProxyURL: proxyURL, + Timeout: 10 * time.Second, + ValidateResolvedIP: true, }) if err != nil { return nil, fmt.Errorf("failed to create HTTP client: %w", err) @@ -70,6 +71,7 @@ func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL return nil, fmt.Errorf("request cancelled: %w", ctx.Err()) } + // #nosec G704 -- request targets fixed Google Drive API host (driveAPIURL) resp, err = client.Do(req) if err != nil { // Network error retry diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index fd24b11d7e..9680d440b2 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -5,12 +5,15 @@ import _ "embed" // Model represents an OpenAI model type Model struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - OwnedBy string `json:"owned_by"` - Type string `json:"type"` - DisplayName string `json:"display_name"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + Type string `json:"type"` + DisplayName string `json:"display_name"` + ContextWindow int `json:"context_window,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + Source string `json:"source,omitempty"` } // DefaultModels OpenAI models list diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index d28ae04252..d48f97bf55 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -749,6 +749,13 @@ func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupI }) } +func (r *accountRepository) ListSchedulableByGroupIDs(ctx context.Context, groupIDs []int64) ([]service.Account, error) { + return r.queryAccountsByGroupIDs(ctx, groupIDs, accountGroupQueryOptions{ + status: service.StatusActive, + schedulable: true, + }) +} + func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { now := time.Now() accounts, err := r.client.Account.Query(). @@ -1270,6 +1277,71 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in return r.accountsToService(ctx, accounts) } +func (r *accountRepository) queryAccountsByGroupIDs(ctx context.Context, groupIDs []int64, opts accountGroupQueryOptions) ([]service.Account, error) { + if len(groupIDs) == 0 { + return []service.Account{}, nil + } + + q := r.client.AccountGroup.Query(). + Where(dbaccountgroup.GroupIDIn(groupIDs...)) + + preds := make([]dbpredicate.Account, 0, 6) + preds = append(preds, dbaccount.DeletedAtIsNil()) + if opts.status != "" { + preds = append(preds, dbaccount.StatusEQ(opts.status)) + } + if len(opts.platforms) > 0 { + preds = append(preds, dbaccount.PlatformIn(opts.platforms...)) + } + if opts.schedulable { + now := time.Now() + preds = append(preds, + dbaccount.SchedulableEQ(true), + tempUnschedulablePredicate(), + notExpiredPredicate(now), + dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), + dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), + ) + } + + if len(preds) > 0 { + q = q.Where(dbaccountgroup.HasAccountWith(preds...)) + } + + groups, err := q. + Order( + dbaccountgroup.ByPriority(), + dbaccountgroup.ByAccountField(dbaccount.FieldPriority), + ). + WithAccount(). + All(ctx) + if err != nil { + return nil, err + } + + orderedIDs := make([]int64, 0, len(groups)) + accountMap := make(map[int64]*dbent.Account, len(groups)) + for _, ag := range groups { + if ag.Edges.Account == nil { + continue + } + if _, exists := accountMap[ag.AccountID]; exists { + continue + } + accountMap[ag.AccountID] = ag.Edges.Account + orderedIDs = append(orderedIDs, ag.AccountID) + } + + accounts := make([]*dbent.Account, 0, len(orderedIDs)) + for _, id := range orderedIDs { + if acc, ok := accountMap[id]; ok { + accounts = append(accounts, acc) + } + } + + return r.accountsToService(ctx, accounts) +} + func (r *accountRepository) accountsToService(ctx context.Context, accounts []*dbent.Account) ([]service.Account, error) { if len(accounts) == 0 { return []service.Account{}, nil diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 1198f4725d..957f088d5f 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -86,6 +86,7 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se client = &http.Client{Timeout: 30 * time.Second} } + // #nosec G704 -- usageURL is fixed (or test-injected) and not user-controlled resp, err = client.Do(req) if err != nil { return nil, fmt.Errorf("request failed: %w", err) diff --git a/backend/internal/repository/github_device_session_store.go b/backend/internal/repository/github_device_session_store.go new file mode 100644 index 0000000000..b27c88dc1c --- /dev/null +++ b/backend/internal/repository/github_device_session_store.go @@ -0,0 +1,54 @@ +package repository + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const gitHubDeviceSessionKeyPrefix = "github:device_session:" + +type gitHubDeviceSessionStore struct { + rdb *redis.Client +} + +func NewGitHubDeviceSessionStore(rdb *redis.Client) service.GitHubDeviceSessionStore { + return &gitHubDeviceSessionStore{rdb: rdb} +} + +func (s *gitHubDeviceSessionStore) Get(ctx context.Context, id string) (*service.GitHubDeviceSession, bool, error) { + key := fmt.Sprintf("%s%s", gitHubDeviceSessionKeyPrefix, id) + b, err := s.rdb.Get(ctx, key).Bytes() + if err == redis.Nil { + return nil, false, nil + } + if err != nil { + return nil, false, err + } + var sess service.GitHubDeviceSession + if err := json.Unmarshal(b, &sess); err != nil { + return nil, false, err + } + return &sess, true, nil +} + +func (s *gitHubDeviceSessionStore) Set(ctx context.Context, id string, sess *service.GitHubDeviceSession, ttl time.Duration) error { + key := fmt.Sprintf("%s%s", gitHubDeviceSessionKeyPrefix, id) + if ttl <= 0 { + return s.rdb.Del(ctx, key).Err() + } + b, err := json.Marshal(sess) + if err != nil { + return err + } + return s.rdb.Set(ctx, key, b, ttl).Err() +} + +func (s *gitHubDeviceSessionStore) Delete(ctx context.Context, id string) error { + key := fmt.Sprintf("%s%s", gitHubDeviceSessionKeyPrefix, id) + return s.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/github_device_session_store_integration_test.go b/backend/internal/repository/github_device_session_store_integration_test.go new file mode 100644 index 0000000000..e818ab0682 --- /dev/null +++ b/backend/internal/repository/github_device_session_store_integration_test.go @@ -0,0 +1,77 @@ +//go:build integration + +package repository + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type GitHubDeviceSessionStoreSuite struct { + IntegrationRedisSuite + store service.GitHubDeviceSessionStore +} + +func (s *GitHubDeviceSessionStoreSuite) SetupTest() { + s.IntegrationRedisSuite.SetupTest() + s.store = NewGitHubDeviceSessionStore(s.rdb) +} + +func (s *GitHubDeviceSessionStoreSuite) TestSetGetDelete() { + id := "session-1" + createdAt := time.Now().Unix() + expiresAt := time.Now().Add(2 * time.Minute).Unix() + + sess := &service.GitHubDeviceSession{ + AccountID: 123, + AccountConcurrency: 7, + ProxyURL: "http://proxy.local", + ClientID: "client-1", + Scope: "read:user", + DeviceCode: "device-code", + ExpiresAtUnix: expiresAt, + IntervalSeconds: 5, + CreatedAtUnix: createdAt, + } + + require.NoError(s.T(), s.store.Set(s.ctx, id, sess, 2*time.Minute)) + + got, ok, err := s.store.Get(s.ctx, id) + require.NoError(s.T(), err) + require.True(s.T(), ok) + require.Equal(s.T(), sess, got) + + key := gitHubDeviceSessionKeyPrefix + id + ttl, err := s.rdb.TTL(s.ctx, key).Result() + require.NoError(s.T(), err) + s.AssertTTLWithin(ttl, time.Minute, 2*time.Minute) + + require.NoError(s.T(), s.store.Delete(s.ctx, id)) + _, ok, err = s.store.Get(s.ctx, id) + require.NoError(s.T(), err) + require.False(s.T(), ok) +} + +func (s *GitHubDeviceSessionStoreSuite) TestGetMissing() { + _, ok, err := s.store.Get(s.ctx, "missing") + require.NoError(s.T(), err) + require.False(s.T(), ok) +} + +func (s *GitHubDeviceSessionStoreSuite) TestSetWithNonPositiveTTLDeletes() { + id := "session-ttl-0" + sess := &service.GitHubDeviceSession{AccountID: 1, ExpiresAtUnix: time.Now().Add(time.Minute).Unix()} + require.NoError(s.T(), s.store.Set(s.ctx, id, sess, time.Minute)) + require.NoError(s.T(), s.store.Set(s.ctx, id, sess, 0)) + _, ok, err := s.store.Get(s.ctx, id) + require.NoError(s.T(), err) + require.False(s.T(), ok) +} + +func TestGitHubDeviceSessionStoreSuite(t *testing.T) { + suite.Run(t, new(GitHubDeviceSessionStoreSuite)) +} diff --git a/backend/internal/repository/github_device_session_store_test.go b/backend/internal/repository/github_device_session_store_test.go new file mode 100644 index 0000000000..cb2f494b69 --- /dev/null +++ b/backend/internal/repository/github_device_session_store_test.go @@ -0,0 +1,39 @@ +//go:build unit + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestGitHubDeviceSessionStore_Set_RedisError(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:1", + DialTimeout: 50 * time.Millisecond, + ReadTimeout: 50 * time.Millisecond, + WriteTimeout: 50 * time.Millisecond, + }) + t.Cleanup(func() { + _ = rdb.Close() + }) + + store := NewGitHubDeviceSessionStore(rdb) + err := store.Set(context.Background(), "broken", &service.GitHubDeviceSession{ + AccountID: 1, + AccountConcurrency: 1, + ProxyURL: "", + ClientID: "cid", + Scope: "scope", + DeviceCode: "dc", + ExpiresAtUnix: time.Now().Add(time.Minute).Unix(), + IntervalSeconds: 5, + CreatedAtUnix: time.Now().Unix(), + }, time.Minute) + require.Error(t, err) +} diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index 03f8cc66bf..8977f07e4c 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -54,6 +54,7 @@ func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo strin req.Header.Set("Accept", "application/vnd.github.v3+json") req.Header.Set("User-Agent", "Sub2API-Updater") + // #nosec G704 -- request targets fixed GitHub API host resp, err := c.httpClient.Do(req) if err != nil { return nil, err @@ -79,6 +80,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string } // 使用预配置的下载客户端(已包含代理配置) + // #nosec G704 -- download URL comes from GitHub release metadata resp, err := c.downloadHTTPClient.Do(req) if err != nil { return err @@ -126,6 +128,7 @@ func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) return nil, err } + // #nosec G704 -- checksum URL comes from GitHub release metadata resp, err := c.httpClient.Do(req) if err != nil { return nil, err diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 4e7a836fce..4129a43cd3 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -523,3 +523,16 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic return nil } + +func (r *groupRepository) ListPublicGroupIDs(ctx context.Context) ([]int64, error) { + ids, err := r.client.Group.Query(). + Where( + group.StatusEQ(service.StatusActive), + group.IsExclusiveEQ(false), + ). + IDs(ctx) + if err != nil { + return nil, err + } + return ids, nil +} diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index b0f15f19a7..6046341cd8 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -134,6 +134,7 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i } // 执行请求 + // #nosec G704 -- req URL/host validated by validateRequestHost resp, err := entry.client.Do(req) if err != nil { // 请求失败,立即减少计数 @@ -206,6 +207,7 @@ func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, acco } // 执行请求 + // #nosec G704 -- req URL/host validated by validateRequestHost resp, err := entry.client.Do(req) if err != nil { // 请求失败,立即减少计数 diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go index 07d796b8cd..8bbdbe4e36 100644 --- a/backend/internal/repository/pricing_service.go +++ b/backend/internal/repository/pricing_service.go @@ -37,6 +37,7 @@ func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) return nil, err } + // #nosec G704 -- pricing URL is configured by operator/admin resp, err := c.httpClient.Do(req) if err != nil { return nil, err @@ -56,6 +57,7 @@ func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (st return "", err } + // #nosec G704 -- pricing URL is configured by operator/admin resp, err := c.httpClient.Do(req) if err != nil { return "", err diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 513e929cbb..cf21640d53 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -86,6 +86,7 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien return nil, 0, fmt.Errorf("failed to create request: %w", err) } + // #nosec G704 -- probe URL comes from a fixed allowlist resp, err := client.Do(req) if err != nil { return nil, 0, fmt.Errorf("proxy connection failed: %w", err) diff --git a/backend/internal/repository/simple_mode_default_groups.go b/backend/internal/repository/simple_mode_default_groups.go index 5630918400..779a107c42 100644 --- a/backend/internal/repository/simple_mode_default_groups.go +++ b/backend/internal/repository/simple_mode_default_groups.go @@ -17,6 +17,8 @@ func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) er requiredByPlatform := map[string]int{ service.PlatformAnthropic: 1, service.PlatformOpenAI: 1, + service.PlatformCopilot: 1, + service.PlatformAggregator: 1, service.PlatformGemini: 1, service.PlatformAntigravity: 2, } diff --git a/backend/internal/repository/turnstile_service.go b/backend/internal/repository/turnstile_service.go index 89748cd3d4..bdf9ffa5b6 100644 --- a/backend/internal/repository/turnstile_service.go +++ b/backend/internal/repository/turnstile_service.go @@ -48,6 +48,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // #nosec G704 -- request targets fixed Cloudflare Turnstile host resp, err := v.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("send request: %w", err) diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 3aed9d9cf2..9371af1e7d 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -83,6 +83,7 @@ var ProviderSet = wire.NewSet( NewRedeemCache, NewUpdateCache, NewGeminiTokenCache, + NewGitHubDeviceSessionStore, NewSchedulerCache, NewSchedulerOutboxRepository, NewProxyLatencyCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index fa6806ae0b..46b45ac920 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -612,7 +612,7 @@ func newContractDeps(t *testing.T) *contractDeps { apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) - adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ @@ -876,6 +876,17 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin return out, nil } +func (r *stubGroupRepo) ListPublicGroupIDs(ctx context.Context) ([]int64, error) { + ids := make([]int64, 0, len(r.active)) + for i := range r.active { + g := r.active[i] + if g.Status == service.StatusActive && !g.IsExclusive { + ids = append(ids, g.ID) + } + } + return ids, nil +} + func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) { return false, errors.New("not implemented") } @@ -988,6 +999,10 @@ func (s *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID return nil, errors.New("not implemented") } +func (s *stubAccountRepo) ListSchedulableByGroupIDs(ctx context.Context, groupIDs []int64) ([]service.Account, error) { + return nil, nil +} + func (s *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 2f739357ff..7c51a6891c 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -130,8 +130,27 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } + // 未绑定分组的 API Key:具体分组会在 handler 中根据模型等信息解析, + // 计费/订阅检查也必须延后到解析后再进行。 + if apiKey.GroupID == nil { + c.Set(string(ContextKeyAPIKey), apiKey) + c.Set(string(ContextKeyUser), AuthSubject{ + UserID: apiKey.User.ID, + Concurrency: apiKey.User.Concurrency, + }) + c.Set(string(ContextKeyUserRole), apiKey.User.Role) + setGroupContext(c, nil) + c.Next() + return + } + + if apiKey.Group == nil { + AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to resolve API key group") + return + } + // 判断计费方式:订阅模式 vs 余额模式 - isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() + isSubscriptionType := apiKey.Group.IsSubscriptionType() if isSubscriptionType && subscriptionService != nil { // 订阅模式:验证订阅 diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 9d51481884..a85ad58267 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -235,6 +235,47 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestAPIKeyAuth_GroupLessKeySkipsBillingPrecheck(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 0, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + Group: nil, + } + + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeStandard} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, nil, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) +} + func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 4509b4bc32..841fdb2cf1 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -213,6 +213,9 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.PUT("/:id", h.Admin.Account.Update) accounts.DELETE("/:id", h.Admin.Account.Delete) accounts.POST("/:id/test", h.Admin.Account.Test) + accounts.POST("/:id/github/device/start", h.Admin.Account.StartGitHubDeviceAuth) + accounts.POST("/:id/github/device/poll", h.Admin.Account.PollGitHubDeviceAuth) + accounts.POST("/:id/github/device/cancel", h.Admin.Account.CancelGitHubDeviceAuth) accounts.POST("/:id/refresh", h.Admin.Account.Refresh) accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier) accounts.GET("/:id/stats", h.Admin.Account.GetStats) @@ -224,6 +227,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable) accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels) + accounts.POST("/:id/models/refresh", h.Admin.Account.RefreshAvailableModels) accounts.POST("/batch", h.Admin.Account.BatchCreate) accounts.GET("/data", h.Admin.Account.ExportData) accounts.POST("/data", h.Admin.Account.ImportData) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 138d5bcb0b..1db96b8588 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -52,6 +52,14 @@ type Account struct { Groups []*Group } +const ( + AccountExtraKeyAvailableModels = "available_models" + AccountExtraKeyAvailableModelsUpdatedAt = "available_models_updated_at" + AccountExtraKeyAvailableModelsSource = "available_models_source" + AccountExtraKeyAvailableModelsError = "available_models_error" + AccountExtraKeyAvailableModelsErrorAt = "available_models_error_at" +) + type TempUnschedulableRule struct { ErrorCode int `json:"error_code"` Keywords []string `json:"keywords"` @@ -288,6 +296,10 @@ func parseTempUnschedString(value any) string { } func parseTempUnschedStrings(value any) []string { + return parseStringSlice(value) +} + +func parseStringSlice(value any) []string { if value == nil { return nil } @@ -382,9 +394,37 @@ func (a *Account) GetModelMapping() map[string]string { return nil } +func (a *Account) GetAvailableModels() []string { + if a == nil || a.Extra == nil { + return nil + } + raw, ok := a.Extra[AccountExtraKeyAvailableModels] + if !ok || raw == nil { + return nil + } + return parseStringSlice(raw) +} + // IsModelSupported 检查模型是否在 model_mapping 中(支持通配符) // 如果未配置 mapping,返回 true(允许所有模型) func (a *Account) IsModelSupported(requestedModel string) bool { + requestedModel = strings.TrimSpace(requestedModel) + if requestedModel == "" { + return false + } + + if isGitHubCopilotAccount(a) { + if models := a.GetAvailableModels(); len(models) > 0 { + for _, id := range models { + if id == requestedModel { + return true + } + } + return false + } + return true + } + mapping := a.GetModelMapping() if len(mapping) == 0 { return true // 无映射 = 允许所有 @@ -405,6 +445,9 @@ func (a *Account) IsModelSupported(requestedModel string) bool { // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) // 如果未配置 mapping,返回原始模型名 func (a *Account) GetMappedModel(requestedModel string) string { + if isGitHubCopilotAccount(a) { + return requestedModel + } mapping := a.GetModelMapping() if len(mapping) == 0 { return requestedModel diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 6c0cca31fd..15c40d51f3 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -47,6 +47,7 @@ type AccountRepository interface { ListSchedulable(ctx context.Context) ([]Account, error) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) + ListSchedulableByGroupIDs(ctx context.Context, groupIDs []int64) ([]Account, error) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) @@ -344,7 +345,7 @@ func (s *AccountService) TestCredentials(ctx context.Context, id int64) error { case PlatformAnthropic: // TODO: 测试Anthropic API凭证 return nil - case PlatformOpenAI: + case PlatformOpenAI, PlatformCopilot, PlatformAggregator: // TODO: 测试OpenAI API凭证 return nil case PlatformGemini: diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 25bd0576ce..23750eb360 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -127,6 +127,10 @@ func (s *accountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID panic("unexpected ListSchedulableByGroupID call") } +func (s *accountRepoStub) ListSchedulableByGroupIDs(ctx context.Context, groupIDs []int64) ([]Account, error) { + panic("unexpected ListSchedulableByGroupIDs call") +} + func (s *accountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) { panic("unexpected ListSchedulableByPlatform call") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 899a44987d..607a6afc6c 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -44,27 +44,30 @@ type TestEvent struct { // AccountTestService handles account testing operations type AccountTestService struct { - accountRepo AccountRepository - geminiTokenProvider *GeminiTokenProvider - antigravityGatewayService *AntigravityGatewayService - httpUpstream HTTPUpstream - cfg *config.Config + accountRepo AccountRepository + geminiTokenProvider *GeminiTokenProvider + githubCopilotTokenProvider *GitHubCopilotTokenProvider + antigravityGatewayService *AntigravityGatewayService + httpUpstream HTTPUpstream + cfg *config.Config } // NewAccountTestService creates a new AccountTestService func NewAccountTestService( accountRepo AccountRepository, geminiTokenProvider *GeminiTokenProvider, + githubCopilotTokenProvider *GitHubCopilotTokenProvider, antigravityGatewayService *AntigravityGatewayService, httpUpstream HTTPUpstream, cfg *config.Config, ) *AccountTestService { return &AccountTestService{ - accountRepo: accountRepo, - geminiTokenProvider: geminiTokenProvider, - antigravityGatewayService: antigravityGatewayService, - httpUpstream: httpUpstream, - cfg: cfg, + accountRepo: accountRepo, + geminiTokenProvider: geminiTokenProvider, + githubCopilotTokenProvider: githubCopilotTokenProvider, + antigravityGatewayService: antigravityGatewayService, + httpUpstream: httpUpstream, + cfg: cfg, } } @@ -150,8 +153,20 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int return s.sendErrorAndEnd(c, "Account not found") } + if ns := ParseModelNamespace(modelID); ns.HasNamespace { + modelID = ns.Model + } + + isCopilot := account.Platform == PlatformCopilot || isGitHubCopilotAccount(account) + if isCopilot { + if strings.TrimSpace(modelID) != "" && IsClaudeModelID(modelID) { + return s.testClaudeAccountConnection(c, account, modelID) + } + return s.testOpenAIAccountConnection(c, account, modelID) + } + // Route to platform-specific test method - if account.IsOpenAI() { + if account.Platform == PlatformOpenAI || account.Platform == PlatformAggregator { return s.testOpenAIAccountConnection(c, account, modelID) } @@ -200,11 +215,23 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account return s.sendErrorAndEnd(c, "No access token available") } } else if account.Type == "apikey" { - // API Key - use x-api-key header - useBearer = false - authToken = account.GetCredential("api_key") + isGitHubCopilot := isGitHubCopilotAccount(account) + if isGitHubCopilot && s.githubCopilotTokenProvider != nil { + if token, err := s.githubCopilotTokenProvider.GetAccessToken(ctx, account); err == nil && strings.TrimSpace(token) != "" { + useBearer = true + authToken = token + } + } if authToken == "" { - return s.sendErrorAndEnd(c, "No API key available") + // API Key - use x-api-key header + useBearer = false + authToken = account.GetCredential("api_key") + if authToken == "" { + if isGitHubCopilot { + return s.sendErrorAndEnd(c, "No GitHub token or Copilot bearer token available") + } + return s.sendErrorAndEnd(c, "No API key available") + } } baseURL := account.GetBaseURL() @@ -246,19 +273,26 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account req.Header.Set("Content-Type", "application/json") req.Header.Set("anthropic-version", "2023-06-01") - // Apply Claude Code client headers - for key, value := range claude.DefaultHeaders { - req.Header.Set(key, value) + if account.Type == "apikey" && isGitHubCopilotAccount(account) { + applyGitHubCopilotHeaders(req, false, "user") + } else { + // Apply Claude Code client headers + for key, value := range claude.DefaultHeaders { + req.Header.Set(key, value) + } } // Set authentication header if useBearer { - req.Header.Set("anthropic-beta", claude.DefaultBetaHeader) + if account.Type != "apikey" || !isGitHubCopilotAccount(account) { + req.Header.Set("anthropic-beta", claude.DefaultBetaHeader) + } req.Header.Set("Authorization", "Bearer "+authToken) } else { req.Header.Set("anthropic-beta", claude.APIKeyBetaHeader) req.Header.Set("x-api-key", authToken) } + req.Header.Set("accept", "text/event-stream") // Get proxy URL proxyURL := "" @@ -306,6 +340,8 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account var apiURL string var isOAuth bool var chatgptAccountID string + var normalizedBaseURL string + var isGitHubCopilot bool if account.IsOAuth() { isOAuth = true @@ -320,20 +356,39 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account chatgptAccountID = account.GetChatGPTAccountID() } else if account.Type == "apikey" { // API Key - use Platform API - authToken = account.GetOpenAIApiKey() + isGitHubCopilot = isGitHubCopilotAccount(account) + if isGitHubCopilot && s.githubCopilotTokenProvider != nil { + if token, err := s.githubCopilotTokenProvider.GetAccessToken(ctx, account); err == nil && strings.TrimSpace(token) != "" { + authToken = token + } + } if authToken == "" { - return s.sendErrorAndEnd(c, "No API key available") + authToken = strings.TrimSpace(account.GetCredential("api_key")) + if authToken == "" { + if isGitHubCopilot { + return s.sendErrorAndEnd(c, "No GitHub token or Copilot bearer token available") + } + return s.sendErrorAndEnd(c, "No API key available") + } } - baseURL := account.GetOpenAIBaseURL() + baseURL := strings.TrimSpace(account.GetCredential("base_url")) if baseURL == "" { - baseURL = "https://api.openai.com" + switch account.Platform { + case PlatformCopilot: + baseURL = "https://api.githubcopilot.com" + case PlatformAggregator: + return s.sendErrorAndEnd(c, "Base URL is required") + default: + baseURL = "https://api.openai.com" + } } - normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + normalized, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) } - apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses" + normalizedBaseURL = normalized + apiURL = openaiResponsesURLFromBaseURL(normalizedBaseURL, isGitHubCopilot) } else { return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) } @@ -360,15 +415,18 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account // Set common headers req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("accept", "text/event-stream") // Set OAuth-specific headers for ChatGPT internal API if isOAuth { req.Host = "chatgpt.com" - req.Header.Set("accept", "text/event-stream") if chatgptAccountID != "" { req.Header.Set("chatgpt-account-id", chatgptAccountID) } } + if account.Type == "apikey" && isGitHubCopilotAccount(account) { + applyGitHubCopilotHeaders(req, false, "user") + } // Get proxy URL proxyURL := "" @@ -384,6 +442,65 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) + if account.Type == "apikey" && isGitHubCopilot { + rawUpstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + if isResponsesAPIUnsupportedError(rawUpstreamMsg, body) { + chatURL := openaiChatCompletionsURLFromBaseURL(normalizedBaseURL, isGitHubCopilot) + payload := map[string]any{ + "model": testModelID, + "messages": []any{ + map[string]any{"role": "system", "content": openai.DefaultInstructions}, + map[string]any{"role": "user", "content": "hi"}, + }, + "stream": false, + "max_tokens": 16, + } + payloadBytes, _ := json.Marshal(payload) + + req2, err2 := http.NewRequestWithContext(ctx, "POST", chatURL, bytes.NewReader(payloadBytes)) + if err2 != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + req2.Header.Set("Content-Type", "application/json") + req2.Header.Set("Authorization", "Bearer "+authToken) + req2.Header.Set("accept", "application/json") + applyGitHubCopilotHeaders(req2, false, "user") + + resp2, err2 := s.httpUpstream.DoWithTLS(req2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err2 != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err2.Error())) + } + defer func() { _ = resp2.Body.Close() }() + respBody2, _ := io.ReadAll(resp2.Body) + if resp2.StatusCode != http.StatusOK { + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp2.StatusCode, string(respBody2))) + } + var data map[string]any + if err := json.Unmarshal(respBody2, &data); err != nil { + return s.sendErrorAndEnd(c, "Failed to parse upstream response") + } + if errData, ok := data["error"].(map[string]any); ok { + msg := "Unknown error" + if m, ok := errData["message"].(string); ok { + msg = m + } + return s.sendErrorAndEnd(c, msg) + } + choicesAny, _ := data["choices"].([]any) + if len(choicesAny) == 0 { + return s.sendErrorAndEnd(c, "Empty upstream response") + } + choice0, _ := choicesAny[0].(map[string]any) + msgAny, _ := choice0["message"].(map[string]any) + if msgAny != nil { + if text, _ := openAIChatMessageContentToText(msgAny["content"]); strings.TrimSpace(text) != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: text}) + } + } + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + } return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) } diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 90e5b57312..352a867063 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -202,6 +202,28 @@ func TestAccountIsModelSupported(t *testing.T) { } } +func TestAccountIsModelSupported_CopilotAvailableModels(t *testing.T) { + account := &Account{ + Platform: PlatformCopilot, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + AccountExtraKeyAvailableModels: []any{"gpt-4o", "gpt-5.2"}, + }, + } + if !account.IsModelSupported("gpt-4o") { + t.Fatalf("expected gpt-4o supported") + } + if !account.IsModelSupported(" gpt-5.2 ") { + t.Fatalf("expected gpt-5.2 supported (trim)") + } + if account.IsModelSupported("o1") { + t.Fatalf("expected o1 not supported") + } + if account.IsModelSupported("") { + t.Fatalf("expected empty model not supported") + } +} + func TestAccountGetMappedModel(t *testing.T) { tests := []struct { name string diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 06354e1e04..a329e58caa 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -178,6 +178,21 @@ type CreateAccountInput struct { SkipMixedChannelCheck bool } +func credentialString(m map[string]any, key string) string { + if m == nil { + return "" + } + v, ok := m[key] + if !ok || v == nil { + return "" + } + s, ok := v.(string) + if !ok { + return "" + } + return strings.TrimSpace(s) +} + type UpdateAccountInput struct { Name string Notes *string @@ -1048,6 +1063,18 @@ func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([ } func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) { + if input == nil { + return nil, errors.New("input is required") + } + if strings.TrimSpace(input.Platform) == PlatformAggregator && (input.Type == AccountTypeAPIKey || input.Type == AccountTypeUpstream) { + if credentialString(input.Credentials, "base_url") == "" { + return nil, errors.New("base_url is required for aggregator accounts") + } + if credentialString(input.Credentials, "api_key") == "" { + return nil, errors.New("api_key is required for aggregator accounts") + } + } + // 绑定分组 groupIDs := input.GroupIDs // 如果没有指定分组,自动绑定对应平台的默认分组 @@ -1172,6 +1199,15 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.AutoPauseOnExpired = *input.AutoPauseOnExpired } + if account.Platform == PlatformAggregator && (account.Type == AccountTypeAPIKey || account.Type == AccountTypeUpstream) { + if strings.TrimSpace(account.GetCredential("base_url")) == "" { + return nil, errors.New("base_url is required for aggregator accounts") + } + if strings.TrimSpace(account.GetCredential("api_key")) == "" { + return nil, errors.New("api_key is required for aggregator accounts") + } + } + // 先验证分组是否存在(在任何写操作之前) if input.GroupIDs != nil { for _, groupID := range *input.GroupIDs { diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 60fa3d774f..c02e9c83e9 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -152,6 +152,10 @@ func (s *groupRepoStub) ListActiveByPlatform(ctx context.Context, platform strin panic("unexpected ListActiveByPlatform call") } +func (s *groupRepoStub) ListPublicGroupIDs(ctx context.Context) ([]int64, error) { + panic("unexpected ListPublicGroupIDs call") +} + func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, error) { panic("unexpected ExistsByName call") } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index ef77a98059..e752da5c9c 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -96,6 +96,10 @@ func (s *groupRepoStubForAdmin) ListActiveByPlatform(_ context.Context, _ string panic("unexpected ListActiveByPlatform call") } +func (s *groupRepoStubForAdmin) ListPublicGroupIDs(_ context.Context) ([]int64, error) { + panic("unexpected ListPublicGroupIDs call") +} + func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, error) { panic("unexpected ExistsByName call") } @@ -379,6 +383,10 @@ func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, panic("unexpected ListActiveByPlatform call") } +func (s *groupRepoStubForFallbackCycle) ListPublicGroupIDs(_ context.Context) ([]int64, error) { + panic("unexpected ListPublicGroupIDs call") +} + func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) { panic("unexpected ExistsByName call") } @@ -454,6 +462,10 @@ func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context. panic("unexpected ListActiveByPlatform call") } +func (s *groupRepoStubForInvalidRequestFallback) ListPublicGroupIDs(_ context.Context) ([]int64, error) { + panic("unexpected ListPublicGroupIDs call") +} + func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) { panic("unexpected ExistsByName call") } diff --git a/backend/internal/service/anthropic_messages_url.go b/backend/internal/service/anthropic_messages_url.go new file mode 100644 index 0000000000..d6ba32fafd --- /dev/null +++ b/backend/internal/service/anthropic_messages_url.go @@ -0,0 +1,28 @@ +package service + +import "strings" + +func anthropicMessagesURLFromBaseURL(normalizedBaseURL string) string { + base := strings.TrimRight(strings.TrimSpace(normalizedBaseURL), "/") + if strings.HasSuffix(base, "/v1/messages") { + return base + } + if strings.HasSuffix(base, "/v1") { + return base + "/messages" + } + return base + "/v1/messages" +} + +func anthropicCountTokensURLFromBaseURL(normalizedBaseURL string) string { + base := strings.TrimRight(strings.TrimSpace(normalizedBaseURL), "/") + if strings.HasSuffix(base, "/v1/messages/count_tokens") { + return base + } + if strings.HasSuffix(base, "/v1/messages") { + return base + "/count_tokens" + } + if strings.HasSuffix(base, "/v1") { + return base + "/messages/count_tokens" + } + return base + "/v1/messages/count_tokens" +} diff --git a/backend/internal/service/anthropic_messages_url_test.go b/backend/internal/service/anthropic_messages_url_test.go new file mode 100644 index 0000000000..7f6d76c71d --- /dev/null +++ b/backend/internal/service/anthropic_messages_url_test.go @@ -0,0 +1,136 @@ +//go:build unit + +package service + +import "testing" + +func TestAnthropicMessagesURLFromBaseURL(t *testing.T) { + tests := []struct { + name string + baseURL string + expected string + }{ + { + name: "root", + baseURL: "https://api.anthropic.com", + expected: "https://api.anthropic.com/v1/messages", + }, + { + name: "root trailing slash", + baseURL: "https://api.anthropic.com/", + expected: "https://api.anthropic.com/v1/messages", + }, + { + name: "v1", + baseURL: "https://api.anthropic.com/v1", + expected: "https://api.anthropic.com/v1/messages", + }, + { + name: "v1 trailing slash", + baseURL: "https://api.anthropic.com/v1/", + expected: "https://api.anthropic.com/v1/messages", + }, + { + name: "messages endpoint", + baseURL: "https://api.anthropic.com/v1/messages", + expected: "https://api.anthropic.com/v1/messages", + }, + { + name: "messages endpoint trailing slash", + baseURL: "https://api.anthropic.com/v1/messages/", + expected: "https://api.anthropic.com/v1/messages", + }, + { + name: "path prefix", + baseURL: "https://proxy.example.com/anthropic", + expected: "https://proxy.example.com/anthropic/v1/messages", + }, + { + name: "github copilot root", + baseURL: "https://api.githubcopilot.com", + expected: "https://api.githubcopilot.com/v1/messages", + }, + { + name: "github copilot v1", + baseURL: "https://api.githubcopilot.com/v1", + expected: "https://api.githubcopilot.com/v1/messages", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := anthropicMessagesURLFromBaseURL(tt.baseURL) + if got != tt.expected { + t.Fatalf("anthropicMessagesURLFromBaseURL(%q) = %q, want %q", tt.baseURL, got, tt.expected) + } + }) + } +} + +func TestAnthropicCountTokensURLFromBaseURL(t *testing.T) { + tests := []struct { + name string + baseURL string + expected string + }{ + { + name: "root", + baseURL: "https://api.anthropic.com", + expected: "https://api.anthropic.com/v1/messages/count_tokens", + }, + { + name: "root trailing slash", + baseURL: "https://api.anthropic.com/", + expected: "https://api.anthropic.com/v1/messages/count_tokens", + }, + { + name: "v1", + baseURL: "https://api.anthropic.com/v1", + expected: "https://api.anthropic.com/v1/messages/count_tokens", + }, + { + name: "v1 trailing slash", + baseURL: "https://api.anthropic.com/v1/", + expected: "https://api.anthropic.com/v1/messages/count_tokens", + }, + { + name: "messages endpoint", + baseURL: "https://api.anthropic.com/v1/messages", + expected: "https://api.anthropic.com/v1/messages/count_tokens", + }, + { + name: "count_tokens endpoint", + baseURL: "https://api.anthropic.com/v1/messages/count_tokens", + expected: "https://api.anthropic.com/v1/messages/count_tokens", + }, + { + name: "count_tokens endpoint trailing slash", + baseURL: "https://api.anthropic.com/v1/messages/count_tokens/", + expected: "https://api.anthropic.com/v1/messages/count_tokens", + }, + { + name: "path prefix", + baseURL: "https://proxy.example.com/anthropic", + expected: "https://proxy.example.com/anthropic/v1/messages/count_tokens", + }, + { + name: "github copilot root", + baseURL: "https://api.githubcopilot.com", + expected: "https://api.githubcopilot.com/v1/messages/count_tokens", + }, + { + name: "github copilot v1", + baseURL: "https://api.githubcopilot.com/v1", + expected: "https://api.githubcopilot.com/v1/messages/count_tokens", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := anthropicCountTokensURLFromBaseURL(tt.baseURL) + if got != tt.expected { + t.Fatalf("anthropicCountTokensURLFromBaseURL(%q) = %q, want %q", tt.baseURL, got, tt.expected) + } + }) + } +} diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index c09cafb993..d385032ebc 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -458,7 +458,7 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user } // 判断计费模式 - isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil + isSubscriptionMode := group != nil && group.IsSubscriptionType() if isSubscriptionMode { return s.checkSubscriptionEligibility(ctx, user.ID, group, subscription) diff --git a/backend/internal/service/billing_cache_service_eligibility_test.go b/backend/internal/service/billing_cache_service_eligibility_test.go new file mode 100644 index 0000000000..c56aed79a1 --- /dev/null +++ b/backend/internal/service/billing_cache_service_eligibility_test.go @@ -0,0 +1,68 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type billingCacheStubForEligibility struct { + balance float64 + sub *SubscriptionCacheData +} + +func (s *billingCacheStubForEligibility) GetUserBalance(ctx context.Context, userID int64) (float64, error) { + return s.balance, nil +} + +func (s *billingCacheStubForEligibility) SetUserBalance(ctx context.Context, userID int64, balance float64) error { + return nil +} + +func (s *billingCacheStubForEligibility) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { + return nil +} + +func (s *billingCacheStubForEligibility) InvalidateUserBalance(ctx context.Context, userID int64) error { + return nil +} + +func (s *billingCacheStubForEligibility) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) { + return s.sub, nil +} + +func (s *billingCacheStubForEligibility) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error { + return nil +} + +func (s *billingCacheStubForEligibility) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { + return nil +} + +func (s *billingCacheStubForEligibility) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { + return nil +} + +func TestBillingCacheService_CheckBillingEligibility_SubscriptionGroupWithoutSubscriptionObject(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeStandard} + cache := &billingCacheStubForEligibility{ + balance: 0, + sub: &SubscriptionCacheData{ + Status: SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }, + } + svc := NewBillingCacheService(cache, nil, nil, cfg) + t.Cleanup(svc.Stop) + + user := &User{ID: 1} + group := &Group{ID: 2, SubscriptionType: SubscriptionTypeSubscription, Status: StatusActive} + + err := svc.CheckBillingEligibility(context.Background(), user, &APIKey{}, group, nil) + require.NoError(t, err) +} diff --git a/backend/internal/service/copilot_model_refresh_service.go b/backend/internal/service/copilot_model_refresh_service.go new file mode 100644 index 0000000000..acbe6ff77b --- /dev/null +++ b/backend/internal/service/copilot_model_refresh_service.go @@ -0,0 +1,150 @@ +package service + +import ( + "context" + "log" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +type CopilotModelRefreshService struct { + accountRepo AccountRepository + githubCopilotToken *GitHubCopilotTokenProvider + cfg *config.CopilotModelRefreshConfig + + stopCh chan struct{} + wg sync.WaitGroup +} + +func NewCopilotModelRefreshService(accountRepo AccountRepository, githubCopilotToken *GitHubCopilotTokenProvider, cfg *config.Config) *CopilotModelRefreshService { + var c *config.CopilotModelRefreshConfig + if cfg != nil { + c = &cfg.CopilotModelRefresh + } + return &CopilotModelRefreshService{ + accountRepo: accountRepo, + githubCopilotToken: githubCopilotToken, + cfg: c, + stopCh: make(chan struct{}), + } +} + +func (s *CopilotModelRefreshService) Start() { + if s == nil || s.cfg == nil || !s.cfg.Enabled { + return + } + if s.githubCopilotToken == nil { + log.Println("[CopilotModelRefresh] GitHub Copilot token provider is nil") + return + } + + s.wg.Add(1) + go s.refreshLoop() + + log.Printf("[CopilotModelRefresh] Service started (check every %d minutes)", s.cfg.CheckIntervalMinutes) +} + +func (s *CopilotModelRefreshService) Stop() { + if s == nil { + return + } + close(s.stopCh) + s.wg.Wait() + log.Println("[CopilotModelRefresh] Service stopped") +} + +func (s *CopilotModelRefreshService) refreshLoop() { + defer s.wg.Done() + + interval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute + if interval < time.Minute { + interval = 6 * time.Hour + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + s.processRefresh() + + for { + select { + case <-ticker.C: + s.processRefresh() + case <-s.stopCh: + return + } + } +} + +func (s *CopilotModelRefreshService) processRefresh() { + ctx := context.Background() + + accounts, err := s.accountRepo.ListActive(ctx) + if err != nil { + log.Printf("[CopilotModelRefresh] Failed to list accounts: %v", err) + return + } + + refreshed, failed, skipped := 0, 0, 0 + + for i := range accounts { + acc := &accounts[i] + if !isGitHubCopilotAccount(acc) { + skipped++ + continue + } + + timeout := time.Duration(s.cfg.RequestTimeoutSeconds) * time.Second + if timeout <= 0 { + timeout = 30 * time.Second + } + reqCtx, cancel := context.WithTimeout(ctx, timeout) + models, fetchErr := s.githubCopilotToken.ListModels(reqCtx, acc) + cancel() + + now := time.Now().Format(time.RFC3339) + if fetchErr != nil || len(models) == 0 { + failed++ + msg := "" + if fetchErr != nil { + msg = fetchErr.Error() + } + _ = s.accountRepo.UpdateExtra(ctx, acc.ID, map[string]any{ + AccountExtraKeyAvailableModelsSource: "github_copilot", + AccountExtraKeyAvailableModelsError: msg, + AccountExtraKeyAvailableModelsErrorAt: now, + }) + continue + } + + ids := make([]string, 0, len(models)) + for _, m := range models { + if id := strings.TrimSpace(m.ID); id != "" { + ids = append(ids, id) + } + } + if len(ids) == 0 { + failed++ + _ = s.accountRepo.UpdateExtra(ctx, acc.ID, map[string]any{ + AccountExtraKeyAvailableModelsSource: "github_copilot", + AccountExtraKeyAvailableModelsError: "copilot models response contained no model ids", + AccountExtraKeyAvailableModelsErrorAt: now, + }) + continue + } + + refreshed++ + _ = s.accountRepo.UpdateExtra(ctx, acc.ID, map[string]any{ + AccountExtraKeyAvailableModels: ids, + AccountExtraKeyAvailableModelsUpdatedAt: now, + AccountExtraKeyAvailableModelsSource: "github_copilot", + AccountExtraKeyAvailableModelsError: "", + AccountExtraKeyAvailableModelsErrorAt: "", + }) + } + + log.Printf("[CopilotModelRefresh] Cycle complete: refreshed=%d failed=%d skipped=%d", refreshed, failed, skipped) +} diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index 040b2357b1..e6cbd27131 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -1163,6 +1163,7 @@ func crsLogin(ctx context.Context, client *http.Client, baseURL, username, passw } req.Header.Set("Content-Type", "application/json") + // #nosec G704 -- baseURL is admin-configured and validated by fetchCRSExport (allowlist when enabled) resp, err := client.Do(req) if err != nil { return "", err @@ -1198,6 +1199,7 @@ func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminT } req.Header.Set("Authorization", "Bearer "+adminToken) + // #nosec G704 -- baseURL is admin-configured and validated by fetchCRSExport (allowlist when enabled) resp, err := client.Do(req) if err != nil { return nil, err diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 0295c23bdc..3a4e64a1e8 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -22,10 +22,31 @@ const ( const ( PlatformAnthropic = domain.PlatformAnthropic PlatformOpenAI = domain.PlatformOpenAI + PlatformCopilot = domain.PlatformCopilot + PlatformAggregator = domain.PlatformAggregator PlatformGemini = domain.PlatformGemini PlatformAntigravity = domain.PlatformAntigravity ) +// Provider constants +const ( + ProviderOpenAI = domain.ProviderOpenAI + ProviderAzure = domain.ProviderAzure + ProviderCopilot = domain.ProviderCopilot + ProviderAnthropic = domain.ProviderAnthropic + ProviderGemini = domain.ProviderGemini + ProviderVertexAI = domain.ProviderVertexAI + ProviderAntigravity = domain.ProviderAntigravity + ProviderBedrock = domain.ProviderBedrock + ProviderOpenRouter = domain.ProviderOpenRouter + ProviderAggregator = domain.ProviderAggregator +) + +// GetPlatformFromProvider wraps domain.GetPlatformFromProvider +func GetPlatformFromProvider(provider string) string { + return domain.GetPlatformFromProvider(provider) +} + // Account type constants const ( AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index b4b93aceca..071f33373b 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -126,6 +126,9 @@ func (m *mockAccountRepoForPlatform) ListSchedulable(ctx context.Context) ([]Acc func (m *mockAccountRepoForPlatform) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) { return nil, nil } +func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDs(ctx context.Context, groupIDs []int64) ([]Account, error) { + return nil, nil +} func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { var result []Account platformSet := make(map[string]bool) @@ -256,6 +259,23 @@ func (m *mockGroupRepoForGateway) ListActive(ctx context.Context) ([]Group, erro func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) { return nil, nil } + +func (m *mockGroupRepoForGateway) ListPublicGroupIDs(ctx context.Context) ([]int64, error) { + ids := make([]int64, 0, len(m.groups)) + for id, g := range m.groups { + if g == nil { + continue + } + if g.Status == StatusActive && !g.IsExclusive { + if g.ID > 0 { + ids = append(ids, g.ID) + } else { + ids = append(ids, id) + } + } + } + return ids, nil +} func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) { return false, nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 56af4610bc..30d911cb92 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -413,6 +413,7 @@ type GatewayService struct { deferredService *DeferredService concurrencyService *ConcurrencyService claudeTokenProvider *ClaudeTokenProvider + githubCopilotToken *GitHubCopilotTokenProvider sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) } @@ -435,6 +436,7 @@ func NewGatewayService( httpUpstream HTTPUpstream, deferredService *DeferredService, claudeTokenProvider *ClaudeTokenProvider, + githubCopilotTokenProvider *GitHubCopilotTokenProvider, sessionLimitCache SessionLimitCache, digestStore *DigestSessionStore, ) *GatewayService { @@ -457,6 +459,7 @@ func NewGatewayService( httpUpstream: httpUpstream, deferredService: deferredService, claudeTokenProvider: claudeTokenProvider, + githubCopilotToken: githubCopilotTokenProvider, sessionLimitCache: sessionLimitCache, } } @@ -2562,6 +2565,20 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( // Both oauth and setup-token use OAuth token flow return s.getOAuthToken(ctx, account) case AccountTypeAPIKey: + if isGitHubCopilotAccount(account) && s.githubCopilotToken != nil { + copilotToken, err := s.githubCopilotToken.GetAccessToken(ctx, account) + if err == nil && strings.TrimSpace(copilotToken) != "" { + return copilotToken, "github_copilot", nil + } + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey != "" { + return apiKey, "apikey", nil + } + if err != nil { + return "", "", err + } + return "", "", errors.New("api_key not found in credentials") + } apiKey := account.GetCredential("api_key") if apiKey == "" { return "", "", errors.New("api_key not found in credentials") @@ -3017,6 +3034,15 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + isGitHubCopilot := isGitHubCopilotAccount(account) + copilotVision := false + copilotInitiator := "user" + if isGitHubCopilot { + copilotVision = githubCopilotVisionEnabledFromClaudeMessagesPayload(parsed.Messages) + copilotInitiator = githubCopilotInitiatorFromClaudeMessagesPayload(parsed.Messages) + } + thinkingEnabled := parsed.ThinkingEnabled + if shouldMimicClaudeCode { // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 @@ -3096,7 +3122,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) // Capture upstream request body for ops retry of this attempt. c.Set(OpsUpstreamRequestBodyKey, string(body)) - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode, isGitHubCopilot, copilotVision, copilotInitiator, thinkingEnabled) if err != nil { return nil, err } @@ -3128,6 +3154,28 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, fmt.Errorf("upstream request failed: %s", safeErr) } + if isGitHubCopilot && resp.StatusCode == http.StatusUnauthorized && s.githubCopilotToken != nil && tokenType == "github_copilot" { + refreshed := "" + s.githubCopilotToken.Invalidate(ctx, account) + if t, refreshErr := s.githubCopilotToken.GetAccessToken(ctx, account); refreshErr == nil { + refreshed = strings.TrimSpace(t) + } + if refreshed != "" { + retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, body, refreshed, tokenType, reqModel, reqStream, shouldMimicClaudeCode, isGitHubCopilot, copilotVision, copilotInitiator, thinkingEnabled) + if buildErr == nil { + retryResp, doErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if doErr == nil { + _ = resp.Body.Close() + resp = retryResp + token = refreshed + } + if doErr != nil && retryResp != nil && retryResp.Body != nil { + _ = retryResp.Body.Close() + } + } + } + } + // 优先检测thinking block签名错误(400)并重试一次 if resp.StatusCode == 400 { respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) @@ -3174,7 +3222,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // also downgrade tool_use/tool_result blocks to text. filteredBody := FilterThinkingBlocksForRetry(body) - retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode, isGitHubCopilot, copilotVision, copilotInitiator, thinkingEnabled) if buildErr == nil { retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { @@ -3206,7 +3254,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) - retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode, isGitHubCopilot, copilotVision, copilotInitiator, thinkingEnabled) if buildErr2 == nil { retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr2 == nil { @@ -3459,7 +3507,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }, nil } -func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { +func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool, isGitHubCopilot bool, copilotVision bool, copilotInitiator string, thinkingEnabled bool) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL if account.Type == AccountTypeAPIKey { @@ -3469,7 +3517,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages" + targetURL = anthropicMessagesURLFromBaseURL(validatedURL) } } @@ -3506,7 +3554,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } // 设置认证头 - if tokenType == "oauth" { + if tokenType == "oauth" || tokenType == "github_copilot" { req.Header.Set("authorization", "Bearer "+token) } else { req.Header.Set("x-api-key", token) @@ -3527,6 +3575,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex s.identityService.ApplyFingerprint(req, fingerprint) } + if isGitHubCopilot { + applyGitHubCopilotHeaders(req, copilotVision, copilotInitiator) + } + // 确保必要的headers存在 if req.Header.Get("content-type") == "" { req.Header.Set("content-type", "application/json") @@ -3537,6 +3589,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex if tokenType == "oauth" { applyClaudeOAuthHeaderDefaults(req, reqStream) } + if reqStream && strings.TrimSpace(req.Header.Get("accept")) == "" { + req.Header.Set("accept", "text/event-stream") + } // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) if tokenType == "oauth" { @@ -3567,6 +3622,21 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } + if isGitHubCopilot { + incomingBeta := req.Header.Get("anthropic-beta") + if strings.TrimSpace(incomingBeta) == "" { + if requestNeedsBetaFeatures(body) { + incomingBeta = defaultAPIKeyBetaHeader(body) + } + } + required := []string{} + if thinkingEnabled { + required = append(required, claude.BetaInterleavedThinking) + } + drop := map[string]struct{}{claude.BetaClaudeCode: {}} + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(required, incomingBeta, drop)) + } + // Always capture a compact fingerprint line for later error diagnostics. // We only print it when needed (or when the explicit debug flag is enabled). if c != nil && tokenType == "oauth" { @@ -4526,6 +4596,16 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu user := input.User account := input.Account subscription := input.Subscription + if subscription == nil && apiKey != nil && user != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() { + if s.userSubRepo == nil { + return errors.New("subscription repository not configured") + } + sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, apiKey.Group.ID) + if err != nil { + return fmt.Errorf("get active subscription: %w", err) + } + subscription = sub + } // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens // 用于粘性会话切换时的特殊计费处理 @@ -5032,7 +5112,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages/count_tokens" + targetURL = anthropicCountTokensURLFromBaseURL(validatedURL) } } @@ -5061,7 +5141,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } // 设置认证头 - if tokenType == "oauth" { + if tokenType == "oauth" || tokenType == "github_copilot" { req.Header.Set("authorization", "Bearer "+token) } else { req.Header.Set("x-api-key", token) @@ -5085,6 +5165,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } + if isGitHubCopilotAccount(account) { + applyGitHubCopilotHeaders(req, false, "user") + } + // 确保必要的 headers 存在 if req.Header.Get("content-type") == "" { req.Header.Set("content-type", "application/json") @@ -5125,6 +5209,12 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } + if isGitHubCopilotAccount(account) { + incomingBeta := req.Header.Get("anthropic-beta") + drop := map[string]struct{}{claude.BetaClaudeCode: {}} + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(nil, incomingBeta, drop)) + } + if c != nil && tokenType == "oauth" { c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) } @@ -5165,8 +5255,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { return normalized, nil } -// GetAvailableModels returns the list of models available for a group -// It aggregates model_mapping keys from all schedulable accounts in the group func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { var accounts []Account var err error @@ -5194,20 +5282,29 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, // Collect unique models from all accounts modelSet := make(map[string]struct{}) - hasAnyMapping := false + hasAnyModels := false for _, acc := range accounts { + if isGitHubCopilotAccount(&acc) { + if ids := acc.GetAvailableModels(); len(ids) > 0 { + hasAnyModels = true + for _, id := range ids { + modelSet[id] = struct{}{} + } + continue + } + } + mapping := acc.GetModelMapping() if len(mapping) > 0 { - hasAnyMapping = true + hasAnyModels = true for model := range mapping { modelSet[model] = struct{}{} } } } - // If no account has model_mapping, return nil (use default) - if !hasAnyMapping { + if !hasAnyModels { return nil } @@ -5220,6 +5317,91 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, return models } +func (s *GatewayService) GetAvailableModelsByGroupIDs(ctx context.Context, groupIDs []int64, platform string) []string { + if len(groupIDs) == 0 { + accounts, err := s.accountRepo.ListSchedulable(ctx) + if err != nil || len(accounts) == 0 { + return nil + } + return s.collectModelsFromAccounts(accounts, platform) + } + + accounts, err := s.accountRepo.ListSchedulableByGroupIDs(ctx, groupIDs) + if err != nil || len(accounts) == 0 { + return nil + } + + return s.collectModelsFromAccounts(accounts, platform) +} + +func (s *GatewayService) collectModelsFromAccounts(accounts []Account, platform string) []string { + if platform != "" { + filtered := make([]Account, 0) + for _, acc := range accounts { + if acc.Platform == platform { + filtered = append(filtered, acc) + } + } + accounts = filtered + } + + modelSet := make(map[string]struct{}) + hasAnyModels := false + + namespacedModel := func(provider, model string) string { + m := strings.TrimSpace(model) + if m == "" { + return "" + } + if strings.Contains(m, "/") { + return m + } + p := strings.TrimSpace(provider) + if p == "" { + return m + } + return p + "/" + m + } + + for _, acc := range accounts { + provider := inferProviderFromAccount(&acc) + if isGitHubCopilotAccount(&acc) { + if ids := acc.GetAvailableModels(); len(ids) > 0 { + hasAnyModels = true + for _, id := range ids { + nsID := namespacedModel(provider, id) + if nsID != "" { + modelSet[nsID] = struct{}{} + } + } + continue + } + } + + mapping := acc.GetModelMapping() + if len(mapping) > 0 { + hasAnyModels = true + for model := range mapping { + nsModel := namespacedModel(provider, model) + if nsModel != "" { + modelSet[nsModel] = struct{}{} + } + } + } + } + + if !hasAnyModels { + return nil + } + + models := make([]string, 0, len(modelSet)) + for model := range modelSet { + models = append(models, model) + } + + return models +} + // reconcileCachedTokens 兼容 Kimi 等上游: // 将 OpenAI 风格的 cached_tokens 映射到 Claude 标准的 cache_read_input_tokens func reconcileCachedTokens(usage map[string]any) bool { @@ -5237,3 +5419,72 @@ func reconcileCachedTokens(usage map[string]any) bool { usage["cache_read_input_tokens"] = cached return true } + +func (s *GatewayService) GetAccessibleGroupIDs(ctx context.Context, allowedGroups []int64) ([]int64, error) { + publicIDs, err := s.groupRepo.ListPublicGroupIDs(ctx) + if err != nil { + return nil, err + } + + seen := make(map[int64]struct{}) + for _, id := range publicIDs { + seen[id] = struct{}{} + } + for _, id := range allowedGroups { + seen[id] = struct{}{} + } + + result := make([]int64, 0, len(seen)) + for id := range seen { + result = append(result, id) + } + return result, nil +} + +func (s *GatewayService) GetAccountGroupForBilling(account *Account) *Group { + if account == nil || len(account.Groups) == 0 { + return nil + } + return account.Groups[0] +} + +func (s *GatewayService) ResolveGroupFromUserPermission(ctx context.Context, allowedGroups []int64, requestedModel string) (*Group, error) { + accessibleIDs, err := s.GetAccessibleGroupIDs(ctx, allowedGroups) + if err != nil { + return nil, fmt.Errorf("get accessible groups: %w", err) + } + if len(accessibleIDs) == 0 { + return nil, errors.New("no accessible groups") + } + + if requestedModel != "" { + for _, groupID := range accessibleIDs { + accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, groupID) + if err != nil || len(accounts) == 0 { + continue + } + for _, acc := range accounts { + mapping := acc.GetModelMapping() + if len(mapping) == 0 { + group, err := s.groupRepo.GetByIDLite(ctx, groupID) + if err == nil { + return group, nil + } + continue + } + if _, ok := mapping[requestedModel]; ok { + group, err := s.groupRepo.GetByIDLite(ctx, groupID) + if err == nil { + return group, nil + } + } + } + } + } + + group, err := s.groupRepo.GetByIDLite(ctx, accessibleIDs[0]) + if err != nil { + return nil, fmt.Errorf("get group %d: %w", accessibleIDs[0], err) + } + return group, nil +} diff --git a/backend/internal/service/gateway_service_models_provider_test.go b/backend/internal/service/gateway_service_models_provider_test.go new file mode 100644 index 0000000000..2e490e4ac9 --- /dev/null +++ b/backend/internal/service/gateway_service_models_provider_test.go @@ -0,0 +1,43 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGatewayService_CollectModelsFromAccounts_ProviderNamespaced(t *testing.T) { + svc := &GatewayService{} + accounts := []Account{ + { + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "base_url": "https://api.openai.com", + "model_mapping": map[string]any{"gpt-5.2": "gpt-5.2"}, + }, + }, + { + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "base_url": "https://foo.openai.azure.com", + "model_mapping": map[string]any{"gpt-5.2": "gpt-5.2"}, + }, + }, + { + Platform: PlatformCopilot, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + AccountExtraKeyAvailableModels: []any{"gpt-5.2"}, + }, + }, + } + + got := svc.collectModelsFromAccounts(accounts, "") + require.Contains(t, got, "openai/gpt-5.2") + require.Contains(t, got, "azure/gpt-5.2") + require.Contains(t, got, "copilot/gpt-5.2") +} diff --git a/backend/internal/service/gateway_service_record_usage_subscription_test.go b/backend/internal/service/gateway_service_record_usage_subscription_test.go new file mode 100644 index 0000000000..75f78d2ffb --- /dev/null +++ b/backend/internal/service/gateway_service_record_usage_subscription_test.go @@ -0,0 +1,296 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/stretchr/testify/require" +) + +type usageLogRepoStubForRecordUsage struct { + created []*UsageLog +} + +func (s *usageLogRepoStubForRecordUsage) Create(ctx context.Context, log *UsageLog) (inserted bool, err error) { + s.created = append(s.created, log) + return true, nil +} + +func (s *usageLogRepoStubForRecordUsage) GetByID(ctx context.Context, id int64) (*UsageLog, error) { + panic("unexpected GetByID call") +} + +func (s *usageLogRepoStubForRecordUsage) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (s *usageLogRepoStubForRecordUsage) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) { + panic("unexpected ListByUser call") +} + +func (s *usageLogRepoStubForRecordUsage) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) { + panic("unexpected ListByAPIKey call") +} + +func (s *usageLogRepoStubForRecordUsage) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) { + panic("unexpected ListByAccount call") +} + +func (s *usageLogRepoStubForRecordUsage) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) { + panic("unexpected ListByUserAndTimeRange call") +} + +func (s *usageLogRepoStubForRecordUsage) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) { + panic("unexpected ListByAPIKeyAndTimeRange call") +} + +func (s *usageLogRepoStubForRecordUsage) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) { + panic("unexpected ListByAccountAndTimeRange call") +} + +func (s *usageLogRepoStubForRecordUsage) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) { + panic("unexpected ListByModelAndTimeRange call") +} + +func (s *usageLogRepoStubForRecordUsage) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { + panic("unexpected GetAccountWindowStats call") +} + +func (s *usageLogRepoStubForRecordUsage) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) { + panic("unexpected GetAccountTodayStats call") +} + +func (s *usageLogRepoStubForRecordUsage) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { + panic("unexpected GetDashboardStats call") +} + +func (s *usageLogRepoStubForRecordUsage) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { + panic("unexpected GetUsageTrendWithFilters call") +} + +func (s *usageLogRepoStubForRecordUsage) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { + panic("unexpected GetModelStatsWithFilters call") +} + +func (s *usageLogRepoStubForRecordUsage) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { + panic("unexpected GetAPIKeyUsageTrend call") +} + +func (s *usageLogRepoStubForRecordUsage) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) { + panic("unexpected GetUserUsageTrend call") +} + +func (s *usageLogRepoStubForRecordUsage) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { + panic("unexpected GetBatchUserUsageStats call") +} + +func (s *usageLogRepoStubForRecordUsage) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { + panic("unexpected GetBatchAPIKeyUsageStats call") +} + +func (s *usageLogRepoStubForRecordUsage) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) { + panic("unexpected GetUserDashboardStats call") +} + +func (s *usageLogRepoStubForRecordUsage) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) { + panic("unexpected GetAPIKeyDashboardStats call") +} + +func (s *usageLogRepoStubForRecordUsage) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { + panic("unexpected GetUserUsageTrendByUserID call") +} + +func (s *usageLogRepoStubForRecordUsage) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) { + panic("unexpected GetUserModelStats call") +} + +func (s *usageLogRepoStubForRecordUsage) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *usageLogRepoStubForRecordUsage) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + panic("unexpected GetGlobalStats call") +} + +func (s *usageLogRepoStubForRecordUsage) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { + panic("unexpected GetStatsWithFilters call") +} + +func (s *usageLogRepoStubForRecordUsage) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { + panic("unexpected GetAccountUsageStats call") +} + +func (s *usageLogRepoStubForRecordUsage) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + panic("unexpected GetUserStatsAggregated call") +} + +func (s *usageLogRepoStubForRecordUsage) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + panic("unexpected GetAPIKeyStatsAggregated call") +} + +func (s *usageLogRepoStubForRecordUsage) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + panic("unexpected GetAccountStatsAggregated call") +} + +func (s *usageLogRepoStubForRecordUsage) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + panic("unexpected GetModelStatsAggregated call") +} + +func (s *usageLogRepoStubForRecordUsage) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) { + panic("unexpected GetDailyStatsAggregated call") +} + +type userRepoStubForRecordUsage struct { + userRepoStub + deductCalls int +} + +func (s *userRepoStubForRecordUsage) DeductBalance(ctx context.Context, id int64, amount float64) error { + s.deductCalls++ + return nil +} + +type userSubRepoStubForRecordUsage struct { + active *UserSubscription + getActiveCalls int + incrementCalls int + incrementSubIDs []int64 +} + +func (s *userSubRepoStubForRecordUsage) Create(ctx context.Context, sub *UserSubscription) error { + panic("unexpected Create call") +} + +func (s *userSubRepoStubForRecordUsage) GetByID(ctx context.Context, id int64) (*UserSubscription, error) { + panic("unexpected GetByID call") +} + +func (s *userSubRepoStubForRecordUsage) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*UserSubscription, error) { + panic("unexpected GetByUserIDAndGroupID call") +} + +func (s *userSubRepoStubForRecordUsage) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*UserSubscription, error) { + s.getActiveCalls++ + if s.active == nil { + return nil, errors.New("subscription not found") + } + return s.active, nil +} + +func (s *userSubRepoStubForRecordUsage) Update(ctx context.Context, sub *UserSubscription) error { + panic("unexpected Update call") +} + +func (s *userSubRepoStubForRecordUsage) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (s *userSubRepoStubForRecordUsage) ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) { + panic("unexpected ListByUserID call") +} + +func (s *userSubRepoStubForRecordUsage) ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) { + panic("unexpected ListActiveByUserID call") +} + +func (s *userSubRepoStubForRecordUsage) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) { + panic("unexpected ListByGroupID call") +} + +func (s *userSubRepoStubForRecordUsage) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *userSubRepoStubForRecordUsage) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { + panic("unexpected ExistsByUserIDAndGroupID call") +} + +func (s *userSubRepoStubForRecordUsage) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { + panic("unexpected ExtendExpiry call") +} + +func (s *userSubRepoStubForRecordUsage) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { + panic("unexpected UpdateStatus call") +} + +func (s *userSubRepoStubForRecordUsage) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { + panic("unexpected UpdateNotes call") +} + +func (s *userSubRepoStubForRecordUsage) ActivateWindows(ctx context.Context, id int64, start time.Time) error { + panic("unexpected ActivateWindows call") +} + +func (s *userSubRepoStubForRecordUsage) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + panic("unexpected ResetDailyUsage call") +} + +func (s *userSubRepoStubForRecordUsage) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + panic("unexpected ResetWeeklyUsage call") +} + +func (s *userSubRepoStubForRecordUsage) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + panic("unexpected ResetMonthlyUsage call") +} + +func (s *userSubRepoStubForRecordUsage) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + s.incrementCalls++ + s.incrementSubIDs = append(s.incrementSubIDs, id) + return nil +} + +func (s *userSubRepoStubForRecordUsage) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { + panic("unexpected BatchUpdateExpiredStatus call") +} + +func TestGatewayService_RecordUsage_SubscriptionGroupWithoutSubscriptionObject(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeStandard} + billingService := NewBillingService(cfg, nil) + + usageRepo := &usageLogRepoStubForRecordUsage{} + userRepo := &userRepoStubForRecordUsage{} + subRepo := &userSubRepoStubForRecordUsage{active: &UserSubscription{ID: 99}} + + svc := &GatewayService{ + usageLogRepo: usageRepo, + userRepo: userRepo, + userSubRepo: subRepo, + cfg: cfg, + billingService: billingService, + billingCacheService: &BillingCacheService{cfg: cfg}, + deferredService: &DeferredService{}, + } + + user := &User{ID: 1} + group := &Group{ID: 2, SubscriptionType: SubscriptionTypeSubscription, Status: StatusActive} + groupID := group.ID + apiKey := &APIKey{ID: 3, User: user, Group: group, GroupID: &groupID} + account := &Account{ID: 4, Platform: PlatformAnthropic, Status: StatusActive, Schedulable: true} + result := &ForwardResult{ + RequestID: "req_1", + Model: "claude-3-5-haiku", + Usage: ClaudeUsage{ + InputTokens: 100, + OutputTokens: 0, + }, + Duration: 1 * time.Second, + } + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: user, + Account: account, + Subscription: nil, + }) + require.NoError(t, err) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 1, subRepo.incrementCalls) +} diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 080352ba42..4515c79ec7 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -111,6 +111,9 @@ func (m *mockAccountRepoForGemini) ListSchedulable(ctx context.Context) ([]Accou func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) { return nil, nil } +func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDs(ctx context.Context, groupIDs []int64) ([]Account, error) { + return nil, nil +} func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { if m.listByPlatformFunc != nil { return m.listByPlatformFunc(ctx, platforms) @@ -208,6 +211,22 @@ func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) { return nil, nil } +func (m *mockGroupRepoForGemini) ListPublicGroupIDs(ctx context.Context) ([]int64, error) { + ids := make([]int64, 0, len(m.groups)) + for id, g := range m.groups { + if g == nil { + continue + } + if g.Status == StatusActive && !g.IsExclusive { + if g.ID > 0 { + ids = append(ids, g.ID) + } else { + ids = append(ids, id) + } + } + } + return ids, nil +} func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) { return false, nil } diff --git a/backend/internal/service/gemini_native_claude_compat.go b/backend/internal/service/gemini_native_claude_compat.go new file mode 100644 index 0000000000..823689da36 --- /dev/null +++ b/backend/internal/service/gemini_native_claude_compat.go @@ -0,0 +1,258 @@ +package service + +import ( + "encoding/json" + "errors" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" +) + +func ConvertGeminiNativeRequestToClaudeMessages(model string, body []byte) ([]byte, error) { + model = strings.TrimSpace(model) + if model == "" { + return nil, errors.New("missing model") + } + + var req antigravity.GeminiRequest + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + claudeReq := map[string]any{ + "model": model, + "max_tokens": 1024, + "stream": false, + "messages": geminiContentsToClaudeMessages(req.Contents), + } + + if req.GenerationConfig != nil { + if req.GenerationConfig.MaxOutputTokens > 0 { + claudeReq["max_tokens"] = req.GenerationConfig.MaxOutputTokens + } + if req.GenerationConfig.Temperature != nil { + claudeReq["temperature"] = *req.GenerationConfig.Temperature + } + if req.GenerationConfig.TopP != nil { + claudeReq["top_p"] = *req.GenerationConfig.TopP + } + if len(req.GenerationConfig.StopSequences) > 0 { + stop := make([]any, 0, len(req.GenerationConfig.StopSequences)) + for _, s := range req.GenerationConfig.StopSequences { + if strings.TrimSpace(s) != "" { + stop = append(stop, s) + } + } + if len(stop) > 0 { + claudeReq["stop_sequences"] = stop + } + } + } + + if sys := geminiSystemText(req.SystemInstruction); sys != "" { + claudeReq["system"] = sys + } + + if tools := geminiToolsToClaudeTools(req.Tools); len(tools) > 0 { + claudeReq["tools"] = tools + } + + return json.Marshal(claudeReq) +} + +func ConvertClaudeMessageToGeminiResponse(claudeResp map[string]any, usage *ClaudeUsage) (map[string]any, error) { + if claudeResp == nil { + return nil, errors.New("empty response") + } + + if usage == nil { + usage = extractClaudeUsageFromResponse(claudeResp) + if usage == nil { + usage = &ClaudeUsage{} + } + } + + finishReason := "STOP" + if sr, ok := claudeResp["stop_reason"].(string); ok { + if strings.EqualFold(strings.TrimSpace(sr), "max_tokens") { + finishReason = "MAX_TOKENS" + } + } + + parts := make([]any, 0) + if content, ok := claudeResp["content"].([]any); ok { + for _, b := range content { + bm, ok := b.(map[string]any) + if !ok { + continue + } + bt, _ := bm["type"].(string) + switch strings.ToLower(strings.TrimSpace(bt)) { + case "text": + if text, ok := bm["text"].(string); ok { + parts = append(parts, map[string]any{"text": text}) + } + case "tool_use": + name, _ := bm["name"].(string) + id, _ := bm["id"].(string) + call := map[string]any{ + "name": strings.TrimSpace(name), + "args": bm["input"], + } + if strings.TrimSpace(id) != "" { + call["id"] = strings.TrimSpace(id) + } + parts = append(parts, map[string]any{"functionCall": call}) + } + } + } else if s, ok := claudeResp["content"].(string); ok { + if strings.TrimSpace(s) != "" { + parts = append(parts, map[string]any{"text": s}) + } + } + + prompt := usage.InputTokens + usage.CacheReadInputTokens + resp := map[string]any{ + "candidates": []any{ + map[string]any{ + "content": map[string]any{ + "role": "model", + "parts": parts, + }, + "finishReason": finishReason, + "index": 0, + }, + }, + "usageMetadata": map[string]any{ + "promptTokenCount": prompt, + "candidatesTokenCount": usage.OutputTokens, + "cachedContentTokenCount": usage.CacheReadInputTokens, + "totalTokenCount": prompt + usage.OutputTokens, + }, + } + return resp, nil +} + +func geminiSystemText(sys *antigravity.GeminiContent) string { + if sys == nil { + return "" + } + parts := make([]string, 0, len(sys.Parts)) + for _, p := range sys.Parts { + if s := strings.TrimSpace(p.Text); s != "" { + parts = append(parts, s) + } + } + return strings.Join(parts, "\n") +} + +func geminiToolsToClaudeTools(tools []antigravity.GeminiToolDeclaration) []any { + out := make([]any, 0) + for _, td := range tools { + for _, fd := range td.FunctionDeclarations { + name := strings.TrimSpace(fd.Name) + if name == "" { + continue + } + params := fd.Parameters + if params == nil { + params = map[string]any{"type": "object", "properties": map[string]any{}} + } + out = append(out, map[string]any{ + "name": name, + "description": strings.TrimSpace(fd.Description), + "input_schema": params, + }) + } + } + if len(out) == 0 { + return nil + } + return out +} + +func geminiContentsToClaudeMessages(contents []antigravity.GeminiContent) []any { + nameToCallID := make(map[string]string) + out := make([]any, 0, len(contents)) + for _, c := range contents { + role := strings.ToLower(strings.TrimSpace(c.Role)) + claudeRole := "user" + if role == "model" { + claudeRole = "assistant" + } + + blocks := make([]any, 0) + for _, p := range c.Parts { + if s := strings.TrimSpace(p.Text); s != "" { + blocks = append(blocks, map[string]any{"type": "text", "text": p.Text}) + } + + if p.InlineData != nil { + mt := strings.TrimSpace(p.InlineData.MimeType) + data := strings.TrimSpace(p.InlineData.Data) + if mt != "" && data != "" { + blocks = append(blocks, map[string]any{ + "type": "image", + "source": map[string]any{ + "type": "base64", + "media_type": mt, + "data": data, + }, + }) + } + } + + if p.FunctionCall != nil { + name := strings.TrimSpace(p.FunctionCall.Name) + callID := strings.TrimSpace(p.FunctionCall.ID) + if callID == "" { + callID = nameToCallID[name] + } + if callID == "" { + callID = "toolu_" + randomHex(12) + } + if name != "" { + nameToCallID[name] = callID + } + blocks = append(blocks, map[string]any{ + "type": "tool_use", + "id": callID, + "name": name, + "input": p.FunctionCall.Args, + }) + } + + if p.FunctionResponse != nil { + name := strings.TrimSpace(p.FunctionResponse.Name) + callID := strings.TrimSpace(p.FunctionResponse.ID) + if callID == "" { + callID = nameToCallID[name] + } + outText := "" + if v, ok := p.FunctionResponse.Response["content"].(string); ok { + outText = v + } else if b, err := json.Marshal(p.FunctionResponse.Response); err == nil { + outText = string(b) + } + if callID != "" { + blocks = append(blocks, map[string]any{ + "type": "tool_result", + "tool_use_id": callID, + "content": outText, + }) + } else if strings.TrimSpace(outText) != "" { + blocks = append(blocks, map[string]any{"type": "text", "text": outText}) + } + } + } + + if len(blocks) == 0 { + continue + } + out = append(out, map[string]any{ + "role": claudeRole, + "content": blocks, + }) + } + return out +} diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index fd2932e6a1..da437b8856 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -1049,6 +1049,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR client = &http.Client{Timeout: 30 * time.Second} } + // #nosec G704 -- request targets fixed Google API host resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("resource manager request failed: %w", err) diff --git a/backend/internal/service/github_copilot_helpers.go b/backend/internal/service/github_copilot_helpers.go new file mode 100644 index 0000000000..0fad473a42 --- /dev/null +++ b/backend/internal/service/github_copilot_helpers.go @@ -0,0 +1,274 @@ +package service + +import ( + "net/http" + "net/url" + "strings" + + "github.com/google/uuid" +) + +const ( + githubCopilotDefaultVSCodeVersion = "1.109.2" + githubCopilotDefaultCopilotChatVersion = "0.37.4" + githubCopilotDefaultGitHubAPIVersionHeaderDate = "2025-10-01" + githubCopilotDefaultIntegrationID = "vscode-chat" + githubCopilotDefaultOpenAIIntent = "conversation-agent" + githubCopilotDefaultUserAgentLibraryVersion = "electron-fetch" +) + +func isGitHubCopilotBaseURL(raw string) bool { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return false + } + candidate := trimmed + if !strings.Contains(candidate, "://") { + candidate = "https://" + candidate + } + parsed, err := url.Parse(candidate) + if err != nil { + return false + } + host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + if host == "" { + return false + } + if host == "api.githubcopilot.com" { + return true + } + return strings.HasSuffix(host, ".githubcopilot.com") +} + +func isGitHubCopilotAccount(account *Account) bool { + if account == nil { + return false + } + if account.Type != AccountTypeAPIKey { + return false + } + if account.Platform == PlatformCopilot { + return true + } + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + if baseURL == "" { + return false + } + return isGitHubCopilotBaseURL(baseURL) +} + +func githubCopilotDefaultEditorVersion() string { + return "vscode/" + githubCopilotDefaultVSCodeVersion +} + +func githubCopilotDefaultEditorPluginVersion() string { + return "copilot-chat/" + githubCopilotDefaultCopilotChatVersion +} + +func githubCopilotDefaultUserAgent() string { + return "GitHubCopilotChat/" + githubCopilotDefaultCopilotChatVersion +} + +func githubCopilotVisionEnabledFromResponsesPayload(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"] + if !ok || input == nil { + return false + } + items, ok := input.([]any) + if !ok { + return false + } + for _, item := range items { + m, ok := item.(map[string]any) + if !ok { + continue + } + if t, _ := m["type"].(string); t == "input_image" { + return true + } + content, ok := m["content"] + if !ok || content == nil { + continue + } + blocks, ok := content.([]any) + if !ok { + continue + } + for _, block := range blocks { + bm, ok := block.(map[string]any) + if !ok { + continue + } + bt, _ := bm["type"].(string) + if bt == "input_image" || bt == "image_url" { + return true + } + } + } + return false +} + +func githubCopilotInitiatorFromResponsesPayload(reqBody map[string]any) string { + if reqBody == nil { + return "user" + } + input, ok := reqBody["input"] + if !ok || input == nil { + return "user" + } + items, ok := input.([]any) + if !ok || len(items) == 0 { + return "user" + } + last, ok := items[len(items)-1].(map[string]any) + if !ok { + return "user" + } + if t, _ := last["type"].(string); t != "" { + switch t { + case "function_call", "function_call_output": + return "agent" + case "message": + } + } + role, _ := last["role"].(string) + switch role { + case "assistant", "tool": + return "agent" + default: + return "user" + } +} + +func githubCopilotVisionEnabledFromClaudeMessagesPayload(messages []any) bool { + if len(messages) == 0 { + return false + } + for _, msg := range messages { + m, ok := msg.(map[string]any) + if !ok { + continue + } + content, ok := m["content"] + if !ok || content == nil { + continue + } + blocks, ok := content.([]any) + if !ok { + continue + } + for _, block := range blocks { + bm, ok := block.(map[string]any) + if !ok { + continue + } + bt, _ := bm["type"].(string) + switch bt { + case "image", "image_url", "input_image": + return true + } + } + } + return false +} + +func githubCopilotInitiatorFromClaudeMessagesPayload(messages []any) string { + if len(messages) == 0 { + return "user" + } + last, ok := messages[len(messages)-1].(map[string]any) + if !ok { + return "user" + } + role, _ := last["role"].(string) + if role != "user" { + return "agent" + } + content, ok := last["content"] + if !ok || content == nil { + return "user" + } + blocks, ok := content.([]any) + if !ok { + return "user" + } + for _, block := range blocks { + bm, ok := block.(map[string]any) + if !ok { + return "user" + } + bt, _ := bm["type"].(string) + if bt != "tool_result" { + return "user" + } + } + return "agent" +} + +func IsGitHubCopilotBaseURL(raw string) bool { + return isGitHubCopilotBaseURL(raw) +} + +func IsGitHubCopilotAccount(account *Account) bool { + return isGitHubCopilotAccount(account) +} + +func githubCopilotModelsURLFromBaseURL(normalizedBaseURL string) string { + base := strings.TrimRight(strings.TrimSpace(normalizedBaseURL), "/") + if strings.HasSuffix(base, "/models") { + if strings.HasSuffix(base, "/v1/models") { + base = strings.TrimSuffix(base, "/v1/models") + base = strings.TrimRight(base, "/") + return base + "/models" + } + return base + } + if strings.HasSuffix(base, "/v1") { + base = strings.TrimSuffix(base, "/v1") + base = strings.TrimRight(base, "/") + } + return base + "/models" +} + +func applyGitHubCopilotHeaders(req *http.Request, vision bool, initiator string) { + if req == nil { + return + } + if initiator != "agent" { + initiator = "user" + } + req.Header.Set("copilot-integration-id", githubCopilotDefaultIntegrationID) + req.Header.Set("editor-version", githubCopilotDefaultEditorVersion()) + req.Header.Set("editor-plugin-version", githubCopilotDefaultEditorPluginVersion()) + req.Header.Set("user-agent", githubCopilotDefaultUserAgent()) + req.Header.Set("openai-intent", githubCopilotDefaultOpenAIIntent) + req.Header.Set("x-github-api-version", githubCopilotDefaultGitHubAPIVersionHeaderDate) + req.Header.Set("x-request-id", uuid.NewString()) + req.Header.Set("x-vscode-user-agent-library-version", githubCopilotDefaultUserAgentLibraryVersion) + req.Header.Set("X-Initiator", initiator) + if vision { + req.Header.Set("copilot-vision-request", "true") + } +} + +func applyGitHubCopilotTokenExchangeHeaders(req *http.Request, githubToken string) { + if req == nil { + return + } + gh := strings.TrimSpace(githubToken) + if gh != "" { + req.Header.Set("authorization", "token "+gh) + } + if strings.TrimSpace(req.Header.Get("accept")) == "" { + req.Header.Set("accept", "application/json") + } + req.Header.Set("copilot-integration-id", githubCopilotDefaultIntegrationID) + req.Header.Set("editor-version", githubCopilotDefaultEditorVersion()) + req.Header.Set("editor-plugin-version", githubCopilotDefaultEditorPluginVersion()) + req.Header.Set("user-agent", githubCopilotDefaultUserAgent()) + req.Header.Set("x-github-api-version", githubCopilotDefaultGitHubAPIVersionHeaderDate) + req.Header.Set("x-vscode-user-agent-library-version", githubCopilotDefaultUserAgentLibraryVersion) +} diff --git a/backend/internal/service/github_copilot_models.go b/backend/internal/service/github_copilot_models.go new file mode 100644 index 0000000000..a854c12579 --- /dev/null +++ b/backend/internal/service/github_copilot_models.go @@ -0,0 +1,26 @@ +package service + +import ( + "context" + "errors" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" +) + +func (s *GatewayService) ListGitHubCopilotModels(ctx context.Context, groupID *int64) ([]openai.Model, error) { + if s == nil { + return nil, errors.New("gateway service is nil") + } + if s.githubCopilotToken == nil { + return nil, errors.New("github copilot token provider not configured") + } + accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformCopilot, false) + if err != nil { + return nil, err + } + if len(accounts) == 0 { + return nil, errors.New("no available copilot accounts") + } + acc := &accounts[0] + return s.githubCopilotToken.ListModels(ctx, acc) +} diff --git a/backend/internal/service/github_copilot_token_provider.go b/backend/internal/service/github_copilot_token_provider.go new file mode 100644 index 0000000000..07703cfcf3 --- /dev/null +++ b/backend/internal/service/github_copilot_token_provider.go @@ -0,0 +1,308 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "sort" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" +) + +const ( + githubCopilotTokenExchangeURL = "https://api.github.com/copilot_internal/v2/token" + + githubCopilotTokenMinTTL = 30 * time.Second + githubCopilotTokenSkew = time.Minute + githubCopilotTokenLockTTL = 30 * time.Second + githubCopilotTokenLockWait = 200 * time.Millisecond + githubCopilotTokenMaxBodyLen = 2 << 20 +) + +type GitHubCopilotTokenProvider struct { + tokenCache GeminiTokenCache + httpUpstream HTTPUpstream +} + +func NewGitHubCopilotTokenProvider(tokenCache GeminiTokenCache, httpUpstream HTTPUpstream) *GitHubCopilotTokenProvider { + return &GitHubCopilotTokenProvider{ + tokenCache: tokenCache, + httpUpstream: httpUpstream, + } +} + +func (p *GitHubCopilotTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + if !isGitHubCopilotAccount(account) { + return "", errors.New("not a github copilot apikey account") + } + + githubToken := strings.TrimSpace(account.GetCredential("github_token")) + if githubToken == "" { + githubToken = strings.TrimSpace(account.GetCredential("gh_token")) + } + if githubToken == "" { + return "", errors.New("github_token not found in credentials") + } + + cacheKey := GitHubCopilotTokenCacheKey(account) + + if p.tokenCache != nil { + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + slog.Debug("github_copilot_token_cache_hit", "account_id", account.ID) + return token, nil + } else if err != nil { + slog.Warn("github_copilot_token_cache_get_failed", "account_id", account.ID, "error", err) + } + } + + slog.Debug("github_copilot_token_cache_miss", "account_id", account.ID) + + if p.tokenCache == nil { + return p.exchangeCopilotToken(ctx, account, githubToken) + } + + locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, githubCopilotTokenLockTTL) + if lockErr == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + return p.exchangeAndCacheCopilotToken(ctx, account, githubToken, cacheKey) + } + if lockErr != nil { + slog.Warn("github_copilot_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr) + return p.exchangeAndCacheCopilotToken(ctx, account, githubToken, cacheKey) + } + + timer := time.NewTimer(githubCopilotTokenLockWait) + defer timer.Stop() + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-timer.C: + } + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + slog.Debug("github_copilot_token_cache_hit_after_wait", "account_id", account.ID) + return token, nil + } + + return p.exchangeAndCacheCopilotToken(ctx, account, githubToken, cacheKey) +} + +func (p *GitHubCopilotTokenProvider) Invalidate(ctx context.Context, account *Account) { + if p == nil || p.tokenCache == nil || account == nil { + return + } + _ = p.tokenCache.DeleteAccessToken(ctx, GitHubCopilotTokenCacheKey(account)) +} + +func (p *GitHubCopilotTokenProvider) ListModels(ctx context.Context, account *Account) ([]openai.Model, error) { + if p == nil { + return nil, errors.New("github copilot token provider is nil") + } + if account == nil { + return nil, errors.New("account is nil") + } + if !isGitHubCopilotAccount(account) { + return nil, errors.New("not a github copilot apikey account") + } + if p.httpUpstream == nil { + return nil, errors.New("http upstream is nil") + } + + token, err := p.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + if baseURL == "" { + baseURL = "https://api.githubcopilot.com" + } + normalizedBaseURL, err := urlvalidator.ValidateHTTPSURL(baseURL, urlvalidator.ValidationOptions{ + AllowedHosts: []string{"api.githubcopilot.com", "*.githubcopilot.com"}, + RequireAllowlist: true, + AllowPrivate: false, + }) + if err != nil { + return nil, fmt.Errorf("invalid base_url: %w", err) + } + + modelsURL := githubCopilotModelsURLFromBaseURL(normalizedBaseURL) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil) + if err != nil { + return nil, err + } + applyGitHubCopilotHeaders(req, false, "user") + req.Header.Set("authorization", "Bearer "+strings.TrimSpace(token)) + if strings.TrimSpace(req.Header.Get("accept")) == "" { + req.Header.Set("accept", "application/json") + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := p.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, fmt.Errorf("copilot models request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, githubCopilotTokenMaxBodyLen)) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + msg := strings.TrimSpace(ExtractUpstreamErrorMessage(body)) + msg = sanitizeUpstreamErrorMessage(msg) + if msg == "" { + msg = strings.TrimSpace(string(body)) + msg = sanitizeUpstreamErrorMessage(msg) + } + if msg == "" { + msg = "models request failed" + } + return nil, fmt.Errorf("copilot models request failed: status=%d message=%s", resp.StatusCode, msg) + } + + type copilotModel struct { + ID string `json:"id"` + Name string `json:"name"` + } + var parsed struct { + Data []copilotModel `json:"data"` + Models []copilotModel `json:"models"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("parse copilot models response: %w", err) + } + src := parsed.Data + if len(src) == 0 { + src = parsed.Models + } + if len(src) == 0 { + return nil, errors.New("copilot models response is empty") + } + + seen := make(map[string]struct{}, len(src)) + result := make([]openai.Model, 0, len(src)) + for _, m := range src { + id := strings.TrimSpace(m.ID) + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + display := strings.TrimSpace(m.Name) + if display == "" { + display = id + } + result = append(result, openai.Model{ID: id, Object: "model", Type: "model", DisplayName: display}) + } + if len(result) == 0 { + return nil, errors.New("copilot models response contained no model ids") + } + + sort.Slice(result, func(i, j int) bool { + return result[i].ID < result[j].ID + }) + return result, nil +} + +func (p *GitHubCopilotTokenProvider) exchangeAndCacheCopilotToken(ctx context.Context, account *Account, githubToken, cacheKey string) (string, error) { + token, ttl, err := p.exchangeCopilotTokenWithTTL(ctx, account, githubToken) + if err != nil { + return "", err + } + if p.tokenCache != nil { + if err := p.tokenCache.SetAccessToken(ctx, cacheKey, token, ttl); err != nil { + slog.Warn("github_copilot_token_cache_set_failed", "account_id", account.ID, "error", err) + } + } + return token, nil +} + +func (p *GitHubCopilotTokenProvider) exchangeCopilotToken(ctx context.Context, account *Account, githubToken string) (string, error) { + token, _, err := p.exchangeCopilotTokenWithTTL(ctx, account, githubToken) + return token, err +} + +func (p *GitHubCopilotTokenProvider) exchangeCopilotTokenWithTTL(ctx context.Context, account *Account, githubToken string) (string, time.Duration, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, githubCopilotTokenExchangeURL, nil) + if err != nil { + return "", 0, err + } + applyGitHubCopilotTokenExchangeHeaders(req, githubToken) + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := p.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return "", 0, fmt.Errorf("copilot token exchange request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, githubCopilotTokenMaxBodyLen)) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + msg := strings.TrimSpace(ExtractUpstreamErrorMessage(body)) + msg = sanitizeUpstreamErrorMessage(msg) + if msg == "" { + msg = strings.TrimSpace(string(body)) + msg = sanitizeUpstreamErrorMessage(msg) + } + if msg == "" { + msg = "token exchange failed" + } + return "", 0, fmt.Errorf("copilot token exchange failed: status=%d message=%s", resp.StatusCode, msg) + } + + var parsed struct { + ExpiresAt int64 `json:"expires_at"` + RefreshIn int64 `json:"refresh_in"` + Token string `json:"token"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return "", 0, fmt.Errorf("parse copilot token exchange response: %w", err) + } + token := strings.TrimSpace(parsed.Token) + if token == "" { + return "", 0, errors.New("copilot token is empty") + } + + ttl := githubCopilotTokenTTL(time.Now(), parsed.ExpiresAt, parsed.RefreshIn) + if ttl < githubCopilotTokenMinTTL { + ttl = githubCopilotTokenMinTTL + } + return token, ttl, nil +} + +func githubCopilotTokenTTL(now time.Time, expiresAtSec, refreshInSec int64) time.Duration { + if refreshInSec > 0 { + ttl := time.Duration(refreshInSec)*time.Second - githubCopilotTokenSkew + if ttl > 0 { + return ttl + } + } + if expiresAtSec > 0 { + expiresAt := time.Unix(expiresAtSec, 0) + ttl := expiresAt.Sub(now) - githubCopilotTokenSkew + if ttl > 0 { + return ttl + } + } + return 10 * time.Minute +} diff --git a/backend/internal/service/github_device_auth_service.go b/backend/internal/service/github_device_auth_service.go new file mode 100644 index 0000000000..61d65ee1ea --- /dev/null +++ b/backend/internal/service/github_device_auth_service.go @@ -0,0 +1,306 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +const ( + gitHubDeviceCodeURL = "https://github.com/login/device/code" + gitHubAccessTokenURL = "https://github.com/login/oauth/access_token" + gitHubDeviceGrantType = "urn:ietf:params:oauth:grant-type:device_code" + + gitHubCopilotDefaultDeviceClientID = "Iv1.b507a08c87ecfe98" + gitHubCopilotDefaultDeviceScope = "read:user" + gitHubDeviceAuthSlowDownIncrement = 5 * time.Second + + gitHubDeviceAuthMaxBodyLen = 2 << 20 +) + +type GitHubDeviceAuthService struct { + httpUpstream HTTPUpstream + store GitHubDeviceSessionStore +} + +func NewGitHubDeviceAuthService(store GitHubDeviceSessionStore, httpUpstream HTTPUpstream) *GitHubDeviceAuthService { + if store == nil { + store = NewInMemoryGitHubDeviceSessionStore() + } + return &GitHubDeviceAuthService{ + httpUpstream: httpUpstream, + store: store, + } +} + +type GitHubDeviceAuthStartResult struct { + SessionID string `json:"session_id"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete,omitempty"` + ExpiresIn int64 `json:"expires_in"` + IntervalSeconds int64 `json:"interval"` +} + +type GitHubDeviceAuthPollResult struct { + Status string `json:"status"` + IntervalSeconds int64 `json:"interval,omitempty"` + AccessToken string `json:"access_token,omitempty"` + TokenType string `json:"token_type,omitempty"` + Scope string `json:"scope,omitempty"` + Error string `json:"error,omitempty"` + ErrorDesc string `json:"error_description,omitempty"` +} + +func (s *GitHubDeviceAuthService) Start(ctx context.Context, account *Account, clientID string, scope string) (*GitHubDeviceAuthStartResult, error) { + if s == nil || s.httpUpstream == nil { + return nil, errors.New("github device auth service not configured") + } + if account == nil { + return nil, errors.New("account is nil") + } + + clientID = strings.TrimSpace(clientID) + if clientID == "" { + clientID = gitHubCopilotDefaultDeviceClientID + } + scope = strings.TrimSpace(scope) + if scope == "" { + scope = gitHubCopilotDefaultDeviceScope + } + + form := url.Values{} + form.Set("client_id", clientID) + form.Set("scope", scope) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, gitHubDeviceCodeURL, strings.NewReader(form.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("content-type", "application/x-www-form-urlencoded") + req.Header.Set("accept", "application/json") + if strings.TrimSpace(req.Header.Get("user-agent")) == "" { + req.Header.Set("user-agent", githubCopilotDefaultUserAgent()) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, fmt.Errorf("github device code request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, gitHubDeviceAuthMaxBodyLen)) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("github device code request failed: status=%d body=%s", resp.StatusCode, sanitizeUpstreamErrorMessage(string(body))) + } + + var parsed struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + ExpiresIn int64 `json:"expires_in"` + Interval int64 `json:"interval"` + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("parse github device code response: %w", err) + } + if strings.TrimSpace(parsed.Error) != "" { + msg := strings.TrimSpace(parsed.ErrorDescription) + if msg == "" { + msg = parsed.Error + } + return nil, fmt.Errorf("github device code request failed: %s", sanitizeUpstreamErrorMessage(msg)) + } + if strings.TrimSpace(parsed.DeviceCode) == "" || strings.TrimSpace(parsed.UserCode) == "" || strings.TrimSpace(parsed.VerificationURI) == "" { + return nil, errors.New("github device code response is incomplete") + } + if parsed.ExpiresIn <= 0 { + parsed.ExpiresIn = 900 + } + if parsed.Interval <= 0 { + parsed.Interval = 5 + } + + sessionID, err := newSessionID() + if err != nil { + return nil, err + } + createdAt := time.Now() + expiresAt := createdAt.Add(time.Duration(parsed.ExpiresIn) * time.Second) + if err := s.store.Set(ctx, sessionID, &GitHubDeviceSession{ + AccountID: account.ID, + AccountConcurrency: account.Concurrency, + ProxyURL: proxyURL, + ClientID: clientID, + Scope: scope, + DeviceCode: parsed.DeviceCode, + ExpiresAtUnix: expiresAt.Unix(), + IntervalSeconds: parsed.Interval, + CreatedAtUnix: createdAt.Unix(), + }, time.Until(expiresAt)); err != nil { + return nil, fmt.Errorf("persist github device session failed: %w", err) + } + + return &GitHubDeviceAuthStartResult{ + SessionID: sessionID, + UserCode: parsed.UserCode, + VerificationURI: parsed.VerificationURI, + VerificationURIComplete: strings.TrimSpace(parsed.VerificationURIComplete), + ExpiresIn: parsed.ExpiresIn, + IntervalSeconds: parsed.Interval, + }, nil +} + +func (s *GitHubDeviceAuthService) Poll(ctx context.Context, accountID int64, sessionID string) (*GitHubDeviceAuthPollResult, error) { + if s == nil || s.httpUpstream == nil { + return nil, errors.New("github device auth service not configured") + } + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return nil, errors.New("session_id is required") + } + + sess, ok, err := s.store.Get(ctx, sessionID) + if err != nil { + return nil, fmt.Errorf("load device auth session failed: %w", err) + } + if !ok || sess == nil { + return nil, errors.New("device auth session not found") + } + if sess.AccountID != accountID { + return nil, errors.New("device auth session does not belong to this account") + } + now := time.Now() + expiresAt := time.Unix(sess.ExpiresAtUnix, 0) + if now.After(expiresAt) { + _ = s.store.Delete(ctx, sessionID) + return &GitHubDeviceAuthPollResult{Status: "error", Error: "expired_token", ErrorDesc: "device code expired"}, nil + } + + form := url.Values{} + form.Set("client_id", sess.ClientID) + form.Set("device_code", sess.DeviceCode) + form.Set("grant_type", gitHubDeviceGrantType) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, gitHubAccessTokenURL, strings.NewReader(form.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("content-type", "application/x-www-form-urlencoded") + req.Header.Set("accept", "application/json") + if strings.TrimSpace(req.Header.Get("user-agent")) == "" { + req.Header.Set("user-agent", githubCopilotDefaultUserAgent()) + } + + resp, err := s.httpUpstream.Do(req, sess.ProxyURL, sess.AccountID, sess.AccountConcurrency) + if err != nil { + return nil, fmt.Errorf("github device token poll failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, gitHubDeviceAuthMaxBodyLen)) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("github device token poll failed: status=%d body=%s", resp.StatusCode, sanitizeUpstreamErrorMessage(string(body))) + } + + var parsed struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + Interval int64 `json:"interval"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("parse github device token response: %w", err) + } + + if strings.TrimSpace(parsed.AccessToken) != "" { + _ = s.store.Delete(ctx, sessionID) + return &GitHubDeviceAuthPollResult{ + Status: "success", + AccessToken: strings.TrimSpace(parsed.AccessToken), + TokenType: strings.TrimSpace(parsed.TokenType), + Scope: strings.TrimSpace(parsed.Scope), + }, nil + } + + errCode := strings.TrimSpace(parsed.Error) + if errCode == "" { + return &GitHubDeviceAuthPollResult{Status: "error", Error: "unknown_error", ErrorDesc: "unexpected response"}, nil + } + + if errCode == "authorization_pending" { + return &GitHubDeviceAuthPollResult{Status: "pending", IntervalSeconds: sess.IntervalSeconds}, nil + } + if errCode == "slow_down" { + sess.IntervalSeconds = sess.IntervalSeconds + int64(gitHubDeviceAuthSlowDownIncrement/time.Second) + if parsed.Interval > 0 { + sess.IntervalSeconds = parsed.Interval + } + if err := s.store.Set(ctx, sessionID, sess, time.Until(expiresAt)); err != nil { + return nil, fmt.Errorf("persist slow_down interval failed: %w", err) + } + return &GitHubDeviceAuthPollResult{Status: "pending", IntervalSeconds: sess.IntervalSeconds, Error: errCode, ErrorDesc: strings.TrimSpace(parsed.ErrorDescription)}, nil + } + if errCode == "expired_token" || errCode == "access_denied" { + _ = s.store.Delete(ctx, sessionID) + return &GitHubDeviceAuthPollResult{Status: "error", Error: errCode, ErrorDesc: strings.TrimSpace(parsed.ErrorDescription)}, nil + } + + return &GitHubDeviceAuthPollResult{Status: "error", Error: errCode, ErrorDesc: strings.TrimSpace(parsed.ErrorDescription)}, nil +} + +func (s *GitHubDeviceAuthService) Cancel(ctx context.Context, accountID int64, sessionID string) bool { + if s == nil { + return false + } + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return false + } + sess, ok, err := s.store.Get(ctx, sessionID) + if err != nil { + return false + } + if !ok || sess == nil { + return false + } + now := time.Now() + expiresAt := time.Unix(sess.ExpiresAtUnix, 0) + if now.After(expiresAt) { + _ = s.store.Delete(ctx, sessionID) + return false + } + if sess.AccountID != accountID { + return false + } + if err := s.store.Delete(ctx, sessionID); err != nil { + return false + } + return true +} + +func newSessionID() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} diff --git a/backend/internal/service/github_device_auth_service_test.go b/backend/internal/service/github_device_auth_service_test.go new file mode 100644 index 0000000000..c5a499263f --- /dev/null +++ b/backend/internal/service/github_device_auth_service_test.go @@ -0,0 +1,83 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +type fakeHTTPUpstream struct { + mu sync.Mutex + pollCount int +} + +func (f *fakeHTTPUpstream) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + f.mu.Lock() + defer f.mu.Unlock() + + switch req.URL.String() { + case gitHubDeviceCodeURL: + body := `{"device_code":"dc1","user_code":"uc1","verification_uri":"https://github.com/login/device","verification_uri_complete":"https://github.com/login/device?user_code=uc1","expires_in":900,"interval":5}` + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(body))}, nil + case gitHubAccessTokenURL: + f.pollCount++ + if f.pollCount == 1 { + body := `{"error":"authorization_pending","error_description":"pending"}` + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(body))}, nil + } + if f.pollCount == 2 { + body := `{"error":"slow_down","error_description":"slow","interval":10}` + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(body))}, nil + } + body := `{"access_token":"gho_xxx","token_type":"bearer","scope":"read:user"}` + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(body))}, nil + default: + return &http.Response{StatusCode: 404, Body: io.NopCloser(strings.NewReader(`{"error":"not_found"}`))}, nil + } +} + +func (f *fakeHTTPUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, _ bool) (*http.Response, error) { + return f.Do(req, proxyURL, accountID, accountConcurrency) +} + +func TestGitHubDeviceAuthService_StartPollFlow(t *testing.T) { + store := NewInMemoryGitHubDeviceSessionStore() + upstream := &fakeHTTPUpstream{} + svc := NewGitHubDeviceAuthService(store, upstream) + account := &Account{ID: 123, Type: AccountTypeAPIKey, Concurrency: 3} + + start, err := svc.Start(context.Background(), account, "", "") + require.NoError(t, err) + require.NotNil(t, start) + require.NotEmpty(t, start.SessionID) + require.Equal(t, int64(900), start.ExpiresIn) + require.Equal(t, int64(5), start.IntervalSeconds) + require.Equal(t, "uc1", start.UserCode) + + res, err := svc.Poll(context.Background(), account.ID, start.SessionID) + require.NoError(t, err) + require.Equal(t, "pending", res.Status) + require.Equal(t, int64(5), res.IntervalSeconds) + + res, err = svc.Poll(context.Background(), account.ID, start.SessionID) + require.NoError(t, err) + require.Equal(t, "pending", res.Status) + require.Equal(t, int64(10), res.IntervalSeconds) + require.Equal(t, "slow_down", res.Error) + + res, err = svc.Poll(context.Background(), account.ID, start.SessionID) + require.NoError(t, err) + require.Equal(t, "success", res.Status) + require.Equal(t, "gho_xxx", res.AccessToken) + + _, ok, err := store.Get(context.Background(), start.SessionID) + require.NoError(t, err) + require.False(t, ok) +} diff --git a/backend/internal/service/github_device_session_store.go b/backend/internal/service/github_device_session_store.go new file mode 100644 index 0000000000..c088a8ebbb --- /dev/null +++ b/backend/internal/service/github_device_session_store.go @@ -0,0 +1,72 @@ +package service + +import ( + "context" + "sync" + "time" +) + +type GitHubDeviceSession struct { + AccountID int64 `json:"account_id"` + AccountConcurrency int `json:"account_concurrency"` + ProxyURL string `json:"proxy_url"` + ClientID string `json:"client_id"` + Scope string `json:"scope"` + DeviceCode string `json:"device_code"` + + ExpiresAtUnix int64 `json:"expires_at_unix"` + IntervalSeconds int64 `json:"interval_seconds"` + CreatedAtUnix int64 `json:"created_at_unix"` +} + +type GitHubDeviceSessionStore interface { + Get(ctx context.Context, id string) (*GitHubDeviceSession, bool, error) + Set(ctx context.Context, id string, sess *GitHubDeviceSession, ttl time.Duration) error + Delete(ctx context.Context, id string) error +} + +type inMemoryGitHubDeviceSessionStore struct { + mu sync.Mutex + sessions map[string]inMemoryGitHubDeviceSession +} + +type inMemoryGitHubDeviceSession struct { + sess *GitHubDeviceSession + expiresAt time.Time +} + +func NewInMemoryGitHubDeviceSessionStore() GitHubDeviceSessionStore { + return &inMemoryGitHubDeviceSessionStore{sessions: make(map[string]inMemoryGitHubDeviceSession)} +} + +func (s *inMemoryGitHubDeviceSessionStore) Get(_ context.Context, id string) (*GitHubDeviceSession, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + entry, ok := s.sessions[id] + if !ok || entry.sess == nil { + return nil, false, nil + } + if !entry.expiresAt.IsZero() && time.Now().After(entry.expiresAt) { + delete(s.sessions, id) + return nil, false, nil + } + return entry.sess, true, nil +} + +func (s *inMemoryGitHubDeviceSessionStore) Set(_ context.Context, id string, sess *GitHubDeviceSession, ttl time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + if ttl <= 0 || sess == nil { + delete(s.sessions, id) + return nil + } + s.sessions[id] = inMemoryGitHubDeviceSession{sess: sess, expiresAt: time.Now().Add(ttl)} + return nil +} + +func (s *inMemoryGitHubDeviceSessionStore) Delete(_ context.Context, id string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, id) + return nil +} diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 22a67edac8..1386a5c7b7 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -25,6 +25,7 @@ type GroupRepository interface { ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) ListActive(ctx context.Context) ([]Group, error) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) + ListPublicGroupIDs(ctx context.Context) ([]int64, error) ExistsByName(ctx context.Context, name string) (bool, error) GetAccountCount(ctx context.Context, groupID int64) (int64, error) diff --git a/backend/internal/service/model_namespace.go b/backend/internal/service/model_namespace.go new file mode 100644 index 0000000000..1504bfec01 --- /dev/null +++ b/backend/internal/service/model_namespace.go @@ -0,0 +1,121 @@ +package service + +import ( + "strings" + + "github.com/Wei-Shaw/sub2api/internal/domain" +) + +type ModelNamespace struct { + Provider string + Platform string + Model string + HasNamespace bool +} + +func ParseModelNamespace(modelID string) ModelNamespace { + trimmed := strings.TrimSpace(modelID) + if trimmed == "" { + return ModelNamespace{} + } + + idx := strings.Index(trimmed, "/") + if idx <= 0 || idx == len(trimmed)-1 { + return inferFromModelName(trimmed) + } + + prefix := strings.ToLower(strings.TrimSpace(trimmed[:idx])) + rest := strings.TrimSpace(trimmed[idx+1:]) + if rest == "" { + return ModelNamespace{Model: trimmed} + } + + provider := normalizeNamespaceProvider(prefix) + if provider == "" { + return inferFromModelName(trimmed) + } + + platform := domain.GetPlatformFromProvider(provider) + if platform == "" { + platform = inferPlatformFromModel(rest) + } + + return ModelNamespace{ + Provider: provider, + Platform: platform, + Model: rest, + HasNamespace: true, + } +} + +func inferFromModelName(modelID string) ModelNamespace { + platform := inferPlatformFromModel(modelID) + var provider string + switch platform { + case domain.PlatformAnthropic: + provider = domain.ProviderAnthropic + case domain.PlatformGemini: + provider = domain.ProviderGemini + case domain.PlatformOpenAI: + provider = domain.ProviderOpenAI + } + return ModelNamespace{ + Provider: provider, + Platform: platform, + Model: modelID, + } +} + +func inferPlatformFromModel(modelID string) string { + m := strings.ToLower(modelID) + switch { + case IsClaudeModelID(m): + return domain.PlatformAnthropic + case IsGeminiModelID(m): + return domain.PlatformGemini + default: + return domain.PlatformOpenAI + } +} + +func normalizeNamespaceProvider(prefix string) string { + p := strings.ToLower(strings.TrimSpace(prefix)) + switch p { + case domain.ProviderOpenAI, + domain.ProviderAzure, + domain.ProviderCopilot, + domain.ProviderAnthropic, + domain.ProviderGemini, + domain.ProviderVertexAI, + domain.ProviderAntigravity, + domain.ProviderBedrock, + domain.ProviderOpenRouter, + domain.ProviderAggregator: + return p + case "claude": + return domain.ProviderAnthropic + case "vertexai", "vertex-ai": + return domain.ProviderVertexAI + case "github", "github_copilot": + return domain.ProviderCopilot + default: + return "" + } +} + +func IsClaudeModelID(modelID string) bool { + m := strings.ToLower(strings.TrimSpace(modelID)) + return strings.HasPrefix(m, "claude-") || strings.HasPrefix(m, "claude_") +} + +func IsGeminiModelID(modelID string) bool { + m := strings.ToLower(strings.TrimSpace(modelID)) + return strings.HasPrefix(m, "gemini-") || strings.HasPrefix(m, "gemini_") +} + +func (n ModelNamespace) NamespacedModel() string { + if n.HasNamespace && n.Provider != "" { + return n.Provider + "/" + n.Model + } + return n.Model +} diff --git a/backend/internal/service/model_namespace_test.go b/backend/internal/service/model_namespace_test.go new file mode 100644 index 0000000000..27ca2558f7 --- /dev/null +++ b/backend/internal/service/model_namespace_test.go @@ -0,0 +1,211 @@ +package service + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/domain" + "github.com/stretchr/testify/assert" +) + +func TestParseModelNamespace(t *testing.T) { + tests := []struct { + name string + input string + expected ModelNamespace + }{ + { + name: "simple model name", + input: "gpt-4o", + expected: ModelNamespace{ + Provider: ProviderOpenAI, + Platform: PlatformOpenAI, + Model: "gpt-4o", + }, + }, + { + name: "claude model", + input: "claude-sonnet-4-5", + expected: ModelNamespace{ + Provider: ProviderAnthropic, + Platform: PlatformAnthropic, + Model: "claude-sonnet-4-5", + }, + }, + { + name: "gemini model", + input: "gemini-2.5-flash", + expected: ModelNamespace{ + Provider: ProviderGemini, + Platform: PlatformGemini, + Model: "gemini-2.5-flash", + }, + }, + { + name: "provider namespace - openai", + input: "openai/gpt-5.2", + expected: ModelNamespace{ + Provider: ProviderOpenAI, + Platform: PlatformOpenAI, + Model: "gpt-5.2", + HasNamespace: true, + }, + }, + { + name: "provider namespace - copilot", + input: "copilot/gpt-5.2", + expected: ModelNamespace{ + Provider: ProviderCopilot, + Platform: PlatformCopilot, + Model: "gpt-5.2", + HasNamespace: true, + }, + }, + { + name: "provider namespace - aggregator", + input: "aggregator/gpt-5.2", + expected: ModelNamespace{ + Provider: ProviderAggregator, + Platform: PlatformAggregator, + Model: "gpt-5.2", + HasNamespace: true, + }, + }, + { + name: "provider namespace - azure", + input: "azure/gpt-5.2", + expected: ModelNamespace{ + Provider: ProviderAzure, + Platform: PlatformOpenAI, + Model: "gpt-5.2", + HasNamespace: true, + }, + }, + { + name: "provider namespace - anthropic", + input: "anthropic/claude-sonnet-4-5", + expected: ModelNamespace{ + Provider: ProviderAnthropic, + Platform: PlatformAnthropic, + Model: "claude-sonnet-4-5", + HasNamespace: true, + }, + }, + { + name: "provider alias - claude", + input: "claude/claude-sonnet-4-5", + expected: ModelNamespace{ + Provider: ProviderAnthropic, + Platform: PlatformAnthropic, + Model: "claude-sonnet-4-5", + HasNamespace: true, + }, + }, + { + name: "provider alias - github", + input: "github/gpt-4o", + expected: ModelNamespace{ + Provider: ProviderCopilot, + Platform: PlatformCopilot, + Model: "gpt-4o", + HasNamespace: true, + }, + }, + { + name: "empty string", + input: "", + expected: ModelNamespace{ + Provider: "", + Platform: "", + Model: "", + }, + }, + { + name: "whitespace only", + input: " ", + expected: ModelNamespace{ + Provider: "", + Platform: "", + Model: "", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParseModelNamespace(tt.input) + assert.Equal(t, tt.expected.Provider, result.Provider) + assert.Equal(t, tt.expected.Platform, result.Platform) + assert.Equal(t, tt.expected.Model, result.Model) + assert.Equal(t, tt.expected.HasNamespace, result.HasNamespace) + }) + } +} + +func TestModelNamespace_NamespacedModel(t *testing.T) { + tests := []struct { + name string + ns ModelNamespace + expected string + }{ + { + name: "with namespace", + ns: ModelNamespace{ + Provider: ProviderOpenAI, + Model: "gpt-4o", + HasNamespace: true, + }, + expected: "openai/gpt-4o", + }, + { + name: "without namespace", + ns: ModelNamespace{ + Provider: ProviderOpenAI, + Model: "gpt-4o", + }, + expected: "gpt-4o", + }, + { + name: "empty provider", + ns: ModelNamespace{ + Model: "gpt-4o", + HasNamespace: true, + }, + expected: "gpt-4o", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.ns.NamespacedModel()) + }) + } +} + +func TestProviderToPlatform(t *testing.T) { + tests := []struct { + provider string + expectedOK bool + platform string + }{ + {domain.ProviderOpenAI, true, domain.PlatformOpenAI}, + {domain.ProviderAzure, true, domain.PlatformOpenAI}, + {domain.ProviderCopilot, true, domain.PlatformCopilot}, + {domain.ProviderAggregator, true, domain.PlatformAggregator}, + {domain.ProviderAnthropic, true, domain.PlatformAnthropic}, + {domain.ProviderGemini, true, domain.PlatformGemini}, + {domain.ProviderVertexAI, true, domain.PlatformGemini}, + {domain.ProviderAntigravity, true, domain.PlatformAntigravity}, + {"unknown", false, ""}, + } + + for _, tt := range tests { + t.Run(tt.provider, func(t *testing.T) { + platform := domain.GetPlatformFromProvider(tt.provider) + if tt.expectedOK { + assert.Equal(t, tt.platform, platform) + } else { + assert.Empty(t, platform) + } + }) + } +} diff --git a/backend/internal/service/model_provider.go b/backend/internal/service/model_provider.go new file mode 100644 index 0000000000..d885dcf524 --- /dev/null +++ b/backend/internal/service/model_provider.go @@ -0,0 +1,57 @@ +package service + +import ( + "net/url" + "strings" +) + +func inferProviderFromAccount(acc *Account) string { + if acc == nil { + return "" + } + if isGitHubCopilotAccount(acc) { + return ProviderCopilot + } + + platform := strings.ToLower(strings.TrimSpace(acc.Platform)) + switch platform { + case PlatformCopilot: + return ProviderCopilot + case PlatformAggregator: + return ProviderAggregator + case PlatformAntigravity: + return ProviderAntigravity + case PlatformAnthropic: + return ProviderAnthropic + case PlatformGemini: + return ProviderGemini + } + + host := normalizedHostname(acc.GetCredential("base_url")) + if host != "" { + if strings.HasSuffix(host, ".openai.azure.com") { + return ProviderAzure + } + if host == "openrouter.ai" || strings.HasSuffix(host, ".openrouter.ai") { + return ProviderOpenRouter + } + } + + return ProviderOpenAI +} + +func normalizedHostname(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + candidate := trimmed + if !strings.Contains(candidate, "://") { + candidate = "https://" + candidate + } + parsed, err := url.Parse(candidate) + if err != nil { + return "" + } + return strings.ToLower(strings.TrimSpace(parsed.Hostname())) +} diff --git a/backend/internal/service/model_provider_test.go b/backend/internal/service/model_provider_test.go new file mode 100644 index 0000000000..da2a24f2f7 --- /dev/null +++ b/backend/internal/service/model_provider_test.go @@ -0,0 +1,24 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestInferProviderFromAccount_OpenAIBaseURL_Azure(t *testing.T) { + acc := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Credentials: map[string]any{"base_url": "https://foo.openai.azure.com"}} + require.Equal(t, ProviderAzure, inferProviderFromAccount(acc)) +} + +func TestInferProviderFromAccount_OpenAIBaseURL_DefaultOpenAI(t *testing.T) { + acc := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Credentials: map[string]any{"base_url": "https://api.openai.com"}} + require.Equal(t, ProviderOpenAI, inferProviderFromAccount(acc)) +} + +func TestInferProviderFromAccount_CopilotPlatform(t *testing.T) { + acc := &Account{Platform: PlatformCopilot, Type: AccountTypeAPIKey} + require.Equal(t, ProviderCopilot, inferProviderFromAccount(acc)) +} diff --git a/backend/internal/service/openai_chat_completions_compat.go b/backend/internal/service/openai_chat_completions_compat.go new file mode 100644 index 0000000000..c1ca906ac9 --- /dev/null +++ b/backend/internal/service/openai_chat_completions_compat.go @@ -0,0 +1,381 @@ +package service + +import ( + "encoding/json" + "errors" + "fmt" + "strings" +) + +func convertClaudeToolsToOpenAIChatTools(tools any) []any { + arr, ok := tools.([]any) + if !ok || len(arr) == 0 { + return nil + } + out := make([]any, 0, len(arr)) + for _, t := range arr { + tm, ok := t.(map[string]any) + if !ok { + continue + } + + toolType, _ := tm["type"].(string) + toolType = strings.TrimSpace(toolType) + + name := "" + desc := "" + params := any(nil) + if toolType == "custom" { + name, _ = tm["name"].(string) + if custom, ok := tm["custom"].(map[string]any); ok { + desc, _ = custom["description"].(string) + params = custom["input_schema"] + } + } else { + name, _ = tm["name"].(string) + desc, _ = tm["description"].(string) + params = tm["input_schema"] + } + + name = strings.TrimSpace(name) + if name == "" { + continue + } + if params == nil { + params = map[string]any{"type": "object", "properties": map[string]any{}} + } + + out = append(out, map[string]any{ + "type": "function", + "function": map[string]any{ + "name": name, + "description": strings.TrimSpace(desc), + "parameters": params, + }, + }) + } + return out +} + +func convertClaudeMessagesToOpenAIChatCompletionsMessages(messages []any, system any) ([]any, error) { + out := make([]any, 0, len(messages)+1) + if systemText := extractClaudeSystemText(system); systemText != "" { + out = append(out, map[string]any{"role": "system", "content": systemText}) + } + + flushMessage := func(role string, sb *strings.Builder, parts []any, usingParts bool) { + role = strings.ToLower(strings.TrimSpace(role)) + if role == "" { + role = "user" + } + + if usingParts { + if sb.Len() > 0 { + parts = append(parts, map[string]any{"type": "text", "text": sb.String()}) + sb.Reset() + } + if len(parts) == 0 { + return + } + out = append(out, map[string]any{"role": role, "content": parts}) + return + } + + text := sb.String() + sb.Reset() + if strings.TrimSpace(text) == "" { + return + } + out = append(out, map[string]any{"role": role, "content": text}) + } + + for _, m := range messages { + mm, ok := m.(map[string]any) + if !ok { + continue + } + role, _ := mm["role"].(string) + role = strings.ToLower(strings.TrimSpace(role)) + if role == "" { + role = "user" + } + + switch content := mm["content"].(type) { + case string: + if strings.TrimSpace(content) != "" { + out = append(out, map[string]any{"role": role, "content": content}) + } + case []any: + var sb strings.Builder + parts := make([]any, 0) + usingParts := false + + appendText := func(text string) { + if usingParts { + parts = append(parts, map[string]any{"type": "text", "text": text}) + return + } + _, _ = sb.WriteString(text) + } + appendImageURL := func(url string) { + if !usingParts { + if sb.Len() > 0 { + parts = append(parts, map[string]any{"type": "text", "text": sb.String()}) + sb.Reset() + } + usingParts = true + } + parts = append(parts, map[string]any{"type": "image_url", "image_url": map[string]any{"url": url}}) + } + + for _, block := range content { + bm, ok := block.(map[string]any) + if !ok { + continue + } + bt, _ := bm["type"].(string) + bt = strings.ToLower(strings.TrimSpace(bt)) + + switch bt { + case "text": + if text, ok := bm["text"].(string); ok { + appendText(text) + } + case "thinking": + if t, ok := bm["thinking"].(string); ok && strings.TrimSpace(t) != "" { + appendText(t) + } + case "image": + if src, ok := bm["source"].(map[string]any); ok { + if srcType, _ := src["type"].(string); srcType == "base64" { + mediaType, _ := src["media_type"].(string) + data, _ := src["data"].(string) + mediaType = strings.TrimSpace(mediaType) + data = strings.TrimSpace(data) + if mediaType != "" && data != "" { + appendImageURL(fmt.Sprintf("data:%s;base64,%s", mediaType, data)) + } + } + } + case "tool_use": + flushMessage(role, &sb, parts, usingParts) + sb.Reset() + parts = make([]any, 0) + usingParts = false + + id, _ := bm["id"].(string) + name, _ := bm["name"].(string) + id = strings.TrimSpace(id) + name = strings.TrimSpace(name) + if id == "" { + id = "call_" + randomHex(12) + } + argsJSON, _ := json.Marshal(bm["input"]) + out = append(out, map[string]any{ + "role": "assistant", + "content": nil, + "tool_calls": []any{ + map[string]any{ + "id": id, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": string(argsJSON), + }, + }, + }, + }) + case "tool_result": + flushMessage(role, &sb, parts, usingParts) + sb.Reset() + parts = make([]any, 0) + usingParts = false + + toolUseID, _ := bm["tool_use_id"].(string) + toolUseID = strings.TrimSpace(toolUseID) + output := extractClaudeContentText(bm["content"]) + out = append(out, map[string]any{ + "role": "tool", + "tool_call_id": toolUseID, + "content": output, + }) + default: + if b, err := json.Marshal(bm); err == nil { + appendText(string(b)) + } + } + } + flushMessage(role, &sb, parts, usingParts) + default: + } + } + + return out, nil +} + +func convertOpenAIChatCompletionsJSONToClaude(openaiResp []byte, originalModel string) (map[string]any, *ClaudeUsage, string, error) { + var resp map[string]any + if err := json.Unmarshal(openaiResp, &resp); err != nil { + return nil, nil, "", err + } + + usage := &ClaudeUsage{} + if u, ok := resp["usage"].(map[string]any); ok { + if in, ok := asInt(u["prompt_tokens"]); ok { + usage.InputTokens = in + } + if out, ok := asInt(u["completion_tokens"]); ok { + usage.OutputTokens = out + } + } + + content := make([]any, 0) + stopReason := "end_turn" + + if choicesAny, ok := resp["choices"].([]any); ok && len(choicesAny) > 0 { + choice, _ := choicesAny[0].(map[string]any) + finishReason, _ := choice["finish_reason"].(string) + switch strings.TrimSpace(finishReason) { + case "length": + stopReason = "max_tokens" + case "tool_calls", "function_call": + stopReason = "tool_use" + } + + msg, _ := choice["message"].(map[string]any) + if msg != nil { + if text, _ := openAIChatMessageContentToText(msg["content"]); strings.TrimSpace(text) != "" { + content = append(content, map[string]any{"type": "text", "text": text}) + } + + if toolCallsAny, ok := msg["tool_calls"].([]any); ok && len(toolCallsAny) > 0 { + stopReason = "tool_use" + for _, tc := range toolCallsAny { + tcm, ok := tc.(map[string]any) + if !ok { + continue + } + callID, _ := tcm["id"].(string) + callID = strings.TrimSpace(callID) + if callID == "" { + callID = "call_" + randomHex(12) + } + + fn, _ := tcm["function"].(map[string]any) + name, _ := fn["name"].(string) + args, _ := fn["arguments"].(string) + + var input any + if strings.TrimSpace(args) != "" { + var parsed any + if json.Unmarshal([]byte(args), &parsed) == nil { + input = parsed + } else { + input = args + } + } + + content = append(content, map[string]any{ + "type": "tool_use", + "id": callID, + "name": strings.TrimSpace(name), + "input": input, + }) + } + } else if fc, ok := msg["function_call"].(map[string]any); ok { + stopReason = "tool_use" + name, _ := fc["name"].(string) + args, _ := fc["arguments"].(string) + callID := "call_" + randomHex(12) + + var input any + if strings.TrimSpace(args) != "" { + var parsed any + if json.Unmarshal([]byte(args), &parsed) == nil { + input = parsed + } else { + input = args + } + } + content = append(content, map[string]any{ + "type": "tool_use", + "id": callID, + "name": strings.TrimSpace(name), + "input": input, + }) + } + } + } + + if len(content) == 0 { + content = append(content, map[string]any{"type": "text", "text": ""}) + } + + msgID, _ := resp["id"].(string) + msgID = strings.TrimSpace(msgID) + if msgID == "" { + msgID = "msg_" + randomHex(12) + } + + claudeResp := map[string]any{ + "id": msgID, + "type": "message", + "role": "assistant", + "model": originalModel, + "content": content, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": usage.InputTokens, + "output_tokens": usage.OutputTokens, + "cache_creation_input_tokens": usage.CacheCreationInputTokens, + "cache_read_input_tokens": usage.CacheReadInputTokens, + }, + } + + return claudeResp, usage, stopReason, nil +} + +func extractOpenAIChatCompletionsInputTokens(openaiResp []byte) (int, error) { + var resp map[string]any + if err := json.Unmarshal(openaiResp, &resp); err != nil { + return 0, err + } + usageAny, ok := resp["usage"].(map[string]any) + if !ok { + return 0, errors.New("missing usage") + } + promptTokens, ok := asInt(usageAny["prompt_tokens"]) + if !ok { + return 0, errors.New("missing usage.prompt_tokens") + } + return promptTokens, nil +} + +func openAIChatMessageContentToText(v any) (string, bool) { + switch t := v.(type) { + case string: + return t, true + case []any: + var sb strings.Builder + for _, part := range t { + pm, ok := part.(map[string]any) + if !ok { + continue + } + pt, _ := pm["type"].(string) + if strings.EqualFold(strings.TrimSpace(pt), "text") { + if text, ok := pm["text"].(string); ok { + _, _ = sb.WriteString(text) + } + } + } + return sb.String(), true + default: + b, err := json.Marshal(t) + if err != nil { + return "", false + } + return string(b), true + } +} diff --git a/backend/internal/service/openai_chat_completions_url.go b/backend/internal/service/openai_chat_completions_url.go new file mode 100644 index 0000000000..84742a73bd --- /dev/null +++ b/backend/internal/service/openai_chat_completions_url.go @@ -0,0 +1,23 @@ +package service + +import "strings" + +func openaiChatCompletionsURLFromBaseURL(normalizedBaseURL string, isGitHubCopilot bool) string { + base := strings.TrimRight(strings.TrimSpace(normalizedBaseURL), "/") + if strings.HasSuffix(base, "/chat/completions") { + if isGitHubCopilot && strings.HasSuffix(base, "/v1/chat/completions") { + base = strings.TrimSuffix(base, "/v1/chat/completions") + base = strings.TrimRight(base, "/") + return base + "/chat/completions" + } + return base + } + if isGitHubCopilot { + base = strings.TrimSuffix(base, "/v1") + return base + "/chat/completions" + } + if strings.HasSuffix(base, "/v1") { + return base + "/chat/completions" + } + return base + "/v1/chat/completions" +} diff --git a/backend/internal/service/openai_chat_completions_url_test.go b/backend/internal/service/openai_chat_completions_url_test.go new file mode 100644 index 0000000000..cd8e1ce338 --- /dev/null +++ b/backend/internal/service/openai_chat_completions_url_test.go @@ -0,0 +1,109 @@ +//go:build unit + +package service + +import "testing" + +func TestOpenAIChatCompletionsURLFromBaseURL(t *testing.T) { + tests := []struct { + name string + baseURL string + isCopilot bool + expectedURL string + }{ + { + name: "openai root", + baseURL: "https://api.openai.com", + isCopilot: false, + expectedURL: "https://api.openai.com/v1/chat/completions", + }, + { + name: "openai root trailing slash", + baseURL: "https://api.openai.com/", + isCopilot: false, + expectedURL: "https://api.openai.com/v1/chat/completions", + }, + { + name: "openai v1", + baseURL: "https://api.openai.com/v1", + isCopilot: false, + expectedURL: "https://api.openai.com/v1/chat/completions", + }, + { + name: "openai v1 trailing slash", + baseURL: "https://api.openai.com/v1/", + isCopilot: false, + expectedURL: "https://api.openai.com/v1/chat/completions", + }, + { + name: "openai chat completions endpoint", + baseURL: "https://api.openai.com/v1/chat/completions", + isCopilot: false, + expectedURL: "https://api.openai.com/v1/chat/completions", + }, + { + name: "openai path prefix", + baseURL: "https://proxy.example.com/openai", + isCopilot: false, + expectedURL: "https://proxy.example.com/openai/v1/chat/completions", + }, + { + name: "openai path prefix v1", + baseURL: "https://proxy.example.com/openai/v1", + isCopilot: false, + expectedURL: "https://proxy.example.com/openai/v1/chat/completions", + }, + + { + name: "github copilot root", + baseURL: "https://api.githubcopilot.com", + isCopilot: true, + expectedURL: "https://api.githubcopilot.com/chat/completions", + }, + { + name: "github copilot root trailing slash", + baseURL: "https://api.githubcopilot.com/", + isCopilot: true, + expectedURL: "https://api.githubcopilot.com/chat/completions", + }, + { + name: "github copilot v1", + baseURL: "https://api.githubcopilot.com/v1", + isCopilot: true, + expectedURL: "https://api.githubcopilot.com/chat/completions", + }, + { + name: "github copilot v1 trailing slash", + baseURL: "https://api.githubcopilot.com/v1/", + isCopilot: true, + expectedURL: "https://api.githubcopilot.com/chat/completions", + }, + { + name: "github copilot chat completions endpoint", + baseURL: "https://api.githubcopilot.com/chat/completions", + isCopilot: true, + expectedURL: "https://api.githubcopilot.com/chat/completions", + }, + { + name: "github copilot v1 chat completions endpoint", + baseURL: "https://api.githubcopilot.com/v1/chat/completions", + isCopilot: true, + expectedURL: "https://api.githubcopilot.com/chat/completions", + }, + { + name: "github copilot enterprise subdomain", + baseURL: "https://api.business.githubcopilot.com", + isCopilot: true, + expectedURL: "https://api.business.githubcopilot.com/chat/completions", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := openaiChatCompletionsURLFromBaseURL(tt.baseURL, tt.isCopilot) + if got != tt.expectedURL { + t.Fatalf("openaiChatCompletionsURLFromBaseURL(%q, %v) = %q, want %q", tt.baseURL, tt.isCopilot, got, tt.expectedURL) + } + }) + } +} diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index cea81693cd..dd16e931a2 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -10,6 +10,9 @@ import ( "path/filepath" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) const ( @@ -549,7 +552,16 @@ func writeJSON(path string, value any) error { } func fetchWithETag(url, etag string) (string, string, int, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) + validatedURL, err := urlvalidator.ValidateHTTPSURL(url, urlvalidator.ValidationOptions{ + AllowedHosts: []string{"raw.githubusercontent.com"}, + RequireAllowlist: true, + AllowPrivate: false, + }) + if err != nil { + return "", "", 0, fmt.Errorf("invalid url: %w", err) + } + + req, err := http.NewRequest(http.MethodGet, validatedURL, nil) if err != nil { return "", "", 0, err } @@ -557,7 +569,17 @@ func fetchWithETag(url, etag string) (string, string, int, error) { if etag != "" { req.Header.Set("If-None-Match", etag) } - resp, err := http.DefaultClient.Do(req) + + client, err := httpclient.GetClient(httpclient.Options{ + Timeout: 10 * time.Second, + ValidateResolvedIP: true, + }) + if err != nil { + client = &http.Client{Timeout: 10 * time.Second} + } + + // #nosec G704 -- validatedURL allowlisted to raw.githubusercontent.com (private hosts blocked); resolved IP validated when available + resp, err := client.Do(req) if err != nil { return "", "", 0, err } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 6c4fe256ca..e58e284806 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -183,9 +183,29 @@ type OpenAIGatewayService struct { httpUpstream HTTPUpstream deferredService *DeferredService openAITokenProvider *OpenAITokenProvider + githubCopilotToken *GitHubCopilotTokenProvider toolCorrector *CodexToolCorrector } +func openaiProtocolPlatform(raw string) string { + platform := strings.ToLower(strings.TrimSpace(raw)) + switch platform { + case PlatformCopilot: + return PlatformCopilot + case PlatformAggregator: + return PlatformAggregator + default: + return PlatformOpenAI + } +} + +func openaiStickySessionKey(platform string, sessionHash string) string { + if strings.TrimSpace(sessionHash) == "" { + return "" + } + return openaiProtocolPlatform(platform) + ":" + sessionHash +} + // NewOpenAIGatewayService creates a new OpenAIGatewayService func NewOpenAIGatewayService( accountRepo AccountRepository, @@ -202,6 +222,7 @@ func NewOpenAIGatewayService( httpUpstream HTTPUpstream, deferredService *DeferredService, openAITokenProvider *OpenAITokenProvider, + githubCopilotToken *GitHubCopilotTokenProvider, ) *OpenAIGatewayService { return &OpenAIGatewayService{ accountRepo: accountRepo, @@ -218,6 +239,7 @@ func NewOpenAIGatewayService( httpUpstream: httpUpstream, deferredService: deferredService, openAITokenProvider: openAITokenProvider, + githubCopilotToken: githubCopilotToken, toolCorrector: NewCodexToolCorrector(), } } @@ -252,10 +274,15 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[s // BindStickySession sets session -> account binding with standard TTL. func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { - if sessionHash == "" || accountID <= 0 { + return s.BindStickySessionForPlatform(ctx, groupID, PlatformOpenAI, sessionHash, accountID) +} + +func (s *OpenAIGatewayService) BindStickySessionForPlatform(ctx context.Context, groupID *int64, platform string, sessionHash string, accountID int64) error { + cacheKey := openaiStickySessionKey(platform, sessionHash) + if cacheKey == "" || accountID <= 0 { return nil } - return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL) + return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, accountID, openaiStickySessionTTL) } // SelectAccount selects an OpenAI account with sticky session support @@ -271,24 +298,28 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. // SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。 func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - cacheKey := "openai:" + sessionHash + return s.SelectAccountForModelWithExclusionsForPlatform(ctx, groupID, PlatformOpenAI, sessionHash, requestedModel, excludedIDs) +} + +func (s *OpenAIGatewayService) SelectAccountForModelWithExclusionsForPlatform(ctx context.Context, groupID *int64, platform string, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { + cacheKey := openaiStickySessionKey(platform, sessionHash) // 1. 尝试粘性会话命中 // Try sticky session hit - if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil { + if account := s.tryStickySessionHit(ctx, groupID, platform, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil { return account, nil } // 2. 获取可调度的 OpenAI 账号 // Get schedulable OpenAI accounts - accounts, err := s.listSchedulableAccounts(ctx, groupID) + accounts, err := s.listSchedulableAccounts(ctx, groupID, platform) if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) } // 3. 按优先级 + LRU 选择最佳账号 // Select by priority + LRU - selected := s.selectBestAccount(accounts, requestedModel, excludedIDs) + selected := s.selectBestAccount(accounts, platform, requestedModel, excludedIDs) if selected == nil { if requestedModel != "" { @@ -299,7 +330,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // 4. 设置粘性会话绑定 // Set sticky session binding - if sessionHash != "" { + if cacheKey != "" { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL) } @@ -311,8 +342,8 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // // tryStickySessionHit attempts to get account from sticky session. // Returns account if hit and usable; clears session and returns nil if account is unavailable. -func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account { - if sessionHash == "" { +func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, platform string, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account { + if cacheKey == "" { return nil } @@ -339,7 +370,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 验证账号是否可用于当前请求 // Verify account is usable for current request - if !account.IsSchedulable() || !account.IsOpenAI() { + if !account.IsSchedulable() || account.Platform != openaiProtocolPlatform(platform) { return nil } if requestedModel != "" && !account.IsModelSupported(requestedModel) { @@ -357,8 +388,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // // selectBestAccount selects the best account from candidates (priority + LRU). // Returns nil if no available account. -func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { +func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, platform string, requestedModel string, excludedIDs map[int64]struct{}) *Account { var selected *Account + wantPlatform := openaiProtocolPlatform(platform) for i := range accounts { acc := &accounts[i] @@ -371,7 +403,7 @@ func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedMo // 调度器快照可能暂时过时,这里重新检查可调度性和平台 // Scheduler snapshots can be temporarily stale; re-check schedulability and platform - if !acc.IsSchedulable() || !acc.IsOpenAI() { + if !acc.IsSchedulable() || acc.Platform != wantPlatform { continue } @@ -431,15 +463,19 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { + return s.SelectAccountWithLoadAwarenessForPlatform(ctx, groupID, PlatformOpenAI, sessionHash, requestedModel, excludedIDs) +} + +func (s *OpenAIGatewayService) SelectAccountWithLoadAwarenessForPlatform(ctx context.Context, groupID *int64, platform string, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { cfg := s.schedulingConfig() var stickyAccountID int64 if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), openaiStickySessionKey(platform, sessionHash)); err == nil { stickyAccountID = accountID } } if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) + account, err := s.SelectAccountForModelWithExclusionsForPlatform(ctx, groupID, platform, sessionHash, requestedModel, excludedIDs) if err != nil { return nil, err } @@ -476,7 +512,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex }, nil } - accounts, err := s.listSchedulableAccounts(ctx, groupID) + accounts, err := s.listSchedulableAccounts(ctx, groupID, platform) if err != nil { return nil, err } @@ -494,19 +530,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex // ============ Layer 1: Sticky session ============ if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), openaiStickySessionKey(platform, sessionHash)) if err == nil && accountID > 0 && !isExcluded(accountID) { account, err := s.getSchedulableAccount(ctx, accountID) if err == nil { clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), openaiStickySessionKey(platform, sessionHash)) } - if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && + if !clearSticky && account.IsSchedulable() && account.Platform == openaiProtocolPlatform(platform) && (requestedModel == "" || account.IsModelSupported(requestedModel)) { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), openaiStickySessionKey(platform, sessionHash), openaiStickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, @@ -570,7 +606,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), openaiStickySessionKey(platform, sessionHash), acc.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: acc, @@ -620,7 +656,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), openaiStickySessionKey(platform, sessionHash), item.account.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: item.account, @@ -649,19 +685,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex return nil, errors.New("no available accounts") } -func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) { +func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string) ([]Account, error) { if s.schedulerSnapshot != nil { - accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false) + accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, openaiProtocolPlatform(platform), false) return accounts, err } var accounts []Account var err error if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, openaiProtocolPlatform(platform)) } else if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, openaiProtocolPlatform(platform)) } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, openaiProtocolPlatform(platform)) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) @@ -716,7 +752,21 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco } return accessToken, "oauth", nil case AccountTypeAPIKey: - apiKey := account.GetOpenAIApiKey() + if isGitHubCopilotAccount(account) && s.githubCopilotToken != nil { + copilotToken, err := s.githubCopilotToken.GetAccessToken(ctx, account) + if err == nil && strings.TrimSpace(copilotToken) != "" { + return copilotToken, "github_copilot", nil + } + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey == "" { + if err != nil { + return "", "", err + } + return "", "", errors.New("api_key not found in credentials") + } + return apiKey, "apikey", nil + } + apiKey := strings.TrimSpace(account.GetCredential("api_key")) if apiKey == "" { return "", "", errors.New("api_key not found in credentials") } @@ -763,6 +813,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco originalModel := reqModel isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) + isGitHubCopilot := isGitHubCopilotAccount(account) // 对所有请求执行模型映射(包含 Codex CLI)。 mappedModel := account.GetMappedModel(reqModel) @@ -772,15 +823,29 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco bodyModified = true } - // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 if model, ok := reqBody["model"].(string); ok { - normalizedModel := normalizeCodexModel(model) - if normalizedModel != "" && normalizedModel != model { - log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", - model, normalizedModel, account.Name, account.Type, isCodexCLI) - reqBody["model"] = normalizedModel - mappedModel = normalizedModel - bodyModified = true + if isGitHubCopilot { + stripped := strings.TrimSpace(model) + if strings.Contains(stripped, "/") { + parts := strings.Split(stripped, "/") + stripped = strings.TrimSpace(parts[len(parts)-1]) + } + if stripped != "" && stripped != model { + log.Printf("[OpenAI] GitHub Copilot model prefix stripped: %s -> %s (account: %s)", + model, stripped, account.Name) + reqBody["model"] = stripped + mappedModel = stripped + bodyModified = true + } + } else { + normalizedModel := normalizeCodexModel(model) + if normalizedModel != "" && normalizedModel != model { + log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", + model, normalizedModel, account.Name, account.Type, isCodexCLI) + reqBody["model"] = normalizedModel + mappedModel = normalizedModel + bodyModified = true + } } } @@ -808,15 +873,34 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // Handle max_output_tokens based on platform and account type if !isCodexCLI { - if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens { + if account.Type == AccountTypeAPIKey { switch account.Platform { - case PlatformOpenAI: - // For OpenAI API Key, remove max_output_tokens (not supported) - // For OpenAI OAuth (Responses API), keep it (supported) - if account.Type == AccountTypeAPIKey { - delete(reqBody, "max_output_tokens") + case PlatformOpenAI, PlatformCopilot, PlatformAggregator: + if _, has := reqBody["max_output_tokens"]; has { + if _, ok := reqBody["max_tokens"]; ok { + delete(reqBody, "max_tokens") + bodyModified = true + } + if _, ok := reqBody["max_completion_tokens"]; ok { + delete(reqBody, "max_completion_tokens") + bodyModified = true + } + } else if v, ok := reqBody["max_completion_tokens"]; ok { + reqBody["max_output_tokens"] = v + delete(reqBody, "max_completion_tokens") + delete(reqBody, "max_tokens") + bodyModified = true + } else if v, ok := reqBody["max_tokens"]; ok { + reqBody["max_output_tokens"] = v + delete(reqBody, "max_tokens") + delete(reqBody, "max_completion_tokens") bodyModified = true } + } + } + + if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens { + switch account.Platform { case PlatformAnthropic: // For Anthropic (Claude), convert to max_tokens delete(reqBody, "max_output_tokens") @@ -829,9 +913,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco delete(reqBody, "max_output_tokens") bodyModified = true default: - // For unknown platforms, remove to be safe - delete(reqBody, "max_output_tokens") - bodyModified = true + if account.Platform != PlatformOpenAI && account.Platform != PlatformCopilot && account.Platform != PlatformAggregator { + // For unknown platforms, remove to be safe + delete(reqBody, "max_output_tokens") + bodyModified = true + } } } @@ -850,6 +936,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco bodyModified = true } } + + if isGitHubCopilot { + // GitHub Copilot does not support service_tier; force JSON null (even if absent). + if v, ok := reqBody["service_tier"]; !ok || v != nil { + reqBody["service_tier"] = nil + bodyModified = true + } + } } // Re-serialize body only if modified @@ -867,8 +961,15 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco return nil, err } + copilotVision := false + copilotInitiator := "user" + if isGitHubCopilot { + copilotVision = githubCopilotVisionEnabledFromResponsesPayload(reqBody) + copilotInitiator = githubCopilotInitiatorFromResponsesPayload(reqBody) + } + // Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI, isGitHubCopilot, copilotVision, copilotInitiator) if err != nil { return nil, err } @@ -906,6 +1007,29 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }) return nil, fmt.Errorf("upstream request failed: %s", safeErr) } + + if isGitHubCopilot && resp.StatusCode == http.StatusUnauthorized && s.githubCopilotToken != nil { + githubToken := strings.TrimSpace(account.GetCredential("github_token")) + if githubToken == "" { + githubToken = strings.TrimSpace(account.GetCredential("gh_token")) + } + if githubToken != "" { + s.githubCopilotToken.Invalidate(ctx, account) + if refreshed, refreshErr := s.githubCopilotToken.GetAccessToken(ctx, account); refreshErr == nil && strings.TrimSpace(refreshed) != "" { + retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, body, refreshed, reqStream, promptCacheKey, isCodexCLI, isGitHubCopilot, copilotVision, copilotInitiator) + if buildErr == nil { + retryResp, doErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) + if doErr == nil { + _ = resp.Body.Close() + resp = retryResp + } else if retryResp != nil && retryResp.Body != nil { + _ = retryResp.Body.Close() + } + } + } + } + } + defer func() { _ = resp.Body.Close() }() // Handle error response @@ -979,7 +1103,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }, nil } -func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) { +func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool, isGitHubCopilot bool, copilotVision bool, copilotInitiator string) (*http.Request, error) { // Determine target URL based on account type var targetURL string switch account.Type { @@ -988,7 +1112,13 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. targetURL = chatgptCodexURL case AccountTypeAPIKey: // API Key accounts use Platform API or custom base URL - baseURL := account.GetOpenAIBaseURL() + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + if baseURL == "" && account.Platform == PlatformCopilot { + baseURL = "https://api.githubcopilot.com" + } + if baseURL == "" && account.Platform == PlatformAggregator { + return nil, errors.New("base_url is required for aggregator accounts") + } if baseURL == "" { targetURL = openaiPlatformAPIURL } else { @@ -996,12 +1126,16 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. if err != nil { return nil, err } - targetURL = validatedURL + "/responses" + targetURL = openaiResponsesURLFromBaseURL(validatedURL, isGitHubCopilot) } default: targetURL = openaiPlatformAPIURL } + return s.buildUpstreamRequestWithTargetURL(ctx, c, account, body, token, isStream, promptCacheKey, isCodexCLI, isGitHubCopilot, copilotVision, copilotInitiator, targetURL) +} + +func (s *OpenAIGatewayService) buildUpstreamRequestWithTargetURL(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool, isGitHubCopilot bool, copilotVision bool, copilotInitiator string, targetURL string) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -1050,6 +1184,14 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. req.Header.Set("user-agent", customUA) } + if isGitHubCopilot { + applyGitHubCopilotHeaders(req, copilotVision, copilotInitiator) + } + + if isStream && strings.TrimSpace(req.Header.Get("accept")) == "" { + req.Header.Set("accept", "text/event-stream") + } + // Ensure required headers exist if req.Header.Get("content-type") == "" { req.Header.Set("content-type", "application/json") @@ -1086,7 +1228,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht if status, errType, errMsg, matched := applyErrorPassthroughRule( c, - PlatformOpenAI, + account.Platform, resp.StatusCode, body, http.StatusBadGateway, diff --git a/backend/internal/service/openai_gateway_service_forward_token_limits_test.go b/backend/internal/service/openai_gateway_service_forward_token_limits_test.go new file mode 100644 index 0000000000..603e51593a --- /dev/null +++ b/backend/internal/service/openai_gateway_service_forward_token_limits_test.go @@ -0,0 +1,113 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestOpenAIGatewayService_Forward_NormalizesOutputTokenCaps(t *testing.T) { + gin.SetMode(gin.TestMode) + + platforms := []string{PlatformOpenAI, PlatformCopilot, PlatformAggregator} + scenarios := []struct { + name string + extraBody map[string]any + wantMaxOut float64 + }{ + { + name: "max_output_tokens_keeps_and_drops_legacy", + extraBody: map[string]any{"max_output_tokens": 77, "max_tokens": 88, "max_completion_tokens": 99}, + wantMaxOut: 77, + }, + { + name: "max_tokens_maps_to_max_output_tokens", + extraBody: map[string]any{"max_tokens": 55}, + wantMaxOut: 55, + }, + { + name: "max_completion_tokens_maps_to_max_output_tokens", + extraBody: map[string]any{"max_completion_tokens": 66}, + wantMaxOut: 66, + }, + } + + for _, platform := range platforms { + for _, scenario := range scenarios { + t.Run(platform+"_"+scenario.name, func(t *testing.T) { + type capturedReq struct { + Path string + JSON map[string]any + } + capCh := make(chan capturedReq, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var m map[string]any + _ = json.Unmarshal(body, &m) + capCh <- capturedReq{Path: r.URL.Path, JSON: m} + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "usage": map[string]any{"input_tokens": 1, "output_tokens": 1}, + }) + })) + defer server.Close() + + cfg := &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false, AllowInsecureHTTP: true}}} + upstream := &testHTTPUpstream{} + svc := NewOpenAIGatewayService(nil, nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, upstream, nil, nil, nil) + + account := &Account{ + ID: 1, + Name: "openai-test", + Platform: platform, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": server.URL, + }, + Concurrency: 1, + } + + reqBody := map[string]any{ + "model": "gpt-5.2", + "stream": false, + } + for k, v := range scenario.extraBody { + reqBody[k] = v + } + body, err := json.Marshal(reqBody) + require.NoError(t, err) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", bytes.NewReader(body)) + + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, http.StatusOK, w.Code) + + select { + case cap := <-capCh: + require.NotEmpty(t, cap.Path) + require.Equal(t, scenario.wantMaxOut, cap.JSON["max_output_tokens"]) + require.NotContains(t, cap.JSON, "max_tokens") + require.NotContains(t, cap.JSON, "max_completion_tokens") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for upstream request") + } + }) + } + } +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index ae69a9867b..c713771278 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -1085,7 +1085,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { Credentials: map[string]any{"base_url": "://invalid-url"}, } - _, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false) + _, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false, false, false, "user") if err == nil { t.Fatalf("expected error for invalid base_url when allowlist disabled") } diff --git a/backend/internal/service/openai_messages_compat_service.go b/backend/internal/service/openai_messages_compat_service.go new file mode 100644 index 0000000000..b6422eed21 --- /dev/null +++ b/backend/internal/service/openai_messages_compat_service.go @@ -0,0 +1,1331 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +type OpenAIMessagesCompatService struct { + openai *OpenAIGatewayService +} + +func NewOpenAIMessagesCompatService(openaiGateway *OpenAIGatewayService) *OpenAIMessagesCompatService { + return &OpenAIMessagesCompatService{openai: openaiGateway} +} + +func isResponsesAPIUnsupportedError(upstreamMsg string, upstreamBody []byte) bool { + msg := strings.ToLower(strings.TrimSpace(upstreamMsg)) + if strings.Contains(msg, "does not support responses api") { + return true + } + if strings.Contains(msg, "responses api") && strings.Contains(msg, "does not support") { + return true + } + lowerBody := strings.ToLower(string(upstreamBody)) + return strings.Contains(lowerBody, "does not support responses api") +} + +func (s *OpenAIMessagesCompatService) chatCompletionsURLForAccount(account *Account) (string, error) { + if s == nil || s.openai == nil { + return "", fmt.Errorf("openai gateway service not configured") + } + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + if baseURL == "" && account.Platform == PlatformCopilot { + baseURL = "https://api.githubcopilot.com" + } + if baseURL == "" { + baseURL = "https://api.openai.com" + } + validatedURL, err := s.openai.validateUpstreamBaseURL(baseURL) + if err != nil { + return "", err + } + return openaiChatCompletionsURLFromBaseURL(validatedURL, isGitHubCopilotAccount(account)), nil +} + +func (s *OpenAIMessagesCompatService) forwardViaChatCompletions(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest, claudeReq map[string]any, originalModel string, mappedModel string, token string, startTime time.Time) (*ForwardResult, error) { + openaiMessages, err := convertClaudeMessagesToOpenAIChatCompletionsMessages(parsed.Messages, parsed.System) + if err != nil { + writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + return nil, err + } + + openaiReq := map[string]any{ + "model": mappedModel, + "stream": false, + "messages": openaiMessages, + } + if parsed.MaxTokens > 0 { + openaiReq["max_tokens"] = parsed.MaxTokens + } + if tools := convertClaudeToolsToOpenAIChatTools(claudeReq["tools"]); len(tools) > 0 { + openaiReq["tools"] = tools + openaiReq["tool_choice"] = "auto" + } + if temp, ok := claudeReq["temperature"].(float64); ok { + openaiReq["temperature"] = temp + } + if topP, ok := claudeReq["top_p"].(float64); ok { + openaiReq["top_p"] = topP + } + if stopSeq, ok := claudeReq["stop_sequences"].([]any); ok && len(stopSeq) > 0 { + openaiReq["stop"] = stopSeq + } + + openaiBody, err := json.Marshal(openaiReq) + if err != nil { + writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return nil, err + } + if c != nil { + c.Set(OpsUpstreamRequestBodyKey, string(openaiBody)) + } + + targetURL, err := s.chatCompletionsURLForAccount(account) + if err != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error()) + return nil, err + } + + upstreamReq, err := s.openai.buildUpstreamRequestWithTargetURL(ctx, c, account, openaiBody, token, false, "", false, isGitHubCopilotAccount(account), false, "user", targetURL) + if err != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error()) + return nil, err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.openai.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + if resp == nil || resp.Body == nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty upstream response") + return nil, fmt.Errorf("empty upstream response") + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.openai.cfg != nil && s.openai.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.openai.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + if upstreamDetail == "" && isGitHubCopilotAccount(account) { + upstreamDetail = truncateString(string(respBody), 2048) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + if s.openai.shouldFailoverUpstreamError(resp.StatusCode) { + if s.openai.rateLimitService != nil { + s.openai.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: strings.TrimSpace(resp.Header.Get("x-request-id")), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: strings.TrimSpace(resp.Header.Get("x-request-id")), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + respBody, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + writeClaudeError(c, status, errType, errMsg) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + + status, errType, errMsg := mapOpenAIUpstreamErrorToClaude(resp.StatusCode) + writeClaudeError(c, status, errType, errMsg) + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + + upstreamBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response") + return nil, readErr + } + + claudeResp, usage, stopReason, convErr := convertOpenAIChatCompletionsJSONToClaude(upstreamBody, originalModel) + if convErr != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + return nil, convErr + } + + if reqID := strings.TrimSpace(resp.Header.Get("x-request-id")); reqID != "" { + c.Header("x-request-id", reqID) + } + + if !parsed.Stream { + c.JSON(http.StatusOK, claudeResp) + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: originalModel, + Stream: false, + Duration: time.Since(startTime), + }, nil + } + + if err := writeClaudeStreamFromMessage(c, claudeResp, usage, originalModel, stopReason); err != nil { + return nil, err + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: originalModel, + Stream: true, + Duration: time.Since(startTime), + }, nil +} + +func (s *OpenAIMessagesCompatService) forwardCountTokensViaChatCompletions(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest, claudeReq map[string]any, mappedModel string, token string) error { + openaiMessages, err := convertClaudeMessagesToOpenAIChatCompletionsMessages(parsed.Messages, parsed.System) + if err != nil { + writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + return err + } + + openaiReq := map[string]any{ + "model": mappedModel, + "stream": false, + "messages": openaiMessages, + "max_tokens": 1, + } + if tools := convertClaudeToolsToOpenAIChatTools(claudeReq["tools"]); len(tools) > 0 { + openaiReq["tools"] = tools + openaiReq["tool_choice"] = "auto" + } + + openaiBody, err := json.Marshal(openaiReq) + if err != nil { + writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return err + } + if c != nil { + c.Set(OpsUpstreamRequestBodyKey, string(openaiBody)) + } + + targetURL, err := s.chatCompletionsURLForAccount(account) + if err != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error()) + return err + } + + upstreamReq, err := s.openai.buildUpstreamRequestWithTargetURL(ctx, c, account, openaiBody, token, false, "", false, isGitHubCopilotAccount(account), false, "user", targetURL) + if err != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error()) + return err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.openai.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + return fmt.Errorf("upstream request failed: %s", sanitizeUpstreamErrorMessage(err.Error())) + } + if resp == nil || resp.Body == nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty upstream response") + return fmt.Errorf("empty upstream response") + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.openai.cfg != nil && s.openai.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.openai.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + if upstreamDetail == "" && isGitHubCopilotAccount(account) { + upstreamDetail = truncateString(string(respBody), 2048) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + if s.openai.rateLimitService != nil { + s.openai.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: strings.TrimSpace(resp.Header.Get("x-request-id")), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + respBody, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + writeClaudeError(c, status, errType, errMsg) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + + status, errType, errMsg := mapOpenAIUpstreamErrorToClaude(resp.StatusCode) + writeClaudeError(c, status, errType, errMsg) + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + + upstreamBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response") + return readErr + } + + inputTokens, err := extractOpenAIChatCompletionsInputTokens(upstreamBody) + if err != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + return err + } + + if reqID := strings.TrimSpace(resp.Header.Get("x-request-id")); reqID != "" { + c.Header("x-request-id", reqID) + } + + c.JSON(http.StatusOK, gin.H{"input_tokens": inputTokens}) + return nil +} + +func (s *OpenAIMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { + startTime := time.Now() + if s == nil || s.openai == nil { + return nil, fmt.Errorf("openai compat service not configured") + } + if account == nil { + return nil, fmt.Errorf("missing account") + } + if parsed == nil { + return nil, fmt.Errorf("empty request") + } + + originalModel := strings.TrimSpace(parsed.Model) + if originalModel == "" { + return nil, fmt.Errorf("missing model") + } + + var claudeReq map[string]any + if err := json.Unmarshal(parsed.Body, &claudeReq); err != nil { + writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return nil, err + } + + mappedModel := account.GetMappedModel(originalModel) + if isGitHubCopilotAccount(account) { + stripped := strings.TrimSpace(mappedModel) + if strings.Contains(stripped, "/") { + parts := strings.Split(stripped, "/") + stripped = strings.TrimSpace(parts[len(parts)-1]) + } + if stripped != "" { + mappedModel = stripped + } + } else { + mappedModel = normalizeCodexModel(mappedModel) + } + + openaiInput, err := convertClaudeMessagesToOpenAIResponsesInput(parsed.Messages) + if err != nil { + writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + return nil, err + } + + openaiReq := map[string]any{ + "model": mappedModel, + "stream": false, + "store": false, + "input": openaiInput, + } + if parsed.MaxTokens > 0 { + if account.Type == AccountTypeOAuth { + openaiReq["max_tokens"] = parsed.MaxTokens + } else { + openaiReq["max_output_tokens"] = parsed.MaxTokens + } + } + + if systemText := extractClaudeSystemText(parsed.System); systemText != "" { + openaiReq["instructions"] = systemText + } + + if tools := convertClaudeToolsToOpenAITools(claudeReq["tools"]); len(tools) > 0 { + openaiReq["tools"] = tools + openaiReq["tool_choice"] = "auto" + } + + if temp, ok := claudeReq["temperature"].(float64); ok { + openaiReq["temperature"] = temp + } + if topP, ok := claudeReq["top_p"].(float64); ok { + openaiReq["top_p"] = topP + } + if stopSeq, ok := claudeReq["stop_sequences"].([]any); ok && len(stopSeq) > 0 { + openaiReq["stop"] = stopSeq + } + + upstreamStream := false + if account.Type == AccountTypeOAuth { + upstreamStream = true + openaiReq["stream"] = true + } + + openaiBody, err := json.Marshal(openaiReq) + if err != nil { + writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return nil, err + } + if c != nil { + c.Set(OpsUpstreamRequestBodyKey, string(openaiBody)) + } + + token, _, err := s.openai.GetAccessToken(ctx, account) + if err != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error()) + return nil, err + } + + upstreamReq, err := s.openai.buildUpstreamRequest(ctx, c, account, openaiBody, token, upstreamStream, "", false, isGitHubCopilotAccount(account), false, "user") + if err != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error()) + return nil, err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.openai.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + if resp == nil || resp.Body == nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty upstream response") + return nil, fmt.Errorf("empty upstream response") + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + rawUpstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + if isGitHubCopilotAccount(account) && isResponsesAPIUnsupportedError(rawUpstreamMsg, respBody) { + return s.forwardViaChatCompletions(ctx, c, account, parsed, claudeReq, originalModel, mappedModel, token, startTime) + } + upstreamMsg := sanitizeUpstreamErrorMessage(rawUpstreamMsg) + upstreamDetail := "" + if s.openai.cfg != nil && s.openai.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.openai.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + if upstreamDetail == "" && isGitHubCopilotAccount(account) { + upstreamDetail = truncateString(string(respBody), 2048) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + if s.openai.shouldFailoverUpstreamError(resp.StatusCode) { + if s.openai.rateLimitService != nil { + s.openai.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: strings.TrimSpace(resp.Header.Get("x-request-id")), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: strings.TrimSpace(resp.Header.Get("x-request-id")), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + respBody, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + writeClaudeError(c, status, errType, errMsg) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + + status, errType, errMsg := mapOpenAIUpstreamErrorToClaude(resp.StatusCode) + writeClaudeError(c, status, errType, errMsg) + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + + upstreamBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response") + return nil, readErr + } + + bodyJSON := upstreamBody + if isEventStreamResponse(resp.Header) || bytes.Contains(upstreamBody, []byte("data:")) { + if final, ok := extractCodexFinalResponse(string(upstreamBody)); ok { + bodyJSON = final + } else { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream stream") + return nil, fmt.Errorf("failed to extract final openai response") + } + } + + claudeResp, usage, stopReason, convErr := convertOpenAIResponsesJSONToClaude(bodyJSON, originalModel) + if convErr != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + return nil, convErr + } + + if reqID := strings.TrimSpace(resp.Header.Get("x-request-id")); reqID != "" { + c.Header("x-request-id", reqID) + } + + if !parsed.Stream { + c.JSON(http.StatusOK, claudeResp) + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: originalModel, + Stream: false, + Duration: time.Since(startTime), + }, nil + } + + if err := writeClaudeStreamFromMessage(c, claudeResp, usage, originalModel, stopReason); err != nil { + return nil, err + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: originalModel, + Stream: true, + Duration: time.Since(startTime), + }, nil +} + +func (s *OpenAIMessagesCompatService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { + if s == nil || s.openai == nil { + writeClaudeError(c, http.StatusInternalServerError, "api_error", "OpenAI compat service not configured") + return fmt.Errorf("openai compat service not configured") + } + if account == nil { + writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Missing account") + return fmt.Errorf("missing account") + } + if parsed == nil { + writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return fmt.Errorf("empty request") + } + + originalModel := strings.TrimSpace(parsed.Model) + if originalModel == "" { + writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return fmt.Errorf("missing model") + } + + var claudeReq map[string]any + if err := json.Unmarshal(parsed.Body, &claudeReq); err != nil { + writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return err + } + + mappedModel := account.GetMappedModel(originalModel) + if isGitHubCopilotAccount(account) { + stripped := strings.TrimSpace(mappedModel) + if strings.Contains(stripped, "/") { + parts := strings.Split(stripped, "/") + stripped = strings.TrimSpace(parts[len(parts)-1]) + } + if stripped != "" { + mappedModel = stripped + } + } else { + mappedModel = normalizeCodexModel(mappedModel) + } + + openaiInput, err := convertClaudeMessagesToOpenAIResponsesInput(parsed.Messages) + if err != nil { + writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + return err + } + + openaiReq := map[string]any{ + "model": mappedModel, + "stream": false, + "store": false, + "input": openaiInput, + } + if account.Type == AccountTypeOAuth { + openaiReq["max_tokens"] = 1 + } else { + openaiReq["max_output_tokens"] = 1 + } + + if systemText := extractClaudeSystemText(parsed.System); systemText != "" { + openaiReq["instructions"] = systemText + } + + if tools := convertClaudeToolsToOpenAITools(claudeReq["tools"]); len(tools) > 0 { + openaiReq["tools"] = tools + openaiReq["tool_choice"] = "auto" + } + + upstreamStream := false + if account.Type == AccountTypeOAuth { + upstreamStream = true + openaiReq["stream"] = true + } + + openaiBody, err := json.Marshal(openaiReq) + if err != nil { + writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return err + } + if c != nil { + c.Set(OpsUpstreamRequestBodyKey, string(openaiBody)) + } + + token, _, err := s.openai.GetAccessToken(ctx, account) + if err != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error()) + return err + } + + upstreamReq, err := s.openai.buildUpstreamRequest(ctx, c, account, openaiBody, token, upstreamStream, "", false, isGitHubCopilotAccount(account), false, "user") + if err != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error()) + return err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.openai.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + return fmt.Errorf("upstream request failed: %s", safeErr) + } + if resp == nil || resp.Body == nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty upstream response") + return fmt.Errorf("empty upstream response") + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + rawUpstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + if isGitHubCopilotAccount(account) && isResponsesAPIUnsupportedError(rawUpstreamMsg, respBody) { + return s.forwardCountTokensViaChatCompletions(ctx, c, account, parsed, claudeReq, mappedModel, token) + } + upstreamMsg := sanitizeUpstreamErrorMessage(rawUpstreamMsg) + upstreamDetail := "" + if s.openai.cfg != nil && s.openai.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.openai.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + if upstreamDetail == "" && isGitHubCopilotAccount(account) { + upstreamDetail = truncateString(string(respBody), 2048) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + if s.openai.rateLimitService != nil { + s.openai.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: strings.TrimSpace(resp.Header.Get("x-request-id")), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + respBody, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + writeClaudeError(c, status, errType, errMsg) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + + status, errType, errMsg := mapOpenAIUpstreamErrorToClaude(resp.StatusCode) + writeClaudeError(c, status, errType, errMsg) + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + + upstreamBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response") + return readErr + } + + bodyJSON := upstreamBody + if isEventStreamResponse(resp.Header) || bytes.Contains(upstreamBody, []byte("data:")) { + if final, ok := extractCodexFinalResponse(string(upstreamBody)); ok { + bodyJSON = final + } else { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream stream") + return fmt.Errorf("failed to extract final openai response") + } + } + + inputTokens, err := extractOpenAIResponsesInputTokens(bodyJSON) + if err != nil { + writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + return err + } + + if reqID := strings.TrimSpace(resp.Header.Get("x-request-id")); reqID != "" { + c.Header("x-request-id", reqID) + } + + c.JSON(http.StatusOK, gin.H{"input_tokens": inputTokens}) + return nil +} + +func extractOpenAIResponsesInputTokens(openaiResp []byte) (int, error) { + var resp map[string]any + if err := json.Unmarshal(openaiResp, &resp); err != nil { + return 0, err + } + usageAny, ok := resp["usage"].(map[string]any) + if !ok { + return 0, errors.New("missing usage") + } + inputTokens, ok := asInt(usageAny["input_tokens"]) + if !ok { + return 0, errors.New("missing usage.input_tokens") + } + return inputTokens, nil +} + +func mapOpenAIUpstreamErrorToClaude(statusCode int) (int, string, string) { + switch statusCode { + case 401: + return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" + case 402: + return http.StatusBadGateway, "upstream_error", "Upstream payment required: insufficient balance or billing issue" + case 403: + return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" + case 429: + return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" + case 529: + return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" + default: + return http.StatusBadGateway, "upstream_error", "Upstream request failed" + } +} + +func writeClaudeError(c *gin.Context, status int, errType, message string) { + if c == nil { + return + } + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +func convertClaudeToolsToOpenAITools(tools any) []any { + arr, ok := tools.([]any) + if !ok || len(arr) == 0 { + return nil + } + out := make([]any, 0, len(arr)) + for _, t := range arr { + tm, ok := t.(map[string]any) + if !ok { + continue + } + + toolType, _ := tm["type"].(string) + toolType = strings.TrimSpace(toolType) + + name := "" + desc := "" + params := any(nil) + if toolType == "custom" { + name, _ = tm["name"].(string) + if custom, ok := tm["custom"].(map[string]any); ok { + desc, _ = custom["description"].(string) + params = custom["input_schema"] + } + } else { + name, _ = tm["name"].(string) + desc, _ = tm["description"].(string) + params = tm["input_schema"] + } + + name = strings.TrimSpace(name) + if name == "" { + continue + } + if params == nil { + params = map[string]any{"type": "object", "properties": map[string]any{}} + } + + out = append(out, map[string]any{ + "type": "function", + "name": name, + "description": strings.TrimSpace(desc), + "parameters": params, + }) + } + return out +} + +func convertClaudeMessagesToOpenAIResponsesInput(messages []any) ([]any, error) { + if len(messages) == 0 { + return []any{}, nil + } + + toolIDToName := make(map[string]string) + out := make([]any, 0, len(messages)) + + flushMessage := func(role string, content []any) { + role = strings.ToLower(strings.TrimSpace(role)) + if role == "" { + role = "user" + } + if len(content) == 0 { + return + } + out = append(out, map[string]any{ + "type": "message", + "role": role, + "content": content, + }) + } + + for _, m := range messages { + mm, ok := m.(map[string]any) + if !ok { + continue + } + role, _ := mm["role"].(string) + role = strings.ToLower(strings.TrimSpace(role)) + if role == "" { + role = "user" + } + + switch content := mm["content"].(type) { + case string: + flushMessage(role, []any{map[string]any{"type": "input_text", "text": content}}) + case []any: + msgContent := make([]any, 0) + for _, block := range content { + bm, ok := block.(map[string]any) + if !ok { + continue + } + bt, _ := bm["type"].(string) + bt = strings.ToLower(strings.TrimSpace(bt)) + + switch bt { + case "text": + if text, ok := bm["text"].(string); ok { + msgContent = append(msgContent, map[string]any{"type": "input_text", "text": text}) + } + case "thinking": + if t, ok := bm["thinking"].(string); ok && strings.TrimSpace(t) != "" { + msgContent = append(msgContent, map[string]any{"type": "input_text", "text": t}) + } + case "image": + if src, ok := bm["source"].(map[string]any); ok { + if srcType, _ := src["type"].(string); srcType == "base64" { + mediaType, _ := src["media_type"].(string) + data, _ := src["data"].(string) + mediaType = strings.TrimSpace(mediaType) + data = strings.TrimSpace(data) + if mediaType != "" && data != "" { + url := fmt.Sprintf("data:%s;base64,%s", mediaType, data) + msgContent = append(msgContent, map[string]any{ + "type": "input_image", + "image_url": map[string]any{"url": url}, + }) + } + } + } + case "tool_use": + flushMessage(role, msgContent) + msgContent = make([]any, 0) + + id, _ := bm["id"].(string) + name, _ := bm["name"].(string) + id = strings.TrimSpace(id) + name = strings.TrimSpace(name) + if id != "" && name != "" { + toolIDToName[id] = name + } + argsJSON, _ := json.Marshal(bm["input"]) + out = append(out, map[string]any{ + "type": "function_call", + "id": id, + "call_id": id, + "name": name, + "arguments": string(argsJSON), + }) + case "tool_result": + flushMessage(role, msgContent) + msgContent = make([]any, 0) + + toolUseID, _ := bm["tool_use_id"].(string) + toolUseID = strings.TrimSpace(toolUseID) + name := "" + if v, ok := bm["name"].(string); ok { + name = strings.TrimSpace(v) + } + if name == "" { + name = toolIDToName[toolUseID] + } + output := extractClaudeContentText(bm["content"]) + out = append(out, map[string]any{ + "type": "function_call_output", + "call_id": toolUseID, + "name": name, + "output": output, + }) + default: + if b, err := json.Marshal(bm); err == nil { + msgContent = append(msgContent, map[string]any{"type": "input_text", "text": string(b)}) + } + } + } + flushMessage(role, msgContent) + default: + } + } + return out, nil +} + +func convertOpenAIResponsesJSONToClaude(openaiResp []byte, originalModel string) (map[string]any, *ClaudeUsage, string, error) { + var resp map[string]any + if err := json.Unmarshal(openaiResp, &resp); err != nil { + return nil, nil, "", err + } + + usage := &ClaudeUsage{} + if u, ok := resp["usage"].(map[string]any); ok { + if in, ok := asInt(u["input_tokens"]); ok { + usage.InputTokens = in + } + if out, ok := asInt(u["output_tokens"]); ok { + usage.OutputTokens = out + } + if details, ok := u["input_tokens_details"].(map[string]any); ok { + if cached, ok := asInt(details["cached_tokens"]); ok { + usage.CacheReadInputTokens = cached + } + } + } + + content := make([]any, 0) + stopReason := "end_turn" + + if outputItems, ok := resp["output"].([]any); ok { + for _, item := range outputItems { + im, ok := item.(map[string]any) + if !ok { + continue + } + t, _ := im["type"].(string) + t = strings.TrimSpace(t) + switch t { + case "message": + if blocks, ok := im["content"].([]any); ok { + for _, b := range blocks { + bm, ok := b.(map[string]any) + if !ok { + continue + } + bt, _ := bm["type"].(string) + bt = strings.TrimSpace(bt) + switch bt { + case "output_text", "text": + if text, ok := bm["text"].(string); ok { + content = append(content, map[string]any{"type": "text", "text": text}) + } + } + } + } + case "function_call", "tool_call": + stopReason = "tool_use" + callID, _ := im["call_id"].(string) + if strings.TrimSpace(callID) == "" { + callID, _ = im["id"].(string) + } + name, _ := im["name"].(string) + argsAny := im["arguments"] + var input any + switch v := argsAny.(type) { + case string: + if strings.TrimSpace(v) != "" { + var parsed any + if json.Unmarshal([]byte(v), &parsed) == nil { + input = parsed + } else { + input = v + } + } + default: + input = v + } + content = append(content, map[string]any{ + "type": "tool_use", + "id": strings.TrimSpace(callID), + "name": strings.TrimSpace(name), + "input": input, + }) + } + } + } + + if len(content) == 0 { + content = append(content, map[string]any{"type": "text", "text": ""}) + } + + msgID, _ := resp["id"].(string) + msgID = strings.TrimSpace(msgID) + if msgID == "" { + msgID = "msg_" + randomHex(12) + } + + if stopReason != "tool_use" { + if inc, ok := resp["incomplete_details"].(map[string]any); ok { + if reason, _ := inc["reason"].(string); strings.EqualFold(strings.TrimSpace(reason), "max_output_tokens") { + stopReason = "max_tokens" + } + } + } + + claudeResp := map[string]any{ + "id": msgID, + "type": "message", + "role": "assistant", + "model": originalModel, + "content": content, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": usage.InputTokens, + "output_tokens": usage.OutputTokens, + "cache_creation_input_tokens": usage.CacheCreationInputTokens, + "cache_read_input_tokens": usage.CacheReadInputTokens, + }, + } + + return claudeResp, usage, stopReason, nil +} + +func writeClaudeStreamFromMessage(c *gin.Context, claudeResp map[string]any, usage *ClaudeUsage, model string, stopReason string) error { + if c == nil { + return errors.New("nil context") + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return errors.New("streaming not supported") + } + + id, _ := claudeResp["id"].(string) + if strings.TrimSpace(id) == "" { + id = "msg_" + randomHex(12) + } + + messageStart := map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": id, + "type": "message", + "role": "assistant", + "model": model, + "content": []any{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": usage.InputTokens, + "output_tokens": 0, + }, + }, + } + if err := writeClaudeSSEEvent(w, "message_start", messageStart); err != nil { + return err + } + flusher.Flush() + + blocksAny, _ := claudeResp["content"].([]any) + blockIndex := 0 + for _, block := range blocksAny { + bm, ok := block.(map[string]any) + if !ok { + continue + } + bt, _ := bm["type"].(string) + bt = strings.TrimSpace(bt) + if bt == "" { + continue + } + + start := map[string]any{ + "type": "content_block_start", + "index": blockIndex, + "content_block": bm, + } + if err := writeClaudeSSEEvent(w, "content_block_start", start); err != nil { + return err + } + + switch bt { + case "text": + text, _ := bm["text"].(string) + delta := map[string]any{ + "type": "content_block_delta", + "index": blockIndex, + "delta": map[string]any{"type": "text_delta", "text": text}, + } + if err := writeClaudeSSEEvent(w, "content_block_delta", delta); err != nil { + return err + } + case "thinking": + thinking, _ := bm["thinking"].(string) + if strings.TrimSpace(thinking) != "" { + delta := map[string]any{ + "type": "content_block_delta", + "index": blockIndex, + "delta": map[string]any{"type": "thinking_delta", "thinking": thinking}, + } + if err := writeClaudeSSEEvent(w, "content_block_delta", delta); err != nil { + return err + } + } + if sig, _ := bm["signature"].(string); strings.TrimSpace(sig) != "" { + delta := map[string]any{ + "type": "content_block_delta", + "index": blockIndex, + "delta": map[string]any{"type": "signature_delta", "signature": sig}, + } + if err := writeClaudeSSEEvent(w, "content_block_delta", delta); err != nil { + return err + } + } + case "tool_use": + inputJSON, _ := json.Marshal(bm["input"]) + delta := map[string]any{ + "type": "content_block_delta", + "index": blockIndex, + "delta": map[string]any{"type": "input_json_delta", "partial_json": string(inputJSON)}, + } + if err := writeClaudeSSEEvent(w, "content_block_delta", delta); err != nil { + return err + } + } + + stop := map[string]any{"type": "content_block_stop", "index": blockIndex} + if err := writeClaudeSSEEvent(w, "content_block_stop", stop); err != nil { + return err + } + flusher.Flush() + blockIndex++ + } + + messageDelta := map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": map[string]any{ + "input_tokens": usage.InputTokens, + "output_tokens": usage.OutputTokens, + "cache_creation_input_tokens": usage.CacheCreationInputTokens, + "cache_read_input_tokens": usage.CacheReadInputTokens, + }, + } + if err := writeClaudeSSEEvent(w, "message_delta", messageDelta); err != nil { + return err + } + if err := writeClaudeSSEEvent(w, "message_stop", map[string]any{"type": "message_stop"}); err != nil { + return err + } + flusher.Flush() + return nil +} + +func writeClaudeSSEEvent(w io.Writer, event string, payload any) error { + b, err := json.Marshal(payload) + if err != nil { + return err + } + _, err = fmt.Fprintf(w, "event: %s\n", event) + if err != nil { + return err + } + _, err = fmt.Fprintf(w, "data: %s\n\n", string(b)) + return err +} diff --git a/backend/internal/service/openai_messages_compat_service_test.go b/backend/internal/service/openai_messages_compat_service_test.go new file mode 100644 index 0000000000..8c8034231b --- /dev/null +++ b/backend/internal/service/openai_messages_compat_service_test.go @@ -0,0 +1,451 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type testHTTPUpstream struct { + client *http.Client +} + +func (t *testHTTPUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + if t.client == nil { + t.client = http.DefaultClient + } + return t.client.Do(req) +} + +func (t *testHTTPUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + if t.client == nil { + t.client = http.DefaultClient + } + return t.client.Do(req) +} + +func TestConvertClaudeMessagesToOpenAIResponsesInput_ToolUseAndResult(t *testing.T) { + messages := []any{ + map[string]any{ + "role": "assistant", + "content": []any{ + map[string]any{ + "type": "tool_use", + "id": "toolu_1", + "name": "get_weather", + "input": map[string]any{"city": "SF"}, + }, + }, + }, + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "tool_result", + "tool_use_id": "toolu_1", + "content": []any{ + map[string]any{"type": "text", "text": "sunny"}, + }, + }, + }, + }, + } + + out, err := convertClaudeMessagesToOpenAIResponsesInput(messages) + require.NoError(t, err) + require.Len(t, out, 2) + + call, ok := out[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "function_call", call["type"]) + require.Equal(t, "toolu_1", call["call_id"]) + require.Equal(t, "get_weather", call["name"]) + + args, ok := call["arguments"].(string) + require.True(t, ok) + var parsedArgs map[string]any + require.NoError(t, json.Unmarshal([]byte(args), &parsedArgs)) + require.Equal(t, "SF", parsedArgs["city"]) + + callOut, ok := out[1].(map[string]any) + require.True(t, ok) + require.Equal(t, "function_call_output", callOut["type"]) + require.Equal(t, "toolu_1", callOut["call_id"]) + require.Equal(t, "get_weather", callOut["name"]) + require.Equal(t, "sunny", callOut["output"]) +} + +func TestConvertClaudeMessagesToOpenAIResponsesInput_ImageBlock(t *testing.T) { + messages := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "image", + "source": map[string]any{ + "type": "base64", + "media_type": "image/png", + "data": "aGVsbG8=", + }, + }, + }, + }, + } + + out, err := convertClaudeMessagesToOpenAIResponsesInput(messages) + require.NoError(t, err) + require.Len(t, out, 1) + + msg, ok := out[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "message", msg["type"]) + require.Equal(t, "user", msg["role"]) + + content, ok := msg["content"].([]any) + require.True(t, ok) + require.Len(t, content, 1) + block, ok := content[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "input_image", block["type"]) + + imageURL, ok := block["image_url"].(map[string]any) + require.True(t, ok) + require.Equal(t, "data:image/png;base64,aGVsbG8=", imageURL["url"]) +} + +func TestConvertOpenAIResponsesJSONToClaude_TextOnly(t *testing.T) { + resp := map[string]any{ + "id": "resp_123", + "output": []any{ + map[string]any{ + "type": "message", + "content": []any{ + map[string]any{"type": "output_text", "text": "hi"}, + }, + }, + }, + "usage": map[string]any{ + "input_tokens": 5, + "output_tokens": 7, + "input_tokens_details": map[string]any{ + "cached_tokens": 2, + }, + }, + } + b, _ := json.Marshal(resp) + + claudeResp, usage, stopReason, err := convertOpenAIResponsesJSONToClaude(b, "gpt-5.2") + require.NoError(t, err) + require.Equal(t, "end_turn", stopReason) + require.NotNil(t, usage) + require.Equal(t, 5, usage.InputTokens) + require.Equal(t, 7, usage.OutputTokens) + require.Equal(t, 2, usage.CacheReadInputTokens) + + require.Equal(t, "message", claudeResp["type"]) + require.Equal(t, "assistant", claudeResp["role"]) + require.Equal(t, "gpt-5.2", claudeResp["model"]) + require.Equal(t, "end_turn", claudeResp["stop_reason"]) + + content, ok := claudeResp["content"].([]any) + require.True(t, ok) + require.Len(t, content, 1) + block, ok := content[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "text", block["type"]) + require.Equal(t, "hi", block["text"]) +} + +func TestConvertOpenAIResponsesJSONToClaude_ToolCall(t *testing.T) { + resp := map[string]any{ + "id": "resp_456", + "output": []any{ + map[string]any{ + "type": "function_call", + "call_id": "call_1", + "name": "get_weather", + "arguments": "{\"city\":\"SF\"}", + }, + }, + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 2, + }, + } + b, _ := json.Marshal(resp) + + claudeResp, _, stopReason, err := convertOpenAIResponsesJSONToClaude(b, "gpt-5.2") + require.NoError(t, err) + require.Equal(t, "tool_use", stopReason) + require.Equal(t, "tool_use", claudeResp["stop_reason"]) + + content, ok := claudeResp["content"].([]any) + require.True(t, ok) + require.Len(t, content, 1) + block, ok := content[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "tool_use", block["type"]) + require.Equal(t, "call_1", block["id"]) + require.Equal(t, "get_weather", block["name"]) + + input, ok := block["input"].(map[string]any) + require.True(t, ok) + require.Equal(t, "SF", input["city"]) +} + +func TestExtractOpenAIResponsesInputTokens_OK(t *testing.T) { + resp := map[string]any{ + "usage": map[string]any{ + "input_tokens": 42, + }, + } + b, _ := json.Marshal(resp) + + inputTokens, err := extractOpenAIResponsesInputTokens(b) + require.NoError(t, err) + require.Equal(t, 42, inputTokens) +} + +func TestExtractOpenAIResponsesInputTokens_MissingUsage(t *testing.T) { + resp := map[string]any{} + b, _ := json.Marshal(resp) + + _, err := extractOpenAIResponsesInputTokens(b) + require.Error(t, err) +} + +func TestExtractOpenAIResponsesInputTokens_MissingInputTokens(t *testing.T) { + resp := map[string]any{ + "usage": map[string]any{}, + } + b, _ := json.Marshal(resp) + + _, err := extractOpenAIResponsesInputTokens(b) + require.Error(t, err) +} + +func TestOpenAIMessagesCompatService_ForwardCountTokens_JSONUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + + type capturedReq struct { + Path string + Headers http.Header + Body []byte + JSON map[string]any + } + capCh := make(chan capturedReq, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var m map[string]any + _ = json.Unmarshal(body, &m) + capCh <- capturedReq{Path: r.URL.Path, Headers: r.Header.Clone(), Body: body, JSON: m} + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "usage": map[string]any{"input_tokens": 123}, + }) + })) + defer server.Close() + + cfg := &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false, AllowInsecureHTTP: true}}} + upstream := &testHTTPUpstream{} + openaiSvc := NewOpenAIGatewayService(nil, nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, upstream, nil, nil, nil) + compat := NewOpenAIMessagesCompatService(openaiSvc) + + account := &Account{ + ID: 1, + Name: "openai-test", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": server.URL, + "user_agent": "sub2api-test", + }, + Concurrency: 1, + } + + claudeReq := map[string]any{ + "model": "gpt-5.2", + "max_tokens": 10, + "system": "You are helpful.", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + body, err := json.Marshal(claudeReq) + require.NoError(t, err) + parsed, err := ParseGatewayRequest(body, "") + require.NoError(t, err) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", bytes.NewReader(body)) + + err = compat.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + require.Equal(t, http.StatusOK, w.Code) + require.JSONEq(t, `{"input_tokens":123}`, w.Body.String()) + + select { + case cap := <-capCh: + require.Equal(t, "/v1/responses", cap.Path) + require.Equal(t, "Bearer sk-test", cap.Headers.Get("authorization")) + require.Equal(t, false, cap.JSON["store"]) + require.Equal(t, false, cap.JSON["stream"]) + require.Equal(t, "You are helpful.", cap.JSON["instructions"]) + require.Equal(t, float64(1), cap.JSON["max_output_tokens"]) + require.NotNil(t, cap.JSON["input"]) + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for upstream request") + } +} + +func TestOpenAIMessagesCompatService_ForwardCountTokens_SSEUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + + capCh := make(chan struct{}, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capCh <- struct{}{} + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":77}}}\n")) + _, _ = w.Write([]byte("data: [DONE]\n")) + })) + defer server.Close() + + cfg := &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false, AllowInsecureHTTP: true}}} + upstream := &testHTTPUpstream{} + openaiSvc := NewOpenAIGatewayService(nil, nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, upstream, nil, nil, nil) + compat := NewOpenAIMessagesCompatService(openaiSvc) + + account := &Account{ + ID: 1, + Name: "openai-test", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": server.URL, + }, + Concurrency: 1, + } + + claudeReq := map[string]any{ + "model": "gpt-5.2", + "max_tokens": 10, + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + body, err := json.Marshal(claudeReq) + require.NoError(t, err) + parsed, err := ParseGatewayRequest(body, "") + require.NoError(t, err) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", bytes.NewReader(body)) + + err = compat.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + require.Equal(t, http.StatusOK, w.Code) + require.JSONEq(t, `{"input_tokens":77}`, w.Body.String()) + + select { + case <-capCh: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for upstream request") + } +} + +func TestOpenAIMessagesCompatService_Forward_ForwardsMaxTokens(t *testing.T) { + gin.SetMode(gin.TestMode) + + type capturedReq struct { + Path string + Headers http.Header + JSON map[string]any + } + capCh := make(chan capturedReq, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var m map[string]any + _ = json.Unmarshal(body, &m) + capCh <- capturedReq{Path: r.URL.Path, Headers: r.Header.Clone(), JSON: m} + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "resp_1", + "output": []any{ + map[string]any{ + "type": "message", + "content": []any{ + map[string]any{"type": "output_text", "text": "hi"}, + }, + }, + }, + "usage": map[string]any{"input_tokens": 5, "output_tokens": 7}, + }) + })) + defer server.Close() + + cfg := &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false, AllowInsecureHTTP: true}}} + upstream := &testHTTPUpstream{} + openaiSvc := NewOpenAIGatewayService(nil, nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, upstream, nil, nil, nil) + compat := NewOpenAIMessagesCompatService(openaiSvc) + + account := &Account{ + ID: 1, + Name: "openai-test", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": server.URL, + }, + Concurrency: 1, + } + + claudeReq := map[string]any{ + "model": "gpt-5.2", + "max_tokens": 99, + "system": "You are helpful.", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + body, err := json.Marshal(claudeReq) + require.NoError(t, err) + parsed, err := ParseGatewayRequest(body, "") + require.NoError(t, err) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + + result, err := compat.Forward(context.Background(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, http.StatusOK, w.Code) + + select { + case cap := <-capCh: + require.Equal(t, "/v1/responses", cap.Path) + require.Equal(t, "Bearer sk-test", cap.Headers.Get("authorization")) + require.Equal(t, float64(99), cap.JSON["max_output_tokens"]) + require.Equal(t, "You are helpful.", cap.JSON["instructions"]) + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for upstream request") + } +} diff --git a/backend/internal/service/openai_responses_claude_compat.go b/backend/internal/service/openai_responses_claude_compat.go new file mode 100644 index 0000000000..8715ee60a7 --- /dev/null +++ b/backend/internal/service/openai_responses_claude_compat.go @@ -0,0 +1,438 @@ +package service + +import ( + "encoding/json" + "errors" + "strings" + "time" +) + +func ConvertOpenAIResponsesRequestToClaudeMessages(req map[string]any) (map[string]any, error) { + if req == nil { + return nil, errors.New("empty request") + } + + model, _ := req["model"].(string) + model = strings.TrimSpace(model) + if model == "" { + return nil, errors.New("model is required") + } + + systemParts := make([]string, 0, 2) + if instr, ok := req["instructions"].(string); ok { + if s := strings.TrimSpace(instr); s != "" { + systemParts = append(systemParts, s) + } + } + + msgs, sysFromInput, err := convertOpenAIResponsesInputToClaudeMessages(req["input"]) + if err != nil { + return nil, err + } + if strings.TrimSpace(sysFromInput) != "" { + systemParts = append(systemParts, sysFromInput) + } + + maxTokens := 0 + if v, ok := parseIntegralNumber(req["max_output_tokens"]); ok { + maxTokens = v + } else if v, ok := parseIntegralNumber(req["max_tokens"]); ok { + maxTokens = v + } + if maxTokens <= 0 { + maxTokens = 1024 + } + + claudeReq := map[string]any{ + "model": model, + "max_tokens": maxTokens, + "stream": false, + "messages": msgs, + } + if len(systemParts) > 0 { + claudeReq["system"] = strings.Join(systemParts, "\n") + } + + if tools := convertOpenAIResponsesToolsToClaudeTools(req["tools"]); len(tools) > 0 { + claudeReq["tools"] = tools + } + + if v, ok := req["temperature"].(float64); ok { + claudeReq["temperature"] = v + } + if v, ok := req["top_p"].(float64); ok { + claudeReq["top_p"] = v + } + if stop, ok := req["stop"].([]any); ok && len(stop) > 0 { + claudeReq["stop_sequences"] = stop + } else if stopStr, ok := req["stop"].(string); ok && strings.TrimSpace(stopStr) != "" { + claudeReq["stop_sequences"] = []any{stopStr} + } + + return claudeReq, nil +} + +func ConvertClaudeMessageToOpenAIResponsesResponse(claudeResp map[string]any, usage *ClaudeUsage, requestedModel, responseID string) (map[string]any, error) { + if claudeResp == nil { + return nil, errors.New("empty claude response") + } + if strings.TrimSpace(responseID) == "" { + responseID = "resp_" + randomHex(12) + } + createdAt := time.Now().Unix() + + if usage == nil { + usage = extractClaudeUsageFromResponse(claudeResp) + if usage == nil { + usage = &ClaudeUsage{} + } + } + + items, err := convertClaudeResponseContentToOpenAIOutputItems(claudeResp["content"]) + if err != nil { + return nil, err + } + + openaiUsage := map[string]any{ + "input_tokens": usage.InputTokens, + "output_tokens": usage.OutputTokens, + "total_tokens": usage.InputTokens + usage.OutputTokens, + } + if usage.CacheReadInputTokens > 0 { + openaiUsage["input_tokens_details"] = map[string]any{ + "cached_tokens": usage.CacheReadInputTokens, + } + } + + resp := map[string]any{ + "id": responseID, + "object": "response", + "created_at": createdAt, + "model": strings.TrimSpace(requestedModel), + "status": "completed", + "output": items, + "usage": openaiUsage, + } + + return resp, nil +} + +func extractClaudeUsageFromResponse(claudeResp map[string]any) *ClaudeUsage { + if claudeResp == nil { + return nil + } + u, ok := claudeResp["usage"].(map[string]any) + if !ok || u == nil { + return nil + } + in, _ := asInt(u["input_tokens"]) + out, _ := asInt(u["output_tokens"]) + cr, _ := asInt(u["cache_read_input_tokens"]) + cc, _ := asInt(u["cache_creation_input_tokens"]) + return &ClaudeUsage{ + InputTokens: in, + OutputTokens: out, + CacheReadInputTokens: cr, + CacheCreationInputTokens: cc, + } +} + +func convertClaudeResponseContentToOpenAIOutputItems(content any) ([]any, error) { + blocks, ok := content.([]any) + if !ok { + if s, ok := content.(string); ok { + s = strings.TrimSpace(s) + if s == "" { + return []any{}, nil + } + return []any{map[string]any{ + "type": "message", + "role": "assistant", + "content": []any{ + map[string]any{"type": "output_text", "text": s}, + }, + }}, nil + } + return []any{}, nil + } + + messageContent := make([]any, 0) + items := make([]any, 0) + + for _, b := range blocks { + bm, ok := b.(map[string]any) + if !ok { + continue + } + bt, _ := bm["type"].(string) + bt = strings.ToLower(strings.TrimSpace(bt)) + switch bt { + case "text": + if text, ok := bm["text"].(string); ok && strings.TrimSpace(text) != "" { + messageContent = append(messageContent, map[string]any{"type": "output_text", "text": text}) + } + case "tool_use": + callID, _ := bm["id"].(string) + name, _ := bm["name"].(string) + callID = strings.TrimSpace(callID) + name = strings.TrimSpace(name) + args := bm["input"] + argsJSON, _ := json.Marshal(args) + if callID == "" { + callID = "call_" + randomHex(12) + } + items = append(items, map[string]any{ + "type": "function_call", + "id": callID, + "call_id": callID, + "name": name, + "arguments": string(argsJSON), + }) + default: + } + } + + if len(messageContent) > 0 { + items = append([]any{map[string]any{ + "type": "message", + "role": "assistant", + "content": messageContent, + }}, items...) + } + + return items, nil +} + +func convertOpenAIResponsesToolsToClaudeTools(tools any) []any { + arr, ok := tools.([]any) + if !ok || len(arr) == 0 { + return nil + } + out := make([]any, 0, len(arr)) + for _, t := range arr { + tm, ok := t.(map[string]any) + if !ok { + continue + } + toolType, _ := tm["type"].(string) + toolType = strings.TrimSpace(toolType) + if toolType != "function" { + continue + } + + name, _ := tm["name"].(string) + desc, _ := tm["description"].(string) + params := tm["parameters"] + if fn, ok := tm["function"].(map[string]any); ok && fn != nil { + if v, ok := fn["name"].(string); ok { + name = v + } + if v, ok := fn["description"].(string); ok { + desc = v + } + if v := fn["parameters"]; v != nil { + params = v + } + } + + name = strings.TrimSpace(name) + if name == "" { + continue + } + if params == nil { + params = map[string]any{"type": "object", "properties": map[string]any{}} + } + out = append(out, map[string]any{ + "name": name, + "description": strings.TrimSpace(desc), + "input_schema": params, + }) + } + return out +} + +func convertOpenAIResponsesInputToClaudeMessages(input any) ([]any, string, error) { + systemParts := make([]string, 0) + messages := make([]any, 0) + + appendMessage := func(role string, content any) { + role = strings.ToLower(strings.TrimSpace(role)) + if role == "" { + role = "user" + } + messages = append(messages, map[string]any{"role": role, "content": content}) + } + + appendTextMessage := func(role, text string) { + text = strings.TrimSpace(text) + if text == "" { + return + } + appendMessage(role, text) + } + + convertMessageContent := func(content any) any { + switch v := content.(type) { + case string: + return v + case []any: + blocks := make([]any, 0, len(v)) + for _, it := range v { + im, ok := it.(map[string]any) + if !ok { + continue + } + itType, _ := im["type"].(string) + itType = strings.ToLower(strings.TrimSpace(itType)) + switch itType { + case "input_text", "output_text", "text": + if text, ok := im["text"].(string); ok { + text = strings.TrimSpace(text) + if text != "" { + blocks = append(blocks, map[string]any{"type": "text", "text": text}) + } + } + case "input_image": + if iu, ok := im["image_url"].(map[string]any); ok { + if urlStr, ok := iu["url"].(string); ok { + if mediaType, data, ok := parseDataURL(urlStr); ok { + blocks = append(blocks, map[string]any{ + "type": "image", + "source": map[string]any{ + "type": "base64", + "media_type": mediaType, + "data": data, + }, + }) + } else { + blocks = append(blocks, map[string]any{"type": "text", "text": urlStr}) + } + } + } + } + } + if len(blocks) == 0 { + return "" + } + return blocks + default: + return "" + } + } + + switch v := input.(type) { + case string: + appendTextMessage("user", v) + case []any: + for _, item := range v { + im, ok := item.(map[string]any) + if !ok { + continue + } + itType, _ := im["type"].(string) + itType = strings.ToLower(strings.TrimSpace(itType)) + switch itType { + case "message": + role, _ := im["role"].(string) + role = strings.ToLower(strings.TrimSpace(role)) + contentAny := convertMessageContent(im["content"]) + if role == "system" { + switch c := contentAny.(type) { + case string: + if s := strings.TrimSpace(c); s != "" { + systemParts = append(systemParts, s) + } + case []any: + if s := extractClaudeContentText(c); strings.TrimSpace(s) != "" { + systemParts = append(systemParts, strings.TrimSpace(s)) + } + } + continue + } + if role == "" { + role = "user" + } + if contentAny == "" { + continue + } + appendMessage(role, contentAny) + case "function_call", "tool_call": + callID, _ := im["call_id"].(string) + if strings.TrimSpace(callID) == "" { + callID, _ = im["id"].(string) + } + callID = strings.TrimSpace(callID) + name, _ := im["name"].(string) + name = strings.TrimSpace(name) + argsAny := im["arguments"] + var toolInput any + switch a := argsAny.(type) { + case string: + if strings.TrimSpace(a) != "" { + var parsed any + if json.Unmarshal([]byte(a), &parsed) == nil { + toolInput = parsed + } else { + toolInput = a + } + } + default: + toolInput = a + } + if callID == "" { + callID = "call_" + randomHex(12) + } + appendMessage("assistant", []any{map[string]any{ + "type": "tool_use", + "id": callID, + "name": name, + "input": toolInput, + }}) + case "function_call_output": + callID, _ := im["call_id"].(string) + callID = strings.TrimSpace(callID) + output, _ := im["output"].(string) + output = strings.TrimSpace(output) + if callID == "" { + continue + } + appendMessage("user", []any{map[string]any{ + "type": "tool_result", + "tool_use_id": callID, + "content": output, + }}) + case "input_text", "text": + if text, ok := im["text"].(string); ok { + appendTextMessage("user", text) + } + default: + } + } + default: + } + + return messages, strings.Join(systemParts, "\n"), nil +} + +func parseDataURL(urlStr string) (mediaType string, data string, ok bool) { + urlStr = strings.TrimSpace(urlStr) + if !strings.HasPrefix(urlStr, "data:") { + return "", "", false + } + comma := strings.Index(urlStr, ",") + if comma < 0 { + return "", "", false + } + header := urlStr[:comma] + payload := urlStr[comma+1:] + if !strings.Contains(header, ";base64") { + return "", "", false + } + mt := strings.TrimPrefix(header, "data:") + mt = strings.TrimSuffix(mt, ";base64") + mt = strings.TrimSpace(mt) + if mt == "" || strings.TrimSpace(payload) == "" { + return "", "", false + } + return mt, strings.TrimSpace(payload), true +} diff --git a/backend/internal/service/openai_responses_url.go b/backend/internal/service/openai_responses_url.go new file mode 100644 index 0000000000..b5b130435c --- /dev/null +++ b/backend/internal/service/openai_responses_url.go @@ -0,0 +1,24 @@ +package service + +import "strings" + +func openaiResponsesURLFromBaseURL(normalizedBaseURL string, isGitHubCopilot bool) string { + base := strings.TrimRight(strings.TrimSpace(normalizedBaseURL), "/") + if strings.HasSuffix(base, "/responses") { + // GitHub Copilot expects /responses without /v1. + if isGitHubCopilot && strings.HasSuffix(base, "/v1/responses") { + base = strings.TrimSuffix(base, "/v1/responses") + base = strings.TrimRight(base, "/") + return base + "/responses" + } + return base + } + if isGitHubCopilot { + base = strings.TrimSuffix(base, "/v1") + return base + "/responses" + } + if strings.HasSuffix(base, "/v1") { + return base + "/responses" + } + return base + "/v1/responses" +} diff --git a/backend/internal/service/openai_responses_url_test.go b/backend/internal/service/openai_responses_url_test.go new file mode 100644 index 0000000000..f79e7d5780 --- /dev/null +++ b/backend/internal/service/openai_responses_url_test.go @@ -0,0 +1,109 @@ +//go:build unit + +package service + +import "testing" + +func TestOpenAIResponsesURLFromBaseURL(t *testing.T) { + tests := []struct { + name string + baseURL string + isCopilot bool + expectedURL string + }{ + { + name: "openai root", + baseURL: "https://api.openai.com", + isCopilot: false, + expectedURL: "https://api.openai.com/v1/responses", + }, + { + name: "openai root trailing slash", + baseURL: "https://api.openai.com/", + isCopilot: false, + expectedURL: "https://api.openai.com/v1/responses", + }, + { + name: "openai v1", + baseURL: "https://api.openai.com/v1", + isCopilot: false, + expectedURL: "https://api.openai.com/v1/responses", + }, + { + name: "openai v1 trailing slash", + baseURL: "https://api.openai.com/v1/", + isCopilot: false, + expectedURL: "https://api.openai.com/v1/responses", + }, + { + name: "openai responses endpoint", + baseURL: "https://api.openai.com/v1/responses", + isCopilot: false, + expectedURL: "https://api.openai.com/v1/responses", + }, + { + name: "openai path prefix", + baseURL: "https://proxy.example.com/openai", + isCopilot: false, + expectedURL: "https://proxy.example.com/openai/v1/responses", + }, + { + name: "openai path prefix v1", + baseURL: "https://proxy.example.com/openai/v1", + isCopilot: false, + expectedURL: "https://proxy.example.com/openai/v1/responses", + }, + + { + name: "github copilot root", + baseURL: "https://api.githubcopilot.com", + isCopilot: true, + expectedURL: "https://api.githubcopilot.com/responses", + }, + { + name: "github copilot root trailing slash", + baseURL: "https://api.githubcopilot.com/", + isCopilot: true, + expectedURL: "https://api.githubcopilot.com/responses", + }, + { + name: "github copilot v1", + baseURL: "https://api.githubcopilot.com/v1", + isCopilot: true, + expectedURL: "https://api.githubcopilot.com/responses", + }, + { + name: "github copilot v1 trailing slash", + baseURL: "https://api.githubcopilot.com/v1/", + isCopilot: true, + expectedURL: "https://api.githubcopilot.com/responses", + }, + { + name: "github copilot responses endpoint", + baseURL: "https://api.githubcopilot.com/responses", + isCopilot: true, + expectedURL: "https://api.githubcopilot.com/responses", + }, + { + name: "github copilot v1 responses endpoint", + baseURL: "https://api.githubcopilot.com/v1/responses", + isCopilot: true, + expectedURL: "https://api.githubcopilot.com/responses", + }, + { + name: "github copilot enterprise subdomain", + baseURL: "https://api.business.githubcopilot.com", + isCopilot: true, + expectedURL: "https://api.business.githubcopilot.com/responses", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := openaiResponsesURLFromBaseURL(tt.baseURL, tt.isCopilot) + if got != tt.expectedURL { + t.Fatalf("openaiResponsesURLFromBaseURL(%q, %v) = %q, want %q", tt.baseURL, tt.isCopilot, got, tt.expectedURL) + } + }) + } +} diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index d8db0d67eb..cdd3fe5cfb 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -25,7 +25,6 @@ var ( ) // LiteLLMModelPricing LiteLLM价格数据结构 -// 只保留我们需要的字段,使用指针来处理可能缺失的值 type LiteLLMModelPricing struct { InputCostPerToken float64 `json:"input_cost_per_token"` OutputCostPerToken float64 `json:"output_cost_per_token"` @@ -34,7 +33,23 @@ type LiteLLMModelPricing struct { LiteLLMProvider string `json:"litellm_provider"` Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` - OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 + OutputCostPerImage float64 `json:"output_cost_per_image"` + MaxInputTokens int `json:"max_input_tokens"` + MaxOutputTokens int `json:"max_output_tokens"` + Source string `json:"source,omitempty"` +} + +type ModelInfo struct { + ID string `json:"id"` + Object string `json:"object"` + Type string `json:"type"` + DisplayName string `json:"display_name,omitempty"` + ContextWindow int `json:"context_window,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + InputPricePer1M float64 `json:"input_price_per_1m,omitempty"` + OutputPricePer1M float64 `json:"output_price_per_1m,omitempty"` + LiteLLMProvider string `json:"litellm_provider,omitempty"` + Source string `json:"source,omitempty"` } // PricingRemoteClient 远程价格数据获取接口 @@ -53,6 +68,10 @@ type LiteLLMRawEntry struct { Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` OutputCostPerImage *float64 `json:"output_cost_per_image"` + MaxInputTokens *int `json:"max_input_tokens"` + MaxOutputTokens *int `json:"max_output_tokens"` + MaxTokens *int `json:"max_tokens"` + Source string `json:"source"` } // PricingService 动态价格服务 @@ -307,6 +326,7 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel LiteLLMProvider: entry.LiteLLMProvider, Mode: entry.Mode, SupportsPromptCaching: entry.SupportsPromptCaching, + Source: strings.TrimSpace(entry.Source), } if entry.InputCostPerToken != nil { @@ -324,6 +344,14 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel if entry.OutputCostPerImage != nil { pricing.OutputCostPerImage = *entry.OutputCostPerImage } + if entry.MaxInputTokens != nil { + pricing.MaxInputTokens = *entry.MaxInputTokens + } + if entry.MaxOutputTokens != nil { + pricing.MaxOutputTokens = *entry.MaxOutputTokens + } else if entry.MaxTokens != nil { + pricing.MaxOutputTokens = *entry.MaxTokens + } result[modelName] = pricing } @@ -752,3 +780,99 @@ func isNumeric(s string) bool { } return true } + +// GetModelPricingWithProvider 获取带 provider 前缀的模型定价 +// 优先级: provider/model > model (模糊匹配) +func (s *PricingService) GetModelPricingWithProvider(provider, model string) *LiteLLMModelPricing { + // 1. 尝试 provider/model 格式 + if provider != "" { + key := provider + "/" + model + if pricing := s.GetModelPricing(key); pricing != nil { + return pricing + } + } + // 2. 回退到纯模型名查询 + return s.GetModelPricing(model) +} + +// GetModelInfo 获取模型完整信息(用于 Models API 响应) +func (s *PricingService) GetModelInfo(provider, model string) *ModelInfo { + pricing := s.GetModelPricingWithProvider(provider, model) + if pricing == nil { + return nil + } + + id := model + if provider != "" { + id = provider + "/" + model + } + + return &ModelInfo{ + ID: id, + Object: "model", + Type: "model", + DisplayName: model, + ContextWindow: pricing.MaxInputTokens, + MaxOutputTokens: pricing.MaxOutputTokens, + InputPricePer1M: pricing.InputCostPerToken * 1_000_000, + OutputPricePer1M: pricing.OutputCostPerToken * 1_000_000, + LiteLLMProvider: pricing.LiteLLMProvider, + Source: pricing.Source, + } +} + +// ListAllModelsWithProvider 列出所有模型,带 provider 前缀 +// 返回 provider 去重后的模型列表 +func (s *PricingService) ListAllModelsWithProvider() []ModelInfo { + s.mu.RLock() + defer s.mu.RUnlock() + + seen := make(map[string]bool) + var result []ModelInfo + + for key, pricing := range s.pricingData { + var provider, model string + idx := strings.Index(key, "/") + if idx > 0 && idx < len(key)-1 { + provider = key[:idx] + model = key[idx+1:] + } else { + model = key + provider = pricing.LiteLLMProvider + } + + id := key + if !strings.Contains(key, "/") && provider != "" { + id = provider + "/" + key + } + + if seen[id] { + continue + } + seen[id] = true + + result = append(result, ModelInfo{ + ID: id, + Object: "model", + Type: "model", + DisplayName: model, + ContextWindow: pricing.MaxInputTokens, + MaxOutputTokens: pricing.MaxOutputTokens, + InputPricePer1M: pricing.InputCostPerToken * 1_000_000, + OutputPricePer1M: pricing.OutputCostPerToken * 1_000_000, + LiteLLMProvider: pricing.LiteLLMProvider, + Source: pricing.Source, + }) + } + + return result +} + +// GetContextWindow 获取模型的上下文窗口大小 +func (s *PricingService) GetContextWindow(provider, model string) int { + pricing := s.GetModelPricingWithProvider(provider, model) + if pricing == nil { + return 0 + } + return pricing.MaxInputTokens +} diff --git a/backend/internal/service/pricing_service_source_test.go b/backend/internal/service/pricing_service_source_test.go new file mode 100644 index 0000000000..501cbf48bc --- /dev/null +++ b/backend/internal/service/pricing_service_source_test.go @@ -0,0 +1,30 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPricingService_ParsePricingData_PreservesSource(t *testing.T) { + s := &PricingService{} + body := []byte(`{ + "gpt-5.2": { + "input_cost_per_token": 0.000001, + "output_cost_per_token": 0.000002, + "litellm_provider": "openai", + "mode": "chat", + "supports_prompt_caching": true, + "max_input_tokens": 10, + "max_output_tokens": 20, + "source": "https://example.com/pricing" + } + }`) + + data, err := s.parsePricingData(body) + require.NoError(t, err) + require.NotNil(t, data["gpt-5.2"]) + require.Equal(t, "https://example.com/pricing", data["gpt-5.2"].Source) +} diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 52d455b81e..baa50f69b3 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -386,7 +386,7 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI if len(groupIDs) == 0 { return nil } - platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity} + platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformCopilot, PlatformAggregator, PlatformAntigravity} var firstErr error for _, platform := range platforms { if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason); err != nil && firstErr == nil { @@ -669,7 +669,7 @@ func (s *SchedulerSnapshotService) fullRebuildInterval() time.Duration { func (s *SchedulerSnapshotService) defaultBuckets(ctx context.Context) ([]SchedulerBucket, error) { buckets := make([]SchedulerBucket, 0) - platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity} + platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformCopilot, PlatformAggregator, PlatformAntigravity} for _, platform := range platforms { buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeSingle}) buckets = append(buckets, SchedulerBucket{GroupID: 0, Platform: platform, Mode: SchedulerModeForced}) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f5ba9d7103..9668d1e0ae 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -664,6 +664,9 @@ func (s *SettingService) GetFallbackModel(ctx context.Context, platform string) case PlatformOpenAI: key = SettingKeyFallbackModelOpenAI defaultModel = "gpt-4o" + case PlatformCopilot, PlatformAggregator: + key = SettingKeyFallbackModelOpenAI + defaultModel = "gpt-4o" case PlatformGemini: key = SettingKeyFallbackModelGemini defaultModel = "gemini-2.5-pro" diff --git a/backend/internal/service/token_cache_key.go b/backend/internal/service/token_cache_key.go index df0c025ee5..ec97b7d5e1 100644 --- a/backend/internal/service/token_cache_key.go +++ b/backend/internal/service/token_cache_key.go @@ -13,3 +13,7 @@ func OpenAITokenCacheKey(account *Account) string { func ClaudeTokenCacheKey(account *Account) string { return "claude:account:" + strconv.FormatInt(account.ID, 10) } + +func GitHubCopilotTokenCacheKey(account *Account) string { + return "copilot:account:" + strconv.FormatInt(account.ID, 10) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 87ca78974d..2b770d87f6 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -52,6 +52,12 @@ func ProvideTokenRefreshService( return svc } +func ProvideCopilotModelRefreshService(accountRepo AccountRepository, githubCopilotToken *GitHubCopilotTokenProvider, cfg *config.Config) *CopilotModelRefreshService { + svc := NewCopilotModelRefreshService(accountRepo, githubCopilotToken, cfg) + svc.Start() + return svc +} + // ProvideDashboardAggregationService 创建并启动仪表盘聚合服务 func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService { svc := NewDashboardAggregationService(repo, timingWheel, cfg) @@ -240,7 +246,10 @@ var ProviderSet = wire.NewSet( NewAntigravityOAuthService, NewGeminiTokenProvider, NewGeminiMessagesCompatService, + NewOpenAIMessagesCompatService, NewAntigravityTokenProvider, + NewGitHubCopilotTokenProvider, + NewGitHubDeviceAuthService, NewOpenAITokenProvider, NewClaudeTokenProvider, NewAntigravityGatewayService, @@ -264,6 +273,7 @@ var ProviderSet = wire.NewSet( NewCRSSyncService, ProvideUpdateService, ProvideTokenRefreshService, + ProvideCopilotModelRefreshService, ProvideAccountExpiryService, ProvideSubscriptionExpiryService, ProvideTimingWheelService, diff --git a/backend/internal/web/dist/.keep b/backend/internal/web/dist/.keep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index d9f5f2ab5b..b345bc3f8d 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -88,6 +88,8 @@ security: upstream_hosts: - "api.openai.com" - "api.anthropic.com" + - "*.githubcopilot.com" + - "api.github.com" - "api.kimi.com" - "open.bigmodel.cn" - "api.minimaxi.com" diff --git a/frontend/src/__tests__/integration/data-import.spec.ts b/frontend/src/__tests__/integration/data-import.spec.ts index 1fe870abf6..de552eab4d 100644 --- a/frontend/src/__tests__/integration/data-import.spec.ts +++ b/frontend/src/__tests__/integration/data-import.spec.ts @@ -1,5 +1,5 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' -import { mount } from '@vue/test-utils' +import { mount, flushPromises } from '@vue/test-utils' import ImportDataModal from '@/components/admin/account/ImportDataModal.vue' const showError = vi.fn() @@ -64,6 +64,9 @@ describe('ImportDataModal', () => { await input.trigger('change') await wrapper.find('form').trigger('submit') + await flushPromises() + await new Promise((resolve) => setTimeout(resolve, 0)) + await flushPromises() expect(showError).toHaveBeenCalledWith('admin.accounts.dataImportParseFailed') }) diff --git a/frontend/src/__tests__/integration/proxy-data-import.spec.ts b/frontend/src/__tests__/integration/proxy-data-import.spec.ts index f0433898d3..dedce126c3 100644 --- a/frontend/src/__tests__/integration/proxy-data-import.spec.ts +++ b/frontend/src/__tests__/integration/proxy-data-import.spec.ts @@ -1,5 +1,5 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' -import { mount } from '@vue/test-utils' +import { mount, flushPromises } from '@vue/test-utils' import ImportDataModal from '@/components/admin/proxy/ImportDataModal.vue' const showError = vi.fn() @@ -64,6 +64,9 @@ describe('Proxy ImportDataModal', () => { await input.trigger('change') await wrapper.find('form').trigger('submit') + await flushPromises() + await new Promise((resolve) => setTimeout(resolve, 0)) + await flushPromises() expect(showError).toHaveBeenCalledWith('admin.proxies.dataImportParseFailed') }) diff --git a/frontend/src/__tests__/setup.ts b/frontend/src/__tests__/setup.ts index decb2a370a..7ff9fc0918 100644 --- a/frontend/src/__tests__/setup.ts +++ b/frontend/src/__tests__/setup.ts @@ -5,6 +5,26 @@ import { config } from '@vue/test-utils' import { vi } from 'vitest' +// jsdom File/Blob polyfills +// Some jsdom versions do not implement File.prototype.text() even though browsers do. +if (typeof globalThis.File !== 'undefined' && typeof globalThis.File.prototype.text !== 'function') { + globalThis.File.prototype.text = async function text(): Promise { + const blob = this as unknown as Blob + + if (typeof blob.arrayBuffer === 'function') { + const buf = await blob.arrayBuffer() + return new TextDecoder().decode(buf) + } + + return await new Promise((resolve, reject) => { + const reader = new FileReader() + reader.onload = () => resolve(String(reader.result ?? '')) + reader.onerror = () => reject(reader.error) + reader.readAsText(blob) + }) + } +} + // Mock requestIdleCallback (Safari < 15 不支持) if (typeof globalThis.requestIdleCallback === 'undefined') { globalThis.requestIdleCallback = ((callback: IdleRequestCallback) => { diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 4cb1a6f214..6dd0861d3f 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -18,6 +18,22 @@ import type { AdminDataImportResult } from '@/types' +export interface GitHubDeviceAuthStartResult { + session_id: string + user_code: string + verification_uri: string + verification_uri_complete?: string + expires_in: number + interval: number +} + +export interface GitHubDeviceAuthPollResult { + status: 'pending' | 'error' + interval?: number + error?: string + error_description?: string +} + /** * List all accounts with pagination * @param page - Page number (default: 1) @@ -118,6 +134,42 @@ export async function testAccount(id: number): Promise<{ return data } +export async function startGitHubDeviceAuth( + id: number, + params?: { + client_id?: string + scope?: string + } +): Promise { + const { data } = await apiClient.post( + `/admin/accounts/${id}/github/device/start`, + params || {} + ) + return data +} + +export async function pollGitHubDeviceAuth( + id: number, + sessionId: string +): Promise { + const { data } = await apiClient.post( + `/admin/accounts/${id}/github/device/poll`, + { session_id: sessionId } + ) + return data +} + +export async function cancelGitHubDeviceAuth( + id: number, + sessionId: string +): Promise<{ message: string }> { + const { data } = await apiClient.post<{ message: string }>( + `/admin/accounts/${id}/github/device/cancel`, + { session_id: sessionId } + ) + return data +} + /** * Refresh account credentials * @param id - Account ID @@ -327,6 +379,11 @@ export async function getAvailableModels(id: number): Promise { return data } +export async function refreshAvailableModels(id: number): Promise { + const { data } = await apiClient.post(`/admin/accounts/${id}/models/refresh`) + return data +} + export interface CRSPreviewAccount { crs_account_id: string kind: string @@ -461,6 +518,9 @@ export const accountsAPI = { delete: deleteAccount, toggleStatus, testAccount, + startGitHubDeviceAuth, + pollGitHubDeviceAuth, + cancelGitHubDeviceAuth, refreshCredentials, getStats, clearError, @@ -471,6 +531,7 @@ export const accountsAPI = { resetTempUnschedulable, setSchedulable, getAvailableModels, + refreshAvailableModels, generateAuthUrl, exchangeCode, refreshOpenAIToken, diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index af06abcaaf..d225312026 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -70,12 +70,15 @@
-
+
+
-
+
@@ -2027,12 +2066,16 @@ const oauthStepTitle = computed(() => { // Platform-specific hints for API Key type const baseUrlHint = computed(() => { if (form.platform === 'openai') return t('admin.accounts.openai.baseUrlHint') + if (form.platform === 'copilot') return '' + if (form.platform === 'aggregator') return '' if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint') return t('admin.accounts.baseUrlHint') }) const apiKeyHint = computed(() => { if (form.platform === 'openai') return t('admin.accounts.openai.apiKeyHint') + if (form.platform === 'copilot') return '' + if (form.platform === 'aggregator') return '' if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint') return t('admin.accounts.apiKeyHint') }) @@ -2311,8 +2354,12 @@ watch( apiKeyBaseUrl.value = newPlatform === 'openai' ? 'https://api.openai.com' + : newPlatform === 'copilot' + ? 'https://api.githubcopilot.com' : newPlatform === 'gemini' ? 'https://generativelanguage.googleapis.com' + : newPlatform === 'aggregator' + ? '' : 'https://api.anthropic.com' // Clear model-related settings allowedModels.value = [] @@ -2331,6 +2378,11 @@ watch( antigravityModelMappings.value = [] antigravityModelRestrictionMode.value = 'mapping' } + + // Copilot/Aggregator are API-key only + if (newPlatform === 'copilot' || newPlatform === 'aggregator') { + accountCategory.value = 'apikey' + } // Reset Anthropic-specific settings when switching to other platforms if (newPlatform !== 'anthropic') { interceptWarmupRequests.value = false @@ -2730,27 +2782,40 @@ const handleSubmit = async () => { return } + if (form.platform === 'aggregator' && !apiKeyBaseUrl.value.trim()) { + appStore.showError(t('admin.accounts.upstream.pleaseEnterBaseUrl')) + return + } + // Determine default base URL based on platform const defaultBaseUrl = form.platform === 'openai' ? 'https://api.openai.com' + : form.platform === 'copilot' + ? 'https://api.githubcopilot.com' : form.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' + : form.platform === 'aggregator' + ? '' : 'https://api.anthropic.com' // Build credentials with optional model mapping const credentials: Record = { base_url: apiKeyBaseUrl.value.trim() || defaultBaseUrl, - api_key: apiKeyValue.value.trim() + ...(form.platform === 'copilot' + ? { github_token: apiKeyValue.value.trim() } + : { api_key: apiKeyValue.value.trim() }) } if (form.platform === 'gemini') { credentials.tier_id = geminiTierAIStudio.value } // Add model mapping if configured - const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) - if (modelMapping) { - credentials.model_mapping = modelMapping + if (form.platform !== 'copilot') { + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + credentials.model_mapping = modelMapping + } } // Add custom error codes if enabled diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 60575f568e..161aa139ee 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -37,6 +37,10 @@ :placeholder=" account.platform === 'openai' ? 'https://api.openai.com' + : account.platform === 'copilot' + ? 'https://api.githubcopilot.com' + : account.platform === 'aggregator' + ? 'https://example.com' : account.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' : account.platform === 'antigravity' @@ -55,6 +59,8 @@ :placeholder=" account.platform === 'openai' ? 'sk-proj-...' + : account.platform === 'copilot' + ? 'github_pat_...' : account.platform === 'gemini' ? 'AIza...' : account.platform === 'antigravity' @@ -66,7 +72,10 @@
-
+
@@ -1050,6 +1059,8 @@ const authStore = useAuthStore() const baseUrlHint = computed(() => { if (!props.account) return t('admin.accounts.baseUrlHint') if (props.account.platform === 'openai') return t('admin.accounts.openai.baseUrlHint') + if (props.account.platform === 'copilot') return '' + if (props.account.platform === 'aggregator') return '' if (props.account.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint') return t('admin.accounts.baseUrlHint') }) @@ -1138,7 +1149,9 @@ const tempUnschedPresets = computed(() => [ // Computed: default base URL based on platform const defaultBaseUrl = computed(() => { if (props.account?.platform === 'openai') return 'https://api.openai.com' + if (props.account?.platform === 'copilot') return 'https://api.githubcopilot.com' if (props.account?.platform === 'gemini') return 'https://generativelanguage.googleapis.com' + if (props.account?.platform === 'aggregator') return '' return 'https://api.anthropic.com' }) @@ -1233,8 +1246,12 @@ watch( const platformDefaultUrl = newAccount.platform === 'openai' ? 'https://api.openai.com' + : newAccount.platform === 'copilot' + ? 'https://api.githubcopilot.com' : newAccount.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' + : newAccount.platform === 'aggregator' + ? '' : 'https://api.anthropic.com' editBaseUrl.value = (credentials.base_url as string) || platformDefaultUrl @@ -1279,8 +1296,12 @@ watch( const platformDefaultUrl = newAccount.platform === 'openai' ? 'https://api.openai.com' + : newAccount.platform === 'copilot' + ? 'https://api.githubcopilot.com' : newAccount.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' + : newAccount.platform === 'aggregator' + ? '' : 'https://api.anthropic.com' editBaseUrl.value = platformDefaultUrl modelRestrictionMode.value = 'whitelist' @@ -1575,7 +1596,10 @@ const handleSubmit = async () => { if (props.account.type === 'apikey') { const currentCredentials = (props.account.credentials as Record) || {} const newBaseUrl = editBaseUrl.value.trim() || defaultBaseUrl.value - const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + const modelMapping = + props.account.platform === 'copilot' + ? null + : buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) // Always update credentials for apikey type to handle model mapping changes const newCredentials: Record = { @@ -1583,12 +1607,13 @@ const handleSubmit = async () => { } // Handle API key + const tokenField = props.account.platform === 'copilot' ? 'github_token' : 'api_key' if (editApiKey.value.trim()) { - // User provided a new API key - newCredentials.api_key = editApiKey.value.trim() - } else if (currentCredentials.api_key) { - // Preserve existing api_key - newCredentials.api_key = currentCredentials.api_key + // User provided a new key/token + newCredentials[tokenField] = editApiKey.value.trim() + } else if (currentCredentials[tokenField]) { + // Preserve existing + newCredentials[tokenField] = currentCredentials[tokenField] } else { appStore.showError(t('admin.accounts.apiKeyIsRequired')) submitting.value = false diff --git a/frontend/src/components/account/OAuthAuthorizationFlow.vue b/frontend/src/components/account/OAuthAuthorizationFlow.vue index 22e179ba99..0e5ebcec05 100644 --- a/frontend/src/components/account/OAuthAuthorizationFlow.vue +++ b/frontend/src/components/account/OAuthAuthorizationFlow.vue @@ -511,6 +511,7 @@ import { useI18n } from 'vue-i18n' import { useClipboard } from '@/composables/useClipboard' import Icon from '@/components/icons/Icon.vue' import type { AddMethod, AuthInputMethod } from '@/composables/useAccountOAuth' +import type { AccountPlatform } from '@/types' interface Props { addMethod: AddMethod @@ -524,7 +525,7 @@ interface Props { methodLabel?: string showCookieOption?: boolean // Whether to show cookie auto-auth option showRefreshTokenOption?: boolean // Whether to show refresh token input option (OpenAI only) - platform?: 'anthropic' | 'openai' | 'gemini' | 'antigravity' // Platform type for different UI/text + platform?: AccountPlatform // Platform type for different UI/text showProjectId?: boolean // New prop to control project ID visibility } diff --git a/frontend/src/components/admin/ErrorPassthroughRulesModal.vue b/frontend/src/components/admin/ErrorPassthroughRulesModal.vue index 2ed6ded3dc..77dd9c0f14 100644 --- a/frontend/src/components/admin/ErrorPassthroughRulesModal.vue +++ b/frontend/src/components/admin/ErrorPassthroughRulesModal.vue @@ -488,6 +488,8 @@ const matchModeOptions = computed(() => [ const platformOptions = [ { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, + { value: 'copilot', label: 'Copilot' }, + { value: 'aggregator', label: 'Aggregator' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' } ] diff --git a/frontend/src/components/admin/account/AccountActionMenu.vue b/frontend/src/components/admin/account/AccountActionMenu.vue index 2325f4b401..f846d23873 100644 --- a/frontend/src/components/admin/account/AccountActionMenu.vue +++ b/frontend/src/components/admin/account/AccountActionMenu.vue @@ -18,6 +18,14 @@ {{ t('admin.accounts.viewStats') }} + +