Revert "Read/Write RTP/RTCP packets with context"

This change caused a ~24% performance decrease

Relates to pion/webrtc#1564

This reverts commit 47a7a64898.
This commit is contained in:
Sean DuBois
2020-12-02 19:36:51 -08:00
parent 3da29b7c0c
commit 9715626a0c
39 changed files with 205 additions and 329 deletions

View File

@@ -3,7 +3,6 @@
package webrtc
import (
"context"
"io"
"testing"
"time"
@@ -113,19 +112,19 @@ func (s *testORTCStack) setSignal(sig *testORTCSignal, isOffer bool) error {
}
// Start the ICE transport
err = s.ice.Start(context.Background(), nil, sig.ICEParameters, &iceRole)
err = s.ice.Start(nil, sig.ICEParameters, &iceRole)
if err != nil {
return err
}
// Start the DTLS transport
err = s.dtls.Start(context.Background(), sig.DTLSParameters)
err = s.dtls.Start(sig.DTLSParameters)
if err != nil {
return err
}
// Start the SCTP transport
err = s.sctp.Start(context.Background(), sig.SCTPCapabilities)
err = s.sctp.Start(sig.SCTPCapabilities)
if err != nil {
return err
}

View File

@@ -3,7 +3,6 @@
package webrtc
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
@@ -18,7 +17,7 @@ import (
"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/crypto/fingerprint"
"github.com/pion/srtp/v2"
"github.com/pion/srtp"
"github.com/pion/webrtc/v3/internal/mux"
"github.com/pion/webrtc/v3/internal/util"
"github.com/pion/webrtc/v3/pkg/rtcerr"
@@ -147,7 +146,7 @@ func (t *DTLSTransport) GetRemoteCertificate() []byte {
return t.remoteCertificate
}
func (t *DTLSTransport) startSRTP(ctx context.Context) error {
func (t *DTLSTransport) startSRTP() error {
srtpConfig := &srtp.Config{
Profile: t.srtpProtectionProfile,
LoggerFactory: t.api.settingEngine.LoggerFactory,
@@ -186,12 +185,12 @@ func (t *DTLSTransport) startSRTP(ctx context.Context) error {
return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
}
srtpSession, err := srtp.NewSessionSRTP(ctx, t.srtpEndpoint, srtpConfig)
srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
if err != nil {
return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
}
srtcpSession, err := srtp.NewSessionSRTCP(ctx, t.srtcpEndpoint, srtpConfig)
srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig)
if err != nil {
return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err)
}
@@ -245,7 +244,7 @@ func (t *DTLSTransport) role() DTLSRole {
}
// Start DTLS transport negotiation with the parameters of the remote DTLS transport
func (t *DTLSTransport) Start(ctx context.Context, remoteParameters DTLSParameters) error {
func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
// Take lock and prepare connection, we must not hold the lock
// when connecting
prepareTransport := func() (DTLSRole, *dtls.Config, error) {
@@ -351,7 +350,7 @@ func (t *DTLSTransport) Start(ctx context.Context, remoteParameters DTLSParamete
return err
}
return t.startSRTP(ctx)
return t.startSRTP()
}
// Stop stops and closes the DTLSTransport object.

View File

@@ -100,7 +100,6 @@ func TestE2E_Audio(t *testing.T) {
go func() {
for {
if err := track.WriteSample(
context.Background(),
media.Sample{Data: silentOpusFrame, Duration: time.Millisecond * 20},
); err != nil {
t.Errorf("Failed to WriteSample: %v", err)

View File

@@ -3,7 +3,6 @@
package main
import (
"context"
"errors"
"fmt"
"io"
@@ -54,9 +53,7 @@ func main() { // nolint:gocognit
go func() {
ticker := time.NewTicker(rtcpPLIInterval)
for range ticker.C {
if rtcpSendErr := peerConnection.WriteRTCP(
context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(remoteTrack.SSRC())}},
); rtcpSendErr != nil {
if rtcpSendErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(remoteTrack.SSRC())}}); rtcpSendErr != nil {
fmt.Println(rtcpSendErr)
}
}
@@ -71,13 +68,13 @@ func main() { // nolint:gocognit
rtpBuf := make([]byte, 1400)
for {
i, readErr := remoteTrack.Read(context.TODO(), rtpBuf)
i, readErr := remoteTrack.Read(rtpBuf)
if readErr != nil {
panic(readErr)
}
// ErrClosedPipe means we don't have any subscribers, this is ok if no peers have connected yet
if _, err = localTrack.Write(context.TODO(), rtpBuf[:i]); err != nil && !errors.Is(err, io.ErrClosedPipe) {
if _, err = localTrack.Write(rtpBuf[:i]); err != nil && !errors.Is(err, io.ErrClosedPipe) {
panic(err)
}
}

View File

@@ -74,7 +74,7 @@ func main() {
}
time.Sleep(sleepTime)
if ivfErr = videoTrack.WriteSample(context.TODO(), media.Sample{Data: frame, Duration: time.Second}); ivfErr != nil {
if ivfErr = videoTrack.WriteSample(media.Sample{Data: frame, Duration: time.Second}); ivfErr != nil {
panic(ivfErr)
}
}

View File

@@ -3,7 +3,6 @@
package main
import (
"context"
"flag"
"fmt"
"time"
@@ -111,7 +110,7 @@ func main() {
}
// Start the ICE transport
err = ice.Start(context.TODO(), nil, remoteSignal.ICEParameters, &iceRole)
err = ice.Start(nil, remoteSignal.ICEParameters, &iceRole)
if err != nil {
panic(err)
}

View File

@@ -3,7 +3,6 @@
package main
import (
"context"
"flag"
"fmt"
"time"
@@ -112,19 +111,19 @@ func main() {
}
// Start the ICE transport
err = ice.Start(context.TODO(), nil, remoteSignal.ICEParameters, &iceRole)
err = ice.Start(nil, remoteSignal.ICEParameters, &iceRole)
if err != nil {
panic(err)
}
// Start the DTLS transport
err = dtls.Start(context.TODO(), remoteSignal.DTLSParameters)
err = dtls.Start(remoteSignal.DTLSParameters)
if err != nil {
panic(err)
}
// Start the SCTP transport
err = sctp.Start(context.TODO(), remoteSignal.SCTPCapabilities)
err = sctp.Start(remoteSignal.SCTPCapabilities)
if err != nil {
panic(err)
}

View File

@@ -3,7 +3,6 @@
package main
import (
"context"
"encoding/json"
"fmt"
"math/rand"
@@ -146,7 +145,7 @@ func writeVideoToTrack(t *webrtc.TrackLocalStaticSample) {
}
time.Sleep(sleepTime)
if err = t.WriteSample(context.TODO(), media.Sample{Data: frame, Duration: time.Second}); err != nil {
if err = t.WriteSample(media.Sample{Data: frame, Duration: time.Second}); err != nil {
fmt.Printf("Finish writing video track: %s ", err)
return
}

View File

@@ -87,7 +87,7 @@ func main() {
}
time.Sleep(sleepTime)
if ivfErr = videoTrack.WriteSample(context.TODO(), media.Sample{Data: frame, Duration: time.Second}); ivfErr != nil {
if ivfErr = videoTrack.WriteSample(media.Sample{Data: frame, Duration: time.Second}); ivfErr != nil {
panic(ivfErr)
}
}
@@ -138,7 +138,7 @@ func main() {
lastGranule = pageHeader.GranulePosition
sampleDuration := time.Duration((sampleCount/48000)*1000) * time.Millisecond
if oggErr = audioTrack.WriteSample(context.TODO(), media.Sample{Data: pageData, Duration: sampleDuration}); oggErr != nil {
if oggErr = audioTrack.WriteSample(media.Sample{Data: pageData, Duration: sampleDuration}); oggErr != nil {
panic(oggErr)
}

View File

@@ -3,7 +3,6 @@
package main
import (
"context"
"fmt"
"time"
@@ -73,9 +72,7 @@ func main() {
go func() {
ticker := time.NewTicker(time.Second * 3)
for range ticker.C {
errSend := peerConnection.WriteRTCP(
context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}},
)
errSend := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}})
if errSend != nil {
fmt.Println(errSend)
}
@@ -85,12 +82,12 @@ func main() {
fmt.Printf("Track has started, of type %d: %s \n", track.PayloadType(), track.Codec().MimeType)
for {
// Read RTP packets being sent to Pion
rtp, readErr := track.ReadRTP(context.TODO())
rtp, readErr := track.ReadRTP()
if readErr != nil {
panic(readErr)
}
if writeErr := outputTrack.WriteRTP(context.TODO(), rtp); writeErr != nil {
if writeErr := outputTrack.WriteRTP(rtp); writeErr != nil {
panic(writeErr)
}
}

View File

@@ -107,9 +107,7 @@ func main() {
go func() {
ticker := time.NewTicker(time.Second * 2)
for range ticker.C {
if rtcpErr := peerConnection.WriteRTCP(
context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}},
); rtcpErr != nil {
if rtcpErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}); rtcpErr != nil {
fmt.Println(rtcpErr)
}
}
@@ -118,7 +116,7 @@ func main() {
b := make([]byte, 1500)
for {
// Read
n, readErr := track.Read(context.TODO(), b)
n, readErr := track.Read(b)
if readErr != nil {
panic(readErr)
}

View File

@@ -3,7 +3,6 @@
package main
import (
"context"
"fmt"
"net"
@@ -104,7 +103,7 @@ func main() {
panic(err)
}
if _, writeErr := videoTrack.Write(context.TODO(), inboundRTPPacket[:n]); writeErr != nil {
if _, writeErr := videoTrack.Write(inboundRTPPacket[:n]); writeErr != nil {
panic(writeErr)
}
}

View File

@@ -3,7 +3,6 @@
package main
import (
"context"
"fmt"
"os"
"time"
@@ -24,7 +23,7 @@ func saveToDisk(i media.Writer, track *webrtc.TrackRemote) {
}()
for {
rtpPacket, err := track.ReadRTP(context.TODO())
rtpPacket, err := track.ReadRTP()
if err != nil {
panic(err)
}
@@ -97,9 +96,7 @@ func main() {
go func() {
ticker := time.NewTicker(time.Second * 3)
for range ticker.C {
errSend := peerConnection.WriteRTCP(
context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}},
)
errSend := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}})
if errSend != nil {
fmt.Println(errSend)
}

View File

@@ -3,7 +3,6 @@
package main
import (
"context"
"errors"
"fmt"
"io"
@@ -81,28 +80,24 @@ func main() {
ticker := time.NewTicker(3 * time.Second)
for range ticker.C {
fmt.Printf("Sending pli for stream with rid: %q, ssrc: %d\n", track.RID(), track.SSRC())
if writeErr := peerConnection.WriteRTCP(
context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}},
); writeErr != nil {
if writeErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}); writeErr != nil {
fmt.Println(writeErr)
}
// Send a remb message with a very high bandwidth to trigger chrome to send also the high bitrate stream
fmt.Printf("Sending remb for stream with rid: %q, ssrc: %d\n", track.RID(), track.SSRC())
if writeErr := peerConnection.WriteRTCP(
context.TODO(), []rtcp.Packet{&rtcp.ReceiverEstimatedMaximumBitrate{Bitrate: 10000000, SenderSSRC: uint32(track.SSRC())}},
); writeErr != nil {
if writeErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.ReceiverEstimatedMaximumBitrate{Bitrate: 10000000, SenderSSRC: uint32(track.SSRC())}}); writeErr != nil {
fmt.Println(writeErr)
}
}
}()
for {
// Read RTP packets being sent to Pion
packet, readErr := track.ReadRTP(context.TODO())
packet, readErr := track.ReadRTP()
if readErr != nil {
panic(readErr)
}
if writeErr := outputTracks[rid].WriteRTP(context.TODO(), packet); writeErr != nil && !errors.Is(writeErr, io.ErrClosedPipe) {
if writeErr := outputTracks[rid].WriteRTP(packet); writeErr != nil && !errors.Is(writeErr, io.ErrClosedPipe) {
panic(writeErr)
}
}

View File

@@ -3,7 +3,6 @@
package main
import (
"context"
"errors"
"fmt"
"io"
@@ -86,7 +85,7 @@ func main() { // nolint:gocognit
var isCurrTrack bool
for {
// Read RTP packets being sent to Pion
rtp, readErr := track.ReadRTP(context.TODO())
rtp, readErr := track.ReadRTP()
if readErr != nil {
panic(readErr)
}
@@ -105,9 +104,7 @@ func main() { // nolint:gocognit
// If just switched to this track, send PLI to get picture refresh
if !isCurrTrack {
isCurrTrack = true
if writeErr := peerConnection.WriteRTCP(
context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}},
); writeErr != nil {
if writeErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}); writeErr != nil {
fmt.Println(writeErr)
}
}
@@ -157,7 +154,7 @@ func main() { // nolint:gocognit
// Keep an increasing sequence number
packet.SequenceNumber = i
// Write out the packet, ignoring closed pipe if nobody is listening
if err := outputTrack.WriteRTP(context.TODO(), packet); err != nil && !errors.Is(err, io.ErrClosedPipe) {
if err := outputTrack.WriteRTP(packet); err != nil && !errors.Is(err, io.ErrClosedPipe) {
panic(err)
}
}

4
go.mod
View File

@@ -6,7 +6,7 @@ require (
github.com/pion/datachannel v1.4.21
github.com/pion/dtls/v2 v2.0.4
github.com/pion/ice/v2 v2.0.13
github.com/pion/interceptor v0.0.4
github.com/pion/interceptor v0.0.3
github.com/pion/logging v0.2.2
github.com/pion/quic v0.1.4
github.com/pion/randutil v0.1.0
@@ -14,7 +14,7 @@ require (
github.com/pion/rtp v1.6.1
github.com/pion/sctp v1.7.11
github.com/pion/sdp/v3 v3.0.3
github.com/pion/srtp/v2 v2.0.0-rc.1
github.com/pion/srtp v1.5.2
github.com/pion/transport v0.11.1
github.com/sclevine/agouti v3.0.0+incompatible
github.com/stretchr/testify v1.6.1

8
go.sum
View File

@@ -106,8 +106,8 @@ github.com/pion/dtls/v2 v2.0.4 h1:WuUcqi6oYMu/noNTz92QrF1DaFj4eXbhQ6dzaaAwOiI=
github.com/pion/dtls/v2 v2.0.4/go.mod h1:qAkFscX0ZHoI1E07RfYPoRw3manThveu+mlTDdOxoGI=
github.com/pion/ice/v2 v2.0.13 h1:lVe7g86tQ0vKdH430hQR/t7zV1oeXbK75130TUArrnw=
github.com/pion/ice/v2 v2.0.13/go.mod h1:mZlypgoynMn2ayhGsjrPY/G/WiRiYO8WCPC6gUeg1RA=
github.com/pion/interceptor v0.0.4 h1:FNq8cFDDv0i+Db9oEKccmMx4rerPm6pzBf4szjNaX2E=
github.com/pion/interceptor v0.0.4/go.mod h1:lPVrf5xfosI989ZcmgPS4WwwRhd+XAyTFaYI2wHf7nU=
github.com/pion/interceptor v0.0.3 h1:VQtmPts/2IgYQtb9sZLTp6B0kIdHE5zBMQ6tCcgdJcM=
github.com/pion/interceptor v0.0.3/go.mod h1:lPVrf5xfosI989ZcmgPS4WwwRhd+XAyTFaYI2wHf7nU=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/mdns v0.0.4 h1:O4vvVqr4DGX63vzmO6Fw9vpy3lfztVWHGCQfyw0ZLSY=
@@ -125,8 +125,8 @@ github.com/pion/sctp v1.7.11 h1:UCnj7MsobLKLuP/Hh+JMiI/6W5Bs/VF45lWKgHFjSIE=
github.com/pion/sctp v1.7.11/go.mod h1:EhpTUQu1/lcK3xI+eriS6/96fWetHGCvBi9MSsnaBN0=
github.com/pion/sdp/v3 v3.0.3 h1:gJK9hk+JFD2NGIM1nXmqNCq1DkVaIZ9dlA3u3otnkaw=
github.com/pion/sdp/v3 v3.0.3/go.mod h1:bNiSknmJE0HYBprTHXKPQ3+JjacTv5uap92ueJZKsRk=
github.com/pion/srtp/v2 v2.0.0-rc.1 h1:UGabuCAIE5Yn5qmFr9H8UGdzYdYIaMt/AZLskcBcumA=
github.com/pion/srtp/v2 v2.0.0-rc.1/go.mod h1:i3Oyh9/RPIdYRXywhs0NwTkMV/XYYr8NGj1E3gsRgk8=
github.com/pion/srtp v1.5.2 h1:25DmvH+fqKZDqvX64vTwnycVwL9ooJxHF/gkX16bDBY=
github.com/pion/srtp v1.5.2/go.mod h1:NiBff/MSxUwMUwx/fRNyD/xGE+dVvf8BOCeXhjCXZ9U=
github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg=
github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA=
github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRhSzR+CZ8=

View File

@@ -11,7 +11,6 @@ import (
"github.com/pion/ice/v2"
"github.com/pion/logging"
"github.com/pion/transport/connctx"
"github.com/pion/webrtc/v3/internal/mux"
)
@@ -71,7 +70,7 @@ func NewICETransport(gatherer *ICEGatherer, loggerFactory logging.LoggerFactory)
}
// Start incoming connectivity checks based on its configured role.
func (t *ICETransport) Start(ctx context.Context, gatherer *ICEGatherer, params ICEParameters, role *ICERole) error {
func (t *ICETransport) Start(gatherer *ICEGatherer, params ICEParameters, role *ICERole) error {
t.lock.Lock()
defer t.lock.Unlock()
@@ -143,11 +142,11 @@ func (t *ICETransport) Start(ctx context.Context, gatherer *ICEGatherer, params
t.conn = iceConn
config := mux.Config{
Conn: connctx.New(t.conn),
Conn: t.conn,
BufferSize: receiveMTU,
LoggerFactory: t.loggerFactory,
}
t.mux = mux.NewMux(ctx, config)
t.mux = mux.NewMux(config)
return nil
}

View File

@@ -3,7 +3,6 @@
package webrtc
import (
"context"
"sync"
"sync/atomic"
"testing"
@@ -26,19 +25,19 @@ type testInterceptor struct {
}
func (t *testInterceptor) BindLocalStream(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
return interceptor.RTPWriterFunc(func(ctx context.Context, p *rtp.Packet, attributes interceptor.Attributes) (int, error) {
return interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) {
// set extension on outgoing packet
p.Header.Extension = true
p.Header.ExtensionProfile = 0xBEDE
assert.NoError(t.t, p.Header.SetExtension(t.extensionID, []byte("write")))
return writer.Write(ctx, p, attributes)
return writer.Write(p, attributes)
})
}
func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
return interceptor.RTPReaderFunc(func(ctx context.Context) (*rtp.Packet, interceptor.Attributes, error) {
p, attributes, err := reader.Read(ctx)
return interceptor.RTPReaderFunc(func() (*rtp.Packet, interceptor.Attributes, error) {
p, attributes, err := reader.Read()
if err != nil {
return nil, nil, err
}
@@ -50,7 +49,7 @@ func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader
// write back a pli
rtcpWriter := t.rtcpWriter.Load().(interceptor.RTCPWriter)
pli := &rtcp.PictureLossIndication{SenderSSRC: info.SSRC, MediaSSRC: info.SSRC}
_, err = rtcpWriter.Write(ctx, []rtcp.Packet{pli}, make(interceptor.Attributes))
_, err = rtcpWriter.Write([]rtcp.Packet{pli}, make(interceptor.Attributes))
assert.NoError(t.t, err)
return p, attributes, nil
@@ -58,8 +57,8 @@ func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader
}
func (t *testInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader {
return interceptor.RTCPReaderFunc(func(ctx context.Context) ([]rtcp.Packet, interceptor.Attributes, error) {
pkts, attributes, err := reader.Read(ctx)
return interceptor.RTCPReaderFunc(func() ([]rtcp.Packet, interceptor.Attributes, error) {
pkts, attributes, err := reader.Read()
if err != nil {
return nil, nil, err
}
@@ -123,7 +122,7 @@ func TestPeerConnection_Interceptor(t *testing.T) {
wg.Add(1)
*pending++
receiverPC.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) {
p, readErr := track.ReadRTP(context.Background())
p, readErr := track.ReadRTP()
if readErr != nil {
t.Fatal(readErr)
}
@@ -134,7 +133,7 @@ func TestPeerConnection_Interceptor(t *testing.T) {
wg.Done()
for {
_, readErr = track.ReadRTP(context.Background())
_, readErr = track.ReadRTP()
if readErr != nil {
return
}
@@ -144,13 +143,13 @@ func TestPeerConnection_Interceptor(t *testing.T) {
wg.Add(1)
*pending++
go func() {
_, readErr := sender.ReadRTCP(context.Background())
_, readErr := sender.ReadRTCP()
assert.NoError(t, readErr)
atomic.AddInt32(pending, -1)
wg.Done()
for {
_, readErr = sender.ReadRTCP(context.Background())
_, readErr = sender.ReadRTCP()
if readErr != nil {
return
}
@@ -167,9 +166,7 @@ func TestPeerConnection_Interceptor(t *testing.T) {
defer wg.Done()
for {
time.Sleep(time.Millisecond * 100)
if routineErr := track.WriteSample(
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
); routineErr != nil {
if routineErr := track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}); routineErr != nil {
t.Error(routineErr)
return
}

View File

@@ -3,7 +3,6 @@
package webrtc
import (
"context"
"sync/atomic"
"github.com/pion/interceptor"
@@ -19,12 +18,12 @@ func (i *interceptorTrackLocalWriter) setRTPWriter(writer interceptor.RTPWriter)
i.rtpWriter.Store(writer)
}
func (i *interceptorTrackLocalWriter) WriteRTP(ctx context.Context, header *rtp.Header, payload []byte) (int, error) {
func (i *interceptorTrackLocalWriter) WriteRTP(header *rtp.Header, payload []byte) (int, error) {
writer := i.rtpWriter.Load().(interceptor.RTPWriter)
if writer == nil {
return 0, nil
}
return writer.Write(ctx, &rtp.Packet{Header: *header, Payload: payload}, make(interceptor.Attributes))
return writer.Write(&rtp.Packet{Header: *header, Payload: payload}, make(interceptor.Attributes))
}

View File

@@ -1,7 +1,6 @@
package mux
import (
"context"
"errors"
"io"
"net"
@@ -35,23 +34,12 @@ func (e *Endpoint) close() error {
// Read reads a packet of len(p) bytes from the underlying conn
// that are matched by the associated MuxFunc
func (e *Endpoint) Read(p []byte) (int, error) {
return e.buffer.ReadContext(context.Background(), p)
}
// ReadContext reads a packet of len(p) bytes from the underlying conn
// that are matched by the associated MuxFunc
func (e *Endpoint) ReadContext(ctx context.Context, p []byte) (int, error) {
return e.buffer.ReadContext(ctx, p)
return e.buffer.Read(p)
}
// Write writes len(p) bytes to the underlying conn
func (e *Endpoint) Write(p []byte) (int, error) {
return e.WriteContext(context.Background(), p)
}
// WriteContext writes len(p) bytes to the underlying conn
func (e *Endpoint) WriteContext(ctx context.Context, p []byte) (int, error) {
n, err := e.mux.nextConn.WriteContext(ctx, p)
n, err := e.mux.nextConn.Write(p)
if errors.Is(err, ice.ErrNoCandidatePairs) {
return 0, nil
} else if errors.Is(err, ice.ErrClosed) {

View File

@@ -2,11 +2,10 @@
package mux
import (
"context"
"net"
"sync"
"github.com/pion/logging"
"github.com/pion/transport/connctx"
"github.com/pion/transport/packetio"
)
@@ -16,7 +15,7 @@ const maxBufferSize = 1000 * 1000 // 1MB
// Config collects the arguments to mux.Mux construction into
// a single structure
type Config struct {
Conn connctx.ConnCtx
Conn net.Conn
BufferSize int
LoggerFactory logging.LoggerFactory
}
@@ -24,7 +23,7 @@ type Config struct {
// Mux allows multiplexing
type Mux struct {
lock sync.RWMutex
nextConn connctx.ConnCtx
nextConn net.Conn
endpoints map[*Endpoint]MatchFunc
bufferSize int
closedCh chan struct{}
@@ -33,7 +32,7 @@ type Mux struct {
}
// NewMux creates a new Mux
func NewMux(ctx context.Context, config Config) *Mux {
func NewMux(config Config) *Mux {
m := &Mux{
nextConn: config.Conn,
endpoints: make(map[*Endpoint]MatchFunc),
@@ -42,7 +41,7 @@ func NewMux(ctx context.Context, config Config) *Mux {
log: config.LoggerFactory.NewLogger("mux"),
}
go m.readLoop(ctx)
go m.readLoop()
return m
}
@@ -97,26 +96,26 @@ func (m *Mux) Close() error {
return nil
}
func (m *Mux) readLoop(ctx context.Context) {
func (m *Mux) readLoop() {
defer func() {
close(m.closedCh)
}()
buf := make([]byte, m.bufferSize)
for {
n, err := m.nextConn.ReadContext(ctx, buf)
n, err := m.nextConn.Read(buf)
if err != nil {
return
}
err = m.dispatch(ctx, buf[:n])
err = m.dispatch(buf[:n])
if err != nil {
return
}
}
}
func (m *Mux) dispatch(ctx context.Context, buf []byte) error {
func (m *Mux) dispatch(buf []byte) error {
var endpoint *Endpoint
m.lock.Lock()
@@ -137,7 +136,7 @@ func (m *Mux) dispatch(ctx context.Context, buf []byte) error {
return nil
}
_, err := endpoint.buffer.WriteContext(ctx, buf)
_, err := endpoint.buffer.Write(buf)
if err != nil {
return err
}

View File

@@ -1,33 +1,29 @@
package mux
import (
"context"
"net"
"testing"
"time"
"github.com/pion/logging"
"github.com/pion/transport/connctx"
"github.com/pion/transport/test"
)
func TestStressDuplex(t *testing.T) {
lim := test.TimeOut(time.Second * 30)
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
// Run the test
stressDuplex(ctx, t)
stressDuplex(t)
}
func stressDuplex(ctx context.Context, t *testing.T) {
ca, cb, stop := pipeMemory(ctx)
func stressDuplex(t *testing.T) {
ca, cb, stop := pipeMemory()
defer func() {
stop(t)
@@ -38,21 +34,13 @@ func stressDuplex(ctx context.Context, t *testing.T) {
MsgCount: 100,
}
t.Run("WithoutContext", func(t *testing.T) {
err := test.StressDuplex(ca, cb.Conn(), opt)
if err != nil {
t.Fatal(err)
}
})
t.Run("WithContext", func(t *testing.T) {
err := test.StressDuplexContext(context.Background(), ca, cb, opt)
if err != nil {
t.Fatal(err)
}
})
err := test.StressDuplex(ca, cb, opt)
if err != nil {
t.Fatal(err)
}
}
func pipeMemory(ctx context.Context) (*Endpoint, connctx.ConnCtx, func(*testing.T)) {
func pipeMemory() (*Endpoint, net.Conn, func(*testing.T)) {
// In memory pipe
ca, cb := net.Pipe()
@@ -61,12 +49,12 @@ func pipeMemory(ctx context.Context) (*Endpoint, connctx.ConnCtx, func(*testing.
}
config := Config{
Conn: connctx.New(ca),
Conn: ca,
BufferSize: 8192,
LoggerFactory: logging.NewDefaultLoggerFactory(),
}
m := NewMux(ctx, config)
m := NewMux(config)
e := m.NewEndpoint(matchAll)
m.RemoveEndpoint(e)
e = m.NewEndpoint(matchAll)
@@ -82,7 +70,7 @@ func pipeMemory(ctx context.Context) (*Endpoint, connctx.ConnCtx, func(*testing.
}
}
return e, connctx.New(cb), stop
return e, cb, stop
}
func TestNoEndpoints(t *testing.T) {
@@ -94,16 +82,13 @@ func TestNoEndpoints(t *testing.T) {
}
config := Config{
Conn: connctx.New(ca),
Conn: ca,
BufferSize: 8192,
LoggerFactory: logging.NewDefaultLoggerFactory(),
}
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
m := NewMux(ctx, config)
err = m.dispatch(ctx, make([]byte, 1))
m := NewMux(config)
err = m.dispatch(make([]byte, 1))
if err != nil {
t.Fatal(err)
}

View File

@@ -1,12 +1,11 @@
package webrtc
import (
"context"
"sync"
)
// Operation is a function
type operation func(ctx context.Context)
type operation func()
// Operations is a task executor.
type operations struct {
@@ -33,7 +32,7 @@ func (o *operations) Enqueue(op operation) {
o.mu.Unlock()
if !running {
go o.start(context.Background())
go o.start()
}
}
@@ -49,13 +48,13 @@ func (o *operations) IsEmpty() bool {
func (o *operations) Done() {
var wg sync.WaitGroup
wg.Add(1)
o.Enqueue(func(_ context.Context) {
o.Enqueue(func() {
wg.Done()
})
wg.Wait()
}
func (o *operations) pop() func(context.Context) {
func (o *operations) pop() func() {
o.mu.Lock()
defer o.mu.Unlock()
if len(o.ops) == 0 {
@@ -67,7 +66,7 @@ func (o *operations) pop() func(context.Context) {
return fn
}
func (o *operations) start(ctx context.Context) {
func (o *operations) start() {
defer func() {
o.mu.Lock()
defer o.mu.Unlock()
@@ -77,12 +76,12 @@ func (o *operations) start(ctx context.Context) {
}
// either a new operation was enqueued while we
// were busy, or an operation panicked
go o.start(ctx)
go o.start()
}()
fn := o.pop()
for fn != nil {
fn(ctx)
fn()
fn = o.pop()
}
}

View File

@@ -1,7 +1,6 @@
package webrtc
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
@@ -13,7 +12,7 @@ func TestOperations_Enqueue(t *testing.T) {
results := make([]int, 16)
for i := range results {
func(j int) {
ops.Enqueue(func(_ context.Context) {
ops.Enqueue(func() {
results[j] = j * j
})
}(i)

View File

@@ -3,11 +3,11 @@
package webrtc
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"fmt"
"io"
"strconv"
"strings"
"sync"
@@ -19,7 +19,6 @@ import (
"github.com/pion/logging"
"github.com/pion/rtcp"
"github.com/pion/sdp/v3"
"github.com/pion/transport/connctx"
"github.com/pion/webrtc/v3/internal/util"
"github.com/pion/webrtc/v3/pkg/rtcerr"
)
@@ -276,7 +275,7 @@ func (pc *PeerConnection) onNegotiationNeeded() {
pc.ops.Enqueue(pc.negotiationNeededOp)
}
func (pc *PeerConnection) negotiationNeededOp(ctx context.Context) {
func (pc *PeerConnection) negotiationNeededOp() {
// Don't run NegotiatedNeeded checks if OnNegotiationNeeded is not set
if handler := pc.onNegotiationNeededHandler.Load(); handler == nil {
return
@@ -946,8 +945,8 @@ func (pc *PeerConnection) SetLocalDescription(desc SessionDescription) error {
if err := pc.startRTPSenders(currentTransceivers); err != nil {
return err
}
pc.ops.Enqueue(func(ctx context.Context) {
pc.startRTP(ctx, haveLocalDescription, remoteDesc, currentTransceivers)
pc.ops.Enqueue(func() {
pc.startRTP(haveLocalDescription, remoteDesc, currentTransceivers)
})
}
@@ -1067,8 +1066,8 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
if err = pc.startRTPSenders(currentTransceivers); err != nil {
return err
}
pc.ops.Enqueue(func(ctx context.Context) {
pc.startRTP(ctx, true, &desc, currentTransceivers)
pc.ops.Enqueue(func() {
pc.startRTP(true, &desc, currentTransceivers)
})
}
return nil
@@ -1100,16 +1099,16 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
}
}
pc.ops.Enqueue(func(ctx context.Context) {
pc.startTransports(ctx, iceRole, dtlsRoleFromRemoteSDP(desc.parsed), remoteUfrag, remotePwd, fingerprint, fingerprintHash)
pc.ops.Enqueue(func() {
pc.startTransports(iceRole, dtlsRoleFromRemoteSDP(desc.parsed), remoteUfrag, remotePwd, fingerprint, fingerprintHash)
if weOffer {
pc.startRTP(ctx, false, &desc, currentTransceivers)
pc.startRTP(false, &desc, currentTransceivers)
}
})
return nil
}
func (pc *PeerConnection) startReceiver(ctx context.Context, incoming trackDetails, receiver *RTPReceiver) {
func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPReceiver) {
encodings := []RTPDecodingParameters{}
if incoming.ssrc != 0 {
encodings = append(encodings, RTPDecodingParameters{RTPCodingParameters{SSRC: incoming.ssrc}})
@@ -1138,7 +1137,7 @@ func (pc *PeerConnection) startReceiver(ctx context.Context, incoming trackDetai
}
go func() {
if err := receiver.Track().determinePayloadType(ctx); err != nil {
if err := receiver.Track().determinePayloadType(); err != nil {
pc.log.Warnf("Could not determine PayloadType for SSRC %d", receiver.Track().SSRC())
return
}
@@ -1161,7 +1160,7 @@ func (pc *PeerConnection) startReceiver(ctx context.Context, incoming trackDetai
}
// startRTPReceivers opens knows inbound SRTP streams from the RemoteDescription
func (pc *PeerConnection) startRTPReceivers(ctx context.Context, incomingTracks []trackDetails, currentTransceivers []*RTPTransceiver) { //nolint:gocognit
func (pc *PeerConnection) startRTPReceivers(incomingTracks []trackDetails, currentTransceivers []*RTPTransceiver) { //nolint:gocognit
localTransceivers := append([]*RTPTransceiver{}, currentTransceivers...)
remoteIsPlanB := false
@@ -1208,7 +1207,7 @@ func (pc *PeerConnection) startRTPReceivers(ctx context.Context, incomingTracks
continue
}
pc.startReceiver(ctx, incomingTrack, t.Receiver())
pc.startReceiver(incomingTrack, t.Receiver())
trackHandled = true
break
}
@@ -1227,7 +1226,7 @@ func (pc *PeerConnection) startRTPReceivers(ctx context.Context, incomingTracks
pc.log.Warnf("Could not add transceiver for remote SSRC %d: %s", incoming.ssrc, err)
continue
}
pc.startReceiver(ctx, incoming, t.Receiver())
pc.startReceiver(incoming, t.Receiver())
}
}
}
@@ -1254,9 +1253,9 @@ func (pc *PeerConnection) startRTPSenders(currentTransceivers []*RTPTransceiver)
}
// Start SCTP subsystem
func (pc *PeerConnection) startSCTP(ctx context.Context) {
func (pc *PeerConnection) startSCTP() {
// Start sctp
if err := pc.sctpTransport.Start(ctx, SCTPCapabilities{
if err := pc.sctpTransport.Start(SCTPCapabilities{
MaxMessageSize: 0,
}); err != nil {
pc.log.Warnf("Failed to start SCTP: %s", err)
@@ -1290,7 +1289,7 @@ func (pc *PeerConnection) startSCTP(ctx context.Context) {
pc.sctpTransport.lock.Unlock()
}
func (pc *PeerConnection) handleUndeclaredSSRC(ctx context.Context, rtpStream connctx.Reader, ssrc SSRC) error { //nolint:gocognit
func (pc *PeerConnection) handleUndeclaredSSRC(rtpStream io.Reader, ssrc SSRC) error { //nolint:gocognit
remoteDescription := pc.RemoteDescription()
if remoteDescription == nil {
return errPeerConnRemoteDescriptionNil
@@ -1319,7 +1318,7 @@ func (pc *PeerConnection) handleUndeclaredSSRC(ctx context.Context, rtpStream co
if err != nil {
return fmt.Errorf("%w: %d: %s", errPeerConnRemoteSSRCAddTransceiver, ssrc, err)
}
pc.startReceiver(ctx, incoming, t.Receiver())
pc.startReceiver(incoming, t.Receiver())
return nil
}
@@ -1336,7 +1335,7 @@ func (pc *PeerConnection) handleUndeclaredSSRC(ctx context.Context, rtpStream co
b := make([]byte, receiveMTU)
var mid, rid string
for readCount := 0; readCount <= simulcastProbeCount; readCount++ {
i, err := rtpStream.ReadContext(ctx, b)
i, err := rtpStream.Read(b)
if err != nil {
return err
}
@@ -1380,7 +1379,7 @@ func (pc *PeerConnection) handleUndeclaredSSRC(ctx context.Context, rtpStream co
}
// undeclaredMediaProcessor handles RTP/RTCP packets that don't match any a:ssrc lines
func (pc *PeerConnection) undeclaredMediaProcessor(ctx context.Context) {
func (pc *PeerConnection) undeclaredMediaProcessor() {
go func() {
for {
srtpSession, err := pc.dtlsTransport.getSRTPSession()
@@ -1395,7 +1394,7 @@ func (pc *PeerConnection) undeclaredMediaProcessor(ctx context.Context) {
return
}
if err := pc.handleUndeclaredSSRC(ctx, stream, SSRC(ssrc)); err != nil {
if err := pc.handleUndeclaredSSRC(stream, SSRC(ssrc)); err != nil {
pc.log.Errorf("Incoming unhandled RTP ssrc(%d), OnTrack will not be fired. %v", ssrc, err)
}
}
@@ -1754,12 +1753,12 @@ func (pc *PeerConnection) SetIdentityProvider(provider string) error {
// WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the
// packet is discarded. It also runs any configured interceptors.
func (pc *PeerConnection) WriteRTCP(ctx context.Context, pkts []rtcp.Packet) error {
_, err := pc.interceptorRTCPWriter.Write(ctx, pkts, make(interceptor.Attributes))
func (pc *PeerConnection) WriteRTCP(pkts []rtcp.Packet) error {
_, err := pc.interceptorRTCPWriter.Write(pkts, make(interceptor.Attributes))
return err
}
func (pc *PeerConnection) writeRTCP(ctx context.Context, pkts []rtcp.Packet, _ interceptor.Attributes) (int, error) {
func (pc *PeerConnection) writeRTCP(pkts []rtcp.Packet, _ interceptor.Attributes) (int, error) {
raw, err := rtcp.Marshal(pkts)
if err != nil {
return 0, err
@@ -1775,7 +1774,7 @@ func (pc *PeerConnection) writeRTCP(ctx context.Context, pkts []rtcp.Packet, _ i
return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err)
}
if n, err := writeStream.WriteContext(ctx, raw); err != nil {
if n, err := writeStream.Write(raw); err != nil {
return n, err
}
return 0, nil
@@ -1986,10 +1985,9 @@ func (pc *PeerConnection) GetStats() StatsReport {
}
// Start all transports. PeerConnection now has enough state
func (pc *PeerConnection) startTransports(ctx context.Context, iceRole ICERole, dtlsRole DTLSRole, remoteUfrag, remotePwd, fingerprint, fingerprintHash string) {
func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, remoteUfrag, remotePwd, fingerprint, fingerprintHash string) {
// Start the ice transport
err := pc.iceTransport.Start(
ctx,
pc.iceGatherer,
ICEParameters{
UsernameFragment: remoteUfrag,
@@ -2004,7 +2002,7 @@ func (pc *PeerConnection) startTransports(ctx context.Context, iceRole ICERole,
}
// Start the dtls transport
err = pc.dtlsTransport.Start(ctx, DTLSParameters{
err = pc.dtlsTransport.Start(DTLSParameters{
Role: dtlsRole,
Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}},
})
@@ -2015,7 +2013,7 @@ func (pc *PeerConnection) startTransports(ctx context.Context, iceRole ICERole,
}
}
func (pc *PeerConnection) startRTP(ctx context.Context, isRenegotiation bool, remoteDesc *SessionDescription, currentTransceivers []*RTPTransceiver) {
func (pc *PeerConnection) startRTP(isRenegotiation bool, remoteDesc *SessionDescription, currentTransceivers []*RTPTransceiver) {
trackDetails := trackDetailsFromSDP(pc.log, remoteDesc.parsed)
if isRenegotiation {
for _, t := range currentTransceivers {
@@ -2049,13 +2047,13 @@ func (pc *PeerConnection) startRTP(ctx context.Context, isRenegotiation bool, re
}
}
pc.startRTPReceivers(ctx, trackDetails, currentTransceivers)
pc.startRTPReceivers(trackDetails, currentTransceivers)
if haveApplicationMediaSection(remoteDesc.parsed) {
pc.startSCTP(ctx)
pc.startSCTP()
}
if !isRenegotiation {
pc.undeclaredMediaProcessor(ctx)
pc.undeclaredMediaProcessor()
}
}

View File

@@ -1137,7 +1137,7 @@ func TestPeerConnection_MassiveTracks(t *testing.T) {
<-connected
time.Sleep(1 * time.Second)
for _, track := range tracks {
assert.NoError(t, track.WriteRTP(context.Background(), samplePkt))
assert.NoError(t, track.WriteRTP(samplePkt))
}
// Ping trackRecords to see if any track event not received yet.
tooLong := time.After(timeoutDuration)

View File

@@ -5,7 +5,6 @@ package webrtc
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
@@ -89,10 +88,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
go func() {
for {
time.Sleep(time.Millisecond * 100)
if routineErr := pcAnswer.WriteRTCP(
context.Background(),
[]rtcp.Packet{&rtcp.RapidResynchronizationRequest{SenderSSRC: uint32(track.SSRC()), MediaSSRC: uint32(track.SSRC())}},
); routineErr != nil {
if routineErr := pcAnswer.WriteRTCP([]rtcp.Packet{&rtcp.RapidResynchronizationRequest{SenderSSRC: uint32(track.SSRC()), MediaSSRC: uint32(track.SSRC())}}); routineErr != nil {
awaitRTCPReceiverSend <- routineErr
return
}
@@ -107,7 +103,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
}()
go func() {
_, routineErr := receiver.Read(context.Background(), make([]byte, 1400))
_, routineErr := receiver.Read(make([]byte, 1400))
if routineErr != nil {
awaitRTCPReceiverRecv <- routineErr
} else {
@@ -117,7 +113,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
haveClosedAwaitRTPRecv := false
for {
p, routineErr := track.ReadRTP(context.Background())
p, routineErr := track.ReadRTP()
if routineErr != nil {
close(awaitRTPRecvClosed)
return
@@ -140,9 +136,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
go func() {
for {
time.Sleep(time.Millisecond * 100)
if routineErr := vp8Track.WriteSample(
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
); routineErr != nil {
if routineErr := vp8Track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}); routineErr != nil {
fmt.Println(routineErr)
}
@@ -158,10 +152,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
go func() {
for {
time.Sleep(time.Millisecond * 100)
if routineErr := pcOffer.WriteRTCP(
context.Background(),
[]rtcp.Packet{&rtcp.PictureLossIndication{SenderSSRC: uint32(sender.ssrc), MediaSSRC: uint32(sender.ssrc)}},
); routineErr != nil {
if routineErr := pcOffer.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{SenderSSRC: uint32(sender.ssrc), MediaSSRC: uint32(sender.ssrc)}}); routineErr != nil {
awaitRTCPSenderSend <- routineErr
}
@@ -175,7 +166,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
}()
go func() {
if _, routineErr := sender.Read(context.Background(), make([]byte, 1400)); routineErr == nil {
if _, routineErr := sender.Read(make([]byte, 1400)); routineErr == nil {
close(awaitRTCPSenderRecv)
}
}()
@@ -375,13 +366,9 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) {
t.Fatal(err)
}
for i := 0; i <= 5; i++ {
if rtpErr := vp8Track.WriteSample(
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
); rtpErr != nil {
if rtpErr := vp8Track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}); rtpErr != nil {
t.Fatal(rtpErr)
} else if rtcpErr := pcOffer.WriteRTCP(
context.Background(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: 0}},
); rtcpErr != nil {
} else if rtcpErr := pcOffer.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: 0}}); rtcpErr != nil {
t.Fatal(rtcpErr)
}
}
@@ -461,9 +448,7 @@ func TestUndeclaredSSRC(t *testing.T) {
go func() {
for {
assert.NoError(t, vp8Writer.WriteSample(
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
))
assert.NoError(t, vp8Writer.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
time.Sleep(time.Millisecond * 25)
select {
@@ -701,11 +686,11 @@ func TestRtpSenderReceiver_ReadClose_Error(t *testing.T) {
sender, receiver := tr.Sender(), tr.Receiver()
assert.NoError(t, sender.Stop())
_, err = sender.Read(context.Background(), make([]byte, 0, 1400))
_, err = sender.Read(make([]byte, 0, 1400))
assert.Error(t, err, io.ErrClosedPipe)
assert.NoError(t, receiver.Stop())
_, err = receiver.Read(context.Background(), make([]byte, 0, 1400))
_, err = receiver.Read(make([]byte, 0, 1400))
assert.Error(t, err, io.ErrClosedPipe)
assert.NoError(t, pc.Close())
@@ -852,9 +837,7 @@ func TestPlanBMediaExchange(t *testing.T) {
select {
case <-time.After(20 * time.Millisecond):
for _, track := range outboundTracks {
assert.NoError(t, track.WriteSample(
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
))
assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
}
case <-done:
return

View File

@@ -25,9 +25,7 @@ func sendVideoUntilDone(done <-chan struct{}, t *testing.T, tracks []*TrackLocal
select {
case <-time.After(20 * time.Millisecond):
for _, track := range tracks {
assert.NoError(t, track.WriteSample(
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
))
assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
}
case <-done:
return
@@ -102,9 +100,7 @@ func TestPeerConnection_Renegotiation_AddTrack(t *testing.T) {
// Send 10 packets, OnTrack MUST not be fired
for i := 0; i <= 10; i++ {
assert.NoError(t, vp8Track.WriteSample(
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
))
assert.NoError(t, vp8Track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
time.Sleep(20 * time.Millisecond)
}
@@ -364,7 +360,7 @@ func TestPeerConnection_Renegotiation_CodecChange(t *testing.T) {
pcAnswer.OnTrack(func(track *TrackRemote, r *RTPReceiver) {
tracksCh <- track
for {
if _, readErr := track.ReadRTP(context.Background()); readErr == io.EOF {
if _, readErr := track.ReadRTP(); readErr == io.EOF {
tracksClosed <- struct{}{}
return
}
@@ -454,7 +450,7 @@ func TestPeerConnection_Renegotiation_RemoveTrack(t *testing.T) {
onTrackFiredFunc()
for {
if _, err := track.ReadRTP(context.Background()); err == io.EOF {
if _, err := track.ReadRTP(); err == io.EOF {
trackClosedFunc()
return
}
@@ -842,9 +838,7 @@ func TestNegotiationNeededRemoveTrack(t *testing.T) {
sender, err := pcOffer.AddTrack(track)
assert.NoError(t, err)
assert.NoError(t, track.WriteSample(
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
))
assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
wg.Wait()

View File

@@ -5,7 +5,6 @@ package webrtc
import (
"bytes"
"context"
"encoding/binary"
"testing"
"time"
@@ -143,7 +142,7 @@ func (s *testQuicStack) setSignal(sig *testQuicSignal, isOffer bool) error {
}
// Start the ICE transport
err = s.ice.Start(context.Background(), nil, sig.ICEParameters, &iceRole)
err = s.ice.Start(nil, sig.ICEParameters, &iceRole)
if err != nil {
return err
}

View File

@@ -3,14 +3,13 @@
package webrtc
import (
"context"
"fmt"
"io"
"sync"
"github.com/pion/interceptor"
"github.com/pion/rtcp"
"github.com/pion/srtp/v2"
"github.com/pion/srtp"
)
// trackStreams maintains a mapping of RTP/RTCP streams to a specific track
@@ -133,45 +132,41 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error {
}
// Read reads incoming RTCP for this RTPReceiver
func (r *RTPReceiver) Read(ctx context.Context, b []byte) (n int, err error) {
func (r *RTPReceiver) Read(b []byte) (n int, err error) {
select {
case <-r.received:
return r.tracks[0].rtcpReadStream.ReadContext(ctx, b)
return r.tracks[0].rtcpReadStream.Read(b)
case <-r.closed:
return 0, io.ErrClosedPipe
case <-ctx.Done():
return 0, ctx.Err()
}
}
// ReadSimulcast reads incoming RTCP for this RTPReceiver for given rid
func (r *RTPReceiver) ReadSimulcast(ctx context.Context, b []byte, rid string) (n int, err error) {
func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, err error) {
select {
case <-r.received:
for _, t := range r.tracks {
if t.track != nil && t.track.rid == rid {
return t.rtcpReadStream.ReadContext(ctx, b)
return t.rtcpReadStream.Read(b)
}
}
return 0, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid)
case <-r.closed:
return 0, io.ErrClosedPipe
case <-ctx.Done():
return 0, ctx.Err()
}
}
// ReadRTCP is a convenience method that wraps Read and unmarshal for you.
// It also runs any configured interceptors.
func (r *RTPReceiver) ReadRTCP(ctx context.Context) ([]rtcp.Packet, error) {
pkts, _, err := r.interceptorRTCPReader.Read(ctx)
func (r *RTPReceiver) ReadRTCP() ([]rtcp.Packet, error) {
pkts, _, err := r.interceptorRTCPReader.Read()
return pkts, err
}
// ReadRTCP is a convenience method that wraps Read and unmarshal for you
func (r *RTPReceiver) readRTCP(ctx context.Context) ([]rtcp.Packet, interceptor.Attributes, error) {
func (r *RTPReceiver) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) {
b := make([]byte, receiveMTU)
i, err := r.Read(ctx, b)
i, err := r.Read(b)
if err != nil {
return nil, nil, err
}
@@ -185,9 +180,9 @@ func (r *RTPReceiver) readRTCP(ctx context.Context) ([]rtcp.Packet, interceptor.
}
// ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you
func (r *RTPReceiver) ReadSimulcastRTCP(ctx context.Context, rid string) ([]rtcp.Packet, error) {
func (r *RTPReceiver) ReadSimulcastRTCP(rid string) ([]rtcp.Packet, error) {
b := make([]byte, receiveMTU)
i, err := r.ReadSimulcast(ctx, b, rid)
i, err := r.ReadSimulcast(b, rid)
if err != nil {
return nil, err
}
@@ -246,10 +241,10 @@ func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams {
}
// readRTP should only be called by a track, this only exists so we can keep state in one place
func (r *RTPReceiver) readRTP(ctx context.Context, b []byte, reader *TrackRemote) (n int, err error) {
func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, err error) {
<-r.received
if t := r.streamsForTrack(reader); t != nil {
return t.rtpReadStream.ReadContext(ctx, b)
return t.rtpReadStream.Read(b)
}
return 0, fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC())

View File

@@ -3,7 +3,6 @@
package webrtc
import (
"context"
"io"
"sync"
@@ -176,8 +175,8 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error {
writeStream.setRTPWriter(
r.api.interceptor.BindLocalStream(
info,
interceptor.RTPWriterFunc(func(ctx context.Context, p *rtp.Packet, attributes interceptor.Attributes) (int, error) {
return r.srtpStream.WriteRTP(ctx, &p.Header, p.Payload)
interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) {
return r.srtpStream.WriteRTP(&p.Header, p.Payload)
}),
))
@@ -209,27 +208,25 @@ func (r *RTPSender) Stop() error {
}
// Read reads incoming RTCP for this RTPReceiver
func (r *RTPSender) Read(ctx context.Context, b []byte) (n int, err error) {
func (r *RTPSender) Read(b []byte) (n int, err error) {
select {
case <-r.sendCalled:
return r.srtpStream.ReadContext(ctx, b)
return r.srtpStream.Read(b)
case <-r.stopCalled:
return 0, io.ErrClosedPipe
case <-ctx.Done():
return 0, ctx.Err()
}
}
// ReadRTCP is a convenience method that wraps Read and unmarshals for you.
// It also runs any configured interceptors.
func (r *RTPSender) ReadRTCP(ctx context.Context) ([]rtcp.Packet, error) {
pkts, _, err := r.interceptorRTCPReader.Read(ctx)
func (r *RTPSender) ReadRTCP() ([]rtcp.Packet, error) {
pkts, _, err := r.interceptorRTCPReader.Read()
return pkts, err
}
func (r *RTPSender) readRTCP(ctx context.Context) ([]rtcp.Packet, interceptor.Attributes, error) {
func (r *RTPSender) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) {
b := make([]byte, receiveMTU)
i, err := r.Read(ctx, b)
i, err := r.Read(b)
if err != nil {
return nil, nil, err
}

View File

@@ -50,7 +50,7 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) {
assert.Equal(t, uint64(1), atomic.AddUint64(&onTrackCount, 1))
for {
pkt, err := track.ReadRTP(context.Background())
pkt, err := track.ReadRTP()
if err != nil {
assert.True(t, errors.Is(io.EOF, err))
return
@@ -74,9 +74,7 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) {
case <-seenPacketA.Done():
return
default:
assert.NoError(t, trackA.WriteSample(
context.Background(), media.Sample{Data: []byte{0xAA}, Duration: time.Second},
))
assert.NoError(t, trackA.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second}))
}
}
}()
@@ -90,9 +88,7 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) {
case <-seenPacketB.Done():
return
default:
assert.NoError(t, trackB.WriteSample(
context.Background(), media.Sample{Data: []byte{0xBB}, Duration: time.Second},
))
assert.NoError(t, trackB.WriteSample(media.Sample{Data: []byte{0xBB}, Duration: time.Second}))
}
}
}()
@@ -127,9 +123,7 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) {
case <-seenPacket.Done():
return
default:
assert.NoError(t, trackA.WriteSample(
context.Background(), media.Sample{Data: []byte{0xAA}, Duration: time.Second},
))
assert.NoError(t, trackA.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second}))
}
}
}()
@@ -140,14 +134,3 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) {
assert.NoError(t, receiver.Close())
})
}
func Test_RTPSender_ContextCancel(t *testing.T) {
sender := &RTPSender{
sendCalled: make(chan struct{}),
stopCalled: make(chan struct{}),
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := sender.Read(ctx, []byte{})
assert.Equal(t, context.Canceled, err)
}

View File

@@ -3,7 +3,6 @@
package webrtc
import (
"context"
"io"
"math"
"sync"
@@ -91,7 +90,7 @@ func (r *SCTPTransport) GetCapabilities() SCTPCapabilities {
// Start the SCTPTransport. Since both local and remote parties must mutually
// create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish
// a connection over SCTP.
func (r *SCTPTransport) Start(ctx context.Context, remoteCaps SCTPCapabilities) error {
func (r *SCTPTransport) Start(remoteCaps SCTPCapabilities) error {
if r.isStarted {
return nil
}

View File

@@ -3,12 +3,11 @@
package webrtc
import (
"context"
"io"
"sync/atomic"
"github.com/pion/rtp"
"github.com/pion/srtp/v2"
"github.com/pion/srtp"
)
// srtpWriterFuture blocks Read/Write calls until
@@ -19,14 +18,12 @@ type srtpWriterFuture struct {
rtpWriteStream atomic.Value // *srtp.WriteStreamSRTP
}
func (s *srtpWriterFuture) init(ctx context.Context, returnWhenNoSRTP bool) error {
func (s *srtpWriterFuture) init(returnWhenNoSRTP bool) error {
if returnWhenNoSRTP {
select {
case <-s.rtpSender.stopCalled:
return io.ErrClosedPipe
case <-s.rtpSender.transport.srtpReady:
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
@@ -35,8 +32,6 @@ func (s *srtpWriterFuture) init(ctx context.Context, returnWhenNoSRTP bool) erro
case <-s.rtpSender.stopCalled:
return io.ErrClosedPipe
case <-s.rtpSender.transport.srtpReady:
case <-ctx.Done():
return ctx.Err()
}
}
@@ -73,38 +68,38 @@ func (s *srtpWriterFuture) Close() error {
return nil
}
func (s *srtpWriterFuture) ReadContext(ctx context.Context, b []byte) (n int, err error) {
func (s *srtpWriterFuture) Read(b []byte) (n int, err error) {
if value := s.rtcpReadStream.Load(); value != nil {
return value.(*srtp.ReadStreamSRTCP).ReadContext(ctx, b)
return value.(*srtp.ReadStreamSRTCP).Read(b)
}
if err := s.init(ctx, false); err != nil || s.rtcpReadStream.Load() == nil {
if err := s.init(false); err != nil || s.rtcpReadStream.Load() == nil {
return 0, err
}
return s.ReadContext(ctx, b)
return s.Read(b)
}
func (s *srtpWriterFuture) WriteRTP(ctx context.Context, header *rtp.Header, payload []byte) (int, error) {
func (s *srtpWriterFuture) WriteRTP(header *rtp.Header, payload []byte) (int, error) {
if value := s.rtpWriteStream.Load(); value != nil {
return value.(*srtp.WriteStreamSRTP).WriteRTP(ctx, header, payload)
return value.(*srtp.WriteStreamSRTP).WriteRTP(header, payload)
}
if err := s.init(ctx, true); err != nil || s.rtpWriteStream.Load() == nil {
if err := s.init(true); err != nil || s.rtpWriteStream.Load() == nil {
return 0, err
}
return s.WriteRTP(ctx, header, payload)
return s.WriteRTP(header, payload)
}
func (s *srtpWriterFuture) Write(ctx context.Context, b []byte) (int, error) {
func (s *srtpWriterFuture) Write(b []byte) (int, error) {
if value := s.rtpWriteStream.Load(); value != nil {
return value.(*srtp.WriteStreamSRTP).WriteContext(ctx, b)
return value.(*srtp.WriteStreamSRTP).Write(b)
}
if err := s.init(ctx, true); err != nil || s.rtpWriteStream.Load() == nil {
if err := s.init(true); err != nil || s.rtpWriteStream.Load() == nil {
return 0, err
}
return s.Write(ctx, b)
return s.Write(b)
}

View File

@@ -1,18 +1,14 @@
package webrtc
import (
"context"
"github.com/pion/rtp"
)
import "github.com/pion/rtp"
// TrackLocalWriter is the Writer for outbound RTP Packets
type TrackLocalWriter interface {
// WriteRTP encrypts a RTP packet and writes to the connection
WriteRTP(ctx context.Context, header *rtp.Header, payload []byte) (int, error)
WriteRTP(header *rtp.Header, payload []byte) (int, error)
// Write encrypts and writes a full RTP packet
Write(ctx context.Context, b []byte) (int, error)
Write(b []byte) (int, error)
}
// TrackLocalContext is the Context passed when a TrackLocal has been Binded/Unbinded from a PeerConnection, and used

View File

@@ -3,7 +3,6 @@
package webrtc
import (
"context"
"strings"
"sync"
@@ -103,7 +102,7 @@ func (s *TrackLocalStaticRTP) Kind() RTPCodecType {
// If one PeerConnection fails the packets will still be sent to
// all PeerConnections. The error message will contain the ID of the failed
// PeerConnections so you can remove them
func (s *TrackLocalStaticRTP) WriteRTP(ctx context.Context, p *rtp.Packet) error {
func (s *TrackLocalStaticRTP) WriteRTP(p *rtp.Packet) error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -113,7 +112,7 @@ func (s *TrackLocalStaticRTP) WriteRTP(ctx context.Context, p *rtp.Packet) error
for _, b := range s.bindings {
outboundPacket.Header.SSRC = uint32(b.ssrc)
outboundPacket.Header.PayloadType = uint8(b.payloadType)
if _, err := b.writeStream.WriteRTP(ctx, &outboundPacket.Header, outboundPacket.Payload); err != nil {
if _, err := b.writeStream.WriteRTP(&outboundPacket.Header, outboundPacket.Payload); err != nil {
writeErrs = append(writeErrs, err)
}
}
@@ -125,13 +124,13 @@ func (s *TrackLocalStaticRTP) WriteRTP(ctx context.Context, p *rtp.Packet) error
// If one PeerConnection fails the packets will still be sent to
// all PeerConnections. The error message will contain the ID of the failed
// PeerConnections so you can remove them
func (s *TrackLocalStaticRTP) Write(ctx context.Context, b []byte) (n int, err error) {
func (s *TrackLocalStaticRTP) Write(b []byte) (n int, err error) {
packet := &rtp.Packet{}
if err = packet.Unmarshal(b); err != nil {
return 0, err
}
return len(b), s.WriteRTP(ctx, packet)
return len(b), s.WriteRTP(packet)
}
// TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples.
@@ -209,7 +208,7 @@ func (s *TrackLocalStaticSample) Unbind(t TrackLocalContext) error {
// If one PeerConnection fails the packets will still be sent to
// all PeerConnections. The error message will contain the ID of the failed
// PeerConnections so you can remove them
func (s *TrackLocalStaticSample) WriteSample(ctx context.Context, sample media.Sample) error {
func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error {
s.rtpTrack.mu.RLock()
p := s.packetizer
clockRate := s.clockRate
@@ -224,7 +223,7 @@ func (s *TrackLocalStaticSample) WriteSample(ctx context.Context, sample media.S
writeErrs := []error{}
for _, p := range packets {
if err := s.rtpTrack.WriteRTP(ctx, p); err != nil {
if err := s.rtpTrack.WriteRTP(p); err != nil {
writeErrs = append(writeErrs, err)
}
}

View File

@@ -185,7 +185,7 @@ func Test_TrackLocalStatic_Mutate_Input(t *testing.T) {
assert.NoError(t, signalPair(pcOffer, pcAnswer))
pkt := &rtp.Packet{Header: rtp.Header{SSRC: 1, PayloadType: 1}}
assert.NoError(t, vp8Writer.WriteRTP(context.Background(), pkt))
assert.NoError(t, vp8Writer.WriteRTP(pkt))
assert.Equal(t, pkt.Header.SSRC, uint32(1))
assert.Equal(t, pkt.Header.PayloadType, uint8(1))
@@ -224,7 +224,7 @@ func Test_TrackLocalStatic_Binding_NonBlocking(t *testing.T) {
assert.NoError(t, err)
assert.NoError(t, pcAnswer.SetLocalDescription(answer))
_, err = vp8Writer.Write(context.Background(), make([]byte, 20))
_, err = vp8Writer.Write(make([]byte, 20))
assert.NoError(t, err)
assert.NoError(t, pcOffer.Close())

View File

@@ -3,7 +3,6 @@
package webrtc
import (
"context"
"sync"
"github.com/pion/interceptor"
@@ -126,7 +125,7 @@ func (t *TrackRemote) Codec() RTPCodecParameters {
}
// Read reads data from the track.
func (t *TrackRemote) Read(ctx context.Context, b []byte) (n int, err error) {
func (t *TrackRemote) Read(b []byte) (n int, err error) {
t.mu.RLock()
r := t.receiver
peeked := t.peeked != nil
@@ -145,12 +144,12 @@ func (t *TrackRemote) Read(ctx context.Context, b []byte) (n int, err error) {
}
}
return r.readRTP(ctx, b, t)
return r.readRTP(b, t)
}
// peek is like Read, but it doesn't discard the packet read
func (t *TrackRemote) peek(ctx context.Context, b []byte) (n int, err error) {
n, err = t.Read(ctx, b)
func (t *TrackRemote) peek(b []byte) (n int, err error) {
n, err = t.Read(b)
if err != nil {
return
}
@@ -168,14 +167,14 @@ func (t *TrackRemote) peek(ctx context.Context, b []byte) (n int, err error) {
// ReadRTP is a convenience method that wraps Read and unmarshals for you.
// It also runs any configured interceptors.
func (t *TrackRemote) ReadRTP(ctx context.Context) (*rtp.Packet, error) {
p, _, err := t.interceptorRTPReader.Read(ctx)
func (t *TrackRemote) ReadRTP() (*rtp.Packet, error) {
p, _, err := t.interceptorRTPReader.Read()
return p, err
}
func (t *TrackRemote) readRTP(ctx context.Context) (*rtp.Packet, interceptor.Attributes, error) {
func (t *TrackRemote) readRTP() (*rtp.Packet, interceptor.Attributes, error) {
b := make([]byte, receiveMTU)
i, err := t.Read(ctx, b)
i, err := t.Read(b)
if err != nil {
return nil, nil, err
}
@@ -189,9 +188,9 @@ func (t *TrackRemote) readRTP(ctx context.Context) (*rtp.Packet, interceptor.Att
// determinePayloadType blocks and reads a single packet to determine the PayloadType for this Track
// this is useful because we can't announce it to the user until we know the payloadType
func (t *TrackRemote) determinePayloadType(ctx context.Context) error {
func (t *TrackRemote) determinePayloadType() error {
b := make([]byte, receiveMTU)
n, err := t.peek(ctx, b)
n, err := t.peek(b)
if err != nil {
return err
}