diff --git a/middleware/proxy.go b/middleware/proxy.go index 2744bc4a8..85e48ca93 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -375,6 +375,13 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // that Balancer may have replaced with c.SetRequest. req = c.Request() + // Handle authorization from target URL (for both HTTP and WebSocket) + if tgt.URL.User != nil { + username := tgt.URL.User.Username() + password, _ := tgt.URL.User.Password() + req.SetBasicAuth(username, password) + } + // Proxy switch { case c.IsWebSocket(): diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index dbf07648b..569742774 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -1040,3 +1040,121 @@ func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) { assert.NoError(t, err) assert.Equal(t, sendMsg, recvMsg) } + +func TestProxyWithAuthorizationHeader(t *testing.T) { + // Scenario: + // A proxy target has user:pass in the url. + // The proxy should pass the Authorization header to the target. + + var receivedAuthHeader string + // Arrange + t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuthHeader = r.Header.Get("Authorization") + fmt.Fprint(w, "target 1") + })) + defer t1.Close() + url1, _ := url.Parse(t1.URL) + url1.User = url.UserPassword("user1", "pass1") + + e := echo.New() + tp := &testProvider{} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + + // Assert + assert.Equal(t, "target 1", rec.Body.String()) + assert.Equal(t, "Basic dXNlcjE6cGFzczE=", receivedAuthHeader) + + // Scenario: + // A proxy target does not have user:pass in the url. + // The proxy should not pass the Authorization header to the target. + + receivedAuthHeader = "" + // Arrange + t2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuthHeader = r.Header.Get("Authorization") + fmt.Fprint(w, "target 2") + })) + defer t2.Close() + url2, _ := url.Parse(t2.URL) + + e = echo.New() + tp = &testProvider{} + tp.target = &ProxyTarget{Name: "target 2", URL: url2} + e.Use(Proxy(tp)) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + + // Assert + assert.Equal(t, "target 2", rec.Body.String()) + assert.Equal(t, "", receivedAuthHeader) +} + +func TestProxyWithConfigWebSocketAuthorizationHeader(t *testing.T) { + // Capture the authorization header received by the WebSocket server + var receivedAuthHeader string + var authHeaderMutex sync.Mutex + + // Create a WebSocket server that captures the Authorization header + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeaderMutex.Lock() + receivedAuthHeader = r.Header.Get("Authorization") + authHeaderMutex.Unlock() + + wsHandler := func(conn *websocket.Conn) { + defer conn.Close() + for { + var msg string + err := websocket.Message.Receive(conn, &msg) + if err != nil { + return + } + // Echo message back to the client + websocket.Message.Send(conn, msg) + } + } + websocket.Server{Handler: wsHandler}.ServeHTTP(w, r) + })) + defer wsServer.Close() + + // Create proxy server with target URL containing user:pass credentials + targetURL, _ := url.Parse(wsServer.URL) + targetURL.User = url.UserPassword("wsuser", "wspass") + + e := echo.New() + balancer := NewRandomBalancer([]*ProxyTarget{{URL: targetURL}}) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer})) + + proxyServer := httptest.NewServer(e) + defer proxyServer.Close() + + // Connect to the proxy WebSocket + proxyURL, _ := url.Parse(proxyServer.URL) + proxyURL.Scheme = "ws" + proxyURL.Path = "/" + + wsConn, err := websocket.Dial(proxyURL.String(), "", "http://localhost/") + assert.NoError(t, err) + defer wsConn.Close() + + // Send message to verify WebSocket connection works + sendMsg := "Hello, WebSocket with Auth!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) + + // Verify authorization header was forwarded + authHeaderMutex.Lock() + expectedAuth := "Basic d3N1c2VyOndzcGFzcw==" // base64 of "wsuser:wspass" + assert.Equal(t, expectedAuth, receivedAuthHeader) + authHeaderMutex.Unlock() +}