diff --git a/.github/lint-commit-message.sh b/.github/lint-commit-message.sh index 010a3328..d1194645 100755 --- a/.github/lint-commit-message.sh +++ b/.github/lint-commit-message.sh @@ -29,6 +29,7 @@ EndOfMessage exit 1 } + lint_commit_message() { if [[ "$(echo "$1" | awk 'NR == 2 {print $1;}' | wc -c)" -ne 1 ]]; then display_commit_message_error "$1" 'Separate subject from body with a blank line' diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 43608f13..2b6f3552 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -13,10 +13,10 @@ name: Test on: push: branches: - - master + - * pull_request: branches: - - master + - * jobs: test: runs-on: ubuntu-latest diff --git a/.github/workflows/tidy-check.yaml b/.github/workflows/tidy-check.yaml index 03b5189d..470e5abb 100644 --- a/.github/workflows/tidy-check.yaml +++ b/.github/workflows/tidy-check.yaml @@ -13,10 +13,10 @@ name: Go mod tidy on: pull_request: branches: - - master + - * push: branches: - - master + - * jobs: Check: diff --git a/agent.go b/agent.go index b4aba19f..aa2a1d6f 100644 --- a/agent.go +++ b/agent.go @@ -122,9 +122,10 @@ type Agent struct { loggerFactory logging.LoggerFactory log logging.LeveledLogger - net *vnet.Net - tcpMux TCPMux - udpMux UDPMux + net *vnet.Net + tcpMux TCPMux + udpMux UDPMux + udpMuxSrflx UniversalUDPMux interfaceFilter func(string) bool @@ -319,6 +320,7 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit a.tcpMux = newInvalidTCPMux() } a.udpMux = config.UDPMux + a.udpMuxSrflx = config.UDPMuxSrflx if a.net == nil { a.net = vnet.NewNet(nil) @@ -892,6 +894,9 @@ func (a *Agent) removeUfragFromMux() { if a.udpMux != nil { a.udpMux.RemoveConnByUfrag(a.localUfrag) } + if a.udpMuxSrflx != nil { + a.udpMuxSrflx.RemoveConnByUfrag(a.localUfrag) + } } // Close cleans up the Agent diff --git a/agent_config.go b/agent_config.go index e577af10..355e05ae 100644 --- a/agent_config.go +++ b/agent_config.go @@ -150,6 +150,11 @@ type AgentConfig struct { // defer to UDPMux for incoming connections UDPMux UDPMux + // UDPMux is used for multiplexing multiple incoming UDP connections on a single port + // when this is set, the agent ignores PortMin and PortMax configurations and will + // defer to UDPMux for incoming connections + UDPMuxSrflx UniversalUDPMux + // Proxy Dialer is a dialer that should be implemented by the user based on golang.org/x/net/proxy // dial interface in order to support corporate proxies ProxyDialer proxy.Dialer diff --git a/errors.go b/errors.go index 8ca9c2cd..2ef595e5 100644 --- a/errors.go +++ b/errors.go @@ -135,4 +135,8 @@ var ( errICEWriteSTUNMessage = errors.New("the ICE conn can't write STUN messages") errUDPMuxDisabled = errors.New("UDPMux is not enabled") errCandidateIPNotFound = errors.New("could not determine local IP for Mux candidate") + errNoXorAddrMapping = errors.New("no address mapping") + errSendSTUNPacket = errors.New("failed to send STUN packet") + errXORMappedAddrTimeout = errors.New("timeout while waiting for XORMappedAddr") + errNotImplemented = errors.New("not implemented yet") ) diff --git a/gather.go b/gather.go index e248c08d..8d2c40db 100644 --- a/gather.go +++ b/gather.go @@ -97,7 +97,11 @@ func (a *Agent) gatherCandidates(ctx context.Context) { case CandidateTypeServerReflexive: wg.Add(1) go func() { - a.gatherCandidatesSrflx(ctx, a.urls, a.networkTypes) + if a.udpMuxSrflx != nil { + a.gatherCandidatesSrflxUDPMux(ctx, a.urls, a.networkTypes) + } else { + a.gatherCandidatesSrflx(ctx, a.urls, a.networkTypes) + } wg.Done() }() if a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeServerReflexive { @@ -333,6 +337,68 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes [] } } +func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, networkTypes []NetworkType) { + var wg sync.WaitGroup + defer wg.Wait() + + for _, networkType := range networkTypes { + if networkType.IsTCP() { + continue + } + + for i := range urls { + wg.Add(1) + go func(url URL, network string) { + defer wg.Done() + + hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port) + serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) + if err != nil { + a.log.Warnf("failed to resolve stun host: %s: %v", hostPort, err) + return + } + + xoraddr, err := a.udpMuxSrflx.GetXORMappedAddr(serverAddr, stunGatherTimeout) + if err != nil { + a.log.Warnf("could not get server reflexive address %s %s: %v\n", network, url, err) + return + } + + conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String()) + if err != nil { + a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v\n", network, url, err) + return + } + + ip := xoraddr.IP + port := xoraddr.Port + + laddr := conn.LocalAddr().(*net.UDPAddr) + srflxConfig := CandidateServerReflexiveConfig{ + Network: network, + Address: ip.String(), + Port: port, + Component: ComponentRTP, + RelAddr: laddr.IP.String(), + RelPort: laddr.Port, + } + c, err := NewCandidateServerReflexive(&srflxConfig) + if err != nil { + closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create server reflexive candidate: %s %s %d: %v\n", network, ip, port, err)) + return + } + + if err := a.addCandidate(ctx, c, conn); err != nil { + if closeErr := c.close(); closeErr != nil { + a.log.Warnf("Failed to close candidate: %v", closeErr) + } + a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v\n", err) + } + }(*urls[i], networkType.String()) + } + } +} + func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*URL, networkTypes []NetworkType) { var wg sync.WaitGroup defer wg.Wait() diff --git a/udp_mux.go b/udp_mux.go index 0bbe3b5c..f8fdd98a 100644 --- a/udp_mux.go +++ b/udp_mux.go @@ -1,6 +1,7 @@ package ice import ( + "github.com/pion/stun" "io" "net" "os" @@ -8,7 +9,6 @@ import ( "sync" "github.com/pion/logging" - "github.com/pion/stun" ) // UDPMux allows multiple connections to go over a single UDP port @@ -97,12 +97,16 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) { return c, nil } -// RemoveConnByUfrag stops and removes the muxed packet connection +// RemoveConnByUfrag stops and removes the all muxed packet connections with a ufrag prefix func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { + // it happens probably when we close not yet initialized agent + if ufrag == "" { + return + } m.mu.Lock() removedConns := make([]*udpMuxedConn, 0) for key := range m.conns { - if key != ufrag { + if !strings.HasPrefix(key, ufrag) { continue } diff --git a/udp_mux_test.go b/udp_mux_test.go index 3dd47f9a..67121de1 100644 --- a/udp_mux_test.go +++ b/udp_mux_test.go @@ -1,3 +1,4 @@ +//go:build !js // +build !js package ice @@ -59,7 +60,6 @@ func TestUDPMux(t *testing.T) { if ptrSize != 32 { testMuxConnection(t, udpMux, "ufrag3", "udp6") } - wg.Wait() require.NoError(t, udpMux.Close()) diff --git a/udp_mux_universal.go b/udp_mux_universal.go new file mode 100644 index 00000000..29653b0d --- /dev/null +++ b/udp_mux_universal.go @@ -0,0 +1,264 @@ +package ice + +import ( + "fmt" + "github.com/pion/logging" + "github.com/pion/stun" + "net" + "time" +) + +// UniversalUDPMux allows multiple connections to go over a single UDP port for +// host, server reflexive and relayed candidates. +// Actual connection muxing is happening in the UDPMux. +type UniversalUDPMux interface { + UDPMux + GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) + GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) + GetConnForURL(ufrag string, url string) (net.PacketConn, error) +} + +// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom. +// It the passes packets to the UDPMux that does the actual connection muxing. +type UniversalUDPMuxDefault struct { + *UDPMuxDefault + params UniversalUDPMuxParams + + // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents + // stun.XORMappedAddress indexed by the STUN server addr + xorMappedMap map[string]*xorMapped +} + +// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive. +type UniversalUDPMuxParams struct { + Logger logging.LeveledLogger + UDPConn net.PacketConn + XORMappedAddrCacheTTL time.Duration +} + +// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux +func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault { + if params.Logger == nil { + params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") + } + if params.XORMappedAddrCacheTTL == 0 { + params.XORMappedAddrCacheTTL = time.Second * 25 + } + + m := &UniversalUDPMuxDefault{ + params: params, + xorMappedMap: make(map[string]*xorMapped), + } + + // wrap UDP connection, process server reflexive messages + // before they are passed to the UDPMux connection handler (connWorker) + m.params.UDPConn = &udpConn{ + PacketConn: params.UDPConn, + mux: m, + logger: params.Logger, + } + + // embed UDPMux + udpMuxParams := UDPMuxParams{ + Logger: params.Logger, + UDPConn: m.params.UDPConn, + } + m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) + + return m +} + +// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets +type udpConn struct { + net.PacketConn + mux *UniversalUDPMuxDefault + logger logging.LeveledLogger +} + +// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr. +// Not implemented yet. +func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) { + return nil, errNotImplemented +} + +// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers +// and return a unique connection per server. +func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) { + return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url)) +} + +// ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address. +// It passes processed packets further to the UDPMux (maybe this is not really necessary). +func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return + } + + if stun.IsMessage(p[:n]) { + msg := &stun.Message{ + Raw: append([]byte{}, p[:n]...), + } + + if err = msg.Decode(); err != nil { + c.logger.Warnf("Failed to handle decode ICE from %s: %v\n", addr.String(), err) + return n, addr, nil + } + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + // message about this err will be logged in the UDPMux + return + } + + if c.mux.isXORMappedResponse(msg, udpAddr.String()) { + err = c.mux.handleXORMappedResponse(udpAddr, msg) + if err != nil { + c.logger.Debugf("%w: %v", errGetXorMappedAddrResponse, err) + return n, addr, nil + } + return + } + } + return n, addr, err +} + +// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. +func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool { + m.mu.Lock() + defer m.mu.Unlock() + // check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess + _, ok := m.xorMappedMap[stunAddr] + _, err := msg.Get(stun.AttrXORMappedAddress) + return err == nil && ok +} + +// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute +// and set the mapped address for the server +func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error { + m.mu.Lock() + defer m.mu.Unlock() + + mappedAddr, ok := m.xorMappedMap[stunAddr.String()] + if !ok { + return errNoXorAddrMapping + } + + var addr stun.XORMappedAddress + if err := addr.GetFrom(msg); err != nil { + return err + } + + m.xorMappedMap[stunAddr.String()] = mappedAddr + mappedAddr.SetAddr(&addr) + + return nil +} + +// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server. +// Makes a STUN binding request to discover mapped address otherwise. +// Blocks until the stun.XORMappedAddress has been discovered or deadline. +// Method is safe for concurrent use. +func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) { + m.mu.Lock() + mappedAddr, ok := m.xorMappedMap[serverAddr.String()] + // if we already have a mapping for this STUN server (address already received) + // and if it is not too old we return it without making a new request to STUN server + if ok { + if mappedAddr.expired() { + mappedAddr.closeWaiters() + delete(m.xorMappedMap, serverAddr.String()) + ok = false + } else if mappedAddr.pending() { + ok = false + } + } + m.mu.Unlock() + if ok { + return mappedAddr.addr, nil + } + + // otherwise, make a STUN request to discover the address + // or wait for already sent request to complete + waitAddrReceived, err := m.sendStun(serverAddr) + if err != nil { + return nil, errSendSTUNPacket + } + + // block until response was handled by the connWorker routine and XORMappedAddress was updated + select { + case <-waitAddrReceived: + // when channel closed, addr was obtained + m.mu.Lock() + mappedAddr := *m.xorMappedMap[serverAddr.String()] + m.mu.Unlock() + if mappedAddr.addr == nil { + return nil, errNoXorAddrMapping + } + return mappedAddr.addr, nil + case <-time.After(deadline): + return nil, errXORMappedAddrTimeout + } +} + +// sendStun sends a STUN request via UDP conn. +// +// The returned channel is closed when the STUN response has been received. +// Method is safe for concurrent use. +func (m *UniversalUDPMuxDefault) sendStun(serverAddr net.Addr) (chan struct{}, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // if record present in the map, we already sent a STUN request, + // just wait when waitAddrReceived will be closed + addrMap, ok := m.xorMappedMap[serverAddr.String()] + if !ok { + addrMap = &xorMapped{ + expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL), + waitAddrReceived: make(chan struct{}), + } + m.xorMappedMap[serverAddr.String()] = addrMap + } + + req, err := stun.Build(stun.BindingRequest, stun.TransactionID) + if err != nil { + return nil, err + } + + if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil { + return nil, err + } + + return addrMap.waitAddrReceived, nil +} + +type xorMapped struct { + addr *stun.XORMappedAddress + waitAddrReceived chan struct{} + expiresAt time.Time +} + +func (a *xorMapped) closeWaiters() { + select { + case <-a.waitAddrReceived: + // notify was close, ok, that means we received duplicate response + // just exit + break + default: + // notify tha twe have a new addr + close(a.waitAddrReceived) + } +} + +func (a *xorMapped) pending() bool { + return a.addr == nil +} + +func (a *xorMapped) expired() bool { + return a.expiresAt.Before(time.Now()) +} + +func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) { + a.addr = addr + a.closeWaiters() +} diff --git a/udp_mux_universal_test.go b/udp_mux_universal_test.go new file mode 100644 index 00000000..00930359 --- /dev/null +++ b/udp_mux_universal_test.go @@ -0,0 +1,128 @@ +//go:build !js +// +build !js + +package ice + +import ( + "github.com/pion/stun" + "github.com/stretchr/testify/require" + "net" + "sync" + "testing" + "time" +) + +func TestUniversalUDPMux(t *testing.T) { + conn, err := net.ListenUDP(udp, &net.UDPAddr{}) + require.NoError(t, err) + + udpMux := NewUniversalUDPMuxDefault(UniversalUDPMuxParams{ + Logger: nil, + UDPConn: conn, + }) + + require.NoError(t, err) + defer func() { + _ = udpMux.Close() + _ = conn.Close() + }() + + require.NotNil(t, udpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + testMuxSrflxConnection(t, udpMux, "ufrag4", udp) + }() + + wg.Wait() + +} + +func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) { + pktConn, err := udpMux.GetConn(ufrag) + require.NoError(t, err, "error retrieving muxed connection for ufrag") + defer func() { + _ = pktConn.Close() + }() + + remoteConn, err := net.DialUDP(network, nil, &net.UDPAddr{ + Port: udpMux.LocalAddr().(*net.UDPAddr).Port, + }) + require.NoError(t, err, "error dialing test udp connection") + defer func() { + _ = remoteConn.Close() + }() + + // use small value for TTL to check expiration of the address + udpMux.params.XORMappedAddrCacheTTL = time.Millisecond * 20 + testXORIP := net.ParseIP("213.141.156.236") + testXORPort := 21254 + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + address, e := udpMux.GetXORMappedAddr(remoteConn.LocalAddr(), time.Second) + require.NoError(t, e) + require.NotNil(t, address) + require.True(t, address.IP.Equal(testXORIP)) + require.Equal(t, address.Port, testXORPort) + }() + + // wait until GetXORMappedAddr calls sendStun method + time.Sleep(time.Millisecond) + + // check that mapped address filled correctly after sent stun + udpMux.mu.Lock() + mappedAddr, ok := udpMux.xorMappedMap[remoteConn.LocalAddr().String()] + require.True(t, ok) + require.NotNil(t, mappedAddr) + require.True(t, mappedAddr.pending()) + require.False(t, mappedAddr.expired()) + udpMux.mu.Unlock() + + // clean receiver read buffer + buf := make([]byte, receiveMTU) + _, err = remoteConn.Read(buf) + require.NoError(t, err) + + // write back to udpMux XOR message with address + msg := stun.New() + msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} + msg.Add(stun.AttrUsername, []byte(ufrag+":otherufrag")) + addr := &stun.XORMappedAddress{ + IP: testXORIP, + Port: testXORPort, + } + err = addr.AddTo(msg) + require.NoError(t, err) + + msg.Encode() + _, err = remoteConn.Write(msg.Raw) + require.NoError(t, err) + + // wait for the packet to be consumed and parsed by udpMux + wg.Wait() + + // we should get address immediately from the cached map + address, err := udpMux.GetXORMappedAddr(remoteConn.LocalAddr(), time.Second) + require.NoError(t, err) + require.NotNil(t, address) + + udpMux.mu.Lock() + // check mappedAddr is not pending, we didn't send stun twice + require.False(t, mappedAddr.pending()) + + // check expiration by TTL + time.Sleep(time.Millisecond * 21) + require.True(t, mappedAddr.expired()) + udpMux.mu.Unlock() + + // after expire, we send stun request again + // but we not receive response in 5 milliseconds and should get error here + address, err = udpMux.GetXORMappedAddr(remoteConn.LocalAddr(), time.Millisecond*5) + require.NotNil(t, err) + require.Nil(t, address) +}