Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions server/adkrest/controllers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,42 @@ package controllers

import (
"encoding/json"
"log"
"net/http"
)

// TODO: Move to an internal package, controllers doesn't have to be public API.

// trackingResponseWriter wraps http.ResponseWriter to track if headers have been written.
// This prevents the "superfluous WriteHeader" error when errors occur after streaming starts.
type trackingResponseWriter struct {
http.ResponseWriter
headerWritten bool
}

// WriteHeader tracks that headers have been written and delegates to the underlying writer.
func (w *trackingResponseWriter) WriteHeader(statusCode int) {
if w.headerWritten {
// Headers already written, log and skip to avoid superfluous WriteHeader
log.Printf("ADK: Skipping duplicate WriteHeader call (status %d) - headers already sent", statusCode)
return
}
w.headerWritten = true
w.ResponseWriter.WriteHeader(statusCode)
}

// Write delegates to the underlying writer and marks headers as written
// (Go's http.ResponseWriter implicitly calls WriteHeader(200) on first Write if not called)
func (w *trackingResponseWriter) Write(data []byte) (int, error) {
w.headerWritten = true
return w.ResponseWriter.Write(data)
}

// Unwrap returns the underlying ResponseWriter for http.ResponseController compatibility
func (w *trackingResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

// EncodeJSONResponse uses the json encoder to write an interface to the http response with an optional status code
func EncodeJSONResponse(i any, status int, w http.ResponseWriter) {
wHeader := w.Header()
Expand All @@ -32,6 +63,11 @@ func EncodeJSONResponse(i any, status int, w http.ResponseWriter) {
if i != nil {
err := json.NewEncoder(w).Encode(i)
if err != nil {
// Only attempt error response if headers haven't been written yet
if tw, ok := w.(*trackingResponseWriter); ok && tw.headerWritten {
log.Printf("ADK: Failed to encode JSON response after headers written: %v", err)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
Expand All @@ -40,14 +76,26 @@ func EncodeJSONResponse(i any, status int, w http.ResponseWriter) {
type errorHandler func(http.ResponseWriter, *http.Request) error

// NewErrorHandler writes the error code returned from the http handler.
// It uses trackingResponseWriter to prevent "superfluous WriteHeader" errors
// when handlers return errors after already starting to write a response (e.g., SSE streaming).
func NewErrorHandler(fn errorHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
err := fn(w, r)
// Wrap the response writer to track if headers have been written
tw := &trackingResponseWriter{ResponseWriter: w}

err := fn(tw, r)
if err != nil {
// Only write error response if headers haven't been sent yet
if tw.headerWritten {
// Headers already written (e.g., during SSE streaming), just log the error
log.Printf("ADK: Error occurred after response started: %v", err)
return
}

if statusErr, ok := err.(statusError); ok {
http.Error(w, statusErr.Error(), statusErr.Status())
http.Error(tw, statusErr.Error(), statusErr.Status())
} else {
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Error(tw, err.Error(), http.StatusInternalServerError)
}
}
}
Expand Down