Update Interceptors to use []byte based API

Also update test to assert Attributes get passed all the way through

Resolves pion/interceptor#14
This commit is contained in:
Sean DuBois
2020-12-12 19:42:29 -08:00
parent ff1bc32221
commit 67826b1914
18 changed files with 210 additions and 306 deletions

View File

@@ -68,7 +68,7 @@ func main() { // nolint:gocognit
rtpBuf := make([]byte, 1400)
for {
i, readErr := remoteTrack.Read(rtpBuf)
i, _, readErr := remoteTrack.Read(rtpBuf)
if readErr != nil {
panic(readErr)
}

View File

@@ -82,7 +82,7 @@ func main() {
fmt.Printf("Track has started, of type %d: %s \n", track.PayloadType(), track.Codec().MimeType)
for {
// Read RTP packets being sent to Pion
rtp, readErr := track.ReadRTP()
rtp, _, readErr := track.ReadRTP()
if readErr != nil {
panic(readErr)
}

View File

@@ -116,7 +116,7 @@ func main() {
b := make([]byte, 1500)
for {
// Read
n, readErr := track.Read(b)
n, _, readErr := track.Read(b)
if readErr != nil {
panic(readErr)
}

View File

@@ -23,7 +23,7 @@ func saveToDisk(i media.Writer, track *webrtc.TrackRemote) {
}()
for {
rtpPacket, err := track.ReadRTP()
rtpPacket, _, err := track.ReadRTP()
if err != nil {
panic(err)
}

View File

@@ -92,7 +92,7 @@ func main() {
}()
for {
// Read RTP packets being sent to Pion
packet, readErr := track.ReadRTP()
packet, _, readErr := track.ReadRTP()
if readErr != nil {
panic(readErr)
}

View File

@@ -85,7 +85,7 @@ func main() { // nolint:gocognit
var isCurrTrack bool
for {
// Read RTP packets being sent to Pion
rtp, readErr := track.ReadRTP()
rtp, _, readErr := track.ReadRTP()
if readErr != nil {
panic(readErr)
}

2
go.mod
View File

@@ -8,7 +8,7 @@ require (
github.com/pion/datachannel v1.4.21
github.com/pion/dtls/v2 v2.0.4
github.com/pion/ice/v2 v2.0.14
github.com/pion/interceptor v0.0.5
github.com/pion/interceptor v0.0.6
github.com/pion/logging v0.2.2
github.com/pion/randutil v0.1.0
github.com/pion/rtcp v1.2.6

4
go.sum
View File

@@ -40,8 +40,8 @@ github.com/pion/dtls/v2 v2.0.4 h1:WuUcqi6oYMu/noNTz92QrF1DaFj4eXbhQ6dzaaAwOiI=
github.com/pion/dtls/v2 v2.0.4/go.mod h1:qAkFscX0ZHoI1E07RfYPoRw3manThveu+mlTDdOxoGI=
github.com/pion/ice/v2 v2.0.14 h1:FxXxauyykf89SWAtkQCfnHkno6G8+bhRkNguSh9zU+4=
github.com/pion/ice/v2 v2.0.14/go.mod h1:wqaUbOq5ObDNU5ox1hRsEst0rWfsKuH1zXjQFEWiZwM=
github.com/pion/interceptor v0.0.5 h1:BOwlubM1lntji3eNaVrhW1Qk3u1UoemrhM4mbv24XGM=
github.com/pion/interceptor v0.0.5/go.mod h1:lPVrf5xfosI989ZcmgPS4WwwRhd+XAyTFaYI2wHf7nU=
github.com/pion/interceptor v0.0.6 h1:530EdZi757pZEx510kvO25FkEuKm2mrb0p9NA+Xfj8E=
github.com/pion/interceptor v0.0.6/go.mod h1:QHkPVN5uyuw54wHqqL1KS9fxf3M3RzOlVKg/YrtK1so=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/mdns v0.0.4 h1:O4vvVqr4DGX63vzmO6Fw9vpy3lfztVWHGCQfyw0ZLSY=

View File

@@ -3,14 +3,16 @@
package webrtc
import (
"sync/atomic"
"github.com/pion/interceptor"
"github.com/pion/rtp"
)
// RegisterDefaultInterceptors will register some useful interceptors. If you want to customize which interceptors are loaded,
// you should copy the code from this method and remove unwanted interceptors.
func RegisterDefaultInterceptors(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Registry) error {
err := ConfigureNack(mediaEngine, interceptorRegistry)
if err != nil {
if err := ConfigureNack(mediaEngine, interceptorRegistry); err != nil {
return err
}
@@ -24,3 +26,47 @@ func ConfigureNack(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Re
interceptorRegistry.Add(&interceptor.NACK{})
return nil
}
type interceptorToTrackLocalWriter struct{ interceptor atomic.Value } // interceptor.RTPWriter }
func (i *interceptorToTrackLocalWriter) WriteRTP(header *rtp.Header, payload []byte) (int, error) {
if writer, ok := i.interceptor.Load().(interceptor.RTPWriter); ok && writer != nil {
return writer.Write(header, payload, interceptor.Attributes{})
}
return 0, nil
}
func (i *interceptorToTrackLocalWriter) Write(b []byte) (int, error) {
packet := &rtp.Packet{}
if err := packet.Unmarshal(b); err != nil {
return 0, err
}
return i.WriteRTP(&packet.Header, packet.Payload)
}
func createStreamInfo(id string, ssrc SSRC, payloadType PayloadType, codec RTPCodecCapability, webrtcHeaderExtensions []RTPHeaderExtensionParameter) interceptor.StreamInfo {
headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(webrtcHeaderExtensions))
for _, h := range webrtcHeaderExtensions {
headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI})
}
feedbacks := make([]interceptor.RTCPFeedback, 0, len(codec.RTCPFeedback))
for _, f := range codec.RTCPFeedback {
feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter})
}
return interceptor.StreamInfo{
ID: id,
Attributes: interceptor.Attributes{},
SSRC: uint32(ssrc),
PayloadType: uint8(payloadType),
RTPHeaderExtensions: headerExtensions,
MimeType: codec.MimeType,
ClockRate: codec.ClockRate,
Channels: codec.Channels,
SDPFmtpLine: codec.SDPFmtpLine,
RTCPFeedback: feedbacks,
}
}

View File

@@ -2,14 +2,13 @@
package webrtc
//
import (
"sync"
"sync/atomic"
"context"
"testing"
"time"
"github.com/pion/interceptor"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/transport/test"
"github.com/pion/webrtc/v3/pkg/media"
@@ -17,68 +16,37 @@ import (
)
type testInterceptor struct {
t *testing.T
extensionID uint8
rtcpWriter atomic.Value
lastRTCP atomic.Value
interceptor.NoOp
t *testing.T
}
func (t *testInterceptor) BindLocalStream(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
return interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) {
return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
// set extension on outgoing packet
p.Header.Extension = true
p.Header.ExtensionProfile = 0xBEDE
assert.NoError(t.t, p.Header.SetExtension(t.extensionID, []byte("write")))
header.Extension = true
header.ExtensionProfile = 0xBEDE
assert.NoError(t.t, header.SetExtension(2, []byte("foo")))
return writer.Write(p, attributes)
return writer.Write(header, payload, attributes)
})
}
func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
return interceptor.RTPReaderFunc(func() (*rtp.Packet, interceptor.Attributes, error) {
p, attributes, err := reader.Read()
if err != nil {
return nil, nil, err
}
// set extension on incoming packet
p.Header.Extension = true
p.Header.ExtensionProfile = 0xBEDE
assert.NoError(t.t, p.Header.SetExtension(t.extensionID, []byte("read")))
// write back a pli
rtcpWriter := t.rtcpWriter.Load().(interceptor.RTCPWriter)
pli := &rtcp.PictureLossIndication{SenderSSRC: info.SSRC, MediaSSRC: info.SSRC}
_, err = rtcpWriter.Write([]rtcp.Packet{pli}, make(interceptor.Attributes))
assert.NoError(t.t, err)
return p, attributes, nil
})
}
func (t *testInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader {
return interceptor.RTCPReaderFunc(func() ([]rtcp.Packet, interceptor.Attributes, error) {
pkts, attributes, err := reader.Read()
if err != nil {
return nil, nil, err
func (t *testInterceptor) BindRemoteStream(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) {
if a == nil {
a = interceptor.Attributes{}
}
t.lastRTCP.Store(pkts[0])
return pkts, attributes, nil
a.Set("attribute", "value")
return reader.Read(b, a)
})
}
func (t *testInterceptor) lastReadRTCP() rtcp.Packet {
p, _ := t.lastRTCP.Load().(rtcp.Packet)
return p
}
func (t *testInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter {
t.rtcpWriter.Store(writer)
return writer
}
// E2E test of the features of Interceptors
// * Assert an extension can be set on an outbound packet
// * Assert an extension can be read on an outbound packet
// * Assert that attributes set by an interceptor are returned to the Reader
func TestPeerConnection_Interceptor(t *testing.T) {
to := test.TimeOut(time.Second * 20)
defer to.Stop()
@@ -86,12 +54,12 @@ func TestPeerConnection_Interceptor(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
createPC := func(i interceptor.Interceptor) *PeerConnection {
createPC := func() *PeerConnection {
m := &MediaEngine{}
assert.NoError(t, m.RegisterDefaultCodecs())
ir := &interceptor.Registry{}
ir.Add(i)
ir.Add(&testInterceptor{t: t})
pc, err := NewAPI(WithMediaEngine(m), WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{})
assert.NoError(t, err)
@@ -99,75 +67,41 @@ func TestPeerConnection_Interceptor(t *testing.T) {
return pc
}
sendInterceptor := &testInterceptor{t: t, extensionID: 1}
senderPC := createPC(sendInterceptor)
receiverPC := createPC(&testInterceptor{t: t, extensionID: 2})
offerer := createPC()
answerer := createPC()
track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video", "pion")
assert.NoError(t, err)
sender, err := senderPC.AddTrack(track)
_, err = offerer.AddTrack(track)
assert.NoError(t, err)
pending := new(int32)
wg := &sync.WaitGroup{}
wg.Add(1)
*pending++
receiverPC.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) {
p, readErr := track.ReadRTP()
seenRTP, seenRTPCancel := context.WithCancel(context.Background())
answerer.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) {
p, attributes, readErr := track.ReadRTP()
assert.NoError(t, readErr)
assert.Equal(t, p.Extension, true)
assert.Equal(t, "write", string(p.GetExtension(1)))
assert.Equal(t, "read", string(p.GetExtension(2)))
atomic.AddInt32(pending, -1)
wg.Done()
assert.Equal(t, "foo", string(p.GetExtension(2)))
assert.Equal(t, "value", attributes.Get("attribute"))
for {
if _, readErr = track.ReadRTP(); readErr != nil {
return
}
}
seenRTPCancel()
})
wg.Add(1)
*pending++
go func() {
_, readErr := sender.ReadRTCP()
assert.NoError(t, readErr)
atomic.AddInt32(pending, -1)
wg.Done()
assert.NoError(t, signalPair(offerer, answerer))
func() {
ticker := time.NewTicker(time.Millisecond * 20)
for {
if _, readErr = sender.ReadRTCP(); readErr != nil {
select {
case <-seenRTP.Done():
return
case <-ticker.C:
assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
}
}
}()
assert.NoError(t, signalPair(senderPC, receiverPC))
wg.Add(1)
go func() {
defer wg.Done()
for {
time.Sleep(time.Millisecond * 100)
assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
if atomic.LoadInt32(pending) == 0 {
return
}
}
}()
wg.Wait()
assert.NoError(t, senderPC.Close())
assert.NoError(t, receiverPC.Close())
pli, _ := sendInterceptor.lastReadRTCP().(*rtcp.PictureLossIndication)
if pli == nil || pli.SenderSSRC == 0 {
t.Errorf("pli not found by send interceptor")
}
assert.NoError(t, offerer.Close())
assert.NoError(t, answerer.Close())
}

View File

@@ -1,27 +0,0 @@
// +build !js
package webrtc
import (
"sync/atomic"
"github.com/pion/interceptor"
"github.com/pion/rtp"
)
type interceptorTrackLocalWriter struct {
TrackLocalWriter
rtpWriter atomic.Value
}
func (i *interceptorTrackLocalWriter) setRTPWriter(writer interceptor.RTPWriter) {
i.rtpWriter.Store(writer)
}
func (i *interceptorTrackLocalWriter) WriteRTP(header *rtp.Header, payload []byte) (int, error) {
if writer, ok := i.rtpWriter.Load().(interceptor.RTPWriter); ok && writer != nil {
return writer.Write(&rtp.Packet{Header: *header, Payload: payload}, make(interceptor.Attributes))
}
return 0, nil
}

View File

@@ -1152,7 +1152,6 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece
receiver.Track().kind = receiver.kind
receiver.Track().codec = params.Codecs[0]
receiver.Track().params = params
receiver.Track().bindInterceptor()
receiver.Track().mu.Unlock()
pc.onTrack(receiver.Track(), receiver)

View File

@@ -105,7 +105,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
}()
go func() {
_, routineErr := receiver.Read(make([]byte, 1400))
_, _, routineErr := receiver.Read(make([]byte, 1400))
if routineErr != nil {
awaitRTCPReceiverRecv <- routineErr
} else {
@@ -115,7 +115,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
haveClosedAwaitRTPRecv := false
for {
p, routineErr := track.ReadRTP()
p, _, routineErr := track.ReadRTP()
if routineErr != nil {
close(awaitRTPRecvClosed)
return
@@ -168,7 +168,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
}()
go func() {
if _, routineErr := sender.Read(make([]byte, 1400)); routineErr == nil {
if _, _, routineErr := sender.Read(make([]byte, 1400)); routineErr == nil {
close(awaitRTCPSenderRecv)
}
}()
@@ -688,11 +688,11 @@ func TestRtpSenderReceiver_ReadClose_Error(t *testing.T) {
sender, receiver := tr.Sender(), tr.Receiver()
assert.NoError(t, sender.Stop())
_, err = sender.Read(make([]byte, 0, 1400))
_, _, err = sender.Read(make([]byte, 0, 1400))
assert.Error(t, err, io.ErrClosedPipe)
assert.NoError(t, receiver.Stop())
_, err = receiver.Read(make([]byte, 0, 1400))
_, _, err = receiver.Read(make([]byte, 0, 1400))
assert.Error(t, err, io.ErrClosedPipe)
assert.NoError(t, pc.Close())

View File

@@ -360,7 +360,7 @@ func TestPeerConnection_Renegotiation_CodecChange(t *testing.T) {
pcAnswer.OnTrack(func(track *TrackRemote, r *RTPReceiver) {
tracksCh <- track
for {
if _, readErr := track.ReadRTP(); readErr == io.EOF {
if _, _, readErr := track.ReadRTP(); readErr == io.EOF {
tracksClosed <- struct{}{}
return
}
@@ -450,7 +450,7 @@ func TestPeerConnection_Renegotiation_RemoveTrack(t *testing.T) {
onTrackFiredFunc()
for {
if _, err := track.ReadRTP(); err == io.EOF {
if _, _, err := track.ReadRTP(); err == io.EOF {
trackClosedFunc()
return
}

View File

@@ -10,14 +10,19 @@ import (
"github.com/pion/interceptor"
"github.com/pion/rtcp"
"github.com/pion/srtp/v2"
"github.com/pion/webrtc/v3/internal/util"
)
// trackStreams maintains a mapping of RTP/RTCP streams to a specific track
// a RTPReceiver may contain multiple streams if we are dealing with Multicast
type trackStreams struct {
track *TrackRemote
track *TrackRemote
rtpReadStream *srtp.ReadStreamSRTP
rtcpReadStream *srtp.ReadStreamSRTCP
rtpInterceptor interceptor.RTPReader
rtcpReadStream *srtp.ReadStreamSRTCP
rtcpInterceptor interceptor.RTCPReader
}
// RTPReceiver allows an application to inspect the receipt of a TrackRemote
@@ -32,8 +37,6 @@ type RTPReceiver struct {
// A reference to the associated api object
api *API
interceptorRTCPReader interceptor.RTCPReader
}
// NewRTPReceiver constructs a new RTPReceiver
@@ -50,7 +53,6 @@ func (api *API) NewRTPReceiver(kind RTPCodecType, transport *DTLSTransport) (*RT
received: make(chan interface{}),
tracks: []trackStreams{},
}
r.interceptorRTCPReader = api.interceptor.BindRTCPReader(interceptor.RTCPReaderFunc(r.readRTCP))
return r, nil
}
@@ -115,8 +117,7 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error {
}
var err error
t.rtpReadStream, t.rtcpReadStream, err = r.streamsForSSRC(parameters.Encodings[0].SSRC)
if err != nil {
if t.rtpReadStream, t.rtpInterceptor, t.rtcpReadStream, t.rtcpInterceptor, err = r.streamsForSSRC(parameters.Encodings[0].SSRC, interceptor.StreamInfo{}); err != nil {
return err
}
@@ -138,41 +139,35 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error {
}
// Read reads incoming RTCP for this RTPReceiver
func (r *RTPReceiver) Read(b []byte) (n int, err error) {
func (r *RTPReceiver) Read(b []byte) (n int, a interceptor.Attributes, err error) {
select {
case <-r.received:
return r.tracks[0].rtcpReadStream.Read(b)
return r.tracks[0].rtcpInterceptor.Read(b, a)
case <-r.closed:
return 0, io.ErrClosedPipe
return 0, nil, io.ErrClosedPipe
}
}
// ReadSimulcast reads incoming RTCP for this RTPReceiver for given rid
func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, err error) {
func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, a interceptor.Attributes, err error) {
select {
case <-r.received:
for _, t := range r.tracks {
if t.track != nil && t.track.rid == rid {
return t.rtcpReadStream.Read(b)
return t.rtcpInterceptor.Read(b, a)
}
}
return 0, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid)
return 0, nil, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid)
case <-r.closed:
return 0, io.ErrClosedPipe
return 0, nil, io.ErrClosedPipe
}
}
// ReadRTCP is a convenience method that wraps Read and unmarshal for you.
// It also runs any configured interceptors.
func (r *RTPReceiver) ReadRTCP() ([]rtcp.Packet, error) {
pkts, _, err := r.interceptorRTCPReader.Read()
return pkts, err
}
// ReadRTCP is a convenience method that wraps Read and unmarshal for you
func (r *RTPReceiver) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) {
func (r *RTPReceiver) ReadRTCP() ([]rtcp.Packet, interceptor.Attributes, error) {
b := make([]byte, receiveMTU)
i, err := r.Read(b)
i, attributes, err := r.Read(b)
if err != nil {
return nil, nil, err
}
@@ -182,18 +177,19 @@ func (r *RTPReceiver) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error)
return nil, nil, err
}
return pkts, make(interceptor.Attributes), nil
return pkts, attributes, nil
}
// ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you
func (r *RTPReceiver) ReadSimulcastRTCP(rid string) ([]rtcp.Packet, error) {
func (r *RTPReceiver) ReadSimulcastRTCP(rid string) ([]rtcp.Packet, interceptor.Attributes, error) {
b := make([]byte, receiveMTU)
i, err := r.ReadSimulcast(b, rid)
i, attributes, err := r.ReadSimulcast(b, rid)
if err != nil {
return nil, err
return nil, nil, err
}
return rtcp.Unmarshal(b[:i])
pkts, err := rtcp.Unmarshal(b[:i])
return pkts, attributes, err
}
func (r *RTPReceiver) haveReceived() bool {
@@ -209,32 +205,34 @@ func (r *RTPReceiver) haveReceived() bool {
func (r *RTPReceiver) Stop() error {
r.mu.Lock()
defer r.mu.Unlock()
var err error
select {
case <-r.closed:
return nil
return err
default:
}
select {
case <-r.received:
for i := range r.tracks {
errs := []error{}
if r.tracks[i].rtcpReadStream != nil {
if err := r.tracks[i].rtcpReadStream.Close(); err != nil {
return err
}
errs = append(errs, r.tracks[i].rtcpReadStream.Close())
}
if r.tracks[i].rtpReadStream != nil {
if err := r.tracks[i].rtpReadStream.Close(); err != nil {
return err
}
errs = append(errs, r.tracks[i].rtpReadStream.Close())
}
err = util.FlattenErrs(errs)
}
default:
}
close(r.closed)
return nil
return err
}
func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams {
@@ -247,13 +245,13 @@ func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams {
}
// readRTP should only be called by a track, this only exists so we can keep state in one place
func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, err error) {
func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, a interceptor.Attributes, err error) {
<-r.received
if t := r.streamsForTrack(reader); t != nil {
return t.rtpReadStream.Read(b)
return t.rtpInterceptor.Read(b, a)
}
return 0, fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC())
return 0, nil, fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC())
}
// receiveForRid is the sibling of Receive expect for RIDs instead of SSRCs
@@ -269,12 +267,11 @@ func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, ssrc SSRC)
r.tracks[i].track.codec = params.Codecs[0]
r.tracks[i].track.params = params
r.tracks[i].track.ssrc = ssrc
r.tracks[i].track.bindInterceptor()
streamInfo := createStreamInfo("", ssrc, params.Codecs[0].PayloadType, params.Codecs[0].RTPCodecCapability, params.HeaderExtensions)
r.tracks[i].track.mu.Unlock()
var err error
r.tracks[i].rtpReadStream, r.tracks[i].rtcpReadStream, err = r.streamsForSSRC(ssrc)
if err != nil {
if r.tracks[0].rtpReadStream, r.tracks[0].rtpInterceptor, r.tracks[0].rtcpReadStream, r.tracks[0].rtcpInterceptor, err = r.streamsForSSRC(ssrc, streamInfo); err != nil {
return nil, err
}
@@ -285,26 +282,36 @@ func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, ssrc SSRC)
return nil, fmt.Errorf("%w: %d", errRTPReceiverForSSRCTrackStreamNotFound, ssrc)
}
func (r *RTPReceiver) streamsForSSRC(ssrc SSRC) (*srtp.ReadStreamSRTP, *srtp.ReadStreamSRTCP, error) {
func (r *RTPReceiver) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) {
srtpSession, err := r.transport.getSRTPSession()
if err != nil {
return nil, nil, err
return nil, nil, nil, nil, err
}
rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc))
if err != nil {
return nil, nil, err
return nil, nil, nil, nil, err
}
rtpInterceptor := r.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = rtpReadStream.Read(in)
return n, a, err
}))
srtcpSession, err := r.transport.getSRTCPSession()
if err != nil {
return nil, nil, err
return nil, nil, nil, nil, err
}
rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc))
if err != nil {
return nil, nil, err
return nil, nil, nil, nil, err
}
return rtpReadStream, rtcpReadStream, nil
rtcpInterceptor := r.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = rtcpReadStream.Read(in)
return n, a, err
}))
return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil
}

View File

@@ -16,8 +16,10 @@ import (
type RTPSender struct {
track TrackLocal
srtpStream *srtpWriterFuture
context TrackLocalContext
srtpStream *srtpWriterFuture
rtcpInterceptor interceptor.RTCPReader
context TrackLocalContext
transport *DTLSTransport
@@ -36,8 +38,6 @@ type RTPSender struct {
mu sync.RWMutex
sendCalled, stopCalled chan struct{}
interceptorRTCPReader interceptor.RTCPReader
}
// NewRTPSender constructs a new RTPSender
@@ -64,9 +64,13 @@ func (api *API) NewRTPSender(track TrackLocal, transport *DTLSTransport) (*RTPSe
srtpStream: &srtpWriterFuture{},
}
r.interceptorRTCPReader = api.interceptor.BindRTCPReader(interceptor.RTCPReaderFunc(r.readRTCP))
r.srtpStream.rtpSender = r
r.rtcpInterceptor = r.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = r.srtpStream.Read(in)
return n, a, err
}))
return r, nil
}
@@ -156,8 +160,7 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error {
return errRTPSenderSendAlreadyCalled
}
writeStream := &interceptorTrackLocalWriter{TrackLocalWriter: r.srtpStream}
writeStream := &interceptorToTrackLocalWriter{}
r.context = TrackLocalContext{
id: r.id,
params: r.api.mediaEngine.getRTPParametersByKind(r.track.Kind(), []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}),
@@ -171,33 +174,11 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error {
}
r.context.params.Codecs = []RTPCodecParameters{codec}
headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(r.context.params.HeaderExtensions))
for _, h := range r.context.params.HeaderExtensions {
headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI})
}
feedbacks := make([]interceptor.RTCPFeedback, 0, len(codec.RTCPFeedback))
for _, f := range codec.RTCPFeedback {
feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter})
}
info := &interceptor.StreamInfo{
ID: r.context.id,
Attributes: interceptor.Attributes{},
SSRC: uint32(r.context.ssrc),
PayloadType: uint8(codec.PayloadType),
RTPHeaderExtensions: headerExtensions,
MimeType: codec.MimeType,
ClockRate: codec.ClockRate,
Channels: codec.Channels,
SDPFmtpLine: codec.SDPFmtpLine,
RTCPFeedback: feedbacks,
}
writeStream.setRTPWriter(
r.api.interceptor.BindLocalStream(
info,
interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) {
return r.srtpStream.WriteRTP(&p.Header, p.Payload)
}),
))
streamInfo := createStreamInfo(r.id, parameters.Encodings[0].SSRC, codec.PayloadType, codec.RTPCodecCapability, parameters.HeaderExtensions)
rtpInterceptor := r.api.interceptor.BindLocalStream(&streamInfo, interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
return r.srtpStream.WriteRTP(header, payload)
}))
writeStream.interceptor.Store(rtpInterceptor)
close(r.sendCalled)
return nil
@@ -227,25 +208,19 @@ func (r *RTPSender) Stop() error {
}
// Read reads incoming RTCP for this RTPReceiver
func (r *RTPSender) Read(b []byte) (n int, err error) {
func (r *RTPSender) Read(b []byte) (n int, a interceptor.Attributes, err error) {
select {
case <-r.sendCalled:
return r.srtpStream.Read(b)
return r.rtcpInterceptor.Read(b, a)
case <-r.stopCalled:
return 0, io.ErrClosedPipe
return 0, nil, io.ErrClosedPipe
}
}
// ReadRTCP is a convenience method that wraps Read and unmarshals for you.
// It also runs any configured interceptors.
func (r *RTPSender) ReadRTCP() ([]rtcp.Packet, error) {
pkts, _, err := r.interceptorRTCPReader.Read()
return pkts, err
}
func (r *RTPSender) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) {
func (r *RTPSender) ReadRTCP() ([]rtcp.Packet, interceptor.Attributes, error) {
b := make([]byte, receiveMTU)
i, err := r.Read(b)
i, attributes, err := r.Read(b)
if err != nil {
return nil, nil, err
}
@@ -255,7 +230,7 @@ func (r *RTPSender) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) {
return nil, nil, err
}
return pkts, make(interceptor.Attributes), nil
return pkts, attributes, nil
}
// hasSent tells if data has been ever sent for this instance

