From 91c9b8d0626362105c878d4a5bd5465010204671 Mon Sep 17 00:00:00 2001 From: erio Date: Sat, 4 Apr 2026 11:00:55 +0800 Subject: [PATCH 01/67] =?UTF-8?q?feat(channel):=20=E6=B8=A0=E9=81=93?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E7=B3=BB=E7=BB=9F=20=E2=80=94=20=E5=A4=9A?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E5=AE=9A=E4=BB=B7=20+=20=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E8=AE=A1=E8=B4=B9=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cherry-picked from release/custom-0.1.106: a9117600 --- backend/cmd/server/wire_gen.go | 9 +- .../internal/handler/admin/channel_handler.go | 308 +++++++++ backend/internal/handler/handler.go | 1 + backend/internal/handler/wire.go | 3 + backend/internal/repository/channel_repo.go | 392 +++++++++++ .../repository/channel_repo_pricing.go | 285 ++++++++ backend/internal/repository/wire.go | 1 + backend/internal/server/routes/admin.go | 14 + backend/internal/service/billing_service.go | 184 ++++- backend/internal/service/channel.go | 171 +++++ backend/internal/service/channel_service.go | 338 ++++++++++ backend/internal/service/channel_test.go | 210 ++++++ .../service/gateway_record_usage_test.go | 1 + backend/internal/service/gateway_service.go | 25 +- .../service/model_pricing_resolver.go | 198 ++++++ .../service/model_pricing_resolver_test.go | 164 +++++ backend/internal/service/wire.go | 2 + backend/migrations/081_create_channels.sql | 56 ++ .../082_refactor_channel_pricing.sql | 67 ++ frontend/src/api/admin/channels.ts | 121 ++++ frontend/src/api/admin/index.ts | 7 +- .../components/admin/channel/IntervalRow.vue | 160 +++++ .../admin/channel/PricingEntryCard.vue | 260 ++++++++ .../src/components/admin/channel/types.ts | 59 ++ frontend/src/components/layout/AppSidebar.vue | 16 + frontend/src/router/index.ts | 10 + frontend/src/views/admin/ChannelsView.vue | 628 ++++++++++++++++++ 27 files changed, 3682 insertions(+), 8 deletions(-) create mode 100644 backend/internal/handler/admin/channel_handler.go create mode 100644 backend/internal/repository/channel_repo.go create mode 100644 backend/internal/repository/channel_repo_pricing.go create mode 100644 backend/internal/service/channel.go create mode 100644 backend/internal/service/channel_service.go create mode 100644 backend/internal/service/channel_test.go create mode 100644 backend/internal/service/model_pricing_resolver.go create mode 100644 backend/internal/service/model_pricing_resolver_test.go create mode 100644 backend/migrations/081_create_channels.sql create mode 100644 backend/migrations/082_refactor_channel_pricing.sql create mode 100644 frontend/src/api/admin/channels.ts create mode 100644 frontend/src/components/admin/channel/IntervalRow.vue create mode 100644 frontend/src/components/admin/channel/PricingEntryCard.vue create mode 100644 frontend/src/components/admin/channel/types.ts create mode 100644 frontend/src/views/admin/ChannelsView.vue diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ce898a4a90..cde870def4 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -49,6 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) groupRepository := repository.NewGroupRepository(client, db) + channelRepository := repository.NewChannelRepository(db) settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) @@ -175,7 +176,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService) + channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator) + modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) + _ = modelPricingResolver // Phase 4: 已注册,后续 Gateway 迁移时使用 + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) @@ -213,7 +217,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db) scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler) + channelHandler := admin.NewChannelHandler(channelService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go new file mode 100644 index 0000000000..fb6f7d0204 --- /dev/null +++ b/backend/internal/handler/admin/channel_handler.go @@ -0,0 +1,308 @@ +package admin + +import ( + "strconv" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// ChannelHandler handles admin channel management +type ChannelHandler struct { + channelService *service.ChannelService +} + +// NewChannelHandler creates a new admin channel handler +func NewChannelHandler(channelService *service.ChannelService) *ChannelHandler { + return &ChannelHandler{channelService: channelService} +} + +// --- Request / Response types --- + +type createChannelRequest struct { + Name string `json:"name" binding:"required,max=100"` + Description string `json:"description"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingRequest `json:"model_pricing"` +} + +type updateChannelRequest struct { + Name string `json:"name" binding:"omitempty,max=100"` + Description *string `json:"description"` + Status string `json:"status" binding:"omitempty,oneof=active disabled"` + GroupIDs *[]int64 `json:"group_ids"` + ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` +} + +type channelModelPricingRequest struct { + Models []string `json:"models" binding:"required,min=1,max=100"` + BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` + InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` + OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"` + CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"` + CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"` + ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"` + Intervals []pricingIntervalRequest `json:"intervals"` +} + +type pricingIntervalRequest struct { + MinTokens int `json:"min_tokens"` + MaxTokens *int `json:"max_tokens"` + TierLabel string `json:"tier_label"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + PerRequestPrice *float64 `json:"per_request_price"` + SortOrder int `json:"sort_order"` +} + +type channelResponse struct { + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingResponse `json:"model_pricing"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type channelModelPricingResponse struct { + ID int64 `json:"id"` + Models []string `json:"models"` + BillingMode string `json:"billing_mode"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + ImageOutputPrice *float64 `json:"image_output_price"` + Intervals []pricingIntervalResponse `json:"intervals"` +} + +type pricingIntervalResponse struct { + ID int64 `json:"id"` + MinTokens int `json:"min_tokens"` + MaxTokens *int `json:"max_tokens"` + TierLabel string `json:"tier_label,omitempty"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + PerRequestPrice *float64 `json:"per_request_price"` + SortOrder int `json:"sort_order"` +} + +func channelToResponse(ch *service.Channel) *channelResponse { + if ch == nil { + return nil + } + resp := &channelResponse{ + ID: ch.ID, + Name: ch.Name, + Description: ch.Description, + Status: ch.Status, + GroupIDs: ch.GroupIDs, + CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), + UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"), + } + if resp.GroupIDs == nil { + resp.GroupIDs = []int64{} + } + + resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing)) + for _, p := range ch.ModelPricing { + models := p.Models + if models == nil { + models = []string{} + } + billingMode := string(p.BillingMode) + if billingMode == "" { + billingMode = "token" + } + intervals := make([]pricingIntervalResponse, 0, len(p.Intervals)) + for _, iv := range p.Intervals { + intervals = append(intervals, pricingIntervalResponse{ + ID: iv.ID, + MinTokens: iv.MinTokens, + MaxTokens: iv.MaxTokens, + TierLabel: iv.TierLabel, + InputPrice: iv.InputPrice, + OutputPrice: iv.OutputPrice, + CacheWritePrice: iv.CacheWritePrice, + CacheReadPrice: iv.CacheReadPrice, + PerRequestPrice: iv.PerRequestPrice, + SortOrder: iv.SortOrder, + }) + } + resp.ModelPricing = append(resp.ModelPricing, channelModelPricingResponse{ + ID: p.ID, + Models: models, + BillingMode: billingMode, + InputPrice: p.InputPrice, + OutputPrice: p.OutputPrice, + CacheWritePrice: p.CacheWritePrice, + CacheReadPrice: p.CacheReadPrice, + ImageOutputPrice: p.ImageOutputPrice, + Intervals: intervals, + }) + } + return resp +} + +func pricingRequestToService(reqs []channelModelPricingRequest) []service.ChannelModelPricing { + result := make([]service.ChannelModelPricing, 0, len(reqs)) + for _, r := range reqs { + billingMode := service.BillingMode(r.BillingMode) + if billingMode == "" { + billingMode = service.BillingModeToken + } + intervals := make([]service.PricingInterval, 0, len(r.Intervals)) + for _, iv := range r.Intervals { + intervals = append(intervals, service.PricingInterval{ + MinTokens: iv.MinTokens, + MaxTokens: iv.MaxTokens, + TierLabel: iv.TierLabel, + InputPrice: iv.InputPrice, + OutputPrice: iv.OutputPrice, + CacheWritePrice: iv.CacheWritePrice, + CacheReadPrice: iv.CacheReadPrice, + PerRequestPrice: iv.PerRequestPrice, + SortOrder: iv.SortOrder, + }) + } + result = append(result, service.ChannelModelPricing{ + Models: r.Models, + BillingMode: billingMode, + InputPrice: r.InputPrice, + OutputPrice: r.OutputPrice, + CacheWritePrice: r.CacheWritePrice, + CacheReadPrice: r.CacheReadPrice, + ImageOutputPrice: r.ImageOutputPrice, + Intervals: intervals, + }) + } + return result +} + +// --- Handlers --- + +// List handles listing channels with pagination +// GET /api/v1/admin/channels +func (h *ChannelHandler) List(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + status := c.Query("status") + search := strings.TrimSpace(c.Query("search")) + if len(search) > 100 { + search = search[:100] + } + + channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search) + if err != nil { + response.ErrorFrom(c, err) + return + } + + out := make([]*channelResponse, 0, len(channels)) + for i := range channels { + out = append(out, channelToResponse(&channels[i])) + } + response.Paginated(c, out, pag.Total, page, pageSize) +} + +// GetByID handles getting a channel by ID +// GET /api/v1/admin/channels/:id +func (h *ChannelHandler) GetByID(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid channel ID") + return + } + + channel, err := h.channelService.GetByID(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, channelToResponse(channel)) +} + +// Create handles creating a new channel +// POST /api/v1/admin/channels +func (h *ChannelHandler) Create(c *gin.Context) { + var req createChannelRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ + Name: req.Name, + Description: req.Description, + GroupIDs: req.GroupIDs, + ModelPricing: pricingRequestToService(req.ModelPricing), + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, channelToResponse(channel)) +} + +// Update handles updating a channel +// PUT /api/v1/admin/channels/:id +func (h *ChannelHandler) Update(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid channel ID") + return + } + + var req updateChannelRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + input := &service.UpdateChannelInput{ + Name: req.Name, + Description: req.Description, + Status: req.Status, + GroupIDs: req.GroupIDs, + } + if req.ModelPricing != nil { + pricing := pricingRequestToService(*req.ModelPricing) + input.ModelPricing = &pricing + } + + channel, err := h.channelService.Update(c.Request.Context(), id, input) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, channelToResponse(channel)) +} + +// Delete handles deleting a channel +// DELETE /api/v1/admin/channels/:id +func (h *ChannelHandler) Delete(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid channel ID") + return + } + + if err := h.channelService.Delete(c.Request.Context(), id); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Channel deleted successfully"}) +} diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b2467eacdb..ebf8d5f674 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -30,6 +30,7 @@ type AdminHandlers struct { TLSFingerprintProfile *admin.TLSFingerprintProfileHandler APIKey *admin.AdminAPIKeyHandler ScheduledTest *admin.ScheduledTestHandler + Channel *admin.ChannelHandler } // Handlers contains all HTTP handlers diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 02ddd03098..c917f24a0d 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -33,6 +33,7 @@ func ProvideAdminHandlers( tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler, apiKeyHandler *admin.AdminAPIKeyHandler, scheduledTestHandler *admin.ScheduledTestHandler, + channelHandler *admin.ChannelHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -59,6 +60,7 @@ func ProvideAdminHandlers( TLSFingerprintProfile: tlsFingerprintProfileHandler, APIKey: apiKeyHandler, ScheduledTest: scheduledTestHandler, + Channel: channelHandler, } } @@ -150,6 +152,7 @@ var ProviderSet = wire.NewSet( admin.NewTLSFingerprintProfileHandler, admin.NewAdminAPIKeyHandler, admin.NewScheduledTestHandler, + admin.NewChannelHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go new file mode 100644 index 0000000000..aa8696abff --- /dev/null +++ b/backend/internal/repository/channel_repo.go @@ -0,0 +1,392 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +type channelRepository struct { + db *sql.DB +} + +// NewChannelRepository 创建渠道数据访问实例 +func NewChannelRepository(db *sql.DB) service.ChannelRepository { + return &channelRepository{db: db} +} + +// runInTx 在事务中执行 fn,成功 commit,失败 rollback。 +func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) error) error { + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + if err := fn(tx); err != nil { + return err + } + return tx.Commit() +} + +func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error { + return r.runInTx(ctx, func(tx *sql.Tx) error { + err := tx.QueryRowContext(ctx, + `INSERT INTO channels (name, description, status) VALUES ($1, $2, $3) + RETURNING id, created_at, updated_at`, + channel.Name, channel.Description, channel.Status, + ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) + if err != nil { + if isUniqueViolation(err) { + return service.ErrChannelExists + } + return fmt.Errorf("insert channel: %w", err) + } + + // 设置分组关联 + if len(channel.GroupIDs) > 0 { + if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil { + return err + } + } + + // 设置模型定价 + if len(channel.ModelPricing) > 0 { + if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil { + return err + } + } + + return nil + }) +} + +func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) { + ch := &service.Channel{} + err := r.db.QueryRowContext(ctx, + `SELECT id, name, description, status, created_at, updated_at + FROM channels WHERE id = $1`, id, + ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt) + if err == sql.ErrNoRows { + return nil, service.ErrChannelNotFound + } + if err != nil { + return nil, fmt.Errorf("get channel: %w", err) + } + + groupIDs, err := r.GetGroupIDs(ctx, id) + if err != nil { + return nil, err + } + ch.GroupIDs = groupIDs + + pricing, err := r.ListModelPricing(ctx, id) + if err != nil { + return nil, err + } + ch.ModelPricing = pricing + + return ch, nil +} + +func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error { + return r.runInTx(ctx, func(tx *sql.Tx) error { + result, err := tx.ExecContext(ctx, + `UPDATE channels SET name = $1, description = $2, status = $3, updated_at = NOW() + WHERE id = $4`, + channel.Name, channel.Description, channel.Status, channel.ID, + ) + if err != nil { + if isUniqueViolation(err) { + return service.ErrChannelExists + } + return fmt.Errorf("update channel: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + return service.ErrChannelNotFound + } + + // 更新分组关联 + if channel.GroupIDs != nil { + if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil { + return err + } + } + + // 更新模型定价 + if channel.ModelPricing != nil { + if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil { + return err + } + } + + return nil + }) +} + +func (r *channelRepository) Delete(ctx context.Context, id int64) error { + result, err := r.db.ExecContext(ctx, `DELETE FROM channels WHERE id = $1`, id) + if err != nil { + return fmt.Errorf("delete channel: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + return service.ErrChannelNotFound + } + return nil +} + +func (r *channelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.Channel, *pagination.PaginationResult, error) { + where := []string{"1=1"} + args := []any{} + argIdx := 1 + + if status != "" { + where = append(where, fmt.Sprintf("c.status = $%d", argIdx)) + args = append(args, status) + argIdx++ + } + if search != "" { + where = append(where, fmt.Sprintf("(c.name ILIKE $%d OR c.description ILIKE $%d)", argIdx, argIdx)) + args = append(args, "%"+escapeLike(search)+"%") + argIdx++ + } + + whereClause := strings.Join(where, " AND ") + + // 计数 + var total int64 + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM channels c WHERE %s", whereClause) + if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, nil, fmt.Errorf("count channels: %w", err) + } + + pageSize := params.Limit() // 约束在 [1, 100] + page := params.Page + if page < 1 { + page = 1 + } + offset := (page - 1) * pageSize + + // 查询 channel 列表 + dataQuery := fmt.Sprintf( + `SELECT c.id, c.name, c.description, c.status, c.created_at, c.updated_at + FROM channels c WHERE %s ORDER BY c.id DESC LIMIT $%d OFFSET $%d`, + whereClause, argIdx, argIdx+1, + ) + args = append(args, pageSize, offset) + + rows, err := r.db.QueryContext(ctx, dataQuery, args...) + if err != nil { + return nil, nil, fmt.Errorf("query channels: %w", err) + } + defer rows.Close() + + var channels []service.Channel + var channelIDs []int64 + for rows.Next() { + var ch service.Channel + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + return nil, nil, fmt.Errorf("scan channel: %w", err) + } + channels = append(channels, ch) + channelIDs = append(channelIDs, ch.ID) + } + if err := rows.Err(); err != nil { + return nil, nil, fmt.Errorf("iterate channels: %w", err) + } + + // 批量加载分组 ID 和模型定价(避免 N+1) + if len(channelIDs) > 0 { + groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs) + if err != nil { + return nil, nil, err + } + pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs) + if err != nil { + return nil, nil, err + } + for i := range channels { + channels[i].GroupIDs = groupMap[channels[i].ID] + channels[i].ModelPricing = pricingMap[channels[i].ID] + } + } + + pages := 0 + if total > 0 { + pages = int((total + int64(pageSize) - 1) / int64(pageSize)) + } + + paginationResult := &pagination.PaginationResult{ + Total: total, + Page: page, + PageSize: pageSize, + Pages: pages, + } + + return channels, paginationResult, nil +} + +func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT id, name, description, status, created_at, updated_at FROM channels ORDER BY id`, + ) + if err != nil { + return nil, fmt.Errorf("query all channels: %w", err) + } + defer rows.Close() + + var channels []service.Channel + var channelIDs []int64 + for rows.Next() { + var ch service.Channel + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + return nil, fmt.Errorf("scan channel: %w", err) + } + channels = append(channels, ch) + channelIDs = append(channelIDs, ch.ID) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate channels: %w", err) + } + + if len(channelIDs) == 0 { + return channels, nil + } + + // 批量加载分组 ID + groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs) + if err != nil { + return nil, err + } + + // 批量加载模型定价 + pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs) + if err != nil { + return nil, err + } + + for i := range channels { + channels[i].GroupIDs = groupMap[channels[i].ID] + channels[i].ModelPricing = pricingMap[channels[i].ID] + } + + return channels, nil +} + +// --- 批量加载辅助方法 --- + +// batchLoadGroupIDs 批量加载多个渠道的分组 ID +func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs []int64) (map[int64][]int64, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT channel_id, group_id FROM channel_groups + WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`, + pq.Array(channelIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load group ids: %w", err) + } + defer rows.Close() + + groupMap := make(map[int64][]int64, len(channelIDs)) + for rows.Next() { + var channelID, groupID int64 + if err := rows.Scan(&channelID, &groupID); err != nil { + return nil, fmt.Errorf("scan group id: %w", err) + } + groupMap[channelID] = append(groupMap[channelID], groupID) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate group ids: %w", err) + } + return groupMap, nil +} + +func (r *channelRepository) ExistsByName(ctx context.Context, name string) (bool, error) { + var exists bool + err := r.db.QueryRowContext(ctx, + `SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`, name, + ).Scan(&exists) + return exists, err +} + +func (r *channelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) { + var exists bool + err := r.db.QueryRowContext(ctx, + `SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`, name, excludeID, + ).Scan(&exists) + return exists, err +} + +// --- 分组关联 --- + +func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`, channelID, + ) + if err != nil { + return nil, fmt.Errorf("get group ids: %w", err) + } + defer rows.Close() + + var ids []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, fmt.Errorf("scan group id: %w", err) + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate group ids: %w", err) + } + return ids, nil +} + +func (r *channelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error { + return setGroupIDsTx(ctx, r.db, channelID, groupIDs) +} + +func (r *channelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + var channelID int64 + err := r.db.QueryRowContext(ctx, + `SELECT channel_id FROM channel_groups WHERE group_id = $1`, groupID, + ).Scan(&channelID) + if err == sql.ErrNoRows { + return 0, nil + } + return channelID, err +} + +func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) { + if len(groupIDs) == 0 { + return nil, nil + } + rows, err := r.db.QueryContext(ctx, + `SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`, + pq.Array(groupIDs), channelID, + ) + if err != nil { + return nil, fmt.Errorf("get groups in other channels: %w", err) + } + defer rows.Close() + + var conflicting []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, fmt.Errorf("scan conflicting group id: %w", err) + } + conflicting = append(conflicting, id) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate conflicting group ids: %w", err) + } + return conflicting, nil +} diff --git a/backend/internal/repository/channel_repo_pricing.go b/backend/internal/repository/channel_repo_pricing.go new file mode 100644 index 0000000000..2e7ec6a311 --- /dev/null +++ b/backend/internal/repository/channel_repo_pricing.go @@ -0,0 +1,285 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +// --- 模型定价 --- + +func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at + FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID, + ) + if err != nil { + return nil, fmt.Errorf("list model pricing: %w", err) + } + defer rows.Close() + + result, pricingIDs, err := scanModelPricingRows(rows) + if err != nil { + return nil, err + } + + if len(pricingIDs) > 0 { + intervalMap, err := r.batchLoadIntervals(ctx, pricingIDs) + if err != nil { + return nil, err + } + for i := range result { + result[i].Intervals = intervalMap[result[i].ID] + } + } + + return result, nil +} + +func (r *channelRepository) CreateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error { + return createModelPricingExec(ctx, r.db, pricing) +} + +func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error { + modelsJSON, err := json.Marshal(pricing.Models) + if err != nil { + return fmt.Errorf("marshal models: %w", err) + } + billingMode := pricing.BillingMode + if billingMode == "" { + billingMode = service.BillingModeToken + } + result, err := r.db.ExecContext(ctx, + `UPDATE channel_model_pricing + SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, updated_at = NOW() + WHERE id = $8`, + modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, + pricing.ImageOutputPrice, pricing.ID, + ) + if err != nil { + return fmt.Errorf("update model pricing: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + return fmt.Errorf("pricing entry not found: %d", pricing.ID) + } + return nil +} + +func (r *channelRepository) DeleteModelPricing(ctx context.Context, id int64) error { + _, err := r.db.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE id = $1`, id) + if err != nil { + return fmt.Errorf("delete model pricing: %w", err) + } + return nil +} + +func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []service.ChannelModelPricing) error { + return r.runInTx(ctx, func(tx *sql.Tx) error { + return replaceModelPricingTx(ctx, tx, channelID, pricingList) + }) +} + +// --- 批量加载辅助方法 --- + +// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间) +func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at + FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`, + pq.Array(channelIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load model pricing: %w", err) + } + defer rows.Close() + + allPricing, allPricingIDs, err := scanModelPricingRows(rows) + if err != nil { + return nil, err + } + + // 按 channelID 分组 + pricingMap := make(map[int64][]service.ChannelModelPricing, len(channelIDs)) + for _, p := range allPricing { + pricingMap[p.ChannelID] = append(pricingMap[p.ChannelID], p) + } + + // 批量加载所有区间 + if len(allPricingIDs) > 0 { + intervalMap, err := r.batchLoadIntervals(ctx, allPricingIDs) + if err != nil { + return nil, err + } + for chID := range pricingMap { + for i := range pricingMap[chID] { + pricingMap[chID][i].Intervals = intervalMap[pricingMap[chID][i].ID] + } + } + } + + return pricingMap, nil +} + +// batchLoadIntervals 批量加载多个定价条目的区间 +func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT id, pricing_id, min_tokens, max_tokens, tier_label, + input_price, output_price, cache_write_price, cache_read_price, + per_request_price, sort_order, created_at, updated_at + FROM channel_pricing_intervals + WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`, + pq.Array(pricingIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load intervals: %w", err) + } + defer rows.Close() + + intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs)) + for rows.Next() { + var iv service.PricingInterval + if err := rows.Scan( + &iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel, + &iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice, + &iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan interval: %w", err) + } + intervalMap[iv.PricingID] = append(intervalMap[iv.PricingID], iv) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate intervals: %w", err) + } + return intervalMap, nil +} + +// --- 共享 scan 辅助 --- + +// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表 +func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int64, error) { + var result []service.ChannelModelPricing + var pricingIDs []int64 + for rows.Next() { + var p service.ChannelModelPricing + var modelsJSON []byte + if err := rows.Scan( + &p.ID, &p.ChannelID, &modelsJSON, &p.BillingMode, + &p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice, + &p.ImageOutputPrice, &p.CreatedAt, &p.UpdatedAt, + ); err != nil { + return nil, nil, fmt.Errorf("scan model pricing: %w", err) + } + if err := json.Unmarshal(modelsJSON, &p.Models); err != nil { + p.Models = []string{} + } + pricingIDs = append(pricingIDs, p.ID) + result = append(result, p) + } + if err := rows.Err(); err != nil { + return nil, nil, fmt.Errorf("iterate model pricing: %w", err) + } + return result, pricingIDs, nil +} + +// --- 事务内辅助方法 --- + +// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口 +type dbExec interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +func setGroupIDsTx(ctx context.Context, exec dbExec, channelID int64, groupIDs []int64) error { + if _, err := exec.ExecContext(ctx, `DELETE FROM channel_groups WHERE channel_id = $1`, channelID); err != nil { + return fmt.Errorf("delete old group associations: %w", err) + } + if len(groupIDs) == 0 { + return nil + } + _, err := exec.ExecContext(ctx, + `INSERT INTO channel_groups (channel_id, group_id) + SELECT $1, unnest($2::bigint[])`, + channelID, pq.Array(groupIDs), + ) + if err != nil { + return fmt.Errorf("insert group associations: %w", err) + } + return nil +} + +func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.ChannelModelPricing) error { + modelsJSON, err := json.Marshal(pricing.Models) + if err != nil { + return fmt.Errorf("marshal models: %w", err) + } + billingMode := pricing.BillingMode + if billingMode == "" { + billingMode = service.BillingModeToken + } + err = exec.QueryRowContext(ctx, + `INSERT INTO channel_model_pricing (channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, updated_at`, + pricing.ChannelID, modelsJSON, billingMode, + pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, + pricing.ImageOutputPrice, + ).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt) + if err != nil { + return fmt.Errorf("insert model pricing: %w", err) + } + + for i := range pricing.Intervals { + pricing.Intervals[i].PricingID = pricing.ID + if err := createIntervalExec(ctx, exec, &pricing.Intervals[i]); err != nil { + return err + } + } + + return nil +} + +func createIntervalExec(ctx context.Context, exec dbExec, iv *service.PricingInterval) error { + return exec.QueryRowContext(ctx, + `INSERT INTO channel_pricing_intervals + (pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`, + iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel, + iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice, + iv.PerRequestPrice, iv.SortOrder, + ).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt) +} + +func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pricingList []service.ChannelModelPricing) error { + if _, err := exec.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE channel_id = $1`, channelID); err != nil { + return fmt.Errorf("delete old model pricing: %w", err) + } + for i := range pricingList { + pricingList[i].ChannelID = channelID + if err := createModelPricingExec(ctx, exec, &pricingList[i]); err != nil { + return fmt.Errorf("insert model pricing: %w", err) + } + } + return nil +} + +// isUniqueViolation 检查 pq 唯一约束违反错误 +func isUniqueViolation(err error) bool { + if pqErr, ok := err.(*pq.Error); ok { + return pqErr.Code == "23505" + } + return false +} + +// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符 +func escapeLike(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `%`, `\%`) + s = strings.ReplaceAll(s, `_`, `\_`) + return s +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 49d47bf63c..4548c02882 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet( NewUserGroupRateRepository, NewErrorPassthroughRepository, NewTLSFingerprintProfileRepository, + NewChannelRepository, // Cache implementations NewGatewayCache, diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index e04dae8521..abc28295fb 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -87,6 +87,9 @@ func RegisterAdminRoutes( // 定时测试计划 registerScheduledTestRoutes(admin, h) + + // 渠道管理 + registerChannelRoutes(admin, h) } } @@ -567,3 +570,14 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete) } } + +func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + channels := admin.Group("/channels") + { + channels.GET("", h.Admin.Channel.List) + channels.GET("/:id", h.Admin.Channel.GetByID) + channels.POST("", h.Admin.Channel.Create) + channels.PUT("/:id", h.Admin.Channel.Update) + channels.DELETE("/:id", h.Admin.Channel.Delete) + } +} diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 004511f5dd..58c86f36ee 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -371,13 +371,193 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { return nil, fmt.Errorf("pricing not found for model: %s", model) } +// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值 +// 仅覆盖渠道中非 nil 的价格字段,nil 字段使用默认定价 +func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing *ChannelModelPricing) (*ModelPricing, error) { + pricing, err := s.GetModelPricing(model) + if err != nil { + return nil, err + } + if channelPricing == nil { + return pricing, nil + } + if channelPricing.InputPrice != nil { + pricing.InputPricePerToken = *channelPricing.InputPrice + pricing.InputPricePerTokenPriority = *channelPricing.InputPrice + } + if channelPricing.OutputPrice != nil { + pricing.OutputPricePerToken = *channelPricing.OutputPrice + pricing.OutputPricePerTokenPriority = *channelPricing.OutputPrice + } + if channelPricing.CacheWritePrice != nil { + pricing.CacheCreationPricePerToken = *channelPricing.CacheWritePrice + pricing.CacheCreation5mPrice = *channelPricing.CacheWritePrice + pricing.CacheCreation1hPrice = *channelPricing.CacheWritePrice + } + if channelPricing.CacheReadPrice != nil { + pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice + pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice + } + return pricing, nil +} + +// CalculateCostWithChannel 使用渠道定价计算费用 +// Deprecated: 使用 CalculateCostUnified 代替 +func (s *BillingService) CalculateCostWithChannel(model string, tokens UsageTokens, rateMultiplier float64, channelPricing *ChannelModelPricing) (*CostBreakdown, error) { + return s.calculateCostInternal(model, tokens, rateMultiplier, "", channelPricing) +} + +// --- 统一计费入口 --- + +// CostInput 统一计费输入 +type CostInput struct { + Ctx context.Context + Model string + GroupID *int64 // 用于渠道定价查找 + Tokens UsageTokens + RequestCount int // 按次计费时使用 + SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等) + RateMultiplier float64 + ServiceTier string // "priority","flex","" 等 + Resolver *ModelPricingResolver // 定价解析器 +} + +// CalculateCostUnified 统一计费入口,支持三种计费模式。 +// 使用 ModelPricingResolver 解析定价,然后根据 BillingMode 分发计算。 +func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, error) { + if input.Resolver == nil { + // 无 Resolver,回退到旧路径 + return s.calculateCostInternal(input.Model, input.Tokens, input.RateMultiplier, input.ServiceTier, nil) + } + + resolved := input.Resolver.Resolve(input.Ctx, PricingInput{ + Model: input.Model, + GroupID: input.GroupID, + }) + + if input.RateMultiplier <= 0 { + input.RateMultiplier = 1.0 + } + + switch resolved.Mode { + case BillingModePerRequest, BillingModeImage: + return s.calculatePerRequestCost(resolved, input) + default: // BillingModeToken + return s.calculateTokenCost(resolved, input) + } +} + +// calculateTokenCost 按 token 区间计费 +func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) { + totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens + + pricing := input.Resolver.GetIntervalPricing(resolved, totalContext) + if pricing == nil { + return nil, fmt.Errorf("no pricing available for model: %s", input.Model) + } + + pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing) + + breakdown := &CostBreakdown{} + inputPricePerToken := pricing.InputPricePerToken + outputPricePerToken := pricing.OutputPricePerToken + cacheReadPricePerToken := pricing.CacheReadPricePerToken + tierMultiplier := 1.0 + + if usePriorityServiceTierPricing(input.ServiceTier, pricing) { + if pricing.InputPricePerTokenPriority > 0 { + inputPricePerToken = pricing.InputPricePerTokenPriority + } + if pricing.OutputPricePerTokenPriority > 0 { + outputPricePerToken = pricing.OutputPricePerTokenPriority + } + if pricing.CacheReadPricePerTokenPriority > 0 { + cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority + } + } else { + tierMultiplier = serviceTierCostMultiplier(input.ServiceTier) + } + + // 长上下文定价(仅在无区间定价时应用,区间定价已包含上下文分层) + if len(resolved.Intervals) == 0 && s.shouldApplySessionLongContextPricing(input.Tokens, pricing) { + inputPricePerToken *= pricing.LongContextInputMultiplier + outputPricePerToken *= pricing.LongContextOutputMultiplier + } + + breakdown.InputCost = float64(input.Tokens.InputTokens) * inputPricePerToken + breakdown.OutputCost = float64(input.Tokens.OutputTokens) * outputPricePerToken + + if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { + if input.Tokens.CacheCreation5mTokens == 0 && input.Tokens.CacheCreation1hTokens == 0 && input.Tokens.CacheCreationTokens > 0 { + breakdown.CacheCreationCost = float64(input.Tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice + } else { + breakdown.CacheCreationCost = float64(input.Tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice + + float64(input.Tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice + } + } else { + breakdown.CacheCreationCost = float64(input.Tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken + } + + breakdown.CacheReadCost = float64(input.Tokens.CacheReadTokens) * cacheReadPricePerToken + + if tierMultiplier != 1.0 { + breakdown.InputCost *= tierMultiplier + breakdown.OutputCost *= tierMultiplier + breakdown.CacheCreationCost *= tierMultiplier + breakdown.CacheReadCost *= tierMultiplier + } + + breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + + breakdown.CacheCreationCost + breakdown.CacheReadCost + breakdown.ActualCost = breakdown.TotalCost * input.RateMultiplier + + return breakdown, nil +} + +// calculatePerRequestCost 按次/图片计费 +func (s *BillingService) calculatePerRequestCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) { + count := input.RequestCount + if count <= 0 { + count = 1 + } + + var unitPrice float64 + + if input.SizeTier != "" { + unitPrice = input.Resolver.GetRequestTierPrice(resolved, input.SizeTier) + } + + if unitPrice == 0 { + totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens + unitPrice = input.Resolver.GetRequestTierPriceByContext(resolved, totalContext) + } + + totalCost := unitPrice * float64(count) + actualCost := totalCost * input.RateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + }, nil +} + // CalculateCost 计算使用费用 func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) { - return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "") + return s.calculateCostInternal(model, tokens, rateMultiplier, "", nil) } func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) { - pricing, err := s.GetModelPricing(model) + return s.calculateCostInternal(model, tokens, rateMultiplier, serviceTier, nil) +} + +func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string, channelPricing *ChannelModelPricing) (*CostBreakdown, error) { + var pricing *ModelPricing + var err error + if channelPricing != nil { + pricing, err = s.GetModelPricingWithChannel(model, channelPricing) + } else { + pricing, err = s.GetModelPricing(model) + } if err != nil { return nil, err } diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go new file mode 100644 index 0000000000..e3556edd0d --- /dev/null +++ b/backend/internal/service/channel.go @@ -0,0 +1,171 @@ +package service + +import ( + "strings" + "time" +) + +// BillingMode 计费模式 +type BillingMode string + +const ( + BillingModeToken BillingMode = "token" // 按 token 区间计费 + BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层) + BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费) +) + +// IsValid 检查 BillingMode 是否为合法值 +func (m BillingMode) IsValid() bool { + switch m { + case BillingModeToken, BillingModePerRequest, BillingModeImage, "": + return true + } + return false +} + +// Channel 渠道实体 +type Channel struct { + ID int64 + Name string + Description string + Status string + CreatedAt time.Time + UpdatedAt time.Time + + // 关联的分组 ID 列表 + GroupIDs []int64 + // 模型定价列表 + ModelPricing []ChannelModelPricing +} + +// ChannelModelPricing 渠道模型定价条目 +type ChannelModelPricing struct { + ID int64 + ChannelID int64 + Models []string // 绑定的模型列表 + BillingMode BillingMode // 计费模式 + InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 + OutputPrice *float64 // 每 token 输出价格(USD) + CacheWritePrice *float64 // 缓存写入价格 + CacheReadPrice *float64 // 缓存读取价格 + ImageOutputPrice *float64 // 图片输出价格(向后兼容) + Intervals []PricingInterval // 区间定价列表 + CreatedAt time.Time + UpdatedAt time.Time +} + +// PricingInterval 定价区间(token 区间 / 按次分层 / 图片分辨率分层) +type PricingInterval struct { + ID int64 + PricingID int64 + MinTokens int // 区间下界(含) + MaxTokens *int // 区间上界(不含),nil = 无上限 + TierLabel string // 层级标签(按次/图片模式:1K, 2K, 4K, HD 等) + InputPrice *float64 // token 模式:每 token 输入价 + OutputPrice *float64 // token 模式:每 token 输出价 + CacheWritePrice *float64 // token 模式:缓存写入价 + CacheReadPrice *float64 // token 模式:缓存读取价 + PerRequestPrice *float64 // 按次/图片模式:每次请求价格 + SortOrder int + CreatedAt time.Time + UpdatedAt time.Time +} + +// IsActive 判断渠道是否启用 +func (c *Channel) IsActive() bool { + return c.Status == StatusActive +} + +// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。 +// 优先精确匹配,然后通配符匹配(如 claude-opus-*)。大小写不敏感。 +// 返回值拷贝,不污染缓存。 +func (c *Channel) GetModelPricing(model string) *ChannelModelPricing { + modelLower := strings.ToLower(model) + + // 第一轮:精确匹配 + for i := range c.ModelPricing { + for _, m := range c.ModelPricing[i].Models { + if strings.ToLower(m) == modelLower { + cp := c.ModelPricing[i].Clone() + return &cp + } + } + } + + // 第二轮:通配符匹配(仅支持末尾 *) + for i := range c.ModelPricing { + for _, m := range c.ModelPricing[i].Models { + mLower := strings.ToLower(m) + if strings.HasSuffix(mLower, "*") { + prefix := strings.TrimSuffix(mLower, "*") + if strings.HasPrefix(modelLower, prefix) { + cp := c.ModelPricing[i].Clone() + return &cp + } + } + } + } + + return nil +} + +// FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。 +// 通用辅助函数,供 GetIntervalForContext、ModelPricingResolver 等复用。 +func FindMatchingInterval(intervals []PricingInterval, totalTokens int) *PricingInterval { + for i := range intervals { + iv := &intervals[i] + if totalTokens >= iv.MinTokens && (iv.MaxTokens == nil || totalTokens < *iv.MaxTokens) { + return iv + } + } + return nil +} + +// GetIntervalForContext 根据总 context token 数查找匹配的区间。 +func (p *ChannelModelPricing) GetIntervalForContext(totalTokens int) *PricingInterval { + return FindMatchingInterval(p.Intervals, totalTokens) +} + +// GetTierByLabel 根据标签查找层级(用于 per_request / image 模式) +func (p *ChannelModelPricing) GetTierByLabel(label string) *PricingInterval { + labelLower := strings.ToLower(label) + for i := range p.Intervals { + if strings.ToLower(p.Intervals[i].TierLabel) == labelLower { + return &p.Intervals[i] + } + } + return nil +} + +// Clone 返回 ChannelModelPricing 的拷贝(切片独立,指针字段共享,调用方只读安全) +func (p ChannelModelPricing) Clone() ChannelModelPricing { + cp := p + if p.Models != nil { + cp.Models = make([]string, len(p.Models)) + copy(cp.Models, p.Models) + } + if p.Intervals != nil { + cp.Intervals = make([]PricingInterval, len(p.Intervals)) + copy(cp.Intervals, p.Intervals) + } + return cp +} + +// Clone 返回 Channel 的深拷贝 +func (c *Channel) Clone() *Channel { + if c == nil { + return nil + } + cp := *c + if c.GroupIDs != nil { + cp.GroupIDs = make([]int64, len(c.GroupIDs)) + copy(cp.GroupIDs, c.GroupIDs) + } + if c.ModelPricing != nil { + cp.ModelPricing = make([]ChannelModelPricing, len(c.ModelPricing)) + for i := range c.ModelPricing { + cp.ModelPricing[i] = c.ModelPricing[i].Clone() + } + } + return &cp +} diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go new file mode 100644 index 0000000000..8f00481fd7 --- /dev/null +++ b/backend/internal/service/channel_service.go @@ -0,0 +1,338 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "sync/atomic" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "golang.org/x/sync/singleflight" +) + +var ( + ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found") + ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists") + ErrGroupAlreadyInChannel = infraerrors.Conflict( + "GROUP_ALREADY_IN_CHANNEL", + "one or more groups already belong to another channel", + ) +) + +// ChannelRepository 渠道数据访问接口 +type ChannelRepository interface { + Create(ctx context.Context, channel *Channel) error + GetByID(ctx context.Context, id int64) (*Channel, error) + Update(ctx context.Context, channel *Channel) error + Delete(ctx context.Context, id int64) error + List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) + ListAll(ctx context.Context) ([]Channel, error) + ExistsByName(ctx context.Context, name string) (bool, error) + ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) + + // 分组关联 + GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) + SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error + GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) + GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) + + // 模型定价 + ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) + CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error + UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error + DeleteModelPricing(ctx context.Context, id int64) error + ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error +} + +// channelCache 渠道缓存快照 +type channelCache struct { + // byID: channelID -> *Channel(含 ModelPricing) + byID map[int64]*Channel + // byGroupID: groupID -> channelID + byGroupID map[int64]int64 + loadedAt time.Time +} + +const ( + channelCacheTTL = 60 * time.Second + channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 + channelCacheDBTimeout = 10 * time.Second +) + +// ChannelService 渠道管理服务 +type ChannelService struct { + repo ChannelRepository + authCacheInvalidator APIKeyAuthCacheInvalidator + + cache atomic.Value // *channelCache + cacheSF singleflight.Group +} + +// NewChannelService 创建渠道服务实例 +func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService { + s := &ChannelService{ + repo: repo, + authCacheInvalidator: authCacheInvalidator, + } + return s +} + +// loadCache 加载或返回缓存的渠道数据 +func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) { + if cached, ok := s.cache.Load().(*channelCache); ok { + if time.Since(cached.loadedAt) < channelCacheTTL { + return cached, nil + } + } + + result, err, _ := s.cacheSF.Do("channel_cache", func() (any, error) { + // 双重检查 + if cached, ok := s.cache.Load().(*channelCache); ok { + if time.Since(cached.loadedAt) < channelCacheTTL { + return cached, nil + } + } + return s.buildCache(ctx) + }) + if err != nil { + return nil, err + } + return result.(*channelCache), nil +} + +// buildCache 从数据库构建渠道缓存。 +// 使用独立 context 避免请求取消导致空值被长期缓存。 +func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) { + // 断开请求取消链,避免客户端断连导致空值被长期缓存 + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout) + defer cancel() + + channels, err := s.repo.ListAll(dbCtx) + if err != nil { + // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 + slog.Warn("failed to build channel cache", "error", err) + errorCache := &channelCache{ + byID: make(map[int64]*Channel), + byGroupID: make(map[int64]int64), + loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL + } + s.cache.Store(errorCache) + return nil, fmt.Errorf("list all channels: %w", err) + } + + cache := &channelCache{ + byID: make(map[int64]*Channel, len(channels)), + byGroupID: make(map[int64]int64), + loadedAt: time.Now(), + } + + for i := range channels { + ch := &channels[i] + cache.byID[ch.ID] = ch + for _, gid := range ch.GroupIDs { + cache.byGroupID[gid] = ch.ID + } + } + + s.cache.Store(cache) + return cache, nil +} + +// invalidateCache 使缓存失效,让下次读取时自然重建 +func (s *ChannelService) invalidateCache() { + s.cache.Store((*channelCache)(nil)) + s.cacheSF.Forget("channel_cache") +} + +// GetChannelForGroup 获取分组关联的渠道(热路径,从缓存读取) +// 返回深拷贝,不污染缓存。 +func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) { + cache, err := s.loadCache(ctx) + if err != nil { + return nil, err + } + + channelID, ok := cache.byGroupID[groupID] + if !ok { + return nil, nil + } + + ch, ok := cache.byID[channelID] + if !ok { + return nil, nil + } + + if !ch.IsActive() { + return nil, nil + } + + return ch.Clone(), nil +} + +// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径) +func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing { + ch, err := s.GetChannelForGroup(ctx, groupID) + if err != nil { + slog.Warn("failed to get channel for group", "group_id", groupID, "error", err) + return nil + } + if ch == nil { + return nil + } + return ch.GetModelPricing(model) +} + +// --- CRUD --- + +// Create 创建渠道 +func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) (*Channel, error) { + exists, err := s.repo.ExistsByName(ctx, input.Name) + if err != nil { + return nil, fmt.Errorf("check channel exists: %w", err) + } + if exists { + return nil, ErrChannelExists + } + + // 检查分组冲突 + if len(input.GroupIDs) > 0 { + conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs) + if err != nil { + return nil, fmt.Errorf("check group conflicts: %w", err) + } + if len(conflicting) > 0 { + return nil, ErrGroupAlreadyInChannel + } + } + + channel := &Channel{ + Name: input.Name, + Description: input.Description, + Status: StatusActive, + GroupIDs: input.GroupIDs, + ModelPricing: input.ModelPricing, + } + + if err := s.repo.Create(ctx, channel); err != nil { + return nil, fmt.Errorf("create channel: %w", err) + } + + s.invalidateCache() + return s.repo.GetByID(ctx, channel.ID) +} + +// GetByID 获取渠道详情 +func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) { + return s.repo.GetByID(ctx, id) +} + +// Update 更新渠道 +func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChannelInput) (*Channel, error) { + channel, err := s.repo.GetByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("get channel: %w", err) + } + + if input.Name != "" && input.Name != channel.Name { + exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id) + if err != nil { + return nil, fmt.Errorf("check channel exists: %w", err) + } + if exists { + return nil, ErrChannelExists + } + channel.Name = input.Name + } + + if input.Description != nil { + channel.Description = *input.Description + } + + if input.Status != "" { + channel.Status = input.Status + } + + // 检查分组冲突 + if input.GroupIDs != nil { + conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs) + if err != nil { + return nil, fmt.Errorf("check group conflicts: %w", err) + } + if len(conflicting) > 0 { + return nil, ErrGroupAlreadyInChannel + } + channel.GroupIDs = *input.GroupIDs + } + + if input.ModelPricing != nil { + channel.ModelPricing = *input.ModelPricing + } + + if err := s.repo.Update(ctx, channel); err != nil { + return nil, fmt.Errorf("update channel: %w", err) + } + + s.invalidateCache() + + // 失效关联分组的 auth 缓存 + if s.authCacheInvalidator != nil { + groupIDs, err := s.repo.GetGroupIDs(ctx, id) + if err != nil { + slog.Warn("failed to get group IDs for cache invalidation", "channel_id", id, "error", err) + } + for _, gid := range groupIDs { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } + + return s.repo.GetByID(ctx, id) +} + +// Delete 删除渠道 +func (s *ChannelService) Delete(ctx context.Context, id int64) error { + // 先获取关联分组用于失效缓存 + groupIDs, err := s.repo.GetGroupIDs(ctx, id) + if err != nil { + slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err) + } + + if err := s.repo.Delete(ctx, id); err != nil { + return fmt.Errorf("delete channel: %w", err) + } + + s.invalidateCache() + + if s.authCacheInvalidator != nil { + for _, gid := range groupIDs { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } + + return nil +} + +// List 获取渠道列表 +func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) { + return s.repo.List(ctx, params, status, search) +} + +// --- Input types --- + +// CreateChannelInput 创建渠道输入 +type CreateChannelInput struct { + Name string + Description string + GroupIDs []int64 + ModelPricing []ChannelModelPricing +} + +// UpdateChannelInput 更新渠道输入 +type UpdateChannelInput struct { + Name string + Description *string + Status string + GroupIDs *[]int64 + ModelPricing *[]ChannelModelPricing +} diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go new file mode 100644 index 0000000000..004d06b1d7 --- /dev/null +++ b/backend/internal/service/channel_test.go @@ -0,0 +1,210 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func channelTestPtrFloat64(v float64) *float64 { return &v } +func channelTestPtrInt(v int) *int { return &v } + +func TestGetModelPricing(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)}, + {ID: 2, Models: []string{"claude-*"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(5e-6)}, + {ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest}, + }, + } + + tests := []struct { + name string + model string + wantID int64 + wantNil bool + }{ + {"exact match", "claude-sonnet-4", 1, false}, + {"case insensitive", "Claude-Sonnet-4", 1, false}, + {"wildcard match", "claude-opus-4-20250514", 2, false}, + {"exact takes priority over wildcard", "claude-sonnet-4", 1, false}, + {"not found", "gemini-3.1-pro", 0, true}, + {"per_request model", "gpt-5.1", 3, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ch.GetModelPricing(tt.model) + if tt.wantNil { + require.Nil(t, result) + return + } + require.NotNil(t, result) + require.Equal(t, tt.wantID, result.ID) + }) + } +} + +func TestGetModelPricing_ReturnsCopy(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: channelTestPtrFloat64(3e-6)}, + }, + } + + result := ch.GetModelPricing("claude-sonnet-4") + require.NotNil(t, result) + + // Modify the returned copy's slice — original should be unchanged + result.Models = append(result.Models, "hacked") + + // Original should be unchanged + require.Equal(t, 1, len(ch.ModelPricing[0].Models)) +} + +func TestGetModelPricing_EmptyPricing(t *testing.T) { + ch := &Channel{ModelPricing: nil} + require.Nil(t, ch.GetModelPricing("any-model")) + + ch2 := &Channel{ModelPricing: []ChannelModelPricing{}} + require.Nil(t, ch2.GetModelPricing("any-model")) +} + +func TestGetIntervalForContext(t *testing.T) { + p := &ChannelModelPricing{ + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: channelTestPtrInt(128000), InputPrice: channelTestPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: channelTestPtrFloat64(2e-6)}, + }, + } + + tests := []struct { + name string + tokens int + wantPrice *float64 + wantNil bool + }{ + {"first interval", 50000, channelTestPtrFloat64(1e-6), false}, + {"boundary: at min of second", 128000, channelTestPtrFloat64(2e-6), false}, + {"boundary: at max of first (exclusive)", 128000, channelTestPtrFloat64(2e-6), false}, + {"unbounded interval", 500000, channelTestPtrFloat64(2e-6), false}, + {"zero tokens", 0, channelTestPtrFloat64(1e-6), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := p.GetIntervalForContext(tt.tokens) + if tt.wantNil { + require.Nil(t, result) + return + } + require.NotNil(t, result) + require.InDelta(t, *tt.wantPrice, *result.InputPrice, 1e-12) + }) + } +} + +func TestGetIntervalForContext_NoMatch(t *testing.T) { + p := &ChannelModelPricing{ + Intervals: []PricingInterval{ + {MinTokens: 10000, MaxTokens: channelTestPtrInt(50000)}, + }, + } + require.Nil(t, p.GetIntervalForContext(5000)) + require.Nil(t, p.GetIntervalForContext(50000)) +} + +func TestGetIntervalForContext_Empty(t *testing.T) { + p := &ChannelModelPricing{Intervals: nil} + require.Nil(t, p.GetIntervalForContext(1000)) +} + +func TestGetTierByLabel(t *testing.T) { + p := &ChannelModelPricing{ + Intervals: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: channelTestPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: channelTestPtrFloat64(0.08)}, + {TierLabel: "HD", PerRequestPrice: channelTestPtrFloat64(0.12)}, + }, + } + + tests := []struct { + name string + label string + wantNil bool + want float64 + }{ + {"exact match", "1K", false, 0.04}, + {"case insensitive", "hd", false, 0.12}, + {"not found", "4K", true, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := p.GetTierByLabel(tt.label) + if tt.wantNil { + require.Nil(t, result) + return + } + require.NotNil(t, result) + require.InDelta(t, tt.want, *result.PerRequestPrice, 1e-12) + }) + } +} + +func TestGetTierByLabel_Empty(t *testing.T) { + p := &ChannelModelPricing{Intervals: nil} + require.Nil(t, p.GetTierByLabel("1K")) +} + +func TestChannelClone(t *testing.T) { + original := &Channel{ + ID: 1, + Name: "test", + GroupIDs: []int64{10, 20}, + ModelPricing: []ChannelModelPricing{ + { + ID: 100, + Models: []string{"model-a"}, + InputPrice: channelTestPtrFloat64(5e-6), + }, + }, + } + + cloned := original.Clone() + require.NotNil(t, cloned) + require.Equal(t, original.ID, cloned.ID) + require.Equal(t, original.Name, cloned.Name) + + // Modify clone slices — original should not change + cloned.GroupIDs[0] = 999 + require.Equal(t, int64(10), original.GroupIDs[0]) + + cloned.ModelPricing[0].Models[0] = "hacked" + require.Equal(t, "model-a", original.ModelPricing[0].Models[0]) +} + +func TestChannelClone_Nil(t *testing.T) { + var ch *Channel + require.Nil(t, ch.Clone()) +} + +func TestChannelModelPricingClone(t *testing.T) { + original := ChannelModelPricing{ + Models: []string{"a", "b"}, + Intervals: []PricingInterval{ + {MinTokens: 0, TierLabel: "tier1"}, + }, + } + + cloned := original.Clone() + + // Modify clone slices — original unchanged + cloned.Models[0] = "hacked" + require.Equal(t, "a", original.Models[0]) + + cloned.Intervals[0].TierLabel = "hacked" + require.Equal(t, "tier1", original.Intervals[0].TierLabel) +} diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 48488dc8c5..5df0b58c57 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -41,6 +41,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo nil, nil, nil, + nil, ) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 94e04d286d..69d218f81d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -568,6 +568,7 @@ type GatewayService struct { responseHeaderFilter *responseheaders.CompiledHeaderFilter debugModelRouting atomic.Bool debugClaudeMimic atomic.Bool + channelService *ChannelService debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set tlsFPProfileService *TLSFingerprintProfileService } @@ -597,6 +598,7 @@ func NewGatewayService( digestStore *DigestSessionStore, settingService *SettingService, tlsFPProfileService *TLSFingerprintProfileService, + channelService *ChannelService, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) @@ -629,6 +631,7 @@ func NewGatewayService( modelsListCacheTTL: modelsListTTL, responseHeaderFilter: compileResponseHeaderFilter(cfg), tlsFPProfileService: tlsFPProfileService, + channelService: channelService, } svc.userGroupRateResolver = newUserGroupRateResolver( userGroupRateRepo, @@ -7771,7 +7774,16 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error - cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) + // 渠道定价覆盖 + var chPricing *ChannelModelPricing + if s.channelService != nil && apiKey.Group != nil { + chPricing = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel) + } + if chPricing != nil { + cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing) + } else { + cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) + } if err != nil { logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} @@ -7959,7 +7971,16 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error - cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) + // 渠道定价覆盖 + var chPricing2 *ChannelModelPricing + if s.channelService != nil && apiKey.Group != nil { + chPricing2 = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel) + } + if chPricing2 != nil { + cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing2) + } else { + cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) + } if err != nil { logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} diff --git a/backend/internal/service/model_pricing_resolver.go b/backend/internal/service/model_pricing_resolver.go new file mode 100644 index 0000000000..67f2684cae --- /dev/null +++ b/backend/internal/service/model_pricing_resolver.go @@ -0,0 +1,198 @@ +package service + +import ( + "context" + "log/slog" +) + +// ResolvedPricing 统一定价解析结果 +type ResolvedPricing struct { + // Mode 计费模式 + Mode BillingMode + + // Token 模式:基础定价(来自 LiteLLM 或 fallback) + BasePricing *ModelPricing + + // Token 模式:区间定价列表(如有,覆盖 BasePricing 中的对应字段) + Intervals []PricingInterval + + // 按次/图片模式:分层定价 + RequestTiers []PricingInterval + + // 来源标识 + Source string // "channel", "litellm", "fallback" + + // 是否支持缓存细分 + SupportsCacheBreakdown bool +} + +// ModelPricingResolver 统一模型定价解析器。 +// 解析链:Channel → LiteLLM → Fallback。 +type ModelPricingResolver struct { + channelService *ChannelService + billingService *BillingService +} + +// NewModelPricingResolver 创建定价解析器实例 +func NewModelPricingResolver(channelService *ChannelService, billingService *BillingService) *ModelPricingResolver { + return &ModelPricingResolver{ + channelService: channelService, + billingService: billingService, + } +} + +// PricingInput 定价解析输入 +type PricingInput struct { + Model string + GroupID *int64 // nil 表示不检查渠道 +} + +// Resolve 解析模型定价。 +// 1. 获取基础定价(LiteLLM → Fallback) +// 2. 如果指定了 GroupID,查找渠道定价并覆盖 +func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing { + // 1. 获取基础定价 + basePricing, source := r.resolveBasePricing(input.Model) + + resolved := &ResolvedPricing{ + Mode: BillingModeToken, + BasePricing: basePricing, + Source: source, + SupportsCacheBreakdown: basePricing != nil && basePricing.SupportsCacheBreakdown, + } + + // 2. 如果有 GroupID,尝试渠道覆盖 + if input.GroupID != nil { + r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved) + } + + return resolved +} + +// resolveBasePricing 从 LiteLLM 或 Fallback 获取基础定价 +func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing, string) { + pricing, err := r.billingService.GetModelPricing(model) + if err != nil { + slog.Debug("failed to get model pricing from LiteLLM, using fallback", + "model", model, "error", err) + return nil, "fallback" + } + return pricing, "litellm" +} + +// applyChannelOverrides 应用渠道定价覆盖 +func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupID int64, model string, resolved *ResolvedPricing) { + chPricing := r.channelService.GetChannelModelPricing(ctx, groupID, model) + if chPricing == nil { + return + } + + resolved.Source = "channel" + resolved.Mode = chPricing.BillingMode + if resolved.Mode == "" { + resolved.Mode = BillingModeToken + } + + switch resolved.Mode { + case BillingModeToken: + r.applyTokenOverrides(chPricing, resolved) + case BillingModePerRequest, BillingModeImage: + r.applyRequestTierOverrides(chPricing, resolved) + } +} + +// applyTokenOverrides 应用 token 模式的渠道覆盖 +func (r *ModelPricingResolver) applyTokenOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) { + // 如果有区间定价,使用区间 + if len(chPricing.Intervals) > 0 { + resolved.Intervals = chPricing.Intervals + return + } + + // 否则用 flat 字段覆盖 BasePricing + if resolved.BasePricing == nil { + resolved.BasePricing = &ModelPricing{} + } + + if chPricing.InputPrice != nil { + resolved.BasePricing.InputPricePerToken = *chPricing.InputPrice + resolved.BasePricing.InputPricePerTokenPriority = *chPricing.InputPrice + } + if chPricing.OutputPrice != nil { + resolved.BasePricing.OutputPricePerToken = *chPricing.OutputPrice + resolved.BasePricing.OutputPricePerTokenPriority = *chPricing.OutputPrice + } + if chPricing.CacheWritePrice != nil { + resolved.BasePricing.CacheCreationPricePerToken = *chPricing.CacheWritePrice + resolved.BasePricing.CacheCreation5mPrice = *chPricing.CacheWritePrice + resolved.BasePricing.CacheCreation1hPrice = *chPricing.CacheWritePrice + } + if chPricing.CacheReadPrice != nil { + resolved.BasePricing.CacheReadPricePerToken = *chPricing.CacheReadPrice + resolved.BasePricing.CacheReadPricePerTokenPriority = *chPricing.CacheReadPrice + } +} + +// applyRequestTierOverrides 应用按次/图片模式的渠道覆盖 +func (r *ModelPricingResolver) applyRequestTierOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) { + resolved.RequestTiers = chPricing.Intervals +} + +// GetIntervalPricing 根据 context token 数获取区间定价。 +// 如果有区间列表,找到匹配区间并构造 ModelPricing;否则直接返回 BasePricing。 +func (r *ModelPricingResolver) GetIntervalPricing(resolved *ResolvedPricing, totalContextTokens int) *ModelPricing { + if len(resolved.Intervals) == 0 { + return resolved.BasePricing + } + + iv := FindMatchingInterval(resolved.Intervals, totalContextTokens) + if iv == nil { + return resolved.BasePricing + } + + return intervalToModelPricing(iv, resolved.SupportsCacheBreakdown) +} + +// intervalToModelPricing 将区间定价转换为 ModelPricing +func intervalToModelPricing(iv *PricingInterval, supportsCacheBreakdown bool) *ModelPricing { + pricing := &ModelPricing{ + SupportsCacheBreakdown: supportsCacheBreakdown, + } + if iv.InputPrice != nil { + pricing.InputPricePerToken = *iv.InputPrice + pricing.InputPricePerTokenPriority = *iv.InputPrice + } + if iv.OutputPrice != nil { + pricing.OutputPricePerToken = *iv.OutputPrice + pricing.OutputPricePerTokenPriority = *iv.OutputPrice + } + if iv.CacheWritePrice != nil { + pricing.CacheCreationPricePerToken = *iv.CacheWritePrice + pricing.CacheCreation5mPrice = *iv.CacheWritePrice + pricing.CacheCreation1hPrice = *iv.CacheWritePrice + } + if iv.CacheReadPrice != nil { + pricing.CacheReadPricePerToken = *iv.CacheReadPrice + pricing.CacheReadPricePerTokenPriority = *iv.CacheReadPrice + } + return pricing +} + +// GetRequestTierPrice 根据层级标签获取按次价格 +func (r *ModelPricingResolver) GetRequestTierPrice(resolved *ResolvedPricing, tierLabel string) float64 { + for _, tier := range resolved.RequestTiers { + if tier.TierLabel == tierLabel && tier.PerRequestPrice != nil { + return *tier.PerRequestPrice + } + } + return 0 +} + +// GetRequestTierPriceByContext 根据 context token 数获取按次价格 +func (r *ModelPricingResolver) GetRequestTierPriceByContext(resolved *ResolvedPricing, totalContextTokens int) float64 { + iv := FindMatchingInterval(resolved.RequestTiers, totalContextTokens) + if iv != nil && iv.PerRequestPrice != nil { + return *iv.PerRequestPrice + } + return 0 +} diff --git a/backend/internal/service/model_pricing_resolver_test.go b/backend/internal/service/model_pricing_resolver_test.go new file mode 100644 index 0000000000..5b4a0b135d --- /dev/null +++ b/backend/internal/service/model_pricing_resolver_test.go @@ -0,0 +1,164 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func resolverPtrFloat64(v float64) *float64 { return &v } +func resolverPtrInt(v int) *int { return &v } + +func newTestBillingServiceForResolver() *BillingService { + bs := &BillingService{ + fallbackPrices: make(map[string]*ModelPricing), + } + bs.fallbackPrices["claude-sonnet-4"] = &ModelPricing{ + InputPricePerToken: 3e-6, + OutputPricePerToken: 15e-6, + CacheCreationPricePerToken: 3.75e-6, + CacheReadPricePerToken: 0.3e-6, + SupportsCacheBreakdown: false, + } + return bs +} + +func TestResolve_NoGroupID(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: nil, + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModeToken, resolved.Mode) + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12) + // BillingService.GetModelPricing uses fallback internally, but resolveBasePricing + // reports "litellm" when GetModelPricing succeeds (regardless of internal source) + require.Equal(t, "litellm", resolved.Source) +} + +func TestResolve_UnknownModel(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "unknown-model-xyz", + GroupID: nil, + }) + + require.NotNil(t, resolved) + require.Nil(t, resolved.BasePricing) + // Unknown model: GetModelPricing returns error, source is "fallback" + require.Equal(t, "fallback", resolved.Source) +} + +func TestGetIntervalPricing_NoIntervals(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + basePricing := &ModelPricing{InputPricePerToken: 5e-6} + resolved := &ResolvedPricing{ + Mode: BillingModeToken, + BasePricing: basePricing, + Intervals: nil, + } + + result := r.GetIntervalPricing(resolved, 50000) + require.Equal(t, basePricing, result) +} + +func TestGetIntervalPricing_MatchesInterval(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModeToken, + BasePricing: &ModelPricing{InputPricePerToken: 5e-6}, + SupportsCacheBreakdown: true, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: resolverPtrInt(128000), InputPrice: resolverPtrFloat64(1e-6), OutputPrice: resolverPtrFloat64(2e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: resolverPtrFloat64(3e-6), OutputPrice: resolverPtrFloat64(6e-6)}, + }, + } + + result := r.GetIntervalPricing(resolved, 50000) + require.NotNil(t, result) + require.InDelta(t, 1e-6, result.InputPricePerToken, 1e-12) + require.InDelta(t, 2e-6, result.OutputPricePerToken, 1e-12) + require.True(t, result.SupportsCacheBreakdown) + + result2 := r.GetIntervalPricing(resolved, 200000) + require.NotNil(t, result2) + require.InDelta(t, 3e-6, result2.InputPricePerToken, 1e-12) +} + +func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + basePricing := &ModelPricing{InputPricePerToken: 99e-6} + resolved := &ResolvedPricing{ + Mode: BillingModeToken, + BasePricing: basePricing, + Intervals: []PricingInterval{ + {MinTokens: 10000, MaxTokens: resolverPtrInt(50000), InputPrice: resolverPtrFloat64(1e-6)}, + }, + } + + result := r.GetIntervalPricing(resolved, 5000) + require.Equal(t, basePricing, result) +} + +func TestGetRequestTierPrice(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: resolverPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: resolverPtrFloat64(0.08)}, + }, + } + + require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12) + require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12) + require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "4K"), 1e-12) +} + +func TestGetRequestTierPriceByContext(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{ + {MinTokens: 0, MaxTokens: resolverPtrInt(128000), PerRequestPrice: resolverPtrFloat64(0.05)}, + {MinTokens: 128000, MaxTokens: nil, PerRequestPrice: resolverPtrFloat64(0.10)}, + }, + } + + require.InDelta(t, 0.05, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12) + require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12) +} + +func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: nil}, + }, + } + + require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index d79a353124..717e22ffc5 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet( ProvideScheduledTestService, ProvideScheduledTestRunnerService, NewGroupCapacityService, + NewChannelService, + NewModelPricingResolver, ) diff --git a/backend/migrations/081_create_channels.sql b/backend/migrations/081_create_channels.sql new file mode 100644 index 0000000000..3059816b2e --- /dev/null +++ b/backend/migrations/081_create_channels.sql @@ -0,0 +1,56 @@ +-- Create channels table for managing pricing channels. +-- A channel groups multiple groups together and provides custom model pricing. + +SET LOCAL lock_timeout = '5s'; +SET LOCAL statement_timeout = '10min'; + +-- 渠道表 +CREATE TABLE IF NOT EXISTS channels ( + id BIGSERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + description TEXT DEFAULT '', + status VARCHAR(20) NOT NULL DEFAULT 'active', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- 渠道名称唯一索引 +CREATE UNIQUE INDEX IF NOT EXISTS idx_channels_name ON channels (name); +CREATE INDEX IF NOT EXISTS idx_channels_status ON channels (status); + +-- 渠道-分组关联表(每个分组只能属于一个渠道) +CREATE TABLE IF NOT EXISTS channel_groups ( + id BIGSERIAL PRIMARY KEY, + channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE, + group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_groups_group_id ON channel_groups (group_id); +CREATE INDEX IF NOT EXISTS idx_channel_groups_channel_id ON channel_groups (channel_id); + +-- 渠道模型定价表(一条定价可绑定多个模型) +CREATE TABLE IF NOT EXISTS channel_model_pricing ( + id BIGSERIAL PRIMARY KEY, + channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE, + models JSONB NOT NULL DEFAULT '[]', + input_price NUMERIC(20,12), + output_price NUMERIC(20,12), + cache_write_price NUMERIC(20,12), + cache_read_price NUMERIC(20,12), + image_output_price NUMERIC(20,8), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_channel_id ON channel_model_pricing (channel_id); + +COMMENT ON TABLE channels IS '渠道管理:关联多个分组,提供自定义模型定价'; +COMMENT ON TABLE channel_groups IS '渠道-分组关联表:每个分组最多属于一个渠道'; +COMMENT ON TABLE channel_model_pricing IS '渠道模型定价:一条定价可绑定多个模型,价格一致'; +COMMENT ON COLUMN channel_model_pricing.models IS '绑定的模型列表,JSON 数组,如 ["claude-opus-4-6","claude-opus-4-6-thinking"]'; +COMMENT ON COLUMN channel_model_pricing.input_price IS '每 token 输入价格(USD),NULL 表示使用默认'; +COMMENT ON COLUMN channel_model_pricing.output_price IS '每 token 输出价格(USD),NULL 表示使用默认'; +COMMENT ON COLUMN channel_model_pricing.cache_write_price IS '缓存写入每 token 价格,NULL 表示使用默认'; +COMMENT ON COLUMN channel_model_pricing.cache_read_price IS '缓存读取每 token 价格,NULL 表示使用默认'; +COMMENT ON COLUMN channel_model_pricing.image_output_price IS '图片输出价格(Gemini Image 等),NULL 表示使用默认'; diff --git a/backend/migrations/082_refactor_channel_pricing.sql b/backend/migrations/082_refactor_channel_pricing.sql new file mode 100644 index 0000000000..d0a5406246 --- /dev/null +++ b/backend/migrations/082_refactor_channel_pricing.sql @@ -0,0 +1,67 @@ +-- Extend channel_model_pricing with billing_mode and add context-interval child table. +-- Supports three billing modes: token (per-token with context intervals), +-- per_request (per-request with context-size tiers), and image (per-image). + +SET LOCAL lock_timeout = '5s'; +SET LOCAL statement_timeout = '10min'; + +-- 1. 为 channel_model_pricing 添加 billing_mode 列 +ALTER TABLE channel_model_pricing + ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20) NOT NULL DEFAULT 'token'; + +COMMENT ON COLUMN channel_model_pricing.billing_mode IS '计费模式:token(按 token 区间计费)、per_request(按次计费)、image(图片计费)'; + +-- 2. 创建区间定价子表 +CREATE TABLE IF NOT EXISTS channel_pricing_intervals ( + id BIGSERIAL PRIMARY KEY, + pricing_id BIGINT NOT NULL REFERENCES channel_model_pricing(id) ON DELETE CASCADE, + min_tokens INT NOT NULL DEFAULT 0, + max_tokens INT, + tier_label VARCHAR(50), + input_price NUMERIC(20,12), + output_price NUMERIC(20,12), + cache_write_price NUMERIC(20,12), + cache_read_price NUMERIC(20,12), + per_request_price NUMERIC(20,12), + sort_order INT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_channel_pricing_intervals_pricing_id + ON channel_pricing_intervals (pricing_id); + +COMMENT ON TABLE channel_pricing_intervals IS '渠道定价区间:支持按 token 区间、按次分层、图片分辨率分层'; +COMMENT ON COLUMN channel_pricing_intervals.min_tokens IS '区间下界(含),token 模式使用'; +COMMENT ON COLUMN channel_pricing_intervals.max_tokens IS '区间上界(不含),NULL 表示无上限'; +COMMENT ON COLUMN channel_pricing_intervals.tier_label IS '层级标签,按次/图片模式使用(如 1K、2K、4K、HD)'; +COMMENT ON COLUMN channel_pricing_intervals.input_price IS 'token 模式:每 token 输入价'; +COMMENT ON COLUMN channel_pricing_intervals.output_price IS 'token 模式:每 token 输出价'; +COMMENT ON COLUMN channel_pricing_intervals.cache_write_price IS 'token 模式:缓存写入价'; +COMMENT ON COLUMN channel_pricing_intervals.cache_read_price IS 'token 模式:缓存读取价'; +COMMENT ON COLUMN channel_pricing_intervals.per_request_price IS '按次/图片模式:每次请求价格'; + +-- 3. 迁移现有 flat 定价为单区间 [0, +inf) +-- 仅迁移有明确定价(至少一个价格字段非 NULL)的条目 +INSERT INTO channel_pricing_intervals (pricing_id, min_tokens, max_tokens, input_price, output_price, cache_write_price, cache_read_price, sort_order) +SELECT + cmp.id, + 0, + NULL, + cmp.input_price, + cmp.output_price, + cmp.cache_write_price, + cmp.cache_read_price, + 0 +FROM channel_model_pricing cmp +WHERE cmp.billing_mode = 'token' + AND (cmp.input_price IS NOT NULL OR cmp.output_price IS NOT NULL + OR cmp.cache_write_price IS NOT NULL OR cmp.cache_read_price IS NOT NULL) + AND NOT EXISTS ( + SELECT 1 FROM channel_pricing_intervals cpi WHERE cpi.pricing_id = cmp.id + ); + +-- 4. 迁移 image_output_price 为 image 模式的区间条目 +-- 将有 image_output_price 的现有条目复制为 billing_mode='image' 的独立条目 +-- 注意:这里不改变原条目的 billing_mode,而是将 image_output_price 作为向后兼容字段保留 +-- 实际的 image 计费在未来由独立的 billing_mode='image' 条目处理 diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts new file mode 100644 index 0000000000..0b86fcaa6f --- /dev/null +++ b/frontend/src/api/admin/channels.ts @@ -0,0 +1,121 @@ +/** + * Admin Channels API endpoints + * Handles channel management for administrators + */ + +import { apiClient } from '../client' + +export type BillingMode = 'token' | 'per_request' | 'image' + +export interface PricingInterval { + id?: number + min_tokens: number + max_tokens: number | null + tier_label: string + input_price: number | null + output_price: number | null + cache_write_price: number | null + cache_read_price: number | null + per_request_price: number | null + sort_order: number +} + +export interface ChannelModelPricing { + id?: number + models: string[] + billing_mode: BillingMode + input_price: number | null + output_price: number | null + cache_write_price: number | null + cache_read_price: number | null + image_output_price: number | null + intervals: PricingInterval[] +} + +export interface Channel { + id: number + name: string + description: string + status: string + group_ids: number[] + model_pricing: ChannelModelPricing[] + created_at: string + updated_at: string +} + +export interface CreateChannelRequest { + name: string + description?: string + group_ids?: number[] + model_pricing?: ChannelModelPricing[] +} + +export interface UpdateChannelRequest { + name?: string + description?: string + status?: string + group_ids?: number[] + model_pricing?: ChannelModelPricing[] +} + +interface PaginatedResponse { + items: T[] + total: number +} + +/** + * List channels with pagination + */ +export async function list( + page: number = 1, + pageSize: number = 20, + filters?: { + status?: string + search?: string + }, + options?: { signal?: AbortSignal } +): Promise> { + const { data } = await apiClient.get>('/admin/channels', { + params: { + page, + page_size: pageSize, + ...filters + }, + signal: options?.signal + }) + return data +} + +/** + * Get channel by ID + */ +export async function getById(id: number): Promise { + const { data } = await apiClient.get(`/admin/channels/${id}`) + return data +} + +/** + * Create a new channel + */ +export async function create(req: CreateChannelRequest): Promise { + const { data } = await apiClient.post('/admin/channels', req) + return data +} + +/** + * Update a channel + */ +export async function update(id: number, req: UpdateChannelRequest): Promise { + const { data } = await apiClient.put(`/admin/channels/${id}`, req) + return data +} + +/** + * Delete a channel + */ +export async function remove(id: number): Promise { + await apiClient.delete(`/admin/channels/${id}`) +} + +const channelsAPI = { list, getById, create, update, remove } +export default channelsAPI diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index 9a3fb8c510..da1e3cfa8f 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -25,6 +25,7 @@ import apiKeysAPI from './apiKeys' import scheduledTestsAPI from './scheduledTests' import backupAPI from './backup' import tlsFingerprintProfileAPI from './tlsFingerprintProfile' +import channelsAPI from './channels' /** * Unified admin API object for convenient access @@ -51,7 +52,8 @@ export const adminAPI = { apiKeys: apiKeysAPI, scheduledTests: scheduledTestsAPI, backup: backupAPI, - tlsFingerprintProfiles: tlsFingerprintProfileAPI + tlsFingerprintProfiles: tlsFingerprintProfileAPI, + channels: channelsAPI } export { @@ -76,7 +78,8 @@ export { apiKeysAPI, scheduledTestsAPI, backupAPI, - tlsFingerprintProfileAPI + tlsFingerprintProfileAPI, + channelsAPI } export default adminAPI diff --git a/frontend/src/components/admin/channel/IntervalRow.vue b/frontend/src/components/admin/channel/IntervalRow.vue new file mode 100644 index 0000000000..ad7447e831 --- /dev/null +++ b/frontend/src/components/admin/channel/IntervalRow.vue @@ -0,0 +1,160 @@ + + + diff --git a/frontend/src/components/admin/channel/PricingEntryCard.vue b/frontend/src/components/admin/channel/PricingEntryCard.vue new file mode 100644 index 0000000000..077aea883b --- /dev/null +++ b/frontend/src/components/admin/channel/PricingEntryCard.vue @@ -0,0 +1,260 @@ + + + diff --git a/frontend/src/components/admin/channel/types.ts b/frontend/src/components/admin/channel/types.ts new file mode 100644 index 0000000000..9a90a6b4e2 --- /dev/null +++ b/frontend/src/components/admin/channel/types.ts @@ -0,0 +1,59 @@ +import type { BillingMode, PricingInterval } from '@/api/admin/channels' + +export interface IntervalFormEntry { + min_tokens: number + max_tokens: number | null + tier_label: string + input_price: number | string | null + output_price: number | string | null + cache_write_price: number | string | null + cache_read_price: number | string | null + per_request_price: number | string | null + sort_order: number +} + +export interface PricingFormEntry { + modelsInput: string + billing_mode: BillingMode + input_price: number | string | null + output_price: number | string | null + cache_write_price: number | string | null + cache_read_price: number | string | null + per_request_price: number | string | null + image_output_price: number | string | null + intervals: IntervalFormEntry[] +} + +export function toNullableNumber(val: number | string | null | undefined): number | null { + if (val === null || val === undefined || val === '') return null + const num = Number(val) + return isNaN(num) ? null : num +} + +export function apiIntervalsToForm(intervals: PricingInterval[]): IntervalFormEntry[] { + return (intervals || []).map(iv => ({ + min_tokens: iv.min_tokens, + max_tokens: iv.max_tokens, + tier_label: iv.tier_label || '', + input_price: iv.input_price, + output_price: iv.output_price, + cache_write_price: iv.cache_write_price, + cache_read_price: iv.cache_read_price, + per_request_price: iv.per_request_price, + sort_order: iv.sort_order + })) +} + +export function formIntervalsToAPI(intervals: IntervalFormEntry[]): PricingInterval[] { + return (intervals || []).map(iv => ({ + min_tokens: iv.min_tokens, + max_tokens: iv.max_tokens, + tier_label: iv.tier_label, + input_price: toNullableNumber(iv.input_price), + output_price: toNullableNumber(iv.output_price), + cache_write_price: toNullableNumber(iv.cache_write_price), + cache_read_price: toNullableNumber(iv.cache_read_price), + per_request_price: toNullableNumber(iv.per_request_price), + sort_order: iv.sort_order + })) +} diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index 2e5babeba3..80a3eed8f3 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -287,6 +287,21 @@ const FolderIcon = { ) } +const ChannelIcon = { + render: () => + h( + 'svg', + { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, + [ + h('path', { + 'stroke-linecap': 'round', + 'stroke-linejoin': 'round', + d: 'M6.429 9.75L2.25 12l4.179 2.25m0-4.5l5.571 3 5.571-3m-11.142 0L2.25 7.5 12 2.25l9.75 5.25-4.179 2.25m0 0l4.179 2.25L12 17.25 2.25 12m15.321-2.25l4.179 2.25L12 17.25l-9.75-5.25' + }) + ] + ) +} + const CreditCardIcon = { render: () => h( @@ -568,6 +583,7 @@ const adminNavItems = computed((): NavItem[] => { : []), { path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true }, { path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true }, + { path: '/admin/channels', label: t('nav.channels', '渠道管理'), icon: ChannelIcon, hideInSimpleMode: true }, { path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true }, { path: '/admin/accounts', label: t('nav.accounts'), icon: GlobeIcon }, { path: '/admin/announcements', label: t('nav.announcements'), icon: BellIcon }, diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 0ffef1a372..67f52ea04b 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -278,6 +278,16 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'admin.groups.description' } }, + { + path: '/admin/channels', + name: 'AdminChannels', + component: () => import('@/views/admin/ChannelsView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Channel Management' + } + }, { path: '/admin/subscriptions', name: 'AdminSubscriptions', diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue new file mode 100644 index 0000000000..072efed1f4 --- /dev/null +++ b/frontend/src/views/admin/ChannelsView.vue @@ -0,0 +1,628 @@ + + + From 983fe5895924efbd441db06cfe1aa34ad1953b33 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 30 Mar 2026 02:14:30 +0800 Subject: [PATCH 02/67] =?UTF-8?q?fix:=20CI=20lint/test=20fixes=20=E2=80=94?= =?UTF-8?q?=20gofmt,=20errcheck,=20handler=20test=20args?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../internal/handler/admin/channel_handler.go | 70 +++++++++---------- ...eway_handler_warmup_intercept_unit_test.go | 1 + .../handler/sora_client_handler_test.go | 2 +- .../handler/sora_gateway_handler_test.go | 1 + backend/internal/repository/channel_repo.go | 10 +-- .../repository/channel_repo_pricing.go | 6 +- backend/internal/service/billing_service.go | 8 +-- backend/internal/service/channel.go | 18 ++--- 8 files changed, 59 insertions(+), 57 deletions(-) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index fb6f7d0204..492b6b8f72 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -24,29 +24,29 @@ func NewChannelHandler(channelService *service.ChannelService) *ChannelHandler { // --- Request / Response types --- type createChannelRequest struct { - Name string `json:"name" binding:"required,max=100"` - Description string `json:"description"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingRequest `json:"model_pricing"` + Name string `json:"name" binding:"required,max=100"` + Description string `json:"description"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingRequest `json:"model_pricing"` } type updateChannelRequest struct { - Name string `json:"name" binding:"omitempty,max=100"` - Description *string `json:"description"` - Status string `json:"status" binding:"omitempty,oneof=active disabled"` - GroupIDs *[]int64 `json:"group_ids"` - ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` + Name string `json:"name" binding:"omitempty,max=100"` + Description *string `json:"description"` + Status string `json:"status" binding:"omitempty,oneof=active disabled"` + GroupIDs *[]int64 `json:"group_ids"` + ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` } type channelModelPricingRequest struct { - Models []string `json:"models" binding:"required,min=1,max=100"` - BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` - InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` - OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"` - CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"` - CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"` - ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"` - Intervals []pricingIntervalRequest `json:"intervals"` + Models []string `json:"models" binding:"required,min=1,max=100"` + BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` + InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` + OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"` + CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"` + CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"` + ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"` + Intervals []pricingIntervalRequest `json:"intervals"` } type pricingIntervalRequest struct { @@ -62,26 +62,26 @@ type pricingIntervalRequest struct { } type channelResponse struct { - ID int64 `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Status string `json:"status"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingResponse `json:"model_pricing"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingResponse `json:"model_pricing"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` } type channelModelPricingResponse struct { - ID int64 `json:"id"` - Models []string `json:"models"` - BillingMode string `json:"billing_mode"` - InputPrice *float64 `json:"input_price"` - OutputPrice *float64 `json:"output_price"` - CacheWritePrice *float64 `json:"cache_write_price"` - CacheReadPrice *float64 `json:"cache_read_price"` - ImageOutputPrice *float64 `json:"image_output_price"` - Intervals []pricingIntervalResponse `json:"intervals"` + ID int64 `json:"id"` + Models []string `json:"models"` + BillingMode string `json:"billing_mode"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + ImageOutputPrice *float64 `json:"image_output_price"` + Intervals []pricingIntervalResponse `json:"intervals"` } type pricingIntervalResponse struct { @@ -106,7 +106,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { Name: ch.Name, Description: ch.Description, Status: ch.Status, - GroupIDs: ch.GroupIDs, + GroupIDs: ch.GroupIDs, CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"), } diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 69c8d1d582..7dc062df0a 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -161,6 +161,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // digestStore nil, // settingService nil, // tlsFPProfileService + nil, // channelService ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index fe035b6f7f..78e2d24bc2 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2224,7 +2224,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { return service.NewGatewayService( accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ) } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index c790a36c06..18e6e92971 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -465,6 +465,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, // digestStore nil, // settingService nil, // tlsFPProfileService + nil, // channelService ) soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go index aa8696abff..9259edd6cd 100644 --- a/backend/internal/repository/channel_repo.go +++ b/backend/internal/repository/channel_repo.go @@ -186,7 +186,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati if err != nil { return nil, nil, fmt.Errorf("query channels: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() var channels []service.Channel var channelIDs []int64 @@ -240,7 +240,7 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err if err != nil { return nil, fmt.Errorf("query all channels: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() var channels []service.Channel var channelIDs []int64 @@ -292,7 +292,7 @@ func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs [] if err != nil { return nil, fmt.Errorf("batch load group ids: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() groupMap := make(map[int64][]int64, len(channelIDs)) for rows.Next() { @@ -333,7 +333,7 @@ func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([ if err != nil { return nil, fmt.Errorf("get group ids: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() var ids []int64 for rows.Next() { @@ -375,7 +375,7 @@ func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channe if err != nil { return nil, fmt.Errorf("get groups in other channels: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() var conflicting []int64 for rows.Next() { diff --git a/backend/internal/repository/channel_repo_pricing.go b/backend/internal/repository/channel_repo_pricing.go index 2e7ec6a311..87c856f8ff 100644 --- a/backend/internal/repository/channel_repo_pricing.go +++ b/backend/internal/repository/channel_repo_pricing.go @@ -21,7 +21,7 @@ func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int6 if err != nil { return nil, fmt.Errorf("list model pricing: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() result, pricingIDs, err := scanModelPricingRows(rows) if err != nil { @@ -97,7 +97,7 @@ func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelID if err != nil { return nil, fmt.Errorf("batch load model pricing: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() allPricing, allPricingIDs, err := scanModelPricingRows(rows) if err != nil { @@ -139,7 +139,7 @@ func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs [ if err != nil { return nil, fmt.Errorf("batch load intervals: %w", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs)) for rows.Next() { diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 58c86f36ee..7deb1cf991 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -413,12 +413,12 @@ func (s *BillingService) CalculateCostWithChannel(model string, tokens UsageToke type CostInput struct { Ctx context.Context Model string - GroupID *int64 // 用于渠道定价查找 + GroupID *int64 // 用于渠道定价查找 Tokens UsageTokens - RequestCount int // 按次计费时使用 - SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等) + RequestCount int // 按次计费时使用 + SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等) RateMultiplier float64 - ServiceTier string // "priority","flex","" 等 + ServiceTier string // "priority","flex","" 等 Resolver *ModelPricingResolver // 定价解析器 } diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index e3556edd0d..f408f246dd 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -10,8 +10,8 @@ type BillingMode string const ( BillingModeToken BillingMode = "token" // 按 token 区间计费 - BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层) - BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费) + BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层) + BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费) ) // IsValid 检查 BillingMode 是否为合法值 @@ -42,13 +42,13 @@ type Channel struct { type ChannelModelPricing struct { ID int64 ChannelID int64 - Models []string // 绑定的模型列表 - BillingMode BillingMode // 计费模式 - InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 - OutputPrice *float64 // 每 token 输出价格(USD) - CacheWritePrice *float64 // 缓存写入价格 - CacheReadPrice *float64 // 缓存读取价格 - ImageOutputPrice *float64 // 图片输出价格(向后兼容) + Models []string // 绑定的模型列表 + BillingMode BillingMode // 计费模式 + InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 + OutputPrice *float64 // 每 token 输出价格(USD) + CacheWritePrice *float64 // 缓存写入价格 + CacheReadPrice *float64 // 缓存读取价格 + ImageOutputPrice *float64 // 图片输出价格(向后兼容) Intervals []PricingInterval // 区间定价列表 CreatedAt time.Time UpdatedAt time.Time From dca0054e93cdf45859a0a0c9ad873647baf620ab Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 30 Mar 2026 02:24:54 +0800 Subject: [PATCH 03/67] =?UTF-8?q?feat(channel):=20=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=A0=87=E7=AD=BE=E8=BE=93=E5=85=A5=20+=20$/MTok=20=E4=BB=B7?= =?UTF-8?q?=E6=A0=BC=E5=8D=95=E4=BD=8D=20+=20=E5=B7=A6=E5=BC=80=E5=8F=B3?= =?UTF-8?q?=E9=97=AD=E5=8C=BA=E9=97=B4=20+=20i18n?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 模型输入改为标签列表(输入回车添加,支持粘贴批量导入) - 价格显示单位改为 $/MTok(每百万 token),提交时自动转换 - Token 模式增加图片输出价格字段(适配 Gemini 图片模型按 token 计费) - 区间边界改为左开右闭 (min, max],右边界包含 - 默认价格作为未命中区间时的回退价格 - 添加完整中英文 i18n 翻译 --- backend/internal/service/channel.go | 5 +- backend/internal/service/channel_test.go | 15 +- .../components/admin/channel/IntervalRow.vue | 125 ++++---------- .../admin/channel/ModelTagInput.vue | 86 ++++++++++ .../admin/channel/PricingEntryCard.vue | 154 +++++++----------- .../src/components/admin/channel/types.ts | 34 ++-- frontend/src/i18n/locales/en.ts | 74 +++++++++ frontend/src/i18n/locales/zh.ts | 74 +++++++++ frontend/src/views/admin/ChannelsView.vue | 32 ++-- 9 files changed, 375 insertions(+), 224 deletions(-) create mode 100644 frontend/src/components/admin/channel/ModelTagInput.vue diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index f408f246dd..be82b997f7 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -110,11 +110,12 @@ func (c *Channel) GetModelPricing(model string) *ChannelModelPricing { } // FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。 -// 通用辅助函数,供 GetIntervalForContext、ModelPricingResolver 等复用。 +// 区间为左开右闭 (min, max]:min 不含,max 包含。 +// 第一个区间 min=0 时,0 token 不匹配任何区间(回退到默认价格)。 func FindMatchingInterval(intervals []PricingInterval, totalTokens int) *PricingInterval { for i := range intervals { iv := &intervals[i] - if totalTokens >= iv.MinTokens && (iv.MaxTokens == nil || totalTokens < *iv.MaxTokens) { + if totalTokens > iv.MinTokens && (iv.MaxTokens == nil || totalTokens <= *iv.MaxTokens) { return iv } } diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go index 004d06b1d7..0c055ce41c 100644 --- a/backend/internal/service/channel_test.go +++ b/backend/internal/service/channel_test.go @@ -87,10 +87,13 @@ func TestGetIntervalForContext(t *testing.T) { wantNil bool }{ {"first interval", 50000, channelTestPtrFloat64(1e-6), false}, - {"boundary: at min of second", 128000, channelTestPtrFloat64(2e-6), false}, - {"boundary: at max of first (exclusive)", 128000, channelTestPtrFloat64(2e-6), false}, + // (min, max] — 128000 在第一个区间的 max,包含,所以匹配第一个 + {"boundary: max of first (inclusive)", 128000, channelTestPtrFloat64(1e-6), false}, + // 128001 > 128000,匹配第二个区间 + {"boundary: just above first max", 128001, channelTestPtrFloat64(2e-6), false}, {"unbounded interval", 500000, channelTestPtrFloat64(2e-6), false}, - {"zero tokens", 0, channelTestPtrFloat64(1e-6), false}, + // (0, max] — 0 不匹配任何区间(左开) + {"zero tokens: no match", 0, nil, true}, } for _, tt := range tests { @@ -112,8 +115,10 @@ func TestGetIntervalForContext_NoMatch(t *testing.T) { {MinTokens: 10000, MaxTokens: channelTestPtrInt(50000)}, }, } - require.Nil(t, p.GetIntervalForContext(5000)) - require.Nil(t, p.GetIntervalForContext(50000)) + require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min + require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open) + require.NotNil(t, p.GetIntervalForContext(50000)) // 50000 <= 50000 (right-closed) + require.Nil(t, p.GetIntervalForContext(50001)) // 50001 > 50000 } func TestGetIntervalForContext_Empty(t *testing.T) { diff --git a/frontend/src/components/admin/channel/IntervalRow.vue b/frontend/src/components/admin/channel/IntervalRow.vue index ad7447e831..6f6e582673 100644 --- a/frontend/src/components/admin/channel/IntervalRow.vue +++ b/frontend/src/components/admin/channel/IntervalRow.vue @@ -1,125 +1,66 @@
- +
+ + +
+ + + + + +
+ + - +
-
- - -
- -
- {{ t('admin.channels.form.noPricingRules', 'No pricing rules yet. Click "Add" to create one.') }} -
- -
- -
-
- - -
- +