mirror of
https://github.com/pion/webrtc.git
synced 2025-09-27 03:25:58 +08:00
Read/Write RTP/RTCP packets with context
Control cancel/timeout by context.
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -112,19 +113,19 @@ func (s *testORTCStack) setSignal(sig *testORTCSignal, isOffer bool) error {
|
||||
}
|
||||
|
||||
// Start the ICE transport
|
||||
err = s.ice.Start(nil, sig.ICEParameters, &iceRole)
|
||||
err = s.ice.Start(context.Background(), nil, sig.ICEParameters, &iceRole)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start the DTLS transport
|
||||
err = s.dtls.Start(sig.DTLSParameters)
|
||||
err = s.dtls.Start(context.Background(), sig.DTLSParameters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start the SCTP transport
|
||||
err = s.sctp.Start(sig.SCTPCapabilities)
|
||||
err = s.sctp.Start(context.Background(), sig.SCTPCapabilities)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
@@ -17,7 +18,7 @@ import (
|
||||
|
||||
"github.com/pion/dtls/v2"
|
||||
"github.com/pion/dtls/v2/pkg/crypto/fingerprint"
|
||||
"github.com/pion/srtp"
|
||||
"github.com/pion/srtp/v2"
|
||||
"github.com/pion/webrtc/v3/internal/mux"
|
||||
"github.com/pion/webrtc/v3/internal/util"
|
||||
"github.com/pion/webrtc/v3/pkg/rtcerr"
|
||||
@@ -146,7 +147,7 @@ func (t *DTLSTransport) GetRemoteCertificate() []byte {
|
||||
return t.remoteCertificate
|
||||
}
|
||||
|
||||
func (t *DTLSTransport) startSRTP() error {
|
||||
func (t *DTLSTransport) startSRTP(ctx context.Context) error {
|
||||
srtpConfig := &srtp.Config{
|
||||
Profile: t.srtpProtectionProfile,
|
||||
LoggerFactory: t.api.settingEngine.LoggerFactory,
|
||||
@@ -185,12 +186,12 @@ func (t *DTLSTransport) startSRTP() error {
|
||||
return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
|
||||
}
|
||||
|
||||
srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
|
||||
srtpSession, err := srtp.NewSessionSRTP(ctx, t.srtpEndpoint, srtpConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
|
||||
}
|
||||
|
||||
srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig)
|
||||
srtcpSession, err := srtp.NewSessionSRTCP(ctx, t.srtcpEndpoint, srtpConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err)
|
||||
}
|
||||
@@ -244,7 +245,7 @@ func (t *DTLSTransport) role() DTLSRole {
|
||||
}
|
||||
|
||||
// Start DTLS transport negotiation with the parameters of the remote DTLS transport
|
||||
func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
|
||||
func (t *DTLSTransport) Start(ctx context.Context, remoteParameters DTLSParameters) error {
|
||||
// Take lock and prepare connection, we must not hold the lock
|
||||
// when connecting
|
||||
prepareTransport := func() (DTLSRole, *dtls.Config, error) {
|
||||
@@ -350,7 +351,7 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return t.startSRTP()
|
||||
return t.startSRTP(ctx)
|
||||
}
|
||||
|
||||
// Stop stops and closes the DTLSTransport object.
|
||||
|
@@ -100,6 +100,7 @@ func TestE2E_Audio(t *testing.T) {
|
||||
go func() {
|
||||
for {
|
||||
if err := track.WriteSample(
|
||||
context.Background(),
|
||||
media.Sample{Data: silentOpusFrame, Duration: time.Millisecond * 20},
|
||||
); err != nil {
|
||||
t.Errorf("Failed to WriteSample: %v", err)
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -53,7 +54,9 @@ func main() { // nolint:gocognit
|
||||
go func() {
|
||||
ticker := time.NewTicker(rtcpPLIInterval)
|
||||
for range ticker.C {
|
||||
if rtcpSendErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(remoteTrack.SSRC())}}); rtcpSendErr != nil {
|
||||
if rtcpSendErr := peerConnection.WriteRTCP(
|
||||
context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(remoteTrack.SSRC())}},
|
||||
); rtcpSendErr != nil {
|
||||
fmt.Println(rtcpSendErr)
|
||||
}
|
||||
}
|
||||
@@ -68,13 +71,13 @@ func main() { // nolint:gocognit
|
||||
|
||||
rtpBuf := make([]byte, 1400)
|
||||
for {
|
||||
i, readErr := remoteTrack.Read(rtpBuf)
|
||||
i, readErr := remoteTrack.Read(context.TODO(), rtpBuf)
|
||||
if readErr != nil {
|
||||
panic(readErr)
|
||||
}
|
||||
|
||||
// ErrClosedPipe means we don't have any subscribers, this is ok if no peers have connected yet
|
||||
if _, err = localTrack.Write(rtpBuf[:i]); err != nil && !errors.Is(err, io.ErrClosedPipe) {
|
||||
if _, err = localTrack.Write(context.TODO(), rtpBuf[:i]); err != nil && !errors.Is(err, io.ErrClosedPipe) {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
@@ -74,7 +74,7 @@ func main() {
|
||||
}
|
||||
|
||||
time.Sleep(sleepTime)
|
||||
if ivfErr = videoTrack.WriteSample(media.Sample{Data: frame, Duration: time.Second}); ivfErr != nil {
|
||||
if ivfErr = videoTrack.WriteSample(context.TODO(), media.Sample{Data: frame, Duration: time.Second}); ivfErr != nil {
|
||||
panic(ivfErr)
|
||||
}
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"time"
|
||||
@@ -110,7 +111,7 @@ func main() {
|
||||
}
|
||||
|
||||
// Start the ICE transport
|
||||
err = ice.Start(nil, remoteSignal.ICEParameters, &iceRole)
|
||||
err = ice.Start(context.TODO(), nil, remoteSignal.ICEParameters, &iceRole)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"time"
|
||||
@@ -111,19 +112,19 @@ func main() {
|
||||
}
|
||||
|
||||
// Start the ICE transport
|
||||
err = ice.Start(nil, remoteSignal.ICEParameters, &iceRole)
|
||||
err = ice.Start(context.TODO(), nil, remoteSignal.ICEParameters, &iceRole)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Start the DTLS transport
|
||||
err = dtls.Start(remoteSignal.DTLSParameters)
|
||||
err = dtls.Start(context.TODO(), remoteSignal.DTLSParameters)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Start the SCTP transport
|
||||
err = sctp.Start(remoteSignal.SCTPCapabilities)
|
||||
err = sctp.Start(context.TODO(), remoteSignal.SCTPCapabilities)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
@@ -145,7 +146,7 @@ func writeVideoToTrack(t *webrtc.TrackLocalStaticSample) {
|
||||
}
|
||||
|
||||
time.Sleep(sleepTime)
|
||||
if err = t.WriteSample(media.Sample{Data: frame, Duration: time.Second}); err != nil {
|
||||
if err = t.WriteSample(context.TODO(), media.Sample{Data: frame, Duration: time.Second}); err != nil {
|
||||
fmt.Printf("Finish writing video track: %s ", err)
|
||||
return
|
||||
}
|
||||
|
@@ -87,7 +87,7 @@ func main() {
|
||||
}
|
||||
|
||||
time.Sleep(sleepTime)
|
||||
if ivfErr = videoTrack.WriteSample(media.Sample{Data: frame, Duration: time.Second}); ivfErr != nil {
|
||||
if ivfErr = videoTrack.WriteSample(context.TODO(), media.Sample{Data: frame, Duration: time.Second}); ivfErr != nil {
|
||||
panic(ivfErr)
|
||||
}
|
||||
}
|
||||
@@ -138,7 +138,7 @@ func main() {
|
||||
lastGranule = pageHeader.GranulePosition
|
||||
sampleDuration := time.Duration((sampleCount/48000)*1000) * time.Millisecond
|
||||
|
||||
if oggErr = audioTrack.WriteSample(media.Sample{Data: pageData, Duration: sampleDuration}); oggErr != nil {
|
||||
if oggErr = audioTrack.WriteSample(context.TODO(), media.Sample{Data: pageData, Duration: sampleDuration}); oggErr != nil {
|
||||
panic(oggErr)
|
||||
}
|
||||
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -72,7 +73,9 @@ func main() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second * 3)
|
||||
for range ticker.C {
|
||||
errSend := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}})
|
||||
errSend := peerConnection.WriteRTCP(
|
||||
context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}},
|
||||
)
|
||||
if errSend != nil {
|
||||
fmt.Println(errSend)
|
||||
}
|
||||
@@ -82,12 +85,12 @@ func main() {
|
||||
fmt.Printf("Track has started, of type %d: %s \n", track.PayloadType(), track.Codec().MimeType)
|
||||
for {
|
||||
// Read RTP packets being sent to Pion
|
||||
rtp, readErr := track.ReadRTP()
|
||||
rtp, readErr := track.ReadRTP(context.TODO())
|
||||
if readErr != nil {
|
||||
panic(readErr)
|
||||
}
|
||||
|
||||
if writeErr := outputTrack.WriteRTP(rtp); writeErr != nil {
|
||||
if writeErr := outputTrack.WriteRTP(context.TODO(), rtp); writeErr != nil {
|
||||
panic(writeErr)
|
||||
}
|
||||
}
|
||||
|
@@ -107,7 +107,9 @@ func main() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second * 2)
|
||||
for range ticker.C {
|
||||
if rtcpErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}); rtcpErr != nil {
|
||||
if rtcpErr := peerConnection.WriteRTCP(
|
||||
context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}},
|
||||
); rtcpErr != nil {
|
||||
fmt.Println(rtcpErr)
|
||||
}
|
||||
}
|
||||
@@ -116,7 +118,7 @@ func main() {
|
||||
b := make([]byte, 1500)
|
||||
for {
|
||||
// Read
|
||||
n, readErr := track.Read(b)
|
||||
n, readErr := track.Read(context.TODO(), b)
|
||||
if readErr != nil {
|
||||
panic(readErr)
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
@@ -103,7 +104,7 @@ func main() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if _, writeErr := videoTrack.Write(inboundRTPPacket[:n]); writeErr != nil {
|
||||
if _, writeErr := videoTrack.Write(context.TODO(), inboundRTPPacket[:n]); writeErr != nil {
|
||||
panic(writeErr)
|
||||
}
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
@@ -23,7 +24,7 @@ func saveToDisk(i media.Writer, track *webrtc.TrackRemote) {
|
||||
}()
|
||||
|
||||
for {
|
||||
rtpPacket, err := track.ReadRTP()
|
||||
rtpPacket, err := track.ReadRTP(context.TODO())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -96,7 +97,9 @@ func main() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second * 3)
|
||||
for range ticker.C {
|
||||
errSend := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}})
|
||||
errSend := peerConnection.WriteRTCP(
|
||||
context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}},
|
||||
)
|
||||
if errSend != nil {
|
||||
fmt.Println(errSend)
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -80,24 +81,28 @@ func main() {
|
||||
ticker := time.NewTicker(3 * time.Second)
|
||||
for range ticker.C {
|
||||
fmt.Printf("Sending pli for stream with rid: %q, ssrc: %d\n", track.RID(), track.SSRC())
|
||||
if writeErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}); writeErr != nil {
|
||||
if writeErr := peerConnection.WriteRTCP(
|
||||
context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}},
|
||||
); writeErr != nil {
|
||||
fmt.Println(writeErr)
|
||||
}
|
||||
// Send a remb message with a very high bandwidth to trigger chrome to send also the high bitrate stream
|
||||
fmt.Printf("Sending remb for stream with rid: %q, ssrc: %d\n", track.RID(), track.SSRC())
|
||||
if writeErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.ReceiverEstimatedMaximumBitrate{Bitrate: 10000000, SenderSSRC: uint32(track.SSRC())}}); writeErr != nil {
|
||||
if writeErr := peerConnection.WriteRTCP(
|
||||
context.TODO(), []rtcp.Packet{&rtcp.ReceiverEstimatedMaximumBitrate{Bitrate: 10000000, SenderSSRC: uint32(track.SSRC())}},
|
||||
); writeErr != nil {
|
||||
fmt.Println(writeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
for {
|
||||
// Read RTP packets being sent to Pion
|
||||
packet, readErr := track.ReadRTP()
|
||||
packet, readErr := track.ReadRTP(context.TODO())
|
||||
if readErr != nil {
|
||||
panic(readErr)
|
||||
}
|
||||
|
||||
if writeErr := outputTracks[rid].WriteRTP(packet); writeErr != nil && !errors.Is(writeErr, io.ErrClosedPipe) {
|
||||
if writeErr := outputTracks[rid].WriteRTP(context.TODO(), packet); writeErr != nil && !errors.Is(writeErr, io.ErrClosedPipe) {
|
||||
panic(writeErr)
|
||||
}
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -85,7 +86,7 @@ func main() { // nolint:gocognit
|
||||
var isCurrTrack bool
|
||||
for {
|
||||
// Read RTP packets being sent to Pion
|
||||
rtp, readErr := track.ReadRTP()
|
||||
rtp, readErr := track.ReadRTP(context.TODO())
|
||||
if readErr != nil {
|
||||
panic(readErr)
|
||||
}
|
||||
@@ -104,7 +105,9 @@ func main() { // nolint:gocognit
|
||||
// If just switched to this track, send PLI to get picture refresh
|
||||
if !isCurrTrack {
|
||||
isCurrTrack = true
|
||||
if writeErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}); writeErr != nil {
|
||||
if writeErr := peerConnection.WriteRTCP(
|
||||
context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}},
|
||||
); writeErr != nil {
|
||||
fmt.Println(writeErr)
|
||||
}
|
||||
}
|
||||
@@ -154,7 +157,7 @@ func main() { // nolint:gocognit
|
||||
// Keep an increasing sequence number
|
||||
packet.SequenceNumber = i
|
||||
// Write out the packet, ignoring closed pipe if nobody is listening
|
||||
if err := outputTrack.WriteRTP(packet); err != nil && !errors.Is(err, io.ErrClosedPipe) {
|
||||
if err := outputTrack.WriteRTP(context.TODO(), packet); err != nil && !errors.Is(err, io.ErrClosedPipe) {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
6
go.mod
6
go.mod
@@ -6,7 +6,7 @@ require (
|
||||
github.com/pion/datachannel v1.4.21
|
||||
github.com/pion/dtls/v2 v2.0.4
|
||||
github.com/pion/ice/v2 v2.0.13
|
||||
github.com/pion/interceptor v0.0.3
|
||||
github.com/pion/interceptor v0.0.4
|
||||
github.com/pion/logging v0.2.2
|
||||
github.com/pion/quic v0.1.4
|
||||
github.com/pion/randutil v0.1.0
|
||||
@@ -14,8 +14,8 @@ require (
|
||||
github.com/pion/rtp v1.6.1
|
||||
github.com/pion/sctp v1.7.11
|
||||
github.com/pion/sdp/v3 v3.0.3
|
||||
github.com/pion/srtp v1.5.2
|
||||
github.com/pion/transport v0.11.0
|
||||
github.com/pion/srtp/v2 v2.0.0-rc.1
|
||||
github.com/pion/transport v0.11.1
|
||||
github.com/sclevine/agouti v3.0.0+incompatible
|
||||
github.com/stretchr/testify v1.6.1
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b
|
||||
|
10
go.sum
10
go.sum
@@ -106,8 +106,8 @@ github.com/pion/dtls/v2 v2.0.4 h1:WuUcqi6oYMu/noNTz92QrF1DaFj4eXbhQ6dzaaAwOiI=
|
||||
github.com/pion/dtls/v2 v2.0.4/go.mod h1:qAkFscX0ZHoI1E07RfYPoRw3manThveu+mlTDdOxoGI=
|
||||
github.com/pion/ice/v2 v2.0.13 h1:lVe7g86tQ0vKdH430hQR/t7zV1oeXbK75130TUArrnw=
|
||||
github.com/pion/ice/v2 v2.0.13/go.mod h1:mZlypgoynMn2ayhGsjrPY/G/WiRiYO8WCPC6gUeg1RA=
|
||||
github.com/pion/interceptor v0.0.3 h1:VQtmPts/2IgYQtb9sZLTp6B0kIdHE5zBMQ6tCcgdJcM=
|
||||
github.com/pion/interceptor v0.0.3/go.mod h1:lPVrf5xfosI989ZcmgPS4WwwRhd+XAyTFaYI2wHf7nU=
|
||||
github.com/pion/interceptor v0.0.4 h1:FNq8cFDDv0i+Db9oEKccmMx4rerPm6pzBf4szjNaX2E=
|
||||
github.com/pion/interceptor v0.0.4/go.mod h1:lPVrf5xfosI989ZcmgPS4WwwRhd+XAyTFaYI2wHf7nU=
|
||||
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
|
||||
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
||||
github.com/pion/mdns v0.0.4 h1:O4vvVqr4DGX63vzmO6Fw9vpy3lfztVWHGCQfyw0ZLSY=
|
||||
@@ -125,8 +125,8 @@ github.com/pion/sctp v1.7.11 h1:UCnj7MsobLKLuP/Hh+JMiI/6W5Bs/VF45lWKgHFjSIE=
|
||||
github.com/pion/sctp v1.7.11/go.mod h1:EhpTUQu1/lcK3xI+eriS6/96fWetHGCvBi9MSsnaBN0=
|
||||
github.com/pion/sdp/v3 v3.0.3 h1:gJK9hk+JFD2NGIM1nXmqNCq1DkVaIZ9dlA3u3otnkaw=
|
||||
github.com/pion/sdp/v3 v3.0.3/go.mod h1:bNiSknmJE0HYBprTHXKPQ3+JjacTv5uap92ueJZKsRk=
|
||||
github.com/pion/srtp v1.5.2 h1:25DmvH+fqKZDqvX64vTwnycVwL9ooJxHF/gkX16bDBY=
|
||||
github.com/pion/srtp v1.5.2/go.mod h1:NiBff/MSxUwMUwx/fRNyD/xGE+dVvf8BOCeXhjCXZ9U=
|
||||
github.com/pion/srtp/v2 v2.0.0-rc.1 h1:UGabuCAIE5Yn5qmFr9H8UGdzYdYIaMt/AZLskcBcumA=
|
||||
github.com/pion/srtp/v2 v2.0.0-rc.1/go.mod h1:i3Oyh9/RPIdYRXywhs0NwTkMV/XYYr8NGj1E3gsRgk8=
|
||||
github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg=
|
||||
github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA=
|
||||
github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRhSzR+CZ8=
|
||||
@@ -135,6 +135,8 @@ github.com/pion/transport v0.10.1 h1:2W+yJT+0mOQ160ThZYUx5Zp2skzshiNgxrNE9GUfhJM
|
||||
github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+DwKOzf4A=
|
||||
github.com/pion/transport v0.11.0 h1:Z1RhzqrWPPYj5Xed8P7pirTKTvXFoxDI3uJuuKu6akM=
|
||||
github.com/pion/transport v0.11.0/go.mod h1:ORH8Ouyl1enoJyHwU+MwMeQocWbeorEk5068FOsHjog=
|
||||
github.com/pion/transport v0.11.1 h1:z+6FJ3T7R4pX87efsiFLmIRLYVLLk7XZI76kHkI5V10=
|
||||
github.com/pion/transport v0.11.1/go.mod h1:ORH8Ouyl1enoJyHwU+MwMeQocWbeorEk5068FOsHjog=
|
||||
github.com/pion/turn/v2 v2.0.5 h1:iwMHqDfPEDEOFzwWKT56eFmh6DYC6o/+xnLAEzgISbA=
|
||||
github.com/pion/turn/v2 v2.0.5/go.mod h1:APg43CFyt/14Uy7heYUOGWdkem/Wu4PhCO/bjyrTqMw=
|
||||
github.com/pion/udp v0.1.0 h1:uGxQsNyrqG3GLINv36Ff60covYmfrLoxzwnCsIYspXI=
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/pion/ice/v2"
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/transport/connctx"
|
||||
"github.com/pion/webrtc/v3/internal/mux"
|
||||
)
|
||||
|
||||
@@ -70,7 +71,7 @@ func NewICETransport(gatherer *ICEGatherer, loggerFactory logging.LoggerFactory)
|
||||
}
|
||||
|
||||
// Start incoming connectivity checks based on its configured role.
|
||||
func (t *ICETransport) Start(gatherer *ICEGatherer, params ICEParameters, role *ICERole) error {
|
||||
func (t *ICETransport) Start(ctx context.Context, gatherer *ICEGatherer, params ICEParameters, role *ICERole) error {
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
|
||||
@@ -142,11 +143,11 @@ func (t *ICETransport) Start(gatherer *ICEGatherer, params ICEParameters, role *
|
||||
t.conn = iceConn
|
||||
|
||||
config := mux.Config{
|
||||
Conn: t.conn,
|
||||
Conn: connctx.New(t.conn),
|
||||
BufferSize: receiveMTU,
|
||||
LoggerFactory: t.loggerFactory,
|
||||
}
|
||||
t.mux = mux.NewMux(config)
|
||||
t.mux = mux.NewMux(ctx, config)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -25,19 +26,19 @@ type testInterceptor struct {
|
||||
}
|
||||
|
||||
func (t *testInterceptor) BindLocalStream(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
|
||||
return interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) {
|
||||
return interceptor.RTPWriterFunc(func(ctx context.Context, p *rtp.Packet, attributes interceptor.Attributes) (int, error) {
|
||||
// set extension on outgoing packet
|
||||
p.Header.Extension = true
|
||||
p.Header.ExtensionProfile = 0xBEDE
|
||||
assert.NoError(t.t, p.Header.SetExtension(t.extensionID, []byte("write")))
|
||||
|
||||
return writer.Write(p, attributes)
|
||||
return writer.Write(ctx, p, attributes)
|
||||
})
|
||||
}
|
||||
|
||||
func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
|
||||
return interceptor.RTPReaderFunc(func() (*rtp.Packet, interceptor.Attributes, error) {
|
||||
p, attributes, err := reader.Read()
|
||||
return interceptor.RTPReaderFunc(func(ctx context.Context) (*rtp.Packet, interceptor.Attributes, error) {
|
||||
p, attributes, err := reader.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -49,7 +50,7 @@ func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader
|
||||
// write back a pli
|
||||
rtcpWriter := t.rtcpWriter.Load().(interceptor.RTCPWriter)
|
||||
pli := &rtcp.PictureLossIndication{SenderSSRC: info.SSRC, MediaSSRC: info.SSRC}
|
||||
_, err = rtcpWriter.Write([]rtcp.Packet{pli}, make(interceptor.Attributes))
|
||||
_, err = rtcpWriter.Write(ctx, []rtcp.Packet{pli}, make(interceptor.Attributes))
|
||||
assert.NoError(t.t, err)
|
||||
|
||||
return p, attributes, nil
|
||||
@@ -57,8 +58,8 @@ func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader
|
||||
}
|
||||
|
||||
func (t *testInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader {
|
||||
return interceptor.RTCPReaderFunc(func() ([]rtcp.Packet, interceptor.Attributes, error) {
|
||||
pkts, attributes, err := reader.Read()
|
||||
return interceptor.RTCPReaderFunc(func(ctx context.Context) ([]rtcp.Packet, interceptor.Attributes, error) {
|
||||
pkts, attributes, err := reader.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -122,7 +123,7 @@ func TestPeerConnection_Interceptor(t *testing.T) {
|
||||
wg.Add(1)
|
||||
*pending++
|
||||
receiverPC.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) {
|
||||
p, readErr := track.ReadRTP()
|
||||
p, readErr := track.ReadRTP(context.Background())
|
||||
if readErr != nil {
|
||||
t.Fatal(readErr)
|
||||
}
|
||||
@@ -133,7 +134,7 @@ func TestPeerConnection_Interceptor(t *testing.T) {
|
||||
wg.Done()
|
||||
|
||||
for {
|
||||
_, readErr = track.ReadRTP()
|
||||
_, readErr = track.ReadRTP(context.Background())
|
||||
if readErr != nil {
|
||||
return
|
||||
}
|
||||
@@ -143,13 +144,13 @@ func TestPeerConnection_Interceptor(t *testing.T) {
|
||||
wg.Add(1)
|
||||
*pending++
|
||||
go func() {
|
||||
_, readErr := sender.ReadRTCP()
|
||||
_, readErr := sender.ReadRTCP(context.Background())
|
||||
assert.NoError(t, readErr)
|
||||
atomic.AddInt32(pending, -1)
|
||||
wg.Done()
|
||||
|
||||
for {
|
||||
_, readErr = sender.ReadRTCP()
|
||||
_, readErr = sender.ReadRTCP(context.Background())
|
||||
if readErr != nil {
|
||||
return
|
||||
}
|
||||
@@ -166,7 +167,9 @@ func TestPeerConnection_Interceptor(t *testing.T) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
if routineErr := track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}); routineErr != nil {
|
||||
if routineErr := track.WriteSample(
|
||||
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
|
||||
); routineErr != nil {
|
||||
t.Error(routineErr)
|
||||
return
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/pion/interceptor"
|
||||
@@ -18,12 +19,12 @@ func (i *interceptorTrackLocalWriter) setRTPWriter(writer interceptor.RTPWriter)
|
||||
i.rtpWriter.Store(writer)
|
||||
}
|
||||
|
||||
func (i *interceptorTrackLocalWriter) WriteRTP(header *rtp.Header, payload []byte) (int, error) {
|
||||
func (i *interceptorTrackLocalWriter) WriteRTP(ctx context.Context, header *rtp.Header, payload []byte) (int, error) {
|
||||
writer := i.rtpWriter.Load().(interceptor.RTPWriter)
|
||||
|
||||
if writer == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return writer.Write(&rtp.Packet{Header: *header, Payload: payload}, make(interceptor.Attributes))
|
||||
return writer.Write(ctx, &rtp.Packet{Header: *header, Payload: payload}, make(interceptor.Attributes))
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
@@ -34,12 +35,23 @@ func (e *Endpoint) close() error {
|
||||
// Read reads a packet of len(p) bytes from the underlying conn
|
||||
// that are matched by the associated MuxFunc
|
||||
func (e *Endpoint) Read(p []byte) (int, error) {
|
||||
return e.buffer.Read(p)
|
||||
return e.buffer.ReadContext(context.Background(), p)
|
||||
}
|
||||
|
||||
// ReadContext reads a packet of len(p) bytes from the underlying conn
|
||||
// that are matched by the associated MuxFunc
|
||||
func (e *Endpoint) ReadContext(ctx context.Context, p []byte) (int, error) {
|
||||
return e.buffer.ReadContext(ctx, p)
|
||||
}
|
||||
|
||||
// Write writes len(p) bytes to the underlying conn
|
||||
func (e *Endpoint) Write(p []byte) (int, error) {
|
||||
n, err := e.mux.nextConn.Write(p)
|
||||
return e.WriteContext(context.Background(), p)
|
||||
}
|
||||
|
||||
// WriteContext writes len(p) bytes to the underlying conn
|
||||
func (e *Endpoint) WriteContext(ctx context.Context, p []byte) (int, error) {
|
||||
n, err := e.mux.nextConn.WriteContext(ctx, p)
|
||||
if errors.Is(err, ice.ErrNoCandidatePairs) {
|
||||
return 0, nil
|
||||
} else if errors.Is(err, ice.ErrClosed) {
|
||||
|
@@ -2,10 +2,11 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"net"
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/transport/connctx"
|
||||
"github.com/pion/transport/packetio"
|
||||
)
|
||||
|
||||
@@ -15,7 +16,7 @@ const maxBufferSize = 1000 * 1000 // 1MB
|
||||
// Config collects the arguments to mux.Mux construction into
|
||||
// a single structure
|
||||
type Config struct {
|
||||
Conn net.Conn
|
||||
Conn connctx.ConnCtx
|
||||
BufferSize int
|
||||
LoggerFactory logging.LoggerFactory
|
||||
}
|
||||
@@ -23,7 +24,7 @@ type Config struct {
|
||||
// Mux allows multiplexing
|
||||
type Mux struct {
|
||||
lock sync.RWMutex
|
||||
nextConn net.Conn
|
||||
nextConn connctx.ConnCtx
|
||||
endpoints map[*Endpoint]MatchFunc
|
||||
bufferSize int
|
||||
closedCh chan struct{}
|
||||
@@ -32,7 +33,7 @@ type Mux struct {
|
||||
}
|
||||
|
||||
// NewMux creates a new Mux
|
||||
func NewMux(config Config) *Mux {
|
||||
func NewMux(ctx context.Context, config Config) *Mux {
|
||||
m := &Mux{
|
||||
nextConn: config.Conn,
|
||||
endpoints: make(map[*Endpoint]MatchFunc),
|
||||
@@ -41,7 +42,7 @@ func NewMux(config Config) *Mux {
|
||||
log: config.LoggerFactory.NewLogger("mux"),
|
||||
}
|
||||
|
||||
go m.readLoop()
|
||||
go m.readLoop(ctx)
|
||||
|
||||
return m
|
||||
}
|
||||
@@ -96,26 +97,26 @@ func (m *Mux) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mux) readLoop() {
|
||||
func (m *Mux) readLoop(ctx context.Context) {
|
||||
defer func() {
|
||||
close(m.closedCh)
|
||||
}()
|
||||
|
||||
buf := make([]byte, m.bufferSize)
|
||||
for {
|
||||
n, err := m.nextConn.Read(buf)
|
||||
n, err := m.nextConn.ReadContext(ctx, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = m.dispatch(buf[:n])
|
||||
err = m.dispatch(ctx, buf[:n])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mux) dispatch(buf []byte) error {
|
||||
func (m *Mux) dispatch(ctx context.Context, buf []byte) error {
|
||||
var endpoint *Endpoint
|
||||
|
||||
m.lock.Lock()
|
||||
@@ -136,7 +137,7 @@ func (m *Mux) dispatch(buf []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := endpoint.buffer.Write(buf)
|
||||
_, err := endpoint.buffer.WriteContext(ctx, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -1,29 +1,33 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/transport/connctx"
|
||||
"github.com/pion/transport/test"
|
||||
)
|
||||
|
||||
func TestStressDuplex(t *testing.T) {
|
||||
// Limit runtime in case of deadlocks
|
||||
lim := test.TimeOut(time.Second * 20)
|
||||
lim := test.TimeOut(time.Second * 30)
|
||||
defer lim.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Check for leaking routines
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
// Run the test
|
||||
stressDuplex(t)
|
||||
stressDuplex(ctx, t)
|
||||
}
|
||||
|
||||
func stressDuplex(t *testing.T) {
|
||||
ca, cb, stop := pipeMemory()
|
||||
func stressDuplex(ctx context.Context, t *testing.T) {
|
||||
ca, cb, stop := pipeMemory(ctx)
|
||||
|
||||
defer func() {
|
||||
stop(t)
|
||||
@@ -34,13 +38,21 @@ func stressDuplex(t *testing.T) {
|
||||
MsgCount: 100,
|
||||
}
|
||||
|
||||
err := test.StressDuplex(ca, cb, opt)
|
||||
t.Run("WithoutContext", func(t *testing.T) {
|
||||
err := test.StressDuplex(ca, cb.Conn(), opt)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
t.Run("WithContext", func(t *testing.T) {
|
||||
err := test.StressDuplexContext(context.Background(), ca, cb, opt)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func pipeMemory() (*Endpoint, net.Conn, func(*testing.T)) {
|
||||
func pipeMemory(ctx context.Context) (*Endpoint, connctx.ConnCtx, func(*testing.T)) {
|
||||
// In memory pipe
|
||||
ca, cb := net.Pipe()
|
||||
|
||||
@@ -49,12 +61,12 @@ func pipeMemory() (*Endpoint, net.Conn, func(*testing.T)) {
|
||||
}
|
||||
|
||||
config := Config{
|
||||
Conn: ca,
|
||||
Conn: connctx.New(ca),
|
||||
BufferSize: 8192,
|
||||
LoggerFactory: logging.NewDefaultLoggerFactory(),
|
||||
}
|
||||
|
||||
m := NewMux(config)
|
||||
m := NewMux(ctx, config)
|
||||
e := m.NewEndpoint(matchAll)
|
||||
m.RemoveEndpoint(e)
|
||||
e = m.NewEndpoint(matchAll)
|
||||
@@ -70,7 +82,7 @@ func pipeMemory() (*Endpoint, net.Conn, func(*testing.T)) {
|
||||
}
|
||||
}
|
||||
|
||||
return e, cb, stop
|
||||
return e, connctx.New(cb), stop
|
||||
}
|
||||
|
||||
func TestNoEndpoints(t *testing.T) {
|
||||
@@ -82,13 +94,16 @@ func TestNoEndpoints(t *testing.T) {
|
||||
}
|
||||
|
||||
config := Config{
|
||||
Conn: ca,
|
||||
Conn: connctx.New(ca),
|
||||
BufferSize: 8192,
|
||||
LoggerFactory: logging.NewDefaultLoggerFactory(),
|
||||
}
|
||||
|
||||
m := NewMux(config)
|
||||
err = m.dispatch(make([]byte, 1))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
m := NewMux(ctx, config)
|
||||
err = m.dispatch(ctx, make([]byte, 1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@@ -1,11 +1,12 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Operation is a function
|
||||
type operation func()
|
||||
type operation func(ctx context.Context)
|
||||
|
||||
// Operations is a task executor.
|
||||
type operations struct {
|
||||
@@ -32,7 +33,7 @@ func (o *operations) Enqueue(op operation) {
|
||||
o.mu.Unlock()
|
||||
|
||||
if !running {
|
||||
go o.start()
|
||||
go o.start(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,13 +49,13 @@ func (o *operations) IsEmpty() bool {
|
||||
func (o *operations) Done() {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
o.Enqueue(func() {
|
||||
o.Enqueue(func(_ context.Context) {
|
||||
wg.Done()
|
||||
})
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (o *operations) pop() func() {
|
||||
func (o *operations) pop() func(context.Context) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
if len(o.ops) == 0 {
|
||||
@@ -66,7 +67,7 @@ func (o *operations) pop() func() {
|
||||
return fn
|
||||
}
|
||||
|
||||
func (o *operations) start() {
|
||||
func (o *operations) start(ctx context.Context) {
|
||||
defer func() {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
@@ -76,12 +77,12 @@ func (o *operations) start() {
|
||||
}
|
||||
// either a new operation was enqueued while we
|
||||
// were busy, or an operation panicked
|
||||
go o.start()
|
||||
go o.start(ctx)
|
||||
}()
|
||||
|
||||
fn := o.pop()
|
||||
for fn != nil {
|
||||
fn()
|
||||
fn(ctx)
|
||||
fn = o.pop()
|
||||
}
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -12,7 +13,7 @@ func TestOperations_Enqueue(t *testing.T) {
|
||||
results := make([]int, 16)
|
||||
for i := range results {
|
||||
func(j int) {
|
||||
ops.Enqueue(func() {
|
||||
ops.Enqueue(func(_ context.Context) {
|
||||
results[j] = j * j
|
||||
})
|
||||
}(i)
|
||||
|
@@ -3,11 +3,11 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/sdp/v3"
|
||||
"github.com/pion/transport/connctx"
|
||||
"github.com/pion/webrtc/v3/internal/util"
|
||||
"github.com/pion/webrtc/v3/pkg/rtcerr"
|
||||
)
|
||||
@@ -275,7 +276,7 @@ func (pc *PeerConnection) onNegotiationNeeded() {
|
||||
pc.ops.Enqueue(pc.negotiationNeededOp)
|
||||
}
|
||||
|
||||
func (pc *PeerConnection) negotiationNeededOp() {
|
||||
func (pc *PeerConnection) negotiationNeededOp(ctx context.Context) {
|
||||
// Don't run NegotiatedNeeded checks if OnNegotiationNeeded is not set
|
||||
if handler := pc.onNegotiationNeededHandler.Load(); handler == nil {
|
||||
return
|
||||
@@ -945,8 +946,8 @@ func (pc *PeerConnection) SetLocalDescription(desc SessionDescription) error {
|
||||
if err := pc.startRTPSenders(currentTransceivers); err != nil {
|
||||
return err
|
||||
}
|
||||
pc.ops.Enqueue(func() {
|
||||
pc.startRTP(haveLocalDescription, remoteDesc, currentTransceivers)
|
||||
pc.ops.Enqueue(func(ctx context.Context) {
|
||||
pc.startRTP(ctx, haveLocalDescription, remoteDesc, currentTransceivers)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1066,8 +1067,8 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
|
||||
if err = pc.startRTPSenders(currentTransceivers); err != nil {
|
||||
return err
|
||||
}
|
||||
pc.ops.Enqueue(func() {
|
||||
pc.startRTP(true, &desc, currentTransceivers)
|
||||
pc.ops.Enqueue(func(ctx context.Context) {
|
||||
pc.startRTP(ctx, true, &desc, currentTransceivers)
|
||||
})
|
||||
}
|
||||
return nil
|
||||
@@ -1099,16 +1100,16 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
|
||||
}
|
||||
}
|
||||
|
||||
pc.ops.Enqueue(func() {
|
||||
pc.startTransports(iceRole, dtlsRoleFromRemoteSDP(desc.parsed), remoteUfrag, remotePwd, fingerprint, fingerprintHash)
|
||||
pc.ops.Enqueue(func(ctx context.Context) {
|
||||
pc.startTransports(ctx, iceRole, dtlsRoleFromRemoteSDP(desc.parsed), remoteUfrag, remotePwd, fingerprint, fingerprintHash)
|
||||
if weOffer {
|
||||
pc.startRTP(false, &desc, currentTransceivers)
|
||||
pc.startRTP(ctx, false, &desc, currentTransceivers)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPReceiver) {
|
||||
func (pc *PeerConnection) startReceiver(ctx context.Context, incoming trackDetails, receiver *RTPReceiver) {
|
||||
encodings := []RTPDecodingParameters{}
|
||||
if incoming.ssrc != 0 {
|
||||
encodings = append(encodings, RTPDecodingParameters{RTPCodingParameters{SSRC: incoming.ssrc}})
|
||||
@@ -1137,7 +1138,7 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := receiver.Track().determinePayloadType(); err != nil {
|
||||
if err := receiver.Track().determinePayloadType(ctx); err != nil {
|
||||
pc.log.Warnf("Could not determine PayloadType for SSRC %d", receiver.Track().SSRC())
|
||||
return
|
||||
}
|
||||
@@ -1160,7 +1161,7 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece
|
||||
}
|
||||
|
||||
// startRTPReceivers opens knows inbound SRTP streams from the RemoteDescription
|
||||
func (pc *PeerConnection) startRTPReceivers(incomingTracks []trackDetails, currentTransceivers []*RTPTransceiver) { //nolint:gocognit
|
||||
func (pc *PeerConnection) startRTPReceivers(ctx context.Context, incomingTracks []trackDetails, currentTransceivers []*RTPTransceiver) { //nolint:gocognit
|
||||
localTransceivers := append([]*RTPTransceiver{}, currentTransceivers...)
|
||||
|
||||
remoteIsPlanB := false
|
||||
@@ -1207,7 +1208,7 @@ func (pc *PeerConnection) startRTPReceivers(incomingTracks []trackDetails, curre
|
||||
continue
|
||||
}
|
||||
|
||||
pc.startReceiver(incomingTrack, t.Receiver())
|
||||
pc.startReceiver(ctx, incomingTrack, t.Receiver())
|
||||
trackHandled = true
|
||||
break
|
||||
}
|
||||
@@ -1226,7 +1227,7 @@ func (pc *PeerConnection) startRTPReceivers(incomingTracks []trackDetails, curre
|
||||
pc.log.Warnf("Could not add transceiver for remote SSRC %d: %s", incoming.ssrc, err)
|
||||
continue
|
||||
}
|
||||
pc.startReceiver(incoming, t.Receiver())
|
||||
pc.startReceiver(ctx, incoming, t.Receiver())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1253,9 +1254,9 @@ func (pc *PeerConnection) startRTPSenders(currentTransceivers []*RTPTransceiver)
|
||||
}
|
||||
|
||||
// Start SCTP subsystem
|
||||
func (pc *PeerConnection) startSCTP() {
|
||||
func (pc *PeerConnection) startSCTP(ctx context.Context) {
|
||||
// Start sctp
|
||||
if err := pc.sctpTransport.Start(SCTPCapabilities{
|
||||
if err := pc.sctpTransport.Start(ctx, SCTPCapabilities{
|
||||
MaxMessageSize: 0,
|
||||
}); err != nil {
|
||||
pc.log.Warnf("Failed to start SCTP: %s", err)
|
||||
@@ -1289,7 +1290,7 @@ func (pc *PeerConnection) startSCTP() {
|
||||
pc.sctpTransport.lock.Unlock()
|
||||
}
|
||||
|
||||
func (pc *PeerConnection) handleUndeclaredSSRC(rtpStream io.Reader, ssrc SSRC) error { //nolint:gocognit
|
||||
func (pc *PeerConnection) handleUndeclaredSSRC(ctx context.Context, rtpStream connctx.Reader, ssrc SSRC) error { //nolint:gocognit
|
||||
remoteDescription := pc.RemoteDescription()
|
||||
if remoteDescription == nil {
|
||||
return errPeerConnRemoteDescriptionNil
|
||||
@@ -1318,7 +1319,7 @@ func (pc *PeerConnection) handleUndeclaredSSRC(rtpStream io.Reader, ssrc SSRC) e
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %d: %s", errPeerConnRemoteSSRCAddTransceiver, ssrc, err)
|
||||
}
|
||||
pc.startReceiver(incoming, t.Receiver())
|
||||
pc.startReceiver(ctx, incoming, t.Receiver())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1335,7 +1336,7 @@ func (pc *PeerConnection) handleUndeclaredSSRC(rtpStream io.Reader, ssrc SSRC) e
|
||||
b := make([]byte, receiveMTU)
|
||||
var mid, rid string
|
||||
for readCount := 0; readCount <= simulcastProbeCount; readCount++ {
|
||||
i, err := rtpStream.Read(b)
|
||||
i, err := rtpStream.ReadContext(ctx, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1379,7 +1380,7 @@ func (pc *PeerConnection) handleUndeclaredSSRC(rtpStream io.Reader, ssrc SSRC) e
|
||||
}
|
||||
|
||||
// undeclaredMediaProcessor handles RTP/RTCP packets that don't match any a:ssrc lines
|
||||
func (pc *PeerConnection) undeclaredMediaProcessor() {
|
||||
func (pc *PeerConnection) undeclaredMediaProcessor(ctx context.Context) {
|
||||
go func() {
|
||||
for {
|
||||
srtpSession, err := pc.dtlsTransport.getSRTPSession()
|
||||
@@ -1394,7 +1395,7 @@ func (pc *PeerConnection) undeclaredMediaProcessor() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := pc.handleUndeclaredSSRC(stream, SSRC(ssrc)); err != nil {
|
||||
if err := pc.handleUndeclaredSSRC(ctx, stream, SSRC(ssrc)); err != nil {
|
||||
pc.log.Errorf("Incoming unhandled RTP ssrc(%d), OnTrack will not be fired. %v", ssrc, err)
|
||||
}
|
||||
}
|
||||
@@ -1753,12 +1754,12 @@ func (pc *PeerConnection) SetIdentityProvider(provider string) error {
|
||||
|
||||
// WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the
|
||||
// packet is discarded. It also runs any configured interceptors.
|
||||
func (pc *PeerConnection) WriteRTCP(pkts []rtcp.Packet) error {
|
||||
_, err := pc.interceptorRTCPWriter.Write(pkts, make(interceptor.Attributes))
|
||||
func (pc *PeerConnection) WriteRTCP(ctx context.Context, pkts []rtcp.Packet) error {
|
||||
_, err := pc.interceptorRTCPWriter.Write(ctx, pkts, make(interceptor.Attributes))
|
||||
return err
|
||||
}
|
||||
|
||||
func (pc *PeerConnection) writeRTCP(pkts []rtcp.Packet, _ interceptor.Attributes) (int, error) {
|
||||
func (pc *PeerConnection) writeRTCP(ctx context.Context, pkts []rtcp.Packet, _ interceptor.Attributes) (int, error) {
|
||||
raw, err := rtcp.Marshal(pkts)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -1774,7 +1775,7 @@ func (pc *PeerConnection) writeRTCP(pkts []rtcp.Packet, _ interceptor.Attributes
|
||||
return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err)
|
||||
}
|
||||
|
||||
if n, err := writeStream.Write(raw); err != nil {
|
||||
if n, err := writeStream.WriteContext(ctx, raw); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return 0, nil
|
||||
@@ -1985,9 +1986,10 @@ func (pc *PeerConnection) GetStats() StatsReport {
|
||||
}
|
||||
|
||||
// Start all transports. PeerConnection now has enough state
|
||||
func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, remoteUfrag, remotePwd, fingerprint, fingerprintHash string) {
|
||||
func (pc *PeerConnection) startTransports(ctx context.Context, iceRole ICERole, dtlsRole DTLSRole, remoteUfrag, remotePwd, fingerprint, fingerprintHash string) {
|
||||
// Start the ice transport
|
||||
err := pc.iceTransport.Start(
|
||||
ctx,
|
||||
pc.iceGatherer,
|
||||
ICEParameters{
|
||||
UsernameFragment: remoteUfrag,
|
||||
@@ -2002,7 +2004,7 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re
|
||||
}
|
||||
|
||||
// Start the dtls transport
|
||||
err = pc.dtlsTransport.Start(DTLSParameters{
|
||||
err = pc.dtlsTransport.Start(ctx, DTLSParameters{
|
||||
Role: dtlsRole,
|
||||
Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}},
|
||||
})
|
||||
@@ -2013,7 +2015,7 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re
|
||||
}
|
||||
}
|
||||
|
||||
func (pc *PeerConnection) startRTP(isRenegotiation bool, remoteDesc *SessionDescription, currentTransceivers []*RTPTransceiver) {
|
||||
func (pc *PeerConnection) startRTP(ctx context.Context, isRenegotiation bool, remoteDesc *SessionDescription, currentTransceivers []*RTPTransceiver) {
|
||||
trackDetails := trackDetailsFromSDP(pc.log, remoteDesc.parsed)
|
||||
if isRenegotiation {
|
||||
for _, t := range currentTransceivers {
|
||||
@@ -2047,13 +2049,13 @@ func (pc *PeerConnection) startRTP(isRenegotiation bool, remoteDesc *SessionDesc
|
||||
}
|
||||
}
|
||||
|
||||
pc.startRTPReceivers(trackDetails, currentTransceivers)
|
||||
pc.startRTPReceivers(ctx, trackDetails, currentTransceivers)
|
||||
if haveApplicationMediaSection(remoteDesc.parsed) {
|
||||
pc.startSCTP()
|
||||
pc.startSCTP(ctx)
|
||||
}
|
||||
|
||||
if !isRenegotiation {
|
||||
pc.undeclaredMediaProcessor()
|
||||
pc.undeclaredMediaProcessor(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1137,7 +1137,7 @@ func TestPeerConnection_MassiveTracks(t *testing.T) {
|
||||
<-connected
|
||||
time.Sleep(1 * time.Second)
|
||||
for _, track := range tracks {
|
||||
assert.NoError(t, track.WriteRTP(samplePkt))
|
||||
assert.NoError(t, track.WriteRTP(context.Background(), samplePkt))
|
||||
}
|
||||
// Ping trackRecords to see if any track event not received yet.
|
||||
tooLong := time.After(timeoutDuration)
|
||||
|
@@ -5,6 +5,7 @@ package webrtc
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -88,7 +89,10 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
if routineErr := pcAnswer.WriteRTCP([]rtcp.Packet{&rtcp.RapidResynchronizationRequest{SenderSSRC: uint32(track.SSRC()), MediaSSRC: uint32(track.SSRC())}}); routineErr != nil {
|
||||
if routineErr := pcAnswer.WriteRTCP(
|
||||
context.Background(),
|
||||
[]rtcp.Packet{&rtcp.RapidResynchronizationRequest{SenderSSRC: uint32(track.SSRC()), MediaSSRC: uint32(track.SSRC())}},
|
||||
); routineErr != nil {
|
||||
awaitRTCPReceiverSend <- routineErr
|
||||
return
|
||||
}
|
||||
@@ -103,7 +107,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, routineErr := receiver.Read(make([]byte, 1400))
|
||||
_, routineErr := receiver.Read(context.Background(), make([]byte, 1400))
|
||||
if routineErr != nil {
|
||||
awaitRTCPReceiverRecv <- routineErr
|
||||
} else {
|
||||
@@ -113,7 +117,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
|
||||
|
||||
haveClosedAwaitRTPRecv := false
|
||||
for {
|
||||
p, routineErr := track.ReadRTP()
|
||||
p, routineErr := track.ReadRTP(context.Background())
|
||||
if routineErr != nil {
|
||||
close(awaitRTPRecvClosed)
|
||||
return
|
||||
@@ -136,7 +140,9 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
if routineErr := vp8Track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}); routineErr != nil {
|
||||
if routineErr := vp8Track.WriteSample(
|
||||
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
|
||||
); routineErr != nil {
|
||||
fmt.Println(routineErr)
|
||||
}
|
||||
|
||||
@@ -152,7 +158,10 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
if routineErr := pcOffer.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{SenderSSRC: uint32(sender.ssrc), MediaSSRC: uint32(sender.ssrc)}}); routineErr != nil {
|
||||
if routineErr := pcOffer.WriteRTCP(
|
||||
context.Background(),
|
||||
[]rtcp.Packet{&rtcp.PictureLossIndication{SenderSSRC: uint32(sender.ssrc), MediaSSRC: uint32(sender.ssrc)}},
|
||||
); routineErr != nil {
|
||||
awaitRTCPSenderSend <- routineErr
|
||||
}
|
||||
|
||||
@@ -166,7 +175,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, routineErr := sender.Read(make([]byte, 1400)); routineErr == nil {
|
||||
if _, routineErr := sender.Read(context.Background(), make([]byte, 1400)); routineErr == nil {
|
||||
close(awaitRTCPSenderRecv)
|
||||
}
|
||||
}()
|
||||
@@ -366,9 +375,13 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for i := 0; i <= 5; i++ {
|
||||
if rtpErr := vp8Track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}); rtpErr != nil {
|
||||
if rtpErr := vp8Track.WriteSample(
|
||||
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
|
||||
); rtpErr != nil {
|
||||
t.Fatal(rtpErr)
|
||||
} else if rtcpErr := pcOffer.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: 0}}); rtcpErr != nil {
|
||||
} else if rtcpErr := pcOffer.WriteRTCP(
|
||||
context.Background(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: 0}},
|
||||
); rtcpErr != nil {
|
||||
t.Fatal(rtcpErr)
|
||||
}
|
||||
}
|
||||
@@ -448,7 +461,9 @@ func TestUndeclaredSSRC(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
for {
|
||||
assert.NoError(t, vp8Writer.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
|
||||
assert.NoError(t, vp8Writer.WriteSample(
|
||||
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
|
||||
))
|
||||
time.Sleep(time.Millisecond * 25)
|
||||
|
||||
select {
|
||||
@@ -686,11 +701,11 @@ func TestRtpSenderReceiver_ReadClose_Error(t *testing.T) {
|
||||
|
||||
sender, receiver := tr.Sender(), tr.Receiver()
|
||||
assert.NoError(t, sender.Stop())
|
||||
_, err = sender.Read(make([]byte, 0, 1400))
|
||||
_, err = sender.Read(context.Background(), make([]byte, 0, 1400))
|
||||
assert.Error(t, err, io.ErrClosedPipe)
|
||||
|
||||
assert.NoError(t, receiver.Stop())
|
||||
_, err = receiver.Read(make([]byte, 0, 1400))
|
||||
_, err = receiver.Read(context.Background(), make([]byte, 0, 1400))
|
||||
assert.Error(t, err, io.ErrClosedPipe)
|
||||
|
||||
assert.NoError(t, pc.Close())
|
||||
@@ -837,7 +852,9 @@ func TestPlanBMediaExchange(t *testing.T) {
|
||||
select {
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
for _, track := range outboundTracks {
|
||||
assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
|
||||
assert.NoError(t, track.WriteSample(
|
||||
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
|
||||
))
|
||||
}
|
||||
case <-done:
|
||||
return
|
||||
|
@@ -25,7 +25,9 @@ func sendVideoUntilDone(done <-chan struct{}, t *testing.T, tracks []*TrackLocal
|
||||
select {
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
for _, track := range tracks {
|
||||
assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
|
||||
assert.NoError(t, track.WriteSample(
|
||||
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
|
||||
))
|
||||
}
|
||||
case <-done:
|
||||
return
|
||||
@@ -100,7 +102,9 @@ func TestPeerConnection_Renegotiation_AddTrack(t *testing.T) {
|
||||
|
||||
// Send 10 packets, OnTrack MUST not be fired
|
||||
for i := 0; i <= 10; i++ {
|
||||
assert.NoError(t, vp8Track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
|
||||
assert.NoError(t, vp8Track.WriteSample(
|
||||
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
|
||||
))
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -360,7 +364,7 @@ func TestPeerConnection_Renegotiation_CodecChange(t *testing.T) {
|
||||
pcAnswer.OnTrack(func(track *TrackRemote, r *RTPReceiver) {
|
||||
tracksCh <- track
|
||||
for {
|
||||
if _, readErr := track.ReadRTP(); readErr == io.EOF {
|
||||
if _, readErr := track.ReadRTP(context.Background()); readErr == io.EOF {
|
||||
tracksClosed <- struct{}{}
|
||||
return
|
||||
}
|
||||
@@ -450,7 +454,7 @@ func TestPeerConnection_Renegotiation_RemoveTrack(t *testing.T) {
|
||||
onTrackFiredFunc()
|
||||
|
||||
for {
|
||||
if _, err := track.ReadRTP(); err == io.EOF {
|
||||
if _, err := track.ReadRTP(context.Background()); err == io.EOF {
|
||||
trackClosedFunc()
|
||||
return
|
||||
}
|
||||
@@ -838,7 +842,9 @@ func TestNegotiationNeededRemoveTrack(t *testing.T) {
|
||||
sender, err := pcOffer.AddTrack(track)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
|
||||
assert.NoError(t, track.WriteSample(
|
||||
context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second},
|
||||
))
|
||||
|
||||
wg.Wait()
|
||||
|
||||
|
@@ -5,6 +5,7 @@ package webrtc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -142,7 +143,7 @@ func (s *testQuicStack) setSignal(sig *testQuicSignal, isOffer bool) error {
|
||||
}
|
||||
|
||||
// Start the ICE transport
|
||||
err = s.ice.Start(nil, sig.ICEParameters, &iceRole)
|
||||
err = s.ice.Start(context.Background(), nil, sig.ICEParameters, &iceRole)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -3,13 +3,14 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/interceptor"
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/srtp"
|
||||
"github.com/pion/srtp/v2"
|
||||
)
|
||||
|
||||
// trackStreams maintains a mapping of RTP/RTCP streams to a specific track
|
||||
@@ -132,41 +133,45 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error {
|
||||
}
|
||||
|
||||
// Read reads incoming RTCP for this RTPReceiver
|
||||
func (r *RTPReceiver) Read(b []byte) (n int, err error) {
|
||||
func (r *RTPReceiver) Read(ctx context.Context, b []byte) (n int, err error) {
|
||||
select {
|
||||
case <-r.received:
|
||||
return r.tracks[0].rtcpReadStream.Read(b)
|
||||
return r.tracks[0].rtcpReadStream.ReadContext(ctx, b)
|
||||
case <-r.closed:
|
||||
return 0, io.ErrClosedPipe
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// ReadSimulcast reads incoming RTCP for this RTPReceiver for given rid
|
||||
func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, err error) {
|
||||
func (r *RTPReceiver) ReadSimulcast(ctx context.Context, b []byte, rid string) (n int, err error) {
|
||||
select {
|
||||
case <-r.received:
|
||||
for _, t := range r.tracks {
|
||||
if t.track != nil && t.track.rid == rid {
|
||||
return t.rtcpReadStream.Read(b)
|
||||
return t.rtcpReadStream.ReadContext(ctx, b)
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid)
|
||||
case <-r.closed:
|
||||
return 0, io.ErrClosedPipe
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// ReadRTCP is a convenience method that wraps Read and unmarshal for you.
|
||||
// It also runs any configured interceptors.
|
||||
func (r *RTPReceiver) ReadRTCP() ([]rtcp.Packet, error) {
|
||||
pkts, _, err := r.interceptorRTCPReader.Read()
|
||||
func (r *RTPReceiver) ReadRTCP(ctx context.Context) ([]rtcp.Packet, error) {
|
||||
pkts, _, err := r.interceptorRTCPReader.Read(ctx)
|
||||
return pkts, err
|
||||
}
|
||||
|
||||
// ReadRTCP is a convenience method that wraps Read and unmarshal for you
|
||||
func (r *RTPReceiver) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) {
|
||||
func (r *RTPReceiver) readRTCP(ctx context.Context) ([]rtcp.Packet, interceptor.Attributes, error) {
|
||||
b := make([]byte, receiveMTU)
|
||||
i, err := r.Read(b)
|
||||
i, err := r.Read(ctx, b)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -180,9 +185,9 @@ func (r *RTPReceiver) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error)
|
||||
}
|
||||
|
||||
// ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you
|
||||
func (r *RTPReceiver) ReadSimulcastRTCP(rid string) ([]rtcp.Packet, error) {
|
||||
func (r *RTPReceiver) ReadSimulcastRTCP(ctx context.Context, rid string) ([]rtcp.Packet, error) {
|
||||
b := make([]byte, receiveMTU)
|
||||
i, err := r.ReadSimulcast(b, rid)
|
||||
i, err := r.ReadSimulcast(ctx, b, rid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -241,10 +246,10 @@ func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams {
|
||||
}
|
||||
|
||||
// readRTP should only be called by a track, this only exists so we can keep state in one place
|
||||
func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, err error) {
|
||||
func (r *RTPReceiver) readRTP(ctx context.Context, b []byte, reader *TrackRemote) (n int, err error) {
|
||||
<-r.received
|
||||
if t := r.streamsForTrack(reader); t != nil {
|
||||
return t.rtpReadStream.Read(b)
|
||||
return t.rtpReadStream.ReadContext(ctx, b)
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC())
|
||||
|
19
rtpsender.go
19
rtpsender.go
@@ -3,6 +3,7 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
@@ -175,8 +176,8 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error {
|
||||
writeStream.setRTPWriter(
|
||||
r.api.interceptor.BindLocalStream(
|
||||
info,
|
||||
interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) {
|
||||
return r.srtpStream.WriteRTP(&p.Header, p.Payload)
|
||||
interceptor.RTPWriterFunc(func(ctx context.Context, p *rtp.Packet, attributes interceptor.Attributes) (int, error) {
|
||||
return r.srtpStream.WriteRTP(ctx, &p.Header, p.Payload)
|
||||
}),
|
||||
))
|
||||
|
||||
@@ -208,25 +209,27 @@ func (r *RTPSender) Stop() error {
|
||||
}
|
||||
|
||||
// Read reads incoming RTCP for this RTPReceiver
|
||||
func (r *RTPSender) Read(b []byte) (n int, err error) {
|
||||
func (r *RTPSender) Read(ctx context.Context, b []byte) (n int, err error) {
|
||||
select {
|
||||
case <-r.sendCalled:
|
||||
return r.srtpStream.Read(b)
|
||||
return r.srtpStream.ReadContext(ctx, b)
|
||||
case <-r.stopCalled:
|
||||
return 0, io.ErrClosedPipe
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// ReadRTCP is a convenience method that wraps Read and unmarshals for you.
|
||||
// It also runs any configured interceptors.
|
||||
func (r *RTPSender) ReadRTCP() ([]rtcp.Packet, error) {
|
||||
pkts, _, err := r.interceptorRTCPReader.Read()
|
||||
func (r *RTPSender) ReadRTCP(ctx context.Context) ([]rtcp.Packet, error) {
|
||||
pkts, _, err := r.interceptorRTCPReader.Read(ctx)
|
||||
return pkts, err
|
||||
}
|
||||
|
||||
func (r *RTPSender) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) {
|
||||
func (r *RTPSender) readRTCP(ctx context.Context) ([]rtcp.Packet, interceptor.Attributes, error) {
|
||||
b := make([]byte, receiveMTU)
|
||||
i, err := r.Read(b)
|
||||
i, err := r.Read(ctx, b)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@@ -50,7 +50,7 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) {
|
||||
assert.Equal(t, uint64(1), atomic.AddUint64(&onTrackCount, 1))
|
||||
|
||||
for {
|
||||
pkt, err := track.ReadRTP()
|
||||
pkt, err := track.ReadRTP(context.Background())
|
||||
if err != nil {
|
||||
assert.True(t, errors.Is(io.EOF, err))
|
||||
return
|
||||
@@ -74,7 +74,9 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) {
|
||||
case <-seenPacketA.Done():
|
||||
return
|
||||
default:
|
||||
assert.NoError(t, trackA.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second}))
|
||||
assert.NoError(t, trackA.WriteSample(
|
||||
context.Background(), media.Sample{Data: []byte{0xAA}, Duration: time.Second},
|
||||
))
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -88,7 +90,9 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) {
|
||||
case <-seenPacketB.Done():
|
||||
return
|
||||
default:
|
||||
assert.NoError(t, trackB.WriteSample(media.Sample{Data: []byte{0xBB}, Duration: time.Second}))
|
||||
assert.NoError(t, trackB.WriteSample(
|
||||
context.Background(), media.Sample{Data: []byte{0xBB}, Duration: time.Second},
|
||||
))
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -123,7 +127,9 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) {
|
||||
case <-seenPacket.Done():
|
||||
return
|
||||
default:
|
||||
assert.NoError(t, trackA.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second}))
|
||||
assert.NoError(t, trackA.WriteSample(
|
||||
context.Background(), media.Sample{Data: []byte{0xAA}, Duration: time.Second},
|
||||
))
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -134,3 +140,14 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) {
|
||||
assert.NoError(t, receiver.Close())
|
||||
})
|
||||
}
|
||||
|
||||
func Test_RTPSender_ContextCancel(t *testing.T) {
|
||||
sender := &RTPSender{
|
||||
sendCalled: make(chan struct{}),
|
||||
stopCalled: make(chan struct{}),
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err := sender.Read(ctx, []byte{})
|
||||
assert.Equal(t, context.Canceled, err)
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"math"
|
||||
"sync"
|
||||
@@ -90,7 +91,7 @@ func (r *SCTPTransport) GetCapabilities() SCTPCapabilities {
|
||||
// Start the SCTPTransport. Since both local and remote parties must mutually
|
||||
// create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish
|
||||
// a connection over SCTP.
|
||||
func (r *SCTPTransport) Start(remoteCaps SCTPCapabilities) error {
|
||||
func (r *SCTPTransport) Start(ctx context.Context, remoteCaps SCTPCapabilities) error {
|
||||
if r.isStarted {
|
||||
return nil
|
||||
}
|
||||
|
@@ -3,11 +3,12 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/pion/rtp"
|
||||
"github.com/pion/srtp"
|
||||
"github.com/pion/srtp/v2"
|
||||
)
|
||||
|
||||
// srtpWriterFuture blocks Read/Write calls until
|
||||
@@ -18,11 +19,13 @@ type srtpWriterFuture struct {
|
||||
rtpWriteStream atomic.Value // *srtp.WriteStreamSRTP
|
||||
}
|
||||
|
||||
func (s *srtpWriterFuture) init() error {
|
||||
func (s *srtpWriterFuture) init(ctx context.Context) error {
|
||||
select {
|
||||
case <-s.rtpSender.stopCalled:
|
||||
return io.ErrClosedPipe
|
||||
case <-s.rtpSender.transport.srtpReady:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
srtcpSession, err := s.rtpSender.transport.getSRTCPSession()
|
||||
@@ -58,38 +61,38 @@ func (s *srtpWriterFuture) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *srtpWriterFuture) Read(b []byte) (n int, err error) {
|
||||
func (s *srtpWriterFuture) ReadContext(ctx context.Context, b []byte) (n int, err error) {
|
||||
if value := s.rtcpReadStream.Load(); value != nil {
|
||||
return value.(*srtp.ReadStreamSRTCP).Read(b)
|
||||
return value.(*srtp.ReadStreamSRTCP).ReadContext(ctx, b)
|
||||
}
|
||||
|
||||
if err := s.init(); err != nil || s.rtcpReadStream.Load() == nil {
|
||||
if err := s.init(ctx); err != nil || s.rtcpReadStream.Load() == nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return s.Read(b)
|
||||
return s.ReadContext(ctx, b)
|
||||
}
|
||||
|
||||
func (s *srtpWriterFuture) WriteRTP(header *rtp.Header, payload []byte) (int, error) {
|
||||
func (s *srtpWriterFuture) WriteRTP(ctx context.Context, header *rtp.Header, payload []byte) (int, error) {
|
||||
if value := s.rtpWriteStream.Load(); value != nil {
|
||||
return value.(*srtp.WriteStreamSRTP).WriteRTP(header, payload)
|
||||
return value.(*srtp.WriteStreamSRTP).WriteRTP(ctx, header, payload)
|
||||
}
|
||||
|
||||
if err := s.init(); err != nil || s.rtpWriteStream.Load() == nil {
|
||||
if err := s.init(ctx); err != nil || s.rtpWriteStream.Load() == nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return s.WriteRTP(header, payload)
|
||||
return s.WriteRTP(ctx, header, payload)
|
||||
}
|
||||
|
||||
func (s *srtpWriterFuture) Write(b []byte) (int, error) {
|
||||
func (s *srtpWriterFuture) Write(ctx context.Context, b []byte) (int, error) {
|
||||
if value := s.rtpWriteStream.Load(); value != nil {
|
||||
return value.(*srtp.WriteStreamSRTP).Write(b)
|
||||
return value.(*srtp.WriteStreamSRTP).WriteContext(ctx, b)
|
||||
}
|
||||
|
||||
if err := s.init(); err != nil || s.rtpWriteStream.Load() == nil {
|
||||
if err := s.init(ctx); err != nil || s.rtpWriteStream.Load() == nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return s.Write(b)
|
||||
return s.Write(ctx, b)
|
||||
}
|
||||
|
@@ -1,14 +1,18 @@
|
||||
package webrtc
|
||||
|
||||
import "github.com/pion/rtp"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pion/rtp"
|
||||
)
|
||||
|
||||
// TrackLocalWriter is the Writer for outbound RTP Packets
|
||||
type TrackLocalWriter interface {
|
||||
// WriteRTP encrypts a RTP packet and writes to the connection
|
||||
WriteRTP(header *rtp.Header, payload []byte) (int, error)
|
||||
WriteRTP(ctx context.Context, header *rtp.Header, payload []byte) (int, error)
|
||||
|
||||
// Write encrypts and writes a full RTP packet
|
||||
Write(b []byte) (int, error)
|
||||
Write(ctx context.Context, b []byte) (int, error)
|
||||
}
|
||||
|
||||
// TrackLocalContext is the Context passed when a TrackLocal has been Binded/Unbinded from a PeerConnection, and used
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -102,7 +103,7 @@ func (s *TrackLocalStaticRTP) Kind() RTPCodecType {
|
||||
// If one PeerConnection fails the packets will still be sent to
|
||||
// all PeerConnections. The error message will contain the ID of the failed
|
||||
// PeerConnections so you can remove them
|
||||
func (s *TrackLocalStaticRTP) WriteRTP(p *rtp.Packet) error {
|
||||
func (s *TrackLocalStaticRTP) WriteRTP(ctx context.Context, p *rtp.Packet) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
@@ -112,7 +113,7 @@ func (s *TrackLocalStaticRTP) WriteRTP(p *rtp.Packet) error {
|
||||
for _, b := range s.bindings {
|
||||
outboundPacket.Header.SSRC = uint32(b.ssrc)
|
||||
outboundPacket.Header.PayloadType = uint8(b.payloadType)
|
||||
if _, err := b.writeStream.WriteRTP(&outboundPacket.Header, outboundPacket.Payload); err != nil {
|
||||
if _, err := b.writeStream.WriteRTP(ctx, &outboundPacket.Header, outboundPacket.Payload); err != nil {
|
||||
writeErrs = append(writeErrs, err)
|
||||
}
|
||||
}
|
||||
@@ -124,13 +125,13 @@ func (s *TrackLocalStaticRTP) WriteRTP(p *rtp.Packet) error {
|
||||
// If one PeerConnection fails the packets will still be sent to
|
||||
// all PeerConnections. The error message will contain the ID of the failed
|
||||
// PeerConnections so you can remove them
|
||||
func (s *TrackLocalStaticRTP) Write(b []byte) (n int, err error) {
|
||||
func (s *TrackLocalStaticRTP) Write(ctx context.Context, b []byte) (n int, err error) {
|
||||
packet := &rtp.Packet{}
|
||||
if err = packet.Unmarshal(b); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(b), s.WriteRTP(packet)
|
||||
return len(b), s.WriteRTP(ctx, packet)
|
||||
}
|
||||
|
||||
// TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples.
|
||||
@@ -208,7 +209,7 @@ func (s *TrackLocalStaticSample) Unbind(t TrackLocalContext) error {
|
||||
// If one PeerConnection fails the packets will still be sent to
|
||||
// all PeerConnections. The error message will contain the ID of the failed
|
||||
// PeerConnections so you can remove them
|
||||
func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error {
|
||||
func (s *TrackLocalStaticSample) WriteSample(ctx context.Context, sample media.Sample) error {
|
||||
s.rtpTrack.mu.RLock()
|
||||
p := s.packetizer
|
||||
clockRate := s.clockRate
|
||||
@@ -223,7 +224,7 @@ func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error {
|
||||
|
||||
writeErrs := []error{}
|
||||
for _, p := range packets {
|
||||
if err := s.rtpTrack.WriteRTP(p); err != nil {
|
||||
if err := s.rtpTrack.WriteRTP(ctx, p); err != nil {
|
||||
writeErrs = append(writeErrs, err)
|
||||
}
|
||||
}
|
||||
|
@@ -185,7 +185,7 @@ func Test_TrackLocalStatic_Mutate_Input(t *testing.T) {
|
||||
assert.NoError(t, signalPair(pcOffer, pcAnswer))
|
||||
|
||||
pkt := &rtp.Packet{Header: rtp.Header{SSRC: 1, PayloadType: 1}}
|
||||
assert.NoError(t, vp8Writer.WriteRTP(pkt))
|
||||
assert.NoError(t, vp8Writer.WriteRTP(context.Background(), pkt))
|
||||
|
||||
assert.Equal(t, pkt.Header.SSRC, uint32(1))
|
||||
assert.Equal(t, pkt.Header.PayloadType, uint8(1))
|
||||
|
@@ -3,6 +3,7 @@
|
||||
package webrtc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/interceptor"
|
||||
@@ -125,7 +126,7 @@ func (t *TrackRemote) Codec() RTPCodecParameters {
|
||||
}
|
||||
|
||||
// Read reads data from the track.
|
||||
func (t *TrackRemote) Read(b []byte) (n int, err error) {
|
||||
func (t *TrackRemote) Read(ctx context.Context, b []byte) (n int, err error) {
|
||||
t.mu.RLock()
|
||||
r := t.receiver
|
||||
peeked := t.peeked != nil
|
||||
@@ -144,12 +145,12 @@ func (t *TrackRemote) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
return r.readRTP(b, t)
|
||||
return r.readRTP(ctx, b, t)
|
||||
}
|
||||
|
||||
// peek is like Read, but it doesn't discard the packet read
|
||||
func (t *TrackRemote) peek(b []byte) (n int, err error) {
|
||||
n, err = t.Read(b)
|
||||
func (t *TrackRemote) peek(ctx context.Context, b []byte) (n int, err error) {
|
||||
n, err = t.Read(ctx, b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -167,14 +168,14 @@ func (t *TrackRemote) peek(b []byte) (n int, err error) {
|
||||
|
||||
// ReadRTP is a convenience method that wraps Read and unmarshals for you.
|
||||
// It also runs any configured interceptors.
|
||||
func (t *TrackRemote) ReadRTP() (*rtp.Packet, error) {
|
||||
p, _, err := t.interceptorRTPReader.Read()
|
||||
func (t *TrackRemote) ReadRTP(ctx context.Context) (*rtp.Packet, error) {
|
||||
p, _, err := t.interceptorRTPReader.Read(ctx)
|
||||
return p, err
|
||||
}
|
||||
|
||||
func (t *TrackRemote) readRTP() (*rtp.Packet, interceptor.Attributes, error) {
|
||||
func (t *TrackRemote) readRTP(ctx context.Context) (*rtp.Packet, interceptor.Attributes, error) {
|
||||
b := make([]byte, receiveMTU)
|
||||
i, err := t.Read(b)
|
||||
i, err := t.Read(ctx, b)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -188,9 +189,9 @@ func (t *TrackRemote) readRTP() (*rtp.Packet, interceptor.Attributes, error) {
|
||||
|
||||
// determinePayloadType blocks and reads a single packet to determine the PayloadType for this Track
|
||||
// this is useful because we can't announce it to the user until we know the payloadType
|
||||
func (t *TrackRemote) determinePayloadType() error {
|
||||
func (t *TrackRemote) determinePayloadType(ctx context.Context) error {
|
||||
b := make([]byte, receiveMTU)
|
||||
n, err := t.peek(b)
|
||||
n, err := t.peek(ctx, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
Reference in New Issue
Block a user