View File

@@ -50,7 +50,7 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) {
assert.Equal(t, uint64(1), atomic.AddUint64(&onTrackCount, 1))
for {
pkt, err := track.ReadRTP()
pkt, _, err := track.ReadRTP()
if err != nil {
assert.True(t, errors.Is(io.EOF, err))
return

View File

@@ -23,46 +23,18 @@ type TrackRemote struct {
params RTPParameters
rid string
receiver *RTPReceiver
peeked []byte
interceptorRTPReader interceptor.RTPReader
receiver *RTPReceiver
peeked []byte
peekedAttributes interceptor.Attributes
}
func newTrackRemote(kind RTPCodecType, ssrc SSRC, rid string, receiver *RTPReceiver) *TrackRemote {
t := &TrackRemote{
return &TrackRemote{
kind: kind,
ssrc: ssrc,
rid: rid,
receiver: receiver,
}
t.interceptorRTPReader = interceptor.RTPReaderFunc(t.readRTP)
return t
}
func (t *TrackRemote) bindInterceptor() {
headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(t.params.HeaderExtensions))
for _, h := range t.params.HeaderExtensions {
headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI})
}
feedbacks := make([]interceptor.RTCPFeedback, 0, len(t.codec.RTCPFeedback))
for _, f := range t.codec.RTCPFeedback {
feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter})
}
info := &interceptor.StreamInfo{
ID: t.id,
Attributes: interceptor.Attributes{},
SSRC: uint32(t.ssrc),
PayloadType: uint8(t.payloadType),
RTPHeaderExtensions: headerExtensions,
MimeType: t.codec.MimeType,
ClockRate: t.codec.ClockRate,
Channels: t.codec.Channels,
SDPFmtpLine: t.codec.SDPFmtpLine,
RTCPFeedback: feedbacks,
}
t.interceptorRTPReader = t.receiver.api.interceptor.BindRemoteStream(info, interceptor.RTPReaderFunc(t.readRTP))
}
// ID is the unique identifier for this Track. This should be unique for the
@@ -125,7 +97,7 @@ func (t *TrackRemote) Codec() RTPCodecParameters {
}
// Read reads data from the track.
func (t *TrackRemote) Read(b []byte) (n int, err error) {
func (t *TrackRemote) Read(b []byte) (n int, attributes interceptor.Attributes, err error) {
t.mu.RLock()
r := t.receiver
peeked := t.peeked != nil
@@ -134,7 +106,10 @@ func (t *TrackRemote) Read(b []byte) (n int, err error) {
if peeked {
t.mu.Lock()
data := t.peeked
attributes = t.peekedAttributes
t.peeked = nil
t.peekedAttributes = nil
t.mu.Unlock()
// someone else may have stolen our packet when we
// released the lock. Deal with it.
@@ -147,34 +122,10 @@ func (t *TrackRemote) Read(b []byte) (n int, err error) {
return r.readRTP(b, t)
}
// peek is like Read, but it doesn't discard the packet read
func (t *TrackRemote) peek(b []byte) (n int, err error) {
n, err = t.Read(b)
if err != nil {
return
}
t.mu.Lock()
// this might overwrite data if somebody peeked between the Read
// and us getting the lock. Oh well, we'll just drop a packet in
// that case.
data := make([]byte, n)
n = copy(data, b[:n])
t.peeked = data
t.mu.Unlock()
return
}
// ReadRTP is a convenience method that wraps Read and unmarshals for you.
// It also runs any configured interceptors.
func (t *TrackRemote) ReadRTP() (*rtp.Packet, error) {
p, _, err := t.interceptorRTPReader.Read()
return p, err
}
func (t *TrackRemote) readRTP() (*rtp.Packet, interceptor.Attributes, error) {
func (t *TrackRemote) ReadRTP() (*rtp.Packet, interceptor.Attributes, error) {
b := make([]byte, receiveMTU)
i, err := t.Read(b)
i, attributes, err := t.Read(b)
if err != nil {
return nil, nil, err
}
@@ -183,14 +134,14 @@ func (t *TrackRemote) readRTP() (*rtp.Packet, interceptor.Attributes, error) {
if err := r.Unmarshal(b[:i]); err != nil {
return nil, nil, err
}
return r, interceptor.Attributes{}, nil
return r, attributes, nil
}
// determinePayloadType blocks and reads a single packet to determine the PayloadType for this Track
// this is useful because we can't announce it to the user until we know the payloadType
func (t *TrackRemote) determinePayloadType() error {
b := make([]byte, receiveMTU)
n, err := t.peek(b)
n, _, err := t.peek(b)
if err != nil {
return err
}
@@ -205,3 +156,22 @@ func (t *TrackRemote) determinePayloadType() error {
return nil
}
// peek is like Read, but it doesn't discard the packet read
func (t *TrackRemote) peek(b []byte) (n int, a interceptor.Attributes, err error) {
n, a, err = t.Read(b)
if err != nil {
return
}
t.mu.Lock()
// this might overwrite data if somebody peeked between the Read
// and us getting the lock. Oh well, we'll just drop a packet in
// that case.
data := make([]byte, n)
n = copy(data, b[:n])
t.peeked = data
t.peekedAttributes = a
t.mu.Unlock()
return
}