diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ce898a4a90..2b960e5fa6 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -219,6 +219,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) + copilotTokenProvider := service.NewCopilotTokenProvider(httpUpstream) + copilotGatewayService := service.NewCopilotGatewayService(httpUpstream, copilotTokenProvider, rateLimitService, configConfig) + copilotGatewayHandler := handler.NewCopilotGatewayHandler(copilotGatewayService, gatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool) soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig) @@ -228,7 +231,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { totpHandler := handler.NewTotpHandler(totpService) idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, copilotGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 4e69ca0252..8b44f78b6c 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -23,6 +23,7 @@ const ( PlatformGemini = "gemini" PlatformAntigravity = "antigravity" PlatformSora = "sora" + PlatformCopilot = "copilot" ) // Account type constants diff --git a/backend/internal/handler/copilot_gateway_handler.go b/backend/internal/handler/copilot_gateway_handler.go new file mode 100644 index 0000000000..b0aecc53f1 --- /dev/null +++ b/backend/internal/handler/copilot_gateway_handler.go @@ -0,0 +1,348 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// CopilotGatewayHandler handles GitHub Copilot API gateway requests. +// It exposes /copilot/v1/chat/completions and /copilot/v1/responses endpoints +// and can also be invoked through the automatic platform routing on /v1/*. +type CopilotGatewayHandler struct { + copilotService *service.CopilotGatewayService + gatewayService *service.GatewayService + billingCacheService *service.BillingCacheService + apiKeyService *service.APIKeyService + usageRecordWorkerPool *service.UsageRecordWorkerPool + concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int +} + +// NewCopilotGatewayHandler creates a new CopilotGatewayHandler. +func NewCopilotGatewayHandler( + copilotService *service.CopilotGatewayService, + gatewayService *service.GatewayService, + concurrencyService *service.ConcurrencyService, + billingCacheService *service.BillingCacheService, + apiKeyService *service.APIKeyService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, +) *CopilotGatewayHandler { + return &CopilotGatewayHandler{ + copilotService: copilotService, + gatewayService: gatewayService, + billingCacheService: billingCacheService, + apiKeyService: apiKeyService, + usageRecordWorkerPool: usageRecordWorkerPool, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, 0), + maxAccountSwitches: 5, + } +} + +// ChatCompletions handles POST /copilot/v1/chat/completions +func (h *CopilotGatewayHandler) ChatCompletions(c *gin.Context) { + h.handleRequest(c, "chat_completions", func(ctx *copilotRequestContext) (*service.CopilotForwardResult, error) { + return ctx.service.ForwardChatCompletions(c.Request.Context(), c, ctx.account, ctx.body) + }) +} + +// Responses handles POST /copilot/v1/responses +func (h *CopilotGatewayHandler) Responses(c *gin.Context) { + h.handleRequest(c, "responses", func(ctx *copilotRequestContext) (*service.CopilotForwardResult, error) { + return ctx.service.ForwardResponses(c.Request.Context(), c, ctx.account, ctx.body) + }) +} + +// Models handles GET /copilot/v1/models +func (h *CopilotGatewayHandler) Models(c *gin.Context) { + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + + reqLog := requestLogger(c, "handler.copilot_gateway.models", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + ) + + // Select a Copilot account to query models from. + account, err := h.selectAccount(c, apiKey, "") + if err != nil { + reqLog.Warn("copilot: no available account for models list", zap.Error(err)) + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available Copilot accounts") + return + } + + modelsJSON, err := h.copilotService.ListModels(c.Request.Context(), account) + if err != nil { + reqLog.Error("copilot: failed to list models", zap.Error(err)) + h.errorResponse(c, http.StatusBadGateway, "api_error", "Failed to list Copilot models") + return + } + + c.Data(http.StatusOK, "application/json", modelsJSON) +} + +// copilotRequestContext holds the pre-validated request state. +type copilotRequestContext struct { + service *service.CopilotGatewayService + account *service.Account + body []byte +} + +// handleRequest is the shared request processing logic for both +// ChatCompletions and Responses endpoints. It handles authentication, +// account selection, failover, and concurrency control. +func (h *CopilotGatewayHandler) handleRequest( + c *gin.Context, + endpoint string, + forwardFn func(ctx *copilotRequestContext) (*service.CopilotForwardResult, error), +) { + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + + reqLog := requestLogger(c, "handler.copilot_gateway."+endpoint, + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + // Read and validate request body. + body, err := httputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + if !gjson.ValidBytes(body) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Invalid JSON body") + return + } + + reqModel := gjson.GetBytes(body, "model").String() + if reqModel == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + + reqLog = reqLog.With(zap.String("model", reqModel)) + + // Acquire user concurrency slot. + reqStream := gjson.GetBytes(body, "stream").Bool() + streamStarted := false + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) + if err != nil { + reqLog.Warn("copilot: user slot acquire failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + // Check billing eligibility after acquiring slot. + subscription, _ := middleware2.GetSubscriptionFromContext(c) + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("copilot: billing eligibility check failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.errorResponse(c, status, code, message) + return + } + + // Account selection with failover loop. + var failedAccountIDs []int64 + for attempt := 0; attempt <= h.maxAccountSwitches; attempt++ { + account, err := h.selectAccountExcluding(c, apiKey, reqModel, failedAccountIDs) + if err != nil { + reqLog.Warn("copilot: no available accounts", + zap.Error(err), + zap.Int("attempt", attempt), + zap.Int64s("failed_ids", failedAccountIDs), + ) + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available Copilot accounts: "+err.Error()) + return + } + + result, fwdErr := forwardFn(&copilotRequestContext{ + service: h.copilotService, + account: account, + body: body, + }) + + if fwdErr == nil { + reqLog.Info("copilot: request forwarded successfully", + zap.Int64("account_id", account.ID), + zap.Duration("duration", time.Since(requestStart)), + ) + + // Submit usage record asynchronously. + // Note: Copilot streaming responses are forwarded as-is without + // token usage parsing, so we record a basic usage entry for + // billing/audit purposes. Token counts will be zero. + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: &service.ForwardResult{ + Model: reqModel, + UpstreamModel: result.UpstreamModel, + }, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, + }); err != nil { + logger.L().With( + zap.String("component", "handler.copilot_gateway."+endpoint), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("copilot: record usage failed", zap.Error(err)) + } + }) + return + } + + // Check if the error is eligible for failover. + var upstreamErr *service.CopilotUpstreamError + if errors.As(fwdErr, &upstreamErr) && service.ShouldFailoverCopilotUpstreamError(upstreamErr.StatusCode) { + reqLog.Info("copilot: upstream error, trying next account", + zap.Int64("account_id", account.ID), + zap.Int("status", upstreamErr.StatusCode), + zap.Int("attempt", attempt), + ) + failedAccountIDs = append(failedAccountIDs, account.ID) + continue + } + + // Non-failover error: return to client. + reqLog.Error("copilot: forward failed (non-failover)", + zap.Int64("account_id", account.ID), + zap.Error(fwdErr), + ) + if !c.Writer.Written() { + if upstreamErr != nil { + c.Data(upstreamErr.StatusCode, "application/json", upstreamErr.Body) + } else { + h.errorResponse(c, http.StatusBadGateway, "api_error", "Copilot request failed") + } + } + return + } + + // All accounts exhausted. + if !c.Writer.Written() { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "All Copilot accounts exhausted after failover") + } +} + +// selectAccount picks an available Copilot account for the given API key. +func (h *CopilotGatewayHandler) selectAccount(c *gin.Context, apiKey *service.APIKey, model string) (*service.Account, error) { + return h.selectAccountExcluding(c, apiKey, model, nil) +} + +// selectAccountExcluding picks an available Copilot account, excluding +// the specified account IDs (used during failover). +func (h *CopilotGatewayHandler) selectAccountExcluding(c *gin.Context, apiKey *service.APIKey, model string, excludeIDs []int64) (*service.Account, error) { + var excludeSet map[int64]struct{} + if len(excludeIDs) > 0 { + excludeSet = make(map[int64]struct{}, len(excludeIDs)) + for _, id := range excludeIDs { + excludeSet[id] = struct{}{} + } + } + + return h.gatewayService.SelectAccountForModelWithExclusions( + c.Request.Context(), + apiKey.GroupID, + "", // no sticky session for copilot + model, + excludeSet, + ) +} + +func (h *CopilotGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +func (h *CopilotGatewayHandler) handleConcurrencyError(c *gin.Context, err error, scope string, streamStarted bool) { + logger.L().Warn("copilot: concurrency error", + zap.String("scope", scope), + zap.Error(err), + ) + if streamStarted { + return + } + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many concurrent requests") +} + +// submitUsageRecordTask submits a usage recording task to the bounded worker +// pool. Falls back to synchronous execution if the pool is not injected. +func (h *CopilotGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // Fallback: synchronous execution with timeout + panic recovery. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "handler.copilot_gateway"), + zap.Any("panic", recovered), + ).Error("copilot: usage record task panic recovered") + } + }() + task(ctx) +} diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b2467eacdb..1c27236194 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -34,20 +34,21 @@ type AdminHandlers struct { // Handlers contains all HTTP handlers type Handlers struct { - Auth *AuthHandler - User *UserHandler - APIKey *APIKeyHandler - Usage *UsageHandler - Redeem *RedeemHandler - Subscription *SubscriptionHandler - Announcement *AnnouncementHandler - Admin *AdminHandlers - Gateway *GatewayHandler - OpenAIGateway *OpenAIGatewayHandler - SoraGateway *SoraGatewayHandler - SoraClient *SoraClientHandler - Setting *SettingHandler - Totp *TotpHandler + Auth *AuthHandler + User *UserHandler + APIKey *APIKeyHandler + Usage *UsageHandler + Redeem *RedeemHandler + Subscription *SubscriptionHandler + Announcement *AnnouncementHandler + Admin *AdminHandlers + Gateway *GatewayHandler + OpenAIGateway *OpenAIGatewayHandler + CopilotGateway *CopilotGatewayHandler + SoraGateway *SoraGatewayHandler + SoraClient *SoraClientHandler + Setting *SettingHandler + Totp *TotpHandler } // BuildInfo contains build-time information diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 02ddd03098..8cc344cdf9 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -84,6 +84,7 @@ func ProvideHandlers( adminHandlers *AdminHandlers, gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, + copilotGatewayHandler *CopilotGatewayHandler, soraGatewayHandler *SoraGatewayHandler, soraClientHandler *SoraClientHandler, settingHandler *SettingHandler, @@ -92,20 +93,21 @@ func ProvideHandlers( _ *service.IdempotencyCleanupService, ) *Handlers { return &Handlers{ - Auth: authHandler, - User: userHandler, - APIKey: apiKeyHandler, - Usage: usageHandler, - Redeem: redeemHandler, - Subscription: subscriptionHandler, - Announcement: announcementHandler, - Admin: adminHandlers, - Gateway: gatewayHandler, - OpenAIGateway: openaiGatewayHandler, - SoraGateway: soraGatewayHandler, - SoraClient: soraClientHandler, - Setting: settingHandler, - Totp: totpHandler, + Auth: authHandler, + User: userHandler, + APIKey: apiKeyHandler, + Usage: usageHandler, + Redeem: redeemHandler, + Subscription: subscriptionHandler, + Announcement: announcementHandler, + Admin: adminHandlers, + Gateway: gatewayHandler, + OpenAIGateway: openaiGatewayHandler, + CopilotGateway: copilotGatewayHandler, + SoraGateway: soraGatewayHandler, + SoraClient: soraClientHandler, + Setting: settingHandler, + Totp: totpHandler, } } @@ -121,6 +123,7 @@ var ProviderSet = wire.NewSet( NewAnnouncementHandler, NewGatewayHandler, NewOpenAIGatewayHandler, + NewCopilotGatewayHandler, NewSoraGatewayHandler, NewTotpHandler, ProvideSettingHandler, diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 072cfdee37..5b5a2a1826 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -71,27 +71,41 @@ func RegisterGatewayRoutes( gateway.GET("/usage", h.Gateway.Usage) // OpenAI Responses API: auto-route based on group platform gateway.POST("/responses", func(c *gin.Context) { - if getGroupPlatform(c) == service.PlatformOpenAI { + switch getGroupPlatform(c) { + case service.PlatformOpenAI: h.OpenAIGateway.Responses(c) - return + case service.PlatformCopilot: + h.CopilotGateway.Responses(c) + default: + h.Gateway.Responses(c) } - h.Gateway.Responses(c) }) gateway.POST("/responses/*subpath", func(c *gin.Context) { - if getGroupPlatform(c) == service.PlatformOpenAI { + switch getGroupPlatform(c) { + case service.PlatformOpenAI: h.OpenAIGateway.Responses(c) - return + case service.PlatformCopilot: + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Responses subresources are not supported for Copilot platform", + }, + }) + default: + h.Gateway.Responses(c) } - h.Gateway.Responses(c) }) gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) // OpenAI Chat Completions API: auto-route based on group platform gateway.POST("/chat/completions", func(c *gin.Context) { - if getGroupPlatform(c) == service.PlatformOpenAI { + switch getGroupPlatform(c) { + case service.PlatformOpenAI: h.OpenAIGateway.ChatCompletions(c) - return + case service.PlatformCopilot: + h.CopilotGateway.ChatCompletions(c) + default: + h.Gateway.ChatCompletions(c) } - h.Gateway.ChatCompletions(c) }) } @@ -112,22 +126,43 @@ func RegisterGatewayRoutes( // OpenAI Responses API(不带v1前缀的别名)— auto-route based on group platform responsesHandler := func(c *gin.Context) { - if getGroupPlatform(c) == service.PlatformOpenAI { + switch getGroupPlatform(c) { + case service.PlatformOpenAI: h.OpenAIGateway.Responses(c) - return + case service.PlatformCopilot: + h.CopilotGateway.Responses(c) + default: + h.Gateway.Responses(c) + } + } + responsesSubpathHandler := func(c *gin.Context) { + switch getGroupPlatform(c) { + case service.PlatformOpenAI: + h.OpenAIGateway.Responses(c) + case service.PlatformCopilot: + c.JSON(http.StatusNotFound, gin.H{ + "error": gin.H{ + "type": "not_found_error", + "message": "Responses subresources are not supported for Copilot platform", + }, + }) + default: + h.Gateway.Responses(c) } - h.Gateway.Responses(c) } r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler) - r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler) + r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesSubpathHandler) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) // OpenAI Chat Completions API(不带v1前缀的别名)— auto-route based on group platform r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) { - if getGroupPlatform(c) == service.PlatformOpenAI { + switch getGroupPlatform(c) { + case service.PlatformOpenAI: h.OpenAIGateway.ChatCompletions(c) - return + case service.PlatformCopilot: + h.CopilotGateway.ChatCompletions(c) + default: + h.Gateway.ChatCompletions(c) } - h.Gateway.ChatCompletions(c) }) // Antigravity 模型列表 @@ -163,6 +198,21 @@ func RegisterGatewayRoutes( antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) } + // Copilot 专用路由(强制使用 copilot 平台) + copilotV1 := r.Group("/copilot/v1") + copilotV1.Use(bodyLimit) + copilotV1.Use(clientRequestID) + copilotV1.Use(opsErrorLogger) + copilotV1.Use(endpointNorm) + copilotV1.Use(middleware.ForcePlatform(service.PlatformCopilot)) + copilotV1.Use(gin.HandlerFunc(apiKeyAuth)) + copilotV1.Use(requireGroupAnthropic) + { + copilotV1.POST("/chat/completions", h.CopilotGateway.ChatCompletions) + copilotV1.POST("/responses", h.CopilotGateway.Responses) + copilotV1.GET("/models", h.CopilotGateway.Models) + } + // Sora 专用路由(强制使用 sora 平台) soraV1 := r.Group("/sora/v1") soraV1.Use(soraBodyLimit) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 512195e334..302c70b858 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -836,6 +836,27 @@ func (a *Account) IsAPIKeyOrBedrock() bool { return a.Type == AccountTypeAPIKey || a.Type == AccountTypeBedrock } +// IsCopilot returns true if the account belongs to the Copilot platform. +func (a *Account) IsCopilot() bool { + return a.Platform == PlatformCopilot +} + +// GetGitHubToken returns the GitHub OAuth access_token stored in credentials. +// This token is used to exchange for a short-lived Copilot API JWT. +func (a *Account) GetGitHubToken() string { + return a.GetCredential("github_token") +} + +// GetCopilotBaseURL returns the Copilot API base URL. +// If a custom base_url is set in credentials, it is used instead of the default. +func (a *Account) GetCopilotBaseURL() string { + baseURL := a.GetCredential("base_url") + if baseURL == "" { + return "https://api.githubcopilot.com" + } + return baseURL +} + func (a *Account) IsOpenAI() bool { return a.Platform == PlatformOpenAI } diff --git a/backend/internal/service/copilot_gateway_service.go b/backend/internal/service/copilot_gateway_service.go new file mode 100644 index 0000000000..53287533fd --- /dev/null +++ b/backend/internal/service/copilot_gateway_service.go @@ -0,0 +1,505 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +const ( + // Copilot API request header constants. + // These mimic the VS Code Copilot extension to ensure the upstream + // API accepts the request without flagging it as an unknown client. + copilotUserAgent = "GitHubCopilotChat/0.35.0" + copilotEditorVersion = "vscode/1.107.0" + copilotPluginVersion = "copilot-chat/0.35.0" + copilotIntegrationID = "vscode-chat" + copilotOpenAIIntent = "conversation-edits" + copilotGitHubAPIVer = "2025-04-01" + + // copilotChatPath and copilotResponsesPath are the upstream endpoint paths. + copilotChatPath = "/chat/completions" + copilotResponsesPath = "/responses" + + // copilotMaxUpstreamResponseSize caps the non-streaming response body + // size to prevent unbounded memory allocation (50 MB). + copilotMaxUpstreamResponseSize = 50 << 20 + + // copilotMaxSSELineSize is the maximum buffer for a single SSE line (20 MB). + copilotMaxSSELineSize = 20 << 20 +) + +// CopilotGatewayService handles forwarding requests to the GitHub Copilot API. +// +// Copilot uses the OpenAI-compatible API format (/chat/completions, /responses) +// but requires a two-phase authentication flow: +// 1. Exchange a GitHub OAuth access_token for a short-lived Copilot API JWT +// 2. Send requests to api.githubcopilot.com with the JWT and special headers +type CopilotGatewayService struct { + httpUpstream HTTPUpstream + copilotTokenProv *CopilotTokenProvider + rateLimitService *RateLimitService + cfg *config.Config +} + +// NewCopilotGatewayService creates a new CopilotGatewayService. +func NewCopilotGatewayService( + httpUpstream HTTPUpstream, + copilotTokenProv *CopilotTokenProvider, + rateLimitService *RateLimitService, + cfg *config.Config, +) *CopilotGatewayService { + return &CopilotGatewayService{ + httpUpstream: httpUpstream, + copilotTokenProv: copilotTokenProv, + rateLimitService: rateLimitService, + cfg: cfg, + } +} + +// CopilotForwardResult is the result of a successful forward to Copilot upstream. +type CopilotForwardResult struct { + StatusCode int + UpstreamModel string +} + +// ForwardChatCompletions forwards a request to the Copilot /chat/completions endpoint. +func (s *CopilotGatewayService) ForwardChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, +) (*CopilotForwardResult, error) { + return s.forward(ctx, c, account, body, copilotChatPath) +} + +// ForwardResponses forwards a request to the Copilot /responses endpoint. +func (s *CopilotGatewayService) ForwardResponses( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, +) (*CopilotForwardResult, error) { + return s.forward(ctx, c, account, body, copilotResponsesPath) +} + +// forward is the core forwarding logic shared by all Copilot endpoints. +func (s *CopilotGatewayService) forward( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + upstreamPath string, +) (*CopilotForwardResult, error) { + reqLog := logger.L().With( + zap.String("service", "copilot_gateway"), + zap.Int64("account_id", account.ID), + zap.String("account_name", account.Name), + ) + + // 1. Extract the GitHub access_token from account credentials. + githubToken := account.GetGitHubToken() + if githubToken == "" { + return nil, fmt.Errorf("copilot: account %d has no github_token in credentials", account.ID) + } + + // 2. Exchange GitHub token for Copilot API token (cached). + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + apiToken, apiEndpoint, err := s.copilotTokenProv.GetCopilotAPIToken(ctx, githubToken, proxyURL) + if err != nil { + reqLog.Error("copilot: token exchange failed", zap.Error(err)) + return nil, fmt.Errorf("copilot: token exchange failed: %w", err) + } + + // 3. Extract request metadata. + reqModel := gjson.GetBytes(body, "model").String() + reqStream := gjson.GetBytes(body, "stream").Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + // 4. Build upstream URL. + // Honor account-level base_url override; fall back to the endpoint + // returned by the token exchange (or its default). + baseURL := strings.TrimSpace(account.GetCopilotBaseURL()) + if baseURL == "" || baseURL == copilotDefaultBaseURL { + baseURL = apiEndpoint + } + targetURL := strings.TrimRight(baseURL, "/") + upstreamPath + + // 5. Build upstream HTTP request. + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("copilot: failed to create upstream request: %w", err) + } + applyCopilotHeaders(upstreamReq, apiToken, body) + + // 6. Send request to upstream. + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, fmt.Errorf("copilot: upstream request failed: %w", err) + } + + // 7. Handle error responses. + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + // Let rate limit service handle upstream errors for this account. + if s.rateLimitService != nil { + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + + reqLog.Warn("copilot: upstream error", + zap.Int("status", resp.StatusCode), + zap.String("body_preview", copilotTruncateString(string(respBody), 500)), + ) + + // Return a failover-compatible error for 401/429/5xx. + return &CopilotForwardResult{StatusCode: resp.StatusCode}, &CopilotUpstreamError{ + StatusCode: resp.StatusCode, + Body: respBody, + } + } + + // 8. Forward the successful response to the client. + if reqStream { + s.streamResponse(c, resp, reqLog) + } else { + if err := s.nonStreamResponse(c, resp, reqLog); err != nil { + return nil, err + } + } + + return &CopilotForwardResult{ + StatusCode: resp.StatusCode, + UpstreamModel: reqModel, + }, nil +} + +// streamResponse streams SSE data from the upstream response to the client. +func (s *CopilotGatewayService) streamResponse(c *gin.Context, resp *http.Response, reqLog *zap.Logger) { + defer func() { + if err := resp.Body.Close(); err != nil { + reqLog.Warn("copilot: close upstream response body error", zap.Error(err)) + } + }() + + // Copy relevant response headers. + for key, values := range resp.Header { + lk := strings.ToLower(key) + if lk == "content-length" || lk == "transfer-encoding" || lk == "connection" { + continue + } + for _, v := range values { + c.Header(key, v) + } + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Status(resp.StatusCode) + + flusher, _ := c.Writer.(http.Flusher) + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), copilotMaxSSELineSize) + + for scanner.Scan() { + line := scanner.Bytes() + _, _ = c.Writer.Write(line) + _, _ = c.Writer.Write([]byte("\n")) + if flusher != nil { + flusher.Flush() + } + } + + if err := scanner.Err(); err != nil { + reqLog.Warn("copilot: stream scanner error", zap.Error(err)) + } +} + +// nonStreamResponse reads the full upstream response and writes it to the client. +func (s *CopilotGatewayService) nonStreamResponse(c *gin.Context, resp *http.Response, reqLog *zap.Logger) error { + defer func() { + if err := resp.Body.Close(); err != nil { + reqLog.Warn("copilot: close upstream response body error", zap.Error(err)) + } + }() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, copilotMaxUpstreamResponseSize)) + if err != nil { + reqLog.Error("copilot: failed to read upstream response", zap.Error(err)) + return fmt.Errorf("copilot: failed to read upstream response: %w", err) + } + + for key, values := range resp.Header { + lk := strings.ToLower(key) + if lk == "content-length" || lk == "transfer-encoding" || lk == "connection" { + continue + } + for _, v := range values { + c.Header(key, v) + } + } + c.Data(resp.StatusCode, "application/json", respBody) + return nil +} + +// ShouldFailoverCopilotUpstreamError returns true if the upstream error +// should trigger account failover (try the next available Copilot account). +func ShouldFailoverCopilotUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 402, 403, 429: + return true + default: + return statusCode >= 500 + } +} + +// ListModels fetches the available models from the Copilot /models endpoint. +func (s *CopilotGatewayService) ListModels(ctx context.Context, account *Account) (json.RawMessage, error) { + githubToken := account.GetGitHubToken() + if githubToken == "" { + return nil, fmt.Errorf("copilot: account has no github_token") + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + apiToken, apiEndpoint, err := s.copilotTokenProv.GetCopilotAPIToken(ctx, githubToken, proxyURL) + if err != nil { + return nil, fmt.Errorf("copilot: token exchange failed: %w", err) + } + + // Honor account-level base_url override. + baseURL := strings.TrimSpace(account.GetCopilotBaseURL()) + if baseURL == "" || baseURL == copilotDefaultBaseURL { + baseURL = apiEndpoint + } + modelsURL := strings.TrimRight(baseURL, "/") + "/models" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil) + if err != nil { + return nil, fmt.Errorf("copilot: failed to create models request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+apiToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", copilotUserAgent) + req.Header.Set("Editor-Version", copilotEditorVersion) + req.Header.Set("Editor-Plugin-Version", copilotPluginVersion) + + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, fmt.Errorf("copilot: models request failed: %w", err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + logger.L().Warn("copilot: close models response body error", zap.Error(errClose)) + } + }() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if err != nil { + return nil, fmt.Errorf("copilot: failed to read models response: %w", err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("copilot: models request failed with status %d: %s", resp.StatusCode, string(respBody)) + } + + return json.RawMessage(respBody), nil +} + +// applyCopilotHeaders sets all required headers for Copilot API requests. +func applyCopilotHeaders(r *http.Request, apiToken string, body []byte) { + r.Header.Set("Content-Type", "application/json") + r.Header.Set("Authorization", "Bearer "+apiToken) + r.Header.Set("Accept", "application/json") + r.Header.Set("User-Agent", copilotUserAgent) + r.Header.Set("Editor-Version", copilotEditorVersion) + r.Header.Set("Editor-Plugin-Version", copilotPluginVersion) + r.Header.Set("Openai-Intent", copilotOpenAIIntent) + r.Header.Set("Copilot-Integration-Id", copilotIntegrationID) + r.Header.Set("X-Github-Api-Version", copilotGitHubAPIVer) + r.Header.Set("X-Request-Id", uuid.NewString()) + + // Determine X-Initiator header for Copilot billing: + // "user" → consumes premium request quota + // "agent" → free (tool loops, continuations) + initiator := "user" + if isCopilotAgentInitiated(body) { + initiator = "agent" + } + r.Header.Set("X-Initiator", initiator) + + // Vision detection + if detectCopilotVisionContent(body) { + r.Header.Set("Copilot-Vision-Request", "true") + } +} + +// isCopilotAgentInitiated detects whether the request is agent-initiated +// (tool callbacks, continuations) rather than user-initiated. Copilot uses +// the X-Initiator header for billing: +// - "user" → consumes premium request quota +// - "agent" → free (tool loops, continuations) +// +// The challenge: Claude Code sends tool results as role:"user" messages with +// content type "tool_result". After translation, the tool_result part may become +// a separate role:"tool" message, but if the original message also contained text, +// a role:"user" message is emitted AFTER the tool message, making the last message +// appear user-initiated when it's actually part of an agent tool loop. +// +// We detect agent status by checking: +// 1. Last message role is "assistant" or "tool" → agent +// 2. Last message is "user" but contains tool_result content → agent (tool loop) +// 3. Last message is "user" but preceding message is assistant with tool_use → agent +// 4. Responses API: any function_call / tool-related types in history → agent +func isCopilotAgentInitiated(body []byte) bool { + if len(body) == 0 { + return false + } + + // Chat Completions API: check messages array. + if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() { + arr := messages.Array() + if len(arr) == 0 { + return false + } + + // Find the last message with a role. + lastRole := "" + for i := len(arr) - 1; i >= 0; i-- { + if r := arr[i].Get("role").String(); r != "" { + lastRole = r + break + } + } + + // If last message is assistant or tool, clearly agent-initiated. + if lastRole == "assistant" || lastRole == "tool" { + return true + } + + // If last message is "user", check whether it contains tool results + // (indicating a tool-loop continuation) or if the preceding message + // is an assistant tool_use. + if lastRole == "user" { + // Check if the last user message contains tool_result content. + lastContent := arr[len(arr)-1].Get("content") + if lastContent.Exists() && lastContent.IsArray() { + for _, part := range lastContent.Array() { + if part.Get("type").String() == "tool_result" { + return true + } + } + } + // Check if the second-to-last message is an assistant with tool_use. + if len(arr) >= 2 { + prev := arr[len(arr)-2] + if prev.Get("role").String() == "assistant" { + prevContent := prev.Get("content") + if prevContent.Exists() && prevContent.IsArray() { + for _, part := range prevContent.Array() { + if part.Get("type").String() == "tool_use" { + return true + } + } + } + } + } + } + + return false + } + + // Responses API: check input array. + if inputs := gjson.GetBytes(body, "input"); inputs.Exists() && inputs.IsArray() { + arr := inputs.Array() + if len(arr) == 0 { + return false + } + + // Check last item — direct indicators. + last := arr[len(arr)-1] + if last.Get("role").String() == "assistant" { + return true + } + switch last.Get("type").String() { + case "function_call", "function_call_arguments", "computer_call": + return true + case "function_call_output", "function_call_response", "tool_result", "computer_call_output": + return true + } + + // If last item is user-role, check for prior non-user items + // that indicate this is a continuation rather than a fresh prompt. + for _, item := range arr { + if item.Get("role").String() == "assistant" { + return true + } + switch item.Get("type").String() { + case "function_call", "function_call_output", "function_call_response", + "function_call_arguments", "computer_call", "computer_call_output": + return true + } + } + } + + return false +} + +// detectCopilotVisionContent checks if the request body contains image content. +func detectCopilotVisionContent(body []byte) bool { + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return false + } + for _, msg := range messages.Array() { + content := msg.Get("content") + if content.IsArray() { + for _, block := range content.Array() { + blockType := block.Get("type").String() + if blockType == "image_url" || blockType == "image" { + return true + } + } + } + } + return false +} + +// CopilotUpstreamError represents an error response from the Copilot upstream. +type CopilotUpstreamError struct { + StatusCode int + Body []byte +} + +func (e *CopilotUpstreamError) Error() string { + return fmt.Sprintf("copilot upstream error: status %d", e.StatusCode) +} + +// copilotTruncateString truncates s to maxLen, appending "..." if truncated. +func copilotTruncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/backend/internal/service/copilot_test.go b/backend/internal/service/copilot_test.go new file mode 100644 index 0000000000..6b16988d60 --- /dev/null +++ b/backend/internal/service/copilot_test.go @@ -0,0 +1,425 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" +) + +// copilotHTTPUpstreamStub is a test double for HTTPUpstream that returns +// configurable responses. It is safe for concurrent use. +type copilotHTTPUpstreamStub struct { + mu sync.Mutex + handler func(req *http.Request) (*http.Response, error) + calls int + lastReq *http.Request +} + +func (s *copilotHTTPUpstreamStub) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + s.mu.Lock() + s.calls++ + s.lastReq = req + handler := s.handler + s.mu.Unlock() + if handler == nil { + return nil, fmt.Errorf("no handler configured") + } + return handler(req) +} + +func (s *copilotHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { + return s.Do(req, proxyURL, accountID, accountConcurrency) +} + +// makeCopilotTokenResponse creates a valid token exchange HTTP response. +func makeCopilotTokenResponse(token string, expiresAt int64, apiEndpoint string) *http.Response { + body := copilotAPITokenResponse{ + Token: token, + ExpiresAt: expiresAt, + } + body.Endpoints.API = apiEndpoint + data, _ := json.Marshal(body) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(data)), + Header: make(http.Header), + } +} + +// ---------- CopilotTokenProvider Tests ---------- + +func TestCopilotTokenProvider_BasicExchange(t *testing.T) { + upstream := &copilotHTTPUpstreamStub{ + handler: func(req *http.Request) (*http.Response, error) { + return makeCopilotTokenResponse("jwt-abc", time.Now().Add(30*time.Minute).Unix(), "https://custom.copilot.api"), nil + }, + } + + provider := NewCopilotTokenProvider(upstream) + token, endpoint, err := provider.GetCopilotAPIToken(context.Background(), "ghp_testtoken", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token != "jwt-abc" { + t.Errorf("expected token 'jwt-abc', got %q", token) + } + if endpoint != "https://custom.copilot.api" { + t.Errorf("expected endpoint 'https://custom.copilot.api', got %q", endpoint) + } +} + +func TestCopilotTokenProvider_CacheHit(t *testing.T) { + var callCount atomic.Int32 + upstream := &copilotHTTPUpstreamStub{ + handler: func(req *http.Request) (*http.Response, error) { + callCount.Add(1) + return makeCopilotTokenResponse("jwt-cached", time.Now().Add(30*time.Minute).Unix(), ""), nil + }, + } + + provider := NewCopilotTokenProvider(upstream) + + // First call — cache miss. + token1, _, err := provider.GetCopilotAPIToken(context.Background(), "ghp_test", "") + if err != nil { + t.Fatalf("first call error: %v", err) + } + + // Second call — should hit cache (no new HTTP request). + token2, _, err := provider.GetCopilotAPIToken(context.Background(), "ghp_test", "") + if err != nil { + t.Fatalf("second call error: %v", err) + } + + if token1 != token2 { + t.Errorf("tokens should match: %q vs %q", token1, token2) + } + if callCount.Load() != 1 { + t.Errorf("expected 1 HTTP call (cache hit), got %d", callCount.Load()) + } +} + +func TestCopilotTokenProvider_CacheExpiry(t *testing.T) { + var callCount atomic.Int32 + upstream := &copilotHTTPUpstreamStub{ + handler: func(req *http.Request) (*http.Response, error) { + callCount.Add(1) + // Token expires in 2 seconds (well within the 5-minute buffer). + return makeCopilotTokenResponse("jwt-expiring", time.Now().Add(2*time.Second).Unix(), ""), nil + }, + } + + provider := NewCopilotTokenProvider(upstream) + + // First call — cache miss. + _, _, err := provider.GetCopilotAPIToken(context.Background(), "ghp_test", "") + if err != nil { + t.Fatalf("first call error: %v", err) + } + + // Second call — token is within expiry buffer, should trigger refresh. + _, _, err = provider.GetCopilotAPIToken(context.Background(), "ghp_test", "") + if err != nil { + t.Fatalf("second call error: %v", err) + } + + if callCount.Load() != 2 { + t.Errorf("expected 2 HTTP calls (expired cache), got %d", callCount.Load()) + } +} + +func TestCopilotTokenProvider_Singleflight(t *testing.T) { + var callCount atomic.Int32 + upstream := &copilotHTTPUpstreamStub{ + handler: func(req *http.Request) (*http.Response, error) { + callCount.Add(1) + // Simulate some latency to ensure concurrent requests overlap. + time.Sleep(50 * time.Millisecond) + return makeCopilotTokenResponse("jwt-sf", time.Now().Add(30*time.Minute).Unix(), ""), nil + }, + } + + provider := NewCopilotTokenProvider(upstream) + + // Fire 10 concurrent requests — only 1 should hit upstream. + var wg sync.WaitGroup + errors := make([]error, 10) + for i := 0; i < 10; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _, _, errors[idx] = provider.GetCopilotAPIToken(context.Background(), "ghp_concurrent", "") + }(i) + } + wg.Wait() + + for i, err := range errors { + if err != nil { + t.Errorf("goroutine %d returned error: %v", i, err) + } + } + + if callCount.Load() != 1 { + t.Errorf("expected 1 HTTP call (singleflight), got %d", callCount.Load()) + } +} + +func TestCopilotTokenProvider_EmptyToken(t *testing.T) { + provider := NewCopilotTokenProvider(nil) + _, _, err := provider.GetCopilotAPIToken(context.Background(), "", "") + if err == nil { + t.Fatal("expected error for empty token") + } +} + +func TestCopilotTokenProvider_WhitespaceToken(t *testing.T) { + provider := NewCopilotTokenProvider(nil) + _, _, err := provider.GetCopilotAPIToken(context.Background(), " ", "") + if err == nil { + t.Fatal("expected error for whitespace-only token") + } +} + +func TestCopilotTokenProvider_ExchangeFailure(t *testing.T) { + upstream := &copilotHTTPUpstreamStub{ + handler: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 401, + Body: io.NopCloser(bytes.NewBufferString(`{"message":"Bad credentials"}`)), + Header: make(http.Header), + }, nil + }, + } + + provider := NewCopilotTokenProvider(upstream) + _, _, err := provider.GetCopilotAPIToken(context.Background(), "ghp_invalid", "") + if err == nil { + t.Fatal("expected error for failed exchange") + } +} + +func TestCopilotTokenProvider_DefaultEndpoint(t *testing.T) { + upstream := &copilotHTTPUpstreamStub{ + handler: func(req *http.Request) (*http.Response, error) { + // No endpoint in response — should use default. + return makeCopilotTokenResponse("jwt-default", time.Now().Add(30*time.Minute).Unix(), ""), nil + }, + } + + provider := NewCopilotTokenProvider(upstream) + _, endpoint, err := provider.GetCopilotAPIToken(context.Background(), "ghp_test", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if endpoint != copilotDefaultBaseURL { + t.Errorf("expected default endpoint %q, got %q", copilotDefaultBaseURL, endpoint) + } +} + +func TestCopilotTokenProvider_EvictsExpiredEntries(t *testing.T) { + upstream := &copilotHTTPUpstreamStub{ + handler: func(req *http.Request) (*http.Response, error) { + return makeCopilotTokenResponse("jwt-evict", time.Now().Add(30*time.Minute).Unix(), ""), nil + }, + } + + provider := NewCopilotTokenProvider(upstream) + + // Manually insert an expired entry. + expiredKey := tokenFingerprint("ghp_expired") + provider.mu.Lock() + provider.cache[expiredKey] = &cachedCopilotToken{ + token: "old-jwt", + expiresAt: time.Now().Add(-1 * time.Hour), + } + provider.mu.Unlock() + + // Trigger a new exchange for a different token (causes eviction sweep). + _, _, err := provider.GetCopilotAPIToken(context.Background(), "ghp_new", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // The expired entry should have been cleaned up. + provider.mu.RLock() + _, exists := provider.cache[expiredKey] + provider.mu.RUnlock() + if exists { + t.Error("expired cache entry was not evicted") + } +} + +// ---------- ShouldFailoverCopilotUpstreamError Tests ---------- + +func TestShouldFailoverCopilotUpstreamError(t *testing.T) { + tests := []struct { + statusCode int + expected bool + }{ + {200, false}, + {400, false}, + {401, true}, + {402, true}, + {403, true}, + {404, false}, + {429, true}, + {500, true}, + {502, true}, + {503, true}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("status_%d", tc.statusCode), func(t *testing.T) { + result := ShouldFailoverCopilotUpstreamError(tc.statusCode) + if result != tc.expected { + t.Errorf("ShouldFailoverCopilotUpstreamError(%d) = %v, want %v", tc.statusCode, result, tc.expected) + } + }) + } +} + +// ---------- Helper Function Tests ---------- + +func TestIsCopilotAgentInitiated(t *testing.T) { + tests := []struct { + name string + body string + expected bool + }{ + { + name: "empty body", + body: "", + expected: false, + }, + { + name: "user message only", + body: `{"messages":[{"role":"user","content":"hello"}]}`, + expected: false, + }, + { + name: "tool message last", + body: `{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`, + expected: true, + }, + { + name: "assistant message last", + body: `{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"thinking..."}]}`, + expected: true, + }, + { + name: "last user contains tool_result content", + body: `{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu1","content":"file contents..."}]}]}`, + expected: true, + }, + { + name: "last user preceded by assistant with tool_use", + body: `{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":[{"type":"tool_use","id":"tu1","name":"Read","input":{}}]},{"role":"user","content":"some text"}]}`, + expected: true, + }, + { + name: "genuine multi-turn no tools", + body: `{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"what is 2+2?"}]}`, + expected: false, + }, + { + name: "genuine follow-up after completed tool history", + body: `{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":[{"type":"tool_use","id":"tu1","name":"Read","input":{}}]},{"role":"tool","tool_call_id":"tu1","content":"file data"},{"role":"assistant","content":"I read the file."},{"role":"user","content":"What did we do so far?"}]}`, + expected: false, + }, + { + name: "responses API function_call_output", + body: `{"input":[{"type":"function_call_output","output":"done"}]}`, + expected: true, + }, + { + name: "responses API function_call_arguments", + body: `{"input":[{"type":"function_call_arguments","arguments":"{}"}]}`, + expected: true, + }, + { + name: "responses API computer_call", + body: `{"input":[{"type":"computer_call","action":"click"}]}`, + expected: true, + }, + { + name: "responses API computer_call_output", + body: `{"input":[{"type":"computer_call_output","output":"screenshot"}]}`, + expected: true, + }, + { + name: "responses API user input only", + body: `{"input":[{"role":"user","content":"hello"}]}`, + expected: false, + }, + { + name: "responses API last user but history has assistant", + body: `{"input":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"I can help"}]},{"type":"message","role":"user","content":[{"type":"input_text","text":"Do X"}]}]}`, + expected: true, + }, + { + name: "responses API last user but history has function_call", + body: `{"input":[{"type":"function_call","name":"tool1","arguments":"{}"},{"type":"message","role":"user","content":[{"type":"input_text","text":"continue"}]}]}`, + expected: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := isCopilotAgentInitiated([]byte(tc.body)) + if result != tc.expected { + t.Errorf("isCopilotAgentInitiated(%q) = %v, want %v", tc.body, result, tc.expected) + } + }) + } +} + +func TestDetectCopilotVisionContent(t *testing.T) { + tests := []struct { + name string + body string + expected bool + }{ + { + name: "no vision", + body: `{"messages":[{"role":"user","content":"hello"}]}`, + expected: false, + }, + { + name: "with image_url", + body: `{"messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"https://example.com/img.png"}}]}]}`, + expected: true, + }, + { + name: "with image", + body: `{"messages":[{"role":"user","content":[{"type":"image","data":"base64..."}]}]}`, + expected: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := detectCopilotVisionContent([]byte(tc.body)) + if result != tc.expected { + t.Errorf("detectCopilotVisionContent(%q) = %v, want %v", tc.body, result, tc.expected) + } + }) + } +} + +func TestCopilotTruncateString(t *testing.T) { + if r := copilotTruncateString("short", 10); r != "short" { + t.Errorf("expected 'short', got %q", r) + } + if r := copilotTruncateString("a very long string", 5); r != "a ver..." { + t.Errorf("expected 'a ver...', got %q", r) + } +} diff --git a/backend/internal/service/copilot_token_provider.go b/backend/internal/service/copilot_token_provider.go new file mode 100644 index 0000000000..0e1a32ba3a --- /dev/null +++ b/backend/internal/service/copilot_token_provider.go @@ -0,0 +1,232 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" +) + +const ( + // copilotTokenExchangeURL is the GitHub API endpoint for exchanging + // a GitHub OAuth access_token into a short-lived Copilot API JWT. + copilotTokenExchangeURL = "https://api.github.com/copilot_internal/v2/token" + + // copilotDefaultBaseURL is the default Copilot API endpoint. + copilotDefaultBaseURL = "https://api.githubcopilot.com" + + // copilotTokenCacheTTL is the default cache duration when the token + // response does not include an explicit expires_at field. + copilotTokenCacheTTL = 25 * time.Minute + + // copilotTokenExpiryBuffer is the time before expiry at which we proactively + // refresh the cached token. This avoids sending requests with a token that + // is about to expire within seconds. + copilotTokenExpiryBuffer = 5 * time.Minute + + // copilotTokenExchangeUserAgent mimics the VS Code Copilot extension. + copilotTokenExchangeUserAgent = "GithubCopilot/1.0" + copilotTokenEditorVersion = "vscode/1.100.0" + copilotTokenPluginVersion = "copilot/1.300.0" + + // maxCopilotTokenResponseSize caps the response body from the token + // exchange endpoint to prevent unbounded memory allocation (1 MB). + maxCopilotTokenResponseSize = 1 << 20 +) + +// copilotAPITokenResponse represents the JSON response returned by +// the Copilot token exchange endpoint. +type copilotAPITokenResponse struct { + Token string `json:"token"` + ExpiresAt int64 `json:"expires_at"` + Endpoints struct { + API string `json:"api"` + } `json:"endpoints,omitempty"` + ErrorDetails *struct { + Message string `json:"message"` + } `json:"error_details,omitempty"` +} + +// cachedCopilotToken stores a Copilot API token together with its expiry +// and the resolved API endpoint. +type cachedCopilotToken struct { + token string + apiEndpoint string + expiresAt time.Time +} + +// CopilotTokenProvider manages the exchange of GitHub OAuth access_tokens +// for short-lived Copilot API JWTs. Tokens are cached per access_token +// (keyed by SHA-256 fingerprint) and automatically refreshed before expiry. +// +// The exchange flow follows the same protocol that the VS Code Copilot +// extension uses: +// +// GET https://api.github.com/copilot_internal/v2/token +// Authorization: token +// → { "token": "", "expires_at": , "endpoints": { "api": "..." } } +type CopilotTokenProvider struct { + httpUpstream HTTPUpstream + mu sync.RWMutex + cache map[string]*cachedCopilotToken // keyed by SHA-256 of github token + sf singleflight.Group +} + +// NewCopilotTokenProvider creates a new provider instance. +func NewCopilotTokenProvider(httpUpstream HTTPUpstream) *CopilotTokenProvider { + return &CopilotTokenProvider{ + httpUpstream: httpUpstream, + cache: make(map[string]*cachedCopilotToken), + } +} + +// tokenFingerprint returns a SHA-256 hex digest of the token, used as +// the cache key to avoid storing raw credentials in memory. +func tokenFingerprint(token string) string { + h := sha256.Sum256([]byte(token)) + return hex.EncodeToString(h[:]) +} + +// GetCopilotAPIToken returns a valid Copilot API token and the resolved +// API endpoint for the given GitHub access_token. Cached tokens are reused +// until they are within copilotTokenExpiryBuffer of expiry. +// +// Uses singleflight to deduplicate concurrent exchange requests for the +// same access_token, preventing thundering herd on cache miss/expiry. +func (p *CopilotTokenProvider) GetCopilotAPIToken(ctx context.Context, githubAccessToken string, proxyURL string) (token string, apiEndpoint string, err error) { + githubAccessToken = strings.TrimSpace(githubAccessToken) + if githubAccessToken == "" { + return "", "", fmt.Errorf("copilot: github access token is empty") + } + + cacheKey := tokenFingerprint(githubAccessToken) + + // Fast path: read-lock check for cached token. + p.mu.RLock() + if cached, ok := p.cache[cacheKey]; ok && cached.expiresAt.After(time.Now().Add(copilotTokenExpiryBuffer)) { + p.mu.RUnlock() + return cached.token, cached.apiEndpoint, nil + } + p.mu.RUnlock() + + // Slow path: use singleflight to deduplicate concurrent exchanges + // for the same token. + type tokenResult struct { + token string + apiEndpoint string + } + + val, err, _ := p.sf.Do(cacheKey, func() (any, error) { + // Double-check under singleflight: another goroutine may have + // already refreshed the cache while we were waiting. + p.mu.RLock() + if cached, ok := p.cache[cacheKey]; ok && cached.expiresAt.After(time.Now().Add(copilotTokenExpiryBuffer)) { + p.mu.RUnlock() + return &tokenResult{token: cached.token, apiEndpoint: cached.apiEndpoint}, nil + } + p.mu.RUnlock() + + // Exchange token. + apiToken, exchangeErr := p.exchangeToken(ctx, githubAccessToken, proxyURL) + if exchangeErr != nil { + return nil, exchangeErr + } + + resolvedEndpoint := copilotDefaultBaseURL + if ep := strings.TrimRight(apiToken.Endpoints.API, "/"); ep != "" { + resolvedEndpoint = ep + } + + expiresAt := time.Now().Add(copilotTokenCacheTTL) + if apiToken.ExpiresAt > 0 { + expiresAt = time.Unix(apiToken.ExpiresAt, 0) + } + + // Update cache and evict expired entries. + p.mu.Lock() + p.cache[cacheKey] = &cachedCopilotToken{ + token: apiToken.Token, + apiEndpoint: resolvedEndpoint, + expiresAt: expiresAt, + } + // Opportunistic eviction of expired entries to prevent unbounded growth. + now := time.Now() + for k, v := range p.cache { + if v.expiresAt.Before(now) { + delete(p.cache, k) + } + } + p.mu.Unlock() + + return &tokenResult{token: apiToken.Token, apiEndpoint: resolvedEndpoint}, nil + }) + + if err != nil { + return "", "", err + } + + result, ok := val.(*tokenResult) + if !ok { + return "", "", fmt.Errorf("copilot: unexpected token result type") + } + return result.token, result.apiEndpoint, nil +} + +// exchangeToken performs the actual HTTP request to exchange a GitHub +// access_token for a Copilot API token. +func (p *CopilotTokenProvider) exchangeToken(ctx context.Context, githubAccessToken string, proxyURL string) (*copilotAPITokenResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotTokenExchangeURL, nil) + if err != nil { + return nil, fmt.Errorf("copilot: failed to create token exchange request: %w", err) + } + + req.Header.Set("Authorization", "token "+githubAccessToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", copilotTokenExchangeUserAgent) + req.Header.Set("Editor-Version", copilotTokenEditorVersion) + req.Header.Set("Editor-Plugin-Version", copilotTokenPluginVersion) + + resp, err := p.httpUpstream.Do(req, proxyURL, 0, 0) + if err != nil { + return nil, fmt.Errorf("copilot: token exchange request failed: %w", err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + logger.L().Warn("copilot: close token exchange response body error", zap.Error(errClose)) + } + }() + + bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxCopilotTokenResponseSize)) + if err != nil { + return nil, fmt.Errorf("copilot: failed to read token exchange response: %w", err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("copilot: token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var apiToken copilotAPITokenResponse + if err = json.Unmarshal(bodyBytes, &apiToken); err != nil { + return nil, fmt.Errorf("copilot: failed to parse token exchange response: %w", err) + } + + if apiToken.Token == "" { + errMsg := "empty copilot api token" + if apiToken.ErrorDetails != nil && apiToken.ErrorDetails.Message != "" { + errMsg = apiToken.ErrorDetails.Message + } + return nil, fmt.Errorf("copilot: token exchange returned error: %s", errMsg) + } + + return &apiToken, nil +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index ecac0db0cf..ecc2e7b29e 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -25,6 +25,7 @@ const ( PlatformGemini = domain.PlatformGemini PlatformAntigravity = domain.PlatformAntigravity PlatformSora = domain.PlatformSora + PlatformCopilot = domain.PlatformCopilot ) // Account type constants diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index d79a353124..8fcffa3752 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -431,6 +431,8 @@ var ProviderSet = wire.NewSet( wire.Bind(new(SoraClient), new(*SoraSDKClient)), NewSoraGatewayService, NewOpenAIGatewayService, + NewCopilotTokenProvider, + NewCopilotGatewayService, NewOAuthService, NewOpenAIOAuthService, NewGeminiOAuthService,