mirror of
				https://github.com/pion/webrtc.git
				synced 2025-10-26 08:40:38 +08:00 
			
		
		
		
	Read/Write RTP/RTCP packets with context
Control cancel/timeout by context.
This commit is contained in:
		| @@ -3,6 +3,7 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"io" | 	"io" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -112,19 +113,19 @@ func (s *testORTCStack) setSignal(sig *testORTCSignal, isOffer bool) error { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Start the ICE transport | 	// Start the ICE transport | ||||||
| 	err = s.ice.Start(nil, sig.ICEParameters, &iceRole) | 	err = s.ice.Start(context.Background(), nil, sig.ICEParameters, &iceRole) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Start the DTLS transport | 	// Start the DTLS transport | ||||||
| 	err = s.dtls.Start(sig.DTLSParameters) | 	err = s.dtls.Start(context.Background(), sig.DTLSParameters) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Start the SCTP transport | 	// Start the SCTP transport | ||||||
| 	err = s.sctp.Start(sig.SCTPCapabilities) | 	err = s.sctp.Start(context.Background(), sig.SCTPCapabilities) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"crypto/ecdsa" | 	"crypto/ecdsa" | ||||||
| 	"crypto/elliptic" | 	"crypto/elliptic" | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
| @@ -17,7 +18,7 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/pion/dtls/v2" | 	"github.com/pion/dtls/v2" | ||||||
| 	"github.com/pion/dtls/v2/pkg/crypto/fingerprint" | 	"github.com/pion/dtls/v2/pkg/crypto/fingerprint" | ||||||
| 	"github.com/pion/srtp" | 	"github.com/pion/srtp/v2" | ||||||
| 	"github.com/pion/webrtc/v3/internal/mux" | 	"github.com/pion/webrtc/v3/internal/mux" | ||||||
| 	"github.com/pion/webrtc/v3/internal/util" | 	"github.com/pion/webrtc/v3/internal/util" | ||||||
| 	"github.com/pion/webrtc/v3/pkg/rtcerr" | 	"github.com/pion/webrtc/v3/pkg/rtcerr" | ||||||
| @@ -146,7 +147,7 @@ func (t *DTLSTransport) GetRemoteCertificate() []byte { | |||||||
| 	return t.remoteCertificate | 	return t.remoteCertificate | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *DTLSTransport) startSRTP() error { | func (t *DTLSTransport) startSRTP(ctx context.Context) error { | ||||||
| 	srtpConfig := &srtp.Config{ | 	srtpConfig := &srtp.Config{ | ||||||
| 		Profile:       t.srtpProtectionProfile, | 		Profile:       t.srtpProtectionProfile, | ||||||
| 		LoggerFactory: t.api.settingEngine.LoggerFactory, | 		LoggerFactory: t.api.settingEngine.LoggerFactory, | ||||||
| @@ -185,12 +186,12 @@ func (t *DTLSTransport) startSRTP() error { | |||||||
| 		return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err) | 		return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig) | 	srtpSession, err := srtp.NewSessionSRTP(ctx, t.srtpEndpoint, srtpConfig) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("%w: %v", errFailedToStartSRTP, err) | 		return fmt.Errorf("%w: %v", errFailedToStartSRTP, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig) | 	srtcpSession, err := srtp.NewSessionSRTCP(ctx, t.srtcpEndpoint, srtpConfig) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err) | 		return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err) | ||||||
| 	} | 	} | ||||||
| @@ -244,7 +245,7 @@ func (t *DTLSTransport) role() DTLSRole { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Start DTLS transport negotiation with the parameters of the remote DTLS transport | // Start DTLS transport negotiation with the parameters of the remote DTLS transport | ||||||
| func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { | func (t *DTLSTransport) Start(ctx context.Context, remoteParameters DTLSParameters) error { | ||||||
| 	// Take lock and prepare connection, we must not hold the lock | 	// Take lock and prepare connection, we must not hold the lock | ||||||
| 	// when connecting | 	// when connecting | ||||||
| 	prepareTransport := func() (DTLSRole, *dtls.Config, error) { | 	prepareTransport := func() (DTLSRole, *dtls.Config, error) { | ||||||
| @@ -350,7 +351,7 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return t.startSRTP() | 	return t.startSRTP(ctx) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Stop stops and closes the DTLSTransport object. | // Stop stops and closes the DTLSTransport object. | ||||||
|   | |||||||
| @@ -100,6 +100,7 @@ func TestE2E_Audio(t *testing.T) { | |||||||
| 			go func() { | 			go func() { | ||||||
| 				for { | 				for { | ||||||
| 					if err := track.WriteSample( | 					if err := track.WriteSample( | ||||||
|  | 						context.Background(), | ||||||
| 						media.Sample{Data: silentOpusFrame, Duration: time.Millisecond * 20}, | 						media.Sample{Data: silentOpusFrame, Duration: time.Millisecond * 20}, | ||||||
| 					); err != nil { | 					); err != nil { | ||||||
| 						t.Errorf("Failed to WriteSample: %v", err) | 						t.Errorf("Failed to WriteSample: %v", err) | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| @@ -53,7 +54,9 @@ func main() { // nolint:gocognit | |||||||
| 		go func() { | 		go func() { | ||||||
| 			ticker := time.NewTicker(rtcpPLIInterval) | 			ticker := time.NewTicker(rtcpPLIInterval) | ||||||
| 			for range ticker.C { | 			for range ticker.C { | ||||||
| 				if rtcpSendErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(remoteTrack.SSRC())}}); rtcpSendErr != nil { | 				if rtcpSendErr := peerConnection.WriteRTCP( | ||||||
|  | 					context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(remoteTrack.SSRC())}}, | ||||||
|  | 				); rtcpSendErr != nil { | ||||||
| 					fmt.Println(rtcpSendErr) | 					fmt.Println(rtcpSendErr) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| @@ -68,13 +71,13 @@ func main() { // nolint:gocognit | |||||||
|  |  | ||||||
| 		rtpBuf := make([]byte, 1400) | 		rtpBuf := make([]byte, 1400) | ||||||
| 		for { | 		for { | ||||||
| 			i, readErr := remoteTrack.Read(rtpBuf) | 			i, readErr := remoteTrack.Read(context.TODO(), rtpBuf) | ||||||
| 			if readErr != nil { | 			if readErr != nil { | ||||||
| 				panic(readErr) | 				panic(readErr) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			// ErrClosedPipe means we don't have any subscribers, this is ok if no peers have connected yet | 			// ErrClosedPipe means we don't have any subscribers, this is ok if no peers have connected yet | ||||||
| 			if _, err = localTrack.Write(rtpBuf[:i]); err != nil && !errors.Is(err, io.ErrClosedPipe) { | 			if _, err = localTrack.Write(context.TODO(), rtpBuf[:i]); err != nil && !errors.Is(err, io.ErrClosedPipe) { | ||||||
| 				panic(err) | 				panic(err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -74,7 +74,7 @@ func main() { | |||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			time.Sleep(sleepTime) | 			time.Sleep(sleepTime) | ||||||
| 			if ivfErr = videoTrack.WriteSample(media.Sample{Data: frame, Duration: time.Second}); ivfErr != nil { | 			if ivfErr = videoTrack.WriteSample(context.TODO(), media.Sample{Data: frame, Duration: time.Second}); ivfErr != nil { | ||||||
| 				panic(ivfErr) | 				panic(ivfErr) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"flag" | 	"flag" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -110,7 +111,7 @@ func main() { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Start the ICE transport | 	// Start the ICE transport | ||||||
| 	err = ice.Start(nil, remoteSignal.ICEParameters, &iceRole) | 	err = ice.Start(context.TODO(), nil, remoteSignal.ICEParameters, &iceRole) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		panic(err) | 		panic(err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"flag" | 	"flag" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -111,19 +112,19 @@ func main() { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Start the ICE transport | 	// Start the ICE transport | ||||||
| 	err = ice.Start(nil, remoteSignal.ICEParameters, &iceRole) | 	err = ice.Start(context.TODO(), nil, remoteSignal.ICEParameters, &iceRole) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		panic(err) | 		panic(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Start the DTLS transport | 	// Start the DTLS transport | ||||||
| 	err = dtls.Start(remoteSignal.DTLSParameters) | 	err = dtls.Start(context.TODO(), remoteSignal.DTLSParameters) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		panic(err) | 		panic(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Start the SCTP transport | 	// Start the SCTP transport | ||||||
| 	err = sctp.Start(remoteSignal.SCTPCapabilities) | 	err = sctp.Start(context.TODO(), remoteSignal.SCTPCapabilities) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		panic(err) | 		panic(err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| @@ -145,7 +146,7 @@ func writeVideoToTrack(t *webrtc.TrackLocalStaticSample) { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		time.Sleep(sleepTime) | 		time.Sleep(sleepTime) | ||||||
| 		if err = t.WriteSample(media.Sample{Data: frame, Duration: time.Second}); err != nil { | 		if err = t.WriteSample(context.TODO(), media.Sample{Data: frame, Duration: time.Second}); err != nil { | ||||||
| 			fmt.Printf("Finish writing video track: %s ", err) | 			fmt.Printf("Finish writing video track: %s ", err) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -87,7 +87,7 @@ func main() { | |||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				time.Sleep(sleepTime) | 				time.Sleep(sleepTime) | ||||||
| 				if ivfErr = videoTrack.WriteSample(media.Sample{Data: frame, Duration: time.Second}); ivfErr != nil { | 				if ivfErr = videoTrack.WriteSample(context.TODO(), media.Sample{Data: frame, Duration: time.Second}); ivfErr != nil { | ||||||
| 					panic(ivfErr) | 					panic(ivfErr) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| @@ -138,7 +138,7 @@ func main() { | |||||||
| 				lastGranule = pageHeader.GranulePosition | 				lastGranule = pageHeader.GranulePosition | ||||||
| 				sampleDuration := time.Duration((sampleCount/48000)*1000) * time.Millisecond | 				sampleDuration := time.Duration((sampleCount/48000)*1000) * time.Millisecond | ||||||
|  |  | ||||||
| 				if oggErr = audioTrack.WriteSample(media.Sample{Data: pageData, Duration: sampleDuration}); oggErr != nil { | 				if oggErr = audioTrack.WriteSample(context.TODO(), media.Sample{Data: pageData, Duration: sampleDuration}); oggErr != nil { | ||||||
| 					panic(oggErr) | 					panic(oggErr) | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -72,7 +73,9 @@ func main() { | |||||||
| 		go func() { | 		go func() { | ||||||
| 			ticker := time.NewTicker(time.Second * 3) | 			ticker := time.NewTicker(time.Second * 3) | ||||||
| 			for range ticker.C { | 			for range ticker.C { | ||||||
| 				errSend := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}) | 				errSend := peerConnection.WriteRTCP( | ||||||
|  | 					context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}, | ||||||
|  | 				) | ||||||
| 				if errSend != nil { | 				if errSend != nil { | ||||||
| 					fmt.Println(errSend) | 					fmt.Println(errSend) | ||||||
| 				} | 				} | ||||||
| @@ -82,12 +85,12 @@ func main() { | |||||||
| 		fmt.Printf("Track has started, of type %d: %s \n", track.PayloadType(), track.Codec().MimeType) | 		fmt.Printf("Track has started, of type %d: %s \n", track.PayloadType(), track.Codec().MimeType) | ||||||
| 		for { | 		for { | ||||||
| 			// Read RTP packets being sent to Pion | 			// Read RTP packets being sent to Pion | ||||||
| 			rtp, readErr := track.ReadRTP() | 			rtp, readErr := track.ReadRTP(context.TODO()) | ||||||
| 			if readErr != nil { | 			if readErr != nil { | ||||||
| 				panic(readErr) | 				panic(readErr) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			if writeErr := outputTrack.WriteRTP(rtp); writeErr != nil { | 			if writeErr := outputTrack.WriteRTP(context.TODO(), rtp); writeErr != nil { | ||||||
| 				panic(writeErr) | 				panic(writeErr) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -107,7 +107,9 @@ func main() { | |||||||
| 		go func() { | 		go func() { | ||||||
| 			ticker := time.NewTicker(time.Second * 2) | 			ticker := time.NewTicker(time.Second * 2) | ||||||
| 			for range ticker.C { | 			for range ticker.C { | ||||||
| 				if rtcpErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}); rtcpErr != nil { | 				if rtcpErr := peerConnection.WriteRTCP( | ||||||
|  | 					context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}, | ||||||
|  | 				); rtcpErr != nil { | ||||||
| 					fmt.Println(rtcpErr) | 					fmt.Println(rtcpErr) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| @@ -116,7 +118,7 @@ func main() { | |||||||
| 		b := make([]byte, 1500) | 		b := make([]byte, 1500) | ||||||
| 		for { | 		for { | ||||||
| 			// Read | 			// Read | ||||||
| 			n, readErr := track.Read(b) | 			n, readErr := track.Read(context.TODO(), b) | ||||||
| 			if readErr != nil { | 			if readErr != nil { | ||||||
| 				panic(readErr) | 				panic(readErr) | ||||||
| 			} | 			} | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
|  |  | ||||||
| @@ -103,7 +104,7 @@ func main() { | |||||||
| 			panic(err) | 			panic(err) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if _, writeErr := videoTrack.Write(inboundRTPPacket[:n]); writeErr != nil { | 		if _, writeErr := videoTrack.Write(context.TODO(), inboundRTPPacket[:n]); writeErr != nil { | ||||||
| 			panic(writeErr) | 			panic(writeErr) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"os" | 	"os" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -23,7 +24,7 @@ func saveToDisk(i media.Writer, track *webrtc.TrackRemote) { | |||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	for { | 	for { | ||||||
| 		rtpPacket, err := track.ReadRTP() | 		rtpPacket, err := track.ReadRTP(context.TODO()) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			panic(err) | 			panic(err) | ||||||
| 		} | 		} | ||||||
| @@ -96,7 +97,9 @@ func main() { | |||||||
| 		go func() { | 		go func() { | ||||||
| 			ticker := time.NewTicker(time.Second * 3) | 			ticker := time.NewTicker(time.Second * 3) | ||||||
| 			for range ticker.C { | 			for range ticker.C { | ||||||
| 				errSend := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}) | 				errSend := peerConnection.WriteRTCP( | ||||||
|  | 					context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}, | ||||||
|  | 				) | ||||||
| 				if errSend != nil { | 				if errSend != nil { | ||||||
| 					fmt.Println(errSend) | 					fmt.Println(errSend) | ||||||
| 				} | 				} | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| @@ -80,24 +81,28 @@ func main() { | |||||||
| 			ticker := time.NewTicker(3 * time.Second) | 			ticker := time.NewTicker(3 * time.Second) | ||||||
| 			for range ticker.C { | 			for range ticker.C { | ||||||
| 				fmt.Printf("Sending pli for stream with rid: %q, ssrc: %d\n", track.RID(), track.SSRC()) | 				fmt.Printf("Sending pli for stream with rid: %q, ssrc: %d\n", track.RID(), track.SSRC()) | ||||||
| 				if writeErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}); writeErr != nil { | 				if writeErr := peerConnection.WriteRTCP( | ||||||
|  | 					context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}, | ||||||
|  | 				); writeErr != nil { | ||||||
| 					fmt.Println(writeErr) | 					fmt.Println(writeErr) | ||||||
| 				} | 				} | ||||||
| 				// Send a remb message with a very high bandwidth to trigger chrome to send also the high bitrate stream | 				// Send a remb message with a very high bandwidth to trigger chrome to send also the high bitrate stream | ||||||
| 				fmt.Printf("Sending remb for stream with rid: %q, ssrc: %d\n", track.RID(), track.SSRC()) | 				fmt.Printf("Sending remb for stream with rid: %q, ssrc: %d\n", track.RID(), track.SSRC()) | ||||||
| 				if writeErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.ReceiverEstimatedMaximumBitrate{Bitrate: 10000000, SenderSSRC: uint32(track.SSRC())}}); writeErr != nil { | 				if writeErr := peerConnection.WriteRTCP( | ||||||
|  | 					context.TODO(), []rtcp.Packet{&rtcp.ReceiverEstimatedMaximumBitrate{Bitrate: 10000000, SenderSSRC: uint32(track.SSRC())}}, | ||||||
|  | 				); writeErr != nil { | ||||||
| 					fmt.Println(writeErr) | 					fmt.Println(writeErr) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
| 		for { | 		for { | ||||||
| 			// Read RTP packets being sent to Pion | 			// Read RTP packets being sent to Pion | ||||||
| 			packet, readErr := track.ReadRTP() | 			packet, readErr := track.ReadRTP(context.TODO()) | ||||||
| 			if readErr != nil { | 			if readErr != nil { | ||||||
| 				panic(readErr) | 				panic(readErr) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			if writeErr := outputTracks[rid].WriteRTP(packet); writeErr != nil && !errors.Is(writeErr, io.ErrClosedPipe) { | 			if writeErr := outputTracks[rid].WriteRTP(context.TODO(), packet); writeErr != nil && !errors.Is(writeErr, io.ErrClosedPipe) { | ||||||
| 				panic(writeErr) | 				panic(writeErr) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| @@ -85,7 +86,7 @@ func main() { // nolint:gocognit | |||||||
| 		var isCurrTrack bool | 		var isCurrTrack bool | ||||||
| 		for { | 		for { | ||||||
| 			// Read RTP packets being sent to Pion | 			// Read RTP packets being sent to Pion | ||||||
| 			rtp, readErr := track.ReadRTP() | 			rtp, readErr := track.ReadRTP(context.TODO()) | ||||||
| 			if readErr != nil { | 			if readErr != nil { | ||||||
| 				panic(readErr) | 				panic(readErr) | ||||||
| 			} | 			} | ||||||
| @@ -104,7 +105,9 @@ func main() { // nolint:gocognit | |||||||
| 				// If just switched to this track, send PLI to get picture refresh | 				// If just switched to this track, send PLI to get picture refresh | ||||||
| 				if !isCurrTrack { | 				if !isCurrTrack { | ||||||
| 					isCurrTrack = true | 					isCurrTrack = true | ||||||
| 					if writeErr := peerConnection.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}); writeErr != nil { | 					if writeErr := peerConnection.WriteRTCP( | ||||||
|  | 						context.TODO(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}, | ||||||
|  | 					); writeErr != nil { | ||||||
| 						fmt.Println(writeErr) | 						fmt.Println(writeErr) | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| @@ -154,7 +157,7 @@ func main() { // nolint:gocognit | |||||||
| 			// Keep an increasing sequence number | 			// Keep an increasing sequence number | ||||||
| 			packet.SequenceNumber = i | 			packet.SequenceNumber = i | ||||||
| 			// Write out the packet, ignoring closed pipe if nobody is listening | 			// Write out the packet, ignoring closed pipe if nobody is listening | ||||||
| 			if err := outputTrack.WriteRTP(packet); err != nil && !errors.Is(err, io.ErrClosedPipe) { | 			if err := outputTrack.WriteRTP(context.TODO(), packet); err != nil && !errors.Is(err, io.ErrClosedPipe) { | ||||||
| 				panic(err) | 				panic(err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|   | |||||||
							
								
								
									
										6
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								go.mod
									
									
									
									
									
								
							| @@ -6,7 +6,7 @@ require ( | |||||||
| 	github.com/pion/datachannel v1.4.21 | 	github.com/pion/datachannel v1.4.21 | ||||||
| 	github.com/pion/dtls/v2 v2.0.4 | 	github.com/pion/dtls/v2 v2.0.4 | ||||||
| 	github.com/pion/ice/v2 v2.0.13 | 	github.com/pion/ice/v2 v2.0.13 | ||||||
| 	github.com/pion/interceptor v0.0.3 | 	github.com/pion/interceptor v0.0.4 | ||||||
| 	github.com/pion/logging v0.2.2 | 	github.com/pion/logging v0.2.2 | ||||||
| 	github.com/pion/quic v0.1.4 | 	github.com/pion/quic v0.1.4 | ||||||
| 	github.com/pion/randutil v0.1.0 | 	github.com/pion/randutil v0.1.0 | ||||||
| @@ -14,8 +14,8 @@ require ( | |||||||
| 	github.com/pion/rtp v1.6.1 | 	github.com/pion/rtp v1.6.1 | ||||||
| 	github.com/pion/sctp v1.7.11 | 	github.com/pion/sctp v1.7.11 | ||||||
| 	github.com/pion/sdp/v3 v3.0.3 | 	github.com/pion/sdp/v3 v3.0.3 | ||||||
| 	github.com/pion/srtp v1.5.2 | 	github.com/pion/srtp/v2 v2.0.0-rc.1 | ||||||
| 	github.com/pion/transport v0.11.0 | 	github.com/pion/transport v0.11.1 | ||||||
| 	github.com/sclevine/agouti v3.0.0+incompatible | 	github.com/sclevine/agouti v3.0.0+incompatible | ||||||
| 	github.com/stretchr/testify v1.6.1 | 	github.com/stretchr/testify v1.6.1 | ||||||
| 	golang.org/x/net v0.0.0-20201110031124-69a78807bb2b | 	golang.org/x/net v0.0.0-20201110031124-69a78807bb2b | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								go.sum
									
									
									
									
									
								
							| @@ -106,8 +106,8 @@ github.com/pion/dtls/v2 v2.0.4 h1:WuUcqi6oYMu/noNTz92QrF1DaFj4eXbhQ6dzaaAwOiI= | |||||||
| github.com/pion/dtls/v2 v2.0.4/go.mod h1:qAkFscX0ZHoI1E07RfYPoRw3manThveu+mlTDdOxoGI= | github.com/pion/dtls/v2 v2.0.4/go.mod h1:qAkFscX0ZHoI1E07RfYPoRw3manThveu+mlTDdOxoGI= | ||||||
| github.com/pion/ice/v2 v2.0.13 h1:lVe7g86tQ0vKdH430hQR/t7zV1oeXbK75130TUArrnw= | github.com/pion/ice/v2 v2.0.13 h1:lVe7g86tQ0vKdH430hQR/t7zV1oeXbK75130TUArrnw= | ||||||
| github.com/pion/ice/v2 v2.0.13/go.mod h1:mZlypgoynMn2ayhGsjrPY/G/WiRiYO8WCPC6gUeg1RA= | github.com/pion/ice/v2 v2.0.13/go.mod h1:mZlypgoynMn2ayhGsjrPY/G/WiRiYO8WCPC6gUeg1RA= | ||||||
| github.com/pion/interceptor v0.0.3 h1:VQtmPts/2IgYQtb9sZLTp6B0kIdHE5zBMQ6tCcgdJcM= | github.com/pion/interceptor v0.0.4 h1:FNq8cFDDv0i+Db9oEKccmMx4rerPm6pzBf4szjNaX2E= | ||||||
| github.com/pion/interceptor v0.0.3/go.mod h1:lPVrf5xfosI989ZcmgPS4WwwRhd+XAyTFaYI2wHf7nU= | github.com/pion/interceptor v0.0.4/go.mod h1:lPVrf5xfosI989ZcmgPS4WwwRhd+XAyTFaYI2wHf7nU= | ||||||
| github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= | github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= | ||||||
| github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= | github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= | ||||||
| github.com/pion/mdns v0.0.4 h1:O4vvVqr4DGX63vzmO6Fw9vpy3lfztVWHGCQfyw0ZLSY= | github.com/pion/mdns v0.0.4 h1:O4vvVqr4DGX63vzmO6Fw9vpy3lfztVWHGCQfyw0ZLSY= | ||||||
| @@ -125,8 +125,8 @@ github.com/pion/sctp v1.7.11 h1:UCnj7MsobLKLuP/Hh+JMiI/6W5Bs/VF45lWKgHFjSIE= | |||||||
| github.com/pion/sctp v1.7.11/go.mod h1:EhpTUQu1/lcK3xI+eriS6/96fWetHGCvBi9MSsnaBN0= | github.com/pion/sctp v1.7.11/go.mod h1:EhpTUQu1/lcK3xI+eriS6/96fWetHGCvBi9MSsnaBN0= | ||||||
| github.com/pion/sdp/v3 v3.0.3 h1:gJK9hk+JFD2NGIM1nXmqNCq1DkVaIZ9dlA3u3otnkaw= | github.com/pion/sdp/v3 v3.0.3 h1:gJK9hk+JFD2NGIM1nXmqNCq1DkVaIZ9dlA3u3otnkaw= | ||||||
| github.com/pion/sdp/v3 v3.0.3/go.mod h1:bNiSknmJE0HYBprTHXKPQ3+JjacTv5uap92ueJZKsRk= | github.com/pion/sdp/v3 v3.0.3/go.mod h1:bNiSknmJE0HYBprTHXKPQ3+JjacTv5uap92ueJZKsRk= | ||||||
| github.com/pion/srtp v1.5.2 h1:25DmvH+fqKZDqvX64vTwnycVwL9ooJxHF/gkX16bDBY= | github.com/pion/srtp/v2 v2.0.0-rc.1 h1:UGabuCAIE5Yn5qmFr9H8UGdzYdYIaMt/AZLskcBcumA= | ||||||
| github.com/pion/srtp v1.5.2/go.mod h1:NiBff/MSxUwMUwx/fRNyD/xGE+dVvf8BOCeXhjCXZ9U= | github.com/pion/srtp/v2 v2.0.0-rc.1/go.mod h1:i3Oyh9/RPIdYRXywhs0NwTkMV/XYYr8NGj1E3gsRgk8= | ||||||
| github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg= | github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg= | ||||||
| github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA= | github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA= | ||||||
| github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRhSzR+CZ8= | github.com/pion/transport v0.8.10/go.mod h1:tBmha/UCjpum5hqTWhfAEs3CO4/tHSg0MYRhSzR+CZ8= | ||||||
| @@ -135,6 +135,8 @@ github.com/pion/transport v0.10.1 h1:2W+yJT+0mOQ160ThZYUx5Zp2skzshiNgxrNE9GUfhJM | |||||||
| github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+DwKOzf4A= | github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+DwKOzf4A= | ||||||
| github.com/pion/transport v0.11.0 h1:Z1RhzqrWPPYj5Xed8P7pirTKTvXFoxDI3uJuuKu6akM= | github.com/pion/transport v0.11.0 h1:Z1RhzqrWPPYj5Xed8P7pirTKTvXFoxDI3uJuuKu6akM= | ||||||
| github.com/pion/transport v0.11.0/go.mod h1:ORH8Ouyl1enoJyHwU+MwMeQocWbeorEk5068FOsHjog= | github.com/pion/transport v0.11.0/go.mod h1:ORH8Ouyl1enoJyHwU+MwMeQocWbeorEk5068FOsHjog= | ||||||
|  | github.com/pion/transport v0.11.1 h1:z+6FJ3T7R4pX87efsiFLmIRLYVLLk7XZI76kHkI5V10= | ||||||
|  | github.com/pion/transport v0.11.1/go.mod h1:ORH8Ouyl1enoJyHwU+MwMeQocWbeorEk5068FOsHjog= | ||||||
| github.com/pion/turn/v2 v2.0.5 h1:iwMHqDfPEDEOFzwWKT56eFmh6DYC6o/+xnLAEzgISbA= | github.com/pion/turn/v2 v2.0.5 h1:iwMHqDfPEDEOFzwWKT56eFmh6DYC6o/+xnLAEzgISbA= | ||||||
| github.com/pion/turn/v2 v2.0.5/go.mod h1:APg43CFyt/14Uy7heYUOGWdkem/Wu4PhCO/bjyrTqMw= | github.com/pion/turn/v2 v2.0.5/go.mod h1:APg43CFyt/14Uy7heYUOGWdkem/Wu4PhCO/bjyrTqMw= | ||||||
| github.com/pion/udp v0.1.0 h1:uGxQsNyrqG3GLINv36Ff60covYmfrLoxzwnCsIYspXI= | github.com/pion/udp v0.1.0 h1:uGxQsNyrqG3GLINv36Ff60covYmfrLoxzwnCsIYspXI= | ||||||
|   | |||||||
| @@ -11,6 +11,7 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/pion/ice/v2" | 	"github.com/pion/ice/v2" | ||||||
| 	"github.com/pion/logging" | 	"github.com/pion/logging" | ||||||
|  | 	"github.com/pion/transport/connctx" | ||||||
| 	"github.com/pion/webrtc/v3/internal/mux" | 	"github.com/pion/webrtc/v3/internal/mux" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -70,7 +71,7 @@ func NewICETransport(gatherer *ICEGatherer, loggerFactory logging.LoggerFactory) | |||||||
| } | } | ||||||
|  |  | ||||||
| // Start incoming connectivity checks based on its configured role. | // Start incoming connectivity checks based on its configured role. | ||||||
| func (t *ICETransport) Start(gatherer *ICEGatherer, params ICEParameters, role *ICERole) error { | func (t *ICETransport) Start(ctx context.Context, gatherer *ICEGatherer, params ICEParameters, role *ICERole) error { | ||||||
| 	t.lock.Lock() | 	t.lock.Lock() | ||||||
| 	defer t.lock.Unlock() | 	defer t.lock.Unlock() | ||||||
|  |  | ||||||
| @@ -142,11 +143,11 @@ func (t *ICETransport) Start(gatherer *ICEGatherer, params ICEParameters, role * | |||||||
| 	t.conn = iceConn | 	t.conn = iceConn | ||||||
|  |  | ||||||
| 	config := mux.Config{ | 	config := mux.Config{ | ||||||
| 		Conn:          t.conn, | 		Conn:          connctx.New(t.conn), | ||||||
| 		BufferSize:    receiveMTU, | 		BufferSize:    receiveMTU, | ||||||
| 		LoggerFactory: t.loggerFactory, | 		LoggerFactory: t.loggerFactory, | ||||||
| 	} | 	} | ||||||
| 	t.mux = mux.NewMux(config) | 	t.mux = mux.NewMux(ctx, config) | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"testing" | 	"testing" | ||||||
| @@ -25,19 +26,19 @@ type testInterceptor struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (t *testInterceptor) BindLocalStream(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { | func (t *testInterceptor) BindLocalStream(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { | ||||||
| 	return interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) { | 	return interceptor.RTPWriterFunc(func(ctx context.Context, p *rtp.Packet, attributes interceptor.Attributes) (int, error) { | ||||||
| 		// set extension on outgoing packet | 		// set extension on outgoing packet | ||||||
| 		p.Header.Extension = true | 		p.Header.Extension = true | ||||||
| 		p.Header.ExtensionProfile = 0xBEDE | 		p.Header.ExtensionProfile = 0xBEDE | ||||||
| 		assert.NoError(t.t, p.Header.SetExtension(t.extensionID, []byte("write"))) | 		assert.NoError(t.t, p.Header.SetExtension(t.extensionID, []byte("write"))) | ||||||
|  |  | ||||||
| 		return writer.Write(p, attributes) | 		return writer.Write(ctx, p, attributes) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { | func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { | ||||||
| 	return interceptor.RTPReaderFunc(func() (*rtp.Packet, interceptor.Attributes, error) { | 	return interceptor.RTPReaderFunc(func(ctx context.Context) (*rtp.Packet, interceptor.Attributes, error) { | ||||||
| 		p, attributes, err := reader.Read() | 		p, attributes, err := reader.Read(ctx) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, nil, err | 			return nil, nil, err | ||||||
| 		} | 		} | ||||||
| @@ -49,7 +50,7 @@ func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader | |||||||
| 		// write back a pli | 		// write back a pli | ||||||
| 		rtcpWriter := t.rtcpWriter.Load().(interceptor.RTCPWriter) | 		rtcpWriter := t.rtcpWriter.Load().(interceptor.RTCPWriter) | ||||||
| 		pli := &rtcp.PictureLossIndication{SenderSSRC: info.SSRC, MediaSSRC: info.SSRC} | 		pli := &rtcp.PictureLossIndication{SenderSSRC: info.SSRC, MediaSSRC: info.SSRC} | ||||||
| 		_, err = rtcpWriter.Write([]rtcp.Packet{pli}, make(interceptor.Attributes)) | 		_, err = rtcpWriter.Write(ctx, []rtcp.Packet{pli}, make(interceptor.Attributes)) | ||||||
| 		assert.NoError(t.t, err) | 		assert.NoError(t.t, err) | ||||||
|  |  | ||||||
| 		return p, attributes, nil | 		return p, attributes, nil | ||||||
| @@ -57,8 +58,8 @@ func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader | |||||||
| } | } | ||||||
|  |  | ||||||
| func (t *testInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { | func (t *testInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { | ||||||
| 	return interceptor.RTCPReaderFunc(func() ([]rtcp.Packet, interceptor.Attributes, error) { | 	return interceptor.RTCPReaderFunc(func(ctx context.Context) ([]rtcp.Packet, interceptor.Attributes, error) { | ||||||
| 		pkts, attributes, err := reader.Read() | 		pkts, attributes, err := reader.Read(ctx) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, nil, err | 			return nil, nil, err | ||||||
| 		} | 		} | ||||||
| @@ -122,7 +123,7 @@ func TestPeerConnection_Interceptor(t *testing.T) { | |||||||
| 	wg.Add(1) | 	wg.Add(1) | ||||||
| 	*pending++ | 	*pending++ | ||||||
| 	receiverPC.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) { | 	receiverPC.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) { | ||||||
| 		p, readErr := track.ReadRTP() | 		p, readErr := track.ReadRTP(context.Background()) | ||||||
| 		if readErr != nil { | 		if readErr != nil { | ||||||
| 			t.Fatal(readErr) | 			t.Fatal(readErr) | ||||||
| 		} | 		} | ||||||
| @@ -133,7 +134,7 @@ func TestPeerConnection_Interceptor(t *testing.T) { | |||||||
| 		wg.Done() | 		wg.Done() | ||||||
|  |  | ||||||
| 		for { | 		for { | ||||||
| 			_, readErr = track.ReadRTP() | 			_, readErr = track.ReadRTP(context.Background()) | ||||||
| 			if readErr != nil { | 			if readErr != nil { | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| @@ -143,13 +144,13 @@ func TestPeerConnection_Interceptor(t *testing.T) { | |||||||
| 	wg.Add(1) | 	wg.Add(1) | ||||||
| 	*pending++ | 	*pending++ | ||||||
| 	go func() { | 	go func() { | ||||||
| 		_, readErr := sender.ReadRTCP() | 		_, readErr := sender.ReadRTCP(context.Background()) | ||||||
| 		assert.NoError(t, readErr) | 		assert.NoError(t, readErr) | ||||||
| 		atomic.AddInt32(pending, -1) | 		atomic.AddInt32(pending, -1) | ||||||
| 		wg.Done() | 		wg.Done() | ||||||
|  |  | ||||||
| 		for { | 		for { | ||||||
| 			_, readErr = sender.ReadRTCP() | 			_, readErr = sender.ReadRTCP(context.Background()) | ||||||
| 			if readErr != nil { | 			if readErr != nil { | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| @@ -166,7 +167,9 @@ func TestPeerConnection_Interceptor(t *testing.T) { | |||||||
| 		defer wg.Done() | 		defer wg.Done() | ||||||
| 		for { | 		for { | ||||||
| 			time.Sleep(time.Millisecond * 100) | 			time.Sleep(time.Millisecond * 100) | ||||||
| 			if routineErr := track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}); routineErr != nil { | 			if routineErr := track.WriteSample( | ||||||
|  | 				context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second}, | ||||||
|  | 			); routineErr != nil { | ||||||
| 				t.Error(routineErr) | 				t.Error(routineErr) | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
|  |  | ||||||
| 	"github.com/pion/interceptor" | 	"github.com/pion/interceptor" | ||||||
| @@ -18,12 +19,12 @@ func (i *interceptorTrackLocalWriter) setRTPWriter(writer interceptor.RTPWriter) | |||||||
| 	i.rtpWriter.Store(writer) | 	i.rtpWriter.Store(writer) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (i *interceptorTrackLocalWriter) WriteRTP(header *rtp.Header, payload []byte) (int, error) { | func (i *interceptorTrackLocalWriter) WriteRTP(ctx context.Context, header *rtp.Header, payload []byte) (int, error) { | ||||||
| 	writer := i.rtpWriter.Load().(interceptor.RTPWriter) | 	writer := i.rtpWriter.Load().(interceptor.RTPWriter) | ||||||
|  |  | ||||||
| 	if writer == nil { | 	if writer == nil { | ||||||
| 		return 0, nil | 		return 0, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return writer.Write(&rtp.Packet{Header: *header, Payload: payload}, make(interceptor.Attributes)) | 	return writer.Write(ctx, &rtp.Packet{Header: *header, Payload: payload}, make(interceptor.Attributes)) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package mux | package mux | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net" | 	"net" | ||||||
| @@ -34,12 +35,23 @@ func (e *Endpoint) close() error { | |||||||
| // Read reads a packet of len(p) bytes from the underlying conn | // Read reads a packet of len(p) bytes from the underlying conn | ||||||
| // that are matched by the associated MuxFunc | // that are matched by the associated MuxFunc | ||||||
| func (e *Endpoint) Read(p []byte) (int, error) { | func (e *Endpoint) Read(p []byte) (int, error) { | ||||||
| 	return e.buffer.Read(p) | 	return e.buffer.ReadContext(context.Background(), p) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ReadContext reads a packet of len(p) bytes from the underlying conn | ||||||
|  | // that are matched by the associated MuxFunc | ||||||
|  | func (e *Endpoint) ReadContext(ctx context.Context, p []byte) (int, error) { | ||||||
|  | 	return e.buffer.ReadContext(ctx, p) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Write writes len(p) bytes to the underlying conn | // Write writes len(p) bytes to the underlying conn | ||||||
| func (e *Endpoint) Write(p []byte) (int, error) { | func (e *Endpoint) Write(p []byte) (int, error) { | ||||||
| 	n, err := e.mux.nextConn.Write(p) | 	return e.WriteContext(context.Background(), p) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // WriteContext writes len(p) bytes to the underlying conn | ||||||
|  | func (e *Endpoint) WriteContext(ctx context.Context, p []byte) (int, error) { | ||||||
|  | 	n, err := e.mux.nextConn.WriteContext(ctx, p) | ||||||
| 	if errors.Is(err, ice.ErrNoCandidatePairs) { | 	if errors.Is(err, ice.ErrNoCandidatePairs) { | ||||||
| 		return 0, nil | 		return 0, nil | ||||||
| 	} else if errors.Is(err, ice.ErrClosed) { | 	} else if errors.Is(err, ice.ErrClosed) { | ||||||
|   | |||||||
| @@ -2,10 +2,11 @@ | |||||||
| package mux | package mux | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"net" | 	"context" | ||||||
| 	"sync" | 	"sync" | ||||||
|  |  | ||||||
| 	"github.com/pion/logging" | 	"github.com/pion/logging" | ||||||
|  | 	"github.com/pion/transport/connctx" | ||||||
| 	"github.com/pion/transport/packetio" | 	"github.com/pion/transport/packetio" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -15,7 +16,7 @@ const maxBufferSize = 1000 * 1000 // 1MB | |||||||
| // Config collects the arguments to mux.Mux construction into | // Config collects the arguments to mux.Mux construction into | ||||||
| // a single structure | // a single structure | ||||||
| type Config struct { | type Config struct { | ||||||
| 	Conn          net.Conn | 	Conn          connctx.ConnCtx | ||||||
| 	BufferSize    int | 	BufferSize    int | ||||||
| 	LoggerFactory logging.LoggerFactory | 	LoggerFactory logging.LoggerFactory | ||||||
| } | } | ||||||
| @@ -23,7 +24,7 @@ type Config struct { | |||||||
| // Mux allows multiplexing | // Mux allows multiplexing | ||||||
| type Mux struct { | type Mux struct { | ||||||
| 	lock       sync.RWMutex | 	lock       sync.RWMutex | ||||||
| 	nextConn   net.Conn | 	nextConn   connctx.ConnCtx | ||||||
| 	endpoints  map[*Endpoint]MatchFunc | 	endpoints  map[*Endpoint]MatchFunc | ||||||
| 	bufferSize int | 	bufferSize int | ||||||
| 	closedCh   chan struct{} | 	closedCh   chan struct{} | ||||||
| @@ -32,7 +33,7 @@ type Mux struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| // NewMux creates a new Mux | // NewMux creates a new Mux | ||||||
| func NewMux(config Config) *Mux { | func NewMux(ctx context.Context, config Config) *Mux { | ||||||
| 	m := &Mux{ | 	m := &Mux{ | ||||||
| 		nextConn:   config.Conn, | 		nextConn:   config.Conn, | ||||||
| 		endpoints:  make(map[*Endpoint]MatchFunc), | 		endpoints:  make(map[*Endpoint]MatchFunc), | ||||||
| @@ -41,7 +42,7 @@ func NewMux(config Config) *Mux { | |||||||
| 		log:        config.LoggerFactory.NewLogger("mux"), | 		log:        config.LoggerFactory.NewLogger("mux"), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	go m.readLoop() | 	go m.readLoop(ctx) | ||||||
|  |  | ||||||
| 	return m | 	return m | ||||||
| } | } | ||||||
| @@ -96,26 +97,26 @@ func (m *Mux) Close() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (m *Mux) readLoop() { | func (m *Mux) readLoop(ctx context.Context) { | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		close(m.closedCh) | 		close(m.closedCh) | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	buf := make([]byte, m.bufferSize) | 	buf := make([]byte, m.bufferSize) | ||||||
| 	for { | 	for { | ||||||
| 		n, err := m.nextConn.Read(buf) | 		n, err := m.nextConn.ReadContext(ctx, buf) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		err = m.dispatch(buf[:n]) | 		err = m.dispatch(ctx, buf[:n]) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (m *Mux) dispatch(buf []byte) error { | func (m *Mux) dispatch(ctx context.Context, buf []byte) error { | ||||||
| 	var endpoint *Endpoint | 	var endpoint *Endpoint | ||||||
|  |  | ||||||
| 	m.lock.Lock() | 	m.lock.Lock() | ||||||
| @@ -136,7 +137,7 @@ func (m *Mux) dispatch(buf []byte) error { | |||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	_, err := endpoint.buffer.Write(buf) | 	_, err := endpoint.buffer.WriteContext(ctx, buf) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -1,29 +1,33 @@ | |||||||
| package mux | package mux | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"net" | 	"net" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/pion/logging" | 	"github.com/pion/logging" | ||||||
|  | 	"github.com/pion/transport/connctx" | ||||||
| 	"github.com/pion/transport/test" | 	"github.com/pion/transport/test" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestStressDuplex(t *testing.T) { | func TestStressDuplex(t *testing.T) { | ||||||
| 	// Limit runtime in case of deadlocks | 	lim := test.TimeOut(time.Second * 30) | ||||||
| 	lim := test.TimeOut(time.Second * 20) |  | ||||||
| 	defer lim.Stop() | 	defer lim.Stop() | ||||||
|  |  | ||||||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  |  | ||||||
| 	// Check for leaking routines | 	// Check for leaking routines | ||||||
| 	report := test.CheckRoutines(t) | 	report := test.CheckRoutines(t) | ||||||
| 	defer report() | 	defer report() | ||||||
|  |  | ||||||
| 	// Run the test | 	// Run the test | ||||||
| 	stressDuplex(t) | 	stressDuplex(ctx, t) | ||||||
| } | } | ||||||
|  |  | ||||||
| func stressDuplex(t *testing.T) { | func stressDuplex(ctx context.Context, t *testing.T) { | ||||||
| 	ca, cb, stop := pipeMemory() | 	ca, cb, stop := pipeMemory(ctx) | ||||||
|  |  | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		stop(t) | 		stop(t) | ||||||
| @@ -34,13 +38,21 @@ func stressDuplex(t *testing.T) { | |||||||
| 		MsgCount: 100, | 		MsgCount: 100, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err := test.StressDuplex(ca, cb, opt) | 	t.Run("WithoutContext", func(t *testing.T) { | ||||||
| 	if err != nil { | 		err := test.StressDuplex(ca, cb.Conn(), opt) | ||||||
| 		t.Fatal(err) | 		if err != nil { | ||||||
| 	} | 			t.Fatal(err) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	t.Run("WithContext", func(t *testing.T) { | ||||||
|  | 		err := test.StressDuplexContext(context.Background(), ca, cb, opt) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func pipeMemory() (*Endpoint, net.Conn, func(*testing.T)) { | func pipeMemory(ctx context.Context) (*Endpoint, connctx.ConnCtx, func(*testing.T)) { | ||||||
| 	// In memory pipe | 	// In memory pipe | ||||||
| 	ca, cb := net.Pipe() | 	ca, cb := net.Pipe() | ||||||
|  |  | ||||||
| @@ -49,12 +61,12 @@ func pipeMemory() (*Endpoint, net.Conn, func(*testing.T)) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	config := Config{ | 	config := Config{ | ||||||
| 		Conn:          ca, | 		Conn:          connctx.New(ca), | ||||||
| 		BufferSize:    8192, | 		BufferSize:    8192, | ||||||
| 		LoggerFactory: logging.NewDefaultLoggerFactory(), | 		LoggerFactory: logging.NewDefaultLoggerFactory(), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	m := NewMux(config) | 	m := NewMux(ctx, config) | ||||||
| 	e := m.NewEndpoint(matchAll) | 	e := m.NewEndpoint(matchAll) | ||||||
| 	m.RemoveEndpoint(e) | 	m.RemoveEndpoint(e) | ||||||
| 	e = m.NewEndpoint(matchAll) | 	e = m.NewEndpoint(matchAll) | ||||||
| @@ -70,7 +82,7 @@ func pipeMemory() (*Endpoint, net.Conn, func(*testing.T)) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return e, cb, stop | 	return e, connctx.New(cb), stop | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestNoEndpoints(t *testing.T) { | func TestNoEndpoints(t *testing.T) { | ||||||
| @@ -82,13 +94,16 @@ func TestNoEndpoints(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	config := Config{ | 	config := Config{ | ||||||
| 		Conn:          ca, | 		Conn:          connctx.New(ca), | ||||||
| 		BufferSize:    8192, | 		BufferSize:    8192, | ||||||
| 		LoggerFactory: logging.NewDefaultLoggerFactory(), | 		LoggerFactory: logging.NewDefaultLoggerFactory(), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	m := NewMux(config) | 	ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) | ||||||
| 	err = m.dispatch(make([]byte, 1)) | 	defer cancel() | ||||||
|  |  | ||||||
|  | 	m := NewMux(ctx, config) | ||||||
|  | 	err = m.dispatch(ctx, make([]byte, 1)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -1,11 +1,12 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"sync" | 	"sync" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Operation is a function | // Operation is a function | ||||||
| type operation func() | type operation func(ctx context.Context) | ||||||
|  |  | ||||||
| // Operations is a task executor. | // Operations is a task executor. | ||||||
| type operations struct { | type operations struct { | ||||||
| @@ -32,7 +33,7 @@ func (o *operations) Enqueue(op operation) { | |||||||
| 	o.mu.Unlock() | 	o.mu.Unlock() | ||||||
|  |  | ||||||
| 	if !running { | 	if !running { | ||||||
| 		go o.start() | 		go o.start(context.Background()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -48,13 +49,13 @@ func (o *operations) IsEmpty() bool { | |||||||
| func (o *operations) Done() { | func (o *operations) Done() { | ||||||
| 	var wg sync.WaitGroup | 	var wg sync.WaitGroup | ||||||
| 	wg.Add(1) | 	wg.Add(1) | ||||||
| 	o.Enqueue(func() { | 	o.Enqueue(func(_ context.Context) { | ||||||
| 		wg.Done() | 		wg.Done() | ||||||
| 	}) | 	}) | ||||||
| 	wg.Wait() | 	wg.Wait() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (o *operations) pop() func() { | func (o *operations) pop() func(context.Context) { | ||||||
| 	o.mu.Lock() | 	o.mu.Lock() | ||||||
| 	defer o.mu.Unlock() | 	defer o.mu.Unlock() | ||||||
| 	if len(o.ops) == 0 { | 	if len(o.ops) == 0 { | ||||||
| @@ -66,7 +67,7 @@ func (o *operations) pop() func() { | |||||||
| 	return fn | 	return fn | ||||||
| } | } | ||||||
|  |  | ||||||
| func (o *operations) start() { | func (o *operations) start(ctx context.Context) { | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		o.mu.Lock() | 		o.mu.Lock() | ||||||
| 		defer o.mu.Unlock() | 		defer o.mu.Unlock() | ||||||
| @@ -76,12 +77,12 @@ func (o *operations) start() { | |||||||
| 		} | 		} | ||||||
| 		// either a new operation was enqueued while we | 		// either a new operation was enqueued while we | ||||||
| 		// were busy, or an operation panicked | 		// were busy, or an operation panicked | ||||||
| 		go o.start() | 		go o.start(ctx) | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	fn := o.pop() | 	fn := o.pop() | ||||||
| 	for fn != nil { | 	for fn != nil { | ||||||
| 		fn() | 		fn(ctx) | ||||||
| 		fn = o.pop() | 		fn = o.pop() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| @@ -12,7 +13,7 @@ func TestOperations_Enqueue(t *testing.T) { | |||||||
| 		results := make([]int, 16) | 		results := make([]int, 16) | ||||||
| 		for i := range results { | 		for i := range results { | ||||||
| 			func(j int) { | 			func(j int) { | ||||||
| 				ops.Enqueue(func() { | 				ops.Enqueue(func(_ context.Context) { | ||||||
| 					results[j] = j * j | 					results[j] = j * j | ||||||
| 				}) | 				}) | ||||||
| 			}(i) | 			}(i) | ||||||
|   | |||||||
| @@ -3,11 +3,11 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"crypto/ecdsa" | 	"crypto/ecdsa" | ||||||
| 	"crypto/elliptic" | 	"crypto/elliptic" | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" |  | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| @@ -19,6 +19,7 @@ import ( | |||||||
| 	"github.com/pion/logging" | 	"github.com/pion/logging" | ||||||
| 	"github.com/pion/rtcp" | 	"github.com/pion/rtcp" | ||||||
| 	"github.com/pion/sdp/v3" | 	"github.com/pion/sdp/v3" | ||||||
|  | 	"github.com/pion/transport/connctx" | ||||||
| 	"github.com/pion/webrtc/v3/internal/util" | 	"github.com/pion/webrtc/v3/internal/util" | ||||||
| 	"github.com/pion/webrtc/v3/pkg/rtcerr" | 	"github.com/pion/webrtc/v3/pkg/rtcerr" | ||||||
| ) | ) | ||||||
| @@ -275,7 +276,7 @@ func (pc *PeerConnection) onNegotiationNeeded() { | |||||||
| 	pc.ops.Enqueue(pc.negotiationNeededOp) | 	pc.ops.Enqueue(pc.negotiationNeededOp) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (pc *PeerConnection) negotiationNeededOp() { | func (pc *PeerConnection) negotiationNeededOp(ctx context.Context) { | ||||||
| 	// Don't run NegotiatedNeeded checks if OnNegotiationNeeded is not set | 	// Don't run NegotiatedNeeded checks if OnNegotiationNeeded is not set | ||||||
| 	if handler := pc.onNegotiationNeededHandler.Load(); handler == nil { | 	if handler := pc.onNegotiationNeededHandler.Load(); handler == nil { | ||||||
| 		return | 		return | ||||||
| @@ -945,8 +946,8 @@ func (pc *PeerConnection) SetLocalDescription(desc SessionDescription) error { | |||||||
| 		if err := pc.startRTPSenders(currentTransceivers); err != nil { | 		if err := pc.startRTPSenders(currentTransceivers); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		pc.ops.Enqueue(func() { | 		pc.ops.Enqueue(func(ctx context.Context) { | ||||||
| 			pc.startRTP(haveLocalDescription, remoteDesc, currentTransceivers) | 			pc.startRTP(ctx, haveLocalDescription, remoteDesc, currentTransceivers) | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -1066,8 +1067,8 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error { | |||||||
| 			if err = pc.startRTPSenders(currentTransceivers); err != nil { | 			if err = pc.startRTPSenders(currentTransceivers); err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 			pc.ops.Enqueue(func() { | 			pc.ops.Enqueue(func(ctx context.Context) { | ||||||
| 				pc.startRTP(true, &desc, currentTransceivers) | 				pc.startRTP(ctx, true, &desc, currentTransceivers) | ||||||
| 			}) | 			}) | ||||||
| 		} | 		} | ||||||
| 		return nil | 		return nil | ||||||
| @@ -1099,16 +1100,16 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	pc.ops.Enqueue(func() { | 	pc.ops.Enqueue(func(ctx context.Context) { | ||||||
| 		pc.startTransports(iceRole, dtlsRoleFromRemoteSDP(desc.parsed), remoteUfrag, remotePwd, fingerprint, fingerprintHash) | 		pc.startTransports(ctx, iceRole, dtlsRoleFromRemoteSDP(desc.parsed), remoteUfrag, remotePwd, fingerprint, fingerprintHash) | ||||||
| 		if weOffer { | 		if weOffer { | ||||||
| 			pc.startRTP(false, &desc, currentTransceivers) | 			pc.startRTP(ctx, false, &desc, currentTransceivers) | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPReceiver) { | func (pc *PeerConnection) startReceiver(ctx context.Context, incoming trackDetails, receiver *RTPReceiver) { | ||||||
| 	encodings := []RTPDecodingParameters{} | 	encodings := []RTPDecodingParameters{} | ||||||
| 	if incoming.ssrc != 0 { | 	if incoming.ssrc != 0 { | ||||||
| 		encodings = append(encodings, RTPDecodingParameters{RTPCodingParameters{SSRC: incoming.ssrc}}) | 		encodings = append(encodings, RTPDecodingParameters{RTPCodingParameters{SSRC: incoming.ssrc}}) | ||||||
| @@ -1137,7 +1138,7 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	go func() { | 	go func() { | ||||||
| 		if err := receiver.Track().determinePayloadType(); err != nil { | 		if err := receiver.Track().determinePayloadType(ctx); err != nil { | ||||||
| 			pc.log.Warnf("Could not determine PayloadType for SSRC %d", receiver.Track().SSRC()) | 			pc.log.Warnf("Could not determine PayloadType for SSRC %d", receiver.Track().SSRC()) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @@ -1160,7 +1161,7 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece | |||||||
| } | } | ||||||
|  |  | ||||||
| // startRTPReceivers opens knows inbound SRTP streams from the RemoteDescription | // startRTPReceivers opens knows inbound SRTP streams from the RemoteDescription | ||||||
| func (pc *PeerConnection) startRTPReceivers(incomingTracks []trackDetails, currentTransceivers []*RTPTransceiver) { //nolint:gocognit | func (pc *PeerConnection) startRTPReceivers(ctx context.Context, incomingTracks []trackDetails, currentTransceivers []*RTPTransceiver) { //nolint:gocognit | ||||||
| 	localTransceivers := append([]*RTPTransceiver{}, currentTransceivers...) | 	localTransceivers := append([]*RTPTransceiver{}, currentTransceivers...) | ||||||
|  |  | ||||||
| 	remoteIsPlanB := false | 	remoteIsPlanB := false | ||||||
| @@ -1207,7 +1208,7 @@ func (pc *PeerConnection) startRTPReceivers(incomingTracks []trackDetails, curre | |||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			pc.startReceiver(incomingTrack, t.Receiver()) | 			pc.startReceiver(ctx, incomingTrack, t.Receiver()) | ||||||
| 			trackHandled = true | 			trackHandled = true | ||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
| @@ -1226,7 +1227,7 @@ func (pc *PeerConnection) startRTPReceivers(incomingTracks []trackDetails, curre | |||||||
| 				pc.log.Warnf("Could not add transceiver for remote SSRC %d: %s", incoming.ssrc, err) | 				pc.log.Warnf("Could not add transceiver for remote SSRC %d: %s", incoming.ssrc, err) | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 			pc.startReceiver(incoming, t.Receiver()) | 			pc.startReceiver(ctx, incoming, t.Receiver()) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -1253,9 +1254,9 @@ func (pc *PeerConnection) startRTPSenders(currentTransceivers []*RTPTransceiver) | |||||||
| } | } | ||||||
|  |  | ||||||
| // Start SCTP subsystem | // Start SCTP subsystem | ||||||
| func (pc *PeerConnection) startSCTP() { | func (pc *PeerConnection) startSCTP(ctx context.Context) { | ||||||
| 	// Start sctp | 	// Start sctp | ||||||
| 	if err := pc.sctpTransport.Start(SCTPCapabilities{ | 	if err := pc.sctpTransport.Start(ctx, SCTPCapabilities{ | ||||||
| 		MaxMessageSize: 0, | 		MaxMessageSize: 0, | ||||||
| 	}); err != nil { | 	}); err != nil { | ||||||
| 		pc.log.Warnf("Failed to start SCTP: %s", err) | 		pc.log.Warnf("Failed to start SCTP: %s", err) | ||||||
| @@ -1289,7 +1290,7 @@ func (pc *PeerConnection) startSCTP() { | |||||||
| 	pc.sctpTransport.lock.Unlock() | 	pc.sctpTransport.lock.Unlock() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (pc *PeerConnection) handleUndeclaredSSRC(rtpStream io.Reader, ssrc SSRC) error { //nolint:gocognit | func (pc *PeerConnection) handleUndeclaredSSRC(ctx context.Context, rtpStream connctx.Reader, ssrc SSRC) error { //nolint:gocognit | ||||||
| 	remoteDescription := pc.RemoteDescription() | 	remoteDescription := pc.RemoteDescription() | ||||||
| 	if remoteDescription == nil { | 	if remoteDescription == nil { | ||||||
| 		return errPeerConnRemoteDescriptionNil | 		return errPeerConnRemoteDescriptionNil | ||||||
| @@ -1318,7 +1319,7 @@ func (pc *PeerConnection) handleUndeclaredSSRC(rtpStream io.Reader, ssrc SSRC) e | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return fmt.Errorf("%w: %d: %s", errPeerConnRemoteSSRCAddTransceiver, ssrc, err) | 			return fmt.Errorf("%w: %d: %s", errPeerConnRemoteSSRCAddTransceiver, ssrc, err) | ||||||
| 		} | 		} | ||||||
| 		pc.startReceiver(incoming, t.Receiver()) | 		pc.startReceiver(ctx, incoming, t.Receiver()) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -1335,7 +1336,7 @@ func (pc *PeerConnection) handleUndeclaredSSRC(rtpStream io.Reader, ssrc SSRC) e | |||||||
| 	b := make([]byte, receiveMTU) | 	b := make([]byte, receiveMTU) | ||||||
| 	var mid, rid string | 	var mid, rid string | ||||||
| 	for readCount := 0; readCount <= simulcastProbeCount; readCount++ { | 	for readCount := 0; readCount <= simulcastProbeCount; readCount++ { | ||||||
| 		i, err := rtpStream.Read(b) | 		i, err := rtpStream.ReadContext(ctx, b) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| @@ -1379,7 +1380,7 @@ func (pc *PeerConnection) handleUndeclaredSSRC(rtpStream io.Reader, ssrc SSRC) e | |||||||
| } | } | ||||||
|  |  | ||||||
| // undeclaredMediaProcessor handles RTP/RTCP packets that don't match any a:ssrc lines | // undeclaredMediaProcessor handles RTP/RTCP packets that don't match any a:ssrc lines | ||||||
| func (pc *PeerConnection) undeclaredMediaProcessor() { | func (pc *PeerConnection) undeclaredMediaProcessor(ctx context.Context) { | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for { | 		for { | ||||||
| 			srtpSession, err := pc.dtlsTransport.getSRTPSession() | 			srtpSession, err := pc.dtlsTransport.getSRTPSession() | ||||||
| @@ -1394,7 +1395,7 @@ func (pc *PeerConnection) undeclaredMediaProcessor() { | |||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			if err := pc.handleUndeclaredSSRC(stream, SSRC(ssrc)); err != nil { | 			if err := pc.handleUndeclaredSSRC(ctx, stream, SSRC(ssrc)); err != nil { | ||||||
| 				pc.log.Errorf("Incoming unhandled RTP ssrc(%d), OnTrack will not be fired. %v", ssrc, err) | 				pc.log.Errorf("Incoming unhandled RTP ssrc(%d), OnTrack will not be fired. %v", ssrc, err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| @@ -1753,12 +1754,12 @@ func (pc *PeerConnection) SetIdentityProvider(provider string) error { | |||||||
|  |  | ||||||
| // WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the | // WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the | ||||||
| // packet is discarded. It also runs any configured interceptors. | // packet is discarded. It also runs any configured interceptors. | ||||||
| func (pc *PeerConnection) WriteRTCP(pkts []rtcp.Packet) error { | func (pc *PeerConnection) WriteRTCP(ctx context.Context, pkts []rtcp.Packet) error { | ||||||
| 	_, err := pc.interceptorRTCPWriter.Write(pkts, make(interceptor.Attributes)) | 	_, err := pc.interceptorRTCPWriter.Write(ctx, pkts, make(interceptor.Attributes)) | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func (pc *PeerConnection) writeRTCP(pkts []rtcp.Packet, _ interceptor.Attributes) (int, error) { | func (pc *PeerConnection) writeRTCP(ctx context.Context, pkts []rtcp.Packet, _ interceptor.Attributes) (int, error) { | ||||||
| 	raw, err := rtcp.Marshal(pkts) | 	raw, err := rtcp.Marshal(pkts) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return 0, err | 		return 0, err | ||||||
| @@ -1774,7 +1775,7 @@ func (pc *PeerConnection) writeRTCP(pkts []rtcp.Packet, _ interceptor.Attributes | |||||||
| 		return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err) | 		return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if n, err := writeStream.Write(raw); err != nil { | 	if n, err := writeStream.WriteContext(ctx, raw); err != nil { | ||||||
| 		return n, err | 		return n, err | ||||||
| 	} | 	} | ||||||
| 	return 0, nil | 	return 0, nil | ||||||
| @@ -1985,9 +1986,10 @@ func (pc *PeerConnection) GetStats() StatsReport { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Start all transports. PeerConnection now has enough state | // Start all transports. PeerConnection now has enough state | ||||||
| func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, remoteUfrag, remotePwd, fingerprint, fingerprintHash string) { | func (pc *PeerConnection) startTransports(ctx context.Context, iceRole ICERole, dtlsRole DTLSRole, remoteUfrag, remotePwd, fingerprint, fingerprintHash string) { | ||||||
| 	// Start the ice transport | 	// Start the ice transport | ||||||
| 	err := pc.iceTransport.Start( | 	err := pc.iceTransport.Start( | ||||||
|  | 		ctx, | ||||||
| 		pc.iceGatherer, | 		pc.iceGatherer, | ||||||
| 		ICEParameters{ | 		ICEParameters{ | ||||||
| 			UsernameFragment: remoteUfrag, | 			UsernameFragment: remoteUfrag, | ||||||
| @@ -2002,7 +2004,7 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Start the dtls transport | 	// Start the dtls transport | ||||||
| 	err = pc.dtlsTransport.Start(DTLSParameters{ | 	err = pc.dtlsTransport.Start(ctx, DTLSParameters{ | ||||||
| 		Role:         dtlsRole, | 		Role:         dtlsRole, | ||||||
| 		Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}}, | 		Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}}, | ||||||
| 	}) | 	}) | ||||||
| @@ -2013,7 +2015,7 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (pc *PeerConnection) startRTP(isRenegotiation bool, remoteDesc *SessionDescription, currentTransceivers []*RTPTransceiver) { | func (pc *PeerConnection) startRTP(ctx context.Context, isRenegotiation bool, remoteDesc *SessionDescription, currentTransceivers []*RTPTransceiver) { | ||||||
| 	trackDetails := trackDetailsFromSDP(pc.log, remoteDesc.parsed) | 	trackDetails := trackDetailsFromSDP(pc.log, remoteDesc.parsed) | ||||||
| 	if isRenegotiation { | 	if isRenegotiation { | ||||||
| 		for _, t := range currentTransceivers { | 		for _, t := range currentTransceivers { | ||||||
| @@ -2047,13 +2049,13 @@ func (pc *PeerConnection) startRTP(isRenegotiation bool, remoteDesc *SessionDesc | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	pc.startRTPReceivers(trackDetails, currentTransceivers) | 	pc.startRTPReceivers(ctx, trackDetails, currentTransceivers) | ||||||
| 	if haveApplicationMediaSection(remoteDesc.parsed) { | 	if haveApplicationMediaSection(remoteDesc.parsed) { | ||||||
| 		pc.startSCTP() | 		pc.startSCTP(ctx) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if !isRenegotiation { | 	if !isRenegotiation { | ||||||
| 		pc.undeclaredMediaProcessor() | 		pc.undeclaredMediaProcessor(ctx) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1137,7 +1137,7 @@ func TestPeerConnection_MassiveTracks(t *testing.T) { | |||||||
| 	<-connected | 	<-connected | ||||||
| 	time.Sleep(1 * time.Second) | 	time.Sleep(1 * time.Second) | ||||||
| 	for _, track := range tracks { | 	for _, track := range tracks { | ||||||
| 		assert.NoError(t, track.WriteRTP(samplePkt)) | 		assert.NoError(t, track.WriteRTP(context.Background(), samplePkt)) | ||||||
| 	} | 	} | ||||||
| 	// Ping trackRecords to see if any track event not received yet. | 	// Ping trackRecords to see if any track event not received yet. | ||||||
| 	tooLong := time.After(timeoutDuration) | 	tooLong := time.After(timeoutDuration) | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ package webrtc | |||||||
| import ( | import ( | ||||||
| 	"bufio" | 	"bufio" | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| @@ -88,7 +89,10 @@ func TestPeerConnection_Media_Sample(t *testing.T) { | |||||||
| 		go func() { | 		go func() { | ||||||
| 			for { | 			for { | ||||||
| 				time.Sleep(time.Millisecond * 100) | 				time.Sleep(time.Millisecond * 100) | ||||||
| 				if routineErr := pcAnswer.WriteRTCP([]rtcp.Packet{&rtcp.RapidResynchronizationRequest{SenderSSRC: uint32(track.SSRC()), MediaSSRC: uint32(track.SSRC())}}); routineErr != nil { | 				if routineErr := pcAnswer.WriteRTCP( | ||||||
|  | 					context.Background(), | ||||||
|  | 					[]rtcp.Packet{&rtcp.RapidResynchronizationRequest{SenderSSRC: uint32(track.SSRC()), MediaSSRC: uint32(track.SSRC())}}, | ||||||
|  | 				); routineErr != nil { | ||||||
| 					awaitRTCPReceiverSend <- routineErr | 					awaitRTCPReceiverSend <- routineErr | ||||||
| 					return | 					return | ||||||
| 				} | 				} | ||||||
| @@ -103,7 +107,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) { | |||||||
| 		}() | 		}() | ||||||
|  |  | ||||||
| 		go func() { | 		go func() { | ||||||
| 			_, routineErr := receiver.Read(make([]byte, 1400)) | 			_, routineErr := receiver.Read(context.Background(), make([]byte, 1400)) | ||||||
| 			if routineErr != nil { | 			if routineErr != nil { | ||||||
| 				awaitRTCPReceiverRecv <- routineErr | 				awaitRTCPReceiverRecv <- routineErr | ||||||
| 			} else { | 			} else { | ||||||
| @@ -113,7 +117,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) { | |||||||
|  |  | ||||||
| 		haveClosedAwaitRTPRecv := false | 		haveClosedAwaitRTPRecv := false | ||||||
| 		for { | 		for { | ||||||
| 			p, routineErr := track.ReadRTP() | 			p, routineErr := track.ReadRTP(context.Background()) | ||||||
| 			if routineErr != nil { | 			if routineErr != nil { | ||||||
| 				close(awaitRTPRecvClosed) | 				close(awaitRTPRecvClosed) | ||||||
| 				return | 				return | ||||||
| @@ -136,7 +140,9 @@ func TestPeerConnection_Media_Sample(t *testing.T) { | |||||||
| 	go func() { | 	go func() { | ||||||
| 		for { | 		for { | ||||||
| 			time.Sleep(time.Millisecond * 100) | 			time.Sleep(time.Millisecond * 100) | ||||||
| 			if routineErr := vp8Track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}); routineErr != nil { | 			if routineErr := vp8Track.WriteSample( | ||||||
|  | 				context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second}, | ||||||
|  | 			); routineErr != nil { | ||||||
| 				fmt.Println(routineErr) | 				fmt.Println(routineErr) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| @@ -152,7 +158,10 @@ func TestPeerConnection_Media_Sample(t *testing.T) { | |||||||
| 	go func() { | 	go func() { | ||||||
| 		for { | 		for { | ||||||
| 			time.Sleep(time.Millisecond * 100) | 			time.Sleep(time.Millisecond * 100) | ||||||
| 			if routineErr := pcOffer.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{SenderSSRC: uint32(sender.ssrc), MediaSSRC: uint32(sender.ssrc)}}); routineErr != nil { | 			if routineErr := pcOffer.WriteRTCP( | ||||||
|  | 				context.Background(), | ||||||
|  | 				[]rtcp.Packet{&rtcp.PictureLossIndication{SenderSSRC: uint32(sender.ssrc), MediaSSRC: uint32(sender.ssrc)}}, | ||||||
|  | 			); routineErr != nil { | ||||||
| 				awaitRTCPSenderSend <- routineErr | 				awaitRTCPSenderSend <- routineErr | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| @@ -166,7 +175,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) { | |||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	go func() { | 	go func() { | ||||||
| 		if _, routineErr := sender.Read(make([]byte, 1400)); routineErr == nil { | 		if _, routineErr := sender.Read(context.Background(), make([]byte, 1400)); routineErr == nil { | ||||||
| 			close(awaitRTCPSenderRecv) | 			close(awaitRTCPSenderRecv) | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| @@ -366,9 +375,13 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) { | |||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	for i := 0; i <= 5; i++ { | 	for i := 0; i <= 5; i++ { | ||||||
| 		if rtpErr := vp8Track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}); rtpErr != nil { | 		if rtpErr := vp8Track.WriteSample( | ||||||
|  | 			context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second}, | ||||||
|  | 		); rtpErr != nil { | ||||||
| 			t.Fatal(rtpErr) | 			t.Fatal(rtpErr) | ||||||
| 		} else if rtcpErr := pcOffer.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: 0}}); rtcpErr != nil { | 		} else if rtcpErr := pcOffer.WriteRTCP( | ||||||
|  | 			context.Background(), []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: 0}}, | ||||||
|  | 		); rtcpErr != nil { | ||||||
| 			t.Fatal(rtcpErr) | 			t.Fatal(rtcpErr) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @@ -448,7 +461,9 @@ func TestUndeclaredSSRC(t *testing.T) { | |||||||
|  |  | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for { | 		for { | ||||||
| 			assert.NoError(t, vp8Writer.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second})) | 			assert.NoError(t, vp8Writer.WriteSample( | ||||||
|  | 				context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second}, | ||||||
|  | 			)) | ||||||
| 			time.Sleep(time.Millisecond * 25) | 			time.Sleep(time.Millisecond * 25) | ||||||
|  |  | ||||||
| 			select { | 			select { | ||||||
| @@ -686,11 +701,11 @@ func TestRtpSenderReceiver_ReadClose_Error(t *testing.T) { | |||||||
|  |  | ||||||
| 	sender, receiver := tr.Sender(), tr.Receiver() | 	sender, receiver := tr.Sender(), tr.Receiver() | ||||||
| 	assert.NoError(t, sender.Stop()) | 	assert.NoError(t, sender.Stop()) | ||||||
| 	_, err = sender.Read(make([]byte, 0, 1400)) | 	_, err = sender.Read(context.Background(), make([]byte, 0, 1400)) | ||||||
| 	assert.Error(t, err, io.ErrClosedPipe) | 	assert.Error(t, err, io.ErrClosedPipe) | ||||||
|  |  | ||||||
| 	assert.NoError(t, receiver.Stop()) | 	assert.NoError(t, receiver.Stop()) | ||||||
| 	_, err = receiver.Read(make([]byte, 0, 1400)) | 	_, err = receiver.Read(context.Background(), make([]byte, 0, 1400)) | ||||||
| 	assert.Error(t, err, io.ErrClosedPipe) | 	assert.Error(t, err, io.ErrClosedPipe) | ||||||
|  |  | ||||||
| 	assert.NoError(t, pc.Close()) | 	assert.NoError(t, pc.Close()) | ||||||
| @@ -837,7 +852,9 @@ func TestPlanBMediaExchange(t *testing.T) { | |||||||
| 				select { | 				select { | ||||||
| 				case <-time.After(20 * time.Millisecond): | 				case <-time.After(20 * time.Millisecond): | ||||||
| 					for _, track := range outboundTracks { | 					for _, track := range outboundTracks { | ||||||
| 						assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second})) | 						assert.NoError(t, track.WriteSample( | ||||||
|  | 							context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second}, | ||||||
|  | 						)) | ||||||
| 					} | 					} | ||||||
| 				case <-done: | 				case <-done: | ||||||
| 					return | 					return | ||||||
|   | |||||||
| @@ -25,7 +25,9 @@ func sendVideoUntilDone(done <-chan struct{}, t *testing.T, tracks []*TrackLocal | |||||||
| 		select { | 		select { | ||||||
| 		case <-time.After(20 * time.Millisecond): | 		case <-time.After(20 * time.Millisecond): | ||||||
| 			for _, track := range tracks { | 			for _, track := range tracks { | ||||||
| 				assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second})) | 				assert.NoError(t, track.WriteSample( | ||||||
|  | 					context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second}, | ||||||
|  | 				)) | ||||||
| 			} | 			} | ||||||
| 		case <-done: | 		case <-done: | ||||||
| 			return | 			return | ||||||
| @@ -100,7 +102,9 @@ func TestPeerConnection_Renegotiation_AddTrack(t *testing.T) { | |||||||
|  |  | ||||||
| 	// Send 10 packets, OnTrack MUST not be fired | 	// Send 10 packets, OnTrack MUST not be fired | ||||||
| 	for i := 0; i <= 10; i++ { | 	for i := 0; i <= 10; i++ { | ||||||
| 		assert.NoError(t, vp8Track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second})) | 		assert.NoError(t, vp8Track.WriteSample( | ||||||
|  | 			context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second}, | ||||||
|  | 		)) | ||||||
| 		time.Sleep(20 * time.Millisecond) | 		time.Sleep(20 * time.Millisecond) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -360,7 +364,7 @@ func TestPeerConnection_Renegotiation_CodecChange(t *testing.T) { | |||||||
| 	pcAnswer.OnTrack(func(track *TrackRemote, r *RTPReceiver) { | 	pcAnswer.OnTrack(func(track *TrackRemote, r *RTPReceiver) { | ||||||
| 		tracksCh <- track | 		tracksCh <- track | ||||||
| 		for { | 		for { | ||||||
| 			if _, readErr := track.ReadRTP(); readErr == io.EOF { | 			if _, readErr := track.ReadRTP(context.Background()); readErr == io.EOF { | ||||||
| 				tracksClosed <- struct{}{} | 				tracksClosed <- struct{}{} | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| @@ -450,7 +454,7 @@ func TestPeerConnection_Renegotiation_RemoveTrack(t *testing.T) { | |||||||
| 		onTrackFiredFunc() | 		onTrackFiredFunc() | ||||||
|  |  | ||||||
| 		for { | 		for { | ||||||
| 			if _, err := track.ReadRTP(); err == io.EOF { | 			if _, err := track.ReadRTP(context.Background()); err == io.EOF { | ||||||
| 				trackClosedFunc() | 				trackClosedFunc() | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| @@ -838,7 +842,9 @@ func TestNegotiationNeededRemoveTrack(t *testing.T) { | |||||||
| 	sender, err := pcOffer.AddTrack(track) | 	sender, err := pcOffer.AddTrack(track) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
|  |  | ||||||
| 	assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second})) | 	assert.NoError(t, track.WriteSample( | ||||||
|  | 		context.Background(), media.Sample{Data: []byte{0x00}, Duration: time.Second}, | ||||||
|  | 	)) | ||||||
|  |  | ||||||
| 	wg.Wait() | 	wg.Wait() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ package webrtc | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"encoding/binary" | 	"encoding/binary" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -142,7 +143,7 @@ func (s *testQuicStack) setSignal(sig *testQuicSignal, isOffer bool) error { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Start the ICE transport | 	// Start the ICE transport | ||||||
| 	err = s.ice.Start(nil, sig.ICEParameters, &iceRole) | 	err = s.ice.Start(context.Background(), nil, sig.ICEParameters, &iceRole) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -3,13 +3,14 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"sync" | 	"sync" | ||||||
|  |  | ||||||
| 	"github.com/pion/interceptor" | 	"github.com/pion/interceptor" | ||||||
| 	"github.com/pion/rtcp" | 	"github.com/pion/rtcp" | ||||||
| 	"github.com/pion/srtp" | 	"github.com/pion/srtp/v2" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // trackStreams maintains a mapping of RTP/RTCP streams to a specific track | // trackStreams maintains a mapping of RTP/RTCP streams to a specific track | ||||||
| @@ -132,41 +133,45 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Read reads incoming RTCP for this RTPReceiver | // Read reads incoming RTCP for this RTPReceiver | ||||||
| func (r *RTPReceiver) Read(b []byte) (n int, err error) { | func (r *RTPReceiver) Read(ctx context.Context, b []byte) (n int, err error) { | ||||||
| 	select { | 	select { | ||||||
| 	case <-r.received: | 	case <-r.received: | ||||||
| 		return r.tracks[0].rtcpReadStream.Read(b) | 		return r.tracks[0].rtcpReadStream.ReadContext(ctx, b) | ||||||
| 	case <-r.closed: | 	case <-r.closed: | ||||||
| 		return 0, io.ErrClosedPipe | 		return 0, io.ErrClosedPipe | ||||||
|  | 	case <-ctx.Done(): | ||||||
|  | 		return 0, ctx.Err() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // ReadSimulcast reads incoming RTCP for this RTPReceiver for given rid | // ReadSimulcast reads incoming RTCP for this RTPReceiver for given rid | ||||||
| func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, err error) { | func (r *RTPReceiver) ReadSimulcast(ctx context.Context, b []byte, rid string) (n int, err error) { | ||||||
| 	select { | 	select { | ||||||
| 	case <-r.received: | 	case <-r.received: | ||||||
| 		for _, t := range r.tracks { | 		for _, t := range r.tracks { | ||||||
| 			if t.track != nil && t.track.rid == rid { | 			if t.track != nil && t.track.rid == rid { | ||||||
| 				return t.rtcpReadStream.Read(b) | 				return t.rtcpReadStream.ReadContext(ctx, b) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		return 0, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid) | 		return 0, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid) | ||||||
| 	case <-r.closed: | 	case <-r.closed: | ||||||
| 		return 0, io.ErrClosedPipe | 		return 0, io.ErrClosedPipe | ||||||
|  | 	case <-ctx.Done(): | ||||||
|  | 		return 0, ctx.Err() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // ReadRTCP is a convenience method that wraps Read and unmarshal for you. | // ReadRTCP is a convenience method that wraps Read and unmarshal for you. | ||||||
| // It also runs any configured interceptors. | // It also runs any configured interceptors. | ||||||
| func (r *RTPReceiver) ReadRTCP() ([]rtcp.Packet, error) { | func (r *RTPReceiver) ReadRTCP(ctx context.Context) ([]rtcp.Packet, error) { | ||||||
| 	pkts, _, err := r.interceptorRTCPReader.Read() | 	pkts, _, err := r.interceptorRTCPReader.Read(ctx) | ||||||
| 	return pkts, err | 	return pkts, err | ||||||
| } | } | ||||||
|  |  | ||||||
| // ReadRTCP is a convenience method that wraps Read and unmarshal for you | // ReadRTCP is a convenience method that wraps Read and unmarshal for you | ||||||
| func (r *RTPReceiver) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { | func (r *RTPReceiver) readRTCP(ctx context.Context) ([]rtcp.Packet, interceptor.Attributes, error) { | ||||||
| 	b := make([]byte, receiveMTU) | 	b := make([]byte, receiveMTU) | ||||||
| 	i, err := r.Read(b) | 	i, err := r.Read(ctx, b) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
| @@ -180,9 +185,9 @@ func (r *RTPReceiver) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) | |||||||
| } | } | ||||||
|  |  | ||||||
| // ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you | // ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you | ||||||
| func (r *RTPReceiver) ReadSimulcastRTCP(rid string) ([]rtcp.Packet, error) { | func (r *RTPReceiver) ReadSimulcastRTCP(ctx context.Context, rid string) ([]rtcp.Packet, error) { | ||||||
| 	b := make([]byte, receiveMTU) | 	b := make([]byte, receiveMTU) | ||||||
| 	i, err := r.ReadSimulcast(b, rid) | 	i, err := r.ReadSimulcast(ctx, b, rid) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -241,10 +246,10 @@ func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams { | |||||||
| } | } | ||||||
|  |  | ||||||
| // readRTP should only be called by a track, this only exists so we can keep state in one place | // readRTP should only be called by a track, this only exists so we can keep state in one place | ||||||
| func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, err error) { | func (r *RTPReceiver) readRTP(ctx context.Context, b []byte, reader *TrackRemote) (n int, err error) { | ||||||
| 	<-r.received | 	<-r.received | ||||||
| 	if t := r.streamsForTrack(reader); t != nil { | 	if t := r.streamsForTrack(reader); t != nil { | ||||||
| 		return t.rtpReadStream.Read(b) | 		return t.rtpReadStream.ReadContext(ctx, b) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return 0, fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC()) | 	return 0, fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC()) | ||||||
|   | |||||||
							
								
								
									
										19
									
								
								rtpsender.go
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								rtpsender.go
									
									
									
									
									
								
							| @@ -3,6 +3,7 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"io" | 	"io" | ||||||
| 	"sync" | 	"sync" | ||||||
|  |  | ||||||
| @@ -175,8 +176,8 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error { | |||||||
| 	writeStream.setRTPWriter( | 	writeStream.setRTPWriter( | ||||||
| 		r.api.interceptor.BindLocalStream( | 		r.api.interceptor.BindLocalStream( | ||||||
| 			info, | 			info, | ||||||
| 			interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) { | 			interceptor.RTPWriterFunc(func(ctx context.Context, p *rtp.Packet, attributes interceptor.Attributes) (int, error) { | ||||||
| 				return r.srtpStream.WriteRTP(&p.Header, p.Payload) | 				return r.srtpStream.WriteRTP(ctx, &p.Header, p.Payload) | ||||||
| 			}), | 			}), | ||||||
| 		)) | 		)) | ||||||
|  |  | ||||||
| @@ -208,25 +209,27 @@ func (r *RTPSender) Stop() error { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Read reads incoming RTCP for this RTPReceiver | // Read reads incoming RTCP for this RTPReceiver | ||||||
| func (r *RTPSender) Read(b []byte) (n int, err error) { | func (r *RTPSender) Read(ctx context.Context, b []byte) (n int, err error) { | ||||||
| 	select { | 	select { | ||||||
| 	case <-r.sendCalled: | 	case <-r.sendCalled: | ||||||
| 		return r.srtpStream.Read(b) | 		return r.srtpStream.ReadContext(ctx, b) | ||||||
| 	case <-r.stopCalled: | 	case <-r.stopCalled: | ||||||
| 		return 0, io.ErrClosedPipe | 		return 0, io.ErrClosedPipe | ||||||
|  | 	case <-ctx.Done(): | ||||||
|  | 		return 0, ctx.Err() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // ReadRTCP is a convenience method that wraps Read and unmarshals for you. | // ReadRTCP is a convenience method that wraps Read and unmarshals for you. | ||||||
| // It also runs any configured interceptors. | // It also runs any configured interceptors. | ||||||
| func (r *RTPSender) ReadRTCP() ([]rtcp.Packet, error) { | func (r *RTPSender) ReadRTCP(ctx context.Context) ([]rtcp.Packet, error) { | ||||||
| 	pkts, _, err := r.interceptorRTCPReader.Read() | 	pkts, _, err := r.interceptorRTCPReader.Read(ctx) | ||||||
| 	return pkts, err | 	return pkts, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func (r *RTPSender) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { | func (r *RTPSender) readRTCP(ctx context.Context) ([]rtcp.Packet, interceptor.Attributes, error) { | ||||||
| 	b := make([]byte, receiveMTU) | 	b := make([]byte, receiveMTU) | ||||||
| 	i, err := r.Read(b) | 	i, err := r.Read(ctx, b) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -50,7 +50,7 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) { | |||||||
| 			assert.Equal(t, uint64(1), atomic.AddUint64(&onTrackCount, 1)) | 			assert.Equal(t, uint64(1), atomic.AddUint64(&onTrackCount, 1)) | ||||||
|  |  | ||||||
| 			for { | 			for { | ||||||
| 				pkt, err := track.ReadRTP() | 				pkt, err := track.ReadRTP(context.Background()) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					assert.True(t, errors.Is(io.EOF, err)) | 					assert.True(t, errors.Is(io.EOF, err)) | ||||||
| 					return | 					return | ||||||
| @@ -74,7 +74,9 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) { | |||||||
| 				case <-seenPacketA.Done(): | 				case <-seenPacketA.Done(): | ||||||
| 					return | 					return | ||||||
| 				default: | 				default: | ||||||
| 					assert.NoError(t, trackA.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second})) | 					assert.NoError(t, trackA.WriteSample( | ||||||
|  | 						context.Background(), media.Sample{Data: []byte{0xAA}, Duration: time.Second}, | ||||||
|  | 					)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
| @@ -88,7 +90,9 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) { | |||||||
| 				case <-seenPacketB.Done(): | 				case <-seenPacketB.Done(): | ||||||
| 					return | 					return | ||||||
| 				default: | 				default: | ||||||
| 					assert.NoError(t, trackB.WriteSample(media.Sample{Data: []byte{0xBB}, Duration: time.Second})) | 					assert.NoError(t, trackB.WriteSample( | ||||||
|  | 						context.Background(), media.Sample{Data: []byte{0xBB}, Duration: time.Second}, | ||||||
|  | 					)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
| @@ -123,7 +127,9 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) { | |||||||
| 				case <-seenPacket.Done(): | 				case <-seenPacket.Done(): | ||||||
| 					return | 					return | ||||||
| 				default: | 				default: | ||||||
| 					assert.NoError(t, trackA.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second})) | 					assert.NoError(t, trackA.WriteSample( | ||||||
|  | 						context.Background(), media.Sample{Data: []byte{0xAA}, Duration: time.Second}, | ||||||
|  | 					)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		}() | 		}() | ||||||
| @@ -134,3 +140,14 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) { | |||||||
| 		assert.NoError(t, receiver.Close()) | 		assert.NoError(t, receiver.Close()) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func Test_RTPSender_ContextCancel(t *testing.T) { | ||||||
|  | 	sender := &RTPSender{ | ||||||
|  | 		sendCalled: make(chan struct{}), | ||||||
|  | 		stopCalled: make(chan struct{}), | ||||||
|  | 	} | ||||||
|  | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
|  | 	cancel() | ||||||
|  | 	_, err := sender.Read(ctx, []byte{}) | ||||||
|  | 	assert.Equal(t, context.Canceled, err) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"io" | 	"io" | ||||||
| 	"math" | 	"math" | ||||||
| 	"sync" | 	"sync" | ||||||
| @@ -90,7 +91,7 @@ func (r *SCTPTransport) GetCapabilities() SCTPCapabilities { | |||||||
| // Start the SCTPTransport. Since both local and remote parties must mutually | // Start the SCTPTransport. Since both local and remote parties must mutually | ||||||
| // create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish | // create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish | ||||||
| // a connection over SCTP. | // a connection over SCTP. | ||||||
| func (r *SCTPTransport) Start(remoteCaps SCTPCapabilities) error { | func (r *SCTPTransport) Start(ctx context.Context, remoteCaps SCTPCapabilities) error { | ||||||
| 	if r.isStarted { | 	if r.isStarted { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -3,11 +3,12 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"io" | 	"io" | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
|  |  | ||||||
| 	"github.com/pion/rtp" | 	"github.com/pion/rtp" | ||||||
| 	"github.com/pion/srtp" | 	"github.com/pion/srtp/v2" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // srtpWriterFuture blocks Read/Write calls until | // srtpWriterFuture blocks Read/Write calls until | ||||||
| @@ -18,11 +19,13 @@ type srtpWriterFuture struct { | |||||||
| 	rtpWriteStream atomic.Value // *srtp.WriteStreamSRTP | 	rtpWriteStream atomic.Value // *srtp.WriteStreamSRTP | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *srtpWriterFuture) init() error { | func (s *srtpWriterFuture) init(ctx context.Context) error { | ||||||
| 	select { | 	select { | ||||||
| 	case <-s.rtpSender.stopCalled: | 	case <-s.rtpSender.stopCalled: | ||||||
| 		return io.ErrClosedPipe | 		return io.ErrClosedPipe | ||||||
| 	case <-s.rtpSender.transport.srtpReady: | 	case <-s.rtpSender.transport.srtpReady: | ||||||
|  | 	case <-ctx.Done(): | ||||||
|  | 		return ctx.Err() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	srtcpSession, err := s.rtpSender.transport.getSRTCPSession() | 	srtcpSession, err := s.rtpSender.transport.getSRTCPSession() | ||||||
| @@ -58,38 +61,38 @@ func (s *srtpWriterFuture) Close() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *srtpWriterFuture) Read(b []byte) (n int, err error) { | func (s *srtpWriterFuture) ReadContext(ctx context.Context, b []byte) (n int, err error) { | ||||||
| 	if value := s.rtcpReadStream.Load(); value != nil { | 	if value := s.rtcpReadStream.Load(); value != nil { | ||||||
| 		return value.(*srtp.ReadStreamSRTCP).Read(b) | 		return value.(*srtp.ReadStreamSRTCP).ReadContext(ctx, b) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if err := s.init(); err != nil || s.rtcpReadStream.Load() == nil { | 	if err := s.init(ctx); err != nil || s.rtcpReadStream.Load() == nil { | ||||||
| 		return 0, err | 		return 0, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return s.Read(b) | 	return s.ReadContext(ctx, b) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *srtpWriterFuture) WriteRTP(header *rtp.Header, payload []byte) (int, error) { | func (s *srtpWriterFuture) WriteRTP(ctx context.Context, header *rtp.Header, payload []byte) (int, error) { | ||||||
| 	if value := s.rtpWriteStream.Load(); value != nil { | 	if value := s.rtpWriteStream.Load(); value != nil { | ||||||
| 		return value.(*srtp.WriteStreamSRTP).WriteRTP(header, payload) | 		return value.(*srtp.WriteStreamSRTP).WriteRTP(ctx, header, payload) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if err := s.init(); err != nil || s.rtpWriteStream.Load() == nil { | 	if err := s.init(ctx); err != nil || s.rtpWriteStream.Load() == nil { | ||||||
| 		return 0, err | 		return 0, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return s.WriteRTP(header, payload) | 	return s.WriteRTP(ctx, header, payload) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *srtpWriterFuture) Write(b []byte) (int, error) { | func (s *srtpWriterFuture) Write(ctx context.Context, b []byte) (int, error) { | ||||||
| 	if value := s.rtpWriteStream.Load(); value != nil { | 	if value := s.rtpWriteStream.Load(); value != nil { | ||||||
| 		return value.(*srtp.WriteStreamSRTP).Write(b) | 		return value.(*srtp.WriteStreamSRTP).WriteContext(ctx, b) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if err := s.init(); err != nil || s.rtpWriteStream.Load() == nil { | 	if err := s.init(ctx); err != nil || s.rtpWriteStream.Load() == nil { | ||||||
| 		return 0, err | 		return 0, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return s.Write(b) | 	return s.Write(ctx, b) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,14 +1,18 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import "github.com/pion/rtp" | import ( | ||||||
|  | 	"context" | ||||||
|  |  | ||||||
|  | 	"github.com/pion/rtp" | ||||||
|  | ) | ||||||
|  |  | ||||||
| // TrackLocalWriter is the Writer for outbound RTP Packets | // TrackLocalWriter is the Writer for outbound RTP Packets | ||||||
| type TrackLocalWriter interface { | type TrackLocalWriter interface { | ||||||
| 	// WriteRTP encrypts a RTP packet and writes to the connection | 	// WriteRTP encrypts a RTP packet and writes to the connection | ||||||
| 	WriteRTP(header *rtp.Header, payload []byte) (int, error) | 	WriteRTP(ctx context.Context, header *rtp.Header, payload []byte) (int, error) | ||||||
|  |  | ||||||
| 	// Write encrypts and writes a full RTP packet | 	// Write encrypts and writes a full RTP packet | ||||||
| 	Write(b []byte) (int, error) | 	Write(ctx context.Context, b []byte) (int, error) | ||||||
| } | } | ||||||
|  |  | ||||||
| // TrackLocalContext is the Context passed when a TrackLocal has been Binded/Unbinded from a PeerConnection, and used | // TrackLocalContext is the Context passed when a TrackLocal has been Binded/Unbinded from a PeerConnection, and used | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
|  |  | ||||||
| @@ -102,7 +103,7 @@ func (s *TrackLocalStaticRTP) Kind() RTPCodecType { | |||||||
| // If one PeerConnection fails the packets will still be sent to | // If one PeerConnection fails the packets will still be sent to | ||||||
| // all PeerConnections. The error message will contain the ID of the failed | // all PeerConnections. The error message will contain the ID of the failed | ||||||
| // PeerConnections so you can remove them | // PeerConnections so you can remove them | ||||||
| func (s *TrackLocalStaticRTP) WriteRTP(p *rtp.Packet) error { | func (s *TrackLocalStaticRTP) WriteRTP(ctx context.Context, p *rtp.Packet) error { | ||||||
| 	s.mu.RLock() | 	s.mu.RLock() | ||||||
| 	defer s.mu.RUnlock() | 	defer s.mu.RUnlock() | ||||||
|  |  | ||||||
| @@ -112,7 +113,7 @@ func (s *TrackLocalStaticRTP) WriteRTP(p *rtp.Packet) error { | |||||||
| 	for _, b := range s.bindings { | 	for _, b := range s.bindings { | ||||||
| 		outboundPacket.Header.SSRC = uint32(b.ssrc) | 		outboundPacket.Header.SSRC = uint32(b.ssrc) | ||||||
| 		outboundPacket.Header.PayloadType = uint8(b.payloadType) | 		outboundPacket.Header.PayloadType = uint8(b.payloadType) | ||||||
| 		if _, err := b.writeStream.WriteRTP(&outboundPacket.Header, outboundPacket.Payload); err != nil { | 		if _, err := b.writeStream.WriteRTP(ctx, &outboundPacket.Header, outboundPacket.Payload); err != nil { | ||||||
| 			writeErrs = append(writeErrs, err) | 			writeErrs = append(writeErrs, err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @@ -124,13 +125,13 @@ func (s *TrackLocalStaticRTP) WriteRTP(p *rtp.Packet) error { | |||||||
| // If one PeerConnection fails the packets will still be sent to | // If one PeerConnection fails the packets will still be sent to | ||||||
| // all PeerConnections. The error message will contain the ID of the failed | // all PeerConnections. The error message will contain the ID of the failed | ||||||
| // PeerConnections so you can remove them | // PeerConnections so you can remove them | ||||||
| func (s *TrackLocalStaticRTP) Write(b []byte) (n int, err error) { | func (s *TrackLocalStaticRTP) Write(ctx context.Context, b []byte) (n int, err error) { | ||||||
| 	packet := &rtp.Packet{} | 	packet := &rtp.Packet{} | ||||||
| 	if err = packet.Unmarshal(b); err != nil { | 	if err = packet.Unmarshal(b); err != nil { | ||||||
| 		return 0, err | 		return 0, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return len(b), s.WriteRTP(packet) | 	return len(b), s.WriteRTP(ctx, packet) | ||||||
| } | } | ||||||
|  |  | ||||||
| // TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples. | // TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples. | ||||||
| @@ -208,7 +209,7 @@ func (s *TrackLocalStaticSample) Unbind(t TrackLocalContext) error { | |||||||
| // If one PeerConnection fails the packets will still be sent to | // If one PeerConnection fails the packets will still be sent to | ||||||
| // all PeerConnections. The error message will contain the ID of the failed | // all PeerConnections. The error message will contain the ID of the failed | ||||||
| // PeerConnections so you can remove them | // PeerConnections so you can remove them | ||||||
| func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error { | func (s *TrackLocalStaticSample) WriteSample(ctx context.Context, sample media.Sample) error { | ||||||
| 	s.rtpTrack.mu.RLock() | 	s.rtpTrack.mu.RLock() | ||||||
| 	p := s.packetizer | 	p := s.packetizer | ||||||
| 	clockRate := s.clockRate | 	clockRate := s.clockRate | ||||||
| @@ -223,7 +224,7 @@ func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error { | |||||||
|  |  | ||||||
| 	writeErrs := []error{} | 	writeErrs := []error{} | ||||||
| 	for _, p := range packets { | 	for _, p := range packets { | ||||||
| 		if err := s.rtpTrack.WriteRTP(p); err != nil { | 		if err := s.rtpTrack.WriteRTP(ctx, p); err != nil { | ||||||
| 			writeErrs = append(writeErrs, err) | 			writeErrs = append(writeErrs, err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -185,7 +185,7 @@ func Test_TrackLocalStatic_Mutate_Input(t *testing.T) { | |||||||
| 	assert.NoError(t, signalPair(pcOffer, pcAnswer)) | 	assert.NoError(t, signalPair(pcOffer, pcAnswer)) | ||||||
|  |  | ||||||
| 	pkt := &rtp.Packet{Header: rtp.Header{SSRC: 1, PayloadType: 1}} | 	pkt := &rtp.Packet{Header: rtp.Header{SSRC: 1, PayloadType: 1}} | ||||||
| 	assert.NoError(t, vp8Writer.WriteRTP(pkt)) | 	assert.NoError(t, vp8Writer.WriteRTP(context.Background(), pkt)) | ||||||
|  |  | ||||||
| 	assert.Equal(t, pkt.Header.SSRC, uint32(1)) | 	assert.Equal(t, pkt.Header.SSRC, uint32(1)) | ||||||
| 	assert.Equal(t, pkt.Header.PayloadType, uint8(1)) | 	assert.Equal(t, pkt.Header.PayloadType, uint8(1)) | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| package webrtc | package webrtc | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"sync" | 	"sync" | ||||||
|  |  | ||||||
| 	"github.com/pion/interceptor" | 	"github.com/pion/interceptor" | ||||||
| @@ -125,7 +126,7 @@ func (t *TrackRemote) Codec() RTPCodecParameters { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Read reads data from the track. | // Read reads data from the track. | ||||||
| func (t *TrackRemote) Read(b []byte) (n int, err error) { | func (t *TrackRemote) Read(ctx context.Context, b []byte) (n int, err error) { | ||||||
| 	t.mu.RLock() | 	t.mu.RLock() | ||||||
| 	r := t.receiver | 	r := t.receiver | ||||||
| 	peeked := t.peeked != nil | 	peeked := t.peeked != nil | ||||||
| @@ -144,12 +145,12 @@ func (t *TrackRemote) Read(b []byte) (n int, err error) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return r.readRTP(b, t) | 	return r.readRTP(ctx, b, t) | ||||||
| } | } | ||||||
|  |  | ||||||
| // peek is like Read, but it doesn't discard the packet read | // peek is like Read, but it doesn't discard the packet read | ||||||
| func (t *TrackRemote) peek(b []byte) (n int, err error) { | func (t *TrackRemote) peek(ctx context.Context, b []byte) (n int, err error) { | ||||||
| 	n, err = t.Read(b) | 	n, err = t.Read(ctx, b) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -167,14 +168,14 @@ func (t *TrackRemote) peek(b []byte) (n int, err error) { | |||||||
|  |  | ||||||
| // ReadRTP is a convenience method that wraps Read and unmarshals for you. | // ReadRTP is a convenience method that wraps Read and unmarshals for you. | ||||||
| // It also runs any configured interceptors. | // It also runs any configured interceptors. | ||||||
| func (t *TrackRemote) ReadRTP() (*rtp.Packet, error) { | func (t *TrackRemote) ReadRTP(ctx context.Context) (*rtp.Packet, error) { | ||||||
| 	p, _, err := t.interceptorRTPReader.Read() | 	p, _, err := t.interceptorRTPReader.Read(ctx) | ||||||
| 	return p, err | 	return p, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *TrackRemote) readRTP() (*rtp.Packet, interceptor.Attributes, error) { | func (t *TrackRemote) readRTP(ctx context.Context) (*rtp.Packet, interceptor.Attributes, error) { | ||||||
| 	b := make([]byte, receiveMTU) | 	b := make([]byte, receiveMTU) | ||||||
| 	i, err := t.Read(b) | 	i, err := t.Read(ctx, b) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
| @@ -188,9 +189,9 @@ func (t *TrackRemote) readRTP() (*rtp.Packet, interceptor.Attributes, error) { | |||||||
|  |  | ||||||
| // determinePayloadType blocks and reads a single packet to determine the PayloadType for this Track | // determinePayloadType blocks and reads a single packet to determine the PayloadType for this Track | ||||||
| // this is useful because we can't announce it to the user until we know the payloadType | // this is useful because we can't announce it to the user until we know the payloadType | ||||||
| func (t *TrackRemote) determinePayloadType() error { | func (t *TrackRemote) determinePayloadType(ctx context.Context) error { | ||||||
| 	b := make([]byte, receiveMTU) | 	b := make([]byte, receiveMTU) | ||||||
| 	n, err := t.peek(b) | 	n, err := t.peek(ctx, b) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Atsushi Watanabe
					Atsushi Watanabe