diff --git a/peerconnection.go b/peerconnection.go index 63bd397b2e6..acdffbf9f86 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -12,7 +12,6 @@ import ( "crypto/rand" "errors" "fmt" - "io" "slices" "strconv" "strings" @@ -1688,7 +1687,7 @@ func (pc *PeerConnection) handleNonMediaBandwidthProbe() { } } -func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop +func (pc *PeerConnection) handleIncomingSSRC(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) error { //nolint:gocyclo,gocognit,cyclop,lll remoteDescription := pc.RemoteDescription() if remoteDescription == nil { return errPeerConnRemoteDescriptionNil @@ -1725,7 +1724,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err // We read the RTP packet to determine the payload type b := make([]byte, pc.api.settingEngine.getReceiveMTU()) - i, err := rtpStream.Read(b) + i, err := rtpStream.Peek(b) if err != nil { return err } @@ -1802,6 +1801,8 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err return err } + peekedPackets := []*peekedPacket{} + // if the first packet didn't contain simuilcast IDs, then probe more packets var paddingOnly bool for readCount := 0; readCount <= simulcastProbeCount; readCount++ { @@ -1811,11 +1812,16 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err readCount-- } - i, _, err := interceptor.Read(b, nil) + i, attributes, err := interceptor.Read(b, nil) if err != nil { return err } + peekedPackets = append(peekedPackets, &peekedPacket{ + payload: slices.Clone(b[:i]), + attributes: attributes, + }) + if paddingOnly, err = handleUnknownRTPPacket( b[:i], uint8(midExtensionID), //nolint:gosec // G115 uint8(streamIDExtensionID), //nolint:gosec // G115 @@ -1851,6 +1857,7 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err interceptor, rtcpReadStream, rtcpInterceptor, + peekedPackets, ) if err != nil { return err @@ -1930,7 +1937,7 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { //nolint:cyclop continue } - go func(rtpStream io.Reader, ssrc SSRC) { + go func(rtpStream *srtp.ReadStreamSRTP, ssrc SSRC) { if err := pc.handleIncomingSSRC(rtpStream, ssrc); err != nil { pc.log.Errorf(incomingUnhandledRTPSsrc, ssrc, err) } diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index cde1c0f4729..ed995142d75 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -10,6 +10,7 @@ import ( "bufio" "bytes" "context" + "crypto/rand" "errors" "fmt" "io" @@ -2062,14 +2063,13 @@ func TestPeerConnection_Simulcast_RTX(t *testing.T) { //nolint:cyclop assert.NotZero(t, ridID) assert.NotZero(t, rsid) - err = signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string { + assert.NoError(t, signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string { // Original chrome sdp contains no ssrc info https://pastebin.com/raw/JTjX6zg6 re := regexp.MustCompile("(?m)[\r\n]+^.*a=ssrc.*$") res := re.ReplaceAllString(sdp, "") return res - }) - assert.NoError(t, err) + })) // padding only packets should not affect simulcast probe var sequenceNumber uint16 @@ -2493,3 +2493,104 @@ func Test_PeerConnection_RTX_E2E(t *testing.T) { //nolint:cyclop closePairNow(t, pcOffer, pcAnswer) assert.NoError(t, wan.Stop()) } + +// Assert that we don't drop any packets during the probe. +func TestPeerConnection_Simulcast_Probe_PacketLoss(t *testing.T) { //nolint:cyclop + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const rtpPktCount = 10 + pcOffer, pcAnswer, wan := createVNetPair(t, nil) + + rids := []string{"a", "b", "c"} + vp8WriterA, err := NewTrackLocalStaticRTP( + RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[0]), + ) + assert.NoError(t, err) + + vp8WriterB, err := NewTrackLocalStaticRTP( + RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[1]), + ) + assert.NoError(t, err) + + vp8WriterC, err := NewTrackLocalStaticRTP( + RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[2]), + ) + assert.NoError(t, err) + + sender, err := pcOffer.AddTrack(vp8WriterA) + assert.NoError(t, err) + assert.NotNil(t, sender) + + assert.NoError(t, sender.AddEncoding(vp8WriterB)) + assert.NoError(t, sender.AddEncoding(vp8WriterC)) + + expectedBuffer := make([]byte, outboundMTU*rtpPktCount) + _, err = rand.Read(expectedBuffer) + assert.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + pcAnswer.OnTrack(func(trackRemote *TrackRemote, _ *RTPReceiver) { + actualBuffer := []byte{} + + for i := 0; i < rtpPktCount; i++ { + pkt, _, err := trackRemote.ReadRTP() + assert.NoError(t, err) + + actualBuffer = append(actualBuffer, pkt.Payload...) + } + + assert.Equal(t, actualBuffer, expectedBuffer) + cancel() + }) + + var midID, ridID uint8 + for _, extension := range sender.GetParameters().HeaderExtensions { + switch extension.URI { + case sdp.SDESMidURI: + midID = uint8(extension.ID) //nolint:gosec // G115 + case sdp.SDESRTPStreamIDURI: + ridID = uint8(extension.ID) //nolint:gosec // G115 + } + } + assert.NotZero(t, midID) + assert.NotZero(t, ridID) + + assert.NoError(t, signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string { + // Original chrome sdp contains no ssrc info https://pastebin.com/raw/JTjX6zg6 + re := regexp.MustCompile("(?m)[\r\n]+^.*a=ssrc.*$") + res := re.ReplaceAllString(sdp, "") + + return res + })) + + peerConnectionConnected := untilConnectionState(PeerConnectionStateConnected, pcOffer, pcAnswer) + peerConnectionConnected.Wait() + + for sequenceNumber := uint16(0); sequenceNumber < rtpPktCount; sequenceNumber++ { + pkt := &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + SequenceNumber: sequenceNumber, + }, + } + + // Make sure that packets for Stream received before MID/RID don't get dropped + if sequenceNumber > 3 { + assert.NoError(t, pkt.SetExtension(midID, []byte("0"))) + assert.NoError(t, pkt.SetExtension(ridID, []byte(vp8WriterA.RID()))) + } + + offset := int(sequenceNumber) * outboundMTU + pkt.Payload = expectedBuffer[offset : offset+outboundMTU] + assert.NoError(t, vp8WriterA.WriteRTP(pkt)) + } + + <-ctx.Done() + assert.NoError(t, wan.Stop()) + closePairNow(t, pcOffer, pcAnswer) +} diff --git a/peerconnection_renegotiation_test.go b/peerconnection_renegotiation_test.go index b6adcec7186..252fc8d0951 100644 --- a/peerconnection_renegotiation_test.go +++ b/peerconnection_renegotiation_test.go @@ -1105,12 +1105,6 @@ func TestPeerConnection_Renegotiation_Simulcast(t *testing.T) { for _, track := range trackMap { _, _, err := track.ReadRTP() - - // Ignore first Read, this was our peeked data - if err == nil { - _, _, err = track.ReadRTP() - } - assert.Equal(t, err, io.EOF) } } diff --git a/rtpreceiver.go b/rtpreceiver.go index aa0312db2e0..d627668a6f8 100644 --- a/rtpreceiver.go +++ b/rtpreceiver.go @@ -12,6 +12,7 @@ import ( "io" "math" "sync" + "sync/atomic" "time" "github.com/pion/interceptor" @@ -64,8 +65,9 @@ type RTPReceiver struct { tracks []trackStreams - closed, received chan any - mu sync.RWMutex + closed atomic.Bool + closedChan, received chan any + mu sync.RWMutex tr *RTPTransceiver @@ -84,12 +86,12 @@ func (api *API) NewRTPReceiver(kind RTPCodecType, transport *DTLSTransport) (*RT } rtpReceiver := &RTPReceiver{ - kind: kind, - transport: transport, - api: api, - closed: make(chan any), - received: make(chan any), - tracks: []trackStreams{}, + kind: kind, + transport: transport, + api: api, + closedChan: make(chan any), + received: make(chan any), + tracks: []trackStreams{}, rtxPool: sync.Pool{New: func() any { return make([]byte, api.settingEngine.getReceiveMTU()) }}, @@ -290,7 +292,7 @@ func (r *RTPReceiver) Read(b []byte) (n int, a interceptor.Attributes, err error } return r.tracks[0].rtcpInterceptor.Read(b, a) - case <-r.closed: + case <-r.closedChan: return 0, nil, io.ErrClosedPipe } } @@ -315,7 +317,7 @@ func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, a interceptor. return rtcpInterceptor.Read(b, a) - case <-r.closed: + case <-r.closedChan: return 0, nil, io.ErrClosedPipe } } @@ -359,6 +361,10 @@ func (r *RTPReceiver) haveReceived() bool { } } +func (r *RTPReceiver) haveClosed() bool { + return r.closed.Load() +} + // Stop irreversibly stops the RTPReceiver. func (r *RTPReceiver) Stop() error { //nolint:cyclop r.mu.Lock() @@ -366,7 +372,7 @@ func (r *RTPReceiver) Stop() error { //nolint:cyclop var err error select { - case <-r.closed: + case <-r.closedChan: return err default: } @@ -405,7 +411,8 @@ func (r *RTPReceiver) Stop() error { //nolint:cyclop default: } - close(r.closed) + close(r.closedChan) + r.closed.Store(true) return err } @@ -519,7 +526,7 @@ func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams { func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, a interceptor.Attributes, err error) { select { case <-r.received: - case <-r.closed: + case <-r.closedChan: return 0, nil, io.EOF } @@ -540,6 +547,7 @@ func (r *RTPReceiver) receiveForRid( rtpInterceptor interceptor.RTPReader, rtcpReadStream *srtp.ReadStreamSRTCP, rtcpInterceptor interceptor.RTCPReader, + peekedPackets []*peekedPacket, ) (*TrackRemote, error) { r.mu.Lock() defer r.mu.Unlock() @@ -551,6 +559,7 @@ func (r *RTPReceiver) receiveForRid( r.tracks[i].track.codec = params.Codecs[0] r.tracks[i].track.params = params r.tracks[i].track.ssrc = SSRC(streamInfo.SSRC) + r.tracks[i].track.peekedPackets = peekedPackets r.tracks[i].track.mu.Unlock() r.tracks[i].streamInfo = streamInfo @@ -651,7 +660,7 @@ func (r *RTPReceiver) receiveForRtx( copy(b[headerLength:i-2], b[headerLength+2:i]) select { - case <-r.closed: + case <-r.closedChan: r.rtxPool.Put(b) // nolint:staticcheck return diff --git a/stats_go.go b/stats_go.go index 95bd20d86bf..e7975632679 100644 --- a/stats_go.go +++ b/stats_go.go @@ -215,7 +215,7 @@ func (p *defaultAudioPlayoutStatsProvider) AddTrack(track *TrackRemote) error { } select { - case <-receiver.closed: + case <-receiver.closedChan: p.removeTrackInternal(track) case <-ctx.Done(): return diff --git a/stats_go_test.go b/stats_go_test.go index 77089062c48..af6f2bee9b0 100644 --- a/stats_go_test.go +++ b/stats_go_test.go @@ -2349,7 +2349,7 @@ func TestDefaultAudioPlayoutStatsProvider_AccumulateSnapshot(t *testing.T) { } func TestDefaultAudioPlayoutStatsProvider_AddRemoveTrack(t *testing.T) { - receiver := &RTPReceiver{closed: make(chan any)} + receiver := &RTPReceiver{closedChan: make(chan any)} track := newTrackRemote(RTPCodecTypeAudio, 1234, 0, "", receiver) samplesPerBatch := 960 @@ -2371,7 +2371,7 @@ func TestDefaultAudioPlayoutStatsProvider_AddRemoveTrack(t *testing.T) { } func TestDefaultAudioPlayoutStatsProvider_MultipleProviders(t *testing.T) { - receiver := &RTPReceiver{closed: make(chan any)} + receiver := &RTPReceiver{closedChan: make(chan any)} track := newTrackRemote(RTPCodecTypeAudio, 5555, 0, "", receiver) samplesPerBatch := 960 diff --git a/track_local_static_test.go b/track_local_static_test.go index 093d177054c..48f4b52400e 100644 --- a/track_local_static_test.go +++ b/track_local_static_test.go @@ -827,8 +827,7 @@ func Test_TrackRemote_ReadRTP_UnmarshalError(t *testing.T) { tr := newTrackRemote(RTPCodecTypeVideo, 0, 0, "", recv) tr.mu.Lock() - tr.peeked = []byte{0x80, 96} - tr.peekedAttributes = nil + tr.peekedPackets = []*peekedPacket{{payload: []byte{0x80, 96}}} tr.mu.Unlock() pkt, attrs, err := tr.ReadRTP() diff --git a/track_remote.go b/track_remote.go index ca6160f2e49..4cbecdc9af9 100644 --- a/track_remote.go +++ b/track_remote.go @@ -8,6 +8,7 @@ package webrtc import ( "fmt" + "io" "sync" "time" @@ -15,6 +16,11 @@ import ( "github.com/pion/rtp" ) +type peekedPacket struct { + payload []byte + attributes interceptor.Attributes +} + // TrackRemote represents a single inbound source of media. type TrackRemote struct { mu sync.RWMutex @@ -30,9 +36,9 @@ type TrackRemote struct { params RTPParameters rid string - receiver *RTPReceiver - peeked []byte - peekedAttributes interceptor.Attributes + receiver *RTPReceiver + + peekedPackets []*peekedPacket audioPlayoutStatsProviders []AudioPlayoutStatsProvider } @@ -116,25 +122,22 @@ func (t *TrackRemote) Codec() RTPCodecParameters { func (t *TrackRemote) Read(b []byte) (n int, attributes interceptor.Attributes, err error) { t.mu.RLock() receiver := t.receiver - peeked := t.peeked != nil + var peekedPkt *peekedPacket + if len(t.peekedPackets) != 0 { + peekedPkt = t.peekedPackets[0] + t.peekedPackets = t.peekedPackets[1:] + } t.mu.RUnlock() - if peeked { - t.mu.Lock() - data := t.peeked - attributes = t.peekedAttributes - - t.peeked = nil - t.peekedAttributes = nil - t.mu.Unlock() - // someone else may have stolen our packet when we - // released the lock. Deal with it. - if data != nil { - n = copy(b, data) - err = t.checkAndUpdateTrack(b) - - return n, attributes, err - } + if receiver.haveClosed() { + return 0, nil, io.EOF + } + + if peekedPkt != nil { + n = copy(b, peekedPkt.payload) + err = t.checkAndUpdateTrack(b) + + return n, peekedPkt.attributes, err } // If there's a separate RTX track and an RTX packet is available, return that @@ -142,17 +145,15 @@ func (t *TrackRemote) Read(b []byte) (n int, attributes interceptor.Attributes, n = copy(b, rtxPacketReceived.pkt) attributes = rtxPacketReceived.attributes rtxPacketReceived.release() - err = nil - } else { - // If there's no separate RTX track (or there's a separate RTX track but no RTX packet waiting), wait for and return - // a packet from the main track - n, attributes, err = receiver.readRTP(b, t) - if err != nil { - return n, attributes, err - } - err = t.checkAndUpdateTrack(b) + return n, attributes, nil + } + + n, attributes, err = receiver.readRTP(b, t) + if err != nil { + return n, attributes, err } + err = t.checkAndUpdateTrack(b) return n, attributes, err } @@ -212,8 +213,7 @@ func (t *TrackRemote) peek(b []byte) (n int, a interceptor.Attributes, err error // that case. data := make([]byte, n) n = copy(data, b[:n]) - t.peeked = data - t.peekedAttributes = a + t.peekedPackets = append(t.peekedPackets, &peekedPacket{payload: data, attributes: a}) t.mu.Unlock() return