Skip to content

Commit eaf4b12

Browse files
committed
fix: use http client instead of transport
1 parent 1bfa322 commit eaf4b12

File tree

1 file changed

+78
-28
lines changed

1 file changed

+78
-28
lines changed

controllers/proxy_controller.go

Lines changed: 78 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"context"
66
"fmt"
77
"io"
8-
"net"
98
"net/http"
109
"net/url"
1110
"time"
@@ -27,13 +26,13 @@ type ProxyController interface {
2726
}
2827

2928
type proxyController struct {
30-
ctx context.Context
31-
cfg *config.Config
32-
logger logging.Logger
33-
metrics metric.MetricService
34-
target *url.URL
35-
mirror *url.URL
36-
transport *http.Transport
29+
ctx context.Context
30+
cfg *config.Config
31+
logger logging.Logger
32+
metrics metric.MetricService
33+
target *url.URL
34+
mirror *url.URL
35+
client *http.Client
3736
}
3837

3938
type requestType string
@@ -70,27 +69,62 @@ func NewProxyController(
7069
logger.Fatalf("invalid mirror URL: %v", err)
7170
}
7271
}
73-
// Configure transport with timeouts
74-
transport := &http.Transport{
75-
Proxy: http.ProxyFromEnvironment,
76-
DialContext: (&net.Dialer{
77-
Timeout: cfg.Proxy.DialTimeout,
78-
}).DialContext,
79-
ForceAttemptHTTP2: true,
80-
IdleConnTimeout: cfg.Proxy.IdleTimeout,
72+
73+
client := &http.Client{
74+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
75+
if len(via) >= 10 {
76+
return fmt.Errorf("stopped after 10 redirects")
77+
}
78+
return nil
79+
},
8180
}
8281

8382
return &proxyController{
84-
ctx: ctx,
85-
cfg: cfg,
86-
logger: logger,
87-
metrics: metrics,
88-
target: target,
89-
mirror: mirror,
90-
transport: transport,
83+
ctx: ctx,
84+
cfg: cfg,
85+
logger: logger,
86+
metrics: metrics,
87+
target: target,
88+
mirror: mirror,
89+
client: client,
9190
}
9291
}
9392

93+
func (_this proxyController) createRequest(
94+
ctx context.Context,
95+
originalReq *http.Request,
96+
bodyBytes []byte,
97+
) (*http.Request, error) {
98+
// Create new request with appropriate context
99+
newReq, err := http.NewRequestWithContext(
100+
ctx,
101+
originalReq.Method,
102+
originalReq.URL.String(),
103+
bytes.NewBuffer(bodyBytes),
104+
)
105+
if err != nil {
106+
return nil, fmt.Errorf("failed to create request: %w", err)
107+
}
108+
109+
// Copy headers from original request
110+
for k, vv := range originalReq.Header {
111+
for _, v := range vv {
112+
newReq.Header.Add(k, v)
113+
}
114+
}
115+
116+
// Add Content-Length header if body exists
117+
if len(bodyBytes) > 0 {
118+
newReq.Header.Add("Content-Length", fmt.Sprintf("%d", len(bodyBytes)))
119+
// Ensure content type is preserved
120+
if contentType := originalReq.Header.Get("Content-Type"); contentType != "" {
121+
newReq.Header.Set("Content-Type", contentType)
122+
}
123+
}
124+
125+
return newReq, nil
126+
}
127+
94128
func (_this proxyController) proxyRequest(c *gin.Context) {
95129
// Read the original request body
96130
bodyBytes, err := io.ReadAll(c.Request.Body)
@@ -99,14 +133,24 @@ func (_this proxyController) proxyRequest(c *gin.Context) {
99133
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read request"})
100134
return
101135
}
102-
// Ignore error since we are closing the body anyway
103136
_ = c.Request.Body.Close()
104137

138+
// Restore the request body for downstream middleware/handlers
139+
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
140+
141+
// Create proxy request
142+
proxyReq, err := _this.createRequest(c.Request.Context(), c.Request, bodyBytes)
143+
if err != nil {
144+
_this.logger.Errorw("failed to create proxy request", "error", err)
145+
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"})
146+
return
147+
}
148+
105149
// Handle proxy request
106150
_this.handleRequest(requestContext{
107151
reqType: proxyRequest,
108152
ginContext: c,
109-
request: c.Request,
153+
request: proxyReq,
110154
bodyBytes: bodyBytes,
111155
startTime: time.Now(),
112156
targetURL: _this.target,
@@ -118,9 +162,15 @@ func (_this proxyController) proxyRequest(c *gin.Context) {
118162
ctx, cancel := context.WithTimeout(_this.ctx, _this.cfg.Proxy.MirrorTimeout)
119163
defer cancel()
120164

165+
mirrorReq, err := _this.createRequest(ctx, c.Request, bodyBytes)
166+
if err != nil {
167+
_this.logger.Errorw("failed to create mirror request", "error", err)
168+
return
169+
}
170+
121171
_this.handleRequest(requestContext{
122172
reqType: mirrorRequest,
123-
request: c.Request.Clone(ctx),
173+
request: mirrorReq,
124174
bodyBytes: bodyBytes,
125175
startTime: time.Now(),
126176
targetURL: _this.mirror,
@@ -131,7 +181,7 @@ func (_this proxyController) proxyRequest(c *gin.Context) {
131181

132182
func (_this proxyController) handleRequest(reqCtx requestContext) {
133183
// Prepare the request
134-
req := reqCtx.request.Clone(reqCtx.request.Context())
184+
req := reqCtx.request
135185
req.URL.Scheme = reqCtx.targetURL.Scheme
136186
req.URL.Host = reqCtx.targetURL.Host
137187
req.Host = reqCtx.targetURL.Host
@@ -155,7 +205,7 @@ func (_this proxyController) handleRequest(reqCtx requestContext) {
155205
}
156206

157207
// Make the request
158-
resp, err := _this.transport.RoundTrip(req)
208+
resp, err := _this.client.Do(req)
159209
if err != nil {
160210
_this.logger.Errorw(fmt.Sprintf("%s error", reqCtx.reqType), "error", err)
161211
if reqCtx.reqType == proxyRequest {

0 commit comments

Comments
 (0)