Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 155 additions & 56 deletions sdk/api/handlers/openai/openai_responses_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
}

allowCompactionReplayBypass := false
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
}
} else {
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if requestModelName == "" {
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
}
allowCompactionReplayBypass = h.websocketUpstreamSupportsCompactionReplayForModel(requestModelName)
}

var requestJSON []byte
var updatedLastRequest []byte
var errMsg *interfaces.ErrorMessage
Expand All @@ -124,6 +137,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
lastRequest,
lastResponseOutput,
allowIncrementalInputWithPreviousResponseID,
allowCompactionReplayBypass,
)
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
Expand Down Expand Up @@ -222,21 +236,21 @@ func websocketUpgradeHeaders(req *http.Request) http.Header {
}

func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true, true)
}

func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
switch requestType {
case wsRequestTypeCreate:
// log.Infof("responses websocket: response.create request")
if len(lastRequest) == 0 {
return normalizeResponseCreateRequest(rawJSON)
}
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
case wsRequestTypeAppend:
// log.Infof("responses websocket: response.append request")
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
default:
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Expand Down Expand Up @@ -265,7 +279,7 @@ func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces
return normalized, bytes.Clone(normalized), nil
}

func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
if len(lastRequest) == 0 {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Expand Down Expand Up @@ -315,20 +329,37 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
}
}

existingInput := gjson.GetBytes(lastRequest, "input")
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
// When the client sends a compact replay for a downstream that can consume it
// directly, the input already carries the canonical history. In that case,
// skip merging with stale lastRequest/lastResponseOutput to avoid breaking
// function_call / function_call_output pairings.
// See: https://github.com/router-for-me/CLIProxyAPI/issues/2207
var mergedInput string
if allowCompactionReplayBypass && inputContainsFullTranscript(nextInput) {
log.Infof("responses websocket: full transcript detected, skipping stale merge (input items=%d)", len(nextInput.Array()))
mergedInput = nextInput.Raw
} else {
appendInputRaw := nextInput.Raw
if inputContainsFullTranscript(nextInput) {
appendInputRaw = inputWithoutCompactionItems(nextInput)
}
}

mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid request input: %w", errMerge),
existingInput := gjson.GetBytes(lastRequest, "input")
var errMerge error
mergedInput, errMerge = mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
}
}

mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, appendInputRaw)
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid request input: %w", errMerge),
}
}
}
dedupedInput, errDedupeFunctionCalls := dedupeFunctionCallsByCallID(mergedInput)
Expand Down Expand Up @@ -480,72 +511,104 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
}

func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
if h == nil || h.AuthManager == nil {
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
for _, auth := range auths {
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
return true
}
}
return false
}

func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsCompactionReplayForModel(modelName string) bool {
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
if len(auths) == 0 {
return false
}
for _, auth := range auths {
if !responsesWebsocketAuthSupportsCompactionReplay(auth) {
return false
}
}
return true
}

func (h *OpenAIResponsesAPIHandler) responsesWebsocketAvailableAuthsForModel(modelName string) ([]*coreauth.Auth, string) {
if h == nil || h.AuthManager == nil {
return nil, ""
}
resolvedModelName := responsesWebsocketResolvedModelName(modelName)
providerSet, modelKey := responsesWebsocketProviderSetForModel(resolvedModelName)
if len(providerSet) == 0 {
return nil, modelKey
}

resolvedModelName := modelName
registryRef := registry.GetGlobalRegistry()
now := time.Now()
auths := h.AuthManager.List()
available := make([]*coreauth.Auth, 0, len(auths))
for _, auth := range auths {
if !responsesWebsocketAuthMatchesModel(auth, providerSet, modelKey, registryRef, now) {
continue
}
available = append(available, auth)
}
return available, modelKey
}

