From 996bf56e32548f7fc30542772be048716ed294a9 Mon Sep 17 00:00:00 2001 From: sicko Date: Sat, 24 Jan 2026 10:49:48 +1100 Subject: [PATCH] fix: prevent superfluous WriteHeader error in SSE streaming Added trackingResponseWriter wrapper to prevent the 'superfluous response.WriteHeader call' error that occurs when: 1. SSE streaming starts (WriteHeader is called) 2. An error occurs during streaming 3. Error handler tries to call WriteHeader again via http.Error() The fix wraps http.ResponseWriter to track if headers have been written and gracefully handles errors that occur after streaming has started. --- server/adkrest/controllers/handlers.go | 54 ++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/server/adkrest/controllers/handlers.go b/server/adkrest/controllers/handlers.go index 97e01525a..3b3465544 100644 --- a/server/adkrest/controllers/handlers.go +++ b/server/adkrest/controllers/handlers.go @@ -17,11 +17,42 @@ package controllers import ( "encoding/json" + "log" "net/http" ) // TODO: Move to an internal package, controllers doesn't have to be public API. +// trackingResponseWriter wraps http.ResponseWriter to track if headers have been written. +// This prevents the "superfluous WriteHeader" error when errors occur after streaming starts. +type trackingResponseWriter struct { + http.ResponseWriter + headerWritten bool +} + +// WriteHeader tracks that headers have been written and delegates to the underlying writer. +func (w *trackingResponseWriter) WriteHeader(statusCode int) { + if w.headerWritten { + // Headers already written, log and skip to avoid superfluous WriteHeader + log.Printf("ADK: Skipping duplicate WriteHeader call (status %d) - headers already sent", statusCode) + return + } + w.headerWritten = true + w.ResponseWriter.WriteHeader(statusCode) +} + +// Write delegates to the underlying writer and marks headers as written +// (Go's http.ResponseWriter implicitly calls WriteHeader(200) on first Write if not called) +func (w *trackingResponseWriter) Write(data []byte) (int, error) { + w.headerWritten = true + return w.ResponseWriter.Write(data) +} + +// Unwrap returns the underlying ResponseWriter for http.ResponseController compatibility +func (w *trackingResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + // EncodeJSONResponse uses the json encoder to write an interface to the http response with an optional status code func EncodeJSONResponse(i any, status int, w http.ResponseWriter) { wHeader := w.Header() @@ -32,6 +63,11 @@ func EncodeJSONResponse(i any, status int, w http.ResponseWriter) { if i != nil { err := json.NewEncoder(w).Encode(i) if err != nil { + // Only attempt error response if headers haven't been written yet + if tw, ok := w.(*trackingResponseWriter); ok && tw.headerWritten { + log.Printf("ADK: Failed to encode JSON response after headers written: %v", err) + return + } http.Error(w, err.Error(), http.StatusInternalServerError) } } @@ -40,14 +76,26 @@ func EncodeJSONResponse(i any, status int, w http.ResponseWriter) { type errorHandler func(http.ResponseWriter, *http.Request) error // NewErrorHandler writes the error code returned from the http handler. +// It uses trackingResponseWriter to prevent "superfluous WriteHeader" errors +// when handlers return errors after already starting to write a response (e.g., SSE streaming). func NewErrorHandler(fn errorHandler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - err := fn(w, r) + // Wrap the response writer to track if headers have been written + tw := &trackingResponseWriter{ResponseWriter: w} + + err := fn(tw, r) if err != nil { + // Only write error response if headers haven't been sent yet + if tw.headerWritten { + // Headers already written (e.g., during SSE streaming), just log the error + log.Printf("ADK: Error occurred after response started: %v", err) + return + } + if statusErr, ok := err.(statusError); ok { - http.Error(w, statusErr.Error(), statusErr.Status()) + http.Error(tw, statusErr.Error(), statusErr.Status()) } else { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(tw, err.Error(), http.StatusInternalServerError) } } }