Bolting on send side simulcast

Introduces AddEncoding method in RTP sender to add simulcast encodings.

Added UTs for AddEncoding.
Also modified the Simulcast send test to use the new API.
This commit is contained in:
boks1971
2022-02-22 15:20:27 +05:30
committed by David Zhao
parent e2b8d4c1d7
commit 37e16a3b15
8 changed files with 399 additions and 101 deletions

View File

@@ -142,6 +142,9 @@ var (
// ErrRTPSenderNewTrackHasIncorrectKind indicates that the new track is of a different kind than the previous/original // ErrRTPSenderNewTrackHasIncorrectKind indicates that the new track is of a different kind than the previous/original
ErrRTPSenderNewTrackHasIncorrectKind = errors.New("new track must be of the same kind as previous") ErrRTPSenderNewTrackHasIncorrectKind = errors.New("new track must be of the same kind as previous")
// ErrRTPSenderNewTrackHasIncorrectEnvelope indicates that the new track has a different envelope than the previous/original
ErrRTPSenderNewTrackHasIncorrectEnvelope = errors.New("new track must have the same envelope as previous")
// ErrUnbindFailed indicates that a TrackLocal was not able to be unbind // ErrUnbindFailed indicates that a TrackLocal was not able to be unbind
ErrUnbindFailed = errors.New("failed to unbind TrackLocal from PeerConnection") ErrUnbindFailed = errors.New("failed to unbind TrackLocal from PeerConnection")
@@ -202,10 +205,16 @@ var (
errRTPReceiverWithSSRCTrackStreamNotFound = errors.New("unable to find stream for Track with SSRC") errRTPReceiverWithSSRCTrackStreamNotFound = errors.New("unable to find stream for Track with SSRC")
errRTPReceiverForRIDTrackStreamNotFound = errors.New("no trackStreams found for RID") errRTPReceiverForRIDTrackStreamNotFound = errors.New("no trackStreams found for RID")
errRTPSenderTrackNil = errors.New("Track must not be nil") errRTPSenderTrackNil = errors.New("Track must not be nil")
errRTPSenderDTLSTransportNil = errors.New("DTLSTransport must not be nil") errRTPSenderDTLSTransportNil = errors.New("DTLSTransport must not be nil")
errRTPSenderSendAlreadyCalled = errors.New("Send has already been called") errRTPSenderSendAlreadyCalled = errors.New("Send has already been called")
errRTPSenderTrackRemoved = errors.New("Sender Track has been removed or replaced to nil") errRTPSenderStopped = errors.New("Sender has already been stopped")
errRTPSenderTrackRemoved = errors.New("Sender Track has been removed or replaced to nil")
errRTPSenderRidNil = errors.New("Sender cannot add encoding as rid is empty")
errRTPSenderNoBaseEncoding = errors.New("Sender cannot add encoding as there is no base track")
errRTPSenderBaseEncodingMismatch = errors.New("Sender cannot add encoding as provided track does not match base track")
errRTPSenderRIDCollision = errors.New("Sender cannot encoding due to RID collision")
errRTPSenderNoTrackForRID = errors.New("Sender does not have track for RID")
errRTPTransceiverCannotChangeMid = errors.New("errRTPSenderTrackNil") errRTPTransceiverCannotChangeMid = errors.New("errRTPSenderTrackNil")
errRTPTransceiverSetSendingInvalidState = errors.New("invalid state change in RTPTransceiver.setSending") errRTPTransceiverSetSendingInvalidState = errors.New("invalid state change in RTPTransceiver.setSending")

View File

@@ -162,7 +162,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) {
go func() { go func() {
for { for {
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
if routineErr := pcOffer.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{SenderSSRC: uint32(sender.ssrc), MediaSSRC: uint32(sender.ssrc)}}); routineErr != nil { if routineErr := pcOffer.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{SenderSSRC: uint32(sender.trackEncodings[0].ssrc), MediaSSRC: uint32(sender.trackEncodings[0].ssrc)}}); routineErr != nil {
awaitRTCPSenderSend <- routineErr awaitRTCPSenderSend <- routineErr
} }
@@ -643,12 +643,12 @@ func TestAddTransceiverAddTrack_Reuse(t *testing.T) {
track1, sender1 := addTrack() track1, sender1 := addTrack()
assert.Equal(t, 1, len(pc.GetTransceivers())) assert.Equal(t, 1, len(pc.GetTransceivers()))
assert.Equal(t, sender1, tr.Sender()) assert.Equal(t, sender1, tr.Sender())
assert.Equal(t, track1, tr.Sender().track) assert.Equal(t, track1, tr.Sender().Track())
require.NoError(t, pc.RemoveTrack(sender1)) require.NoError(t, pc.RemoveTrack(sender1))
track2, _ := addTrack() track2, _ := addTrack()
assert.Equal(t, 1, len(pc.GetTransceivers())) assert.Equal(t, 1, len(pc.GetTransceivers()))
assert.Equal(t, track2, tr.Sender().track) assert.Equal(t, track2, tr.Sender().Track())
addTrack() addTrack()
assert.Equal(t, 2, len(pc.GetTransceivers())) assert.Equal(t, 2, len(pc.GetTransceivers()))
@@ -1256,23 +1256,47 @@ func TestPeerConnection_Simulcast(t *testing.T) {
pcOffer, pcAnswer, err := NewAPI(WithMediaEngine(m)).newPair(Configuration{}) pcOffer, pcAnswer, err := NewAPI(WithMediaEngine(m)).newPair(Configuration{})
assert.NoError(t, err) assert.NoError(t, err)
vp8Writer, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2") vp8WriterA, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID("a"))
assert.NoError(t, err) assert.NoError(t, err)
_, err = pcOffer.AddTrack(vp8Writer) sender, err := pcOffer.AddTrack(vp8WriterA)
assert.NoError(t, err)
assert.NotNil(t, sender)
vp8WriterB, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID("b"))
assert.NoError(t, err)
err = sender.AddEncoding(vp8WriterB)
assert.NoError(t, err)
vp8WriterC, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID("c"))
assert.NoError(t, err)
err = sender.AddEncoding(vp8WriterC)
assert.NoError(t, err) assert.NoError(t, err)
ridMap = map[string]int{} ridMap = map[string]int{}
pcAnswer.OnTrack(onTrackHandler) pcAnswer.OnTrack(onTrackHandler)
assert.NoError(t, signalPairWithModification(pcOffer, pcAnswer, func(sessionDescription string) string { parameters := sender.GetParameters()
sessionDescription = strings.Split(sessionDescription, "a=end-of-candidates\r\n")[0] assert.Equal(t, "a", parameters.Encodings[0].RID)
sessionDescription = filterSsrc(sessionDescription) assert.Equal(t, "b", parameters.Encodings[1].RID)
for _, rid := range rids { assert.Equal(t, "c", parameters.Encodings[2].RID)
sessionDescription += "a=" + sdpAttributeRid + ":" + rid + " send\r\n"
var midID, ridID, rsidID uint8
for _, extension := range parameters.HeaderExtensions {
switch extension.URI {
case sdp.SDESMidURI:
midID = uint8(extension.ID)
case sdp.SDESRTPStreamIDURI:
ridID = uint8(extension.ID)
case sdesRepairRTPStreamIDURI:
rsidID = uint8(extension.ID)
} }
return sessionDescription + "a=simulcast:send " + strings.Join(rids, ";") + "\r\n" }
})) assert.NotZero(t, midID)
assert.NotZero(t, ridID)
assert.NotZero(t, rsidID)
assert.NoError(t, signalPair(pcOffer, pcAnswer))
for sequenceNumber := uint16(0); !ridsFullfilled(); sequenceNumber++ { for sequenceNumber := uint16(0); !ridsFullfilled(); sequenceNumber++ {
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)
@@ -1284,17 +1308,26 @@ func TestPeerConnection_Simulcast(t *testing.T) {
SequenceNumber: sequenceNumber, SequenceNumber: sequenceNumber,
PayloadType: 96, PayloadType: 96,
} }
assert.NoError(t, header.SetExtension(1, []byte("0"))) assert.NoError(t, header.SetExtension(midID, []byte("0")))
// Send RSID for first 10 packets // Send RSID for first 10 packets
if sequenceNumber >= 10 { if sequenceNumber >= 10 {
assert.NoError(t, header.SetExtension(2, []byte(rid))) assert.NoError(t, header.SetExtension(ridID, []byte(rid)))
} else { } else {
assert.NoError(t, header.SetExtension(3, []byte(rid))) assert.NoError(t, header.SetExtension(rsidID, []byte(rid)))
header.SSRC += 10 header.SSRC += 10
} }
_, err := vp8Writer.bindings[0].writeStream.WriteRTP(header, []byte{0x00}) var writer *TrackLocalStaticRTP
switch rid {
case "a":
writer = vp8WriterA
case "b":
writer = vp8WriterB
case "c":
writer = vp8WriterC
}
_, err = writer.bindings[0].writeStream.WriteRTP(header, []byte{0x00})
assert.NoError(t, err) assert.NoError(t, err)
} }
} }

