Skip to content

Commit b6f1fb0

Browse files
committed
feat(tcpreuse): add options for sharing TCP listeners amongst TCP, WS, and WSS transports
1 parent 9038a72 commit b6f1fb0

File tree

10 files changed

+635
-24
lines changed

10 files changed

+635
-24
lines changed

p2p/transport/tcp/tcp.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/libp2p/go-libp2p/core/peer"
1414
"github.com/libp2p/go-libp2p/core/transport"
1515
"github.com/libp2p/go-libp2p/p2p/net/reuseport"
16+
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
1617

1718
logging "github.com/ipfs/go-log/v2"
1819
ma "github.com/multiformats/go-multiaddr"
@@ -33,6 +34,9 @@ type canKeepAlive interface {
3334

3435
var _ canKeepAlive = &net.TCPConn{}
3536

37+
// Deprecated: Use tcpreuse.ReuseportIsAvailable
38+
var ReuseportIsAvailable = tcpreuse.ReuseportIsAvailable
39+
3640
func tryKeepAlive(conn net.Conn, keepAlive bool) {
3741
keepAliveConn, ok := conn.(canKeepAlive)
3842
if !ok {
@@ -113,6 +117,13 @@ func WithMetrics() Option {
113117
}
114118
}
115119

120+
func WithSharedTCP(mgr *tcpreuse.ConnMgr) Option {
121+
return func(tr *TcpTransport) error {
122+
tr.sharedTcp = mgr
123+
return nil
124+
}
125+
}
126+
116127
// TcpTransport is the TCP transport.
117128
type TcpTransport struct {
118129
// Connection upgrader for upgrading insecure stream connections to
@@ -122,6 +133,9 @@ type TcpTransport struct {
122133
disableReuseport bool // Explicitly disable reuseport.
123134
enableMetrics bool
124135

136+
// share and demultiplex TCP listeners across multiple transports
137+
sharedTcp *tcpreuse.ConnMgr
138+
125139
// TCP connect timeout
126140
connectTimeout time.Duration
127141

@@ -168,6 +182,10 @@ func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Co
168182
defer cancel()
169183
}
170184

185+
if t.sharedTcp != nil {
186+
return t.sharedTcp.DialContext(ctx, raddr)
187+
}
188+
171189
if t.UseReuseport() {
172190
return t.reuse.DialContext(ctx, raddr)
173191
}
@@ -233,10 +251,10 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p
233251

234252
// UseReuseport returns true if reuseport is enabled and available.
235253
func (t *TcpTransport) UseReuseport() bool {
236-
return !t.disableReuseport && ReuseportIsAvailable()
254+
return !t.disableReuseport && tcpreuse.ReuseportIsAvailable()
237255
}
238256

239-
func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) {
257+
func (t *TcpTransport) unsharedMAListen(laddr ma.Multiaddr) (manet.Listener, error) {
240258
if t.UseReuseport() {
241259
return t.reuse.Listen(laddr)
242260
}
@@ -245,10 +263,18 @@ func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) {
245263

246264
// Listen listens on the given multiaddr.
247265
func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) {
248-
list, err := t.maListen(laddr)
266+
var list manet.Listener
267+
var err error
268+
269+
if t.sharedTcp == nil {
270+
list, err = t.unsharedMAListen(laddr)
271+
} else {
272+
list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.MultistreamSelect)
273+
}
249274
if err != nil {
250275
return nil, err
251276
}
277+
252278
if t.enableMetrics {
253279
list = newTracingListener(&tcpListener{list, 0})
254280
}

p2p/transport/tcp/tcp_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/libp2p/go-libp2p/core/transport"
1515
"github.com/libp2p/go-libp2p/p2p/muxer/yamux"
1616
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
17+
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
1718
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"
1819

1920
ma "github.com/multiformats/go-multiaddr"
@@ -41,9 +42,9 @@ func TestTcpTransport(t *testing.T) {
4142
zero := "/ip4/127.0.0.1/tcp/0"
4243
ttransport.SubtestTransport(t, ta, tb, zero, peerA)
4344

44-
envReuseportVal = false
45+
tcpreuse.EnvReuseportVal = false
4546
}
46-
envReuseportVal = true
47+
tcpreuse.EnvReuseportVal = true
4748
}
4849

4950
func TestTcpTransportWithMetrics(t *testing.T) {
@@ -126,9 +127,9 @@ func TestTcpTransportCantDialDNS(t *testing.T) {
126127
t.Fatal("shouldn't be able to dial dns")
127128
}
128129

129-
envReuseportVal = false
130+
tcpreuse.EnvReuseportVal = false
130131
}
131-
envReuseportVal = true
132+
tcpreuse.EnvReuseportVal = true
132133
}
133134

134135
func TestTcpTransportCantListenUtp(t *testing.T) {
@@ -143,9 +144,9 @@ func TestTcpTransportCantListenUtp(t *testing.T) {
143144
_, err = tpt.Listen(utpa)
144145
require.Error(t, err, "shouldn't be able to listen on utp addr with tcp transport")
145146

146-
envReuseportVal = false
147+
tcpreuse.EnvReuseportVal = false
147148
}
148-
envReuseportVal = true
149+
tcpreuse.EnvReuseportVal = true
149150
}
150151

151152
func TestDialWithUpdates(t *testing.T) {

p2p/transport/tcpreuse/demultiplex.go

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
package tcpreuse
2+
3+
import (
4+
"bufio"
5+
"errors"
6+
"fmt"
7+
"io"
8+
"math"
9+
"net"
10+
"time"
11+
12+
ma "github.com/multiformats/go-multiaddr"
13+
manet "github.com/multiformats/go-multiaddr/net"
14+
)
15+
16+
type peekAble interface {
17+
// Peek returns the next n bytes without advancing the reader. The bytes stop
18+
// being valid at the next read call. If Peek returns fewer than n bytes, it
19+
// also returns an error explaining why the read is short. The error is
20+
// [ErrBufferFull] if n is larger than b's buffer size.
21+
Peek(n int) ([]byte, error)
22+
}
23+
24+
var _ peekAble = (*bufio.Reader)(nil)
25+
26+
type DemultiplexedConnType int
27+
28+
const (
29+
Unknown DemultiplexedConnType = iota
30+
MultistreamSelect
31+
HTTP
32+
TLS
33+
)
34+
35+
func (t DemultiplexedConnType) String() string {
36+
switch t {
37+
case MultistreamSelect:
38+
return "MultistreamSelect"
39+
case HTTP:
40+
return "HTTP"
41+
case TLS:
42+
return "TLS"
43+
default:
44+
return fmt.Sprintf("Unknown(%d)", int(t))
45+
}
46+
}
47+
48+
func (t DemultiplexedConnType) IsKnown() bool {
49+
return t >= 1 || t <= 3
50+
}
51+
52+
func ConnTypeFromConn(c net.Conn) (DemultiplexedConnType, manet.Conn, error) {
53+
if err := c.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil {
54+
closeErr := c.Close()
55+
return 0, nil, errors.Join(err, closeErr)
56+
}
57+
58+
s, sc, err := ReadSampleFromConn(c)
59+
if err != nil {
60+
closeErr := c.Close()
61+
return 0, nil, errors.Join(err, closeErr)
62+
}
63+
64+
if err := c.SetReadDeadline(time.Time{}); err != nil {
65+
closeErr := c.Close()
66+
return 0, nil, errors.Join(err, closeErr)
67+
}
68+
69+
if IsMultistreamSelect(s) {
70+
return MultistreamSelect, sc, nil
71+
}
72+
if IsTLS(s) {
73+
return TLS, sc, nil
74+
}
75+
if IsHTTP(s) {
76+
return HTTP, sc, nil
77+
}
78+
return Unknown, sc, nil
79+
}
80+
81+
// ReadSampleFromConn read the sample and returns a reader which still include the sample, so it can be kept undamaged.
82+
// If an error occurs it only return the error.
83+
func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) {
84+
if peekAble, ok := c.(peekAble); ok {
85+
b, err := peekAble.Peek(len(Sample{}))
86+
switch {
87+
case err == nil:
88+
mac, err := manet.WrapNetConn(c)
89+
if err != nil {
90+
return Sample{}, nil, err
91+
}
92+
93+
return Sample(b), mac, nil
94+
case errors.Is(err, bufio.ErrBufferFull):
95+
// fallback to sampledConn
96+
default:
97+
return Sample{}, nil, err
98+
}
99+
}
100+
101+
tcpConnLike, ok := c.(tcpConnInterface)
102+
if !ok {
103+
return Sample{}, nil, fmt.Errorf("expected tcp-like connection")
104+
}
105+
106+
laddr, err := manet.FromNetAddr(c.LocalAddr())
107+
if err != nil {
108+
return Sample{}, nil, fmt.Errorf("failed to convert nconn.LocalAddr: %s", err)
109+
}
110+
111+
raddr, err := manet.FromNetAddr(c.RemoteAddr())
112+
if err != nil {
113+
return Sample{}, nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err)
114+
}
115+
116+
sc := &sampledConn{tcpConnInterface: tcpConnLike, maEndpoints: maEndpoints{laddr: laddr, raddr: raddr}}
117+
_, err = io.ReadFull(c, sc.s[:])
118+
if err != nil {
119+
return Sample{}, nil, err
120+
}
121+
122+
return sc.s, sc, nil
123+
}
124+
125+
// Try out best to mimic a TCPConn's functions
126+
// Note: Skipping `SyscallConn() (syscall.RawConn, error)` since it can be misused given we've read a few bytes from the connection
127+
// If this is an issue here we can revisit the options.
128+
type tcpConnInterface interface {
129+
net.Conn
130+
131+
CloseRead() error
132+
CloseWrite() error
133+
134+
SetLinger(sec int) error
135+
SetKeepAlive(keepalive bool) error
136+
SetKeepAlivePeriod(d time.Duration) error
137+
SetNoDelay(noDelay bool) error
138+
MultipathTCP() (bool, error)
139+
140+
io.ReaderFrom
141+
io.WriterTo
142+
}
143+
144+
type maEndpoints struct {
145+
laddr ma.Multiaddr
146+
raddr ma.Multiaddr
147+
}
148+
149+
// LocalMultiaddr returns the local address associated with
150+
// this connection
151+
func (c *maEndpoints) LocalMultiaddr() ma.Multiaddr {
152+
return c.laddr
153+
}
154+
155+
// RemoteMultiaddr returns the remote address associated with
156+
// this connection
157+
func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr {
158+
return c.raddr
159+
}
160+
161+
type sampledConn struct {
162+
tcpConnInterface
163+
maEndpoints
164+
165+
s Sample
166+
readFromSample uint8
167+
}
168+
169+
var _ = [math.MaxUint8]struct{}{}[len(Sample{})] // compiletime assert sampledConn.readFromSample wont overflow
170+
var _ io.ReaderFrom = (*sampledConn)(nil)
171+
var _ io.WriterTo = (*sampledConn)(nil)
172+
173+
func (sc *sampledConn) Read(b []byte) (int, error) {
174+
if int(sc.readFromSample) != len(sc.s) {
175+
red := copy(b, sc.s[sc.readFromSample:])
176+
sc.readFromSample += uint8(red)
177+
return red, nil
178+
}
179+
180+
return sc.tcpConnInterface.Read(b)
181+
}
182+
183+
// forward optimizations
184+
func (sc *sampledConn) ReadFrom(r io.Reader) (int64, error) {
185+
return io.Copy(sc.tcpConnInterface, r)
186+
}
187+
188+
// forward optimizations
189+
func (sc *sampledConn) WriteTo(w io.Writer) (total int64, err error) {
190+
if int(sc.readFromSample) != len(sc.s) {
191+
b := sc.s[sc.readFromSample:]
192+
written, err := w.Write(b)
193+
if written < 0 || len(b) < written {
194+
// buggy writer, harden against this
195+
sc.readFromSample = uint8(len(sc.s))
196+
total = int64(len(sc.s))
197+
} else {
198+
sc.readFromSample += uint8(written)
199+
total += int64(written)
200+
}
201+
if err != nil {
202+
return total, err
203+
}
204+
}
205+
206+
written, err := io.Copy(w, sc.tcpConnInterface)
207+
total += written
208+
return total, err
209+
}
210+
211+
type Matcher interface {
212+
Match(s Sample) bool
213+
}
214+
215+
// Sample might evolve over time.
216+
type Sample [3]byte
217+
218+
// Matchers are implemented here instead of in the transports so we can easily fuzz them together.
219+
220+
func IsMultistreamSelect(s Sample) bool {
221+
return string(s[:]) == "\x13/m"
222+
}
223+
224+
func IsHTTP(s Sample) bool {
225+
switch string(s[:]) {
226+
case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT":
227+
return true
228+
default:
229+
return false
230+
}
231+
}
232+
233+
func IsTLS(s Sample) bool {
234+
switch string(s[:]) {
235+
case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03", "\x16\x03\x04":
236+
return true
237+
default:
238+
return false
239+
}
240+
}

0 commit comments

Comments
 (0)