Fix deadlock in track.Bind() (#466)

Occurs when read errors happen from a
driver source during a call to track.Bind()
This commit is contained in:
Kyle
2023-01-22 17:00:20 -08:00
committed by GitHub
parent f8f8511d94
commit 5da0ebf443
2 changed files with 62 additions and 6 deletions

View File

@@ -83,6 +83,7 @@ type baseTrack struct {
Source
err error
onErrorHandler func(error)
errMu sync.Mutex
mu sync.Mutex
endOnce sync.Once
kind MediaDeviceType
@@ -129,10 +130,10 @@ func (track *baseTrack) RID() string {
// OnEnded sets an error handler. When a track has been created and started, if an
// error occurs, handler will get called with the error given to the parameter.
func (track *baseTrack) OnEnded(handler func(error)) {
track.mu.Lock()
track.errMu.Lock()
track.onErrorHandler = handler
err := track.err
track.mu.Unlock()
track.errMu.Unlock()
if err != nil && handler != nil {
// Already errored.
@@ -144,10 +145,10 @@ func (track *baseTrack) OnEnded(handler func(error)) {
// onError is a callback when an error occurs
func (track *baseTrack) onError(err error) {
track.mu.Lock()
track.errMu.Lock()
track.err = err
handler := track.onErrorHandler
track.mu.Unlock()
track.errMu.Unlock()
if handler != nil {
track.endOnce.Do(func() {
@@ -171,6 +172,14 @@ func (track *baseTrack) bind(ctx webrtc.TrackLocalContext, specializedTrack Trac
for _, wantedCodec := range ctx.CodecParameters() {
logger.Debugf("trying to build %s rtp reader", wantedCodec.MimeType)
encodedReader, err = specializedTrack.NewRTPReader(wantedCodec.MimeType, uint32(ctx.SSRC()), rtpOutboundMTU)
track.errMu.Lock()
if track.err != nil {
err = track.err
encodedReader = nil
}
track.errMu.Unlock()
if err == nil {
selectedCodec = wantedCodec
break

View File

@@ -2,14 +2,33 @@ package mediadevices
import (
"errors"
"github.com/pion/interceptor"
"io"
"sync"
"testing"
"time"
"github.com/pion/interceptor"
"github.com/pion/webrtc/v3"
)
var errExpected error = errors.New("an error")
type DummyBindTrack struct {
*baseTrack
}
func (track *DummyBindTrack) Bind(ctx webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) {
track.mu.Lock()
defer track.mu.Unlock()
track.onError(errExpected)
<-time.After(5 * time.Millisecond)
return webrtc.RTPCodecParameters{}, nil
}
func TestOnEnded(t *testing.T) {
errExpected := errors.New("an error")
t.Run("ErrorAfterRegister", func(t *testing.T) {
tr := &baseTrack{}
@@ -54,6 +73,34 @@ func TestOnEnded(t *testing.T) {
t.Error("Timeout")
}
})
t.Run("ErrorDurringBind", func(t *testing.T) {
tr := &DummyBindTrack{
baseTrack: &baseTrack{
activePeerConnections: make(map[string]chan<- chan<- struct{}),
mu: sync.Mutex{},
},
}
called := make(chan error, 1)
tr.OnEnded(func(err error) {
called <- errExpected
})
_, err := tr.Bind(webrtc.TrackLocalContext{})
if err != nil {
t.Fatal(err)
}
select {
case err := <-called:
if err != errExpected {
t.Errorf("Expected to receive error: %v, got: %v", errExpected, err)
}
case <-time.After(10 * time.Millisecond):
t.Error("Timeout")
}
})
}
type fakeRTCPReader struct {