diff --git a/tracker/clientcontext/tracker.go b/tracker/clientcontext/tracker.go index 62ac854..33d95ec 100644 --- a/tracker/clientcontext/tracker.go +++ b/tracker/clientcontext/tracker.go @@ -9,14 +9,20 @@ package clientcontext import ( + stdbufio "bufio" + "bytes" "context" "encoding/json" "fmt" + "io" "net" + "net/http" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/service" ) @@ -161,36 +167,68 @@ func (b *boundsRule) match(tag string) bool { type readConn struct { net.Conn ctx context.Context - info *ClientInfo + info ClientInfo logger log.ContextLogger + + reader io.Reader + n int + readErr error } // newReadConn creates a readConn and reads client info from it. If successful, the info is stored // in the context. func newReadConn(ctx context.Context, conn net.Conn, logger log.ContextLogger) net.Conn { - c := &readConn{Conn: conn, ctx: ctx} - info, err := c.readInfo() - if err != nil { - logger.Error("reading client info ", err) - return conn + c := &readConn{ + Conn: conn, + ctx: ctx, + reader: conn, + logger: logger, + } + if err := c.readInfo(); err != nil { + logger.Warn("reading client info: ", err) } - service.ContextWithPtr(ctx, info) return c } -// readInfo reads and decodes client info, then sends an OK response. -func (c *readConn) readInfo() (*ClientInfo, error) { +func (c *readConn) Read(b []byte) (n int, err error) { + if c.readErr != nil { + return c.n, c.readErr + } + return c.reader.Read(b) +} + +// readInfo reads and decodes client info, then sends an HTTP 200 OK response. +func (c *readConn) readInfo() error { + var buf [32]byte + n, err := c.Conn.Read(buf[:]) + if err != nil { + c.readErr = err + c.n = n + return err + } + reader := io.MultiReader(bytes.NewReader(buf[:n]), c.Conn) + if !bytes.HasPrefix(buf[:n], []byte("POST /clientinfo")) { + c.reader = reader + return nil + } + var info ClientInfo - if err := json.NewDecoder(c).Decode(&info); err != nil { - return nil, fmt.Errorf("decoding client info: %w", err) + req, err := http.ReadRequest(stdbufio.NewReader(reader)) + if err != nil { + return fmt.Errorf("reading HTTP request: %w", err) } - c.info = &info + defer req.Body.Close() + if err := json.NewDecoder(req.Body).Decode(&info); err != nil { + return fmt.Errorf("decoding client info: %w", err) + } + c.info = info - // send `OK` response - if _, err := c.Write([]byte("OK")); err != nil { - return nil, fmt.Errorf("writing OK response to client: %w", err) + resp := "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" + if _, err := c.Write([]byte(resp)); err != nil { + return fmt.Errorf("writing HTTP response: %w", err) } - return &info, nil + service.ContextWithPtr(c.ctx, &info) + return nil } // writeConn sends client info after handshake. @@ -213,70 +251,101 @@ func (c *writeConn) ConnHandshakeSuccess(conn net.Conn) error { return nil } -// sendInfo marshals and sends client info, then waits for OK. +// sendInfo marshals and sends client info as an HTTP POST, then waits for HTTP 200 OK. func (c *writeConn) sendInfo(conn net.Conn) error { buf, err := json.Marshal(c.info) if err != nil { return fmt.Errorf("marshaling client info: %w", err) } - if _, err = conn.Write(buf); err != nil { + // Write HTTP POST request + req := bytes.NewBuffer(nil) + fmt.Fprintf(req, "POST /clientinfo HTTP/1.1\r\n") + fmt.Fprintf(req, "Host: localhost\r\n") + fmt.Fprintf(req, "Content-Type: application/json\r\n") + fmt.Fprintf(req, "Content-Length: %d\r\n", len(buf)) + fmt.Fprintf(req, "\r\n") + req.Write(buf) + if _, err = conn.Write(req.Bytes()); err != nil { return fmt.Errorf("writing client info: %w", err) } - // wait for `OK` response - resp := make([]byte, 2) - if _, err := conn.Read(resp); err != nil { - return fmt.Errorf("reading server response: %w", err) + // wait for HTTP 200 OK response + reader := stdbufio.NewReader(conn) + resp, err := http.ReadResponse(reader, nil) + if err != nil { + return fmt.Errorf("reading HTTP response: %w", err) } - if string(resp) != "OK" { - return fmt.Errorf("invalid server response: %s", resp) + defer resp.Body.Close() + if resp.StatusCode != 200 { + return fmt.Errorf("invalid server response: %s", resp.Status) } return nil } +const prefix = "CLIENTINFO " + type readPacketConn struct { N.PacketConn ctx context.Context info *ClientInfo logger log.ContextLogger + + reader io.Reader + destination metadata.Socksaddr + readErr error } // newReadPacketConn creates a readPacketConn and reads client info from it. If successful, the // info is stored in the context. func newReadPacketConn(ctx context.Context, conn N.PacketConn, logger log.ContextLogger) N.PacketConn { - c := &readPacketConn{PacketConn: conn, ctx: ctx, logger: logger} - info, err := c.readInfo() - if err != nil { - logger.Error("reading client info ", err) - return conn + c := &readPacketConn{ + PacketConn: conn, + ctx: ctx, + logger: logger, + } + if err := c.readInfo(); err != nil { + logger.Warn("reading client info: ", err) } - - service.ContextWithPtr(ctx, info) return c } -// readInfo reads and decodes client info, then sends an OK response. -func (c *readPacketConn) readInfo() (*ClientInfo, error) { +func (c *readPacketConn) ReadPacket(b *buf.Buffer) (destination metadata.Socksaddr, err error) { + if c.readErr != nil { + return c.destination, c.readErr + } + return c.PacketConn.ReadPacket(b) +} + +// readInfo reads and decodes client info if the first packet is a CLIENTINFO packet, then sends an +// OK response. +func (c *readPacketConn) readInfo() error { buffer := buf.NewPacket() defer buffer.Release() destination, err := c.ReadPacket(buffer) if err != nil { - return nil, fmt.Errorf("reading packet from client: %w", err) + c.readErr = err + return err + } + data := buffer.Bytes() + if !bytes.HasPrefix(data, []byte(prefix)) { + // not a client info packet, wrap with cached packet conn so the packet can be read again + c.PacketConn = bufio.NewCachedPacketConn(c.PacketConn, buffer, destination) + return nil } var info ClientInfo - if err := json.Unmarshal(buffer.Bytes(), &info); err != nil { - return nil, fmt.Errorf("decoding client info: %w", err) + if err := json.Unmarshal(data[len(prefix):], &info); err != nil { + return fmt.Errorf("unmarshaling client info: %w", err) } c.info = &info - // send `OK` response buffer.Reset() buffer.WriteString("OK") if err := c.WritePacket(buffer, destination); err != nil { - return nil, fmt.Errorf("writing OK response to client: %w", err) + return fmt.Errorf("writing OK response: %w", err) } - return &info, nil + service.ContextWithPtr(c.ctx, &info) + return nil } type writePacketConn struct { @@ -311,13 +380,14 @@ func (c *writePacketConn) PacketConnHandshakeSuccess(conn net.PacketConn) error return nil } -// sendInfo marshals and sends client info, then waits for OK. +// sendInfo marshals and sends client info as a CLIENTINFO packet, then waits for OK. func (c *writePacketConn) sendInfo(conn net.PacketConn) error { buf, err := json.Marshal(c.info) if err != nil { return fmt.Errorf("marshaling client info: %w", err) } - _, err = conn.WriteTo(buf, c.metadata.Destination) + packet := append([]byte(prefix), buf...) + _, err = conn.WriteTo(packet, c.metadata.Destination) if err != nil { return fmt.Errorf("writing packet: %w", err) } diff --git a/tracker/clientcontext/tracker_test.go b/tracker/clientcontext/tracker_test.go index 93827b9..e733727 100644 --- a/tracker/clientcontext/tracker_test.go +++ b/tracker/clientcontext/tracker_test.go @@ -28,51 +28,68 @@ const testOptionsPath = "../../testdata/options" func TestIntegration(t *testing.T) { cInfo := ClientInfo{ - DeviceID: "sing-box-extensions", + DeviceID: "lantern-box", Platform: "linux", IsPro: false, CountryCode: "US", Version: "9.0", } - ctx := box.BoxContext() + ctx := box.BaseContext() logger := log.NewNOPFactory().NewLogger("") - clientTracker := NewClientContextTracker(cInfo, MatchBounds{[]string{"any"}, []string{"any"}}, logger) - clientOpts, clientBox := newTestBox(ctx, t, testOptionsPath+"/http_client.json", clientTracker) - - httpInbound, exists := clientBox.Inbound().Get("http-client") - require.True(t, exists, "http-client inbound should exist") - require.Equal(t, constant.TypeHTTP, httpInbound.Type(), "http-client should be a HTTP inbound") - - // this cannot actually be empty or we would have failed to create the box instance - proxyAddr := getProxyAddress(clientOpts.Inbounds) - serverTracker := NewClientContextReader(MatchBounds{[]string{"any"}, []string{"any"}}, logger) _, serverBox := newTestBox(ctx, t, testOptionsPath+"/http_server.json", serverTracker) mTracker := &mockTracker{} serverBox.Router().AppendTracker(mTracker) - require.NoError(t, clientBox.Start()) - defer clientBox.Close() require.NoError(t, serverBox.Start()) defer serverBox.Close() httpServer := startHTTPServer() defer httpServer.Close() + clientOpts, clientBox := newTestBox(ctx, t, testOptionsPath+"/http_client.json", nil) + + httpInbound, exists := clientBox.Inbound().Get("http-client") + require.True(t, exists, "http-client inbound should exist") + require.Equal(t, constant.TypeHTTP, httpInbound.Type(), "http-client should be a HTTP inbound") + + // this cannot actually be empty or we would have failed to create the box instance + proxyAddr := getProxyAddress(clientOpts.Inbounds) + + require.NoError(t, clientBox.Start()) + defer clientBox.Close() + proxyURL, _ := url.Parse("http://" + proxyAddr) httpClient := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxyURL), }, } - req, err := http.NewRequest("GET", httpServer.URL, nil) - require.NoError(t, err) + addr := httpServer.URL - _, err = httpClient.Do(req) - require.NoError(t, err) + t.Run("without ClientContext tracker", func(t *testing.T) { + req, err := http.NewRequest("GET", addr+"/ip", nil) + require.NoError(t, err) + + _, err = httpClient.Do(req) + require.NoError(t, err) - require.Equal(t, cInfo, *mTracker.info) + require.Nil(t, mTracker.info) + }) + t.Run("with ClientContext tracker", func(t *testing.T) { + clientTracker := NewClientContextTracker(cInfo, MatchBounds{[]string{"any"}, []string{"any"}}, logger) + clientBox.Router().AppendTracker(clientTracker) + req, err := http.NewRequest("GET", addr+"/ip", nil) + require.NoError(t, err) + + _, err = httpClient.Do(req) + require.NoError(t, err) + + info := mTracker.info + require.NotNil(t, info) + require.Equal(t, cInfo, *info) + }) } func getProxyAddress(inbounds []option.Inbound) string { @@ -106,7 +123,9 @@ func newTestBox(ctx context.Context, t *testing.T, configPath string, tracker *C }) require.NoError(t, err) - instance.Router().AppendTracker(tracker) + if tracker != nil { + instance.Router().AppendTracker(tracker) + } return options, instance }