func responsesWebsocketResolvedModelName(modelName string) string {
initialSuffix := thinking.ParseSuffix(modelName)
if initialSuffix.ModelName == "auto" {
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
if initialSuffix.HasSuffix {
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
} else {
resolvedModelName = resolvedBase
return fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
}
} else {
resolvedModelName = util.ResolveAutoModel(modelName)
return resolvedBase
}
return util.ResolveAutoModel(modelName)
}

func responsesWebsocketProviderSetForModel(resolvedModelName string) (map[string]struct{}, string) {
parsed := thinking.ParseSuffix(resolvedModelName)
baseModel := strings.TrimSpace(parsed.ModelName)
providers := util.GetProviderName(baseModel)
if len(providers) == 0 && baseModel != resolvedModelName {
providers = util.GetProviderName(resolvedModelName)
}
if len(providers) == 0 {
return false
}

providerSet := make(map[string]struct{}, len(providers))
for i := 0; i < len(providers); i++ {
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
for _, provider := range providers {
providerKey := strings.TrimSpace(strings.ToLower(provider))
if providerKey == "" {
continue
}
providerSet[providerKey] = struct{}{}
}
if len(providerSet) == 0 {
return false
}

modelKey := baseModel
if modelKey == "" {
modelKey = strings.TrimSpace(resolvedModelName)
}
registryRef := registry.GetGlobalRegistry()
now := time.Now()
auths := h.AuthManager.List()
for i := 0; i < len(auths); i++ {
auth := auths[i]
if auth == nil {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
if _, ok := providerSet[providerKey]; !ok {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
continue
}
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
continue
}
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
return true
}
return providerSet, modelKey
}

func responsesWebsocketAuthMatchesModel(auth *coreauth.Auth, providerSet map[string]struct{}, modelKey string, registryRef *registry.ModelRegistry, now time.Time) bool {
if auth == nil {
return false
}
return false
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
if _, ok := providerSet[providerKey]; !ok {
return false
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
return false
}
return responsesWebsocketAuthAvailableForModel(auth, modelKey, now)
}

func responsesWebsocketAuthSupportsCompactionReplay(auth *coreauth.Auth) bool {
if auth == nil {
return false
}
return strings.EqualFold(strings.TrimSpace(auth.Provider), "codex")
}

func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
Expand Down Expand Up @@ -691,6 +754,42 @@ func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
return string(out), nil
}

// inputContainsFullTranscript returns true when the input array carries compact
// replay markers that indicate the client already sent the full conversation
// transcript. Merging that input with stale lastRequest/lastResponseOutput
// would duplicate or break function_call/function_call_output pairings, so the
// caller should use the input as-is.
//
// Assistant messages alone are not enough to classify the payload as a replay:
// incremental websocket requests may legitimately append assistant items.
func inputContainsFullTranscript(input gjson.Result) bool {
if !input.IsArray() {
return false
}
for _, item := range input.Array() {
t := item.Get("type").String()
if t == "compaction" || t == "compaction_summary" {
return true
}
}
Comment on lines +769 to +774
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and structure, you can refactor this loop to use a switch statement on the item's type. This makes the logic clearer and is more idiomatic in Go for handling multiple cases based on a string value.

Suggested change
for _, item := range input.Array() {
t := item.Get("type").String()
if t == "message" && item.Get("role").String() == "assistant" {
return true
}
if t == "compaction" || t == "compaction_summary" {
return true
}
}
for _, item := range input.Array() {
switch item.Get("type").String() {
case "message":
if item.Get("role").String() == "assistant" {
return true
}
case "compaction", "compaction_summary":
return true
}
}

return false
}

func inputWithoutCompactionItems(input gjson.Result) string {
if !input.IsArray() {
return normalizeJSONArrayRaw([]byte(input.Raw))
}
filtered := make([]string, 0, len(input.Array()))
for _, item := range input.Array() {
t := item.Get("type").String()
if t == "compaction" || t == "compaction_summary" {
continue
}
filtered = append(filtered, item.Raw)
}
return "[" + strings.Join(filtered, ",") + "]"
}

func normalizeJSONArrayRaw(raw []byte) string {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" {
Expand Down
Loading
Loading