View File

@@ -371,7 +371,7 @@ func TestPeerConnection_Transceiver_Mid(t *testing.T) {
// Must have 3 media descriptions (2 video channels) // Must have 3 media descriptions (2 video channels)
assert.Equal(t, len(offer.parsed.MediaDescriptions), 2) assert.Equal(t, len(offer.parsed.MediaDescriptions), 2)
assert.True(t, sdpMidHasSsrc(offer, "0", sender1.ssrc), "Expected mid %q with ssrc %d, offer.SDP: %s", "0", sender1.ssrc, offer.SDP) assert.True(t, sdpMidHasSsrc(offer, "0", sender1.trackEncodings[0].ssrc), "Expected mid %q with ssrc %d, offer.SDP: %s", "0", sender1.trackEncodings[0].ssrc, offer.SDP)
// Remove first track, must keep same number of media // Remove first track, must keep same number of media
// descriptions and same track ssrc for mid 1 as previous // descriptions and same track ssrc for mid 1 as previous
@@ -382,7 +382,7 @@ func TestPeerConnection_Transceiver_Mid(t *testing.T) {
assert.Equal(t, len(offer.parsed.MediaDescriptions), 2) assert.Equal(t, len(offer.parsed.MediaDescriptions), 2)
assert.True(t, sdpMidHasSsrc(offer, "1", sender2.ssrc), "Expected mid %q with ssrc %d, offer.SDP: %s", "1", sender2.ssrc, offer.SDP) assert.True(t, sdpMidHasSsrc(offer, "1", sender2.trackEncodings[0].ssrc), "Expected mid %q with ssrc %d, offer.SDP: %s", "1", sender2.trackEncodings[0].ssrc, offer.SDP)
_, err = pcAnswer.CreateAnswer(nil) _, err = pcAnswer.CreateAnswer(nil)
assert.Equal(t, err, &rtcerr.InvalidStateError{Err: ErrIncorrectSignalingState}) assert.Equal(t, err, &rtcerr.InvalidStateError{Err: ErrIncorrectSignalingState})
@@ -402,8 +402,8 @@ func TestPeerConnection_Transceiver_Mid(t *testing.T) {
// We reuse the existing non-sending transceiver // We reuse the existing non-sending transceiver
assert.Equal(t, len(offer.parsed.MediaDescriptions), 2) assert.Equal(t, len(offer.parsed.MediaDescriptions), 2)
assert.True(t, sdpMidHasSsrc(offer, "0", sender3.ssrc), "Expected mid %q with ssrc %d, offer.sdp: %s", "0", sender3.ssrc, offer.SDP) assert.True(t, sdpMidHasSsrc(offer, "0", sender3.trackEncodings[0].ssrc), "Expected mid %q with ssrc %d, offer.sdp: %s", "0", sender3.trackEncodings[0].ssrc, offer.SDP)
assert.True(t, sdpMidHasSsrc(offer, "1", sender2.ssrc), "Expected mid %q with ssrc %d, offer.sdp: %s", "1", sender2.ssrc, offer.SDP) assert.True(t, sdpMidHasSsrc(offer, "1", sender2.trackEncodings[0].ssrc), "Expected mid %q with ssrc %d, offer.sdp: %s", "1", sender2.trackEncodings[0].ssrc, offer.SDP)
closePairNow(t, pcOffer, pcAnswer) closePairNow(t, pcOffer, pcAnswer)
} }

View File

@@ -4,6 +4,7 @@
package webrtc package webrtc
import ( import (
"fmt"
"io" "io"
"sync" "sync"
"time" "time"
@@ -12,23 +13,30 @@ import (
"github.com/pion/randutil" "github.com/pion/randutil"
"github.com/pion/rtcp" "github.com/pion/rtcp"
"github.com/pion/rtp" "github.com/pion/rtp"
"github.com/pion/webrtc/v3/internal/util"
) )
// RTPSender allows an application to control how a given Track is encoded and transmitted to a remote peer type trackEncoding struct {
type RTPSender struct {
track TrackLocal track TrackLocal
srtpStream *srtpWriterFuture srtpStream *srtpWriterFuture
rtcpInterceptor interceptor.RTCPReader rtcpInterceptor interceptor.RTCPReader
streamInfo interceptor.StreamInfo streamInfo interceptor.StreamInfo
context TrackLocalContext context TrackLocalContext
ssrc SSRC
}
// RTPSender allows an application to control how a given Track is encoded and transmitted to a remote peer
type RTPSender struct {
trackEncodings []*trackEncoding
transport *DTLSTransport transport *DTLSTransport
payloadType PayloadType payloadType PayloadType
kind RTPCodecType kind RTPCodecType
ssrc SSRC
// nolint:godox // nolint:godox
// TODO(sgotti) remove this when in future we'll avoid replacing // TODO(sgotti) remove this when in future we'll avoid replacing
@@ -60,23 +68,15 @@ func (api *API) NewRTPSender(track TrackLocal, transport *DTLSTransport) (*RTPSe
} }
r := &RTPSender{ r := &RTPSender{
track: track,
transport: transport, transport: transport,
api: api, api: api,
sendCalled: make(chan struct{}), sendCalled: make(chan struct{}),
stopCalled: make(chan struct{}), stopCalled: make(chan struct{}),
ssrc: SSRC(randutil.NewMathRandomGenerator().Uint32()),
id: id, id: id,
srtpStream: &srtpWriterFuture{},
kind: track.Kind(), kind: track.Kind(),
} }
r.srtpStream.rtpSender = r r.addEncoding(track)
r.rtcpInterceptor = r.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = r.srtpStream.Read(in)
return n, a, err
}))
return r, nil return r, nil
} }
@@ -108,24 +108,26 @@ func (r *RTPSender) Transport() *DTLSTransport {
} }
func (r *RTPSender) getParameters() RTPSendParameters { func (r *RTPSender) getParameters() RTPSendParameters {
var rid string var encodings []RTPEncodingParameters
if r.track != nil { for _, trackEncoding := range r.trackEncodings {
rid = r.track.RID() var rid string
if trackEncoding.track != nil {
rid = trackEncoding.track.RID()
}
encodings = append(encodings, RTPEncodingParameters{
RTPCodingParameters: RTPCodingParameters{
RID: rid,
SSRC: trackEncoding.ssrc,
PayloadType: r.payloadType,
},
})
} }
sendParameters := RTPSendParameters{ sendParameters := RTPSendParameters{
RTPParameters: r.api.mediaEngine.getRTPParametersByKind( RTPParameters: r.api.mediaEngine.getRTPParametersByKind(
r.kind, r.kind,
[]RTPTransceiverDirection{RTPTransceiverDirectionSendonly}, []RTPTransceiverDirection{RTPTransceiverDirectionSendonly},
), ),
Encodings: []RTPEncodingParameters{ Encodings: encodings,
{
RTPCodingParameters: RTPCodingParameters{
RID: rid,
SSRC: r.ssrc,
PayloadType: r.payloadType,
},
},
},
} }
if r.rtpTransceiver != nil { if r.rtpTransceiver != nil {
sendParameters.Codecs = r.rtpTransceiver.getCodecs() sendParameters.Codecs = r.rtpTransceiver.getCodecs()
@@ -143,11 +145,81 @@ func (r *RTPSender) GetParameters() RTPSendParameters {
return r.getParameters() return r.getParameters()
} }
// AddEncoding adds an encoding to RTPSender. Used by simulcast senders.
func (r *RTPSender) AddEncoding(track TrackLocal) error {
r.mu.Lock()
defer r.mu.Unlock()
if track == nil {
return errRTPSenderTrackNil
}
if track.RID() == "" {
return errRTPSenderRidNil
}
if r.hasStopped() {
return errRTPSenderStopped
}
if r.hasSent() {
return errRTPSenderSendAlreadyCalled
}
var refTrack TrackLocal
if len(r.trackEncodings) != 0 {
refTrack = r.trackEncodings[0].track
}
if refTrack == nil || refTrack.RID() == "" {
return errRTPSenderNoBaseEncoding
}
if refTrack.ID() != track.ID() || refTrack.StreamID() != track.StreamID() || refTrack.Kind() != track.Kind() {
return errRTPSenderBaseEncodingMismatch
}
for _, encoding := range r.trackEncodings {
if encoding.track == nil {
continue
}
if encoding.track.RID() == track.RID() {
return errRTPSenderRIDCollision
}
}
r.addEncoding(track)
return nil
}
func (r *RTPSender) addEncoding(track TrackLocal) {
ssrc := SSRC(randutil.NewMathRandomGenerator().Uint32())
trackEncoding := &trackEncoding{
track: track,
srtpStream: &srtpWriterFuture{ssrc: ssrc},
ssrc: ssrc,
}
trackEncoding.srtpStream.rtpSender = r
trackEncoding.rtcpInterceptor = r.api.interceptor.BindRTCPReader(
interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = trackEncoding.srtpStream.Read(in)
return n, a, err
}),
)
r.trackEncodings = append(r.trackEncodings, trackEncoding)
}
// Track returns the RTCRtpTransceiver track, or nil // Track returns the RTCRtpTransceiver track, or nil
func (r *RTPSender) Track() TrackLocal { func (r *RTPSender) Track() TrackLocal {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
return r.track
if len(r.trackEncodings) == 0 {
return nil
}
return r.trackEncodings[0].track
} }
// ReplaceTrack replaces the track currently being used as the sender's source with a new TrackLocal. // ReplaceTrack replaces the track currently being used as the sender's source with a new TrackLocal.
@@ -161,26 +233,38 @@ func (r *RTPSender) ReplaceTrack(track TrackLocal) error {
return ErrRTPSenderNewTrackHasIncorrectKind return ErrRTPSenderNewTrackHasIncorrectKind
} }
if r.hasSent() && r.track != nil { // cannot replace simulcast envelope
if err := r.track.Unbind(r.context); err != nil { if track != nil && len(r.trackEncodings) > 1 {
return ErrRTPSenderNewTrackHasIncorrectEnvelope
}
var replacedTrack TrackLocal
var context *TrackLocalContext
if len(r.trackEncodings) != 0 {
replacedTrack = r.trackEncodings[0].track
context = &r.trackEncodings[0].context
}
if r.hasSent() && replacedTrack != nil {
if err := replacedTrack.Unbind(*context); err != nil {
return err return err
} }
} }
if !r.hasSent() || track == nil { if !r.hasSent() || track == nil {
r.track = track r.trackEncodings[0].track = track
return nil return nil
} }
codec, err := track.Bind(TrackLocalContext{ codec, err := track.Bind(TrackLocalContext{
id: r.context.id, id: context.id,
params: r.api.mediaEngine.getRTPParametersByKind(track.Kind(), []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}), params: r.api.mediaEngine.getRTPParametersByKind(track.Kind(), []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}),
ssrc: r.context.ssrc, ssrc: context.ssrc,
writeStream: r.context.writeStream, writeStream: context.writeStream,
rtcpInterceptor: context.rtcpInterceptor,
}) })
if err != nil { if err != nil {
// Re-bind the original track // Re-bind the original track
if _, reBindErr := r.track.Bind(r.context); reBindErr != nil { if _, reBindErr := replacedTrack.Bind(*context); reBindErr != nil {
return reBindErr return reBindErr
} }
@@ -189,10 +273,10 @@ func (r *RTPSender) ReplaceTrack(track TrackLocal) error {
// Codec has changed // Codec has changed
if r.payloadType != codec.PayloadType { if r.payloadType != codec.PayloadType {
r.context.params.Codecs = []RTPCodecParameters{codec} context.params.Codecs = []RTPCodecParameters{codec}
} }
r.track = track r.trackEncodings[0].track = track
return nil return nil
} }
@@ -204,29 +288,42 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error {
switch { switch {
case r.hasSent(): case r.hasSent():
return errRTPSenderSendAlreadyCalled return errRTPSenderSendAlreadyCalled
case r.track == nil: case r.trackEncodings[0].track == nil:
return errRTPSenderTrackRemoved return errRTPSenderTrackRemoved
} }
writeStream := &interceptorToTrackLocalWriter{} for idx, trackEncoding := range r.trackEncodings {
r.context = TrackLocalContext{ writeStream := &interceptorToTrackLocalWriter{}
id: r.id, trackEncoding.context = TrackLocalContext{
params: r.api.mediaEngine.getRTPParametersByKind(r.track.Kind(), []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}), id: r.id,
ssrc: parameters.Encodings[0].SSRC, params: r.api.mediaEngine.getRTPParametersByKind(trackEncoding.track.Kind(), []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}),
writeStream: writeStream, ssrc: parameters.Encodings[idx].SSRC,
} writeStream: writeStream,
rtcpInterceptor: trackEncoding.rtcpInterceptor,
}
codec, err := r.track.Bind(r.context) codec, err := trackEncoding.track.Bind(trackEncoding.context)
if err != nil { if err != nil {
return err return err
} }
r.context.params.Codecs = []RTPCodecParameters{codec} trackEncoding.context.params.Codecs = []RTPCodecParameters{codec}
r.streamInfo = *createStreamInfo(r.id, parameters.Encodings[0].SSRC, codec.PayloadType, codec.RTPCodecCapability, parameters.HeaderExtensions) trackEncoding.streamInfo = *createStreamInfo(
rtpInterceptor := r.api.interceptor.BindLocalStream(&r.streamInfo, interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { r.id,
return r.srtpStream.WriteRTP(header, payload) parameters.Encodings[idx].SSRC,
})) codec.PayloadType,
writeStream.interceptor.Store(rtpInterceptor) codec.RTPCodecCapability,
parameters.HeaderExtensions,
)
srtpStream := trackEncoding.srtpStream
rtpInterceptor := r.api.interceptor.BindLocalStream(
&trackEncoding.streamInfo,
interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
return srtpStream.WriteRTP(header, payload)
}),
)
writeStream.interceptor.Store(rtpInterceptor)
}
close(r.sendCalled) close(r.sendCalled)
return nil return nil
@@ -252,16 +349,20 @@ func (r *RTPSender) Stop() error {
return err return err
} }
r.api.interceptor.UnbindLocalStream(&r.streamInfo) errs := []error{}
for _, trackEncoding := range r.trackEncodings {
r.api.interceptor.UnbindLocalStream(&trackEncoding.streamInfo)
errs = append(errs, trackEncoding.srtpStream.Close())
}
return r.srtpStream.Close() return util.FlattenErrs(errs)
} }
// Read reads incoming RTCP for this RTPReceiver // Read reads incoming RTCP for this RTPSender
func (r *RTPSender) Read(b []byte) (n int, a interceptor.Attributes, err error) { func (r *RTPSender) Read(b []byte) (n int, a interceptor.Attributes, err error) {
select { select {
case <-r.sendCalled: case <-r.sendCalled:
return r.rtcpInterceptor.Read(b, a) return r.trackEncodings[0].rtcpInterceptor.Read(b, a)
case <-r.stopCalled: case <-r.stopCalled:
return 0, nil, io.ErrClosedPipe return 0, nil, io.ErrClosedPipe
} }
@@ -283,10 +384,50 @@ func (r *RTPSender) ReadRTCP() ([]rtcp.Packet, interceptor.Attributes, error) {
return pkts, attributes, nil return pkts, attributes, nil
} }
// ReadSimulcast reads incoming RTCP for this RTPSender for given rid
func (r *RTPSender) ReadSimulcast(b []byte, rid string) (n int, a interceptor.Attributes, err error) {
select {
case <-r.sendCalled:
for _, t := range r.trackEncodings {
if t.track != nil && t.track.RID() == rid {
return t.rtcpInterceptor.Read(b, a)
}
}
return 0, nil, fmt.Errorf("%w: %s", errRTPSenderNoTrackForRID, rid)
case <-r.stopCalled:
return 0, nil, io.ErrClosedPipe
}
}
// ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you
func (r *RTPSender) ReadSimulcastRTCP(rid string) ([]rtcp.Packet, interceptor.Attributes, error) {
b := make([]byte, r.api.settingEngine.getReceiveMTU())
i, attributes, err := r.ReadSimulcast(b, rid)
if err != nil {
return nil, nil, err
}
pkts, err := rtcp.Unmarshal(b[:i])
return pkts, attributes, err
}
// SetReadDeadline sets the deadline for the Read operation. // SetReadDeadline sets the deadline for the Read operation.
// Setting to zero means no deadline. // Setting to zero means no deadline.
func (r *RTPSender) SetReadDeadline(t time.Time) error { func (r *RTPSender) SetReadDeadline(t time.Time) error {
return r.srtpStream.SetReadDeadline(t) return r.trackEncodings[0].srtpStream.SetReadDeadline(t)
}
// SetReadDeadlineSimulcast sets the max amount of time the RTCP stream for a given rid will block before returning. 0 is forever.
func (r *RTPSender) SetReadDeadlineSimulcast(deadline time.Time, rid string) error {
r.mu.RLock()
defer r.mu.RUnlock()
for _, t := range r.trackEncodings {
if t.track != nil && t.track.RID() == rid {
return t.srtpStream.SetReadDeadline(deadline)
}
}
return fmt.Errorf("%w: %s", errRTPSenderNoTrackForRID, rid)
} }
// hasSent tells if data has been ever sent for this instance // hasSent tells if data has been ever sent for this instance

View File

@@ -117,7 +117,7 @@ func Test_RTPSender_GetParameters(t *testing.T) {
parameters := rtpTransceiver.Sender().GetParameters() parameters := rtpTransceiver.Sender().GetParameters()
assert.NotEqual(t, 0, len(parameters.Codecs)) assert.NotEqual(t, 0, len(parameters.Codecs))
assert.Equal(t, 1, len(parameters.Encodings)) assert.Equal(t, 1, len(parameters.Encodings))
assert.Equal(t, rtpTransceiver.Sender().ssrc, parameters.Encodings[0].SSRC) assert.Equal(t, rtpTransceiver.Sender().trackEncodings[0].ssrc, parameters.Encodings[0].SSRC)
assert.Equal(t, "", parameters.Encodings[0].RID) assert.Equal(t, "", parameters.Encodings[0].RID)
closePairNow(t, offerer, answerer) closePairNow(t, offerer, answerer)
@@ -340,3 +340,64 @@ func Test_RTPSender_Send_Track_Removed(t *testing.T) {
assert.NoError(t, peerConnection.Close()) assert.NoError(t, peerConnection.Close())
} }
func Test_RTPSender_Add_Encoding(t *testing.T) {
track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
assert.NoError(t, err)
peerConnection, err := NewPeerConnection(Configuration{})
assert.NoError(t, err)
rtpSender, err := peerConnection.AddTrack(track)
assert.NoError(t, err)
assert.Equal(t, errRTPSenderTrackNil, rtpSender.AddEncoding(nil))
track1, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
assert.NoError(t, err)
assert.Equal(t, errRTPSenderRidNil, rtpSender.AddEncoding(track1))
track1, err = NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion", WithRTPStreamID("h"))
assert.NoError(t, err)
assert.Equal(t, errRTPSenderNoBaseEncoding, rtpSender.AddEncoding(track1))
track, err = NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion", WithRTPStreamID("q"))
assert.NoError(t, err)
rtpSender, err = peerConnection.AddTrack(track)
assert.NoError(t, err)
track1, err = NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video1", "pion", WithRTPStreamID("h"))
assert.NoError(t, err)
assert.Equal(t, errRTPSenderBaseEncodingMismatch, rtpSender.AddEncoding(track1))
track1, err = NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion1", WithRTPStreamID("h"))
assert.NoError(t, err)
assert.Equal(t, errRTPSenderBaseEncodingMismatch, rtpSender.AddEncoding(track1))
track1, err = NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeOpus}, "video", "pion", WithRTPStreamID("h"))
assert.NoError(t, err)
assert.Equal(t, errRTPSenderBaseEncodingMismatch, rtpSender.AddEncoding(track1))
track1, err = NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion", WithRTPStreamID("q"))
assert.NoError(t, err)
assert.Equal(t, errRTPSenderRIDCollision, rtpSender.AddEncoding(track1))
track1, err = NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion", WithRTPStreamID("h"))
assert.NoError(t, err)
assert.NoError(t, rtpSender.AddEncoding(track1))
err = rtpSender.Send(rtpSender.GetParameters())
assert.NoError(t, err)
track1, err = NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion", WithRTPStreamID("f"))
assert.NoError(t, err)
assert.Equal(t, errRTPSenderSendAlreadyCalled, rtpSender.AddEncoding(track1))
err = rtpSender.Stop()
assert.NoError(t, err)
assert.Equal(t, errRTPSenderStopped, rtpSender.AddEncoding(track1))
assert.NoError(t, peerConnection.Close())
}

