diff --git a/rtpsender.go b/rtpsender.go index e8f0a2f1..8cfe97d3 100644 --- a/rtpsender.go +++ b/rtpsender.go @@ -13,13 +13,16 @@ import ( // RTPSender allows an application to control how a given Track is encoded and transmitted to a remote peer type RTPSender struct { - track TrackLocal + track TrackLocal + rtcpReadStream *srtp.ReadStreamSRTCP + rtpWriteStream *srtp.WriteStreamSRTP transport *DTLSTransport payloadType PayloadType ssrc SSRC + codec RTPCodecParameters // nolint:godox // TODO(sgotti) remove this when in future we'll avoid replacing @@ -86,10 +89,39 @@ func (r *RTPSender) Track() TrackLocal { return r.track } -func (r *RTPSender) setTrack(track TrackLocal) { +// ReplaceTrack replaces the track currently being used as the sender's source with a new TrackLocal. +// The new track must be of the same media kind (audio, video, etc) and switching the track should not +// require negotiation. +func (r *RTPSender) ReplaceTrack(track TrackLocal) error { r.mu.Lock() defer r.mu.Unlock() + + if r.hasSent() { + if err := r.track.Unbind(TrackLocalContext{ + id: r.id, + ssrc: r.ssrc, + writeStream: r.rtpWriteStream, + }); err != nil { + return err + } + } + + if !r.hasSent() || track == nil { + r.track = track + return nil + } + + if _, err := track.Bind(TrackLocalContext{ + id: r.id, + codecs: []RTPCodecParameters{r.codec}, + ssrc: r.ssrc, + writeStream: r.rtpWriteStream, + }); err != nil { + return err + } + r.track = track + return nil } // Send Attempts to set the parameters controlling the sending of media. @@ -116,16 +148,15 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error { return err } - rtpWriteStream, err := srtpSession.OpenWriteStream() - if err != nil { + if r.rtpWriteStream, err = srtpSession.OpenWriteStream(); err != nil { return err } - if err = r.track.Bind(TrackLocalContext{ + if r.codec, err = r.track.Bind(TrackLocalContext{ id: r.id, codecs: r.api.mediaEngine.getCodecsByKind(r.track.Kind()), ssrc: parameters.Encodings.SSRC, - writeStream: rtpWriteStream, + writeStream: r.rtpWriteStream, }); err != nil { return err } @@ -150,14 +181,6 @@ func (r *RTPSender) Stop() error { return nil } - if err := r.track.Unbind(TrackLocalContext{ - id: r.id, - codecs: r.api.mediaEngine.getCodecsByKind(r.track.Kind()), - ssrc: r.ssrc, - }); err != nil { - return err - } - return r.rtcpReadStream.Close() } diff --git a/rtpsender_test.go b/rtpsender_test.go new file mode 100644 index 00000000..99a13d42 --- /dev/null +++ b/rtpsender_test.go @@ -0,0 +1,136 @@ +// +build !js + +package webrtc + +import ( + "bytes" + "context" + "errors" + "io" + "sync/atomic" + "testing" + "time" + + "github.com/pion/transport/test" + "github.com/pion/webrtc/v3/pkg/media" + "github.com/stretchr/testify/assert" +) + +func Test_RTPSender_ReplaceTrack(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + t.Run("Basic", func(t *testing.T) { + s := SettingEngine{} + s.DisableSRTPReplayProtection(true) + + m := &MediaEngine{} + assert.NoError(t, m.RegisterDefaultCodecs()) + + sender, receiver, err := NewAPI(WithMediaEngine(m), WithSettingEngine(s)).newPair(Configuration{}) + assert.NoError(t, err) + + trackA, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video", "pion") + assert.NoError(t, err) + + trackB, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video", "pion") + assert.NoError(t, err) + + rtpSender, err := sender.AddTrack(trackA) + assert.NoError(t, err) + + seenPacketA, seenPacketACancel := context.WithCancel(context.Background()) + seenPacketB, seenPacketBCancel := context.WithCancel(context.Background()) + + var onTrackCount uint64 + receiver.OnTrack(func(track *TrackRemote, _ *RTPReceiver) { + assert.Equal(t, uint64(1), atomic.AddUint64(&onTrackCount, 1)) + + for { + pkt, err := track.ReadRTP() + if err != nil { + assert.True(t, errors.Is(io.EOF, err)) + return + } + + switch { + case bytes.Equal(pkt.Payload, []byte{0x10, 0xAA}): + seenPacketACancel() + case bytes.Equal(pkt.Payload, []byte{0x10, 0xBB}): + seenPacketBCancel() + } + } + }) + + assert.NoError(t, signalPair(sender, receiver)) + + // Block Until packet with 0xAA has been seen + func() { + for range time.Tick(time.Millisecond * 20) { + select { + case <-seenPacketA.Done(): + return + default: + assert.NoError(t, trackA.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second})) + } + } + }() + + assert.NoError(t, rtpSender.ReplaceTrack(trackB)) + + // Block Until packet with 0xBB has been seen + func() { + for range time.Tick(time.Millisecond * 20) { + select { + case <-seenPacketB.Done(): + return + default: + assert.NoError(t, trackB.WriteSample(media.Sample{Data: []byte{0xBB}, Duration: time.Second})) + } + } + }() + + assert.NoError(t, sender.Close()) + assert.NoError(t, receiver.Close()) + }) + + t.Run("Invalid Codec Change", func(t *testing.T) { + sender, receiver, err := newPair() + assert.NoError(t, err) + + trackA, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video", "pion") + assert.NoError(t, err) + + trackB, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/h264", SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42001f"}, "video", "pion") + assert.NoError(t, err) + + rtpSender, err := sender.AddTrack(trackA) + assert.NoError(t, err) + + assert.NoError(t, signalPair(sender, receiver)) + + seenPacket, seenPacketCancel := context.WithCancel(context.Background()) + receiver.OnTrack(func(_ *TrackRemote, _ *RTPReceiver) { + seenPacketCancel() + }) + + func() { + for range time.Tick(time.Millisecond * 20) { + select { + case <-seenPacket.Done(): + return + default: + assert.NoError(t, trackA.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second})) + } + } + }() + + assert.True(t, errors.Is(rtpSender.ReplaceTrack(trackB), ErrUnsupportedCodec)) + + assert.NoError(t, sender.Close()) + assert.NoError(t, receiver.Close()) + }) +} diff --git a/rtptransceiver.go b/rtptransceiver.go index b5b8fd8e..b0c1f03d 100644 --- a/rtptransceiver.go +++ b/rtptransceiver.go @@ -101,7 +101,9 @@ func (t *RTPTransceiver) setDirection(d RTPTransceiverDirection) { } func (t *RTPTransceiver) setSendingTrack(track TrackLocal) error { - t.Sender().setTrack(track) + if err := t.Sender().ReplaceTrack(track); err != nil { + return err + } if track == nil { t.setSender(nil) } diff --git a/track_local.go b/track_local.go index 53913370..1b232b95 100644 --- a/track_local.go +++ b/track_local.go @@ -49,7 +49,7 @@ type TrackLocal interface { // Bind should implement the way how the media data flows from the Track to the PeerConnection // This will be called internally after signaling is complete and the list of available // codecs has been determined - Bind(TrackLocalContext) error + Bind(TrackLocalContext) (RTPCodecParameters, error) // Unbind should implement the teardown logic when the track is no longer needed. This happens // because a track has been stopped. diff --git a/track_local_static.go b/track_local_static.go index 0ed8fc67..297f256c 100644 --- a/track_local_static.go +++ b/track_local_static.go @@ -43,7 +43,7 @@ func NewTrackLocalStaticRTP(c RTPCodecCapability, id, streamID string) (*TrackLo // Bind is called by the PeerConnection after negotiation is complete // This asserts that the code requested is supported by the remote peer. // If so it setups all the state (SSRC and PayloadType) to have a call -func (s *TrackLocalStaticRTP) Bind(t TrackLocalContext) error { +func (s *TrackLocalStaticRTP) Bind(t TrackLocalContext) (RTPCodecParameters, error) { s.mu.Lock() defer s.mu.Unlock() @@ -55,10 +55,10 @@ func (s *TrackLocalStaticRTP) Bind(t TrackLocalContext) error { writeStream: t.WriteStream(), id: t.ID(), }) - return nil + return codec, nil } - return ErrUnsupportedCodec + return RTPCodecParameters{}, ErrUnsupportedCodec } // Unbind implements the teardown logic when the track is no longer needed. This happens @@ -165,9 +165,10 @@ func (s *TrackLocalStaticSample) Kind() RTPCodecType { return s.rtpTrack.Kind() // Bind is called by the PeerConnection after negotiation is complete // This asserts that the code requested is supported by the remote peer. // If so it setups all the state (SSRC and PayloadType) to have a call -func (s *TrackLocalStaticSample) Bind(t TrackLocalContext) error { - if err := s.rtpTrack.Bind(t); err != nil { - return err +func (s *TrackLocalStaticSample) Bind(t TrackLocalContext) (RTPCodecParameters, error) { + codec, err := s.rtpTrack.Bind(t) + if err != nil { + return codec, err } s.rtpTrack.mu.Lock() @@ -175,18 +176,12 @@ func (s *TrackLocalStaticSample) Bind(t TrackLocalContext) error { // We only need one packetizer if s.packetizer != nil { - return nil + return codec, nil } - parameters := RTPCodecParameters{RTPCodecCapability: s.rtpTrack.codec} - codec, err := codecParametersFuzzySearch(parameters, t.CodecParameters()) + payloader, err := payloaderForCodec(codec.RTPCodecCapability) if err != nil { - return err - } - - payloader, err := payloaderForCodec(s.rtpTrack.codec) - if err != nil { - return err + return codec, err } s.packetizer = rtp.NewPacketizer( @@ -198,7 +193,7 @@ func (s *TrackLocalStaticSample) Bind(t TrackLocalContext) error { codec.ClockRate, ) s.clockRate = float64(codec.RTPCodecCapability.ClockRate) - return nil + return codec, nil } // Unbind implements the teardown logic when the track is no longer needed. This happens