diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ce898a4a90..513b7996db 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -49,6 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) groupRepository := repository.NewGroupRepository(client, db) + channelRepository := repository.NewChannelRepository(db) settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) @@ -138,11 +139,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache) internal500CounterCache := repository.NewInternal500CounterCache(redisClient) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client) tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) @@ -175,9 +176,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService) + channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator) + modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) @@ -213,7 +216,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db) scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler) + channelHandler := admin.NewChannelHandler(channelService, billingService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 6c56f2d098..bdbb9fdddd 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -744,6 +744,10 @@ var ( {Name: "model", Type: field.TypeString, Size: 100}, {Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100}, {Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100}, + {Name: "channel_id", Type: field.TypeInt64, Nullable: true}, + {Name: "model_mapping_chain", Type: field.TypeString, Nullable: true, Size: 500}, + {Name: "billing_tier", Type: field.TypeString, Nullable: true, Size: 50}, + {Name: "billing_mode", Type: field.TypeString, Nullable: true, Size: 20}, {Name: "input_tokens", Type: field.TypeInt, Default: 0}, {Name: "output_tokens", Type: field.TypeInt, Default: 0}, {Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0}, @@ -783,31 +787,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[34]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[35]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[36]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[33]}, + Columns: []*schema.Column{UsageLogsColumns[37]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[34]}, + Columns: []*schema.Column{UsageLogsColumns[38]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -816,32 +820,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[33]}, + Columns: []*schema.Column{UsageLogsColumns[37]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[34]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[35]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[36]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[34]}, + Columns: []*schema.Column{UsageLogsColumns[38]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, }, { Name: "usagelog_model", @@ -861,17 +865,17 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[33]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[34], UsageLogsColumns[33]}, }, { Name: "usagelog_group_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[33]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index a862209dec..28d9a0ef22 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -19725,6 +19725,11 @@ type UsageLogMutation struct { model *string requested_model *string upstream_model *string + channel_id *int64 + addchannel_id *int64 + model_mapping_chain *string + billing_tier *string + billing_mode *string input_tokens *int addinput_tokens *int output_tokens *int @@ -20160,6 +20165,223 @@ func (m *UsageLogMutation) ResetUpstreamModel() { delete(m.clearedFields, usagelog.FieldUpstreamModel) } +// SetChannelID sets the "channel_id" field. +func (m *UsageLogMutation) SetChannelID(i int64) { + m.channel_id = &i + m.addchannel_id = nil +} + +// ChannelID returns the value of the "channel_id" field in the mutation. +func (m *UsageLogMutation) ChannelID() (r int64, exists bool) { + v := m.channel_id + if v == nil { + return + } + return *v, true +} + +// OldChannelID returns the old "channel_id" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldChannelID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannelID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannelID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannelID: %w", err) + } + return oldValue.ChannelID, nil +} + +// AddChannelID adds i to the "channel_id" field. +func (m *UsageLogMutation) AddChannelID(i int64) { + if m.addchannel_id != nil { + *m.addchannel_id += i + } else { + m.addchannel_id = &i + } +} + +// AddedChannelID returns the value that was added to the "channel_id" field in this mutation. +func (m *UsageLogMutation) AddedChannelID() (r int64, exists bool) { + v := m.addchannel_id + if v == nil { + return + } + return *v, true +} + +// ClearChannelID clears the value of the "channel_id" field. +func (m *UsageLogMutation) ClearChannelID() { + m.channel_id = nil + m.addchannel_id = nil + m.clearedFields[usagelog.FieldChannelID] = struct{}{} +} + +// ChannelIDCleared returns if the "channel_id" field was cleared in this mutation. +func (m *UsageLogMutation) ChannelIDCleared() bool { + _, ok := m.clearedFields[usagelog.FieldChannelID] + return ok +} + +// ResetChannelID resets all changes to the "channel_id" field. +func (m *UsageLogMutation) ResetChannelID() { + m.channel_id = nil + m.addchannel_id = nil + delete(m.clearedFields, usagelog.FieldChannelID) +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (m *UsageLogMutation) SetModelMappingChain(s string) { + m.model_mapping_chain = &s +} + +// ModelMappingChain returns the value of the "model_mapping_chain" field in the mutation. +func (m *UsageLogMutation) ModelMappingChain() (r string, exists bool) { + v := m.model_mapping_chain + if v == nil { + return + } + return *v, true +} + +// OldModelMappingChain returns the old "model_mapping_chain" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldModelMappingChain(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModelMappingChain is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModelMappingChain requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModelMappingChain: %w", err) + } + return oldValue.ModelMappingChain, nil +} + +// ClearModelMappingChain clears the value of the "model_mapping_chain" field. +func (m *UsageLogMutation) ClearModelMappingChain() { + m.model_mapping_chain = nil + m.clearedFields[usagelog.FieldModelMappingChain] = struct{}{} +} + +// ModelMappingChainCleared returns if the "model_mapping_chain" field was cleared in this mutation. +func (m *UsageLogMutation) ModelMappingChainCleared() bool { + _, ok := m.clearedFields[usagelog.FieldModelMappingChain] + return ok +} + +// ResetModelMappingChain resets all changes to the "model_mapping_chain" field. +func (m *UsageLogMutation) ResetModelMappingChain() { + m.model_mapping_chain = nil + delete(m.clearedFields, usagelog.FieldModelMappingChain) +} + +// SetBillingTier sets the "billing_tier" field. +func (m *UsageLogMutation) SetBillingTier(s string) { + m.billing_tier = &s +} + +// BillingTier returns the value of the "billing_tier" field in the mutation. +func (m *UsageLogMutation) BillingTier() (r string, exists bool) { + v := m.billing_tier + if v == nil { + return + } + return *v, true +} + +// OldBillingTier returns the old "billing_tier" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldBillingTier(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBillingTier is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBillingTier requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBillingTier: %w", err) + } + return oldValue.BillingTier, nil +} + +// ClearBillingTier clears the value of the "billing_tier" field. +func (m *UsageLogMutation) ClearBillingTier() { + m.billing_tier = nil + m.clearedFields[usagelog.FieldBillingTier] = struct{}{} +} + +// BillingTierCleared returns if the "billing_tier" field was cleared in this mutation. +func (m *UsageLogMutation) BillingTierCleared() bool { + _, ok := m.clearedFields[usagelog.FieldBillingTier] + return ok +} + +// ResetBillingTier resets all changes to the "billing_tier" field. +func (m *UsageLogMutation) ResetBillingTier() { + m.billing_tier = nil + delete(m.clearedFields, usagelog.FieldBillingTier) +} + +// SetBillingMode sets the "billing_mode" field. +func (m *UsageLogMutation) SetBillingMode(s string) { + m.billing_mode = &s +} + +// BillingMode returns the value of the "billing_mode" field in the mutation. +func (m *UsageLogMutation) BillingMode() (r string, exists bool) { + v := m.billing_mode + if v == nil { + return + } + return *v, true +} + +// OldBillingMode returns the old "billing_mode" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldBillingMode(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBillingMode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBillingMode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBillingMode: %w", err) + } + return oldValue.BillingMode, nil +} + +// ClearBillingMode clears the value of the "billing_mode" field. +func (m *UsageLogMutation) ClearBillingMode() { + m.billing_mode = nil + m.clearedFields[usagelog.FieldBillingMode] = struct{}{} +} + +// BillingModeCleared returns if the "billing_mode" field was cleared in this mutation. +func (m *UsageLogMutation) BillingModeCleared() bool { + _, ok := m.clearedFields[usagelog.FieldBillingMode] + return ok +} + +// ResetBillingMode resets all changes to the "billing_mode" field. +func (m *UsageLogMutation) ResetBillingMode() { + m.billing_mode = nil + delete(m.clearedFields, usagelog.FieldBillingMode) +} + // SetGroupID sets the "group_id" field. func (m *UsageLogMutation) SetGroupID(i int64) { m.group = &i @@ -21781,7 +22003,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 34) + fields := make([]string, 0, 38) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -21803,6 +22025,18 @@ func (m *UsageLogMutation) Fields() []string { if m.upstream_model != nil { fields = append(fields, usagelog.FieldUpstreamModel) } + if m.channel_id != nil { + fields = append(fields, usagelog.FieldChannelID) + } + if m.model_mapping_chain != nil { + fields = append(fields, usagelog.FieldModelMappingChain) + } + if m.billing_tier != nil { + fields = append(fields, usagelog.FieldBillingTier) + } + if m.billing_mode != nil { + fields = append(fields, usagelog.FieldBillingMode) + } if m.group != nil { fields = append(fields, usagelog.FieldGroupID) } @@ -21906,6 +22140,14 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.RequestedModel() case usagelog.FieldUpstreamModel: return m.UpstreamModel() + case usagelog.FieldChannelID: + return m.ChannelID() + case usagelog.FieldModelMappingChain: + return m.ModelMappingChain() + case usagelog.FieldBillingTier: + return m.BillingTier() + case usagelog.FieldBillingMode: + return m.BillingMode() case usagelog.FieldGroupID: return m.GroupID() case usagelog.FieldSubscriptionID: @@ -21983,6 +22225,14 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldRequestedModel(ctx) case usagelog.FieldUpstreamModel: return m.OldUpstreamModel(ctx) + case usagelog.FieldChannelID: + return m.OldChannelID(ctx) + case usagelog.FieldModelMappingChain: + return m.OldModelMappingChain(ctx) + case usagelog.FieldBillingTier: + return m.OldBillingTier(ctx) + case usagelog.FieldBillingMode: + return m.OldBillingMode(ctx) case usagelog.FieldGroupID: return m.OldGroupID(ctx) case usagelog.FieldSubscriptionID: @@ -22095,6 +22345,34 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetUpstreamModel(v) return nil + case usagelog.FieldChannelID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetChannelID(v) + return nil + case usagelog.FieldModelMappingChain: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModelMappingChain(v) + return nil + case usagelog.FieldBillingTier: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBillingTier(v) + return nil + case usagelog.FieldBillingMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBillingMode(v) + return nil case usagelog.FieldGroupID: v, ok := value.(int64) if !ok { @@ -22292,6 +22570,9 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { // this mutation. func (m *UsageLogMutation) AddedFields() []string { var fields []string + if m.addchannel_id != nil { + fields = append(fields, usagelog.FieldChannelID) + } if m.addinput_tokens != nil { fields = append(fields, usagelog.FieldInputTokens) } @@ -22354,6 +22635,8 @@ func (m *UsageLogMutation) AddedFields() []string { // was not set, or was not defined in the schema. func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) { switch name { + case usagelog.FieldChannelID: + return m.AddedChannelID() case usagelog.FieldInputTokens: return m.AddedInputTokens() case usagelog.FieldOutputTokens: @@ -22399,6 +22682,13 @@ func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) { // type. func (m *UsageLogMutation) AddField(name string, value ent.Value) error { switch name { + case usagelog.FieldChannelID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddChannelID(v) + return nil case usagelog.FieldInputTokens: v, ok := value.(int) if !ok { @@ -22539,6 +22829,18 @@ func (m *UsageLogMutation) ClearedFields() []string { if m.FieldCleared(usagelog.FieldUpstreamModel) { fields = append(fields, usagelog.FieldUpstreamModel) } + if m.FieldCleared(usagelog.FieldChannelID) { + fields = append(fields, usagelog.FieldChannelID) + } + if m.FieldCleared(usagelog.FieldModelMappingChain) { + fields = append(fields, usagelog.FieldModelMappingChain) + } + if m.FieldCleared(usagelog.FieldBillingTier) { + fields = append(fields, usagelog.FieldBillingTier) + } + if m.FieldCleared(usagelog.FieldBillingMode) { + fields = append(fields, usagelog.FieldBillingMode) + } if m.FieldCleared(usagelog.FieldGroupID) { fields = append(fields, usagelog.FieldGroupID) } @@ -22586,6 +22888,18 @@ func (m *UsageLogMutation) ClearField(name string) error { case usagelog.FieldUpstreamModel: m.ClearUpstreamModel() return nil + case usagelog.FieldChannelID: + m.ClearChannelID() + return nil + case usagelog.FieldModelMappingChain: + m.ClearModelMappingChain() + return nil + case usagelog.FieldBillingTier: + m.ClearBillingTier() + return nil + case usagelog.FieldBillingMode: + m.ClearBillingMode() + return nil case usagelog.FieldGroupID: m.ClearGroupID() return nil @@ -22642,6 +22956,18 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldUpstreamModel: m.ResetUpstreamModel() return nil + case usagelog.FieldChannelID: + m.ResetChannelID() + return nil + case usagelog.FieldModelMappingChain: + m.ResetModelMappingChain() + return nil + case usagelog.FieldBillingTier: + m.ResetBillingTier() + return nil + case usagelog.FieldBillingMode: + m.ResetBillingMode() + return nil case usagelog.FieldGroupID: m.ResetGroupID() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index fd6be291d8..336b1f8243 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -875,92 +875,104 @@ func init() { usagelogDescUpstreamModel := usagelogFields[6].Descriptor() // usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error) + // usagelogDescModelMappingChain is the schema descriptor for model_mapping_chain field. + usagelogDescModelMappingChain := usagelogFields[8].Descriptor() + // usagelog.ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save. + usagelog.ModelMappingChainValidator = usagelogDescModelMappingChain.Validators[0].(func(string) error) + // usagelogDescBillingTier is the schema descriptor for billing_tier field. + usagelogDescBillingTier := usagelogFields[9].Descriptor() + // usagelog.BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save. + usagelog.BillingTierValidator = usagelogDescBillingTier.Validators[0].(func(string) error) + // usagelogDescBillingMode is the schema descriptor for billing_mode field. + usagelogDescBillingMode := usagelogFields[10].Descriptor() + // usagelog.BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save. + usagelog.BillingModeValidator = usagelogDescBillingMode.Validators[0].(func(string) error) // usagelogDescInputTokens is the schema descriptor for input_tokens field. - usagelogDescInputTokens := usagelogFields[9].Descriptor() + usagelogDescInputTokens := usagelogFields[13].Descriptor() // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) // usagelogDescOutputTokens is the schema descriptor for output_tokens field. - usagelogDescOutputTokens := usagelogFields[10].Descriptor() + usagelogDescOutputTokens := usagelogFields[14].Descriptor() // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. - usagelogDescCacheCreationTokens := usagelogFields[11].Descriptor() + usagelogDescCacheCreationTokens := usagelogFields[15].Descriptor() // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. - usagelogDescCacheReadTokens := usagelogFields[12].Descriptor() + usagelogDescCacheReadTokens := usagelogFields[16].Descriptor() // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. - usagelogDescCacheCreation5mTokens := usagelogFields[13].Descriptor() + usagelogDescCacheCreation5mTokens := usagelogFields[17].Descriptor() // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. - usagelogDescCacheCreation1hTokens := usagelogFields[14].Descriptor() + usagelogDescCacheCreation1hTokens := usagelogFields[18].Descriptor() // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) // usagelogDescInputCost is the schema descriptor for input_cost field. - usagelogDescInputCost := usagelogFields[15].Descriptor() + usagelogDescInputCost := usagelogFields[19].Descriptor() // usagelog.DefaultInputCost holds the default value on creation for the input_cost field. usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) // usagelogDescOutputCost is the schema descriptor for output_cost field. - usagelogDescOutputCost := usagelogFields[16].Descriptor() + usagelogDescOutputCost := usagelogFields[20].Descriptor() // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. - usagelogDescCacheCreationCost := usagelogFields[17].Descriptor() + usagelogDescCacheCreationCost := usagelogFields[21].Descriptor() // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. - usagelogDescCacheReadCost := usagelogFields[18].Descriptor() + usagelogDescCacheReadCost := usagelogFields[22].Descriptor() // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) // usagelogDescTotalCost is the schema descriptor for total_cost field. - usagelogDescTotalCost := usagelogFields[19].Descriptor() + usagelogDescTotalCost := usagelogFields[23].Descriptor() // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) // usagelogDescActualCost is the schema descriptor for actual_cost field. - usagelogDescActualCost := usagelogFields[20].Descriptor() + usagelogDescActualCost := usagelogFields[24].Descriptor() // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. - usagelogDescRateMultiplier := usagelogFields[21].Descriptor() + usagelogDescRateMultiplier := usagelogFields[25].Descriptor() // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) // usagelogDescBillingType is the schema descriptor for billing_type field. - usagelogDescBillingType := usagelogFields[23].Descriptor() + usagelogDescBillingType := usagelogFields[27].Descriptor() // usagelog.DefaultBillingType holds the default value on creation for the billing_type field. usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) // usagelogDescStream is the schema descriptor for stream field. - usagelogDescStream := usagelogFields[24].Descriptor() + usagelogDescStream := usagelogFields[28].Descriptor() // usagelog.DefaultStream holds the default value on creation for the stream field. usagelog.DefaultStream = usagelogDescStream.Default.(bool) // usagelogDescUserAgent is the schema descriptor for user_agent field. - usagelogDescUserAgent := usagelogFields[27].Descriptor() + usagelogDescUserAgent := usagelogFields[31].Descriptor() // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) // usagelogDescIPAddress is the schema descriptor for ip_address field. - usagelogDescIPAddress := usagelogFields[28].Descriptor() + usagelogDescIPAddress := usagelogFields[32].Descriptor() // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) // usagelogDescImageCount is the schema descriptor for image_count field. - usagelogDescImageCount := usagelogFields[29].Descriptor() + usagelogDescImageCount := usagelogFields[33].Descriptor() // usagelog.DefaultImageCount holds the default value on creation for the image_count field. usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) // usagelogDescImageSize is the schema descriptor for image_size field. - usagelogDescImageSize := usagelogFields[30].Descriptor() + usagelogDescImageSize := usagelogFields[34].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) // usagelogDescMediaType is the schema descriptor for media_type field. - usagelogDescMediaType := usagelogFields[31].Descriptor() + usagelogDescMediaType := usagelogFields[35].Descriptor() // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. - usagelogDescCacheTTLOverridden := usagelogFields[32].Descriptor() + usagelogDescCacheTTLOverridden := usagelogFields[36].Descriptor() // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[33].Descriptor() + usagelogDescCreatedAt := usagelogFields[37].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index 32c39e2511..f6c725a2b1 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -53,6 +53,10 @@ func (UsageLog) Fields() []ent.Field { MaxLen(100). Optional(). Nillable(), + field.Int64("channel_id").Optional().Nillable().Comment("渠道 ID"), + field.String("model_mapping_chain").MaxLen(500).Optional().Nillable().Comment("模型映射链"), + field.String("billing_tier").MaxLen(50).Optional().Nillable().Comment("计费层级标签"), + field.String("billing_mode").MaxLen(20).Optional().Nillable().Comment("计费模式:token/per_request/image"), field.Int64("group_id"). Optional(). Nillable(), diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index fb4ee1c5d7..b857afdbb5 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -36,6 +36,14 @@ type UsageLog struct { RequestedModel *string `json:"requested_model,omitempty"` // UpstreamModel holds the value of the "upstream_model" field. UpstreamModel *string `json:"upstream_model,omitempty"` + // 渠道 ID + ChannelID *int64 `json:"channel_id,omitempty"` + // 模型映射链 + ModelMappingChain *string `json:"model_mapping_chain,omitempty"` + // 计费层级标签 + BillingTier *string `json:"billing_tier,omitempty"` + // 计费模式:token/per_request/image + BillingMode *string `json:"billing_mode,omitempty"` // GroupID holds the value of the "group_id" field. GroupID *int64 `json:"group_id,omitempty"` // SubscriptionID holds the value of the "subscription_id" field. @@ -177,9 +185,9 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier: values[i] = new(sql.NullFloat64) - case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: + case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -248,6 +256,34 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.UpstreamModel = new(string) *_m.UpstreamModel = value.String } + case usagelog.FieldChannelID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field channel_id", values[i]) + } else if value.Valid { + _m.ChannelID = new(int64) + *_m.ChannelID = value.Int64 + } + case usagelog.FieldModelMappingChain: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field model_mapping_chain", values[i]) + } else if value.Valid { + _m.ModelMappingChain = new(string) + *_m.ModelMappingChain = value.String + } + case usagelog.FieldBillingTier: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field billing_tier", values[i]) + } else if value.Valid { + _m.BillingTier = new(string) + *_m.BillingTier = value.String + } + case usagelog.FieldBillingMode: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field billing_mode", values[i]) + } else if value.Valid { + _m.BillingMode = new(string) + *_m.BillingMode = value.String + } case usagelog.FieldGroupID: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field group_id", values[i]) @@ -505,6 +541,26 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.ChannelID; v != nil { + builder.WriteString("channel_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.ModelMappingChain; v != nil { + builder.WriteString("model_mapping_chain=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.BillingTier; v != nil { + builder.WriteString("billing_tier=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.BillingMode; v != nil { + builder.WriteString("billing_mode=") + builder.WriteString(*v) + } + builder.WriteString(", ") if v := _m.GroupID; v != nil { builder.WriteString("group_id=") builder.WriteString(fmt.Sprintf("%v", *v)) diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index b534f19365..1567ad9b45 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -28,6 +28,14 @@ const ( FieldRequestedModel = "requested_model" // FieldUpstreamModel holds the string denoting the upstream_model field in the database. FieldUpstreamModel = "upstream_model" + // FieldChannelID holds the string denoting the channel_id field in the database. + FieldChannelID = "channel_id" + // FieldModelMappingChain holds the string denoting the model_mapping_chain field in the database. + FieldModelMappingChain = "model_mapping_chain" + // FieldBillingTier holds the string denoting the billing_tier field in the database. + FieldBillingTier = "billing_tier" + // FieldBillingMode holds the string denoting the billing_mode field in the database. + FieldBillingMode = "billing_mode" // FieldGroupID holds the string denoting the group_id field in the database. FieldGroupID = "group_id" // FieldSubscriptionID holds the string denoting the subscription_id field in the database. @@ -141,6 +149,10 @@ var Columns = []string{ FieldModel, FieldRequestedModel, FieldUpstreamModel, + FieldChannelID, + FieldModelMappingChain, + FieldBillingTier, + FieldBillingMode, FieldGroupID, FieldSubscriptionID, FieldInputTokens, @@ -189,6 +201,12 @@ var ( RequestedModelValidator func(string) error // UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. UpstreamModelValidator func(string) error + // ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save. + ModelMappingChainValidator func(string) error + // BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save. + BillingTierValidator func(string) error + // BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save. + BillingModeValidator func(string) error // DefaultInputTokens holds the default value on creation for the "input_tokens" field. DefaultInputTokens int // DefaultOutputTokens holds the default value on creation for the "output_tokens" field. @@ -278,6 +296,26 @@ func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc() } +// ByChannelID orders the results by the channel_id field. +func ByChannelID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannelID, opts...).ToFunc() +} + +// ByModelMappingChain orders the results by the model_mapping_chain field. +func ByModelMappingChain(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldModelMappingChain, opts...).ToFunc() +} + +// ByBillingTier orders the results by the billing_tier field. +func ByBillingTier(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBillingTier, opts...).ToFunc() +} + +// ByBillingMode orders the results by the billing_mode field. +func ByBillingMode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBillingMode, opts...).ToFunc() +} + // ByGroupID orders the results by the group_id field. func ByGroupID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldGroupID, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index f95bceb753..a1fb36cbaa 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -90,6 +90,26 @@ func UpstreamModel(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) } +// ChannelID applies equality check predicate on the "channel_id" field. It's identical to ChannelIDEQ. +func ChannelID(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v)) +} + +// ModelMappingChain applies equality check predicate on the "model_mapping_chain" field. It's identical to ModelMappingChainEQ. +func ModelMappingChain(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v)) +} + +// BillingTier applies equality check predicate on the "billing_tier" field. It's identical to BillingTierEQ. +func BillingTier(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v)) +} + +// BillingMode applies equality check predicate on the "billing_mode" field. It's identical to BillingModeEQ. +func BillingMode(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v)) +} + // GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ. func GroupID(v int64) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) @@ -565,6 +585,281 @@ func UpstreamModelContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v)) } +// ChannelIDEQ applies the EQ predicate on the "channel_id" field. +func ChannelIDEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v)) +} + +// ChannelIDNEQ applies the NEQ predicate on the "channel_id" field. +func ChannelIDNEQ(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldChannelID, v)) +} + +// ChannelIDIn applies the In predicate on the "channel_id" field. +func ChannelIDIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldChannelID, vs...)) +} + +// ChannelIDNotIn applies the NotIn predicate on the "channel_id" field. +func ChannelIDNotIn(vs ...int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldChannelID, vs...)) +} + +// ChannelIDGT applies the GT predicate on the "channel_id" field. +func ChannelIDGT(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldChannelID, v)) +} + +// ChannelIDGTE applies the GTE predicate on the "channel_id" field. +func ChannelIDGTE(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldChannelID, v)) +} + +// ChannelIDLT applies the LT predicate on the "channel_id" field. +func ChannelIDLT(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldChannelID, v)) +} + +// ChannelIDLTE applies the LTE predicate on the "channel_id" field. +func ChannelIDLTE(v int64) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldChannelID, v)) +} + +// ChannelIDIsNil applies the IsNil predicate on the "channel_id" field. +func ChannelIDIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldChannelID)) +} + +// ChannelIDNotNil applies the NotNil predicate on the "channel_id" field. +func ChannelIDNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldChannelID)) +} + +// ModelMappingChainEQ applies the EQ predicate on the "model_mapping_chain" field. +func ModelMappingChainEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v)) +} + +// ModelMappingChainNEQ applies the NEQ predicate on the "model_mapping_chain" field. +func ModelMappingChainNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldModelMappingChain, v)) +} + +// ModelMappingChainIn applies the In predicate on the "model_mapping_chain" field. +func ModelMappingChainIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldModelMappingChain, vs...)) +} + +// ModelMappingChainNotIn applies the NotIn predicate on the "model_mapping_chain" field. +func ModelMappingChainNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldModelMappingChain, vs...)) +} + +// ModelMappingChainGT applies the GT predicate on the "model_mapping_chain" field. +func ModelMappingChainGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldModelMappingChain, v)) +} + +// ModelMappingChainGTE applies the GTE predicate on the "model_mapping_chain" field. +func ModelMappingChainGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldModelMappingChain, v)) +} + +// ModelMappingChainLT applies the LT predicate on the "model_mapping_chain" field. +func ModelMappingChainLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldModelMappingChain, v)) +} + +// ModelMappingChainLTE applies the LTE predicate on the "model_mapping_chain" field. +func ModelMappingChainLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldModelMappingChain, v)) +} + +// ModelMappingChainContains applies the Contains predicate on the "model_mapping_chain" field. +func ModelMappingChainContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldModelMappingChain, v)) +} + +// ModelMappingChainHasPrefix applies the HasPrefix predicate on the "model_mapping_chain" field. +func ModelMappingChainHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldModelMappingChain, v)) +} + +// ModelMappingChainHasSuffix applies the HasSuffix predicate on the "model_mapping_chain" field. +func ModelMappingChainHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldModelMappingChain, v)) +} + +// ModelMappingChainIsNil applies the IsNil predicate on the "model_mapping_chain" field. +func ModelMappingChainIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldModelMappingChain)) +} + +// ModelMappingChainNotNil applies the NotNil predicate on the "model_mapping_chain" field. +func ModelMappingChainNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldModelMappingChain)) +} + +// ModelMappingChainEqualFold applies the EqualFold predicate on the "model_mapping_chain" field. +func ModelMappingChainEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldModelMappingChain, v)) +} + +// ModelMappingChainContainsFold applies the ContainsFold predicate on the "model_mapping_chain" field. +func ModelMappingChainContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldModelMappingChain, v)) +} + +// BillingTierEQ applies the EQ predicate on the "billing_tier" field. +func BillingTierEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v)) +} + +// BillingTierNEQ applies the NEQ predicate on the "billing_tier" field. +func BillingTierNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldBillingTier, v)) +} + +// BillingTierIn applies the In predicate on the "billing_tier" field. +func BillingTierIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldBillingTier, vs...)) +} + +// BillingTierNotIn applies the NotIn predicate on the "billing_tier" field. +func BillingTierNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldBillingTier, vs...)) +} + +// BillingTierGT applies the GT predicate on the "billing_tier" field. +func BillingTierGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldBillingTier, v)) +} + +// BillingTierGTE applies the GTE predicate on the "billing_tier" field. +func BillingTierGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldBillingTier, v)) +} + +// BillingTierLT applies the LT predicate on the "billing_tier" field. +func BillingTierLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldBillingTier, v)) +} + +// BillingTierLTE applies the LTE predicate on the "billing_tier" field. +func BillingTierLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldBillingTier, v)) +} + +// BillingTierContains applies the Contains predicate on the "billing_tier" field. +func BillingTierContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldBillingTier, v)) +} + +// BillingTierHasPrefix applies the HasPrefix predicate on the "billing_tier" field. +func BillingTierHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingTier, v)) +} + +// BillingTierHasSuffix applies the HasSuffix predicate on the "billing_tier" field. +func BillingTierHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingTier, v)) +} + +// BillingTierIsNil applies the IsNil predicate on the "billing_tier" field. +func BillingTierIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldBillingTier)) +} + +// BillingTierNotNil applies the NotNil predicate on the "billing_tier" field. +func BillingTierNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldBillingTier)) +} + +// BillingTierEqualFold applies the EqualFold predicate on the "billing_tier" field. +func BillingTierEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldBillingTier, v)) +} + +// BillingTierContainsFold applies the ContainsFold predicate on the "billing_tier" field. +func BillingTierContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldBillingTier, v)) +} + +// BillingModeEQ applies the EQ predicate on the "billing_mode" field. +func BillingModeEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v)) +} + +// BillingModeNEQ applies the NEQ predicate on the "billing_mode" field. +func BillingModeNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldBillingMode, v)) +} + +// BillingModeIn applies the In predicate on the "billing_mode" field. +func BillingModeIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldBillingMode, vs...)) +} + +// BillingModeNotIn applies the NotIn predicate on the "billing_mode" field. +func BillingModeNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldBillingMode, vs...)) +} + +// BillingModeGT applies the GT predicate on the "billing_mode" field. +func BillingModeGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldBillingMode, v)) +} + +// BillingModeGTE applies the GTE predicate on the "billing_mode" field. +func BillingModeGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldBillingMode, v)) +} + +// BillingModeLT applies the LT predicate on the "billing_mode" field. +func BillingModeLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldBillingMode, v)) +} + +// BillingModeLTE applies the LTE predicate on the "billing_mode" field. +func BillingModeLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldBillingMode, v)) +} + +// BillingModeContains applies the Contains predicate on the "billing_mode" field. +func BillingModeContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldBillingMode, v)) +} + +// BillingModeHasPrefix applies the HasPrefix predicate on the "billing_mode" field. +func BillingModeHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingMode, v)) +} + +// BillingModeHasSuffix applies the HasSuffix predicate on the "billing_mode" field. +func BillingModeHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingMode, v)) +} + +// BillingModeIsNil applies the IsNil predicate on the "billing_mode" field. +func BillingModeIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldBillingMode)) +} + +// BillingModeNotNil applies the NotNil predicate on the "billing_mode" field. +func BillingModeNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldBillingMode)) +} + +// BillingModeEqualFold applies the EqualFold predicate on the "billing_mode" field. +func BillingModeEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldBillingMode, v)) +} + +// BillingModeContainsFold applies the ContainsFold predicate on the "billing_mode" field. +func BillingModeContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldBillingMode, v)) +} + // GroupIDEQ applies the EQ predicate on the "group_id" field. func GroupIDEQ(v int64) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index 6ae0bf595e..d15e231d9c 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -85,6 +85,62 @@ func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate { return _c } +// SetChannelID sets the "channel_id" field. +func (_c *UsageLogCreate) SetChannelID(v int64) *UsageLogCreate { + _c.mutation.SetChannelID(v) + return _c +} + +// SetNillableChannelID sets the "channel_id" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableChannelID(v *int64) *UsageLogCreate { + if v != nil { + _c.SetChannelID(*v) + } + return _c +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (_c *UsageLogCreate) SetModelMappingChain(v string) *UsageLogCreate { + _c.mutation.SetModelMappingChain(v) + return _c +} + +// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableModelMappingChain(v *string) *UsageLogCreate { + if v != nil { + _c.SetModelMappingChain(*v) + } + return _c +} + +// SetBillingTier sets the "billing_tier" field. +func (_c *UsageLogCreate) SetBillingTier(v string) *UsageLogCreate { + _c.mutation.SetBillingTier(v) + return _c +} + +// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableBillingTier(v *string) *UsageLogCreate { + if v != nil { + _c.SetBillingTier(*v) + } + return _c +} + +// SetBillingMode sets the "billing_mode" field. +func (_c *UsageLogCreate) SetBillingMode(v string) *UsageLogCreate { + _c.mutation.SetBillingMode(v) + return _c +} + +// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableBillingMode(v *string) *UsageLogCreate { + if v != nil { + _c.SetBillingMode(*v) + } + return _c +} + // SetGroupID sets the "group_id" field. func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate { _c.mutation.SetGroupID(v) @@ -634,6 +690,21 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} } } + if v, ok := _c.mutation.ModelMappingChain(); ok { + if err := usagelog.ModelMappingChainValidator(v); err != nil { + return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)} + } + } + if v, ok := _c.mutation.BillingTier(); ok { + if err := usagelog.BillingTierValidator(v); err != nil { + return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)} + } + } + if v, ok := _c.mutation.BillingMode(); ok { + if err := usagelog.BillingModeValidator(v); err != nil { + return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)} + } + } if _, ok := _c.mutation.InputTokens(); !ok { return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)} } @@ -760,6 +831,22 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) _node.UpstreamModel = &value } + if value, ok := _c.mutation.ChannelID(); ok { + _spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value) + _node.ChannelID = &value + } + if value, ok := _c.mutation.ModelMappingChain(); ok { + _spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value) + _node.ModelMappingChain = &value + } + if value, ok := _c.mutation.BillingTier(); ok { + _spec.SetField(usagelog.FieldBillingTier, field.TypeString, value) + _node.BillingTier = &value + } + if value, ok := _c.mutation.BillingMode(); ok { + _spec.SetField(usagelog.FieldBillingMode, field.TypeString, value) + _node.BillingMode = &value + } if value, ok := _c.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) _node.InputTokens = value @@ -1093,6 +1180,84 @@ func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert { return u } +// SetChannelID sets the "channel_id" field. +func (u *UsageLogUpsert) SetChannelID(v int64) *UsageLogUpsert { + u.Set(usagelog.FieldChannelID, v) + return u +} + +// UpdateChannelID sets the "channel_id" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateChannelID() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldChannelID) + return u +} + +// AddChannelID adds v to the "channel_id" field. +func (u *UsageLogUpsert) AddChannelID(v int64) *UsageLogUpsert { + u.Add(usagelog.FieldChannelID, v) + return u +} + +// ClearChannelID clears the value of the "channel_id" field. +func (u *UsageLogUpsert) ClearChannelID() *UsageLogUpsert { + u.SetNull(usagelog.FieldChannelID) + return u +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (u *UsageLogUpsert) SetModelMappingChain(v string) *UsageLogUpsert { + u.Set(usagelog.FieldModelMappingChain, v) + return u +} + +// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateModelMappingChain() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldModelMappingChain) + return u +} + +// ClearModelMappingChain clears the value of the "model_mapping_chain" field. +func (u *UsageLogUpsert) ClearModelMappingChain() *UsageLogUpsert { + u.SetNull(usagelog.FieldModelMappingChain) + return u +} + +// SetBillingTier sets the "billing_tier" field. +func (u *UsageLogUpsert) SetBillingTier(v string) *UsageLogUpsert { + u.Set(usagelog.FieldBillingTier, v) + return u +} + +// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateBillingTier() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldBillingTier) + return u +} + +// ClearBillingTier clears the value of the "billing_tier" field. +func (u *UsageLogUpsert) ClearBillingTier() *UsageLogUpsert { + u.SetNull(usagelog.FieldBillingTier) + return u +} + +// SetBillingMode sets the "billing_mode" field. +func (u *UsageLogUpsert) SetBillingMode(v string) *UsageLogUpsert { + u.Set(usagelog.FieldBillingMode, v) + return u +} + +// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateBillingMode() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldBillingMode) + return u +} + +// ClearBillingMode clears the value of the "billing_mode" field. +func (u *UsageLogUpsert) ClearBillingMode() *UsageLogUpsert { + u.SetNull(usagelog.FieldBillingMode) + return u +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert { u.Set(usagelog.FieldGroupID, v) @@ -1724,6 +1889,97 @@ func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne { }) } +// SetChannelID sets the "channel_id" field. +func (u *UsageLogUpsertOne) SetChannelID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetChannelID(v) + }) +} + +// AddChannelID adds v to the "channel_id" field. +func (u *UsageLogUpsertOne) AddChannelID(v int64) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.AddChannelID(v) + }) +} + +// UpdateChannelID sets the "channel_id" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateChannelID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateChannelID() + }) +} + +// ClearChannelID clears the value of the "channel_id" field. +func (u *UsageLogUpsertOne) ClearChannelID() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearChannelID() + }) +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (u *UsageLogUpsertOne) SetModelMappingChain(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetModelMappingChain(v) + }) +} + +// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateModelMappingChain() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateModelMappingChain() + }) +} + +// ClearModelMappingChain clears the value of the "model_mapping_chain" field. +func (u *UsageLogUpsertOne) ClearModelMappingChain() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearModelMappingChain() + }) +} + +// SetBillingTier sets the "billing_tier" field. +func (u *UsageLogUpsertOne) SetBillingTier(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetBillingTier(v) + }) +} + +// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateBillingTier() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateBillingTier() + }) +} + +// ClearBillingTier clears the value of the "billing_tier" field. +func (u *UsageLogUpsertOne) ClearBillingTier() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearBillingTier() + }) +} + +// SetBillingMode sets the "billing_mode" field. +func (u *UsageLogUpsertOne) SetBillingMode(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetBillingMode(v) + }) +} + +// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateBillingMode() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateBillingMode() + }) +} + +// ClearBillingMode clears the value of the "billing_mode" field. +func (u *UsageLogUpsertOne) ClearBillingMode() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearBillingMode() + }) +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -2600,6 +2856,97 @@ func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk { }) } +// SetChannelID sets the "channel_id" field. +func (u *UsageLogUpsertBulk) SetChannelID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetChannelID(v) + }) +} + +// AddChannelID adds v to the "channel_id" field. +func (u *UsageLogUpsertBulk) AddChannelID(v int64) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.AddChannelID(v) + }) +} + +// UpdateChannelID sets the "channel_id" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateChannelID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateChannelID() + }) +} + +// ClearChannelID clears the value of the "channel_id" field. +func (u *UsageLogUpsertBulk) ClearChannelID() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearChannelID() + }) +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (u *UsageLogUpsertBulk) SetModelMappingChain(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetModelMappingChain(v) + }) +} + +// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateModelMappingChain() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateModelMappingChain() + }) +} + +// ClearModelMappingChain clears the value of the "model_mapping_chain" field. +func (u *UsageLogUpsertBulk) ClearModelMappingChain() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearModelMappingChain() + }) +} + +// SetBillingTier sets the "billing_tier" field. +func (u *UsageLogUpsertBulk) SetBillingTier(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetBillingTier(v) + }) +} + +// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateBillingTier() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateBillingTier() + }) +} + +// ClearBillingTier clears the value of the "billing_tier" field. +func (u *UsageLogUpsertBulk) ClearBillingTier() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearBillingTier() + }) +} + +// SetBillingMode sets the "billing_mode" field. +func (u *UsageLogUpsertBulk) SetBillingMode(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetBillingMode(v) + }) +} + +// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateBillingMode() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateBillingMode() + }) +} + +// ClearBillingMode clears the value of the "billing_mode" field. +func (u *UsageLogUpsertBulk) ClearBillingMode() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearBillingMode() + }) +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index 516407b92a..52f5a606dc 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -142,6 +142,93 @@ func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate { return _u } +// SetChannelID sets the "channel_id" field. +func (_u *UsageLogUpdate) SetChannelID(v int64) *UsageLogUpdate { + _u.mutation.ResetChannelID() + _u.mutation.SetChannelID(v) + return _u +} + +// SetNillableChannelID sets the "channel_id" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableChannelID(v *int64) *UsageLogUpdate { + if v != nil { + _u.SetChannelID(*v) + } + return _u +} + +// AddChannelID adds value to the "channel_id" field. +func (_u *UsageLogUpdate) AddChannelID(v int64) *UsageLogUpdate { + _u.mutation.AddChannelID(v) + return _u +} + +// ClearChannelID clears the value of the "channel_id" field. +func (_u *UsageLogUpdate) ClearChannelID() *UsageLogUpdate { + _u.mutation.ClearChannelID() + return _u +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (_u *UsageLogUpdate) SetModelMappingChain(v string) *UsageLogUpdate { + _u.mutation.SetModelMappingChain(v) + return _u +} + +// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableModelMappingChain(v *string) *UsageLogUpdate { + if v != nil { + _u.SetModelMappingChain(*v) + } + return _u +} + +// ClearModelMappingChain clears the value of the "model_mapping_chain" field. +func (_u *UsageLogUpdate) ClearModelMappingChain() *UsageLogUpdate { + _u.mutation.ClearModelMappingChain() + return _u +} + +// SetBillingTier sets the "billing_tier" field. +func (_u *UsageLogUpdate) SetBillingTier(v string) *UsageLogUpdate { + _u.mutation.SetBillingTier(v) + return _u +} + +// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableBillingTier(v *string) *UsageLogUpdate { + if v != nil { + _u.SetBillingTier(*v) + } + return _u +} + +// ClearBillingTier clears the value of the "billing_tier" field. +func (_u *UsageLogUpdate) ClearBillingTier() *UsageLogUpdate { + _u.mutation.ClearBillingTier() + return _u +} + +// SetBillingMode sets the "billing_mode" field. +func (_u *UsageLogUpdate) SetBillingMode(v string) *UsageLogUpdate { + _u.mutation.SetBillingMode(v) + return _u +} + +// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableBillingMode(v *string) *UsageLogUpdate { + if v != nil { + _u.SetBillingMode(*v) + } + return _u +} + +// ClearBillingMode clears the value of the "billing_mode" field. +func (_u *UsageLogUpdate) ClearBillingMode() *UsageLogUpdate { + _u.mutation.ClearBillingMode() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate { _u.mutation.SetGroupID(v) @@ -795,6 +882,21 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} } } + if v, ok := _u.mutation.ModelMappingChain(); ok { + if err := usagelog.ModelMappingChainValidator(v); err != nil { + return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)} + } + } + if v, ok := _u.mutation.BillingTier(); ok { + if err := usagelog.BillingTierValidator(v); err != nil { + return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)} + } + } + if v, ok := _u.mutation.BillingMode(); ok { + if err := usagelog.BillingModeValidator(v); err != nil { + return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)} + } + } if v, ok := _u.mutation.UserAgent(); ok { if err := usagelog.UserAgentValidator(v); err != nil { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} @@ -857,6 +959,33 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.UpstreamModelCleared() { _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) } + if value, ok := _u.mutation.ChannelID(); ok { + _spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedChannelID(); ok { + _spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value) + } + if _u.mutation.ChannelIDCleared() { + _spec.ClearField(usagelog.FieldChannelID, field.TypeInt64) + } + if value, ok := _u.mutation.ModelMappingChain(); ok { + _spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value) + } + if _u.mutation.ModelMappingChainCleared() { + _spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString) + } + if value, ok := _u.mutation.BillingTier(); ok { + _spec.SetField(usagelog.FieldBillingTier, field.TypeString, value) + } + if _u.mutation.BillingTierCleared() { + _spec.ClearField(usagelog.FieldBillingTier, field.TypeString) + } + if value, ok := _u.mutation.BillingMode(); ok { + _spec.SetField(usagelog.FieldBillingMode, field.TypeString, value) + } + if _u.mutation.BillingModeCleared() { + _spec.ClearField(usagelog.FieldBillingMode, field.TypeString) + } if value, ok := _u.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) } @@ -1279,6 +1408,93 @@ func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne { return _u } +// SetChannelID sets the "channel_id" field. +func (_u *UsageLogUpdateOne) SetChannelID(v int64) *UsageLogUpdateOne { + _u.mutation.ResetChannelID() + _u.mutation.SetChannelID(v) + return _u +} + +// SetNillableChannelID sets the "channel_id" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableChannelID(v *int64) *UsageLogUpdateOne { + if v != nil { + _u.SetChannelID(*v) + } + return _u +} + +// AddChannelID adds value to the "channel_id" field. +func (_u *UsageLogUpdateOne) AddChannelID(v int64) *UsageLogUpdateOne { + _u.mutation.AddChannelID(v) + return _u +} + +// ClearChannelID clears the value of the "channel_id" field. +func (_u *UsageLogUpdateOne) ClearChannelID() *UsageLogUpdateOne { + _u.mutation.ClearChannelID() + return _u +} + +// SetModelMappingChain sets the "model_mapping_chain" field. +func (_u *UsageLogUpdateOne) SetModelMappingChain(v string) *UsageLogUpdateOne { + _u.mutation.SetModelMappingChain(v) + return _u +} + +// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableModelMappingChain(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetModelMappingChain(*v) + } + return _u +} + +// ClearModelMappingChain clears the value of the "model_mapping_chain" field. +func (_u *UsageLogUpdateOne) ClearModelMappingChain() *UsageLogUpdateOne { + _u.mutation.ClearModelMappingChain() + return _u +} + +// SetBillingTier sets the "billing_tier" field. +func (_u *UsageLogUpdateOne) SetBillingTier(v string) *UsageLogUpdateOne { + _u.mutation.SetBillingTier(v) + return _u +} + +// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableBillingTier(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetBillingTier(*v) + } + return _u +} + +// ClearBillingTier clears the value of the "billing_tier" field. +func (_u *UsageLogUpdateOne) ClearBillingTier() *UsageLogUpdateOne { + _u.mutation.ClearBillingTier() + return _u +} + +// SetBillingMode sets the "billing_mode" field. +func (_u *UsageLogUpdateOne) SetBillingMode(v string) *UsageLogUpdateOne { + _u.mutation.SetBillingMode(v) + return _u +} + +// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableBillingMode(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetBillingMode(*v) + } + return _u +} + +// ClearBillingMode clears the value of the "billing_mode" field. +func (_u *UsageLogUpdateOne) ClearBillingMode() *UsageLogUpdateOne { + _u.mutation.ClearBillingMode() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne { _u.mutation.SetGroupID(v) @@ -1945,6 +2161,21 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} } } + if v, ok := _u.mutation.ModelMappingChain(); ok { + if err := usagelog.ModelMappingChainValidator(v); err != nil { + return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)} + } + } + if v, ok := _u.mutation.BillingTier(); ok { + if err := usagelog.BillingTierValidator(v); err != nil { + return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)} + } + } + if v, ok := _u.mutation.BillingMode(); ok { + if err := usagelog.BillingModeValidator(v); err != nil { + return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)} + } + } if v, ok := _u.mutation.UserAgent(); ok { if err := usagelog.UserAgentValidator(v); err != nil { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} @@ -2024,6 +2255,33 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.UpstreamModelCleared() { _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) } + if value, ok := _u.mutation.ChannelID(); ok { + _spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedChannelID(); ok { + _spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value) + } + if _u.mutation.ChannelIDCleared() { + _spec.ClearField(usagelog.FieldChannelID, field.TypeInt64) + } + if value, ok := _u.mutation.ModelMappingChain(); ok { + _spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value) + } + if _u.mutation.ModelMappingChainCleared() { + _spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString) + } + if value, ok := _u.mutation.BillingTier(); ok { + _spec.SetField(usagelog.FieldBillingTier, field.TypeString, value) + } + if _u.mutation.BillingTierCleared() { + _spec.ClearField(usagelog.FieldBillingTier, field.TypeString) + } + if value, ok := _u.mutation.BillingMode(); ok { + _spec.SetField(usagelog.FieldBillingMode, field.TypeString, value) + } + if _u.mutation.BillingModeCleared() { + _spec.ClearField(usagelog.FieldBillingMode, field.TypeString) + } if value, ok := _u.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) } diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go new file mode 100644 index 0000000000..563a27cea4 --- /dev/null +++ b/backend/internal/handler/admin/channel_handler.go @@ -0,0 +1,452 @@ +package admin + +import ( + "errors" + "fmt" + "strconv" + "strings" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// ChannelHandler handles admin channel management +type ChannelHandler struct { + channelService *service.ChannelService + billingService *service.BillingService +} + +// NewChannelHandler creates a new admin channel handler +func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler { + return &ChannelHandler{channelService: channelService, billingService: billingService} +} + +// --- Request / Response types --- + +type createChannelRequest struct { + Name string `json:"name" binding:"required,max=100"` + Description string `json:"description"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` + RestrictModels bool `json:"restrict_models"` +} + +type updateChannelRequest struct { + Name string `json:"name" binding:"omitempty,max=100"` + Description *string `json:"description"` + Status string `json:"status" binding:"omitempty,oneof=active disabled"` + GroupIDs *[]int64 `json:"group_ids"` + ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` + RestrictModels *bool `json:"restrict_models"` +} + +type channelModelPricingRequest struct { + Platform string `json:"platform" binding:"omitempty,max=50"` + Models []string `json:"models" binding:"required,min=1,max=100"` + BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` + InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` + OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"` + CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"` + CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"` + ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"` + PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"` + Intervals []pricingIntervalRequest `json:"intervals"` +} + +type pricingIntervalRequest struct { + MinTokens int `json:"min_tokens"` + MaxTokens *int `json:"max_tokens"` + TierLabel string `json:"tier_label"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + PerRequestPrice *float64 `json:"per_request_price"` + SortOrder int `json:"sort_order"` +} + +type channelResponse struct { + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + BillingModelSource string `json:"billing_model_source"` + RestrictModels bool `json:"restrict_models"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingResponse `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type channelModelPricingResponse struct { + ID int64 `json:"id"` + Platform string `json:"platform"` + Models []string `json:"models"` + BillingMode string `json:"billing_mode"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + ImageOutputPrice *float64 `json:"image_output_price"` + PerRequestPrice *float64 `json:"per_request_price"` + Intervals []pricingIntervalResponse `json:"intervals"` +} + +type pricingIntervalResponse struct { + ID int64 `json:"id"` + MinTokens int `json:"min_tokens"` + MaxTokens *int `json:"max_tokens"` + TierLabel string `json:"tier_label,omitempty"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + PerRequestPrice *float64 `json:"per_request_price"` + SortOrder int `json:"sort_order"` +} + +func channelToResponse(ch *service.Channel) *channelResponse { + if ch == nil { + return nil + } + resp := &channelResponse{ + ID: ch.ID, + Name: ch.Name, + Description: ch.Description, + Status: ch.Status, + RestrictModels: ch.RestrictModels, + GroupIDs: ch.GroupIDs, + ModelMapping: ch.ModelMapping, + CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), + UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"), + } + resp.BillingModelSource = ch.BillingModelSource + if resp.BillingModelSource == "" { + resp.BillingModelSource = service.BillingModelSourceChannelMapped + } + if resp.GroupIDs == nil { + resp.GroupIDs = []int64{} + } + if resp.ModelMapping == nil { + resp.ModelMapping = map[string]map[string]string{} + } + + resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing)) + for _, p := range ch.ModelPricing { + resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p)) + } + return resp +} + +func pricingToResponse(p *service.ChannelModelPricing) channelModelPricingResponse { + models := p.Models + if models == nil { + models = []string{} + } + billingMode := string(p.BillingMode) + if billingMode == "" { + billingMode = string(service.BillingModeToken) + } + platform := p.Platform + if platform == "" { + platform = service.PlatformAnthropic + } + intervals := make([]pricingIntervalResponse, 0, len(p.Intervals)) + for _, iv := range p.Intervals { + intervals = append(intervals, intervalToResponse(iv)) + } + return channelModelPricingResponse{ + ID: p.ID, + Platform: platform, + Models: models, + BillingMode: billingMode, + InputPrice: p.InputPrice, + OutputPrice: p.OutputPrice, + CacheWritePrice: p.CacheWritePrice, + CacheReadPrice: p.CacheReadPrice, + ImageOutputPrice: p.ImageOutputPrice, + PerRequestPrice: p.PerRequestPrice, + Intervals: intervals, + } +} + +func intervalToResponse(iv service.PricingInterval) pricingIntervalResponse { + return pricingIntervalResponse{ + ID: iv.ID, + MinTokens: iv.MinTokens, + MaxTokens: iv.MaxTokens, + TierLabel: iv.TierLabel, + InputPrice: iv.InputPrice, + OutputPrice: iv.OutputPrice, + CacheWritePrice: iv.CacheWritePrice, + CacheReadPrice: iv.CacheReadPrice, + PerRequestPrice: iv.PerRequestPrice, + SortOrder: iv.SortOrder, + } +} + +func pricingRequestToService(reqs []channelModelPricingRequest) []service.ChannelModelPricing { + result := make([]service.ChannelModelPricing, 0, len(reqs)) + for _, r := range reqs { + billingMode := service.BillingMode(r.BillingMode) + if billingMode == "" { + billingMode = service.BillingModeToken + } + platform := r.Platform + if platform == "" { + platform = service.PlatformAnthropic + } + intervals := make([]service.PricingInterval, 0, len(r.Intervals)) + for _, iv := range r.Intervals { + intervals = append(intervals, service.PricingInterval{ + MinTokens: iv.MinTokens, + MaxTokens: iv.MaxTokens, + TierLabel: iv.TierLabel, + InputPrice: iv.InputPrice, + OutputPrice: iv.OutputPrice, + CacheWritePrice: iv.CacheWritePrice, + CacheReadPrice: iv.CacheReadPrice, + PerRequestPrice: iv.PerRequestPrice, + SortOrder: iv.SortOrder, + }) + } + result = append(result, service.ChannelModelPricing{ + Platform: platform, + Models: r.Models, + BillingMode: billingMode, + InputPrice: r.InputPrice, + OutputPrice: r.OutputPrice, + CacheWritePrice: r.CacheWritePrice, + CacheReadPrice: r.CacheReadPrice, + ImageOutputPrice: r.ImageOutputPrice, + PerRequestPrice: r.PerRequestPrice, + Intervals: intervals, + }) + } + return result +} + +// validatePricingBillingMode 校验计费配置 +func validatePricingBillingMode(pricing []service.ChannelModelPricing) error { + for _, p := range pricing { + // 按次/图片模式必须配置默认价格或区间 + if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage { + if p.PerRequestPrice == nil && len(p.Intervals) == 0 { + return errors.New("per-request price or intervals required for per_request/image billing mode") + } + } + // 校验价格不能为负 + if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil { + return err + } + if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil { + return err + } + if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil { + return err + } + if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil { + return err + } + if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil { + return err + } + if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil { + return err + } + // 校验 interval:至少有一个价格字段非空 + for _, iv := range p.Intervals { + if iv.InputPrice == nil && iv.OutputPrice == nil && + iv.CacheWritePrice == nil && iv.CacheReadPrice == nil && + iv.PerRequestPrice == nil { + return fmt.Errorf("interval [%d, %s] has no price fields set for model %v", + iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models) + } + } + } + return nil +} + +func validatePriceNotNegative(field string, val *float64) error { + if val != nil && *val < 0 { + return fmt.Errorf("%s must be >= 0", field) + } + return nil +} + +func formatMaxTokens(max *int) string { + if max == nil { + return "∞" + } + return fmt.Sprintf("%d", *max) +} + +// --- Handlers --- + +// List handles listing channels with pagination +// GET /api/v1/admin/channels +func (h *ChannelHandler) List(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + status := c.Query("status") + search := strings.TrimSpace(c.Query("search")) + if len(search) > 100 { + search = search[:100] + } + + channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search) + if err != nil { + response.ErrorFrom(c, err) + return + } + + out := make([]*channelResponse, 0, len(channels)) + for i := range channels { + out = append(out, channelToResponse(&channels[i])) + } + response.Paginated(c, out, pag.Total, page, pageSize) +} + +// GetByID handles getting a channel by ID +// GET /api/v1/admin/channels/:id +func (h *ChannelHandler) GetByID(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID")) + return + } + + channel, err := h.channelService.GetByID(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, channelToResponse(channel)) +} + +// Create handles creating a new channel +// POST /api/v1/admin/channels +func (h *ChannelHandler) Create(c *gin.Context) { + var req createChannelRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + + pricing := pricingRequestToService(req.ModelPricing) + if err := validatePricingBillingMode(pricing); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + + channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ + Name: req.Name, + Description: req.Description, + GroupIDs: req.GroupIDs, + ModelPricing: pricing, + ModelMapping: req.ModelMapping, + BillingModelSource: req.BillingModelSource, + RestrictModels: req.RestrictModels, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, channelToResponse(channel)) +} + +// Update handles updating a channel +// PUT /api/v1/admin/channels/:id +func (h *ChannelHandler) Update(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID")) + return + } + + var req updateChannelRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + + input := &service.UpdateChannelInput{ + Name: req.Name, + Description: req.Description, + Status: req.Status, + GroupIDs: req.GroupIDs, + ModelMapping: req.ModelMapping, + BillingModelSource: req.BillingModelSource, + RestrictModels: req.RestrictModels, + } + if req.ModelPricing != nil { + pricing := pricingRequestToService(*req.ModelPricing) + if err := validatePricingBillingMode(pricing); err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) + return + } + input.ModelPricing = &pricing + } + + channel, err := h.channelService.Update(c.Request.Context(), id, input) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, channelToResponse(channel)) +} + +// Delete handles deleting a channel +// DELETE /api/v1/admin/channels/:id +func (h *ChannelHandler) Delete(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID")) + return + } + + if err := h.channelService.Delete(c.Request.Context(), id); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Channel deleted successfully"}) +} + +// GetModelDefaultPricing 获取模型的默认定价(用于前端自动填充) +// GET /api/v1/admin/channels/model-pricing?model=claude-sonnet-4 +func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) { + model := strings.TrimSpace(c.Query("model")) + if model == "" { + response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "model parameter is required"). + WithMetadata(map[string]string{"param": "model"})) + return + } + + pricing, err := h.billingService.GetModelPricing(model) + if err != nil { + // 模型不在定价列表中 + response.Success(c, gin.H{"found": false}) + return + } + + response.Success(c, gin.H{ + "found": true, + "input_price": pricing.InputPricePerToken, + "output_price": pricing.OutputPricePerToken, + "cache_write_price": pricing.CacheCreationPricePerToken, + "cache_read_price": pricing.CacheReadPricePerToken, + "image_output_price": pricing.ImageOutputPricePerToken, + }) +} diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go new file mode 100644 index 0000000000..6f6ea5260e --- /dev/null +++ b/backend/internal/handler/admin/channel_handler_test.go @@ -0,0 +1,502 @@ +//go:build unit + +package admin + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +func float64Ptr(v float64) *float64 { return &v } +func intPtr(v int) *int { return &v } + +// --------------------------------------------------------------------------- +// 1. channelToResponse +// --------------------------------------------------------------------------- + +func TestChannelToResponse_NilInput(t *testing.T) { + require.Nil(t, channelToResponse(nil)) +} + +func TestChannelToResponse_FullChannel(t *testing.T) { + now := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC) + ch := &service.Channel{ + ID: 42, + Name: "test-channel", + Description: "desc", + Status: "active", + BillingModelSource: "upstream", + RestrictModels: true, + CreatedAt: now, + UpdatedAt: now.Add(time.Hour), + GroupIDs: []int64{1, 2, 3}, + ModelPricing: []service.ChannelModelPricing{ + { + ID: 10, + Platform: "openai", + Models: []string{"gpt-4"}, + BillingMode: service.BillingModeToken, + InputPrice: float64Ptr(0.01), + OutputPrice: float64Ptr(0.03), + CacheWritePrice: float64Ptr(0.005), + CacheReadPrice: float64Ptr(0.002), + PerRequestPrice: float64Ptr(0.5), + }, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-3-haiku": "claude-haiku-3"}, + }, + } + + resp := channelToResponse(ch) + require.NotNil(t, resp) + require.Equal(t, int64(42), resp.ID) + require.Equal(t, "test-channel", resp.Name) + require.Equal(t, "desc", resp.Description) + require.Equal(t, "active", resp.Status) + require.Equal(t, "upstream", resp.BillingModelSource) + require.True(t, resp.RestrictModels) + require.Equal(t, []int64{1, 2, 3}, resp.GroupIDs) + require.Equal(t, "2025-06-01T12:00:00Z", resp.CreatedAt) + require.Equal(t, "2025-06-01T13:00:00Z", resp.UpdatedAt) + + // model mapping + require.Len(t, resp.ModelMapping, 1) + require.Equal(t, "claude-haiku-3", resp.ModelMapping["anthropic"]["claude-3-haiku"]) + + // pricing + require.Len(t, resp.ModelPricing, 1) + p := resp.ModelPricing[0] + require.Equal(t, int64(10), p.ID) + require.Equal(t, "openai", p.Platform) + require.Equal(t, []string{"gpt-4"}, p.Models) + require.Equal(t, "token", p.BillingMode) + require.Equal(t, float64Ptr(0.01), p.InputPrice) + require.Equal(t, float64Ptr(0.03), p.OutputPrice) + require.Equal(t, float64Ptr(0.005), p.CacheWritePrice) + require.Equal(t, float64Ptr(0.002), p.CacheReadPrice) + require.Equal(t, float64Ptr(0.5), p.PerRequestPrice) + require.Empty(t, p.Intervals) +} + +func TestChannelToResponse_EmptyDefaults(t *testing.T) { + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + ch := &service.Channel{ + ID: 1, + Name: "ch", + BillingModelSource: "", + CreatedAt: now, + UpdatedAt: now, + GroupIDs: nil, + ModelMapping: nil, + ModelPricing: []service.ChannelModelPricing{ + { + Platform: "", + BillingMode: "", + Models: []string{"m1"}, + }, + }, + } + + resp := channelToResponse(ch) + require.Equal(t, "channel_mapped", resp.BillingModelSource) + require.NotNil(t, resp.GroupIDs) + require.Empty(t, resp.GroupIDs) + require.NotNil(t, resp.ModelMapping) + require.Empty(t, resp.ModelMapping) + + require.Len(t, resp.ModelPricing, 1) + require.Equal(t, "anthropic", resp.ModelPricing[0].Platform) + require.Equal(t, "token", resp.ModelPricing[0].BillingMode) +} + +func TestChannelToResponse_NilModels(t *testing.T) { + now := time.Now() + ch := &service.Channel{ + ID: 1, + Name: "ch", + CreatedAt: now, + UpdatedAt: now, + ModelPricing: []service.ChannelModelPricing{ + { + Models: nil, + }, + }, + } + + resp := channelToResponse(ch) + require.Len(t, resp.ModelPricing, 1) + require.NotNil(t, resp.ModelPricing[0].Models) + require.Empty(t, resp.ModelPricing[0].Models) +} + +func TestChannelToResponse_WithIntervals(t *testing.T) { + now := time.Now() + ch := &service.Channel{ + ID: 1, + Name: "ch", + CreatedAt: now, + UpdatedAt: now, + ModelPricing: []service.ChannelModelPricing{ + { + Models: []string{"m1"}, + BillingMode: service.BillingModePerRequest, + Intervals: []service.PricingInterval{ + { + ID: 100, + MinTokens: 0, + MaxTokens: intPtr(1000), + TierLabel: "1K", + InputPrice: float64Ptr(0.01), + OutputPrice: float64Ptr(0.02), + CacheWritePrice: float64Ptr(0.003), + CacheReadPrice: float64Ptr(0.001), + PerRequestPrice: float64Ptr(0.1), + SortOrder: 1, + }, + { + ID: 101, + MinTokens: 1000, + MaxTokens: nil, + TierLabel: "unlimited", + SortOrder: 2, + }, + }, + }, + }, + } + + resp := channelToResponse(ch) + require.Len(t, resp.ModelPricing, 1) + intervals := resp.ModelPricing[0].Intervals + require.Len(t, intervals, 2) + + iv0 := intervals[0] + require.Equal(t, int64(100), iv0.ID) + require.Equal(t, 0, iv0.MinTokens) + require.Equal(t, intPtr(1000), iv0.MaxTokens) + require.Equal(t, "1K", iv0.TierLabel) + require.Equal(t, float64Ptr(0.01), iv0.InputPrice) + require.Equal(t, float64Ptr(0.02), iv0.OutputPrice) + require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice) + require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice) + require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice) + require.Equal(t, 1, iv0.SortOrder) + + iv1 := intervals[1] + require.Equal(t, int64(101), iv1.ID) + require.Equal(t, 1000, iv1.MinTokens) + require.Nil(t, iv1.MaxTokens) + require.Equal(t, "unlimited", iv1.TierLabel) + require.Equal(t, 2, iv1.SortOrder) +} + +func TestChannelToResponse_MultipleEntries(t *testing.T) { + now := time.Now() + ch := &service.Channel{ + ID: 1, + Name: "multi", + CreatedAt: now, + UpdatedAt: now, + ModelPricing: []service.ChannelModelPricing{ + { + ID: 1, + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: service.BillingModeToken, + InputPrice: float64Ptr(0.003), + OutputPrice: float64Ptr(0.015), + }, + { + ID: 2, + Platform: "openai", + Models: []string{"gpt-4", "gpt-4o"}, + BillingMode: service.BillingModePerRequest, + PerRequestPrice: float64Ptr(1.0), + }, + { + ID: 3, + Platform: "gemini", + Models: []string{"gemini-2.5-pro"}, + BillingMode: service.BillingModeImage, + ImageOutputPrice: float64Ptr(0.05), + PerRequestPrice: float64Ptr(0.2), + }, + }, + } + + resp := channelToResponse(ch) + require.Len(t, resp.ModelPricing, 3) + + require.Equal(t, int64(1), resp.ModelPricing[0].ID) + require.Equal(t, "anthropic", resp.ModelPricing[0].Platform) + require.Equal(t, []string{"claude-sonnet-4"}, resp.ModelPricing[0].Models) + require.Equal(t, "token", resp.ModelPricing[0].BillingMode) + + require.Equal(t, int64(2), resp.ModelPricing[1].ID) + require.Equal(t, "openai", resp.ModelPricing[1].Platform) + require.Equal(t, []string{"gpt-4", "gpt-4o"}, resp.ModelPricing[1].Models) + require.Equal(t, "per_request", resp.ModelPricing[1].BillingMode) + + require.Equal(t, int64(3), resp.ModelPricing[2].ID) + require.Equal(t, "gemini", resp.ModelPricing[2].Platform) + require.Equal(t, []string{"gemini-2.5-pro"}, resp.ModelPricing[2].Models) + require.Equal(t, "image", resp.ModelPricing[2].BillingMode) + require.Equal(t, float64Ptr(0.05), resp.ModelPricing[2].ImageOutputPrice) +} + +// --------------------------------------------------------------------------- +// 2. pricingRequestToService +// --------------------------------------------------------------------------- + +func TestPricingRequestToService_Defaults(t *testing.T) { + tests := []struct { + name string + req channelModelPricingRequest + wantField string // which default field to check + wantValue string + }{ + { + name: "empty billing mode defaults to token", + req: channelModelPricingRequest{ + Models: []string{"m1"}, + BillingMode: "", + }, + wantField: "BillingMode", + wantValue: string(service.BillingModeToken), + }, + { + name: "empty platform defaults to anthropic", + req: channelModelPricingRequest{ + Models: []string{"m1"}, + Platform: "", + }, + wantField: "Platform", + wantValue: "anthropic", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pricingRequestToService([]channelModelPricingRequest{tt.req}) + require.Len(t, result, 1) + switch tt.wantField { + case "BillingMode": + require.Equal(t, service.BillingMode(tt.wantValue), result[0].BillingMode) + case "Platform": + require.Equal(t, tt.wantValue, result[0].Platform) + } + }) + } +} + +func TestPricingRequestToService_WithAllFields(t *testing.T) { + reqs := []channelModelPricingRequest{ + { + Platform: "openai", + Models: []string{"gpt-4", "gpt-4o"}, + BillingMode: "per_request", + InputPrice: float64Ptr(0.01), + OutputPrice: float64Ptr(0.03), + CacheWritePrice: float64Ptr(0.005), + CacheReadPrice: float64Ptr(0.002), + ImageOutputPrice: float64Ptr(0.04), + PerRequestPrice: float64Ptr(0.5), + }, + } + + result := pricingRequestToService(reqs) + require.Len(t, result, 1) + r := result[0] + require.Equal(t, "openai", r.Platform) + require.Equal(t, []string{"gpt-4", "gpt-4o"}, r.Models) + require.Equal(t, service.BillingModePerRequest, r.BillingMode) + require.Equal(t, float64Ptr(0.01), r.InputPrice) + require.Equal(t, float64Ptr(0.03), r.OutputPrice) + require.Equal(t, float64Ptr(0.005), r.CacheWritePrice) + require.Equal(t, float64Ptr(0.002), r.CacheReadPrice) + require.Equal(t, float64Ptr(0.04), r.ImageOutputPrice) + require.Equal(t, float64Ptr(0.5), r.PerRequestPrice) +} + +func TestPricingRequestToService_WithIntervals(t *testing.T) { + reqs := []channelModelPricingRequest{ + { + Models: []string{"m1"}, + BillingMode: "per_request", + Intervals: []pricingIntervalRequest{ + { + MinTokens: 0, + MaxTokens: intPtr(2000), + TierLabel: "small", + InputPrice: float64Ptr(0.01), + OutputPrice: float64Ptr(0.02), + CacheWritePrice: float64Ptr(0.003), + CacheReadPrice: float64Ptr(0.001), + PerRequestPrice: float64Ptr(0.1), + SortOrder: 1, + }, + { + MinTokens: 2000, + MaxTokens: nil, + TierLabel: "large", + SortOrder: 2, + }, + }, + }, + } + + result := pricingRequestToService(reqs) + require.Len(t, result, 1) + require.Len(t, result[0].Intervals, 2) + + iv0 := result[0].Intervals[0] + require.Equal(t, 0, iv0.MinTokens) + require.Equal(t, intPtr(2000), iv0.MaxTokens) + require.Equal(t, "small", iv0.TierLabel) + require.Equal(t, float64Ptr(0.01), iv0.InputPrice) + require.Equal(t, float64Ptr(0.02), iv0.OutputPrice) + require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice) + require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice) + require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice) + require.Equal(t, 1, iv0.SortOrder) + + iv1 := result[0].Intervals[1] + require.Equal(t, 2000, iv1.MinTokens) + require.Nil(t, iv1.MaxTokens) + require.Equal(t, "large", iv1.TierLabel) + require.Equal(t, 2, iv1.SortOrder) +} + +func TestPricingRequestToService_EmptySlice(t *testing.T) { + result := pricingRequestToService([]channelModelPricingRequest{}) + require.NotNil(t, result) + require.Empty(t, result) +} + +func TestPricingRequestToService_NilPriceFields(t *testing.T) { + reqs := []channelModelPricingRequest{ + { + Models: []string{"m1"}, + BillingMode: "token", + // all price fields are nil by default + }, + } + + result := pricingRequestToService(reqs) + require.Len(t, result, 1) + r := result[0] + require.Nil(t, r.InputPrice) + require.Nil(t, r.OutputPrice) + require.Nil(t, r.CacheWritePrice) + require.Nil(t, r.CacheReadPrice) + require.Nil(t, r.ImageOutputPrice) + require.Nil(t, r.PerRequestPrice) +} + +// --------------------------------------------------------------------------- +// 3. validatePricingBillingMode +// --------------------------------------------------------------------------- + +func TestValidatePricingBillingMode(t *testing.T) { + tests := []struct { + name string + pricing []service.ChannelModelPricing + wantErr bool + }{ + { + name: "token mode - valid", + pricing: []service.ChannelModelPricing{ + {BillingMode: service.BillingModeToken}, + }, + wantErr: false, + }, + { + name: "per_request with price - valid", + pricing: []service.ChannelModelPricing{ + { + BillingMode: service.BillingModePerRequest, + PerRequestPrice: float64Ptr(0.5), + }, + }, + wantErr: false, + }, + { + name: "per_request with intervals - valid", + pricing: []service.ChannelModelPricing{ + { + BillingMode: service.BillingModePerRequest, + Intervals: []service.PricingInterval{ + {MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)}, + }, + }, + }, + wantErr: false, + }, + { + name: "per_request no price no intervals - invalid", + pricing: []service.ChannelModelPricing{ + {BillingMode: service.BillingModePerRequest}, + }, + wantErr: true, + }, + { + name: "image with price - valid", + pricing: []service.ChannelModelPricing{ + { + BillingMode: service.BillingModeImage, + PerRequestPrice: float64Ptr(0.2), + }, + }, + wantErr: false, + }, + { + name: "image no price no intervals - invalid", + pricing: []service.ChannelModelPricing{ + {BillingMode: service.BillingModeImage}, + }, + wantErr: true, + }, + { + name: "empty list - valid", + pricing: []service.ChannelModelPricing{}, + wantErr: false, + }, + { + name: "mixed modes with invalid image - invalid", + pricing: []service.ChannelModelPricing{ + { + BillingMode: service.BillingModeToken, + InputPrice: float64Ptr(0.01), + }, + { + BillingMode: service.BillingModePerRequest, + PerRequestPrice: float64Ptr(0.5), + }, + { + BillingMode: service.BillingModeImage, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePricingBillingMode(tt.pricing) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "per-request price or intervals required") + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 2a214471a4..460f63578e 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -636,6 +636,40 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) { dim.Endpoint = c.Query("endpoint") dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound") + // Additional filter conditions + if v := c.Query("user_id"); v != "" { + if id, err := strconv.ParseInt(v, 10, 64); err == nil { + dim.UserID = id + } + } + if v := c.Query("api_key_id"); v != "" { + if id, err := strconv.ParseInt(v, 10, 64); err == nil { + dim.APIKeyID = id + } + } + if v := c.Query("account_id"); v != "" { + if id, err := strconv.ParseInt(v, 10, 64); err == nil { + dim.AccountID = id + } + } + if v := c.Query("request_type"); v != "" { + if rt, err := strconv.ParseInt(v, 10, 16); err == nil { + rtVal := int16(rt) + dim.RequestType = &rtVal + } + } + if v := c.Query("stream"); v != "" { + if s, err := strconv.ParseBool(v); err == nil { + dim.Stream = &s + } + } + if v := c.Query("billing_type"); v != "" { + if bt, err := strconv.ParseInt(v, 10, 8); err == nil { + btVal := int8(bt) + dim.BillingType = &btVal + } + } + limit := 50 if v := c.Query("limit"); v != "" { if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 { diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index 7a3135b88a..2967b3840e 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -110,6 +110,7 @@ func (h *UsageHandler) List(c *gin.Context) { } model := c.Query("model") + billingMode := strings.TrimSpace(c.Query("billing_mode")) var requestType *int16 var stream *bool @@ -174,6 +175,7 @@ func (h *UsageHandler) List(c *gin.Context) { RequestType: requestType, Stream: stream, BillingType: billingType, + BillingMode: billingMode, StartTime: startTime, EndTime: endTime, ExactTotal: exactTotal, @@ -234,6 +236,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { } model := c.Query("model") + billingMode := strings.TrimSpace(c.Query("billing_mode")) var requestType *int16 var stream *bool @@ -312,6 +315,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { RequestType: requestType, Stream: stream, BillingType: billingType, + BillingMode: billingMode, StartTime: &startTime, EndTime: &endTime, } diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index a8da92c020..d9d657836d 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -577,6 +577,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { MediaType: l.MediaType, UserAgent: l.UserAgent, CacheTTLOverridden: l.CacheTTLOverridden, + BillingMode: l.BillingMode, CreatedAt: l.CreatedAt, User: UserFromServiceShallow(l.User), APIKey: APIKeyFromService(l.APIKey), @@ -604,6 +605,9 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog { return &AdminUsageLog{ UsageLog: usageLogFromServiceUser(l), UpstreamModel: l.UpstreamModel, + ChannelID: l.ChannelID, + ModelMappingChain: l.ModelMappingChain, + BillingTier: l.BillingTier, AccountRateMultiplier: l.AccountRateMultiplier, IPAddress: l.IPAddress, Account: AccountSummaryFromService(l.Account), diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 46984044b3..56b67c8c4d 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -390,6 +390,9 @@ type UsageLog struct { // Cache TTL Override 标记 CacheTTLOverridden bool `json:"cache_ttl_overridden"` + // BillingMode 计费模式:token/image + BillingMode *string `json:"billing_mode,omitempty"` + CreatedAt time.Time `json:"created_at"` User *User `json:"user,omitempty"` @@ -406,6 +409,13 @@ type AdminUsageLog struct { // Omitted when no mapping was applied (requested model was used as-is). UpstreamModel *string `json:"upstream_model,omitempty"` + // ChannelID 渠道 ID + ChannelID *int64 `json:"channel_id,omitempty"` + // ModelMappingChain 模型映射链,如 "a→b→c" + ModelMappingChain *string `json:"model_mapping_chain,omitempty"` + // BillingTier 计费层级标签(per_request/image 模式) + BillingTier *string `json:"billing_tier,omitempty"` + // AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理) AccountRateMultiplier *float64 `json:"account_rate_multiplier"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index a0d8b2e9f5..dfc9fb88b4 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -158,6 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqStream := parsedReq.Stream reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { @@ -292,7 +295,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制 if err != nil { if len(fs.FailedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -478,6 +481,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { RequestPayloadHash: requestPayloadHash, ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), @@ -514,7 +518,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0)) if err != nil { if len(fs.FailedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -660,6 +664,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { parsedReq.OnUpstreamAccepted = queueRelease // ===== 用户消息串行队列 END ===== + // 应用渠道模型映射到请求 + if channelMapping.Mapped { + parsedReq.Model = channelMapping.MappedModel + parsedReq.Body = h.gatewayService.ReplaceModelInBody(parsedReq.Body, channelMapping.MappedModel) + body = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() @@ -810,6 +821,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { RequestPayloadHash: requestPayloadHash, ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index da376036d8..be267332d8 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -80,6 +80,9 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + // Claude Code only restriction if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly { h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error", @@ -154,7 +157,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { fs := NewFailoverState(h.maxAccountSwitches, false) for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "") + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0)) if err != nil { if len(fs.FailedAccountIDs) == 0 { h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) @@ -203,7 +206,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { // 5. Forward request writerSizeBeforeForward := c.Writer.Size() - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq) + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq) if accountReleaseFunc != nil { accountReleaseFunc() @@ -255,6 +262,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { reqLog.Error("gateway.cc.record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index d146d72467..e908eb9e0f 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -80,6 +80,9 @@ func (h *GatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + // Claude Code only restriction: // /v1/responses is never a Claude Code endpoint. // When claude_code_only is enabled, this endpoint is rejected. @@ -159,7 +162,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { fs := NewFailoverState(h.maxAccountSwitches, false) for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "") + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0)) if err != nil { if len(fs.FailedAccountIDs) == 0 { h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) @@ -208,7 +211,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) { // 5. Forward request writerSizeBeforeForward := c.Writer.Size() - result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq) + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, forwardBody, parsedReq) if accountReleaseFunc != nil { accountReleaseFunc() @@ -261,6 +268,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { reqLog.Error("gateway.responses.record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 69c8d1d582..4caef9551b 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -161,6 +161,8 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // digestStore nil, // settingService nil, // tlsFPProfileService + nil, // channelService + nil, // resolver ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 524c6b6de4..d200c17ce5 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -184,6 +184,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { setOpsRequestContext(c, modelName, stream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName) + reqModel := modelName // 保存映射前的原始模型名 + if channelMapping.Mapped { + modelName = channelMapping.MappedModel + } + // Get subscription (may be nil) subscription, _ := middleware.GetSubscriptionFromContext(c) @@ -353,7 +360,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制 if err != nil { if len(fs.FailedAccountIDs) == 0 { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) @@ -523,6 +530,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { LongContextMultiplier: 2.0, // 超出部分双倍计费 ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.gemini_v1beta.models"), diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b2467eacdb..ebf8d5f674 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -30,6 +30,7 @@ type AdminHandlers struct { TLSFingerprintProfile *admin.TLSFingerprintProfileHandler APIKey *admin.AdminAPIKeyHandler ScheduledTest *admin.ScheduledTestHandler + Channel *admin.ChannelHandler } // Handlers contains all HTTP handlers diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 0c94aa2173..991cbb91f7 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -79,6 +79,9 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) } @@ -183,7 +186,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { forwardStart := time.Now() defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model")) - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { @@ -257,16 +264,17 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + InboundEndpoint: GetInboundEndpoint(c), + UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.chat_completions"), diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index ae70cee40e..4747ccfe1e 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -185,6 +185,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 if !h.validateFunctionCallOutputRequest(c, body, reqLog) { return @@ -284,7 +287,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Forward request service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) + // 应用渠道模型映射到请求体 + forwardBody := body + if channelMapping.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) + } + result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { accountReleaseFunc() @@ -379,6 +387,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), @@ -549,6 +558,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) + // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) @@ -673,7 +685,12 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的 // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。 defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model")) - result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) + // 应用渠道模型映射到请求体 + forwardBody := body + if channelMappingMsg.Mapped { + forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel) + } + result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { @@ -759,6 +776,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMappingMsg.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.messages"), @@ -1101,6 +1119,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsRequestContext(c, reqModel, true, firstMessage) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) + // 解析渠道级模型映射 + channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel) + var currentUserRelease func() var currentAccountRelease func() releaseTurnSlots := func() { @@ -1259,6 +1280,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), APIKeyService: h.apiKeyService, + ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { reqLog.Error("openai.websocket_record_usage_failed", zap.Int64("account_id", account.ID), @@ -1270,7 +1292,13 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { }, } - if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil { + // 应用渠道模型映射到 WebSocket 首条消息 + wsFirstMessage := firstMessage + if channelMappingWS.Mapped { + wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel) + } + + if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) closeStatus, closeReason := summarizeWSCloseErrorForLog(err) reqLog.Warn("openai.websocket_proxy_failed", diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index fe035b6f7f..5705578660 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2225,6 +2225,7 @@ func newMinimalGatewayService(accountRepo service.AccountRepository) *service.Ga return service.NewGatewayService( accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, ) } diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 5e50540932..d1e7e00fe5 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -30,6 +30,8 @@ import ( ) // SoraGatewayHandler handles Sora chat completions requests +// +// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。 type SoraGatewayHandler struct { gatewayService *service.GatewayService soraGatewayService *service.SoraGatewayService @@ -226,7 +228,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { var lastFailoverHeaders http.Header for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "", int64(0)) if err != nil { reqLog.Warn("sora.account_select_failed", zap.Error(err), diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index c790a36c06..e053b668d3 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -465,6 +465,8 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, // digestStore nil, // settingService nil, // tlsFPProfileService + nil, // channelService + nil, // resolver ) soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 02ddd03098..c917f24a0d 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -33,6 +33,7 @@ func ProvideAdminHandlers( tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler, apiKeyHandler *admin.AdminAPIKeyHandler, scheduledTestHandler *admin.ScheduledTestHandler, + channelHandler *admin.ChannelHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -59,6 +60,7 @@ func ProvideAdminHandlers( TLSFingerprintProfile: tlsFingerprintProfileHandler, APIKey: apiKeyHandler, ScheduledTest: scheduledTestHandler, + Channel: channelHandler, } } @@ -150,6 +152,7 @@ var ProviderSet = wire.NewSet( admin.NewTLSFingerprintProfileHandler, admin.NewAdminAPIKeyHandler, admin.NewScheduledTestHandler, + admin.NewChannelHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 8ea87f1883..ce144bb911 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -125,6 +125,7 @@ type ClaudeUsage struct { OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` + ImageOutputTokens int `json:"image_output_tokens,omitempty"` } // ClaudeError Claude 错误响应 diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index 1a0ca5bb61..033dccbd5d 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -149,13 +149,31 @@ type GeminiCandidate struct { GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"` } +// GeminiTokenDetail Gemini token 详情(按模态分类) +type GeminiTokenDetail struct { + Modality string `json:"modality"` + TokenCount int `json:"tokenCount"` +} + // GeminiUsageMetadata Gemini 用量元数据 type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount,omitempty"` - CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` - CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` - TotalTokenCount int `json:"totalTokenCount,omitempty"` - ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费) + PromptTokenCount int `json:"promptTokenCount,omitempty"` + CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` + CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` + TotalTokenCount int `json:"totalTokenCount,omitempty"` + ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费) + CandidatesTokensDetails []GeminiTokenDetail `json:"candidatesTokensDetails,omitempty"` + PromptTokensDetails []GeminiTokenDetail `json:"promptTokensDetails,omitempty"` +} + +// ImageOutputTokens 从 CandidatesTokensDetails 中提取 IMAGE 模态的 token 数 +func (m *GeminiUsageMetadata) ImageOutputTokens() int { + for _, d := range m.CandidatesTokensDetails { + if d.Modality == "IMAGE" { + return d.TokenCount + } + } + return 0 } // GeminiGroundingMetadata Gemini grounding 元数据(Web Search) diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index f12effb6bd..bc1fd32e38 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount usage.CacheReadInputTokens = cached + usage.ImageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens() } // 生成响应 ID diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index deed5f922e..58982878de 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -32,9 +32,10 @@ type StreamingProcessor struct { groundingChunks []GeminiGroundingChunk // 累计 usage - inputTokens int - outputTokens int - cacheReadTokens int + inputTokens int + outputTokens int + cacheReadTokens int + imageOutputTokens int } // NewStreamingProcessor 创建流式响应处理器 @@ -87,6 +88,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount p.cacheReadTokens = cached + p.imageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens() } // 处理 parts @@ -127,6 +129,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) { InputTokens: p.inputTokens, OutputTokens: p.outputTokens, CacheReadInputTokens: p.cacheReadTokens, + ImageOutputTokens: p.imageOutputTokens, } if !p.messageStartSent { @@ -158,6 +161,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount usage.CacheReadInputTokens = cached + usage.ImageOutputTokens = v1Resp.Response.UsageMetadata.ImageOutputTokens() } responseID := v1Resp.ResponseID @@ -485,6 +489,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { InputTokens: p.inputTokens, OutputTokens: p.outputTokens, CacheReadInputTokens: p.cacheReadTokens, + ImageOutputTokens: p.imageOutputTokens, } deltaEvent := map[string]any{ diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 44cddb6ab6..5d1f7911a0 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -175,6 +175,13 @@ type UserBreakdownDimension struct { ModelType string // "requested", "upstream", or "mapping" Endpoint string // filter by endpoint value (non-empty to enable) EndpointType string // "inbound", "upstream", or "path" + // Additional filter conditions + UserID int64 // filter by user_id (>0 to enable) + APIKeyID int64 // filter by api_key_id (>0 to enable) + AccountID int64 // filter by account_id (>0 to enable) + RequestType *int16 // filter by request_type (non-nil to enable) + Stream *bool // filter by stream flag (non-nil to enable) + BillingType *int8 // filter by billing_type (non-nil to enable) } // APIKeyUsageTrendPoint represents API key usage trend data point @@ -230,6 +237,7 @@ type UsageLogFilters struct { RequestType *int16 Stream *bool BillingType *int8 + BillingMode string StartTime *time.Time EndTime *time.Time // ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging. diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go new file mode 100644 index 0000000000..1e2c2e4cf3 --- /dev/null +++ b/backend/internal/repository/channel_repo.go @@ -0,0 +1,461 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +type channelRepository struct { + db *sql.DB +} + +// NewChannelRepository 创建渠道数据访问实例 +func NewChannelRepository(db *sql.DB) service.ChannelRepository { + return &channelRepository{db: db} +} + +// runInTx 在事务中执行 fn,成功 commit,失败 rollback。 +func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) error) error { + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + if err := fn(tx); err != nil { + return err + } + return tx.Commit() +} + +func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error { + return r.runInTx(ctx, func(tx *sql.Tx) error { + modelMappingJSON, err := marshalModelMapping(channel.ModelMapping) + if err != nil { + return err + } + err = tx.QueryRowContext(ctx, + `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id, created_at, updated_at`, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, + ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) + if err != nil { + if isUniqueViolation(err) { + return service.ErrChannelExists + } + return fmt.Errorf("insert channel: %w", err) + } + + // 设置分组关联 + if len(channel.GroupIDs) > 0 { + if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil { + return err + } + } + + // 设置模型定价 + if len(channel.ModelPricing) > 0 { + if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil { + return err + } + } + + return nil + }) +} + +func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) { + ch := &service.Channel{} + var modelMappingJSON []byte + err := r.db.QueryRowContext(ctx, + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at + FROM channels WHERE id = $1`, id, + ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt) + if err == sql.ErrNoRows { + return nil, service.ErrChannelNotFound + } + if err != nil { + return nil, fmt.Errorf("get channel: %w", err) + } + ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + + groupIDs, err := r.GetGroupIDs(ctx, id) + if err != nil { + return nil, err + } + ch.GroupIDs = groupIDs + + pricing, err := r.ListModelPricing(ctx, id) + if err != nil { + return nil, err + } + ch.ModelPricing = pricing + + return ch, nil +} + +func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error { + return r.runInTx(ctx, func(tx *sql.Tx) error { + modelMappingJSON, err := marshalModelMapping(channel.ModelMapping) + if err != nil { + return err + } + result, err := tx.ExecContext(ctx, + `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW() + WHERE id = $7`, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID, + ) + if err != nil { + if isUniqueViolation(err) { + return service.ErrChannelExists + } + return fmt.Errorf("update channel: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + return service.ErrChannelNotFound + } + + // 更新分组关联 + if channel.GroupIDs != nil { + if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil { + return err + } + } + + // 更新模型定价 + if channel.ModelPricing != nil { + if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil { + return err + } + } + + return nil + }) +} + +func (r *channelRepository) Delete(ctx context.Context, id int64) error { + result, err := r.db.ExecContext(ctx, `DELETE FROM channels WHERE id = $1`, id) + if err != nil { + return fmt.Errorf("delete channel: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + return service.ErrChannelNotFound + } + return nil +} + +func (r *channelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.Channel, *pagination.PaginationResult, error) { + where := []string{"1=1"} + args := []any{} + argIdx := 1 + + if status != "" { + where = append(where, fmt.Sprintf("c.status = $%d", argIdx)) + args = append(args, status) + argIdx++ + } + if search != "" { + where = append(where, fmt.Sprintf("(c.name ILIKE $%d OR c.description ILIKE $%d)", argIdx, argIdx)) + args = append(args, "%"+escapeLike(search)+"%") + argIdx++ + } + + whereClause := strings.Join(where, " AND ") + + // 计数 + var total int64 + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM channels c WHERE %s", whereClause) + if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, nil, fmt.Errorf("count channels: %w", err) + } + + pageSize := params.Limit() // 约束在 [1, 100] + page := params.Page + if page < 1 { + page = 1 + } + offset := (page - 1) * pageSize + + // 查询 channel 列表 + dataQuery := fmt.Sprintf( + `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at + FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`, + whereClause, argIdx, argIdx+1, + ) + args = append(args, pageSize, offset) + + rows, err := r.db.QueryContext(ctx, dataQuery, args...) + if err != nil { + return nil, nil, fmt.Errorf("query channels: %w", err) + } + defer func() { _ = rows.Close() }() + + var channels []service.Channel + var channelIDs []int64 + for rows.Next() { + var ch service.Channel + var modelMappingJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + return nil, nil, fmt.Errorf("scan channel: %w", err) + } + ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + channels = append(channels, ch) + channelIDs = append(channelIDs, ch.ID) + } + if err := rows.Err(); err != nil { + return nil, nil, fmt.Errorf("iterate channels: %w", err) + } + + // 批量加载分组 ID 和模型定价(避免 N+1) + if len(channelIDs) > 0 { + groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs) + if err != nil { + return nil, nil, err + } + pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs) + if err != nil { + return nil, nil, err + } + for i := range channels { + channels[i].GroupIDs = groupMap[channels[i].ID] + channels[i].ModelPricing = pricingMap[channels[i].ID] + } + } + + pages := 0 + if total > 0 { + pages = int((total + int64(pageSize) - 1) / int64(pageSize)) + } + + paginationResult := &pagination.PaginationResult{ + Total: total, + Page: page, + PageSize: pageSize, + Pages: pages, + } + + return channels, paginationResult, nil +} + +func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`, + ) + if err != nil { + return nil, fmt.Errorf("query all channels: %w", err) + } + defer func() { _ = rows.Close() }() + + var channels []service.Channel + var channelIDs []int64 + for rows.Next() { + var ch service.Channel + var modelMappingJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + return nil, fmt.Errorf("scan channel: %w", err) + } + ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + channels = append(channels, ch) + channelIDs = append(channelIDs, ch.ID) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate channels: %w", err) + } + + if len(channelIDs) == 0 { + return channels, nil + } + + // 批量加载分组 ID + groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs) + if err != nil { + return nil, err + } + + // 批量加载模型定价 + pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs) + if err != nil { + return nil, err + } + + for i := range channels { + channels[i].GroupIDs = groupMap[channels[i].ID] + channels[i].ModelPricing = pricingMap[channels[i].ID] + } + + return channels, nil +} + +// --- 批量加载辅助方法 --- + +// batchLoadGroupIDs 批量加载多个渠道的分组 ID +func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs []int64) (map[int64][]int64, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT channel_id, group_id FROM channel_groups + WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`, + pq.Array(channelIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load group ids: %w", err) + } + defer func() { _ = rows.Close() }() + + groupMap := make(map[int64][]int64, len(channelIDs)) + for rows.Next() { + var channelID, groupID int64 + if err := rows.Scan(&channelID, &groupID); err != nil { + return nil, fmt.Errorf("scan group id: %w", err) + } + groupMap[channelID] = append(groupMap[channelID], groupID) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate group ids: %w", err) + } + return groupMap, nil +} + +func (r *channelRepository) ExistsByName(ctx context.Context, name string) (bool, error) { + var exists bool + err := r.db.QueryRowContext(ctx, + `SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`, name, + ).Scan(&exists) + return exists, err +} + +func (r *channelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) { + var exists bool + err := r.db.QueryRowContext(ctx, + `SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`, name, excludeID, + ).Scan(&exists) + return exists, err +} + +// --- 分组关联 --- + +func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`, channelID, + ) + if err != nil { + return nil, fmt.Errorf("get group ids: %w", err) + } + defer func() { _ = rows.Close() }() + + var ids []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, fmt.Errorf("scan group id: %w", err) + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate group ids: %w", err) + } + return ids, nil +} + +func (r *channelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error { + return setGroupIDsTx(ctx, r.db, channelID, groupIDs) +} + +func (r *channelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + var channelID int64 + err := r.db.QueryRowContext(ctx, + `SELECT channel_id FROM channel_groups WHERE group_id = $1`, groupID, + ).Scan(&channelID) + if err == sql.ErrNoRows { + return 0, nil + } + return channelID, err +} + +func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) { + if len(groupIDs) == 0 { + return nil, nil + } + rows, err := r.db.QueryContext(ctx, + `SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`, + pq.Array(groupIDs), channelID, + ) + if err != nil { + return nil, fmt.Errorf("get groups in other channels: %w", err) + } + defer func() { _ = rows.Close() }() + + var conflicting []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, fmt.Errorf("scan conflicting group id: %w", err) + } + conflicting = append(conflicting, id) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate conflicting group ids: %w", err) + } + return conflicting, nil +} + +// marshalModelMapping 将 model mapping 序列化为嵌套 JSON 字节 +// 格式:{"platform": {"src": "dst"}, ...} +func marshalModelMapping(m map[string]map[string]string) ([]byte, error) { + if len(m) == 0 { + return []byte("{}"), nil + } + data, err := json.Marshal(m) + if err != nil { + return nil, fmt.Errorf("marshal model_mapping: %w", err) + } + return data, nil +} + +// unmarshalModelMapping 将 JSON 字节反序列化为嵌套 model mapping +func unmarshalModelMapping(data []byte) map[string]map[string]string { + if len(data) == 0 { + return nil + } + var m map[string]map[string]string + if err := json.Unmarshal(data, &m); err != nil { + return nil + } + return m +} + +// GetGroupPlatforms 批量查询分组 ID 对应的平台 +func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { + if len(groupIDs) == 0 { + return make(map[int64]string), nil + } + rows, err := r.db.QueryContext(ctx, + `SELECT id, platform FROM groups WHERE id = ANY($1)`, + pq.Array(groupIDs), + ) + if err != nil { + return nil, fmt.Errorf("get group platforms: %w", err) + } + defer rows.Close() //nolint:errcheck + + result := make(map[int64]string, len(groupIDs)) + for rows.Next() { + var id int64 + var platform string + if err := rows.Scan(&id, &platform); err != nil { + return nil, fmt.Errorf("scan group platform: %w", err) + } + result[id] = platform + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate group platforms: %w", err) + } + return result, nil +} diff --git a/backend/internal/repository/channel_repo_pricing.go b/backend/internal/repository/channel_repo_pricing.go new file mode 100644 index 0000000000..6dcf3c9175 --- /dev/null +++ b/backend/internal/repository/channel_repo_pricing.go @@ -0,0 +1,291 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +// --- 模型定价 --- + +func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at + FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID, + ) + if err != nil { + return nil, fmt.Errorf("list model pricing: %w", err) + } + defer func() { _ = rows.Close() }() + + result, pricingIDs, err := scanModelPricingRows(rows) + if err != nil { + return nil, err + } + + if len(pricingIDs) > 0 { + intervalMap, err := r.batchLoadIntervals(ctx, pricingIDs) + if err != nil { + return nil, err + } + for i := range result { + result[i].Intervals = intervalMap[result[i].ID] + } + } + + return result, nil +} + +func (r *channelRepository) CreateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error { + return createModelPricingExec(ctx, r.db, pricing) +} + +func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error { + modelsJSON, err := json.Marshal(pricing.Models) + if err != nil { + return fmt.Errorf("marshal models: %w", err) + } + billingMode := pricing.BillingMode + if billingMode == "" { + billingMode = service.BillingModeToken + } + result, err := r.db.ExecContext(ctx, + `UPDATE channel_model_pricing + SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, platform = $9, updated_at = NOW() + WHERE id = $10`, + modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, + pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.Platform, pricing.ID, + ) + if err != nil { + return fmt.Errorf("update model pricing: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + return fmt.Errorf("pricing entry not found: %d", pricing.ID) + } + return nil +} + +func (r *channelRepository) DeleteModelPricing(ctx context.Context, id int64) error { + _, err := r.db.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE id = $1`, id) + if err != nil { + return fmt.Errorf("delete model pricing: %w", err) + } + return nil +} + +func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []service.ChannelModelPricing) error { + return r.runInTx(ctx, func(tx *sql.Tx) error { + return replaceModelPricingTx(ctx, tx, channelID, pricingList) + }) +} + +// --- 批量加载辅助方法 --- + +// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间) +func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at + FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`, + pq.Array(channelIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load model pricing: %w", err) + } + defer func() { _ = rows.Close() }() + + allPricing, allPricingIDs, err := scanModelPricingRows(rows) + if err != nil { + return nil, err + } + + // 按 channelID 分组 + pricingMap := make(map[int64][]service.ChannelModelPricing, len(channelIDs)) + for _, p := range allPricing { + pricingMap[p.ChannelID] = append(pricingMap[p.ChannelID], p) + } + + // 批量加载所有区间 + if len(allPricingIDs) > 0 { + intervalMap, err := r.batchLoadIntervals(ctx, allPricingIDs) + if err != nil { + return nil, err + } + for chID := range pricingMap { + for i := range pricingMap[chID] { + pricingMap[chID][i].Intervals = intervalMap[pricingMap[chID][i].ID] + } + } + } + + return pricingMap, nil +} + +// batchLoadIntervals 批量加载多个定价条目的区间 +func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) { + rows, err := r.db.QueryContext(ctx, + `SELECT id, pricing_id, min_tokens, max_tokens, tier_label, + input_price, output_price, cache_write_price, cache_read_price, + per_request_price, sort_order, created_at, updated_at + FROM channel_pricing_intervals + WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`, + pq.Array(pricingIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load intervals: %w", err) + } + defer func() { _ = rows.Close() }() + + intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs)) + for rows.Next() { + var iv service.PricingInterval + if err := rows.Scan( + &iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel, + &iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice, + &iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan interval: %w", err) + } + intervalMap[iv.PricingID] = append(intervalMap[iv.PricingID], iv) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate intervals: %w", err) + } + return intervalMap, nil +} + +// --- 共享 scan 辅助 --- + +// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表 +func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int64, error) { + var result []service.ChannelModelPricing + var pricingIDs []int64 + for rows.Next() { + var p service.ChannelModelPricing + var modelsJSON []byte + if err := rows.Scan( + &p.ID, &p.ChannelID, &p.Platform, &modelsJSON, &p.BillingMode, + &p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice, + &p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt, + ); err != nil { + return nil, nil, fmt.Errorf("scan model pricing: %w", err) + } + if err := json.Unmarshal(modelsJSON, &p.Models); err != nil { + p.Models = []string{} + } + pricingIDs = append(pricingIDs, p.ID) + result = append(result, p) + } + if err := rows.Err(); err != nil { + return nil, nil, fmt.Errorf("iterate model pricing: %w", err) + } + return result, pricingIDs, nil +} + +// --- 事务内辅助方法 --- + +// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口 +type dbExec interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +func setGroupIDsTx(ctx context.Context, exec dbExec, channelID int64, groupIDs []int64) error { + if _, err := exec.ExecContext(ctx, `DELETE FROM channel_groups WHERE channel_id = $1`, channelID); err != nil { + return fmt.Errorf("delete old group associations: %w", err) + } + if len(groupIDs) == 0 { + return nil + } + _, err := exec.ExecContext(ctx, + `INSERT INTO channel_groups (channel_id, group_id) + SELECT $1, unnest($2::bigint[])`, + channelID, pq.Array(groupIDs), + ) + if err != nil { + return fmt.Errorf("insert group associations: %w", err) + } + return nil +} + +func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.ChannelModelPricing) error { + modelsJSON, err := json.Marshal(pricing.Models) + if err != nil { + return fmt.Errorf("marshal models: %w", err) + } + billingMode := pricing.BillingMode + if billingMode == "" { + billingMode = service.BillingModeToken + } + platform := pricing.Platform + if platform == "" { + platform = "anthropic" + } + err = exec.QueryRowContext(ctx, + `INSERT INTO channel_model_pricing (channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`, + pricing.ChannelID, platform, modelsJSON, billingMode, + pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, + pricing.ImageOutputPrice, pricing.PerRequestPrice, + ).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt) + if err != nil { + return fmt.Errorf("insert model pricing: %w", err) + } + + for i := range pricing.Intervals { + pricing.Intervals[i].PricingID = pricing.ID + if err := createIntervalExec(ctx, exec, &pricing.Intervals[i]); err != nil { + return err + } + } + + return nil +} + +func createIntervalExec(ctx context.Context, exec dbExec, iv *service.PricingInterval) error { + return exec.QueryRowContext(ctx, + `INSERT INTO channel_pricing_intervals + (pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`, + iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel, + iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice, + iv.PerRequestPrice, iv.SortOrder, + ).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt) +} + +func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pricingList []service.ChannelModelPricing) error { + if _, err := exec.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE channel_id = $1`, channelID); err != nil { + return fmt.Errorf("delete old model pricing: %w", err) + } + for i := range pricingList { + pricingList[i].ChannelID = channelID + if err := createModelPricingExec(ctx, exec, &pricingList[i]); err != nil { + return fmt.Errorf("insert model pricing: %w", err) + } + } + return nil +} + +// isUniqueViolation 检查 pq 唯一约束违反错误 +func isUniqueViolation(err error) bool { + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr != nil { + return pqErr.Code == "23505" + } + return false +} + +// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符 +func escapeLike(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `%`, `\%`) + s = strings.ReplaceAll(s, `_`, `\_`) + return s +} diff --git a/backend/internal/repository/channel_repo_test.go b/backend/internal/repository/channel_repo_test.go new file mode 100644 index 0000000000..5a59948d25 --- /dev/null +++ b/backend/internal/repository/channel_repo_test.go @@ -0,0 +1,227 @@ +//go:build unit + +package repository + +import ( + "encoding/json" + "errors" + "fmt" + "testing" + + "github.com/lib/pq" + "github.com/stretchr/testify/require" +) + +// --- marshalModelMapping --- + +func TestMarshalModelMapping(t *testing.T) { + tests := []struct { + name string + input map[string]map[string]string + wantJSON string // expected JSON output (exact match) + }{ + { + name: "empty map", + input: map[string]map[string]string{}, + wantJSON: "{}", + }, + { + name: "nil map", + input: nil, + wantJSON: "{}", + }, + { + name: "populated map", + input: map[string]map[string]string{ + "openai": {"gpt-4": "gpt-4-turbo"}, + }, + }, + { + name: "nested values", + input: map[string]map[string]string{ + "openai": {"*": "gpt-5.4"}, + "anthropic": {"claude-old": "claude-new"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := marshalModelMapping(tt.input) + require.NoError(t, err) + + if tt.wantJSON != "" { + require.Equal(t, []byte(tt.wantJSON), result) + } else { + // round-trip: unmarshal and compare with input + var parsed map[string]map[string]string + require.NoError(t, json.Unmarshal(result, &parsed)) + require.Equal(t, tt.input, parsed) + } + }) + } +} + +// --- unmarshalModelMapping --- + +func TestUnmarshalModelMapping(t *testing.T) { + tests := []struct { + name string + input []byte + wantNil bool + want map[string]map[string]string + }{ + { + name: "nil data", + input: nil, + wantNil: true, + }, + { + name: "empty data", + input: []byte{}, + wantNil: true, + }, + { + name: "invalid JSON", + input: []byte("not-json"), + wantNil: true, + }, + { + name: "type error - number", + input: []byte("42"), + wantNil: true, + }, + { + name: "type error - array", + input: []byte("[1,2,3]"), + wantNil: true, + }, + { + name: "valid JSON", + input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`), + want: map[string]map[string]string{ + "openai": {"gpt-4": "gpt-4-turbo"}, + "anthropic": {"old": "new"}, + }, + }, + { + name: "empty object", + input: []byte("{}"), + want: map[string]map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := unmarshalModelMapping(tt.input) + if tt.wantNil { + require.Nil(t, result) + } else { + require.NotNil(t, result) + require.Equal(t, tt.want, result) + } + }) + } +} + +// --- escapeLike --- + +func TestEscapeLike(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "no special chars", + input: "hello", + want: "hello", + }, + { + name: "backslash", + input: `a\b`, + want: `a\\b`, + }, + { + name: "percent", + input: "50%", + want: `50\%`, + }, + { + name: "underscore", + input: "a_b", + want: `a\_b`, + }, + { + name: "all special chars", + input: `a\b%c_d`, + want: `a\\b\%c\_d`, + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "consecutive special chars", + input: "%_%", + want: `\%\_\%`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, escapeLike(tt.input)) + }) + } +} + +// --- isUniqueViolation --- + +func TestIsUniqueViolation(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "unique violation code 23505", + err: &pq.Error{Code: "23505"}, + want: true, + }, + { + name: "different pq error code", + err: &pq.Error{Code: "23503"}, + want: false, + }, + { + name: "non-pq error", + err: errors.New("some generic error"), + want: false, + }, + { + name: "typed nil pq.Error", + err: func() error { + var pqErr *pq.Error + return pqErr + }(), + want: false, + }, + { + name: "bare nil", + err: nil, + want: false, + }, + { + name: "wrapped pq error with 23505", + err: fmt.Errorf("wrapped: %w", &pq.Error{Code: "23505"}), + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, isUniqueViolation(tt.err)) + }) + } +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index e4da825bdb..66d0b4ec19 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,7 +28,7 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at" // usageLogInsertArgTypes must stay in the same order as: // 1. prepareUsageLogInsert().args @@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{ "integer", // cache_read_tokens "integer", // cache_creation_5m_tokens "integer", // cache_creation_1h_tokens + "integer", // image_output_tokens + "numeric", // image_output_cost "numeric", // input_cost "numeric", // output_cost "numeric", // cache_creation_cost @@ -77,6 +79,10 @@ var usageLogInsertArgTypes = [...]string{ "text", // inbound_endpoint "text", // upstream_endpoint "boolean", // cache_ttl_overridden + "bigint", // channel_id + "text", // model_mapping_chain + "text", // billing_tier + "text", // billing_mode "timestamptz", // created_at } @@ -326,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -350,14 +358,18 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, - $14, $15, - $16, $17, $18, $19, $20, $21, - $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 + $14, $15, $16, $17, + $18, $19, $20, $21, $22, $23, + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -758,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -782,10 +796,14 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at ) AS (VALUES `) - args := make([]any, 0, len(keys)*39) + args := make([]any, 0, len(keys)*47) argPos := 1 for idx, key := range keys { if idx > 0 { @@ -829,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -853,6 +873,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at ) SELECT @@ -871,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -895,6 +921,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at FROM input ON CONFLICT (request_id, api_key_id) DO NOTHING @@ -953,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -977,10 +1009,14 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*40) + args := make([]any, 0, len(preparedList)*46) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -1021,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -1045,6 +1083,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at ) SELECT @@ -1063,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -1087,6 +1131,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at FROM input ON CONFLICT (request_id, api_key_id) DO NOTHING @@ -1113,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, input_cost, output_cost, cache_creation_cost, @@ -1137,14 +1187,18 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared inbound_endpoint, upstream_endpoint, cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, created_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, - $14, $15, - $16, $17, $18, $19, $20, $21, - $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40 + $14, $15, $16, $17, + $18, $19, $20, $21, $22, $23, + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1176,6 +1230,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { reasoningEffort := nullString(log.ReasoningEffort) inboundEndpoint := nullString(log.InboundEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint) + channelID := nullInt64(log.ChannelID) + modelMappingChain := nullString(log.ModelMappingChain) + billingTier := nullString(log.BillingTier) + billingMode := nullString(log.BillingMode) requestedModel := strings.TrimSpace(log.RequestedModel) if requestedModel == "" { requestedModel = strings.TrimSpace(log.Model) @@ -1208,6 +1266,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { log.CacheReadTokens, log.CacheCreation5mTokens, log.CacheCreation1hTokens, + log.ImageOutputTokens, + log.ImageOutputCost, log.InputCost, log.OutputCost, log.CacheCreationCost, @@ -1232,6 +1292,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { inboundEndpoint, upstreamEndpoint, log.CacheTTLOverridden, + channelID, + modelMappingChain, + billingTier, + billingMode, createdAt, }, } @@ -2564,8 +2628,8 @@ type UsageLogFilters = usagestats.UsageLogFilters // ListWithFilters lists usage logs with optional filters (for admin) func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { - conditions := make([]string, 0, 8) - args := make([]any, 0, 8) + conditions := make([]string, 0, 9) + args := make([]any, 0, 9) if filters.UserID > 0 { conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1)) @@ -2589,6 +2653,10 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) } + if filters.BillingMode != "" { + conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1)) + args = append(args, filters.BillingMode) + } if filters.StartTime != nil { conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) args = append(args, *filters.StartTime) @@ -3096,6 +3164,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1) args = append(args, dim.Endpoint) } + if dim.UserID > 0 { + query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1) + args = append(args, dim.UserID) + } + if dim.APIKeyID > 0 { + query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1) + args = append(args, dim.APIKeyID) + } + if dim.AccountID > 0 { + query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1) + args = append(args, dim.AccountID) + } + if dim.RequestType != nil { + query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1) + args = append(args, *dim.RequestType) + } + if dim.Stream != nil { + query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1) + args = append(args, *dim.Stream) + } + if dim.BillingType != nil { + query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1) + args = append(args, *dim.BillingType) + } query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC" if limit > 0 { @@ -3256,6 +3348,10 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) } + if filters.BillingMode != "" { + conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1)) + args = append(args, filters.BillingMode) + } if filters.StartTime != nil { conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) args = append(args, *filters.StartTime) @@ -3935,6 +4031,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e cacheReadTokens int cacheCreation5m int cacheCreation1h int + imageOutputTokens int + imageOutputCost float64 inputCost float64 outputCost float64 cacheCreationCost float64 @@ -3959,6 +4057,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e inboundEndpoint sql.NullString upstreamEndpoint sql.NullString cacheTTLOverridden bool + channelID sql.NullInt64 + modelMappingChain sql.NullString + billingTier sql.NullString + billingMode sql.NullString createdAt time.Time ) @@ -3979,6 +4081,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &cacheReadTokens, &cacheCreation5m, &cacheCreation1h, + &imageOutputTokens, + &imageOutputCost, &inputCost, &outputCost, &cacheCreationCost, @@ -4003,6 +4107,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &inboundEndpoint, &upstreamEndpoint, &cacheTTLOverridden, + &channelID, + &modelMappingChain, + &billingTier, + &billingMode, &createdAt, ); err != nil { return nil, err @@ -4021,6 +4129,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e CacheReadTokens: cacheReadTokens, CacheCreation5mTokens: cacheCreation5m, CacheCreation1hTokens: cacheCreation1h, + ImageOutputTokens: imageOutputTokens, + ImageOutputCost: imageOutputCost, InputCost: inputCost, OutputCost: outputCost, CacheCreationCost: cacheCreationCost, @@ -4087,6 +4197,19 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if upstreamModel.Valid { log.UpstreamModel = &upstreamModel.String } + if channelID.Valid { + value := channelID.Int64 + log.ChannelID = &value + } + if modelMappingChain.Valid { + log.ModelMappingChain = &modelMappingChain.String + } + if billingTier.Valid { + log.BillingTier = &billingTier.String + } + if billingMode.Valid { + log.BillingMode = &billingMode.String + } return log, nil } diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index ebc8929a6a..77f695e378 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { log.CacheReadTokens, log.CacheCreation5mTokens, log.CacheCreation1hTokens, + log.ImageOutputTokens, + log.ImageOutputCost, log.InputCost, log.OutputCost, log.CacheCreationCost, @@ -80,6 +82,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { sqlmock.AnyArg(), // inbound_endpoint sqlmock.AnyArg(), // upstream_endpoint log.CacheTTLOverridden, + sqlmock.AnyArg(), // channel_id + sqlmock.AnyArg(), // model_mapping_chain + sqlmock.AnyArg(), // billing_tier + sqlmock.AnyArg(), // billing_mode createdAt, ). WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) @@ -129,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { log.CacheReadTokens, log.CacheCreation5mTokens, log.CacheCreation1hTokens, + log.ImageOutputTokens, + log.ImageOutputCost, log.InputCost, log.OutputCost, log.CacheCreationCost, @@ -153,6 +161,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { sqlmock.AnyArg(), sqlmock.AnyArg(), log.CacheTTLOverridden, + sqlmock.AnyArg(), // channel_id + sqlmock.AnyArg(), // model_mapping_chain + sqlmock.AnyArg(), // billing_tier + sqlmock.AnyArg(), // billing_mode createdAt, ). WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt)) @@ -439,6 +451,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { 4, // cache_read_tokens 5, // cache_creation_5m_tokens 6, // cache_creation_1h_tokens + 0, // image_output_tokens + 0.0, // image_output_cost 0.1, // input_cost 0.2, // output_cost 0.3, // cache_creation_cost @@ -463,6 +477,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, sql.NullString{}, false, + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode now, }}) require.NoError(t, err) @@ -487,6 +505,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, + 0, 0.0, // image_output_tokens, image_output_cost 0.1, 0.2, 0.3, 0.4, 1.0, 0.9, 1.0, sql.NullFloat64{}, @@ -506,6 +525,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, sql.NullString{}, false, + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode now, }}) require.NoError(t, err) @@ -530,6 +553,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, + 0, 0.0, // image_output_tokens, image_output_cost 0.1, 0.2, 0.3, 0.4, 1.0, 0.9, 1.0, sql.NullFloat64{}, @@ -549,6 +573,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, sql.NullString{}, false, + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode now, }}) require.NoError(t, err) diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 49d47bf63c..4548c02882 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet( NewUserGroupRateRepository, NewErrorPassthroughRepository, NewTLSFingerprintProfileRepository, + NewChannelRepository, // Cache implementations NewGatewayCache, diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index e04dae8521..76f4c4b4d6 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -87,6 +87,9 @@ func RegisterAdminRoutes( // 定时测试计划 registerScheduledTestRoutes(admin, h) + + // 渠道管理 + registerChannelRoutes(admin, h) } } @@ -567,3 +570,15 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete) } } + +func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + channels := admin.Group("/channels") + { + channels.GET("", h.Admin.Channel.List) + channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing) + channels.GET("/:id", h.Admin.Channel.GetByID) + channels.POST("", h.Admin.Channel.Create) + channels.PUT("/:id", h.Admin.Channel.Update) + channels.DELETE("/:id", h.Admin.Channel.Delete) + } +} diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 004511f5dd..2fe13686a7 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -56,6 +56,7 @@ type ModelPricing struct { LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格 LongContextInputMultiplier float64 // 长上下文整次会话输入倍率 LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率 + ImageOutputPricePerToken float64 // 图片输出 token 价格 (USD) } const ( @@ -94,16 +95,19 @@ type UsageTokens struct { CacheReadTokens int CacheCreation5mTokens int CacheCreation1hTokens int + ImageOutputTokens int } // CostBreakdown 费用明细 type CostBreakdown struct { InputCost float64 OutputCost float64 + ImageOutputCost float64 CacheCreationCost float64 CacheReadCost float64 TotalCost float64 ActualCost float64 // 应用倍率后的实际费用 + BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充 } // BillingService 计费服务 @@ -357,6 +361,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold, LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier, LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier, + ImageOutputPricePerToken: litellmPricing.OutputCostPerImageToken, }), nil } } @@ -371,81 +376,252 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { return nil, fmt.Errorf("pricing not found for model: %s", model) } -// CalculateCost 计算使用费用 -func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) { - return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "") -} - -func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) { +// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值 +// 仅覆盖渠道中非 nil 的价格字段,nil 字段使用默认定价 +func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing *ChannelModelPricing) (*ModelPricing, error) { pricing, err := s.GetModelPricing(model) if err != nil { return nil, err } + if channelPricing == nil { + return pricing, nil + } + if channelPricing.InputPrice != nil { + pricing.InputPricePerToken = *channelPricing.InputPrice + pricing.InputPricePerTokenPriority = *channelPricing.InputPrice + } + if channelPricing.OutputPrice != nil { + pricing.OutputPricePerToken = *channelPricing.OutputPrice + pricing.OutputPricePerTokenPriority = *channelPricing.OutputPrice + } + if channelPricing.CacheWritePrice != nil { + pricing.CacheCreationPricePerToken = *channelPricing.CacheWritePrice + pricing.CacheCreation5mPrice = *channelPricing.CacheWritePrice + pricing.CacheCreation1hPrice = *channelPricing.CacheWritePrice + } + if channelPricing.CacheReadPrice != nil { + pricing.CacheReadPricePerToken = *channelPricing.CacheReadPrice + pricing.CacheReadPricePerTokenPriority = *channelPricing.CacheReadPrice + } + if channelPricing.ImageOutputPrice != nil { + pricing.ImageOutputPricePerToken = *channelPricing.ImageOutputPrice + } + return pricing, nil +} + +// --- 统一计费入口 --- + +// CostInput 统一计费输入 +type CostInput struct { + Ctx context.Context + Model string + GroupID *int64 // 用于渠道定价查找 + Tokens UsageTokens + RequestCount int // 按次计费时使用 + SizeTier string // 按次/图片模式的层级标签("1K","2K","4K","HD" 等) + RateMultiplier float64 + ServiceTier string // "priority","flex","" 等 + Resolver *ModelPricingResolver // 定价解析器 + Resolved *ResolvedPricing // 可选:预解析的定价结果(避免重复 Resolve 调用) +} + +// CalculateCostUnified 统一计费入口,支持三种计费模式。 +// 使用 ModelPricingResolver 解析定价,然后根据 BillingMode 分发计算。 +func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, error) { + if input.Resolver == nil { + // 无 Resolver,回退到旧路径 + return s.calculateCostInternal(input.Model, input.Tokens, input.RateMultiplier, input.ServiceTier, nil) + } + + // 优先使用预解析结果,避免重复 Resolve 调用 + resolved := input.Resolved + if resolved == nil { + resolved = input.Resolver.Resolve(input.Ctx, PricingInput{ + Model: input.Model, + GroupID: input.GroupID, + }) + } + + if input.RateMultiplier <= 0 { + input.RateMultiplier = 1.0 + } - breakdown := &CostBreakdown{} - inputPricePerToken := pricing.InputPricePerToken - outputPricePerToken := pricing.OutputPricePerToken - cacheReadPricePerToken := pricing.CacheReadPricePerToken + var breakdown *CostBreakdown + var err error + switch resolved.Mode { + case BillingModePerRequest, BillingModeImage: + breakdown, err = s.calculatePerRequestCost(resolved, input) + default: // BillingModeToken + breakdown, err = s.calculateTokenCost(resolved, input) + } + if err == nil && breakdown != nil { + breakdown.BillingMode = string(resolved.Mode) + if breakdown.BillingMode == "" { + breakdown.BillingMode = string(BillingModeToken) + } + } + return breakdown, err +} + +// calculateTokenCost 按 token 区间计费 +func (s *BillingService) calculateTokenCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) { + totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens + + pricing := input.Resolver.GetIntervalPricing(resolved, totalContext) + if pricing == nil { + return nil, fmt.Errorf("no pricing available for model: %s", input.Model) + } + + pricing = s.applyModelSpecificPricingPolicy(input.Model, pricing) + + // 长上下文定价仅在无区间定价时应用(区间定价已包含上下文分层) + applyLongCtx := len(resolved.Intervals) == 0 + + return s.computeTokenBreakdown(pricing, input.Tokens, input.RateMultiplier, input.ServiceTier, applyLongCtx), nil +} + +// computeTokenBreakdown 是 token 计费的核心逻辑,由 calculateTokenCost 和 calculateCostInternal 共用。 +// applyLongCtx 控制是否检查长上下文定价(区间定价已自含上下文分层,不需要额外应用)。 +func (s *BillingService) computeTokenBreakdown( + pricing *ModelPricing, tokens UsageTokens, + rateMultiplier float64, serviceTier string, + applyLongCtx bool, +) *CostBreakdown { + if rateMultiplier <= 0 { + rateMultiplier = 1.0 + } + + inputPrice := pricing.InputPricePerToken + outputPrice := pricing.OutputPricePerToken + cacheReadPrice := pricing.CacheReadPricePerToken tierMultiplier := 1.0 + if usePriorityServiceTierPricing(serviceTier, pricing) { if pricing.InputPricePerTokenPriority > 0 { - inputPricePerToken = pricing.InputPricePerTokenPriority + inputPrice = pricing.InputPricePerTokenPriority } if pricing.OutputPricePerTokenPriority > 0 { - outputPricePerToken = pricing.OutputPricePerTokenPriority + outputPrice = pricing.OutputPricePerTokenPriority } if pricing.CacheReadPricePerTokenPriority > 0 { - cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority + cacheReadPrice = pricing.CacheReadPricePerTokenPriority } } else { tierMultiplier = serviceTierCostMultiplier(serviceTier) } - if s.shouldApplySessionLongContextPricing(tokens, pricing) { - inputPricePerToken *= pricing.LongContextInputMultiplier - outputPricePerToken *= pricing.LongContextOutputMultiplier + + if applyLongCtx && s.shouldApplySessionLongContextPricing(tokens, pricing) { + inputPrice *= pricing.LongContextInputMultiplier + outputPrice *= pricing.LongContextOutputMultiplier } - // 计算输入token费用(使用per-token价格) - breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken + bd := &CostBreakdown{} + bd.InputCost = float64(tokens.InputTokens) * inputPrice - // 计算输出token费用 - breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken + // 分离图片输出 token 与文本输出 token + textOutputTokens := tokens.OutputTokens - tokens.ImageOutputTokens + if textOutputTokens < 0 { + textOutputTokens = 0 + } + bd.OutputCost = float64(textOutputTokens) * outputPrice - // 计算缓存费用 + // 图片输出 token 费用(独立费率) + if tokens.ImageOutputTokens > 0 { + imgPrice := pricing.ImageOutputPricePerToken + if imgPrice == 0 { + imgPrice = outputPrice // 回退到常规输出价格 + } + bd.ImageOutputCost = float64(tokens.ImageOutputTokens) * imgPrice + } + + // 缓存创建费用 + bd.CacheCreationCost = s.computeCacheCreationCost(pricing, tokens) + + bd.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPrice + + if tierMultiplier != 1.0 { + bd.InputCost *= tierMultiplier + bd.OutputCost *= tierMultiplier + bd.ImageOutputCost *= tierMultiplier + bd.CacheCreationCost *= tierMultiplier + bd.CacheReadCost *= tierMultiplier + } + + bd.TotalCost = bd.InputCost + bd.OutputCost + bd.ImageOutputCost + + bd.CacheCreationCost + bd.CacheReadCost + bd.ActualCost = bd.TotalCost * rateMultiplier + + return bd +} + +// computeCacheCreationCost 计算缓存创建费用(支持 5m/1h 分类或标准计费)。 +func (s *BillingService) computeCacheCreationCost(pricing *ModelPricing, tokens UsageTokens) float64 { if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { - // 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token) if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 { // API 未返回 ephemeral 明细,回退到全部按 5m 单价计费 - breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice - } else { - breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice + - float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice + return float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice } - } else { - // 标准缓存创建价格(per-token) - breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken + return float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice + + float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice } + return float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken +} - breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken +// calculatePerRequestCost 按次/图片计费 +func (s *BillingService) calculatePerRequestCost(resolved *ResolvedPricing, input CostInput) (*CostBreakdown, error) { + count := input.RequestCount + if count <= 0 { + count = 1 + } - if tierMultiplier != 1.0 { - breakdown.InputCost *= tierMultiplier - breakdown.OutputCost *= tierMultiplier - breakdown.CacheCreationCost *= tierMultiplier - breakdown.CacheReadCost *= tierMultiplier + var unitPrice float64 + + if input.SizeTier != "" { + unitPrice = input.Resolver.GetRequestTierPrice(resolved, input.SizeTier) } - // 计算总费用 - breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + - breakdown.CacheCreationCost + breakdown.CacheReadCost + if unitPrice == 0 { + totalContext := input.Tokens.InputTokens + input.Tokens.CacheReadTokens + unitPrice = input.Resolver.GetRequestTierPriceByContext(resolved, totalContext) + } - // 应用倍率计算实际费用 - if rateMultiplier <= 0 { - rateMultiplier = 1.0 + // 回退到默认按次价格 + if unitPrice == 0 { + unitPrice = resolved.DefaultPerRequestPrice } - breakdown.ActualCost = breakdown.TotalCost * rateMultiplier - return breakdown, nil + totalCost := unitPrice * float64(count) + actualCost := totalCost * input.RateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + }, nil +} + +// CalculateCost 计算使用费用 +func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) { + return s.calculateCostInternal(model, tokens, rateMultiplier, "", nil) +} + +func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) { + return s.calculateCostInternal(model, tokens, rateMultiplier, serviceTier, nil) +} + +func (s *BillingService) calculateCostInternal(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string, channelPricing *ChannelModelPricing) (*CostBreakdown, error) { + var pricing *ModelPricing + var err error + if channelPricing != nil { + pricing, err = s.GetModelPricingWithChannel(model, channelPricing) + } else { + pricing, err = s.GetModelPricing(model) + } + if err != nil { + return nil, err + } + + // 旧路径始终检查长上下文定价(无区间定价概念) + return s.computeTokenBreakdown(pricing, tokens, rateMultiplier, serviceTier, true), nil } func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing { @@ -541,6 +717,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage CacheReadTokens: inRangeCacheTokens, CacheCreation5mTokens: tokens.CacheCreation5mTokens, CacheCreation1hTokens: tokens.CacheCreation1hTokens, + ImageOutputTokens: tokens.ImageOutputTokens, } inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) if err != nil { @@ -561,6 +738,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage return &CostBreakdown{ InputCost: inRangeCost.InputCost + outRangeCost.InputCost, OutputCost: inRangeCost.OutputCost, + ImageOutputCost: inRangeCost.ImageOutputCost, CacheCreationCost: inRangeCost.CacheCreationCost, CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost, TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost, @@ -662,8 +840,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag actualCost := totalCost * rateMultiplier return &CostBreakdown{ - TotalCost: totalCost, - ActualCost: actualCost, + TotalCost: totalCost, + ActualCost: actualCost, + BillingMode: string(BillingModeImage), } } diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go new file mode 100644 index 0000000000..1697ed6f86 --- /dev/null +++ b/backend/internal/service/channel.go @@ -0,0 +1,277 @@ +package service + +import ( + "fmt" + "sort" + "strings" + "time" +) + +// BillingMode 计费模式 +type BillingMode string + +const ( + BillingModeToken BillingMode = "token" // 按 token 区间计费 + BillingModePerRequest BillingMode = "per_request" // 按次计费(支持上下文窗口分层) + BillingModeImage BillingMode = "image" // 图片计费(当前按次,预留 token 计费) +) + +// IsValid 检查 BillingMode 是否为合法值 +func (m BillingMode) IsValid() bool { + switch m { + case BillingModeToken, BillingModePerRequest, BillingModeImage, "": + return true + } + return false +} + +const ( + BillingModelSourceRequested = "requested" + BillingModelSourceUpstream = "upstream" + BillingModelSourceChannelMapped = "channel_mapped" +) + +// Channel 渠道实体 +type Channel struct { + ID int64 + Name string + Description string + Status string + BillingModelSource string // "requested", "upstream", or "channel_mapped" + RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) + CreatedAt time.Time + UpdatedAt time.Time + + // 关联的分组 ID 列表 + GroupIDs []int64 + // 模型定价列表(每条含 Platform 字段) + ModelPricing []ChannelModelPricing + // 渠道级模型映射(按平台分组:platform → {src→dst}) + ModelMapping map[string]map[string]string +} + +// ChannelModelPricing 渠道模型定价条目 +type ChannelModelPricing struct { + ID int64 + ChannelID int64 + Platform string // 所属平台(anthropic/openai/gemini/...) + Models []string // 绑定的模型列表 + BillingMode BillingMode // 计费模式 + InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 + OutputPrice *float64 // 每 token 输出价格(USD) + CacheWritePrice *float64 // 缓存写入价格 + CacheReadPrice *float64 // 缓存读取价格 + ImageOutputPrice *float64 // 图片输出价格(向后兼容) + PerRequestPrice *float64 // 默认按次计费价格(USD) + Intervals []PricingInterval // 区间定价列表 + CreatedAt time.Time + UpdatedAt time.Time +} + +// PricingInterval 定价区间(token 区间 / 按次分层 / 图片分辨率分层) +type PricingInterval struct { + ID int64 + PricingID int64 + MinTokens int // 区间下界(含) + MaxTokens *int // 区间上界(不含),nil = 无上限 + TierLabel string // 层级标签(按次/图片模式:1K, 2K, 4K, HD 等) + InputPrice *float64 // token 模式:每 token 输入价 + OutputPrice *float64 // token 模式:每 token 输出价 + CacheWritePrice *float64 // token 模式:缓存写入价 + CacheReadPrice *float64 // token 模式:缓存读取价 + PerRequestPrice *float64 // 按次/图片模式:每次请求价格 + SortOrder int + CreatedAt time.Time + UpdatedAt time.Time +} + +// IsActive 判断渠道是否启用 +func (c *Channel) IsActive() bool { + return c.Status == StatusActive +} + +// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。 +// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。 +func (c *Channel) GetModelPricing(model string) *ChannelModelPricing { + modelLower := strings.ToLower(model) + + for i := range c.ModelPricing { + for _, m := range c.ModelPricing[i].Models { + if strings.ToLower(m) == modelLower { + cp := c.ModelPricing[i].Clone() + return &cp + } + } + } + + return nil +} + +// FindMatchingInterval 在区间列表中查找匹配 totalTokens 的区间。 +// 区间为左开右闭 (min, max]:min 不含,max 包含。 +// 第一个区间 min=0 时,0 token 不匹配任何区间(回退到默认价格)。 +func FindMatchingInterval(intervals []PricingInterval, totalTokens int) *PricingInterval { + for i := range intervals { + iv := &intervals[i] + if totalTokens > iv.MinTokens && (iv.MaxTokens == nil || totalTokens <= *iv.MaxTokens) { + return iv + } + } + return nil +} + +// GetIntervalForContext 根据总 context token 数查找匹配的区间。 +func (p *ChannelModelPricing) GetIntervalForContext(totalTokens int) *PricingInterval { + return FindMatchingInterval(p.Intervals, totalTokens) +} + +// GetTierByLabel 根据标签查找层级(用于 per_request / image 模式) +func (p *ChannelModelPricing) GetTierByLabel(label string) *PricingInterval { + labelLower := strings.ToLower(label) + for i := range p.Intervals { + if strings.ToLower(p.Intervals[i].TierLabel) == labelLower { + return &p.Intervals[i] + } + } + return nil +} + +// Clone 返回 ChannelModelPricing 的拷贝(切片独立,指针字段共享,调用方只读安全) +func (p ChannelModelPricing) Clone() ChannelModelPricing { + cp := p + if p.Models != nil { + cp.Models = make([]string, len(p.Models)) + copy(cp.Models, p.Models) + } + if p.Intervals != nil { + cp.Intervals = make([]PricingInterval, len(p.Intervals)) + copy(cp.Intervals, p.Intervals) + } + return cp +} + +// Clone 返回 Channel 的深拷贝 +func (c *Channel) Clone() *Channel { + if c == nil { + return nil + } + cp := *c + if c.GroupIDs != nil { + cp.GroupIDs = make([]int64, len(c.GroupIDs)) + copy(cp.GroupIDs, c.GroupIDs) + } + if c.ModelPricing != nil { + cp.ModelPricing = make([]ChannelModelPricing, len(c.ModelPricing)) + for i := range c.ModelPricing { + cp.ModelPricing[i] = c.ModelPricing[i].Clone() + } + } + if c.ModelMapping != nil { + cp.ModelMapping = make(map[string]map[string]string, len(c.ModelMapping)) + for platform, mapping := range c.ModelMapping { + inner := make(map[string]string, len(mapping)) + for k, v := range mapping { + inner[k] = v + } + cp.ModelMapping[platform] = inner + } + } + return &cp +} + +// ValidateIntervals 校验区间列表的合法性。 +// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens; +// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义); +// 无界区间(MaxTokens=nil)必须是最后一个。间隙允许(回退默认价格)。 +func ValidateIntervals(intervals []PricingInterval) error { + if len(intervals) == 0 { + return nil + } + sorted := make([]PricingInterval, len(intervals)) + copy(sorted, intervals) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].MinTokens < sorted[j].MinTokens + }) + + for i := range sorted { + if err := validateSingleInterval(&sorted[i], i); err != nil { + return err + } + } + return validateIntervalOverlap(sorted) +} + +// validateSingleInterval 校验单个区间的字段合法性 +func validateSingleInterval(iv *PricingInterval, idx int) error { + if iv.MinTokens < 0 { + return fmt.Errorf("interval #%d: min_tokens (%d) must be >= 0", idx+1, iv.MinTokens) + } + if iv.MaxTokens != nil { + if *iv.MaxTokens <= 0 { + return fmt.Errorf("interval #%d: max_tokens (%d) must be > 0", idx+1, *iv.MaxTokens) + } + if *iv.MaxTokens <= iv.MinTokens { + return fmt.Errorf("interval #%d: max_tokens (%d) must be > min_tokens (%d)", + idx+1, *iv.MaxTokens, iv.MinTokens) + } + } + return validateIntervalPrices(iv, idx) +} + +// validateIntervalPrices 校验区间内所有价格字段 >= 0 +func validateIntervalPrices(iv *PricingInterval, idx int) error { + prices := []struct { + name string + val *float64 + }{ + {"input_price", iv.InputPrice}, + {"output_price", iv.OutputPrice}, + {"cache_write_price", iv.CacheWritePrice}, + {"cache_read_price", iv.CacheReadPrice}, + {"per_request_price", iv.PerRequestPrice}, + } + for _, p := range prices { + if p.val != nil && *p.val < 0 { + return fmt.Errorf("interval #%d: %s must be >= 0", idx+1, p.name) + } + } + return nil +} + +// validateIntervalOverlap 校验排序后的区间列表无重叠,且无界区间在最后 +func validateIntervalOverlap(sorted []PricingInterval) error { + for i, iv := range sorted { + // 无界区间必须是最后一个 + if iv.MaxTokens == nil && i < len(sorted)-1 { + return fmt.Errorf("interval #%d: unbounded interval (max_tokens=null) must be the last one", + i+1) + } + if i == 0 { + continue + } + prev := sorted[i-1] + // 检查重叠:前一个区间的上界 > 当前区间的下界则重叠 + // (min, max] 语义:prev 覆盖 (prev.Min, prev.Max],cur 覆盖 (cur.Min, cur.Max] + if prev.MaxTokens == nil || *prev.MaxTokens > iv.MinTokens { + return fmt.Errorf("interval #%d and #%d overlap: prev max=%s > cur min=%d", + i, i+1, formatMaxTokensLabel(prev.MaxTokens), iv.MinTokens) + } + } + return nil +} + +func formatMaxTokensLabel(max *int) string { + if max == nil { + return "∞" + } + return fmt.Sprintf("%d", *max) +} + +// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中) +type ChannelUsageFields struct { + ChannelID int64 // 渠道 ID(0 = 无渠道) + OriginalModel string // 用户原始请求模型(渠道映射前) + ChannelMappedModel string // 渠道映射后的模型名(无映射时等于 OriginalModel) + BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped" + ModelMappingChain string // 映射链描述,如 "a→b→c" +} diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go new file mode 100644 index 0000000000..ec8310f67c --- /dev/null +++ b/backend/internal/service/channel_service.go @@ -0,0 +1,857 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync/atomic" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "golang.org/x/sync/singleflight" +) + +var ( + ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found") + ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists") + ErrGroupAlreadyInChannel = infraerrors.Conflict( + "GROUP_ALREADY_IN_CHANNEL", + "one or more groups already belong to another channel", + ) +) + +// ChannelRepository 渠道数据访问接口 +type ChannelRepository interface { + Create(ctx context.Context, channel *Channel) error + GetByID(ctx context.Context, id int64) (*Channel, error) + Update(ctx context.Context, channel *Channel) error + Delete(ctx context.Context, id int64) error + List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) + ListAll(ctx context.Context) ([]Channel, error) + ExistsByName(ctx context.Context, name string) (bool, error) + ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) + + // 分组关联 + GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) + SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error + GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) + GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) + + // 分组平台查询 + GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) + + // 模型定价 + ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) + CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error + UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error + DeleteModelPricing(ctx context.Context, id int64) error + ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error +} + +// channelModelKey 渠道缓存复合键(显式包含 platform 防止跨平台同名模型冲突) +type channelModelKey struct { + groupID int64 + platform string // 平台标识 + model string // lowercase +} + +// channelGroupPlatformKey 通配符定价缓存键 +type channelGroupPlatformKey struct { + groupID int64 + platform string +} + +// wildcardPricingEntry 通配符定价条目 +type wildcardPricingEntry struct { + prefix string + pricing *ChannelModelPricing +} + +// wildcardMappingEntry 通配符映射条目 +type wildcardMappingEntry struct { + prefix string + target string +} + +// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找) +type channelCache struct { + // 热路径查找 + pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价 + wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序) + mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标 + wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序) + channelByGroupID map[int64]*Channel // groupID → 渠道 + groupPlatform map[int64]string // groupID → platform + + // 冷路径(CRUD 操作) + byID map[int64]*Channel + loadedAt time.Time +} + +// ChannelMappingResult 渠道映射查找结果 +type ChannelMappingResult struct { + MappedModel string // 映射后的模型名(无映射时等于原始模型名) + ChannelID int64 // 渠道 ID(0 = 无渠道关联) + Mapped bool // 是否发生了映射 + BillingModelSource string // 计费模型来源("requested" / "upstream" / "channel_mapped") +} + +// BuildModelMappingChain 根据映射结果和上游实际模型构建映射链描述。 +// reqModel: 客户端请求的原始模型名。 +// upstreamModel: 上游实际使用的模型名(ForwardResult.UpstreamModel)。 +// 返回空字符串表示无映射。 +func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel string) string { + if !r.Mapped { + if upstreamModel != "" && upstreamModel != reqModel { + return reqModel + "→" + upstreamModel + } + return "" + } + if upstreamModel != "" && upstreamModel != r.MappedModel { + return reqModel + "→" + r.MappedModel + "→" + upstreamModel + } + return reqModel + "→" + r.MappedModel +} + +// ToUsageFields 将渠道映射结果转为使用记录字段 +func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields { + channelMappedModel := reqModel + if r.Mapped { + channelMappedModel = r.MappedModel + } + return ChannelUsageFields{ + ChannelID: r.ChannelID, + OriginalModel: reqModel, + ChannelMappedModel: channelMappedModel, + BillingModelSource: r.BillingModelSource, + ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel), + } +} + +const ( + channelCacheTTL = 10 * time.Minute + channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 + channelCacheDBTimeout = 10 * time.Second +) + +// ChannelService 渠道管理服务 +type ChannelService struct { + repo ChannelRepository + authCacheInvalidator APIKeyAuthCacheInvalidator + + cache atomic.Value // *channelCache + cacheSF singleflight.Group +} + +// NewChannelService 创建渠道服务实例 +func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService { + s := &ChannelService{ + repo: repo, + authCacheInvalidator: authCacheInvalidator, + } + return s +} + +// loadCache 加载或返回缓存的渠道数据 +func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) { + if cached, ok := s.cache.Load().(*channelCache); ok && cached != nil { + if time.Since(cached.loadedAt) < channelCacheTTL { + return cached, nil + } + } + + result, err, _ := s.cacheSF.Do("channel_cache", func() (any, error) { + // 双重检查 + if cached, ok := s.cache.Load().(*channelCache); ok && cached != nil { + if time.Since(cached.loadedAt) < channelCacheTTL { + return cached, nil + } + } + return s.buildCache(ctx) + }) + if err != nil { + return nil, err + } + cache, ok := result.(*channelCache) + if !ok { + return nil, fmt.Errorf("unexpected cache type") + } + return cache, nil +} + +// newEmptyChannelCache 创建空的渠道缓存(所有 map 已初始化) +func newEmptyChannelCache() *channelCache { + return &channelCache{ + pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), + wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry), + mappingByGroupModel: make(map[channelModelKey]string), + wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry), + channelByGroupID: make(map[int64]*Channel), + groupPlatform: make(map[int64]string), + byID: make(map[int64]*Channel), + } +} + +// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。 +// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。 +// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台, +// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。 +// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。 +func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { + for j := range ch.ModelPricing { + pricing := &ch.ModelPricing[j] + if !isPlatformPricingMatch(platform, pricing.Platform) { + continue // 跳过非本平台的定价 + } + // 使用定价条目的原始平台作为缓存 key,防止跨平台同名模型冲突 + pricingPlatform := pricing.Platform + gpKey := channelGroupPlatformKey{groupID: gid, platform: pricingPlatform} + for _, model := range pricing.Models { + if strings.HasSuffix(model, "*") { + prefix := strings.ToLower(strings.TrimSuffix(model, "*")) + cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{ + prefix: prefix, + pricing: pricing, + }) + } else { + key := channelModelKey{groupID: gid, platform: pricingPlatform, model: strings.ToLower(model)} + cache.pricingByGroupModel[key] = pricing + } + } + } +} + +// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。 +// antigravity 平台同时服务 Claude 和 Gemini 模型。 +// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。 +func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { + for _, mappingPlatform := range matchingPlatforms(platform) { + platformMapping, ok := ch.ModelMapping[mappingPlatform] + if !ok { + continue + } + // 使用映射条目的原始平台作为缓存 key,防止跨平台同名映射冲突 + gpKey := channelGroupPlatformKey{groupID: gid, platform: mappingPlatform} + for src, dst := range platformMapping { + if strings.HasSuffix(src, "*") { + prefix := strings.ToLower(strings.TrimSuffix(src, "*")) + cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{ + prefix: prefix, + target: dst, + }) + } else { + key := channelModelKey{groupID: gid, platform: mappingPlatform, model: strings.ToLower(src)} + cache.mappingByGroupModel[key] = dst + } + } + } +} + +// buildCache 从数据库构建渠道缓存。 +// 使用独立 context 避免请求取消导致空值被长期缓存。 +func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) { + // 断开请求取消链,避免客户端断连导致空值被长期缓存 + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout) + defer cancel() + + channels, err := s.repo.ListAll(dbCtx) + if err != nil { + // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 + slog.Warn("failed to build channel cache", "error", err) + errorCache := newEmptyChannelCache() + errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL + s.cache.Store(errorCache) + return nil, fmt.Errorf("list all channels: %w", err) + } + + // 收集所有 groupID,批量查询 platform + var allGroupIDs []int64 + for i := range channels { + allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...) + } + groupPlatforms := make(map[int64]string) + if len(allGroupIDs) > 0 { + groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs) + if err != nil { + slog.Warn("failed to load group platforms for channel cache", "error", err) + errorCache := newEmptyChannelCache() + errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) + s.cache.Store(errorCache) + return nil, fmt.Errorf("get group platforms: %w", err) + } + } + + cache := newEmptyChannelCache() + cache.groupPlatform = groupPlatforms + cache.byID = make(map[int64]*Channel, len(channels)) + cache.loadedAt = time.Now() + + for i := range channels { + ch := &channels[i] + cache.byID[ch.ID] = ch + + for _, gid := range ch.GroupIDs { + cache.channelByGroupID[gid] = ch + platform := groupPlatforms[gid] + expandPricingToCache(cache, ch, gid, platform) + expandMappingToCache(cache, ch, gid, platform) + } + } + + // 通配符条目保持配置顺序(最先匹配到优先) + + s.cache.Store(cache) + return cache, nil +} + +// invalidateCache 使缓存失效,让下次读取时自然重建 + +// isPlatformPricingMatch 判断定价条目的平台是否匹配分组平台。 +// antigravity 平台同时服务 Claude(anthropic)和 Gemini(gemini)模型, +// 因此 antigravity 分组应匹配 anthropic 和 gemini 的定价条目。 +func isPlatformPricingMatch(groupPlatform, pricingPlatform string) bool { + if groupPlatform == pricingPlatform { + return true + } + if groupPlatform == PlatformAntigravity { + return pricingPlatform == PlatformAnthropic || pricingPlatform == PlatformGemini + } + return false +} + +// matchingPlatforms 返回分组平台对应的所有可匹配平台列表。 +func matchingPlatforms(groupPlatform string) []string { + if groupPlatform == PlatformAntigravity { + return []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini} + } + return []string{groupPlatform} +} +func (s *ChannelService) invalidateCache() { + s.cache.Store((*channelCache)(nil)) + s.cacheSF.Forget("channel_cache") + + // 主动重建缓存,确保 CRUD 后立即生效 + if _, err := s.buildCache(context.Background()); err != nil { + slog.Warn("failed to rebuild channel cache after invalidation", "error", err) + } +} + +// matchWildcard 在通配符定价中查找匹配项(最先匹配到优先) +func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing { + gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform} + wildcards := c.wildcardByGroupPlatform[gpKey] + for _, wc := range wildcards { + if strings.HasPrefix(modelLower, wc.prefix) { + return wc.pricing + } + } + return nil +} + +// matchWildcardMapping 在通配符映射中查找匹配项(最先匹配到优先) +func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string { + gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform} + wildcards := c.wildcardMappingByGP[gpKey] + for _, wc := range wildcards { + if strings.HasPrefix(modelLower, wc.prefix) { + return wc.target + } + } + return "" +} + +// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。 +// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试 +// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini), +// 返回第一个命中的结果。非 antigravity 平台只尝试自身。 +func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing { + for _, p := range matchingPlatforms(groupPlatform) { + key := channelModelKey{groupID: groupID, platform: p, model: modelLower} + if pricing, ok := cache.pricingByGroupModel[key]; ok { + return pricing + } + } + // 精确查找全部失败,依次尝试通配符匹配 + for _, p := range matchingPlatforms(groupPlatform) { + if pricing := cache.matchWildcard(groupID, p, modelLower); pricing != nil { + return pricing + } + } + return nil +} + +// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。 +// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。 +func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string { + for _, p := range matchingPlatforms(groupPlatform) { + key := channelModelKey{groupID: groupID, platform: p, model: modelLower} + if mapped, ok := cache.mappingByGroupModel[key]; ok { + return mapped + } + } + for _, p := range matchingPlatforms(groupPlatform) { + if mapped := cache.matchWildcardMapping(groupID, p, modelLower); mapped != "" { + return mapped + } + } + return "" +} + +// GetChannelForGroup 获取分组关联的渠道(热路径 O(1)) +func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) { + cache, err := s.loadCache(ctx) + if err != nil { + return nil, err + } + + ch, ok := cache.channelByGroupID[groupID] + if !ok || !ch.IsActive() { + return nil, nil + } + + return ch.Clone(), nil +} + +// channelLookup 热路径公共查找结果 +type channelLookup struct { + cache *channelCache + channel *Channel + platform string +} + +// lookupGroupChannel 加载缓存并查找分组对应的渠道信息(公共热路径前置逻辑)。 +// 返回 nil 且 err==nil 表示分组无活跃渠道;err!=nil 表示缓存加载失败。 +func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) (*channelLookup, error) { + cache, err := s.loadCache(ctx) + if err != nil { + return nil, err + } + ch, ok := cache.channelByGroupID[groupID] + if !ok || !ch.IsActive() { + return nil, nil + } + return &channelLookup{ + cache: cache, + channel: ch, + platform: cache.groupPlatform[groupID], + }, nil +} + +// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。 +// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini), +// 确保跨平台同名模型各自独立匹配。 +func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing { + lk, err := s.lookupGroupChannel(ctx, groupID) + if err != nil { + slog.Warn("failed to load channel cache", "group_id", groupID, "error", err) + return nil + } + if lk == nil { + return nil + } + + modelLower := strings.ToLower(model) + pricing := lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower) + if pricing == nil { + return nil + } + + cp := pricing.Clone() + return &cp +} + +// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1)) +// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。 +func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { + lk, err := s.lookupGroupChannel(ctx, groupID) + if err != nil { + slog.Warn("failed to load channel cache for mapping", "group_id", groupID, "error", err) + } + if lk == nil { + return ChannelMappingResult{MappedModel: model} + } + return resolveMapping(lk, groupID, model) +} + +// IsModelRestricted 检查模型是否被渠道限制。 +// 返回 true 表示模型被限制(不在允许列表中)。 +// 如果渠道未启用模型限制或分组无渠道关联,返回 false。 +func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { + lk, _ := s.lookupGroupChannel(ctx, groupID) + if lk == nil { + return false + } + return checkRestricted(lk, groupID, model) +} + +// ResolveChannelMappingAndRestrict 解析渠道映射。 +// 返回映射结果。模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction), +// restricted 始终返回 false,保留签名兼容性。 +func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { + if groupID == nil { + return ChannelMappingResult{MappedModel: model}, false + } + lk, _ := s.lookupGroupChannel(ctx, *groupID) + if lk == nil { + return ChannelMappingResult{MappedModel: model}, false + } + return resolveMapping(lk, *groupID, model), false +} + +// resolveMapping 基于已查找的渠道信息解析模型映射。 +// antigravity 分组依次尝试所有匹配平台,确保跨平台同名映射各自独立。 +func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult { + result := ChannelMappingResult{ + MappedModel: model, + ChannelID: lk.channel.ID, + BillingModelSource: lk.channel.BillingModelSource, + } + if result.BillingModelSource == "" { + result.BillingModelSource = BillingModelSourceChannelMapped + } + + modelLower := strings.ToLower(model) + if mapped := lookupMappingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower); mapped != "" { + result.MappedModel = mapped + result.Mapped = true + } + + return result +} + +// checkRestricted 基于已查找的渠道信息检查模型是否被限制。 +// antigravity 分组依次尝试所有匹配平台的定价列表。 +func checkRestricted(lk *channelLookup, groupID int64, model string) bool { + if !lk.channel.RestrictModels { + return false + } + modelLower := strings.ToLower(model) + // 使用与查找定价相同的跨平台逻辑 + if lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower) != nil { + return false + } + return true +} + +// ReplaceModelInBody 替换请求体 JSON 中的 model 字段。 +func ReplaceModelInBody(body []byte, newModel string) []byte { + if len(body) == 0 { + return body + } + if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel { + return body + } + newBody, err := sjson.SetBytes(body, "model", newModel) + if err != nil { + return body + } + return newBody +} + +// --- CRUD --- + +// Create 创建渠道 +func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) (*Channel, error) { + exists, err := s.repo.ExistsByName(ctx, input.Name) + if err != nil { + return nil, fmt.Errorf("check channel exists: %w", err) + } + if exists { + return nil, ErrChannelExists + } + + // 检查分组冲突 + if len(input.GroupIDs) > 0 { + conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs) + if err != nil { + return nil, fmt.Errorf("check group conflicts: %w", err) + } + if len(conflicting) > 0 { + return nil, ErrGroupAlreadyInChannel + } + } + + channel := &Channel{ + Name: input.Name, + Description: input.Description, + Status: StatusActive, + BillingModelSource: input.BillingModelSource, + RestrictModels: input.RestrictModels, + GroupIDs: input.GroupIDs, + ModelPricing: input.ModelPricing, + ModelMapping: input.ModelMapping, + } + if channel.BillingModelSource == "" { + channel.BillingModelSource = BillingModelSourceChannelMapped + } + + if err := validateNoConflictingModels(channel.ModelPricing); err != nil { + return nil, err + } + if err := validatePricingIntervals(channel.ModelPricing); err != nil { + return nil, err + } + if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { + return nil, err + } + + if err := s.repo.Create(ctx, channel); err != nil { + return nil, fmt.Errorf("create channel: %w", err) + } + + s.invalidateCache() + return s.repo.GetByID(ctx, channel.ID) +} + +// GetByID 获取渠道详情 +func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) { + return s.repo.GetByID(ctx, id) +} + +// Update 更新渠道 +func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChannelInput) (*Channel, error) { + channel, err := s.repo.GetByID(ctx, id) + if err != nil { + return nil, fmt.Errorf("get channel: %w", err) + } + + if input.Name != "" && input.Name != channel.Name { + exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id) + if err != nil { + return nil, fmt.Errorf("check channel exists: %w", err) + } + if exists { + return nil, ErrChannelExists + } + channel.Name = input.Name + } + + if input.Description != nil { + channel.Description = *input.Description + } + + if input.Status != "" { + channel.Status = input.Status + } + + if input.RestrictModels != nil { + channel.RestrictModels = *input.RestrictModels + } + + // 检查分组冲突 + if input.GroupIDs != nil { + conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs) + if err != nil { + return nil, fmt.Errorf("check group conflicts: %w", err) + } + if len(conflicting) > 0 { + return nil, ErrGroupAlreadyInChannel + } + channel.GroupIDs = *input.GroupIDs + } + + if input.ModelPricing != nil { + channel.ModelPricing = *input.ModelPricing + } + + if input.ModelMapping != nil { + channel.ModelMapping = input.ModelMapping + } + + if input.BillingModelSource != "" { + channel.BillingModelSource = input.BillingModelSource + } + + if err := validateNoConflictingModels(channel.ModelPricing); err != nil { + return nil, err + } + if err := validatePricingIntervals(channel.ModelPricing); err != nil { + return nil, err + } + if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { + return nil, err + } + + // 先获取旧分组,Update 后旧分组关联已删除,无法再查到 + var oldGroupIDs []int64 + if s.authCacheInvalidator != nil { + var err2 error + oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id) + if err2 != nil { + slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2) + } + } + + if err := s.repo.Update(ctx, channel); err != nil { + return nil, fmt.Errorf("update channel: %w", err) + } + + s.invalidateCache() + + // 失效新旧分组的 auth 缓存 + if s.authCacheInvalidator != nil { + seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs)) + for _, gid := range oldGroupIDs { + if _, ok := seen[gid]; !ok { + seen[gid] = struct{}{} + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } + for _, gid := range channel.GroupIDs { + if _, ok := seen[gid]; !ok { + seen[gid] = struct{}{} + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } + } + + return s.repo.GetByID(ctx, id) +} + +// Delete 删除渠道 +func (s *ChannelService) Delete(ctx context.Context, id int64) error { + // 先获取关联分组用于失效缓存 + groupIDs, err := s.repo.GetGroupIDs(ctx, id) + if err != nil { + slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err) + } + + if err := s.repo.Delete(ctx, id); err != nil { + return fmt.Errorf("delete channel: %w", err) + } + + s.invalidateCache() + + if s.authCacheInvalidator != nil { + for _, gid := range groupIDs { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } + + return nil +} + +// List 获取渠道列表 +func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) { + return s.repo.List(ctx, params, status, search) +} + +// modelEntry 表示一个模型模式条目(用于冲突检测) +type modelEntry struct { + pattern string // 原始模式(如 "claude-*" 或 "claude-opus-4") + prefix string // lowercase 前缀(通配符去掉 *,精确名保持原样) + wildcard bool +} + +// conflictsBetween 检查两个模型模式是否冲突 +func conflictsBetween(a, b modelEntry) bool { + switch { + case !a.wildcard && !b.wildcard: + return a.prefix == b.prefix + case a.wildcard && !b.wildcard: + return strings.HasPrefix(b.prefix, a.prefix) + case !a.wildcard && b.wildcard: + return strings.HasPrefix(a.prefix, b.prefix) + default: + return strings.HasPrefix(a.prefix, b.prefix) || + strings.HasPrefix(b.prefix, a.prefix) + } +} + +// toModelEntry 将模型名转换为 modelEntry +func toModelEntry(pattern string) modelEntry { + lower := strings.ToLower(pattern) + isWild := strings.HasSuffix(lower, "*") + prefix := lower + if isWild { + prefix = strings.TrimSuffix(lower, "*") + } + return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild} +} + +// validateNoConflictingModels 检查定价列表中是否有冲突模型模式(同一平台下)。 +// 冲突包括:精确重复、通配符之间的前缀包含、通配符与精确名的前缀匹配。 +func validateNoConflictingModels(pricingList []ChannelModelPricing) error { + byPlatform := make(map[string][]modelEntry) + for _, p := range pricingList { + for _, model := range p.Models { + byPlatform[p.Platform] = append(byPlatform[p.Platform], toModelEntry(model)) + } + } + for platform, entries := range byPlatform { + if err := detectConflicts(entries, platform, "MODEL_PATTERN_CONFLICT", "model patterns"); err != nil { + return err + } + } + return nil +} + +// validateNoConflictingMappings 检查模型映射中是否有冲突的源模式 +func validateNoConflictingMappings(mapping map[string]map[string]string) error { + for platform, platformMapping := range mapping { + entries := make([]modelEntry, 0, len(platformMapping)) + for src := range platformMapping { + entries = append(entries, toModelEntry(src)) + } + if err := detectConflicts(entries, platform, "MAPPING_PATTERN_CONFLICT", "mapping source patterns"); err != nil { + return err + } + } + return nil +} + +func validatePricingIntervals(pricingList []ChannelModelPricing) error { + for _, pricing := range pricingList { + if err := ValidateIntervals(pricing.Intervals); err != nil { + return infraerrors.BadRequest( + "INVALID_PRICING_INTERVALS", + fmt.Sprintf("invalid pricing intervals for platform '%s' models %v: %v", + pricing.Platform, pricing.Models, err), + ) + } + } + return nil +} + +// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误 +func detectConflicts(entries []modelEntry, platform, errCode, label string) error { + for i := 0; i < len(entries); i++ { + for j := i + 1; j < len(entries); j++ { + if conflictsBetween(entries[i], entries[j]) { + return infraerrors.BadRequest(errCode, + fmt.Sprintf("%s '%s' and '%s' conflict in platform '%s': overlapping match range", + label, entries[i].pattern, entries[j].pattern, platform)) + } + } + } + return nil +} + +// --- Input types --- + +// CreateChannelInput 创建渠道输入 +type CreateChannelInput struct { + Name string + Description string + GroupIDs []int64 + ModelPricing []ChannelModelPricing + ModelMapping map[string]map[string]string // platform → {src→dst} + BillingModelSource string + RestrictModels bool +} + +// UpdateChannelInput 更新渠道输入 +type UpdateChannelInput struct { + Name string + Description *string + Status string + GroupIDs *[]int64 + ModelPricing *[]ChannelModelPricing + ModelMapping map[string]map[string]string // platform → {src→dst} + BillingModelSource string + RestrictModels *bool +} diff --git a/backend/internal/service/channel_service_test.go b/backend/internal/service/channel_service_test.go new file mode 100644 index 0000000000..56bde56cc0 --- /dev/null +++ b/backend/internal/service/channel_service_test.go @@ -0,0 +1,2187 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mock: ChannelRepository +// --------------------------------------------------------------------------- + +type mockChannelRepository struct { + listAllFn func(ctx context.Context) ([]Channel, error) + getGroupPlatformsFn func(ctx context.Context, groupIDs []int64) (map[int64]string, error) + createFn func(ctx context.Context, channel *Channel) error + getByIDFn func(ctx context.Context, id int64) (*Channel, error) + updateFn func(ctx context.Context, channel *Channel) error + deleteFn func(ctx context.Context, id int64) error + listFn func(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) + existsByNameFn func(ctx context.Context, name string) (bool, error) + existsByNameExcludingFn func(ctx context.Context, name string, excludeID int64) (bool, error) + getGroupIDsFn func(ctx context.Context, channelID int64) ([]int64, error) + setGroupIDsFn func(ctx context.Context, channelID int64, groupIDs []int64) error + getChannelIDByGroupIDFn func(ctx context.Context, groupID int64) (int64, error) + getGroupsInOtherChannelsFn func(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) + listModelPricingFn func(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) + createModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error + updateModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error + deleteModelPricingFn func(ctx context.Context, id int64) error + replaceModelPricingFn func(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error +} + +func (m *mockChannelRepository) Create(ctx context.Context, channel *Channel) error { + if m.createFn != nil { + return m.createFn(ctx, channel) + } + return nil +} + +func (m *mockChannelRepository) GetByID(ctx context.Context, id int64) (*Channel, error) { + if m.getByIDFn != nil { + return m.getByIDFn(ctx, id) + } + return nil, ErrChannelNotFound +} + +func (m *mockChannelRepository) Update(ctx context.Context, channel *Channel) error { + if m.updateFn != nil { + return m.updateFn(ctx, channel) + } + return nil +} + +func (m *mockChannelRepository) Delete(ctx context.Context, id int64) error { + if m.deleteFn != nil { + return m.deleteFn(ctx, id) + } + return nil +} + +func (m *mockChannelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) { + if m.listFn != nil { + return m.listFn(ctx, params, status, search) + } + return nil, nil, nil +} + +func (m *mockChannelRepository) ListAll(ctx context.Context) ([]Channel, error) { + if m.listAllFn != nil { + return m.listAllFn(ctx) + } + return nil, nil +} + +func (m *mockChannelRepository) ExistsByName(ctx context.Context, name string) (bool, error) { + if m.existsByNameFn != nil { + return m.existsByNameFn(ctx, name) + } + return false, nil +} + +func (m *mockChannelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) { + if m.existsByNameExcludingFn != nil { + return m.existsByNameExcludingFn(ctx, name, excludeID) + } + return false, nil +} + +func (m *mockChannelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) { + if m.getGroupIDsFn != nil { + return m.getGroupIDsFn(ctx, channelID) + } + return nil, nil +} + +func (m *mockChannelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error { + if m.setGroupIDsFn != nil { + return m.setGroupIDsFn(ctx, channelID, groupIDs) + } + return nil +} + +func (m *mockChannelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + if m.getChannelIDByGroupIDFn != nil { + return m.getChannelIDByGroupIDFn(ctx, groupID) + } + return 0, nil +} + +func (m *mockChannelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) { + if m.getGroupsInOtherChannelsFn != nil { + return m.getGroupsInOtherChannelsFn(ctx, channelID, groupIDs) + } + return nil, nil +} + +func (m *mockChannelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { + if m.getGroupPlatformsFn != nil { + return m.getGroupPlatformsFn(ctx, groupIDs) + } + return nil, nil +} + +func (m *mockChannelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) { + if m.listModelPricingFn != nil { + return m.listModelPricingFn(ctx, channelID) + } + return nil, nil +} + +func (m *mockChannelRepository) CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error { + if m.createModelPricingFn != nil { + return m.createModelPricingFn(ctx, pricing) + } + return nil +} + +func (m *mockChannelRepository) UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error { + if m.updateModelPricingFn != nil { + return m.updateModelPricingFn(ctx, pricing) + } + return nil +} + +func (m *mockChannelRepository) DeleteModelPricing(ctx context.Context, id int64) error { + if m.deleteModelPricingFn != nil { + return m.deleteModelPricingFn(ctx, id) + } + return nil +} + +func (m *mockChannelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error { + if m.replaceModelPricingFn != nil { + return m.replaceModelPricingFn(ctx, channelID, pricingList) + } + return nil +} + +// --------------------------------------------------------------------------- +// Mock: APIKeyAuthCacheInvalidator +// --------------------------------------------------------------------------- + +type mockChannelAuthCacheInvalidator struct { + invalidatedGroupIDs []int64 + invalidatedKeys []string + invalidatedUserIDs []int64 +} + +func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByKey(_ context.Context, key string) { + m.invalidatedKeys = append(m.invalidatedKeys, key) +} + +func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) { + m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID) +} + +func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context.Context, groupID int64) { + m.invalidatedGroupIDs = append(m.invalidatedGroupIDs, groupID) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func newTestChannelService(repo *mockChannelRepository) *ChannelService { + return NewChannelService(repo, nil) +} + +func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService { + return NewChannelService(repo, auth) +} + +// makeStandardRepo returns a repo that serves one active channel with anthropic pricing +// for group 1, with the given model pricing and model mapping. +func makeStandardRepo(ch Channel, groupPlatforms map[int64]string) *mockChannelRepository { + return &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return []Channel{ch}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return groupPlatforms, nil + }, + } +} + +// =========================================================================== +// 1. BuildModelMappingChain +// =========================================================================== + +func TestBuildModelMappingChain(t *testing.T) { + tests := []struct { + name string + result ChannelMappingResult + requestModel string + upstreamModel string + want string + }{ + { + name: "no mapping, no upstream diff", + result: ChannelMappingResult{Mapped: false, MappedModel: "claude-sonnet-4"}, + requestModel: "claude-sonnet-4", + upstreamModel: "claude-sonnet-4", + want: "", + }, + { + name: "no mapping, upstream differs", + result: ChannelMappingResult{Mapped: false, MappedModel: "claude-sonnet-4"}, + requestModel: "claude-sonnet-4", + upstreamModel: "claude-sonnet-4-20250514", + want: "claude-sonnet-4\u2192claude-sonnet-4-20250514", + }, + { + name: "mapped, upstream differs", + result: ChannelMappingResult{Mapped: true, MappedModel: "claude-sonnet-4-20250514"}, + requestModel: "my-model", + upstreamModel: "actual-upstream", + want: "my-model\u2192claude-sonnet-4-20250514\u2192actual-upstream", + }, + { + name: "mapped, upstream same as mapped", + result: ChannelMappingResult{Mapped: true, MappedModel: "claude-sonnet-4-20250514"}, + requestModel: "claude-sonnet-4", + upstreamModel: "claude-sonnet-4-20250514", + want: "claude-sonnet-4\u2192claude-sonnet-4-20250514", + }, + { + name: "mapped, upstream empty", + result: ChannelMappingResult{Mapped: true, MappedModel: "target-model"}, + requestModel: "my-model", + upstreamModel: "", + want: "my-model\u2192target-model", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.result.BuildModelMappingChain(tt.requestModel, tt.upstreamModel) + require.Equal(t, tt.want, got) + }) + } +} + +// =========================================================================== +// 2. ReplaceModelInBody +// =========================================================================== + +func TestReplaceModelInBody(t *testing.T) { + tests := []struct { + name string + body []byte + newModel string + check func(t *testing.T, result []byte) + }{ + { + name: "empty body", + body: []byte{}, + newModel: "new-model", + check: func(t *testing.T, result []byte) { + require.Equal(t, []byte{}, result) + }, + }, + { + name: "model already equal", + body: []byte(`{"model":"claude-sonnet-4","temperature":0.7}`), + newModel: "claude-sonnet-4", + check: func(t *testing.T, result []byte) { + require.Equal(t, []byte(`{"model":"claude-sonnet-4","temperature":0.7}`), result) + }, + }, + { + name: "model different", + body: []byte(`{"model":"claude-sonnet-4","temperature":0.7}`), + newModel: "claude-opus-4", + check: func(t *testing.T, result []byte) { + require.Contains(t, string(result), `"model":"claude-opus-4"`) + require.Contains(t, string(result), `"temperature"`) + }, + }, + { + name: "no model field", + body: []byte(`{"temperature":0.7}`), + newModel: "claude-opus-4", + check: func(t *testing.T, result []byte) { + require.Contains(t, string(result), `"model":"claude-opus-4"`) + require.Contains(t, string(result), `"temperature"`) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ReplaceModelInBody(tt.body, tt.newModel) + tt.check(t, result) + }) + } +} + +// =========================================================================== +// 3. validateNoConflictingModels + validateNoConflictingMappings +// =========================================================================== + +func TestValidateNoConflictingModels(t *testing.T) { + tests := []struct { + name string + pricingList []ChannelModelPricing + wantErr bool + errContains string + }{ + { + name: "no duplicates", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4", "claude-opus-4"}}, + {Platform: "openai", Models: []string{"gpt-5.1"}}, + }, + wantErr: false, + }, + { + name: "same platform duplicate", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + }, + wantErr: true, + errContains: "claude-sonnet-4", + }, + { + name: "same model different platform", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"model-a"}}, + {Platform: "openai", Models: []string{"model-a"}}, + }, + wantErr: false, + }, + { + name: "case insensitive", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"Claude"}}, + {Platform: "anthropic", Models: []string{"claude"}}, + }, + wantErr: true, + }, + { + name: "empty list (nil)", + pricingList: nil, + wantErr: false, + }, + { + name: "wildcard_vs_wildcard_conflict", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-*"}}, + {Platform: "anthropic", Models: []string{"claude-opus-*"}}, + }, + wantErr: true, + errContains: "conflict", + }, + { + name: "wildcard_vs_exact_conflict", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-*"}}, + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + wantErr: true, + errContains: "conflict", + }, + { + name: "no_conflict_different_platform", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-*"}}, + {Platform: "openai", Models: []string{"claude-*"}}, + }, + wantErr: false, + }, + { + name: "no_conflict_same_platform_different_prefix", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-*"}}, + {Platform: "anthropic", Models: []string{"gpt-*"}}, + }, + wantErr: false, + }, + { + name: "catch_all_wildcard_conflicts_with_everything", + pricingList: []ChannelModelPricing{ + {Platform: "openai", Models: []string{"*"}}, + {Platform: "openai", Models: []string{"gpt-5"}}, + }, + wantErr: true, + errContains: "conflict", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateNoConflictingModels(tt.pricingList) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + require.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } + + // Additional sub-case: explicit empty slice + t.Run("empty list (empty slice)", func(t *testing.T) { + err := validateNoConflictingModels([]ChannelModelPricing{}) + require.NoError(t, err) + }) +} + +func TestValidateNoConflictingMappings(t *testing.T) { + tests := []struct { + name string + mapping map[string]map[string]string + wantErr bool + errContains string + }{ + { + name: "nil mapping", + mapping: nil, + wantErr: false, + }, + { + name: "empty mapping", + mapping: map[string]map[string]string{}, + wantErr: false, + }, + { + name: "no conflict", + mapping: map[string]map[string]string{ + "anthropic": {"claude-opus-*": "opus", "gpt-*": "gpt"}, + }, + wantErr: false, + }, + { + name: "wildcard vs wildcard conflict", + mapping: map[string]map[string]string{ + "anthropic": {"claude-*": "a", "claude-opus-*": "b"}, + }, + wantErr: true, + errContains: "conflict", + }, + { + name: "wildcard vs exact conflict", + mapping: map[string]map[string]string{ + "openai": {"gpt-*": "a", "gpt-4o": "b"}, + }, + wantErr: true, + errContains: "conflict", + }, + { + name: "exact duplicate conflict", + mapping: map[string]map[string]string{ + "anthropic": {"claude-opus-4": "a"}, + "openai": {"claude-opus-4": "b"}, + }, + wantErr: false, // different platforms + }, + { + name: "different platforms no conflict", + mapping: map[string]map[string]string{ + "anthropic": {"claude-*": "a"}, + "openai": {"claude-*": "b"}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateNoConflictingMappings(tt.mapping) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + require.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestConflictsBetween(t *testing.T) { + tests := []struct { + name string + a, b modelEntry + want bool + }{ + { + name: "exact same", + a: modelEntry{prefix: "claude-opus-4", wildcard: false}, + b: modelEntry{prefix: "claude-opus-4", wildcard: false}, + want: true, + }, + { + name: "exact different", + a: modelEntry{prefix: "claude-opus-4", wildcard: false}, + b: modelEntry{prefix: "gpt-4o", wildcard: false}, + want: false, + }, + { + name: "wildcard matches exact", + a: modelEntry{prefix: "claude-", wildcard: true}, + b: modelEntry{prefix: "claude-opus-4", wildcard: false}, + want: true, + }, + { + name: "exact does not match unrelated wildcard", + a: modelEntry{prefix: "gpt-4o", wildcard: false}, + b: modelEntry{prefix: "claude-", wildcard: true}, + want: false, + }, + { + name: "wildcard prefix overlap", + a: modelEntry{prefix: "claude-", wildcard: true}, + b: modelEntry{prefix: "claude-opus-", wildcard: true}, + want: true, + }, + { + name: "wildcards no overlap", + a: modelEntry{prefix: "claude-", wildcard: true}, + b: modelEntry{prefix: "gpt-", wildcard: true}, + want: false, + }, + { + name: "catch-all wildcard vs any", + a: modelEntry{prefix: "", wildcard: true}, + b: modelEntry{prefix: "anything", wildcard: false}, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, conflictsBetween(tt.a, tt.b)) + }) + } +} + +// =========================================================================== +// 4. Cache Building + Hot Path Methods +// =========================================================================== + +// --- 4.1 GetChannelForGroup --- + +func TestGetChannelForGroup_Success(t *testing.T) { + ch := Channel{ + ID: 1, + Name: "test-channel", + Status: StatusActive, + GroupIDs: []int64{10}, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, int64(1), result.ID) + require.Equal(t, "test-channel", result.Name) + + // returned value should be a clone + result.Name = "mutated" + result2, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.Equal(t, "test-channel", result2.Name) +} + +func TestGetChannelForGroup_InactiveChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusDisabled, + GroupIDs: []int64{10}, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.Nil(t, result) +} + +func TestGetChannelForGroup_NoChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result, err := svc.GetChannelForGroup(context.Background(), 999) + require.NoError(t, err) + require.Nil(t, result) +} + +func TestGetChannelForGroup_CacheError(t *testing.T) { + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, errors.New("db connection failed") + }, + } + svc := newTestChannelService(repo) + + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "db connection failed") +} + +// --- 4.2 GetChannelModelPricing --- + +func TestGetChannelModelPricing_ExactMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) + require.InDelta(t, 15e-6, *result.InputPrice, 1e-12) +} + +func TestGetChannelModelPricing_CaseInsensitive(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "Claude-Opus-4") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) +} + +func TestGetChannelModelPricing_WildcardMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 200, Platform: "anthropic", Models: []string{"claude-*"}, InputPrice: testPtrFloat64(10e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4") + require.NotNil(t, result) + require.Equal(t, int64(200), result.ID) +} + +func TestGetChannelModelPricing_WildcardFirstMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 200, Platform: "anthropic", Models: []string{"claude-*"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 300, Platform: "anthropic", Models: []string{"claude-sonnet-*"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4-20250514") + require.NotNil(t, result) + // "claude-*" is defined first, so it matches first regardless of prefix length + require.Equal(t, int64(200), result.ID) + require.InDelta(t, 10e-6, *result.InputPrice, 1e-12) +} + +func TestGetChannelModelPricing_NoMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1") + require.Nil(t, result) +} + +func TestGetChannelModelPricing_InactiveChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusDisabled, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.Nil(t, result) +} + +func TestGetChannelModelPricing_PlatformFiltering(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "openai", Models: []string{"gpt-5.1"}, InputPrice: testPtrFloat64(5e-6)}, + {ID: 200, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic", 20: "openai"}) + svc := newTestChannelService(repo) + + // Group 10 (anthropic) should NOT see openai pricing + result := svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1") + require.Nil(t, result) + + // Group 10 (anthropic) should see anthropic pricing + result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, int64(200), result.ID) + + // Group 20 (openai) should see openai pricing + result = svc.GetChannelModelPricing(context.Background(), 20, "gpt-5.1") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) + + // Group 20 (openai) should NOT see anthropic pricing + result = svc.GetChannelModelPricing(context.Background(), 20, "claude-opus-4") + require.Nil(t, result) +} + +func TestGetChannelModelPricing_ReturnsCopy(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + + // Mutate the returned pricing's slice fields — original cache should not be affected + // (Clone copies slices independently, pointer fields are shared per design) + result.Models = append(result.Models, "hacked") + result.ID = 999 + + // Original cache should not be affected (slice independence + struct copy) + result2 := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result2) + require.Equal(t, 1, len(result2.Models)) + require.Equal(t, int64(100), result2.ID) +} + +// --- 4.3 ResolveChannelMapping --- + +func TestResolveChannelMapping_NoChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + // Group 999 is not in any channel + result := svc.ResolveChannelMapping(context.Background(), 999, "claude-opus-4") + require.Equal(t, "claude-opus-4", result.MappedModel) + require.False(t, result.Mapped) + require.Equal(t, int64(0), result.ChannelID) +} + +func TestResolveChannelMapping_ExactMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4": "claude-sonnet-4-20250514", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4") + require.True(t, result.Mapped) + require.Equal(t, "claude-sonnet-4-20250514", result.MappedModel) + require.Equal(t, int64(1), result.ChannelID) +} + +func TestResolveChannelMapping_WildcardMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "*": "gpt-5.4", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "any-model-name") + require.True(t, result.Mapped) + require.Equal(t, "gpt-5.4", result.MappedModel) +} + +func TestResolveChannelMapping_WildcardFirstMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-*": "target2", + "claude-sonnet-*": "target1", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4") + require.True(t, result.Mapped) + // map iteration order is non-deterministic, so the first-match depends on + // insertion order which Go maps don't guarantee; verify that one of the + // wildcard targets matched + require.Contains(t, []string{"target1", "target2"}, result.MappedModel) +} + +func TestResolveChannelMapping_NoMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4": "mapped", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") + require.False(t, result.Mapped) + require.Equal(t, "claude-opus-4", result.MappedModel) + require.Equal(t, int64(1), result.ChannelID) +} + +func TestResolveChannelMapping_DefaultBillingModelSource(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + BillingModelSource: "", // empty + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") + require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource) +} + +func TestResolveChannelMapping_UpstreamBillingModelSource(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + BillingModelSource: BillingModelSourceUpstream, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") + require.Equal(t, BillingModelSourceUpstream, result.BillingModelSource) +} + +func TestResolveChannelMapping_InactiveChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusDisabled, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4": "mapped", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4") + require.False(t, result.Mapped) + require.Equal(t, "claude-sonnet-4", result.MappedModel) + require.Equal(t, int64(0), result.ChannelID) // no channel +} + +// --- 4.4 IsModelRestricted --- + +func TestIsModelRestricted_NoChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + // Group 999 is not in any channel + restricted := svc.IsModelRestricted(context.Background(), 999, "claude-opus-4") + require.False(t, restricted) +} + +func TestIsModelRestricted_RestrictDisabled(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: false, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + // Even though model is not in pricing, RestrictModels=false + restricted := svc.IsModelRestricted(context.Background(), 10, "nonexistent-model") + require.False(t, restricted) +} + +func TestIsModelRestricted_InactiveChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusDisabled, + GroupIDs: []int64{10}, + RestrictModels: true, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "any-model") + require.False(t, restricted) +} + +func TestIsModelRestricted_ModelInPricing(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4", "claude-sonnet-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "claude-opus-4") + require.False(t, restricted) +} + +func TestIsModelRestricted_ModelInWildcard(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-*"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "claude-sonnet-4") + require.False(t, restricted) +} + +func TestIsModelRestricted_ModelNotFound(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "gpt-5.1") + require.True(t, restricted) +} + +func TestIsModelRestricted_CaseInsensitive(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "Claude-Opus-4") + require.False(t, restricted) +} + +// --- 4.5 ResolveChannelMappingAndRestrict --- +// 注意:模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction), +// ResolveChannelMappingAndRestrict 仅做映射,restricted 始终为 false。 + +func TestResolveChannelMappingAndRestrict_NilGroupID(t *testing.T) { + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), nil, "claude-opus-4") + require.False(t, restricted) + require.False(t, mapping.Mapped) + require.Equal(t, "claude-opus-4", mapping.MappedModel) +} + +func TestResolveChannelMappingAndRestrict_WithMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4": "claude-sonnet-4-20250514", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + gid := int64(10) + mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "claude-sonnet-4") + require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段 + require.True(t, mapping.Mapped) + require.Equal(t, "claude-sonnet-4-20250514", mapping.MappedModel) +} + +func TestResolveChannelMappingAndRestrict_NoMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + gid := int64(10) + mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model") + require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段 + require.False(t, mapping.Mapped) + require.Equal(t, "unknown-model", mapping.MappedModel) +} + +// --- 4.6 Cache Building Specifics --- + +func TestBuildCache_DBError(t *testing.T) { + callCount := 0 + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + callCount++ + return nil, errors.New("database down") + }, + } + svc := newTestChannelService(repo) + + // First call should fail + _, err := svc.GetChannelForGroup(context.Background(), 10) + require.Error(t, err) + require.Contains(t, err.Error(), "database down") + require.Equal(t, 1, callCount) + + // Second call within error-TTL should use error cache, but still return error + // Because buildCache stores error-TTL cache and returns error, the cached value + // is still within TTL and loadCache returns it (which is an empty cache). + // Actually, re-reading the code: buildCache returns nil, err, and the error cache + // only serves as a "don't retry immediately" mechanism. The singleflight.Do + // returns the error. On next call within error-TTL, the cache has an empty but + // valid entry, so loadCache returns it (with empty maps). GetChannelForGroup + // will find nothing and return nil, nil. + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.Nil(t, result) + // Should NOT have hit DB again (error-TTL cache is active) + require.Equal(t, 1, callCount) +} + +func TestBuildCache_GroupPlatformError(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return []Channel{ch}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return nil, errors.New("group platforms failed") + }, + } + svc := newTestChannelService(repo) + + // Should fail-close: error propagated when group platforms cannot be loaded + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.Error(t, err) + require.Nil(t, result) + + // Within error-TTL, second call should hit cache (empty) and return nil, nil + result2, err2 := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err2) + require.Nil(t, result2) +} + +func TestBuildCache_MultipleGroupsSameChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20, 30}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{ + 10: "anthropic", + 20: "anthropic", + 30: "anthropic", + }) + svc := newTestChannelService(repo) + + for _, gid := range []int64{10, 20, 30} { + result := svc.GetChannelModelPricing(context.Background(), gid, "claude-opus-4") + require.NotNil(t, result, "group %d should have pricing", gid) + require.Equal(t, int64(100), result.ID) + } +} + +func TestBuildCache_PlatformFiltering(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + {ID: 200, Platform: "openai", Models: []string{"gpt-5.1"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{ + 10: "anthropic", + 20: "openai", + }) + svc := newTestChannelService(repo) + + // anthropic group sees only anthropic models + require.NotNil(t, svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")) + require.Nil(t, svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1")) + + // openai group sees only openai models + require.NotNil(t, svc.GetChannelModelPricing(context.Background(), 20, "gpt-5.1")) + require.Nil(t, svc.GetChannelModelPricing(context.Background(), 20, "claude-opus-4")) +} + +func TestBuildCache_WildcardPreservesConfigOrder(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + // Configuration order: shortest prefix first + {ID: 100, Platform: "anthropic", Models: []string{"c-*"}, InputPrice: testPtrFloat64(1e-6)}, + {ID: 200, Platform: "anthropic", Models: []string{"c-son-*"}, InputPrice: testPtrFloat64(2e-6)}, + {ID: 300, Platform: "anthropic", Models: []string{"c-son-4-*"}, InputPrice: testPtrFloat64(3e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + // "c-son-4-xxx" matches all three wildcards, but "c-*" (ID=100) is first in config + result := svc.GetChannelModelPricing(context.Background(), 10, "c-son-4-xxx") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) + + // "c-son-yyy" matches "c-*" and "c-son-*", but "c-*" (ID=100) is first + result = svc.GetChannelModelPricing(context.Background(), 10, "c-son-yyy") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) + + // "c-other" only matches "c-*" (ID=100) + result = svc.GetChannelModelPricing(context.Background(), 10, "c-other") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) +} + +// --- 4.7 invalidateCache --- + +func TestInvalidateCache(t *testing.T) { + callCount := 0 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + callCount++ + return []Channel{ch}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return map[int64]string{10: "anthropic"}, nil + }, + } + svc := newTestChannelService(repo) + + // First load + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, 1, callCount) + + // Second call should use cache + result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, 1, callCount) // no new DB call + + // Invalidate + svc.invalidateCache() + + // Next call should rebuild from DB + result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, 2, callCount) // rebuilt +} + +// =========================================================================== +// 5. CRUD Methods +// =========================================================================== + +// --- 5.1 Create --- + +func TestCreate_Success(t *testing.T) { + createdID := int64(42) + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) { + return nil, nil + }, + createFn: func(_ context.Context, ch *Channel) error { + ch.ID = createdID + return nil + }, + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return &Channel{ID: id, Name: "new-channel", Status: StatusActive}, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + GroupIDs: []int64{10}, + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, createdID, result.ID) +} + +func TestCreate_NameExists(t *testing.T) { + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return true, nil + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "existing-channel", + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrChannelExists) +} + +func TestCreate_GroupConflict(t *testing.T) { + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) { + return []int64{10}, nil // group 10 already in another channel + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + GroupIDs: []int64{10, 20}, + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrGroupAlreadyInChannel) +} + +func TestCreate_DuplicateModel(t *testing.T) { + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, // duplicate + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "claude-opus-4") +} + +func TestCreate_InvalidPricingIntervals(t *testing.T) { + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + ModelPricing: []ChannelModelPricing{ + { + Platform: "anthropic", + Models: []string{"claude-opus-4"}, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(2000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 1000, MaxTokens: testPtrInt(3000), InputPrice: testPtrFloat64(2e-6)}, + }, + }, + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS") + require.Contains(t, err.Error(), "overlap") +} + +func TestCreate_DefaultBillingModelSource(t *testing.T) { + var capturedChannel *Channel + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + createFn: func(_ context.Context, ch *Channel) error { + capturedChannel = ch + ch.ID = 1 + return nil + }, + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return capturedChannel, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + BillingModelSource: "", // empty, should default to "channel_mapped" + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, BillingModelSourceChannelMapped, result.BillingModelSource) +} + +func TestCreate_InvalidatesCache(t *testing.T) { + loadCount := 0 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + loadCount++ + return []Channel{ch}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return map[int64]string{10: "anthropic"}, nil + }, + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + createFn: func(_ context.Context, c *Channel) error { + c.ID = 2 + return nil + }, + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return &Channel{ID: id, Name: "new", Status: StatusActive}, nil + }, + } + svc := newTestChannelService(repo) + + // Load cache + _ = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.Equal(t, 1, loadCount) + + // Create triggers cache invalidation + _, err := svc.Create(context.Background(), &CreateChannelInput{Name: "new"}) + require.NoError(t, err) + + // Next cache access should rebuild + _ = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.Equal(t, 2, loadCount) +} + +// --- 5.2 Update --- + +func TestUpdate_Success(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return existing.Clone(), nil + }, + updateFn: func(_ context.Context, _ *Channel) error { + return nil + }, + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Name: "updated-name", + Description: testPtrString("new desc"), + }) + require.NoError(t, err) + require.NotNil(t, result) +} + +func TestUpdate_NotFound(t *testing.T) { + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return nil, ErrChannelNotFound + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Update(context.Background(), 999, &UpdateChannelInput{ + Name: "whatever", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "channel") +} + +func TestUpdate_NameConflict(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + existsByNameExcludingFn: func(_ context.Context, _ string, _ int64) (bool, error) { + return true, nil // name conflicts with another channel + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Name: "conflicting-name", + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrChannelExists) +} + +func TestUpdate_GroupConflict(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) { + return []int64{20}, nil // group 20 in another channel + }, + } + svc := newTestChannelService(repo) + + newGroupIDs := []int64{10, 20} + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + GroupIDs: &newGroupIDs, + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrGroupAlreadyInChannel) +} + +func TestUpdate_DuplicateModel(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + } + svc := newTestChannelService(repo) + + dupPricing := []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + } + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + ModelPricing: &dupPricing, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "claude-opus-4") +} + +func TestUpdate_InvalidPricingIntervals(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + } + svc := newTestChannelService(repo) + + invalidPricing := []ChannelModelPricing{ + { + Platform: "anthropic", + Models: []string{"claude-opus-4"}, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 2000, MaxTokens: testPtrInt(4000), InputPrice: testPtrFloat64(2e-6)}, + }, + }, + } + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + ModelPricing: &invalidPricing, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS") + require.Contains(t, err.Error(), "unbounded") +} + +func TestUpdate_InvalidatesChannelCache(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + loadCount := 0 + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + updateFn: func(_ context.Context, _ *Channel) error { + return nil + }, + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return []int64{10, 20}, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + loadCount++ + return []Channel{*existing}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + // Load cache first + _, _ = svc.GetChannelForGroup(context.Background(), 10) + require.Equal(t, 1, loadCount) + + result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Description: testPtrString("updated"), + }) + require.NoError(t, err) + require.NotNil(t, result) + + // Channel cache should be invalidated (next access rebuilds) + _, _ = svc.GetChannelForGroup(context.Background(), 10) + require.Equal(t, 2, loadCount) +} + +func TestUpdate_InvalidatesAuthCache(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + auth := &mockChannelAuthCacheInvalidator{} + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + updateFn: func(_ context.Context, _ *Channel) error { + return nil + }, + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return []int64{10, 20}, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelServiceWithAuth(repo, auth) + + result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Description: testPtrString("updated"), + }) + require.NoError(t, err) + require.NotNil(t, result) + + // Auth cache should be invalidated for both groups + require.ElementsMatch(t, []int64{10, 20}, auth.invalidatedGroupIDs) +} + +// --- 5.3 Delete --- + +func TestChannelDelete_Success(t *testing.T) { + deleted := false + repo := &mockChannelRepository{ + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, nil + }, + deleteFn: func(_ context.Context, _ int64) error { + deleted = true + return nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + err := svc.Delete(context.Background(), 1) + require.NoError(t, err) + require.True(t, deleted) +} + +func TestChannelDelete_InvalidatesCaches(t *testing.T) { + auth := &mockChannelAuthCacheInvalidator{} + loadCount := 0 + repo := &mockChannelRepository{ + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return []int64{10, 20}, nil + }, + deleteFn: func(_ context.Context, _ int64) error { + return nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + loadCount++ + return []Channel{{ID: 1, Status: StatusActive, GroupIDs: []int64{10, 20}}}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return nil, nil + }, + } + svc := newTestChannelServiceWithAuth(repo, auth) + + // Load cache first + _, _ = svc.GetChannelForGroup(context.Background(), 10) + require.Equal(t, 1, loadCount) + + err := svc.Delete(context.Background(), 1) + require.NoError(t, err) + + // Auth cache invalidated for both groups + require.ElementsMatch(t, []int64{10, 20}, auth.invalidatedGroupIDs) + + // Channel cache invalidated + _, _ = svc.GetChannelForGroup(context.Background(), 10) + require.Equal(t, 2, loadCount) +} + +func TestChannelDelete_NotFound(t *testing.T) { + repo := &mockChannelRepository{ + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, nil + }, + deleteFn: func(_ context.Context, _ int64) error { + return errors.New("record not found") + }, + } + svc := newTestChannelService(repo) + + err := svc.Delete(context.Background(), 999) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") +} + +// =========================================================================== +// 6. Edge Case Tests +// =========================================================================== + +// --- 6.1 Create with empty GroupIDs --- + +func TestCreate_NoGroups(t *testing.T) { + createdID := int64(55) + getGroupsInOtherChannelsCalled := false + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) { + getGroupsInOtherChannelsCalled = true + return nil, nil + }, + createFn: func(_ context.Context, ch *Channel) error { + ch.ID = createdID + return nil + }, + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return &Channel{ID: id, Name: "no-groups-channel", Status: StatusActive}, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "no-groups-channel", + GroupIDs: []int64{}, // empty slice + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, createdID, result.ID) + // GetGroupsInOtherChannels should NOT have been called (skipped by len(input.GroupIDs) > 0) + require.False(t, getGroupsInOtherChannelsCalled) +} + +// --- 6.2 Update only Status --- + +func TestUpdate_StatusOnly(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "test-channel", + Status: StatusActive, + } + var capturedChannel *Channel + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return existing.Clone(), nil + }, + updateFn: func(_ context.Context, ch *Channel) error { + capturedChannel = ch + return nil + }, + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Status: StatusDisabled, + }) + require.NoError(t, err) + require.NotNil(t, result) + // Verify that the channel passed to repo.Update has the new status + require.NotNil(t, capturedChannel) + require.Equal(t, StatusDisabled, capturedChannel.Status) + // Name should remain unchanged + require.Equal(t, "test-channel", capturedChannel.Name) +} + +// --- 6.3 Delete when GetGroupIDs fails --- + +func TestChannelDelete_GetGroupIDsError(t *testing.T) { + deleted := false + repo := &mockChannelRepository{ + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, errors.New("group IDs lookup failed") + }, + deleteFn: func(_ context.Context, _ int64) error { + deleted = true + return nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + // Delete should still succeed even though GetGroupIDs returned error (degradation path L588-591) + err := svc.Delete(context.Background(), 1) + require.NoError(t, err) + require.True(t, deleted) +} + +// --- 6.4 ReplaceModelInBody with invalid JSON --- + +func TestReplaceModelInBody_InvalidJSON(t *testing.T) { + // Case 1: broken JSON object — gjson won't find "model", sjson does best-effort set + // (no panic, no error from sjson, but result is mutated garbage) + brokenBody := []byte("{broken") + result := ReplaceModelInBody(brokenBody, "new-model") + require.NotNil(t, result) + // sjson does not error on this input, so result differs from original — just verify no panic + + // Case 2: JSON array — sjson.SetBytes returns error on non-object, + // triggering the L447 error fallback path that returns original body. + arrayBody := []byte("[]") + result2 := ReplaceModelInBody(arrayBody, "new-model") + require.Equal(t, arrayBody, result2) +} + +// =========================================================================== +// 7. isPlatformPricingMatch +// =========================================================================== + +func TestIsPlatformPricingMatch(t *testing.T) { + tests := []struct { + name string + groupPlatform string + pricingPlatform string + want bool + }{ + {"antigravity matches anthropic", PlatformAntigravity, PlatformAnthropic, true}, + {"antigravity matches gemini", PlatformAntigravity, PlatformGemini, true}, + {"antigravity matches antigravity", PlatformAntigravity, PlatformAntigravity, true}, + {"antigravity does NOT match openai", PlatformAntigravity, PlatformOpenAI, false}, + {"anthropic matches anthropic", PlatformAnthropic, PlatformAnthropic, true}, + {"anthropic does NOT match antigravity", PlatformAnthropic, PlatformAntigravity, false}, + {"anthropic does NOT match gemini", PlatformAnthropic, PlatformGemini, false}, + {"gemini matches gemini", PlatformGemini, PlatformGemini, true}, + {"gemini does NOT match antigravity", PlatformGemini, PlatformAntigravity, false}, + {"gemini does NOT match anthropic", PlatformGemini, PlatformAnthropic, false}, + {"empty string matches nothing", "", PlatformAnthropic, false}, + {"empty string matches empty", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, isPlatformPricingMatch(tt.groupPlatform, tt.pricingPlatform)) + }) + } +} + +// =========================================================================== +// 8. matchingPlatforms +// =========================================================================== + +func TestMatchingPlatforms(t *testing.T) { + tests := []struct { + name string + groupPlatform string + want []string + }{ + {"antigravity returns all three", PlatformAntigravity, []string{PlatformAntigravity, PlatformAnthropic, PlatformGemini}}, + {"anthropic returns itself", PlatformAnthropic, []string{PlatformAnthropic}}, + {"gemini returns itself", PlatformGemini, []string{PlatformGemini}}, + {"openai returns itself", PlatformOpenAI, []string{PlatformOpenAI}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := matchingPlatforms(tt.groupPlatform) + require.Equal(t, tt.want, result) + }) + } +} + +// =========================================================================== +// 9. Antigravity cross-platform channel pricing +// =========================================================================== + +func TestGetChannelModelPricing_AntigravityCrossPlatform(t *testing.T) { + // Channel has anthropic pricing for claude-opus-4-6. + // Group 10 is antigravity — should see the anthropic pricing. + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: PlatformAnthropic, Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4-6") + require.NotNil(t, result, "antigravity group should see anthropic pricing") + require.Equal(t, int64(100), result.ID) + require.InDelta(t, 15e-6, *result.InputPrice, 1e-12) +} + +func TestGetChannelModelPricing_AnthropicCannotSeeAntigravityPricing(t *testing.T) { + // Channel has antigravity-platform pricing for claude-opus-4-6. + // Group 10 is anthropic — should NOT see antigravity pricing (no cross-platform leakage). + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: PlatformAntigravity, Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAnthropic}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4-6") + require.Nil(t, result, "anthropic group should NOT see antigravity-platform pricing") +} + +// =========================================================================== +// 10. Antigravity cross-platform model mapping +// =========================================================================== + +func TestResolveChannelMapping_AntigravityCrossPlatform(t *testing.T) { + // Channel has anthropic model mapping: claude-opus-4-5 → claude-opus-4-6. + // Group 10 is antigravity — should apply the anthropic mapping. + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + PlatformAnthropic: { + "claude-opus-4-5": "claude-opus-4-6", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4-5") + require.True(t, result.Mapped, "antigravity group should apply anthropic mapping") + require.Equal(t, "claude-opus-4-6", result.MappedModel) + require.Equal(t, int64(1), result.ChannelID) +} + +// =========================================================================== +// 11. Antigravity cross-platform same-name model — no overwrite +// =========================================================================== + +func TestGetChannelModelPricing_AntigravitySameModelDifferentPlatforms(t *testing.T) { + // anthropic 和 gemini 都定义了同名模型 "shared-model",价格不同。 + // antigravity 分组应能分别查到各自的定价,而不是后者覆盖前者。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 200, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 201, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + // antigravity 分组查找 "shared-model":应命中第一个匹配(按 matchingPlatforms 顺序 antigravity→anthropic→gemini) + result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model") + require.NotNil(t, result, "antigravity group should find pricing for shared-model") + // 第一个匹配应该是 anthropic(matchingPlatforms 返回 [antigravity, anthropic, gemini]) + require.Equal(t, int64(200), result.ID) + require.InDelta(t, 10e-6, *result.InputPrice, 1e-12) +} + +func TestGetChannelModelPricing_AntigravityOnlyGeminiPricing(t *testing.T) { + // 只有 gemini 平台定义了模型 "gemini-model"。 + // antigravity 分组应能查到 gemini 的定价。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 300, Platform: PlatformGemini, Models: []string{"gemini-model"}, InputPrice: testPtrFloat64(2e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "gemini-model") + require.NotNil(t, result, "antigravity group should find gemini pricing") + require.Equal(t, int64(300), result.ID) + require.InDelta(t, 2e-6, *result.InputPrice, 1e-12) +} + +func TestGetChannelModelPricing_AntigravityWildcardCrossPlatformNoOverwrite(t *testing.T) { + // anthropic 和 gemini 都有 "shared-*" 通配符定价,价格不同。 + // antigravity 分组查找 "shared-model" 应命中第一个匹配而非被覆盖。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 400, Platform: PlatformAnthropic, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 401, Platform: PlatformGemini, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model") + require.NotNil(t, result, "antigravity group should find wildcard pricing for shared-model") + // 两个通配符都存在,应命中 anthropic 的(matchingPlatforms 顺序) + require.Equal(t, int64(400), result.ID) + require.InDelta(t, 10e-6, *result.InputPrice, 1e-12) +} + +func TestResolveChannelMapping_AntigravitySameModelDifferentPlatforms(t *testing.T) { + // anthropic 和 gemini 都定义了同名模型映射 "alias" → 不同目标。 + // antigravity 分组应命中 anthropic 的映射(按 matchingPlatforms 顺序)。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + PlatformAnthropic: {"alias": "anthropic-target"}, + PlatformGemini: {"alias": "gemini-target"}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "alias") + require.True(t, result.Mapped) + require.Equal(t, "anthropic-target", result.MappedModel) +} + +func TestCheckRestricted_AntigravitySameModelDifferentPlatforms(t *testing.T) { + // anthropic 和 gemini 都定义了同名模型 "shared-model"。 + // antigravity 分组启用了 RestrictModels,"shared-model" 应不被限制。 + ch := Channel{ + ID: 1, + Status: StatusActive, + RestrictModels: true, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 500, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 501, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "shared-model") + require.False(t, restricted, "shared-model should not be restricted for antigravity") + + // 未定义的模型应被限制 + restricted = svc.IsModelRestricted(context.Background(), 10, "unknown-model") + require.True(t, restricted, "unknown-model should be restricted for antigravity") +} + +func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) { + // 确保非 antigravity 平台的行为不受影响。 + // anthropic 分组只能看到 anthropic 的定价,看不到 gemini 的。 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20}, + ModelPricing: []ChannelModelPricing{ + {ID: 600, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 601, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAnthropic, 20: PlatformGemini}) + svc := newTestChannelService(repo) + + // anthropic 分组应该只看到 anthropic 的定价 + result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model") + require.NotNil(t, result) + require.Equal(t, int64(600), result.ID) + require.InDelta(t, 10e-6, *result.InputPrice, 1e-12) + + // gemini 分组应该只看到 gemini 的定价 + result = svc.GetChannelModelPricing(context.Background(), 20, "shared-model") + require.NotNil(t, result) + require.Equal(t, int64(601), result.ID) + require.InDelta(t, 5e-6, *result.InputPrice, 1e-12) +} diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go new file mode 100644 index 0000000000..deac64d629 --- /dev/null +++ b/backend/internal/service/channel_test.go @@ -0,0 +1,435 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetModelPricing(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: testPtrFloat64(3e-6)}, + {ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest}, + }, + } + + tests := []struct { + name string + model string + wantID int64 + wantNil bool + }{ + {"exact match", "claude-sonnet-4", 1, false}, + {"case insensitive", "Claude-Sonnet-4", 1, false}, + {"not found", "gemini-3.1-pro", 0, true}, + {"wildcard pattern not matched", "claude-opus-4-20250514", 0, true}, + {"per_request model", "gpt-5.1", 3, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ch.GetModelPricing(tt.model) + if tt.wantNil { + require.Nil(t, result) + return + } + require.NotNil(t, result) + require.Equal(t, tt.wantID, result.ID) + }) + } +} + +func TestGetModelPricing_ReturnsCopy(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: testPtrFloat64(3e-6)}, + }, + } + + result := ch.GetModelPricing("claude-sonnet-4") + require.NotNil(t, result) + + // Modify the returned copy's slice — original should be unchanged + result.Models = append(result.Models, "hacked") + + // Original should be unchanged + require.Equal(t, 1, len(ch.ModelPricing[0].Models)) +} + +func TestGetModelPricing_EmptyPricing(t *testing.T) { + ch := &Channel{ModelPricing: nil} + require.Nil(t, ch.GetModelPricing("any-model")) + + ch2 := &Channel{ModelPricing: []ChannelModelPricing{}} + require.Nil(t, ch2.GetModelPricing("any-model")) +} + +func TestGetIntervalForContext(t *testing.T) { + p := &ChannelModelPricing{ + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)}, + }, + } + + tests := []struct { + name string + tokens int + wantPrice *float64 + wantNil bool + }{ + {"first interval", 50000, testPtrFloat64(1e-6), false}, + // (min, max] — 128000 在第一个区间的 max,包含,所以匹配第一个 + {"boundary: max of first (inclusive)", 128000, testPtrFloat64(1e-6), false}, + // 128001 > 128000,匹配第二个区间 + {"boundary: just above first max", 128001, testPtrFloat64(2e-6), false}, + {"unbounded interval", 500000, testPtrFloat64(2e-6), false}, + // (0, max] — 0 不匹配任何区间(左开) + {"zero tokens: no match", 0, nil, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := p.GetIntervalForContext(tt.tokens) + if tt.wantNil { + require.Nil(t, result) + return + } + require.NotNil(t, result) + require.InDelta(t, *tt.wantPrice, *result.InputPrice, 1e-12) + }) + } +} + +func TestGetIntervalForContext_NoMatch(t *testing.T) { + p := &ChannelModelPricing{ + Intervals: []PricingInterval{ + {MinTokens: 10000, MaxTokens: testPtrInt(50000)}, + }, + } + require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min + require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open) + require.NotNil(t, p.GetIntervalForContext(50000)) // 50000 <= 50000 (right-closed) + require.Nil(t, p.GetIntervalForContext(50001)) // 50001 > 50000 +} + +func TestGetIntervalForContext_Empty(t *testing.T) { + p := &ChannelModelPricing{Intervals: nil} + require.Nil(t, p.GetIntervalForContext(1000)) +} + +func TestGetTierByLabel(t *testing.T) { + p := &ChannelModelPricing{ + Intervals: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)}, + {TierLabel: "HD", PerRequestPrice: testPtrFloat64(0.12)}, + }, + } + + tests := []struct { + name string + label string + wantNil bool + want float64 + }{ + {"exact match", "1K", false, 0.04}, + {"case insensitive", "hd", false, 0.12}, + {"not found", "4K", true, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := p.GetTierByLabel(tt.label) + if tt.wantNil { + require.Nil(t, result) + return + } + require.NotNil(t, result) + require.InDelta(t, tt.want, *result.PerRequestPrice, 1e-12) + }) + } +} + +func TestGetTierByLabel_Empty(t *testing.T) { + p := &ChannelModelPricing{Intervals: nil} + require.Nil(t, p.GetTierByLabel("1K")) +} + +func TestChannelClone(t *testing.T) { + original := &Channel{ + ID: 1, + Name: "test", + GroupIDs: []int64{10, 20}, + ModelPricing: []ChannelModelPricing{ + { + ID: 100, + Models: []string{"model-a"}, + InputPrice: testPtrFloat64(5e-6), + }, + }, + } + + cloned := original.Clone() + require.NotNil(t, cloned) + require.Equal(t, original.ID, cloned.ID) + require.Equal(t, original.Name, cloned.Name) + + // Modify clone slices — original should not change + cloned.GroupIDs[0] = 999 + require.Equal(t, int64(10), original.GroupIDs[0]) + + cloned.ModelPricing[0].Models[0] = "hacked" + require.Equal(t, "model-a", original.ModelPricing[0].Models[0]) +} + +func TestChannelClone_Nil(t *testing.T) { + var ch *Channel + require.Nil(t, ch.Clone()) +} + +func TestChannelModelPricingClone(t *testing.T) { + original := ChannelModelPricing{ + Models: []string{"a", "b"}, + Intervals: []PricingInterval{ + {MinTokens: 0, TierLabel: "tier1"}, + }, + } + + cloned := original.Clone() + + // Modify clone slices — original unchanged + cloned.Models[0] = "hacked" + require.Equal(t, "a", original.Models[0]) + + cloned.Intervals[0].TierLabel = "hacked" + require.Equal(t, "tier1", original.Intervals[0].TierLabel) +} + +// --- BillingMode.IsValid --- + +func TestBillingModeIsValid(t *testing.T) { + tests := []struct { + name string + mode BillingMode + want bool + }{ + {"token", BillingModeToken, true}, + {"per_request", BillingModePerRequest, true}, + {"image", BillingModeImage, true}, + {"empty", BillingMode(""), true}, + {"unknown", BillingMode("unknown"), false}, + {"random", BillingMode("xyz"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, tt.mode.IsValid()) + }) + } +} + +// --- Channel.IsActive --- + +func TestChannelIsActive(t *testing.T) { + tests := []struct { + name string + status string + want bool + }{ + {"active", StatusActive, true}, + {"disabled", "disabled", false}, + {"empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := &Channel{Status: tt.status} + require.Equal(t, tt.want, ch.IsActive()) + }) + } +} + +// --- ChannelModelPricing.Clone edge cases --- + +func TestChannelModelPricingClone_EdgeCases(t *testing.T) { + t.Run("nil models", func(t *testing.T) { + original := ChannelModelPricing{Models: nil} + cloned := original.Clone() + require.Nil(t, cloned.Models) + }) + + t.Run("nil intervals", func(t *testing.T) { + original := ChannelModelPricing{Intervals: nil} + cloned := original.Clone() + require.Nil(t, cloned.Intervals) + }) + + t.Run("empty models", func(t *testing.T) { + original := ChannelModelPricing{Models: []string{}} + cloned := original.Clone() + require.NotNil(t, cloned.Models) + require.Empty(t, cloned.Models) + }) +} + +// --- Channel.Clone edge cases --- + +func TestChannelClone_EdgeCases(t *testing.T) { + t.Run("nil model mapping", func(t *testing.T) { + original := &Channel{ID: 1, ModelMapping: nil} + cloned := original.Clone() + require.Nil(t, cloned.ModelMapping) + }) + + t.Run("nil model pricing", func(t *testing.T) { + original := &Channel{ID: 1, ModelPricing: nil} + cloned := original.Clone() + require.Nil(t, cloned.ModelPricing) + }) + + t.Run("deep copy model mapping", func(t *testing.T) { + original := &Channel{ + ID: 1, + ModelMapping: map[string]map[string]string{ + "openai": {"gpt-4": "gpt-4-turbo"}, + }, + } + cloned := original.Clone() + + // Modify the cloned nested map + cloned.ModelMapping["openai"]["gpt-4"] = "hacked" + + // Original must remain unchanged + require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"]) + }) +} + +// --- ValidateIntervals --- + +func TestValidateIntervals_Empty(t *testing.T) { + require.NoError(t, ValidateIntervals(nil)) + require.NoError(t, ValidateIntervals([]PricingInterval{})) +} + +func TestValidateIntervals_ValidIntervals(t *testing.T) { + tests := []struct { + name string + intervals []PricingInterval + }{ + { + name: "single bounded interval", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + }, + }, + { + name: "two intervals with gap", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)}, + }, + }, + { + name: "two contiguous intervals", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)}, + }, + }, + { + name: "unsorted input (auto-sorted by validator)", + intervals: []PricingInterval{ + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)}, + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + }, + }, + { + name: "single unbounded interval", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.NoError(t, ValidateIntervals(tt.intervals)) + }) + } +} + +func TestValidateIntervals_NegativeMinTokens(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: -1, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "min_tokens") + require.Contains(t, err.Error(), ">= 0") +} + +func TestValidateIntervals_MaxTokensZero(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(0), InputPrice: testPtrFloat64(1e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "max_tokens") + require.Contains(t, err.Error(), "> 0") +} + +func TestValidateIntervals_MaxLessThanMin(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 100, MaxTokens: testPtrInt(50), InputPrice: testPtrFloat64(1e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "max_tokens") + require.Contains(t, err.Error(), "> min_tokens") +} + +func TestValidateIntervals_MaxEqualsMin(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 100, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "max_tokens") + require.Contains(t, err.Error(), "> min_tokens") +} + +func TestValidateIntervals_NegativePrice(t *testing.T) { + negPrice := -0.01 + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(100), InputPrice: &negPrice}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "input_price") + require.Contains(t, err.Error(), ">= 0") +} + +func TestValidateIntervals_OverlappingIntervals(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(200), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 100, MaxTokens: testPtrInt(300), InputPrice: testPtrFloat64(2e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "overlap") +} + +func TestValidateIntervals_UnboundedNotLast(t *testing.T) { + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: testPtrInt(256000), InputPrice: testPtrFloat64(2e-6)}, + } + err := ValidateIntervals(intervals) + require.Error(t, err) + require.Contains(t, err.Error(), "unbounded") + require.Contains(t, err.Error(), "last") +} diff --git a/backend/internal/service/gateway_channel_restriction_fallback_test.go b/backend/internal/service/gateway_channel_restriction_fallback_test.go new file mode 100644 index 0000000000..d319641906 --- /dev/null +++ b/backend/internal/service/gateway_channel_restriction_fallback_test.go @@ -0,0 +1,130 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestSelectAccountForModelWithExclusions_UsesFallbackGroupForChannelRestriction(t *testing.T) { + t.Parallel() + + groupID := int64(10) + fallbackID := int64(11) + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{fallbackID}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{ + fallbackID: PlatformAnthropic, + })) + accountRepo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range accountRepo.accounts { + accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i] + } + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + ClaudeCodeOnly: true, + FallbackGroupID: &fallbackID, + Hydrated: true, + }, + fallbackID: { + ID: fallbackID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + }, + }, + } + + svc := &GatewayService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + channelService: channelSvc, + cfg: testConfig(), + } + + ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID]) + account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-sonnet-4-6", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(1), account.ID) +} + +func TestSelectAccountWithLoadAwareness_UsesFallbackGroupForChannelRestriction(t *testing.T) { + t.Parallel() + + groupID := int64(10) + fallbackID := int64(11) + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{fallbackID}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{ + fallbackID: PlatformAnthropic, + })) + accountRepo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range accountRepo.accounts { + accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i] + } + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + ClaudeCodeOnly: true, + FallbackGroupID: &fallbackID, + Hydrated: true, + }, + fallbackID: { + ID: fallbackID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + }, + }, + } + + svc := &GatewayService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + channelService: channelSvc, + cfg: testConfig(), + } + + ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID]) + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-sonnet-4-6", nil, "", 0) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(1), result.Account.ID) +} diff --git a/backend/internal/service/gateway_channel_restriction_test.go b/backend/internal/service/gateway_channel_restriction_test.go new file mode 100644 index 0000000000..3a2ad2ff72 --- /dev/null +++ b/backend/internal/service/gateway_channel_restriction_test.go @@ -0,0 +1,293 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// --- billingModelForRestriction --- + +func TestBillingModelForRestriction_Requested(t *testing.T) { + t.Parallel() + got := billingModelForRestriction(BillingModelSourceRequested, "claude-sonnet-4-5", "claude-sonnet-4-6") + require.Equal(t, "claude-sonnet-4-5", got) +} + +func TestBillingModelForRestriction_ChannelMapped(t *testing.T) { + t.Parallel() + got := billingModelForRestriction(BillingModelSourceChannelMapped, "claude-sonnet-4-5", "claude-sonnet-4-6") + require.Equal(t, "claude-sonnet-4-6", got) +} + +func TestBillingModelForRestriction_Upstream(t *testing.T) { + t.Parallel() + got := billingModelForRestriction(BillingModelSourceUpstream, "claude-sonnet-4-5", "claude-sonnet-4-6") + require.Equal(t, "", got, "upstream should return empty (per-account check needed)") +} + +func TestBillingModelForRestriction_Empty(t *testing.T) { + t.Parallel() + got := billingModelForRestriction("", "claude-sonnet-4-5", "claude-sonnet-4-6") + require.Equal(t, "claude-sonnet-4-6", got, "empty source defaults to channel_mapped") +} + +// --- resolveAccountUpstreamModel --- + +func TestResolveAccountUpstreamModel_Antigravity(t *testing.T) { + t.Parallel() + account := &Account{ + Platform: PlatformAntigravity, + } + // Antigravity 平台使用 DefaultAntigravityModelMapping + got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6") + require.Equal(t, "claude-sonnet-4-6", got) +} + +func TestResolveAccountUpstreamModel_Antigravity_Unsupported(t *testing.T) { + t.Parallel() + account := &Account{ + Platform: PlatformAntigravity, + } + got := resolveAccountUpstreamModel(account, "totally-unknown-model") + require.Equal(t, "", got, "unsupported model should return empty") +} + +func TestResolveAccountUpstreamModel_NonAntigravity(t *testing.T) { + t.Parallel() + account := &Account{ + Platform: PlatformAnthropic, + } + got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6") + require.Equal(t, "claude-sonnet-4-6", got, "no mapping = passthrough") +} + +// --- checkChannelPricingRestriction --- + +func TestCheckChannelPricingRestriction_NilGroupID(t *testing.T) { + t.Parallel() + svc := &GatewayService{channelService: &ChannelService{}} + require.False(t, svc.checkChannelPricingRestriction(context.Background(), nil, "claude-sonnet-4")) +} + +func TestCheckChannelPricingRestriction_NilChannelService(t *testing.T) { + t.Parallel() + svc := &GatewayService{} + gid := int64(10) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4")) +} + +func TestCheckChannelPricingRestriction_EmptyModel(t *testing.T) { + t.Parallel() + svc := &GatewayService{channelService: &ChannelService{}} + gid := int64(10) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "")) +} + +func TestCheckChannelPricingRestriction_ChannelMapped_Restricted(t *testing.T) { + t.Parallel() + // 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6,但定价列表只有 claude-opus-4-6 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceChannelMapped, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(10) + require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"), + "mapped model claude-sonnet-4-6 is NOT in pricing → restricted") +} + +func TestCheckChannelPricingRestriction_ChannelMapped_Allowed(t *testing.T) { + t.Parallel() + // 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6,定价列表包含 claude-sonnet-4-6 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceChannelMapped, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(10) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"), + "mapped model claude-sonnet-4-6 IS in pricing → allowed") +} + +func TestCheckChannelPricingRestriction_Requested_Restricted(t *testing.T) { + t.Parallel() + // billing_model_source=requested,定价列表有 claude-sonnet-4-6 但请求的是 claude-sonnet-4-5 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceRequested, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(10) + require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"), + "requested model claude-sonnet-4-5 is NOT in pricing → restricted") +} + +func TestCheckChannelPricingRestriction_Requested_Allowed(t *testing.T) { + t.Parallel() + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceRequested, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4-5"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(10) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"), + "requested model IS in pricing → allowed") +} + +func TestCheckChannelPricingRestriction_Upstream_SkipsPreCheck(t *testing.T) { + t.Parallel() + // upstream 模式:预检查始终跳过(返回 false),需逐账号检查 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceUpstream, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(10) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "unknown-model"), + "upstream mode should skip pre-check (per-account check needed)") +} + +func TestCheckChannelPricingRestriction_RestrictModelsDisabled(t *testing.T) { + t.Parallel() + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: false, // 未开启模型限制 + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(10) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"), + "RestrictModels=false → always allowed") +} + +func TestCheckChannelPricingRestriction_NoChannel(t *testing.T) { + t.Parallel() + // 分组没有关联渠道 + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { return nil, nil }, + } + channelSvc := newTestChannelService(repo) + svc := &GatewayService{channelService: channelSvc} + + gid := int64(999) + require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"), + "no channel for group → allowed") +} + +// --- isUpstreamModelRestrictedByChannel --- + +func TestIsUpstreamModelRestrictedByChannel_Restricted(t *testing.T) { + t.Parallel() + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + account := &Account{Platform: PlatformAntigravity} + // claude-sonnet-4-6 在 DefaultAntigravityModelMapping 中,映射后仍为 claude-sonnet-4-6 + // 但定价列表只有 claude-opus-4-6 + require.True(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"), + "upstream model claude-sonnet-4-6 NOT in pricing → restricted") +} + +func TestIsUpstreamModelRestrictedByChannel_Allowed(t *testing.T) { + t.Parallel() + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + account := &Account{Platform: PlatformAntigravity} + require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"), + "upstream model claude-sonnet-4-6 IS in pricing → allowed") +} + +func TestIsUpstreamModelRestrictedByChannel_UnsupportedModel(t *testing.T) { + t.Parallel() + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + } + channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"})) + svc := &GatewayService{channelService: channelSvc} + + account := &Account{Platform: PlatformAntigravity} + // totally-unknown-model 不在 DefaultAntigravityModelMapping 中 → 映射结果为空 + require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "totally-unknown-model"), + "unmappable model → upstream model empty → not restricted (account filter handles this)") +} diff --git a/backend/internal/service/gateway_hotpath_optimization_test.go b/backend/internal/service/gateway_hotpath_optimization_test.go index 161c4ba4b1..e5bf49b8b7 100644 --- a/backend/internal/service/gateway_hotpath_optimization_test.go +++ b/backend/internal/service/gateway_hotpath_optimization_test.go @@ -732,7 +732,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) { modelsListCacheTTL: time.Minute, } - result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -754,7 +754,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) { ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID) ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0)) - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -776,7 +776,7 @@ func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) { ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999)) ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77)) - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 2d16ad9429..728328373c 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -2031,7 +2031,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, // No concurrency service } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2084,7 +2084,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, // legacy path } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2116,7 +2116,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2148,7 +2148,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { } excludedIDs := map[int64]struct{}{1: {}} - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2182,7 +2182,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2218,7 +2218,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2259,7 +2259,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2287,7 +2287,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.Error(t, err) require.Nil(t, result) require.ErrorIs(t, err, ErrNoAvailableAccounts) @@ -2319,7 +2319,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2352,7 +2352,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2390,7 +2390,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.WaitPlan) @@ -2426,7 +2426,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2485,7 +2485,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.WaitPlan) @@ -2539,7 +2539,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2593,7 +2593,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2651,7 +2651,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2709,7 +2709,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.WaitPlan) @@ -2767,7 +2767,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2804,7 +2804,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.WaitPlan) @@ -2856,7 +2856,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2934,7 +2934,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { } excluded := map[int64]struct{}{1: {}} - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -2988,7 +2988,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) @@ -3021,7 +3021,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: nil, } - result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.Error(t, err) require.Nil(t, result) require.ErrorIs(t, err, ErrClaudeCodeOnly) @@ -3059,7 +3059,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.WaitPlan) @@ -3097,7 +3097,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { concurrencyService: NewConcurrencyService(concurrencyCache), } - result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "") + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "", int64(0)) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Account) diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 48488dc8c5..97703a9d5e 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -41,6 +41,8 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo nil, nil, nil, + nil, + nil, ) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 94e04d286d..a95b62b133 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -60,6 +60,13 @@ const ( claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) +// MediaType 媒体类型常量 +const ( + MediaTypeImage = "image" + MediaTypeVideo = "video" + MediaTypePrompt = "prompt" +) + // ForceCacheBillingContextKey 强制缓存计费上下文键 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 type forceCacheBillingKeyType struct{} @@ -483,6 +490,7 @@ type ClaudeUsage struct { CacheReadInputTokens int `json:"cache_read_input_tokens"` CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象) CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象) + ImageOutputTokens int `json:"image_output_tokens,omitempty"` } // ForwardResult 转发结果 @@ -568,6 +576,8 @@ type GatewayService struct { responseHeaderFilter *responseheaders.CompiledHeaderFilter debugModelRouting atomic.Bool debugClaudeMimic atomic.Bool + channelService *ChannelService + resolver *ModelPricingResolver debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set tlsFPProfileService *TLSFingerprintProfileService } @@ -597,6 +607,8 @@ func NewGatewayService( digestStore *DigestSessionStore, settingService *SettingService, tlsFPProfileService *TLSFingerprintProfileService, + channelService *ChannelService, + resolver *ModelPricingResolver, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) @@ -629,6 +641,8 @@ func NewGatewayService( modelsListCacheTTL: modelsListTTL, responseHeaderFilter: compileResponseHeaderFilter(cfg), tlsFPProfileService: tlsFPProfileService, + channelService: channelService, + resolver: resolver, } svc.userGroupRateResolver = newUserGroupRateResolver( userGroupRateRepo, @@ -866,17 +880,7 @@ type anthropicMetadataPayload struct { // replaceModelInBody 替换请求体中的model字段 // 优先使用定点修改,尽量保持客户端原始字段顺序。 func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { - if len(body) == 0 { - return body - } - if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel { - return body - } - newBody, err := sjson.SetBytes(body, "model", newModel) - if err != nil { - return body - } - return newBody + return ReplaceModelInBody(body, newModel) } type claudeOAuthNormalizeOptions struct { @@ -1186,6 +1190,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context platform = PlatformAnthropic } + // Claude Code 限制可能已将 groupID 解析为 fallback group, + // 渠道限制预检查必须使用解析后的分组。 + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { @@ -1198,8 +1211,10 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. -// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash -func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) { +// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。 +// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID +// sub2apiUserID: 系统用户 ID,用于二维亲和调度 +func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) { // 调试日志:记录调度入口参数 excludedIDsList := make([]int64, 0, len(excludedIDs)) for id := range excludedIDs { @@ -1220,6 +1235,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } ctx = s.withGroupContext(ctx, group) + // Claude Code 限制可能已将 groupID 解析为 fallback group, + // 渠道限制预检查必须使用解析后的分组。 + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + var stickyAccountID int64 if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { stickyAccountID = prefetch @@ -2945,6 +2969,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持) + // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查, + // 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。 + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { acc := &accounts[i] @@ -2965,6 +2992,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { + continue + } if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } @@ -3197,6 +3227,8 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) + // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。 + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { acc := &accounts[i] @@ -3221,6 +3253,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { + continue + } if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } @@ -7410,6 +7445,8 @@ type RecordUsageInput struct { RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 + + ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) } // APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage @@ -7439,6 +7476,18 @@ type postUsageBillingParams struct { APIKeyService APIKeyQuotaUpdater } +func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool { + return p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil +} + +func (p *postUsageBillingParams) shouldUpdateRateLimits() bool { + return p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil +} + +func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool { + return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() +} + // postUsageBilling 统一处理使用量记录后的扣费逻辑: // - 订阅/余额扣费 // - API Key 配额更新 @@ -7468,21 +7517,21 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill } // 2. API Key 配额 - if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + if p.shouldDeductAPIKeyQuota() { if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) } } // 3. API Key 限速用量 - if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + if p.shouldUpdateRateLimits() { if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) } } // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) - if cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { + if p.shouldUpdateAccountQuota() { accountCost := cost.TotalCost * p.AccountRateMultiplier if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) @@ -7564,13 +7613,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage cmd.BalanceCost = p.Cost.ActualCost } - if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + if p.shouldDeductAPIKeyQuota() { cmd.APIKeyQuotaCost = p.Cost.ActualCost } - if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + if p.shouldUpdateRateLimits() { cmd.APIKeyRateLimitCost = p.Cost.ActualCost } - if p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() { + if p.shouldUpdateAccountQuota() { cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier } @@ -7694,8 +7743,108 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage } } +// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。 +type recordUsageOpts struct { + // Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入) + ParsedRequest *ParsedRequest + + // EnableClaudePath 启用 Claude 路径特有逻辑: + // - Claude Max 缓存计费策略 + // - Sora 媒体类型分支(image/video/prompt) + // - MediaType 字段写入使用日志 + EnableClaudePath bool + + // 长上下文计费(仅 Gemini 路径需要) + LongContextThreshold int + LongContextMultiplier float64 +} + // RecordUsage 记录使用量并扣费(或更新订阅用量) func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { + return s.recordUsageCore(ctx, &recordUsageCoreInput{ + Result: input.Result, + APIKey: input.APIKey, + User: input.User, + Account: input.Account, + Subscription: input.Subscription, + InboundEndpoint: input.InboundEndpoint, + UpstreamEndpoint: input.UpstreamEndpoint, + UserAgent: input.UserAgent, + IPAddress: input.IPAddress, + RequestPayloadHash: input.RequestPayloadHash, + ForceCacheBilling: input.ForceCacheBilling, + APIKeyService: input.APIKeyService, + ChannelUsageFields: input.ChannelUsageFields, + }, &recordUsageOpts{ + EnableClaudePath: true, + }) +} + +// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) +type RecordUsageLongContextInput struct { + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + InboundEndpoint string // 入站端点(客户端请求路径) + UpstreamEndpoint string // 上游端点(标准化后的上游路径) + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + LongContextThreshold int // 长上下文阈值(如 200000) + LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) + + ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) +} + +// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) +func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { + return s.recordUsageCore(ctx, &recordUsageCoreInput{ + Result: input.Result, + APIKey: input.APIKey, + User: input.User, + Account: input.Account, + Subscription: input.Subscription, + InboundEndpoint: input.InboundEndpoint, + UpstreamEndpoint: input.UpstreamEndpoint, + UserAgent: input.UserAgent, + IPAddress: input.IPAddress, + RequestPayloadHash: input.RequestPayloadHash, + ForceCacheBilling: input.ForceCacheBilling, + APIKeyService: input.APIKeyService, + ChannelUsageFields: input.ChannelUsageFields, + }, &recordUsageOpts{ + LongContextThreshold: input.LongContextThreshold, + LongContextMultiplier: input.LongContextMultiplier, + }) +} + +// recordUsageCoreInput 是 recordUsageCore 的公共输入字段,从两种输入结构体中提取。 +type recordUsageCoreInput struct { + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + InboundEndpoint string + UpstreamEndpoint string + UserAgent string + IPAddress string + RequestPayloadHash string + ForceCacheBilling bool + APIKeyService APIKeyQuotaUpdater + ChannelUsageFields +} + +// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 +// opts 中的字段控制两者之间的差异行为: +// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略 +// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt) +// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext +func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { result := input.Result apiKey := input.APIKey user := input.User @@ -7728,56 +7877,24 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } - var cost *CostBreakdown + // 确定计费模型 billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) + if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" { + billingModel = input.ChannelMappedModel + } + if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { + billingModel = input.OriginalModel + } - // 根据请求类型选择计费方式 - if result.MediaType == "image" || result.MediaType == "video" { - var soraConfig *SoraPriceConfig - if apiKey.Group != nil { - soraConfig = &SoraPriceConfig{ - ImagePrice360: apiKey.Group.SoraImagePrice360, - ImagePrice540: apiKey.Group.SoraImagePrice540, - VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, - VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, - } - } - if result.MediaType == "image" { - cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) - } else { - cost = s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) - } - } else if result.MediaType == "prompt" { - cost = &CostBreakdown{} - } else if result.ImageCount > 0 { - // 图片生成计费 - var groupConfig *ImagePriceConfig - if apiKey.Group != nil { - groupConfig = &ImagePriceConfig{ - Price1K: apiKey.Group.ImagePrice1K, - Price2K: apiKey.Group.ImagePrice2K, - Price4K: apiKey.Group.ImagePrice4K, - } - } - cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) - } else { - // Token 计费 - tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, - CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - } - var err error - cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) - if err != nil { - logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) - cost = &CostBreakdown{ActualCost: 0} - } + // 确定 RequestedModel(渠道映射前的原始模型) + requestedModel := result.Model + if input.OriginalModel != "" { + requestedModel = input.OriginalModel } + // 计算费用 + cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts) + // 判断计费方式:订阅模式 vs 余额模式 isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() billingType := BillingTypeBalance @@ -7786,70 +7903,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } // 创建使用日志 - durationMs := int(result.Duration.Milliseconds()) - var imageSize *string - if result.ImageSize != "" { - imageSize = &result.ImageSize - } - var mediaType *string - if strings.TrimSpace(result.MediaType) != "" { - mediaType = &result.MediaType - } accountRateMultiplier := account.BillingRateMultiplier() - requestID := resolveUsageBillingRequestID(ctx, result.RequestID) - usageLog := &UsageLog{ - UserID: user.ID, - APIKeyID: apiKey.ID, - AccountID: account.ID, - RequestID: requestID, - Model: result.Model, - RequestedModel: result.Model, - UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), - ReasoningEffort: result.ReasoningEffort, - InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), - UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, - CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - InputCost: cost.InputCost, - OutputCost: cost.OutputCost, - CacheCreationCost: cost.CacheCreationCost, - CacheReadCost: cost.CacheReadCost, - TotalCost: cost.TotalCost, - ActualCost: cost.ActualCost, - RateMultiplier: multiplier, - AccountRateMultiplier: &accountRateMultiplier, - BillingType: billingType, - Stream: result.Stream, - DurationMs: &durationMs, - FirstTokenMs: result.FirstTokenMs, - ImageCount: result.ImageCount, - ImageSize: imageSize, - MediaType: mediaType, - CacheTTLOverridden: cacheTTLOverridden, - CreatedAt: time.Now(), - } - - // 添加 UserAgent - if input.UserAgent != "" { - usageLog.UserAgent = &input.UserAgent - } - - // 添加 IPAddress - if input.IPAddress != "" { - usageLog.IPAddress = &input.IPAddress - } - - // 添加分组和订阅关联 - if apiKey.GroupID != nil { - usageLog.GroupID = apiKey.GroupID - } - if subscription != nil { - usageLog.SubscriptionID = &subscription.ID - } + usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, + requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") @@ -7858,20 +7914,18 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu return nil } - billingErr := func() error { - _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ - Cost: cost, - User: user, - APIKey: apiKey, - Account: account, - Subscription: subscription, - RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), - IsSubscriptionBill: isSubscriptionBilling, - AccountRateMultiplier: accountRateMultiplier, - APIKeyService: input.APIKeyService, - }, s.billingDeps(), s.usageBillingRepo) - return err - }() + requestID := usageLog.RequestID + _, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps(), s.usageBillingRepo) if billingErr != nil { return billingErr @@ -7881,105 +7935,182 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu return nil } -// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) -type RecordUsageLongContextInput struct { - Result *ForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription // 可选:订阅信息 - InboundEndpoint string // 入站端点(客户端请求路径) - UpstreamEndpoint string // 上游端点(标准化后的上游路径) - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 - LongContextThreshold int // 长上下文阈值(如 200000) - LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) - ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) - APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) +// calculateRecordUsageCost 根据请求类型和选项计算费用。 +func (s *GatewayService) calculateRecordUsageCost( + ctx context.Context, + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, + opts *recordUsageOpts, +) *CostBreakdown { + // Sora 媒体类型分支(仅 Claude 路径启用) + if opts.EnableClaudePath { + if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo { + return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier) + } + if result.MediaType == MediaTypePrompt { + return &CostBreakdown{} + } + } + + // 图片生成计费 + if result.ImageCount > 0 { + return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier) + } + + // Token 计费 + return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts) } -// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) -func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { - result := input.Result - apiKey := input.APIKey - user := input.User - account := input.Account - subscription := input.Subscription +// calculateSoraMediaCost 计算 Sora 图片/视频的费用。 +func (s *GatewayService) calculateSoraMediaCost( + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, +) *CostBreakdown { + var soraConfig *SoraPriceConfig + if apiKey.Group != nil { + soraConfig = &SoraPriceConfig{ + ImagePrice360: apiKey.Group.SoraImagePrice360, + ImagePrice540: apiKey.Group.SoraImagePrice540, + VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, + } + } + if result.MediaType == MediaTypeImage { + return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) + } + return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) +} - // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens - // 用于粘性会话切换时的特殊计费处理 - if input.ForceCacheBilling && result.Usage.InputTokens > 0 { - logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", - result.Usage.InputTokens, account.ID) - result.Usage.CacheReadInputTokens += result.Usage.InputTokens - result.Usage.InputTokens = 0 +// resolveChannelPricing 检查指定模型是否存在渠道级别定价。 +// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。 +func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { + if s.resolver == nil || apiKey.Group == nil { + return nil } + gid := apiKey.Group.ID + resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) + if resolved.Source == PricingSourceChannel { + return resolved + } + return nil +} - // Cache TTL Override: 确保计费时 token 分类与账号设置一致 - cacheTTLOverridden := false - if account.IsCacheTTLOverrideEnabled() { - applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) - cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 +// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。 +func (s *GatewayService) calculateImageCost( + ctx context.Context, + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, +) *CostBreakdown { + if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, + } + gid := apiKey.Group.ID + cost, err := s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + Resolver: s.resolver, + Resolved: resolved, + }) + if err != nil { + logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err) + return &CostBreakdown{ActualCost: 0} + } + return cost } - // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) - multiplier := 1.0 - if s.cfg != nil { - multiplier = s.cfg.Default.RateMultiplier + var groupConfig *ImagePriceConfig + if apiKey.Group != nil { + groupConfig = &ImagePriceConfig{ + Price1K: apiKey.Group.ImagePrice1K, + Price2K: apiKey.Group.ImagePrice2K, + Price4K: apiKey.Group.ImagePrice4K, + } } - if apiKey.GroupID != nil && apiKey.Group != nil { - groupDefault := apiKey.Group.RateMultiplier - multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) + return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) +} + +// calculateTokenCost 计算 Token 计费:根据 opts 决定走普通/长上下文/渠道统一计费。 +func (s *GatewayService) calculateTokenCost( + ctx context.Context, + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, + opts *recordUsageOpts, +) *CostBreakdown { + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, } var cost *CostBreakdown - billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) + var err error - // 根据请求类型选择计费方式 - if result.ImageCount > 0 { - // 图片生成计费 - var groupConfig *ImagePriceConfig - if apiKey.Group != nil { - groupConfig = &ImagePriceConfig{ - Price1K: apiKey.Group.ImagePrice1K, - Price2K: apiKey.Group.ImagePrice2K, - Price4K: apiKey.Group.ImagePrice4K, - } - } - cost = s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) + // 优先尝试渠道定价 → CalculateCostUnified + if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { + gid := apiKey.Group.ID + cost, err = s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + Resolver: s.resolver, + Resolved: resolved, + }) + } else if opts.LongContextThreshold > 0 { + // 长上下文双倍计费(如 Gemini 200K 阈值) + cost, err = s.billingService.CalculateCostWithLongContext( + billingModel, tokens, multiplier, + opts.LongContextThreshold, opts.LongContextMultiplier, + ) } else { - // Token 计费(使用长上下文计费方法) - tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, - CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - } - var err error - cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) - if err != nil { - logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) - cost = &CostBreakdown{ActualCost: 0} - } + cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) } - - // 判断计费方式:订阅模式 vs 余额模式 - isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() - billingType := BillingTypeBalance - if isSubscriptionBilling { - billingType = BillingTypeSubscription + if err != nil { + logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) + return &CostBreakdown{ActualCost: 0} } + return cost +} - // 创建使用日志 +// buildRecordUsageLog 构建使用日志并设置计费模式。 +func (s *GatewayService) buildRecordUsageLog( + ctx context.Context, + input *recordUsageCoreInput, + result *ForwardResult, + apiKey *APIKey, + user *User, + account *Account, + subscription *UserSubscription, + requestedModel string, + multiplier float64, + accountRateMultiplier float64, + billingType int8, + cacheTTLOverridden bool, + cost *CostBreakdown, + opts *recordUsageOpts, +) *UsageLog { durationMs := int(result.Duration.Milliseconds()) - var imageSize *string - if result.ImageSize != "" { - imageSize = &result.ImageSize - } - accountRateMultiplier := account.BillingRateMultiplier() requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, @@ -7987,7 +8118,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * AccountID: account.ID, RequestID: requestID, Model: result.Model, - RequestedModel: result.Model, + RequestedModel: requestedModel, UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), @@ -7998,70 +8129,168 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * CacheReadTokens: result.Usage.CacheReadInputTokens, CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - InputCost: cost.InputCost, - OutputCost: cost.OutputCost, - CacheCreationCost: cost.CacheCreationCost, - CacheReadCost: cost.CacheReadCost, - TotalCost: cost.TotalCost, - ActualCost: cost.ActualCost, + ImageOutputTokens: result.Usage.ImageOutputTokens, RateMultiplier: multiplier, AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, + BillingMode: resolveBillingMode(opts, result, cost), Stream: result.Stream, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, - ImageSize: imageSize, + ImageSize: optionalTrimmedStringPtr(result.ImageSize), + MediaType: resolveMediaType(opts, result), CacheTTLOverridden: cacheTTLOverridden, + ChannelID: optionalInt64Ptr(input.ChannelID), + ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), + UserAgent: optionalTrimmedStringPtr(input.UserAgent), + IPAddress: optionalTrimmedStringPtr(input.IPAddress), + GroupID: apiKey.GroupID, + SubscriptionID: optionalSubscriptionID(subscription), CreatedAt: time.Now(), } - - // 添加 UserAgent - if input.UserAgent != "" { - usageLog.UserAgent = &input.UserAgent + if cost != nil { + usageLog.InputCost = cost.InputCost + usageLog.OutputCost = cost.OutputCost + usageLog.ImageOutputCost = cost.ImageOutputCost + usageLog.CacheCreationCost = cost.CacheCreationCost + usageLog.CacheReadCost = cost.CacheReadCost + usageLog.TotalCost = cost.TotalCost + usageLog.ActualCost = cost.ActualCost } - // 添加 IPAddress - if input.IPAddress != "" { - usageLog.IPAddress = &input.IPAddress + return usageLog +} + +// resolveBillingMode 根据计费结果和请求类型确定计费模式。 +// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。 +func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string { + isSoraMedia := opts.EnableClaudePath && + (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt) + if isSoraMedia { + return nil } + var mode string + switch { + case cost != nil && cost.BillingMode != "": + mode = cost.BillingMode + case result.ImageCount > 0: + mode = string(BillingModeImage) + default: + mode = string(BillingModeToken) + } + return &mode +} - // 添加分组和订阅关联 - if apiKey.GroupID != nil { - usageLog.GroupID = apiKey.GroupID +func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string { + if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" { + return &result.MediaType } + return nil +} + +func optionalSubscriptionID(subscription *UserSubscription) *int64 { if subscription != nil { - usageLog.SubscriptionID = &subscription.ID + return &subscription.ID } + return nil +} - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") - logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil +// ResolveChannelMapping 委托渠道服务解析模型映射 +func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model} } + return s.channelService.ResolveChannelMapping(ctx, groupID, model) +} - billingErr := func() error { - _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ - Cost: cost, - User: user, - APIKey: apiKey, - Account: account, - Subscription: subscription, - RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), - IsSubscriptionBill: isSubscriptionBilling, - AccountRateMultiplier: accountRateMultiplier, - APIKeyService: input.APIKeyService, - }, s.billingDeps(), s.usageBillingRepo) - return err - }() +// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用) +func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { + return ReplaceModelInBody(body, newModel) +} - if billingErr != nil { - return billingErr +// IsModelRestricted 检查模型是否被渠道限制 +func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { + if s.channelService == nil { + return false } - writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + return s.channelService.IsModelRestricted(ctx, groupID, model) +} - return nil +// ResolveChannelMappingAndRestrict 解析渠道映射。 +// 模型限制检查已移至调度阶段(checkChannelPricingRestriction),restricted 始终返回 false。 +func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model}, false + } + return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model) +} + +// checkChannelPricingRestriction 根据渠道计费基准检查模型是否受定价列表限制。 +// 供调度阶段预检查(requested / channel_mapped)。 +// upstream 需逐账号检查,此处返回 false。 +func (s *GatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool { + if groupID == nil || s.channelService == nil || requestedModel == "" { + return false + } + mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel) + billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel) + if billingModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, *groupID, billingModel) +} + +// billingModelForRestriction 根据计费基准确定限制检查使用的模型。 +// upstream 返回空(需逐账号检查)。 +func billingModelForRestriction(source, requestedModel, channelMappedModel string) string { + switch source { + case BillingModelSourceRequested: + return requestedModel + case BillingModelSourceUpstream: + return "" + case BillingModelSourceChannelMapped: + return channelMappedModel + default: + return channelMappedModel + } +} + +// isUpstreamModelRestrictedByChannel 检查账号映射后的上游模型是否受渠道定价限制。 +// 仅在 BillingModelSource="upstream" 且 RestrictModels=true 时由调度循环调用。 +func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool { + if s.channelService == nil { + return false + } + upstreamModel := resolveAccountUpstreamModel(account, requestedModel) + if upstreamModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel) +} + +// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。 +func resolveAccountUpstreamModel(account *Account, requestedModel string) string { + if account.Platform == PlatformAntigravity { + return mapAntigravityModel(account, requestedModel) + } + return account.GetMappedModel(requestedModel) +} + +// needsUpstreamChannelRestrictionCheck 判断是否需要在调度循环中逐账号检查上游模型的渠道限制。 +func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool { + if groupID == nil || s.channelService == nil { + return false + } + ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) + if err != nil { + slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err) + return false + } + if ch == nil || !ch.RestrictModels { + return false + } + return ch.BillingModelSource == BillingModelSourceUpstream } // ForwardCountTokens 转发 count_tokens 请求到上游 API diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 5b1abc119f..b35ebce5c9 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2692,12 +2692,27 @@ func extractGeminiUsage(data []byte) *ClaudeUsage { cand := int(usage.Get("candidatesTokenCount").Int()) cached := int(usage.Get("cachedContentTokenCount").Int()) thoughts := int(usage.Get("thoughtsTokenCount").Int()) + + // 从 candidatesTokensDetails 提取 IMAGE 模态 token 数 + imageTokens := 0 + candidateDetails := usage.Get("candidatesTokensDetails") + if candidateDetails.Exists() { + candidateDetails.ForEach(func(_, detail gjson.Result) bool { + if detail.Get("modality").String() == "IMAGE" { + imageTokens = int(detail.Get("tokenCount").Int()) + return false + } + return true + }) + } + // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 return &ClaudeUsage{ InputTokens: prompt - cached, OutputTokens: cand + thoughts, CacheReadInputTokens: cached, + ImageOutputTokens: imageTokens, } } diff --git a/backend/internal/service/model_pricing_resolver.go b/backend/internal/service/model_pricing_resolver.go new file mode 100644 index 0000000000..b7ca4cb76e --- /dev/null +++ b/backend/internal/service/model_pricing_resolver.go @@ -0,0 +1,231 @@ +package service + +import ( + "context" + "log/slog" +) + +// PricingSource 定价来源标识 +const ( + PricingSourceChannel = "channel" + PricingSourceLiteLLM = "litellm" + PricingSourceFallback = "fallback" +) + +// ResolvedPricing 统一定价解析结果 +type ResolvedPricing struct { + // Mode 计费模式 + Mode BillingMode + + // Token 模式:基础定价(来自 LiteLLM 或 fallback) + BasePricing *ModelPricing + + // Token 模式:区间定价列表(如有,覆盖 BasePricing 中的对应字段) + Intervals []PricingInterval + + // 按次/图片模式:分层定价 + RequestTiers []PricingInterval + + // 按次/图片模式:默认价格(未命中层级时使用) + DefaultPerRequestPrice float64 + + // 来源标识 + Source string // "channel", "litellm", "fallback" + + // 是否支持缓存细分 + SupportsCacheBreakdown bool +} + +// ModelPricingResolver 统一模型定价解析器。 +// 解析链:Channel → LiteLLM → Fallback。 +type ModelPricingResolver struct { + channelService *ChannelService + billingService *BillingService +} + +// NewModelPricingResolver 创建定价解析器实例 +func NewModelPricingResolver(channelService *ChannelService, billingService *BillingService) *ModelPricingResolver { + return &ModelPricingResolver{ + channelService: channelService, + billingService: billingService, + } +} + +// PricingInput 定价解析输入 +type PricingInput struct { + Model string + GroupID *int64 // nil 表示不检查渠道 +} + +// Resolve 解析模型定价。 +// 1. 获取基础定价(LiteLLM → Fallback) +// 2. 如果指定了 GroupID,查找渠道定价并覆盖 +func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing { + // 1. 获取基础定价 + basePricing, source := r.resolveBasePricing(input.Model) + + resolved := &ResolvedPricing{ + Mode: BillingModeToken, + BasePricing: basePricing, + Source: source, + SupportsCacheBreakdown: basePricing != nil && basePricing.SupportsCacheBreakdown, + } + + // 2. 如果有 GroupID,尝试渠道覆盖 + if input.GroupID != nil { + r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved) + } + + return resolved +} + +// resolveBasePricing 从 LiteLLM 或 Fallback 获取基础定价 +func (r *ModelPricingResolver) resolveBasePricing(model string) (*ModelPricing, string) { + pricing, err := r.billingService.GetModelPricing(model) + if err != nil { + slog.Debug("failed to get model pricing from LiteLLM, using fallback", + "model", model, "error", err) + return nil, PricingSourceFallback + } + return pricing, PricingSourceLiteLLM +} + +// applyChannelOverrides 应用渠道定价覆盖 +func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupID int64, model string, resolved *ResolvedPricing) { + chPricing := r.channelService.GetChannelModelPricing(ctx, groupID, model) + if chPricing == nil { + return + } + + resolved.Source = PricingSourceChannel + resolved.Mode = chPricing.BillingMode + if resolved.Mode == "" { + resolved.Mode = BillingModeToken + } + + switch resolved.Mode { + case BillingModeToken: + r.applyTokenOverrides(chPricing, resolved) + case BillingModePerRequest, BillingModeImage: + r.applyRequestTierOverrides(chPricing, resolved) + } +} + +// applyTokenOverrides 应用 token 模式的渠道覆盖 +func (r *ModelPricingResolver) applyTokenOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) { + // 过滤掉所有价格字段都为空的无效 interval + validIntervals := filterValidIntervals(chPricing.Intervals) + + // 如果有有效的区间定价,使用区间 + if len(validIntervals) > 0 { + resolved.Intervals = validIntervals + return + } + + // 否则用 flat 字段覆盖 BasePricing + if resolved.BasePricing == nil { + resolved.BasePricing = &ModelPricing{} + } + + if chPricing.InputPrice != nil { + resolved.BasePricing.InputPricePerToken = *chPricing.InputPrice + resolved.BasePricing.InputPricePerTokenPriority = *chPricing.InputPrice + } + if chPricing.OutputPrice != nil { + resolved.BasePricing.OutputPricePerToken = *chPricing.OutputPrice + resolved.BasePricing.OutputPricePerTokenPriority = *chPricing.OutputPrice + } + if chPricing.CacheWritePrice != nil { + resolved.BasePricing.CacheCreationPricePerToken = *chPricing.CacheWritePrice + resolved.BasePricing.CacheCreation5mPrice = *chPricing.CacheWritePrice + resolved.BasePricing.CacheCreation1hPrice = *chPricing.CacheWritePrice + } + if chPricing.CacheReadPrice != nil { + resolved.BasePricing.CacheReadPricePerToken = *chPricing.CacheReadPrice + resolved.BasePricing.CacheReadPricePerTokenPriority = *chPricing.CacheReadPrice + } + if chPricing.ImageOutputPrice != nil { + resolved.BasePricing.ImageOutputPricePerToken = *chPricing.ImageOutputPrice + } +} + +// applyRequestTierOverrides 应用按次/图片模式的渠道覆盖 +func (r *ModelPricingResolver) applyRequestTierOverrides(chPricing *ChannelModelPricing, resolved *ResolvedPricing) { + resolved.RequestTiers = filterValidIntervals(chPricing.Intervals) + if chPricing.PerRequestPrice != nil { + resolved.DefaultPerRequestPrice = *chPricing.PerRequestPrice + } +} + +// filterValidIntervals 过滤掉所有价格字段都为空的无效 interval。 +// 前端可能创建了只有 min/max 但无价格的空 interval。 +func filterValidIntervals(intervals []PricingInterval) []PricingInterval { + var valid []PricingInterval + for _, iv := range intervals { + if iv.InputPrice != nil || iv.OutputPrice != nil || + iv.CacheWritePrice != nil || iv.CacheReadPrice != nil || + iv.PerRequestPrice != nil { + valid = append(valid, iv) + } + } + return valid +} + +// GetIntervalPricing 根据 context token 数获取区间定价。 +// 如果有区间列表,找到匹配区间并构造 ModelPricing;否则直接返回 BasePricing。 +func (r *ModelPricingResolver) GetIntervalPricing(resolved *ResolvedPricing, totalContextTokens int) *ModelPricing { + if len(resolved.Intervals) == 0 { + return resolved.BasePricing + } + + iv := FindMatchingInterval(resolved.Intervals, totalContextTokens) + if iv == nil { + return resolved.BasePricing + } + + return intervalToModelPricing(iv, resolved.SupportsCacheBreakdown) +} + +// intervalToModelPricing 将区间定价转换为 ModelPricing +func intervalToModelPricing(iv *PricingInterval, supportsCacheBreakdown bool) *ModelPricing { + pricing := &ModelPricing{ + SupportsCacheBreakdown: supportsCacheBreakdown, + } + if iv.InputPrice != nil { + pricing.InputPricePerToken = *iv.InputPrice + pricing.InputPricePerTokenPriority = *iv.InputPrice + } + if iv.OutputPrice != nil { + pricing.OutputPricePerToken = *iv.OutputPrice + pricing.OutputPricePerTokenPriority = *iv.OutputPrice + } + if iv.CacheWritePrice != nil { + pricing.CacheCreationPricePerToken = *iv.CacheWritePrice + pricing.CacheCreation5mPrice = *iv.CacheWritePrice + pricing.CacheCreation1hPrice = *iv.CacheWritePrice + } + if iv.CacheReadPrice != nil { + pricing.CacheReadPricePerToken = *iv.CacheReadPrice + pricing.CacheReadPricePerTokenPriority = *iv.CacheReadPrice + } + return pricing +} + +// GetRequestTierPrice 根据层级标签获取按次价格 +func (r *ModelPricingResolver) GetRequestTierPrice(resolved *ResolvedPricing, tierLabel string) float64 { + for _, tier := range resolved.RequestTiers { + if tier.TierLabel == tierLabel && tier.PerRequestPrice != nil { + return *tier.PerRequestPrice + } + } + return 0 +} + +// GetRequestTierPriceByContext 根据 context token 数获取按次价格 +func (r *ModelPricingResolver) GetRequestTierPriceByContext(resolved *ResolvedPricing, totalContextTokens int) float64 { + iv := FindMatchingInterval(resolved.RequestTiers, totalContextTokens) + if iv != nil && iv.PerRequestPrice != nil { + return *iv.PerRequestPrice + } + return 0 +} diff --git a/backend/internal/service/model_pricing_resolver_test.go b/backend/internal/service/model_pricing_resolver_test.go new file mode 100644 index 0000000000..905c4df685 --- /dev/null +++ b/backend/internal/service/model_pricing_resolver_test.go @@ -0,0 +1,663 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func newTestBillingServiceForResolver() *BillingService { + bs := &BillingService{ + fallbackPrices: make(map[string]*ModelPricing), + } + bs.fallbackPrices["claude-sonnet-4"] = &ModelPricing{ + InputPricePerToken: 3e-6, + OutputPricePerToken: 15e-6, + CacheCreationPricePerToken: 3.75e-6, + CacheReadPricePerToken: 0.3e-6, + SupportsCacheBreakdown: false, + } + return bs +} + +func TestResolve_NoGroupID(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: nil, + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModeToken, resolved.Mode) + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12) + // BillingService.GetModelPricing uses fallback internally, but resolveBasePricing + // reports "litellm" when GetModelPricing succeeds (regardless of internal source) + require.Equal(t, "litellm", resolved.Source) +} + +func TestResolve_UnknownModel(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "unknown-model-xyz", + GroupID: nil, + }) + + require.NotNil(t, resolved) + require.Nil(t, resolved.BasePricing) + // Unknown model: GetModelPricing returns error, source is "fallback" + require.Equal(t, "fallback", resolved.Source) +} + +func TestGetIntervalPricing_NoIntervals(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + basePricing := &ModelPricing{InputPricePerToken: 5e-6} + resolved := &ResolvedPricing{ + Mode: BillingModeToken, + BasePricing: basePricing, + Intervals: nil, + } + + result := r.GetIntervalPricing(resolved, 50000) + require.Equal(t, basePricing, result) +} + +func TestGetIntervalPricing_MatchesInterval(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModeToken, + BasePricing: &ModelPricing{InputPricePerToken: 5e-6}, + SupportsCacheBreakdown: true, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(2e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(3e-6), OutputPrice: testPtrFloat64(6e-6)}, + }, + } + + result := r.GetIntervalPricing(resolved, 50000) + require.NotNil(t, result) + require.InDelta(t, 1e-6, result.InputPricePerToken, 1e-12) + require.InDelta(t, 2e-6, result.OutputPricePerToken, 1e-12) + require.True(t, result.SupportsCacheBreakdown) + + result2 := r.GetIntervalPricing(resolved, 200000) + require.NotNil(t, result2) + require.InDelta(t, 3e-6, result2.InputPricePerToken, 1e-12) +} + +func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + basePricing := &ModelPricing{InputPricePerToken: 99e-6} + resolved := &ResolvedPricing{ + Mode: BillingModeToken, + BasePricing: basePricing, + Intervals: []PricingInterval{ + {MinTokens: 10000, MaxTokens: testPtrInt(50000), InputPrice: testPtrFloat64(1e-6)}, + }, + } + + result := r.GetIntervalPricing(resolved, 5000) + require.Equal(t, basePricing, result) +} + +func TestGetRequestTierPrice(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)}, + }, + } + + require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12) + require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12) + require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "4K"), 1e-12) +} + +func TestGetRequestTierPriceByContext(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)}, + {MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)}, + }, + } + + require.InDelta(t, 0.05, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12) + require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12) +} + +func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: nil}, + }, + } + + require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12) +} + +// =========================================================================== +// Channel override tests — exercises applyChannelOverrides via Resolve +// =========================================================================== + +// helper: creates a resolver wired to a ChannelService that returns the given +// channel (active, groupID=100, platform=anthropic) with the specified pricing. +func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelPricingResolver { + t.Helper() + const groupID = 100 + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return []Channel{{ + ID: 1, + Name: "test-channel", + Status: StatusActive, + GroupIDs: []int64{groupID}, + ModelPricing: pricing, + }}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return map[int64]string{groupID: "anthropic"}, nil + }, + } + cs := NewChannelService(repo, nil) + bs := newTestBillingServiceForResolver() + return NewModelPricingResolver(cs, bs) +} + +// groupIDPtr returns a pointer to groupID 100 (the test constant). +func groupIDPtr() *int64 { v := int64(100); return &v } + +// --------------------------------------------------------------------------- +// 1. Token mode overrides +// --------------------------------------------------------------------------- + +func TestResolve_WithChannelOverride_TokenFlat(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(10e-6), + OutputPrice: testPtrFloat64(50e-6), + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModeToken, resolved.Mode) + require.Equal(t, "channel", resolved.Source) + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerToken, 1e-12) + require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerTokenPriority, 1e-12) +} + +func TestResolve_WithChannelOverride_TokenPartialOverride(t *testing.T) { + // Channel only sets InputPrice; OutputPrice should remain from the base (LiteLLM/fallback). + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(20e-6), + // OutputPrice intentionally nil + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, "channel", resolved.Source) + require.NotNil(t, resolved.BasePricing) + // InputPrice overridden by channel + require.InDelta(t, 20e-6, resolved.BasePricing.InputPricePerToken, 1e-12) + // OutputPrice kept from base (fallback: 15e-6) + require.InDelta(t, 15e-6, resolved.BasePricing.OutputPricePerToken, 1e-12) +} + +func TestResolve_WithChannelOverride_TokenWithIntervals(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(8e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(4e-6), OutputPrice: testPtrFloat64(16e-6)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, "channel", resolved.Source) + require.Len(t, resolved.Intervals, 2) + + // GetIntervalPricing should use channel intervals + iv := r.GetIntervalPricing(resolved, 50000) + require.NotNil(t, iv) + require.InDelta(t, 2e-6, iv.InputPricePerToken, 1e-12) + require.InDelta(t, 8e-6, iv.OutputPricePerToken, 1e-12) + + iv2 := r.GetIntervalPricing(resolved, 200000) + require.NotNil(t, iv2) + require.InDelta(t, 4e-6, iv2.InputPricePerToken, 1e-12) + require.InDelta(t, 16e-6, iv2.OutputPricePerToken, 1e-12) +} + +func TestResolve_WithChannelOverride_TokenNilBasePricing(t *testing.T) { + // Base pricing is nil (unknown model), channel has flat prices → creates new BasePricing. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"unknown-model-xyz"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(7e-6), + OutputPrice: testPtrFloat64(21e-6), + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "unknown-model-xyz", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, "channel", resolved.Source) + // BasePricing was nil from resolveBasePricing but applyTokenOverrides creates a new one + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 7e-6, resolved.BasePricing.InputPricePerToken, 1e-12) + require.InDelta(t, 21e-6, resolved.BasePricing.OutputPricePerToken, 1e-12) +} + +// --------------------------------------------------------------------------- +// 2. Per-request mode overrides +// --------------------------------------------------------------------------- + +func TestResolve_WithChannelOverride_PerRequest(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0.05), + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.03)}, + {MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModePerRequest, resolved.Mode) + require.Equal(t, "channel", resolved.Source) + require.InDelta(t, 0.05, resolved.DefaultPerRequestPrice, 1e-12) + require.Len(t, resolved.RequestTiers, 2) + + // Verify tier lookups + require.InDelta(t, 0.03, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12) + require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12) +} + +func TestResolve_WithChannelOverride_PerRequestNilPrice(t *testing.T) { + // PerRequestPrice nil → DefaultPerRequestPrice stays 0. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModePerRequest, + // PerRequestPrice intentionally nil + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.02)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModePerRequest, resolved.Mode) + require.InDelta(t, 0.0, resolved.DefaultPerRequestPrice, 1e-12) + require.Len(t, resolved.RequestTiers, 1) +} + +// --------------------------------------------------------------------------- +// 3. Image mode overrides +// --------------------------------------------------------------------------- + +func TestResolve_WithChannelOverride_Image(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeImage, + PerRequestPrice: testPtrFloat64(0.08), + Intervals: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)}, + {TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModeImage, resolved.Mode) + require.Equal(t, "channel", resolved.Source) + require.InDelta(t, 0.08, resolved.DefaultPerRequestPrice, 1e-12) + require.Len(t, resolved.RequestTiers, 3) +} + +func TestResolve_WithChannelOverride_ImageTierLabels(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeImage, + Intervals: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)}, + {TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12) + require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12) + require.InDelta(t, 0.16, r.GetRequestTierPrice(resolved, "4K"), 1e-12) + require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "8K"), 1e-12) // not found +} + +// --------------------------------------------------------------------------- +// 4. Source tracking & default mode +// --------------------------------------------------------------------------- + +func TestResolve_WithChannelOverride_SourceIsChannel(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(1e-6), + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.Equal(t, "channel", resolved.Source) +} + +func TestResolve_WithChannelOverride_DefaultMode(t *testing.T) { + // Channel pricing with empty BillingMode → defaults to BillingModeToken. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: "", // intentionally empty + InputPrice: testPtrFloat64(5e-6), + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.Equal(t, "channel", resolved.Source) + require.Equal(t, BillingModeToken, resolved.Mode) + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 5e-6, resolved.BasePricing.InputPricePerToken, 1e-12) +} + +// --------------------------------------------------------------------------- +// 5. GetIntervalPricing integration after channel override +// --------------------------------------------------------------------------- + +func TestGetIntervalPricing_WithChannelIntervals(t *testing.T) { + // Channel provides intervals that override the base pricing path. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(5e-6)}, + {MinTokens: 100000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(10e-6)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + // Token count 50000 matches first interval + pricing := r.GetIntervalPricing(resolved, 50000) + require.NotNil(t, pricing) + require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12) + + // Token count 150000 matches second interval + pricing2 := r.GetIntervalPricing(resolved, 150000) + require.NotNil(t, pricing2) + require.InDelta(t, 2e-6, pricing2.InputPricePerToken, 1e-12) + require.InDelta(t, 10e-6, pricing2.OutputPricePerToken, 1e-12) +} + +func TestGetIntervalPricing_ChannelIntervalsNoMatch(t *testing.T) { + // Channel intervals don't match token count → falls back to BasePricing. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + Intervals: []PricingInterval{ + // Only covers tokens > 50000 + {MinTokens: 50000, MaxTokens: testPtrInt(200000), InputPrice: testPtrFloat64(9e-6)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + // Token count 1000 doesn't match any interval (1000 <= 50000 minTokens) + pricing := r.GetIntervalPricing(resolved, 1000) + // Should fall back to BasePricing (from the billing service fallback) + require.NotNil(t, pricing) + require.Equal(t, resolved.BasePricing, pricing) + require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) // original base price +} + +// =========================================================================== +// 6. Error path tests +// =========================================================================== + +func TestResolve_WithChannelOverride_CacheError(t *testing.T) { + // When ListAll returns an error, the ChannelService cache build fails. + // Resolve should gracefully fall back to base pricing without panicking. + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, errors.New("database unavailable") + }, + } + cs := NewChannelService(repo, nil) + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(cs, bs) + + gid := int64(100) + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: &gid, + }) + + require.NotNil(t, resolved) + // Should NOT panic, should NOT have source "channel" + require.NotEqual(t, "channel", resolved.Source) + // Base pricing should still be present (from BillingService fallback) + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12) +} + +// =========================================================================== +// 7. GetRequestTierPriceByContext boundary tests +// =========================================================================== + +func TestGetRequestTierPriceByContext_EmptyTiers(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: nil, // empty + } + + price := r.GetRequestTierPriceByContext(resolved, 50000) + require.InDelta(t, 0.0, price, 1e-12) + + // Also test with explicit empty slice + resolved2 := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{}, + } + + price2 := r.GetRequestTierPriceByContext(resolved2, 50000) + require.InDelta(t, 0.0, price2, 1e-12) +} + +func TestGetRequestTierPriceByContext_ExactBoundary(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)}, + {MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)}, + }, + } + + // totalContextTokens = 128000 exactly: + // FindMatchingInterval checks: totalTokens > MinTokens && totalTokens <= MaxTokens + // For first interval: 128000 > 0 (true) && 128000 <= 128000 (true) → matches first interval + price := r.GetRequestTierPriceByContext(resolved, 128000) + require.InDelta(t, 0.05, price, 1e-12) + + // totalContextTokens = 128001 should match second interval + // For first interval: 128001 > 0 (true) && 128001 <= 128000 (false) → no match + // For second interval: 128001 > 128000 (true) && MaxTokens == nil → matches + price2 := r.GetRequestTierPriceByContext(resolved, 128001) + require.InDelta(t, 0.10, price2, 1e-12) +} + +// =========================================================================== +// 8. filterValidIntervals +// =========================================================================== + +func TestFilterValidIntervals(t *testing.T) { + tests := []struct { + name string + intervals []PricingInterval + wantLen int + }{ + { + name: "empty list", + intervals: nil, + wantLen: 0, + }, + { + name: "all-nil interval filtered out", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000)}, + }, + wantLen: 0, + }, + { + name: "interval with only InputPrice kept", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + }, + wantLen: 1, + }, + { + name: "interval with only OutputPrice kept", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), OutputPrice: testPtrFloat64(2e-6)}, + }, + wantLen: 1, + }, + { + name: "interval with only CacheWritePrice kept", + intervals: []PricingInterval{ + {MinTokens: 0, CacheWritePrice: testPtrFloat64(3e-6)}, + }, + wantLen: 1, + }, + { + name: "interval with only CacheReadPrice kept", + intervals: []PricingInterval{ + {MinTokens: 0, CacheReadPrice: testPtrFloat64(0.5e-6)}, + }, + wantLen: 1, + }, + { + name: "interval with only PerRequestPrice kept", + intervals: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + }, + wantLen: 1, + }, + { + name: "mixed valid and invalid", + intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: nil}, // all-nil → filtered out + {MinTokens: 256000, OutputPrice: testPtrFloat64(5e-6)}, + }, + wantLen: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := filterValidIntervals(tt.intervals) + require.Len(t, result, tt.wantLen) + }) + } +} diff --git a/backend/internal/service/openai_channel_restriction_test.go b/backend/internal/service/openai_channel_restriction_test.go new file mode 100644 index 0000000000..c9dbceabe0 --- /dev/null +++ b/backend/internal/service/openai_channel_restriction_test.go @@ -0,0 +1,140 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOpenAISelectAccountForModelWithExclusions_ChannelMappedRestrictionRejectsEarly(t *testing.T) { + t.Parallel() + + channelSvc := newTestChannelService(makeStandardRepo(Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceChannelMapped, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformOpenAI, Models: []string{"gpt-4o"}}, + }, + ModelMapping: map[string]map[string]string{ + PlatformOpenAI: {"gpt-4.1": "o3-mini"}, + }, + }, map[int64]string{10: PlatformOpenAI})) + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true}, + }}, + channelService: channelSvc, + } + + groupID := int64(10) + _, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil) + require.ErrorIs(t, err, ErrNoAvailableAccounts) + require.Contains(t, err.Error(), "channel pricing restriction") +} + +func TestOpenAISelectAccountForModelWithExclusions_UpstreamRestrictionSkipsDisallowedAccount(t *testing.T) { + t.Parallel() + + channelSvc := newTestChannelService(makeStandardRepo(Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceUpstream, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformOpenAI, Models: []string{"o3-mini"}}, + }, + }, map[int64]string{10: PlatformOpenAI})) + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{ + { + ID: 1, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Priority: 10, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gpt-4.1": "gpt-4o"}, + }, + }, + { + ID: 2, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Priority: 20, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gpt-4.1": "o3-mini"}, + }, + }, + }}, + channelService: channelSvc, + } + + groupID := int64(10) + account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(2), account.ID) +} + +func TestOpenAISelectAccountForModelWithExclusions_StickyRestrictedUpstreamFallsBack(t *testing.T) { + t.Parallel() + + channelSvc := newTestChannelService(makeStandardRepo(Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + BillingModelSource: BillingModelSourceUpstream, + ModelPricing: []ChannelModelPricing{ + {Platform: PlatformOpenAI, Models: []string{"o3-mini"}}, + }, + }, map[int64]string{10: PlatformOpenAI})) + + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{"openai:sticky-session": 1}, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{ + { + ID: 1, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Priority: 10, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gpt-4.1": "gpt-4o"}, + }, + }, + { + ID: 2, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + Priority: 20, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gpt-4.1": "o3-mini"}, + }, + }, + }}, + channelService: channelSvc, + cache: cache, + } + + groupID := int64(10) + account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "sticky-session", "gpt-4.1", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, int64(2), account.ID) + require.Equal(t, 1, cache.deletedSessions["openai:sticky-session"]) + require.Equal(t, int64(2), cache.sessionBindings["openai:sticky-session"]) +} diff --git a/backend/internal/service/openai_compat_prompt_cache_key.go b/backend/internal/service/openai_compat_prompt_cache_key.go index 46381838a3..88e16a4db0 100644 --- a/backend/internal/service/openai_compat_prompt_cache_key.go +++ b/backend/internal/service/openai_compat_prompt_cache_key.go @@ -10,8 +10,8 @@ import ( const compatPromptCacheKeyPrefix = "compat_cc_" func shouldAutoInjectPromptCacheKeyForCompat(model string) bool { - switch resolveOpenAIUpstreamModel(strings.TrimSpace(model)) { - case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark": + switch normalizeCodexModel(strings.TrimSpace(model)) { + case "gpt-5.4", "gpt-5.3-codex": return true default: return false @@ -23,9 +23,9 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod return "" } - normalizedModel := resolveOpenAIUpstreamModel(strings.TrimSpace(mappedModel)) + normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel)) if normalizedModel == "" { - normalizedModel = resolveOpenAIUpstreamModel(strings.TrimSpace(req.Model)) + normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model)) } if normalizedModel == "" { normalizedModel = strings.TrimSpace(req.Model) diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 1d5bf0d0a4..be076cc082 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -46,7 +46,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( // 2. Resolve model mapping early so compat prompt_cache_key injection can // derive a stable seed from the final upstream model family. billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) - upstreamModel := resolveOpenAIUpstreamModel(billingModel) + upstreamModel := normalizeCodexModel(billingModel) promptCacheKey = strings.TrimSpace(promptCacheKey) compatPromptCacheInjected := false diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 8c389556f2..dd416269f4 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -62,7 +62,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( // 3. Model mapping billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel) - upstreamModel := resolveOpenAIUpstreamModel(billingModel) + upstreamModel := normalizeCodexModel(billingModel) responsesReq.Model = upstreamModel logger.L().Debug("openai messages: model mapping applied", diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 7a636afad1..e2b164c09d 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -145,6 +145,8 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U nil, &DeferredService{}, nil, + nil, + nil, ) svc.userGroupRateResolver = newUserGroupRateResolver( rateRepo, diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index e85f0705aa..28c4b1f4f5 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "log/slog" "math/rand" "net/http" "sort" @@ -204,6 +205,7 @@ type OpenAIUsage struct { OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` + ImageOutputTokens int `json:"image_output_tokens,omitempty"` } // OpenAIForwardResult represents the result of forwarding @@ -322,6 +324,8 @@ type OpenAIGatewayService struct { openAITokenProvider *OpenAITokenProvider toolCorrector *CodexToolCorrector openaiWSResolver OpenAIWSProtocolResolver + resolver *ModelPricingResolver + channelService *ChannelService openaiWSPoolOnce sync.Once openaiWSStateStoreOnce sync.Once @@ -357,6 +361,8 @@ func NewOpenAIGatewayService( httpUpstream HTTPUpstream, deferredService *DeferredService, openAITokenProvider *OpenAITokenProvider, + resolver *ModelPricingResolver, + channelService *ChannelService, ) *OpenAIGatewayService { svc := &OpenAIGatewayService{ accountRepo: accountRepo, @@ -384,6 +390,8 @@ func NewOpenAIGatewayService( openAITokenProvider: openAITokenProvider, toolCorrector: NewCodexToolCorrector(), openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + resolver: resolver, + channelService: channelService, responseHeaderFilter: compileResponseHeaderFilter(cfg), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } @@ -391,6 +399,74 @@ func NewOpenAIGatewayService( return svc } +// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService) +func (s *OpenAIGatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model} + } + return s.channelService.ResolveChannelMapping(ctx, groupID, model) +} + +// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService) +func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { + if s.channelService == nil { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, model) +} + +// ResolveChannelMappingAndRestrict 解析渠道映射。 +// 模型限制检查已移至调度阶段,restricted 始终返回 false。 +func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model}, false + } + return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model) +} + +func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool { + if groupID == nil || s.channelService == nil || requestedModel == "" { + return false + } + mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel) + billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel) + if billingModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, *groupID, billingModel) +} + +func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool { + if s.channelService == nil { + return false + } + upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "") + if upstreamModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel) +} + +func (s *OpenAIGatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool { + if groupID == nil || s.channelService == nil { + return false + } + ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) + if err != nil { + slog.Warn("failed to check openai channel upstream restriction", "group_id", *groupID, "error", err) + return false + } + if ch == nil || !ch.RestrictModels { + return false + } + return ch.BillingModelSource == BillingModelSourceUpstream +} + +// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。 +func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { + return ReplaceModelInBody(body, newModel) +} + func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { if s != nil && s.codexSnapshotThrottle != nil { return s.codexSnapshotThrottle @@ -1125,6 +1201,13 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C } func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) { + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // 1. 尝试粘性会话命中 // Try sticky session hit if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil { @@ -1140,7 +1223,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C // 3. 按优先级 + LRU 选择最佳账号 // Select by priority + LRU - selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs) + selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs) if selected == nil { if requestedModel != "" { @@ -1206,6 +1289,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil } + if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) && + s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + return nil + } // 刷新会话 TTL 并返回账号 // Refresh session TTL and return account @@ -1218,8 +1306,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // // selectBestAccount selects the best account from candidates (priority + LRU). // Returns nil if no available account. -func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { +func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { var selected *Account + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) for i := range accounts { acc := &accounts[i] @@ -1238,6 +1327,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [ if fresh == nil { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + continue + } // 选择优先级最高且最久未使用的账号 // Select highest priority and least recently used @@ -1289,7 +1381,15 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + cfg := s.schedulingConfig() + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var stickyAccountID int64 if sessionHash != "" && s.cache != nil { if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { @@ -1365,6 +1465,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + } else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } else { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { @@ -1410,6 +1512,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { + continue + } candidates = append(candidates, acc) } @@ -1434,6 +1539,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if fresh == nil { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + continue + } result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { @@ -1488,6 +1596,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if fresh == nil { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + continue + } result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { @@ -1510,6 +1621,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if fresh == nil { continue } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { + continue + } return &AccountSelectionResult{ Account: fresh, WaitPlan: &AccountWaitPlan{ @@ -1825,7 +1939,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 if model, ok := reqBody["model"].(string); ok { - upstreamModel = resolveOpenAIUpstreamModel(model) + upstreamModel = normalizeCodexModel(model) if upstreamModel != "" && upstreamModel != model { logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", model, upstreamModel, account.Name, account.Type, isCodexCLI) @@ -4110,6 +4224,7 @@ type OpenAIRecordUsageInput struct { IPAddress string // 请求的客户端 IP 地址 RequestPayloadHash string APIKeyService APIKeyQuotaUpdater + ChannelUsageFields } // RecordUsage records usage and deducts balance @@ -4140,10 +4255,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, } // Get rate multiplier - multiplier := s.cfg.Default.RateMultiplier + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } if apiKey.GroupID != nil && apiKey.Group != nil { resolver := s.userGroupRateResolver if resolver == nil { @@ -4152,12 +4271,37 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) } + var cost *CostBreakdown + var err error billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) + if result.BillingModel != "" { + billingModel = strings.TrimSpace(result.BillingModel) + } + if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" { + billingModel = input.ChannelMappedModel + } + if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { + billingModel = input.OriginalModel + } serviceTier := "" if result.ServiceTier != nil { serviceTier = strings.TrimSpace(*result.ServiceTier) } - cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) + if s.resolver != nil && apiKey.Group != nil { + gid := apiKey.Group.ID + cost, err = s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + ServiceTier: serviceTier, + Resolver: s.resolver, + }) + } else { + cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) + } if err != nil { cost = &CostBreakdown{ActualCost: 0} } @@ -4173,36 +4317,58 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec durationMs := int(result.Duration.Milliseconds()) accountRateMultiplier := account.BillingRateMultiplier() requestID := resolveUsageBillingRequestID(ctx, result.RequestID) + + // 确定 RequestedModel(渠道映射前的原始模型) + requestedModel := result.Model + if input.OriginalModel != "" { + requestedModel = input.OriginalModel + } + usageLog := &UsageLog{ - UserID: user.ID, - APIKeyID: apiKey.ID, - AccountID: account.ID, - RequestID: requestID, - Model: result.Model, - RequestedModel: result.Model, - UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), - ServiceTier: result.ServiceTier, - ReasoningEffort: result.ReasoningEffort, - InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), - UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), - InputTokens: actualInputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - InputCost: cost.InputCost, - OutputCost: cost.OutputCost, - CacheCreationCost: cost.CacheCreationCost, - CacheReadCost: cost.CacheReadCost, - TotalCost: cost.TotalCost, - ActualCost: cost.ActualCost, - RateMultiplier: multiplier, - AccountRateMultiplier: &accountRateMultiplier, - BillingType: billingType, - Stream: result.Stream, - OpenAIWSMode: result.OpenAIWSMode, - DurationMs: &durationMs, - FirstTokenMs: result.FirstTokenMs, - CreatedAt: time.Now(), + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: result.Model, + RequestedModel: requestedModel, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), + ServiceTier: result.ServiceTier, + ReasoningEffort: result.ReasoningEffort, + InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), + UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), + InputTokens: actualInputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, + } + if cost != nil { + usageLog.InputCost = cost.InputCost + usageLog.OutputCost = cost.OutputCost + usageLog.ImageOutputCost = cost.ImageOutputCost + usageLog.CacheCreationCost = cost.CacheCreationCost + usageLog.CacheReadCost = cost.CacheReadCost + usageLog.TotalCost = cost.TotalCost + usageLog.ActualCost = cost.ActualCost + } + usageLog.RateMultiplier = multiplier + usageLog.AccountRateMultiplier = &accountRateMultiplier + usageLog.BillingType = billingType + usageLog.Stream = result.Stream + usageLog.OpenAIWSMode = result.OpenAIWSMode + usageLog.DurationMs = &durationMs + usageLog.FirstTokenMs = result.FirstTokenMs + usageLog.CreatedAt = time.Now() + // 设置渠道信息 + usageLog.ChannelID = optionalInt64Ptr(input.ChannelID) + usageLog.ModelMappingChain = optionalTrimmedStringPtr(input.ModelMappingChain) + // 设置计费模式 + if cost != nil && cost.BillingMode != "" { + billingMode := cost.BillingMode + usageLog.BillingMode = &billingMode + } else { + billingMode := string(BillingModeToken) + usageLog.BillingMode = &billingMode } // 添加 UserAgent if input.UserAgent != "" { diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go index 4f8c094bcd..9bf3fba3b9 100644 --- a/backend/internal/service/openai_model_mapping.go +++ b/backend/internal/service/openai_model_mapping.go @@ -1,10 +1,8 @@ package service -import "strings" - -// resolveOpenAIForwardModel resolves the account/group mapping result for -// OpenAI-compatible forwarding. Group-level default mapping only applies when -// the account itself did not match any explicit model_mapping rule. +// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible +// forwarding. Group-level default mapping only applies when the account itself +// did not match any explicit model_mapping rule. func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { if account == nil { if defaultMappedModel != "" { @@ -19,23 +17,3 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo } return mappedModel } - -func resolveOpenAIUpstreamModel(model string) string { - if isBareGPT53CodexSparkModel(model) { - return "gpt-5.3-codex-spark" - } - return normalizeCodexModel(strings.TrimSpace(model)) -} - -func isBareGPT53CodexSparkModel(model string) bool { - modelID := strings.TrimSpace(model) - if modelID == "" { - return false - } - if strings.Contains(modelID, "/") { - parts := strings.Split(modelID, "/") - modelID = parts[len(parts)-1] - } - normalized := strings.ToLower(strings.TrimSpace(modelID)) - return normalized == "gpt-5.3-codex-spark" || normalized == "gpt 5.3 codex spark" -} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index 42f58b3741..5ce2602c1c 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -74,30 +74,28 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t * Credentials: map[string]any{}, } - withoutDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) + withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) if withoutDefault != "gpt-5.1" { - t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withoutDefault, "gpt-5.1") + t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1") } - withDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) + withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) if withDefault != "gpt-5.4" { - t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withDefault, "gpt-5.4") + t.Fatalf("normalizeCodexModel(...) = %q, want %q", withDefault, "gpt-5.4") } } -func TestResolveOpenAIUpstreamModel(t *testing.T) { +func TestNormalizeCodexModel(t *testing.T) { cases := map[string]string{ - "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", - "gpt 5.3 codex spark": "gpt-5.3-codex-spark", - " openai/gpt-5.3-codex-spark ": "gpt-5.3-codex-spark", - "gpt-5.3-codex-spark-high": "gpt-5.3-codex", - "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", - "gpt-5.3": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3": "gpt-5.3-codex", } for input, expected := range cases { - if got := resolveOpenAIUpstreamModel(input); got != expected { - t.Fatalf("resolveOpenAIUpstreamModel(%q) = %q, want %q", input, got, expected) + if got := normalizeCodexModel(input); got != expected { + t.Fatalf("normalizeCodexModel(%q) = %q, want %q", input, got, expected) } } } diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 1ebe554236..6d45baab36 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -2515,7 +2515,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } normalized = next } - upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) + upstreamModel := normalizeCodexModel(account.GetMappedModel(originalModel)) if upstreamModel != originalModel { next, setErr := applyPayloadMutation(normalized, "model", upstreamModel) if setErr != nil { @@ -2773,7 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( mappedModel := "" var mappedModelBytes []byte if originalModel != "" { - mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) + mappedModel = normalizeCodexModel(account.GetMappedModel(originalModel)) needModelReplace = mappedModel != "" && mappedModel != originalModel if needModelReplace { mappedModelBytes = []byte(mappedModel) diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index 8c5c936844..3834dcb785 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -615,6 +615,8 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, + nil, ) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index fdabbafde9..c0e814ab7b 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -519,7 +519,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry if s.gatewayService == nil { return nil, fmt.Errorf("gateway service not available") } - return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制 + return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "", int64(0)) // 重试不使用会话限制 default: return nil, fmt.Errorf("unsupported retry type: %s", reqType) } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 5623d4b742..3b3f31c309 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -70,7 +70,8 @@ type LiteLLMModelPricing struct { LiteLLMProvider string `json:"litellm_provider"` Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` - OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 + OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 + OutputCostPerImageToken float64 `json:"output_cost_per_image_token"` // 图片输出 token 价格 } // PricingRemoteClient 远程价格数据获取接口 @@ -94,6 +95,7 @@ type LiteLLMRawEntry struct { Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` OutputCostPerImage *float64 `json:"output_cost_per_image"` + OutputCostPerImageToken *float64 `json:"output_cost_per_image_token"` } // PricingService 动态价格服务 @@ -408,6 +410,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel if entry.OutputCostPerImage != nil { pricing.OutputCostPerImage = *entry.OutputCostPerImage } + if entry.OutputCostPerImageToken != nil { + pricing.OutputCostPerImageToken = *entry.OutputCostPerImageToken + } result[modelName] = pricing } diff --git a/backend/internal/service/testhelpers_test.go b/backend/internal/service/testhelpers_test.go new file mode 100644 index 0000000000..73750e2787 --- /dev/null +++ b/backend/internal/service/testhelpers_test.go @@ -0,0 +1,15 @@ +//go:build unit + +package service + +// testPtrFloat64 returns a pointer to the given float64 value. +func testPtrFloat64(v float64) *float64 { return &v } + +// testPtrInt returns a pointer to the given int value. +func testPtrInt(v int) *int { return &v } + +// testPtrString returns a pointer to the given string value. +func testPtrString(v string) *string { return &v } + +// testPtrBool returns a pointer to the given bool value. +func testPtrBool(v bool) *bool { return &v } diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index 576841facd..0f1ccc09ad 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -104,6 +104,14 @@ type UsageLog struct { // UpstreamModel is the actual model sent to the upstream provider after mapping. // Nil means no mapping was applied (requested model was used as-is). UpstreamModel *string + // ChannelID 渠道 ID + ChannelID *int64 + // ModelMappingChain 模型映射链,如 "a→b→c" + ModelMappingChain *string + // BillingTier 计费层级标签(per_request/image 模式) + BillingTier *string + // BillingMode 计费模式:token/image(sora 路径为 nil) + BillingMode *string // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". ServiceTier *string // ReasoningEffort is the request's reasoning effort level. @@ -126,6 +134,9 @@ type UsageLog struct { CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"` CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"` + ImageOutputTokens int + ImageOutputCost float64 + InputCost float64 OutputCost float64 CacheCreationCost float64 diff --git a/backend/internal/service/usage_log_helpers.go b/backend/internal/service/usage_log_helpers.go index a7bcae997f..7cc8a713cd 100644 --- a/backend/internal/service/usage_log_helpers.go +++ b/backend/internal/service/usage_log_helpers.go @@ -26,3 +26,10 @@ func forwardResultBillingModel(requestedModel, upstreamModel string) string { } return strings.TrimSpace(upstreamModel) } + +func optionalInt64Ptr(v int64) *int64 { + if v == 0 { + return nil + } + return &v +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index d79a353124..717e22ffc5 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -490,4 +490,6 @@ var ProviderSet = wire.NewSet( ProvideScheduledTestService, ProvideScheduledTestRunnerService, NewGroupCapacityService, + NewChannelService, + NewModelPricingResolver, ) diff --git a/backend/migrations/081_create_channels.sql b/backend/migrations/081_create_channels.sql new file mode 100644 index 0000000000..3059816b2e --- /dev/null +++ b/backend/migrations/081_create_channels.sql @@ -0,0 +1,56 @@ +-- Create channels table for managing pricing channels. +-- A channel groups multiple groups together and provides custom model pricing. + +SET LOCAL lock_timeout = '5s'; +SET LOCAL statement_timeout = '10min'; + +-- 渠道表 +CREATE TABLE IF NOT EXISTS channels ( + id BIGSERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + description TEXT DEFAULT '', + status VARCHAR(20) NOT NULL DEFAULT 'active', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- 渠道名称唯一索引 +CREATE UNIQUE INDEX IF NOT EXISTS idx_channels_name ON channels (name); +CREATE INDEX IF NOT EXISTS idx_channels_status ON channels (status); + +-- 渠道-分组关联表(每个分组只能属于一个渠道) +CREATE TABLE IF NOT EXISTS channel_groups ( + id BIGSERIAL PRIMARY KEY, + channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE, + group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_groups_group_id ON channel_groups (group_id); +CREATE INDEX IF NOT EXISTS idx_channel_groups_channel_id ON channel_groups (channel_id); + +-- 渠道模型定价表(一条定价可绑定多个模型) +CREATE TABLE IF NOT EXISTS channel_model_pricing ( + id BIGSERIAL PRIMARY KEY, + channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE, + models JSONB NOT NULL DEFAULT '[]', + input_price NUMERIC(20,12), + output_price NUMERIC(20,12), + cache_write_price NUMERIC(20,12), + cache_read_price NUMERIC(20,12), + image_output_price NUMERIC(20,8), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_channel_id ON channel_model_pricing (channel_id); + +COMMENT ON TABLE channels IS '渠道管理:关联多个分组,提供自定义模型定价'; +COMMENT ON TABLE channel_groups IS '渠道-分组关联表:每个分组最多属于一个渠道'; +COMMENT ON TABLE channel_model_pricing IS '渠道模型定价:一条定价可绑定多个模型,价格一致'; +COMMENT ON COLUMN channel_model_pricing.models IS '绑定的模型列表,JSON 数组,如 ["claude-opus-4-6","claude-opus-4-6-thinking"]'; +COMMENT ON COLUMN channel_model_pricing.input_price IS '每 token 输入价格(USD),NULL 表示使用默认'; +COMMENT ON COLUMN channel_model_pricing.output_price IS '每 token 输出价格(USD),NULL 表示使用默认'; +COMMENT ON COLUMN channel_model_pricing.cache_write_price IS '缓存写入每 token 价格,NULL 表示使用默认'; +COMMENT ON COLUMN channel_model_pricing.cache_read_price IS '缓存读取每 token 价格,NULL 表示使用默认'; +COMMENT ON COLUMN channel_model_pricing.image_output_price IS '图片输出价格(Gemini Image 等),NULL 表示使用默认'; diff --git a/backend/migrations/082_refactor_channel_pricing.sql b/backend/migrations/082_refactor_channel_pricing.sql new file mode 100644 index 0000000000..d0a5406246 --- /dev/null +++ b/backend/migrations/082_refactor_channel_pricing.sql @@ -0,0 +1,67 @@ +-- Extend channel_model_pricing with billing_mode and add context-interval child table. +-- Supports three billing modes: token (per-token with context intervals), +-- per_request (per-request with context-size tiers), and image (per-image). + +SET LOCAL lock_timeout = '5s'; +SET LOCAL statement_timeout = '10min'; + +-- 1. 为 channel_model_pricing 添加 billing_mode 列 +ALTER TABLE channel_model_pricing + ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20) NOT NULL DEFAULT 'token'; + +COMMENT ON COLUMN channel_model_pricing.billing_mode IS '计费模式:token(按 token 区间计费)、per_request(按次计费)、image(图片计费)'; + +-- 2. 创建区间定价子表 +CREATE TABLE IF NOT EXISTS channel_pricing_intervals ( + id BIGSERIAL PRIMARY KEY, + pricing_id BIGINT NOT NULL REFERENCES channel_model_pricing(id) ON DELETE CASCADE, + min_tokens INT NOT NULL DEFAULT 0, + max_tokens INT, + tier_label VARCHAR(50), + input_price NUMERIC(20,12), + output_price NUMERIC(20,12), + cache_write_price NUMERIC(20,12), + cache_read_price NUMERIC(20,12), + per_request_price NUMERIC(20,12), + sort_order INT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_channel_pricing_intervals_pricing_id + ON channel_pricing_intervals (pricing_id); + +COMMENT ON TABLE channel_pricing_intervals IS '渠道定价区间:支持按 token 区间、按次分层、图片分辨率分层'; +COMMENT ON COLUMN channel_pricing_intervals.min_tokens IS '区间下界(含),token 模式使用'; +COMMENT ON COLUMN channel_pricing_intervals.max_tokens IS '区间上界(不含),NULL 表示无上限'; +COMMENT ON COLUMN channel_pricing_intervals.tier_label IS '层级标签,按次/图片模式使用(如 1K、2K、4K、HD)'; +COMMENT ON COLUMN channel_pricing_intervals.input_price IS 'token 模式:每 token 输入价'; +COMMENT ON COLUMN channel_pricing_intervals.output_price IS 'token 模式:每 token 输出价'; +COMMENT ON COLUMN channel_pricing_intervals.cache_write_price IS 'token 模式:缓存写入价'; +COMMENT ON COLUMN channel_pricing_intervals.cache_read_price IS 'token 模式:缓存读取价'; +COMMENT ON COLUMN channel_pricing_intervals.per_request_price IS '按次/图片模式:每次请求价格'; + +-- 3. 迁移现有 flat 定价为单区间 [0, +inf) +-- 仅迁移有明确定价(至少一个价格字段非 NULL)的条目 +INSERT INTO channel_pricing_intervals (pricing_id, min_tokens, max_tokens, input_price, output_price, cache_write_price, cache_read_price, sort_order) +SELECT + cmp.id, + 0, + NULL, + cmp.input_price, + cmp.output_price, + cmp.cache_write_price, + cmp.cache_read_price, + 0 +FROM channel_model_pricing cmp +WHERE cmp.billing_mode = 'token' + AND (cmp.input_price IS NOT NULL OR cmp.output_price IS NOT NULL + OR cmp.cache_write_price IS NOT NULL OR cmp.cache_read_price IS NOT NULL) + AND NOT EXISTS ( + SELECT 1 FROM channel_pricing_intervals cpi WHERE cpi.pricing_id = cmp.id + ); + +-- 4. 迁移 image_output_price 为 image 模式的区间条目 +-- 将有 image_output_price 的现有条目复制为 billing_mode='image' 的独立条目 +-- 注意:这里不改变原条目的 billing_mode,而是将 image_output_price 作为向后兼容字段保留 +-- 实际的 image 计费在未来由独立的 billing_mode='image' 条目处理 diff --git a/backend/migrations/083_channel_model_mapping.sql b/backend/migrations/083_channel_model_mapping.sql new file mode 100644 index 0000000000..68e2203f5d --- /dev/null +++ b/backend/migrations/083_channel_model_mapping.sql @@ -0,0 +1,5 @@ +SET LOCAL lock_timeout = '5s'; +SET LOCAL statement_timeout = '10min'; + +ALTER TABLE channels ADD COLUMN IF NOT EXISTS model_mapping JSONB DEFAULT '{}'; +COMMENT ON COLUMN channels.model_mapping IS '渠道级模型映射,在账号映射之前执行。格式:{"source_model": "target_model"}'; diff --git a/backend/migrations/084_channel_billing_model_source.sql b/backend/migrations/084_channel_billing_model_source.sql new file mode 100644 index 0000000000..bd615bacbd --- /dev/null +++ b/backend/migrations/084_channel_billing_model_source.sql @@ -0,0 +1,7 @@ +-- Add billing_model_source to channels (controls whether billing uses requested or upstream model) +ALTER TABLE channels ADD COLUMN IF NOT EXISTS billing_model_source VARCHAR(20) DEFAULT 'requested'; + +-- Add channel tracking fields to usage_logs +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS channel_id BIGINT; +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS model_mapping_chain VARCHAR(500); +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_tier VARCHAR(50); diff --git a/backend/migrations/085_channel_restrict_and_per_request_price.sql b/backend/migrations/085_channel_restrict_and_per_request_price.sql new file mode 100644 index 0000000000..2f494c63cb --- /dev/null +++ b/backend/migrations/085_channel_restrict_and_per_request_price.sql @@ -0,0 +1,5 @@ +-- Add model restriction switch to channels +ALTER TABLE channels ADD COLUMN IF NOT EXISTS restrict_models BOOLEAN DEFAULT false; + +-- Add default per_request_price to channel_model_pricing (fallback when no tier matches) +ALTER TABLE channel_model_pricing ADD COLUMN IF NOT EXISTS per_request_price NUMERIC(20,10); diff --git a/backend/migrations/086_channel_platform_pricing.sql b/backend/migrations/086_channel_platform_pricing.sql new file mode 100644 index 0000000000..f2d0856279 --- /dev/null +++ b/backend/migrations/086_channel_platform_pricing.sql @@ -0,0 +1,21 @@ +-- 086_channel_platform_pricing.sql +-- 渠道按平台维度:model_pricing 加 platform 列,model_mapping 改为嵌套格式 + +-- 1. channel_model_pricing 加 platform 列 +ALTER TABLE channel_model_pricing + ADD COLUMN IF NOT EXISTS platform VARCHAR(50) NOT NULL DEFAULT 'anthropic'; + +CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_platform + ON channel_model_pricing (platform); + +-- 2. model_mapping: 从扁平 {"src":"dst"} 迁移为嵌套 {"anthropic":{"src":"dst"}} +-- 仅迁移非空、非 '{}' 的旧格式数据(通过检查第一个 value 是否为字符串来判断是否为旧格式) +UPDATE channels +SET model_mapping = jsonb_build_object('anthropic', model_mapping) +WHERE model_mapping IS NOT NULL + AND model_mapping::text NOT IN ('{}', 'null', '') + AND NOT EXISTS ( + SELECT 1 FROM jsonb_each(model_mapping) AS kv + WHERE jsonb_typeof(kv.value) = 'object' + LIMIT 1 + ); diff --git a/backend/migrations/087_usage_log_billing_mode.sql b/backend/migrations/087_usage_log_billing_mode.sql new file mode 100644 index 0000000000..8552be0bc8 --- /dev/null +++ b/backend/migrations/087_usage_log_billing_mode.sql @@ -0,0 +1,2 @@ +-- Add billing_mode to usage_logs (records the billing mode: token/per_request/image) +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20); diff --git a/backend/migrations/088_channel_billing_model_source_channel_mapped.sql b/backend/migrations/088_channel_billing_model_source_channel_mapped.sql new file mode 100644 index 0000000000..83f96b094f --- /dev/null +++ b/backend/migrations/088_channel_billing_model_source_channel_mapped.sql @@ -0,0 +1,3 @@ +-- Change default billing_model_source for new channels to 'channel_mapped' +-- Existing channels keep their current setting (no UPDATE on existing rows) +ALTER TABLE channels ALTER COLUMN billing_model_source SET DEFAULT 'channel_mapped'; diff --git a/backend/migrations/089_usage_log_image_output_tokens.sql b/backend/migrations/089_usage_log_image_output_tokens.sql new file mode 100644 index 0000000000..dd142b1516 --- /dev/null +++ b/backend/migrations/089_usage_log_image_output_tokens.sql @@ -0,0 +1,2 @@ +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_tokens INTEGER NOT NULL DEFAULT 0; +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_cost DECIMAL(20, 10) NOT NULL DEFAULT 0; diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts new file mode 100644 index 0000000000..5334dd473d --- /dev/null +++ b/frontend/src/api/admin/channels.ts @@ -0,0 +1,148 @@ +/** + * Admin Channels API endpoints + * Handles channel management for administrators + */ + +import { apiClient } from '../client' + +export type BillingMode = 'token' | 'per_request' | 'image' + +export interface PricingInterval { + id?: number + min_tokens: number + max_tokens: number | null + tier_label: string + input_price: number | null + output_price: number | null + cache_write_price: number | null + cache_read_price: number | null + per_request_price: number | null + sort_order: number +} + +export interface ChannelModelPricing { + id?: number + platform: string + models: string[] + billing_mode: BillingMode + input_price: number | null + output_price: number | null + cache_write_price: number | null + cache_read_price: number | null + image_output_price: number | null + per_request_price: number | null + intervals: PricingInterval[] +} + +export interface Channel { + id: number + name: string + description: string + status: string + billing_model_source: string // "requested" | "upstream" + restrict_models: boolean + group_ids: number[] + model_pricing: ChannelModelPricing[] + model_mapping: Record> // platform → {src→dst} + created_at: string + updated_at: string +} + +export interface CreateChannelRequest { + name: string + description?: string + group_ids?: number[] + model_pricing?: ChannelModelPricing[] + model_mapping?: Record> + billing_model_source?: string + restrict_models?: boolean +} + +export interface UpdateChannelRequest { + name?: string + description?: string + status?: string + group_ids?: number[] + model_pricing?: ChannelModelPricing[] + model_mapping?: Record> + billing_model_source?: string + restrict_models?: boolean +} + +interface PaginatedResponse { + items: T[] + total: number +} + +/** + * List channels with pagination + */ +export async function list( + page: number = 1, + pageSize: number = 20, + filters?: { + status?: string + search?: string + }, + options?: { signal?: AbortSignal } +): Promise> { + const { data } = await apiClient.get>('/admin/channels', { + params: { + page, + page_size: pageSize, + ...filters + }, + signal: options?.signal + }) + return data +} + +/** + * Get channel by ID + */ +export async function getById(id: number): Promise { + const { data } = await apiClient.get(`/admin/channels/${id}`) + return data +} + +/** + * Create a new channel + */ +export async function create(req: CreateChannelRequest): Promise { + const { data } = await apiClient.post('/admin/channels', req) + return data +} + +/** + * Update a channel + */ +export async function update(id: number, req: UpdateChannelRequest): Promise { + const { data } = await apiClient.put(`/admin/channels/${id}`, req) + return data +} + +/** + * Delete a channel + */ +export async function remove(id: number): Promise { + await apiClient.delete(`/admin/channels/${id}`) +} + +export interface ModelDefaultPricing { + found: boolean + input_price?: number // per-token price + output_price?: number + cache_write_price?: number + cache_read_price?: number + image_output_price?: number +} + +export async function getModelDefaultPricing(model: string): Promise { + const { data } = await apiClient.get('/admin/channels/model-pricing', { + params: { model } + }) + return data +} + +const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing } +export default channelsAPI diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index 15d1540fb8..49e487b744 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -167,6 +167,13 @@ export interface UserBreakdownParams { endpoint?: string endpoint_type?: 'inbound' | 'upstream' | 'path' limit?: number + // Additional filter conditions + user_id?: number + api_key_id?: number + account_id?: number + request_type?: number + stream?: boolean + billing_type?: number | null } export interface UserBreakdownResponse { diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index 9a3fb8c510..da1e3cfa8f 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -25,6 +25,7 @@ import apiKeysAPI from './apiKeys' import scheduledTestsAPI from './scheduledTests' import backupAPI from './backup' import tlsFingerprintProfileAPI from './tlsFingerprintProfile' +import channelsAPI from './channels' /** * Unified admin API object for convenient access @@ -51,7 +52,8 @@ export const adminAPI = { apiKeys: apiKeysAPI, scheduledTests: scheduledTestsAPI, backup: backupAPI, - tlsFingerprintProfiles: tlsFingerprintProfileAPI + tlsFingerprintProfiles: tlsFingerprintProfileAPI, + channels: channelsAPI } export { @@ -76,7 +78,8 @@ export { apiKeysAPI, scheduledTestsAPI, backupAPI, - tlsFingerprintProfileAPI + tlsFingerprintProfileAPI, + channelsAPI } export default adminAPI diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts index bd7e3e57b0..d21b28dcdb 100644 --- a/frontend/src/api/admin/usage.ts +++ b/frontend/src/api/admin/usage.ts @@ -80,6 +80,7 @@ export interface CreateUsageCleanupTaskRequest { export interface AdminUsageQueryParams extends UsageQueryParams { user_id?: number exact_total?: boolean + billing_mode?: string } // ==================== API Functions ==================== diff --git a/frontend/src/components/admin/channel/IntervalRow.vue b/frontend/src/components/admin/channel/IntervalRow.vue new file mode 100644 index 0000000000..21dcc90d4c --- /dev/null +++ b/frontend/src/components/admin/channel/IntervalRow.vue @@ -0,0 +1,113 @@ + + + diff --git a/frontend/src/components/admin/channel/ModelTagInput.vue b/frontend/src/components/admin/channel/ModelTagInput.vue new file mode 100644 index 0000000000..a1ce402293 --- /dev/null +++ b/frontend/src/components/admin/channel/ModelTagInput.vue @@ -0,0 +1,89 @@ + + + diff --git a/frontend/src/components/admin/channel/PricingEntryCard.vue b/frontend/src/components/admin/channel/PricingEntryCard.vue new file mode 100644 index 0000000000..e98853c333 --- /dev/null +++ b/frontend/src/components/admin/channel/PricingEntryCard.vue @@ -0,0 +1,354 @@ + + + + + diff --git a/frontend/src/components/admin/channel/types.ts b/frontend/src/components/admin/channel/types.ts new file mode 100644 index 0000000000..8d998911aa --- /dev/null +++ b/frontend/src/components/admin/channel/types.ts @@ -0,0 +1,190 @@ +import type { BillingMode, PricingInterval } from '@/api/admin/channels' + +export interface IntervalFormEntry { + min_tokens: number + max_tokens: number | null + tier_label: string + input_price: number | string | null + output_price: number | string | null + cache_write_price: number | string | null + cache_read_price: number | string | null + per_request_price: number | string | null + sort_order: number +} + +export interface PricingFormEntry { + models: string[] + billing_mode: BillingMode + input_price: number | string | null + output_price: number | string | null + cache_write_price: number | string | null + cache_read_price: number | string | null + image_output_price: number | string | null + per_request_price: number | string | null + intervals: IntervalFormEntry[] +} + +// 价格转换:后端存 per-token,前端显示 per-MTok ($/1M tokens) +const MTOK = 1_000_000 + +export function toNullableNumber(val: number | string | null | undefined): number | null { + if (val === null || val === undefined || val === '') return null + const num = Number(val) + return isNaN(num) ? null : num +} + +/** 前端显示值($/MTok) → 后端存储值(per-token) */ +export function mTokToPerToken(val: number | string | null | undefined): number | null { + const num = toNullableNumber(val) + return num === null ? null : parseFloat((num / MTOK).toPrecision(10)) +} + +/** 后端存储值(per-token) → 前端显示值($/MTok) */ +export function perTokenToMTok(val: number | null | undefined): number | null { + if (val === null || val === undefined) return null + // toPrecision(10) 消除 IEEE 754 浮点乘法精度误差,如 5e-8 * 1e6 = 0.04999...96 → 0.05 + return parseFloat((val * MTOK).toPrecision(10)) +} + +export function apiIntervalsToForm(intervals: PricingInterval[]): IntervalFormEntry[] { + return (intervals || []).map(iv => ({ + min_tokens: iv.min_tokens, + max_tokens: iv.max_tokens, + tier_label: iv.tier_label || '', + input_price: perTokenToMTok(iv.input_price), + output_price: perTokenToMTok(iv.output_price), + cache_write_price: perTokenToMTok(iv.cache_write_price), + cache_read_price: perTokenToMTok(iv.cache_read_price), + per_request_price: iv.per_request_price, + sort_order: iv.sort_order + })) +} + +export function formIntervalsToAPI(intervals: IntervalFormEntry[]): PricingInterval[] { + return (intervals || []).map(iv => ({ + min_tokens: iv.min_tokens, + max_tokens: iv.max_tokens, + tier_label: iv.tier_label, + input_price: mTokToPerToken(iv.input_price), + output_price: mTokToPerToken(iv.output_price), + cache_write_price: mTokToPerToken(iv.cache_write_price), + cache_read_price: mTokToPerToken(iv.cache_read_price), + per_request_price: toNullableNumber(iv.per_request_price), + sort_order: iv.sort_order + })) +} + +// ── 模型模式冲突检测 ────────────────────────────────────── + +interface ModelPattern { + pattern: string + prefix: string // lowercase, 通配符去掉尾部 * + wildcard: boolean +} + +function toModelPattern(model: string): ModelPattern { + const lower = model.toLowerCase() + const wildcard = lower.endsWith('*') + return { + pattern: model, + prefix: wildcard ? lower.slice(0, -1) : lower, + wildcard, + } +} + +function patternsConflict(a: ModelPattern, b: ModelPattern): boolean { + if (!a.wildcard && !b.wildcard) return a.prefix === b.prefix + if (a.wildcard && !b.wildcard) return b.prefix.startsWith(a.prefix) + if (!a.wildcard && b.wildcard) return a.prefix.startsWith(b.prefix) + // 双通配符:任一前缀是另一前缀的前缀即冲突 + return a.prefix.startsWith(b.prefix) || b.prefix.startsWith(a.prefix) +} + +/** 检测模型模式列表中的冲突,返回冲突的两个模式名;无冲突返回 null */ +export function findModelConflict(models: string[]): [string, string] | null { + const patterns = models.map(toModelPattern) + for (let i = 0; i < patterns.length; i++) { + for (let j = i + 1; j < patterns.length; j++) { + if (patternsConflict(patterns[i], patterns[j])) { + return [patterns[i].pattern, patterns[j].pattern] + } + } + } + return null +} + +// ── 区间校验 ────────────────────────────────────────────── + +/** 校验区间列表的合法性,返回错误消息;通过则返回 null */ +export function validateIntervals(intervals: IntervalFormEntry[]): string | null { + if (!intervals || intervals.length === 0) return null + + // 按 min_tokens 排序(不修改原数组) + const sorted = [...intervals].sort((a, b) => a.min_tokens - b.min_tokens) + + for (let i = 0; i < sorted.length; i++) { + const err = validateSingleInterval(sorted[i], i) + if (err) return err + } + return checkIntervalOverlap(sorted) +} + +function validateSingleInterval(iv: IntervalFormEntry, idx: number): string | null { + if (iv.min_tokens < 0) { + return `区间 #${idx + 1}: 最小 token 数 (${iv.min_tokens}) 不能为负数` + } + if (iv.max_tokens != null) { + if (iv.max_tokens <= 0) { + return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于 0` + } + if (iv.max_tokens <= iv.min_tokens) { + return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于最小 token 数 (${iv.min_tokens})` + } + } + return validateIntervalPrices(iv, idx) +} + +function validateIntervalPrices(iv: IntervalFormEntry, idx: number): string | null { + const prices: [string, number | string | null][] = [ + ['输入价格', iv.input_price], + ['输出价格', iv.output_price], + ['缓存写入价格', iv.cache_write_price], + ['缓存读取价格', iv.cache_read_price], + ['单次价格', iv.per_request_price], + ] + for (const [name, val] of prices) { + if (val != null && val !== '' && Number(val) < 0) { + return `区间 #${idx + 1}: ${name}不能为负数` + } + } + return null +} + +function checkIntervalOverlap(sorted: IntervalFormEntry[]): string | null { + for (let i = 0; i < sorted.length; i++) { + // 无上限区间必须是最后一个 + if (sorted[i].max_tokens == null && i < sorted.length - 1) { + return `区间 #${i + 1}: 无上限区间(最大 token 数为空)只能是最后一个` + } + if (i === 0) continue + const prev = sorted[i - 1] + // (min, max] 语义:前一个区间上界 > 当前区间下界则重叠 + if (prev.max_tokens == null || prev.max_tokens > sorted[i].min_tokens) { + const prevMax = prev.max_tokens == null ? '∞' : String(prev.max_tokens) + return `区间 #${i} 和 #${i + 1} 重叠:前一个区间上界 (${prevMax}) 大于当前区间下界 (${sorted[i].min_tokens})` + } + } + return null +} + +/** 平台对应的模型 tag 样式(背景+文字) */ +export function getPlatformTagClass(platform: string): string { + switch (platform) { + case 'anthropic': return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400' + case 'openai': return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400' + case 'gemini': return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400' + case 'antigravity': return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400' + case 'sora': return 'bg-rose-100 text-rose-700 dark:bg-rose-900/30 dark:text-rose-400' + default: return 'bg-gray-100 text-gray-700 dark:bg-gray-900/30 dark:text-gray-400' + } +} diff --git a/frontend/src/components/admin/usage/UsageFilters.vue b/frontend/src/components/admin/usage/UsageFilters.vue index ee5020e796..66c2b4fa65 100644 --- a/frontend/src/components/admin/usage/UsageFilters.vue +++ b/frontend/src/components/admin/usage/UsageFilters.vue @@ -133,6 +133,12 @@ + +
@@ -232,6 +238,13 @@ const billingTypeOptions = ref([ { value: 1, label: t('admin.usage.billingTypeSubscription') } ]) +const billingModeOptions = ref([ + { value: null, label: t('admin.usage.allBillingModes') }, + { value: 'token', label: t('admin.usage.billingModeToken') }, + { value: 'per_request', label: t('admin.usage.billingModePerRequest') }, + { value: 'image', label: t('admin.usage.billingModeImage') } +]) + const emitChange = () => emit('change') const debounceUserSearch = () => { diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index 4a42ab05a2..9bbdb380b6 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -26,7 +26,15 @@ + + + +