66
sdp.go
View File

@@ -340,7 +340,60 @@ func populateLocalCandidates(sessionDescription *SessionDescription, i *ICEGathe
} }
} }
func addTransceiverSDP(d *sdp.SessionDescription, isPlanB, shouldAddCandidates bool, dtlsFingerprints []DTLSFingerprint, mediaEngine *MediaEngine, midValue string, iceParams ICEParameters, candidates []ICECandidate, dtlsRole sdp.ConnectionRole, iceGatheringState ICEGatheringState, mediaSection mediaSection) (bool, error) { func addSenderSDP(
mediaSection mediaSection,
isPlanB bool,
media *sdp.MediaDescription,
) {
for _, mt := range mediaSection.transceivers {
sender := mt.Sender()
if sender == nil {
continue
}
track := sender.Track()
if track == nil {
continue
}
sendParameters := sender.GetParameters()
for _, encoding := range sendParameters.Encodings {
media = media.WithMediaSource(uint32(encoding.SSRC), track.StreamID() /* cname */, track.StreamID() /* streamLabel */, track.ID())
if !isPlanB {
media = media.WithPropertyAttribute("msid:" + track.StreamID() + " " + track.ID())
}
}
if len(sendParameters.Encodings) > 1 {
sendRids := make([]string, 0, len(sendParameters.Encodings))
for _, encoding := range sendParameters.Encodings {
media.WithValueAttribute(sdpAttributeRid, encoding.RID+" send")
sendRids = append(sendRids, encoding.RID)
}
// Simulcast
media.WithValueAttribute("simulcast", "send "+strings.Join(sendRids, ";"))
}
if !isPlanB {
break
}
}
}
func addTransceiverSDP(
d *sdp.SessionDescription,
isPlanB bool,
shouldAddCandidates bool,
dtlsFingerprints []DTLSFingerprint,
mediaEngine *MediaEngine,
midValue string,
iceParams ICEParameters,
candidates []ICECandidate,
dtlsRole sdp.ConnectionRole,
iceGatheringState ICEGatheringState,
mediaSection mediaSection,
) (bool, error) {
transceivers := mediaSection.transceivers transceivers := mediaSection.transceivers
if len(transceivers) < 1 { if len(transceivers) < 1 {
return false, errSDPZeroTransceivers return false, errSDPZeroTransceivers
@@ -410,16 +463,7 @@ func addTransceiverSDP(d *sdp.SessionDescription, isPlanB, shouldAddCandidates b
media.WithValueAttribute("simulcast", "recv "+strings.Join(recvRids, ";")) media.WithValueAttribute("simulcast", "recv "+strings.Join(recvRids, ";"))
} }
for _, mt := range transceivers { addSenderSDP(mediaSection, isPlanB, media)
if sender := mt.Sender(); sender != nil && sender.Track() != nil {
track := sender.Track()
media = media.WithMediaSource(uint32(sender.ssrc), track.StreamID() /* cname */, track.StreamID() /* streamLabel */, track.ID())
if !isPlanB {
media = media.WithPropertyAttribute("msid:" + track.StreamID() + " " + track.ID())
break
}
}
}
media = media.WithPropertyAttribute(t.Direction().String()) media = media.WithPropertyAttribute(t.Direction().String())

View File

@@ -16,6 +16,7 @@ import (
// srtpWriterFuture blocks Read/Write calls until // srtpWriterFuture blocks Read/Write calls until
// the SRTP Session is available // the SRTP Session is available
type srtpWriterFuture struct { type srtpWriterFuture struct {
ssrc SSRC
rtpSender *RTPSender rtpSender *RTPSender
rtcpReadStream atomic.Value // *srtp.ReadStreamSRTCP rtcpReadStream atomic.Value // *srtp.ReadStreamSRTCP
rtpWriteStream atomic.Value // *srtp.WriteStreamSRTP rtpWriteStream atomic.Value // *srtp.WriteStreamSRTP
@@ -52,7 +53,7 @@ func (s *srtpWriterFuture) init(returnWhenNoSRTP bool) error {
return err return err
} }
rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(s.rtpSender.ssrc)) rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(s.ssrc))
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,6 +1,9 @@
package webrtc package webrtc
import "github.com/pion/rtp" import (
"github.com/pion/interceptor"
"github.com/pion/rtp"
)
// TrackLocalWriter is the Writer for outbound RTP Packets // TrackLocalWriter is the Writer for outbound RTP Packets
type TrackLocalWriter interface { type TrackLocalWriter interface {
@@ -14,10 +17,11 @@ type TrackLocalWriter interface {
// TrackLocalContext is the Context passed when a TrackLocal has been Binded/Unbinded from a PeerConnection, and used // TrackLocalContext is the Context passed when a TrackLocal has been Binded/Unbinded from a PeerConnection, and used
// in Interceptors. // in Interceptors.
type TrackLocalContext struct { type TrackLocalContext struct {
id string id string
params RTPParameters params RTPParameters
ssrc SSRC ssrc SSRC
writeStream TrackLocalWriter writeStream TrackLocalWriter
rtcpInterceptor interceptor.RTCPReader
} }
// CodecParameters returns the negotiated RTPCodecParameters. These are the codecs supported by both // CodecParameters returns the negotiated RTPCodecParameters. These are the codecs supported by both
@@ -49,6 +53,11 @@ func (t *TrackLocalContext) ID() string {
return t.id return t.id
} }
// RTCPReader returns the RTCP interceptor for this TrackLocal. Used to read RTCP of this TrackLocal.
func (t *TrackLocalContext) RTCPReader() interceptor.RTCPReader {
return t.rtcpInterceptor
}
// TrackLocal is an interface that controls how the user can send media // TrackLocal is an interface that controls how the user can send media
// The user can provide their own TrackLocal implementations, or use // The user can provide their own TrackLocal implementations, or use
// the implementations in pkg/media // the implementations in pkg/media