Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions tracker/clientcontext/injector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package clientcontext

import (
"context"
"encoding/json"
"fmt"
"net"
"sync"

"github.com/sagernet/sing-box/adapter"
N "github.com/sagernet/sing/common/network"
)

var (
_ (adapter.ConnectionTracker) = (*ClientContextInjector)(nil)
_ (N.ConnHandshakeSuccess) = (*writeConn)(nil)
_ (N.PacketConnHandshakeSuccess) = (*writePacketConn)(nil)
)

// ClientContextInjector is a connection tracker that sends client info to a ClientContext Manager.
type ClientContextInjector struct {
getInfo GetClientInfoFn
inboundRule *boundsRule
outboundRule *boundsRule
ruleMu sync.RWMutex
}

// NewClientContextInjector creates a tracker for injecting client info.
func NewClientContextInjector(fn GetClientInfoFn, bounds MatchBounds) *ClientContextInjector {
return &ClientContextInjector{
inboundRule: newBoundsRule(bounds.Inbound),
outboundRule: newBoundsRule(bounds.Outbound),
getInfo: fn,
}
}

// RoutedConnection wraps the connection for writing client info.
func (t *ClientContextInjector) RoutedConnection(
ctx context.Context,
conn net.Conn,
metadata adapter.InboundContext,
matchedRule adapter.Rule,
matchOutbound adapter.Outbound,
) net.Conn {
if !t.match(metadata.Inbound, matchOutbound.Tag()) {
return conn
}
info := t.getInfo()
return newWriteConn(conn, &info)
}

// RoutedPacketConnection wraps the packet connection for writing client info.
func (t *ClientContextInjector) RoutedPacketConnection(
ctx context.Context,
conn N.PacketConn,
metadata adapter.InboundContext,
matchedRule adapter.Rule,
matchOutbound adapter.Outbound,
) N.PacketConn {
if !t.match(metadata.Inbound, matchOutbound.Tag()) {
return conn
}
info := t.getInfo()
return newWritePacketConn(conn, metadata, &info)
}

func (t *ClientContextInjector) match(inbound, outbound string) bool {
t.ruleMu.RLock()
defer t.ruleMu.RUnlock()
return t.inboundRule.match(inbound) && t.outboundRule.match(outbound)
}

func (t *ClientContextInjector) UpdateBounds(bounds MatchBounds) {
t.ruleMu.Lock()
t.inboundRule = newBoundsRule(bounds.Inbound)
t.outboundRule = newBoundsRule(bounds.Outbound)
t.ruleMu.Unlock()
}

// writeConn sends client info after handshake.
type writeConn struct {
net.Conn
info *ClientInfo
}

func newWriteConn(conn net.Conn, info *ClientInfo) net.Conn {
return &writeConn{Conn: conn, info: info}
}

// ConnHandshakeSuccess sends client info upon successful handshake with the server.
func (c *writeConn) ConnHandshakeSuccess(conn net.Conn) error {
if err := c.sendInfo(conn); err != nil {
return fmt.Errorf("sending client info: %w", err)
}
return nil
}

// sendInfo marshals and sends client info as an HTTP POST, then waits for HTTP 200 OK.
func (c *writeConn) sendInfo(conn net.Conn) error {
buf, err := json.Marshal(c.info)
if err != nil {
return fmt.Errorf("marshaling client info: %w", err)
}
packet := append([]byte(packetPrefix), buf...)
if _, err = conn.Write(packet); err != nil {
return fmt.Errorf("writing client info: %w", err)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could just use a normal http.Request here and then err := req.Write(conn) I think?


// wait for `OK` response
var resp [2]byte
if _, err := conn.Read(resp[:]); err != nil {
return fmt.Errorf("reading response: %w", err)
}
if string(resp[:]) != "OK" {
return fmt.Errorf("invalid response: %s", resp)
}
return nil
}

type writePacketConn struct {
N.PacketConn
metadata adapter.InboundContext
info *ClientInfo
}

func newWritePacketConn(
conn N.PacketConn,
metadata adapter.InboundContext,
info *ClientInfo,
) N.PacketConn {
return &writePacketConn{
PacketConn: conn,
metadata: metadata,
info: info,
}
}

// PacketConnHandshakeSuccess sends client info upon successful handshake.
func (c *writePacketConn) PacketConnHandshakeSuccess(conn net.PacketConn) error {
if err := c.sendInfo(conn); err != nil {
return fmt.Errorf("sending client info: %w", err)
}
return nil
}

// sendInfo marshals and sends client info as a CLIENTINFO packet, then waits for OK.
func (c *writePacketConn) sendInfo(conn net.PacketConn) error {
buf, err := json.Marshal(c.info)
if err != nil {
return fmt.Errorf("marshaling client info: %w", err)
}
packet := append([]byte(packetPrefix), buf...)
if _, err = conn.WriteTo(packet, c.metadata.Destination); err != nil {
return fmt.Errorf("writing packet: %w", err)
}

// wait for `OK` response
var resp [2]byte
if _, _, err := conn.ReadFrom(resp[:]); err != nil {
return fmt.Errorf("reading response: %w", err)
}
if string(resp[:]) != "OK" {
return fmt.Errorf("invalid response: %s", resp)
}
return nil
}
211 changes: 211 additions & 0 deletions tracker/clientcontext/manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
package clientcontext

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"sync"

"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

var _ (adapter.ConnectionTracker) = (*Manager)(nil)

type clientInfoKey struct{}

// ContextWithClientInfo returns a new context with the given ClientInfo.
func ContextWithClientInfo(ctx context.Context, info ClientInfo) context.Context {
return context.WithValue(ctx, clientInfoKey{}, info)
}

// ClientInfoFromContext retrieves the ClientInfo from the context.
func ClientInfoFromContext(ctx context.Context) (ClientInfo, bool) {
info, ok := ctx.Value(clientInfoKey{}).(ClientInfo)
return info, ok
}

// Manager is a ConnectionTracker that manages ClientInfo for connections.
type Manager struct {
logger log.ContextLogger
trackers []adapter.ConnectionTracker

inboundRule *boundsRule
outboundRule *boundsRule
ruleMu sync.RWMutex
}

// NewManager creates a new ClientContext Manager.
func NewManager(bounds MatchBounds, logger log.ContextLogger) *Manager {
return &Manager{
trackers: []adapter.ConnectionTracker{},
logger: logger,
inboundRule: newBoundsRule(bounds.Inbound),
outboundRule: newBoundsRule(bounds.Outbound),
}
}

// AppendTracker appends a ConnectionTracker to the Manager.
func (m *Manager) AppendTracker(tracker adapter.ConnectionTracker) {
m.trackers = append(m.trackers, tracker)
}

func (m *Manager) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn {
if !m.match(metadata.Inbound, matchOutbound.Tag()) {
return conn
}
c := &readConn{
Conn: conn,
reader: conn,
mgr: m,
}
info, err := c.readInfo()
if err != c.readErr {
m.logger.Error("failed to read client info ", "tag", "clientcontext-tracker", "error", err)
}
if err != nil {
return c
}
if info == nil {
return c
}
ctx = ContextWithClientInfo(ctx, *info)
conn = c
for _, tracker := range m.trackers {
conn = tracker.RoutedConnection(ctx, conn, metadata, matchedRule, matchOutbound)
}
return conn
}

func (m *Manager) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) N.PacketConn {
if !m.match(metadata.Inbound, matchOutbound.Tag()) {
return conn
}
c := &readPacketConn{
PacketConn: conn,
mgr: m,
}
info, err := c.readInfo()
if err != c.readErr {
m.logger.Error("failed to read client info ", "tag", "clientcontext-tracker", "error", err)
}
if err != nil {
return c
}
if info == nil {
return c
}
ctx = ContextWithClientInfo(ctx, *info)
conn = c
for _, tracker := range m.trackers {
conn = tracker.RoutedPacketConnection(ctx, conn, metadata, matchedRule, matchOutbound)
}
return conn
}

func (m *Manager) match(inbound, outbound string) bool {
m.ruleMu.RLock()
defer m.ruleMu.RUnlock()
return m.inboundRule.match(inbound) && m.outboundRule.match(outbound)
}

func (m *Manager) UpdateBounds(bounds MatchBounds) {
m.ruleMu.Lock()
m.inboundRule = newBoundsRule(bounds.Inbound)
m.outboundRule = newBoundsRule(bounds.Outbound)
m.ruleMu.Unlock()
}

// readConn reads client info from the connection on creation.
type readConn struct {
net.Conn
mgr *Manager
reader io.Reader
n int
readErr error
}

func (c *readConn) Read(b []byte) (n int, err error) {
if c.readErr != nil {
return c.n, c.readErr
}
return c.reader.Read(b)
}

// readInfo reads and decodes client info, then sends an HTTP 200 OK response.
func (c *readConn) readInfo() (*ClientInfo, error) {
var buf [32]byte
n, err := c.Conn.Read(buf[:])
if err != nil {
c.readErr = err
c.n = n
return nil, err
}
if !bytes.HasPrefix(buf[:n], []byte(packetPrefix)) {
c.reader = io.MultiReader(bytes.NewReader(buf[:n]), c.Conn)
return nil, nil
}

var info ClientInfo
reader := io.MultiReader(bytes.NewReader(buf[len(packetPrefix):n]), c.Conn)
if err := json.NewDecoder(reader).Decode(&info); err != nil {
return nil, fmt.Errorf("decoding client info: %w", err)
}

if _, err := c.Write([]byte("OK")); err != nil {
return nil, fmt.Errorf("writing OK response: %w", err)
}
return &info, nil
}

type readPacketConn struct {
N.PacketConn
mgr *Manager
destination metadata.Socksaddr
readErr error
}

func (c *readPacketConn) ReadPacket(b *buf.Buffer) (destination metadata.Socksaddr, err error) {
if c.readErr != nil {
return c.destination, c.readErr
}
return c.PacketConn.ReadPacket(b)
}

// readInfo reads and decodes client info if the first packet is a CLIENTINFO packet, then sends an
// OK response.
func (c *readPacketConn) readInfo() (*ClientInfo, error) {
buffer := buf.NewPacket()
defer buffer.Release()

destination, err := c.ReadPacket(buffer)
if err != nil {
c.destination = destination
c.readErr = err
return nil, err
}
data := buffer.Bytes()
if !bytes.HasPrefix(data, []byte(packetPrefix)) {
// not a client info packet, wrap with cached packet conn so the packet can be read again
c.PacketConn = bufio.NewCachedPacketConn(c.PacketConn, buffer, destination)
return nil, nil
}
var info ClientInfo
if err := json.Unmarshal(data[len(packetPrefix):], &info); err != nil {
return nil, fmt.Errorf("unmarshaling client info: %w", err)
}

buffer.Reset()
buffer.WriteString("OK")
if err := c.WritePacket(buffer, destination); err != nil {
return nil, fmt.Errorf("writing OK response: %w", err)
}
return &info, nil
}
Loading
Loading