mirror of
https://github.com/pion/mediadevices.git
synced 2025-10-24 00:53:09 +08:00
Fix deadlock in track.Bind() (#466)
Occurs when read errors happen from a driver source during a call to track.Bind()
This commit is contained in:
17
track.go
17
track.go
@@ -83,6 +83,7 @@ type baseTrack struct {
|
|||||||
Source
|
Source
|
||||||
err error
|
err error
|
||||||
onErrorHandler func(error)
|
onErrorHandler func(error)
|
||||||
|
errMu sync.Mutex
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
endOnce sync.Once
|
endOnce sync.Once
|
||||||
kind MediaDeviceType
|
kind MediaDeviceType
|
||||||
@@ -129,10 +130,10 @@ func (track *baseTrack) RID() string {
|
|||||||
// OnEnded sets an error handler. When a track has been created and started, if an
|
// OnEnded sets an error handler. When a track has been created and started, if an
|
||||||
// error occurs, handler will get called with the error given to the parameter.
|
// error occurs, handler will get called with the error given to the parameter.
|
||||||
func (track *baseTrack) OnEnded(handler func(error)) {
|
func (track *baseTrack) OnEnded(handler func(error)) {
|
||||||
track.mu.Lock()
|
track.errMu.Lock()
|
||||||
track.onErrorHandler = handler
|
track.onErrorHandler = handler
|
||||||
err := track.err
|
err := track.err
|
||||||
track.mu.Unlock()
|
track.errMu.Unlock()
|
||||||
|
|
||||||
if err != nil && handler != nil {
|
if err != nil && handler != nil {
|
||||||
// Already errored.
|
// Already errored.
|
||||||
@@ -144,10 +145,10 @@ func (track *baseTrack) OnEnded(handler func(error)) {
|
|||||||
|
|
||||||
// onError is a callback when an error occurs
|
// onError is a callback when an error occurs
|
||||||
func (track *baseTrack) onError(err error) {
|
func (track *baseTrack) onError(err error) {
|
||||||
track.mu.Lock()
|
track.errMu.Lock()
|
||||||
track.err = err
|
track.err = err
|
||||||
handler := track.onErrorHandler
|
handler := track.onErrorHandler
|
||||||
track.mu.Unlock()
|
track.errMu.Unlock()
|
||||||
|
|
||||||
if handler != nil {
|
if handler != nil {
|
||||||
track.endOnce.Do(func() {
|
track.endOnce.Do(func() {
|
||||||
@@ -171,6 +172,14 @@ func (track *baseTrack) bind(ctx webrtc.TrackLocalContext, specializedTrack Trac
|
|||||||
for _, wantedCodec := range ctx.CodecParameters() {
|
for _, wantedCodec := range ctx.CodecParameters() {
|
||||||
logger.Debugf("trying to build %s rtp reader", wantedCodec.MimeType)
|
logger.Debugf("trying to build %s rtp reader", wantedCodec.MimeType)
|
||||||
encodedReader, err = specializedTrack.NewRTPReader(wantedCodec.MimeType, uint32(ctx.SSRC()), rtpOutboundMTU)
|
encodedReader, err = specializedTrack.NewRTPReader(wantedCodec.MimeType, uint32(ctx.SSRC()), rtpOutboundMTU)
|
||||||
|
|
||||||
|
track.errMu.Lock()
|
||||||
|
if track.err != nil {
|
||||||
|
err = track.err
|
||||||
|
encodedReader = nil
|
||||||
|
}
|
||||||
|
track.errMu.Unlock()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
selectedCodec = wantedCodec
|
selectedCodec = wantedCodec
|
||||||
break
|
break
|
||||||
|
@@ -2,14 +2,33 @@ package mediadevices
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/pion/interceptor"
|
|
||||||
"io"
|
"io"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/interceptor"
|
||||||
|
"github.com/pion/webrtc/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errExpected error = errors.New("an error")
|
||||||
|
|
||||||
|
type DummyBindTrack struct {
|
||||||
|
*baseTrack
|
||||||
|
}
|
||||||
|
|
||||||
|
func (track *DummyBindTrack) Bind(ctx webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) {
|
||||||
|
track.mu.Lock()
|
||||||
|
defer track.mu.Unlock()
|
||||||
|
|
||||||
|
track.onError(errExpected)
|
||||||
|
|
||||||
|
<-time.After(5 * time.Millisecond)
|
||||||
|
|
||||||
|
return webrtc.RTPCodecParameters{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestOnEnded(t *testing.T) {
|
func TestOnEnded(t *testing.T) {
|
||||||
errExpected := errors.New("an error")
|
|
||||||
|
|
||||||
t.Run("ErrorAfterRegister", func(t *testing.T) {
|
t.Run("ErrorAfterRegister", func(t *testing.T) {
|
||||||
tr := &baseTrack{}
|
tr := &baseTrack{}
|
||||||
@@ -54,6 +73,34 @@ func TestOnEnded(t *testing.T) {
|
|||||||
t.Error("Timeout")
|
t.Error("Timeout")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("ErrorDurringBind", func(t *testing.T) {
|
||||||
|
tr := &DummyBindTrack{
|
||||||
|
baseTrack: &baseTrack{
|
||||||
|
activePeerConnections: make(map[string]chan<- chan<- struct{}),
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
called := make(chan error, 1)
|
||||||
|
tr.OnEnded(func(err error) {
|
||||||
|
called <- errExpected
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := tr.Bind(webrtc.TrackLocalContext{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-called:
|
||||||
|
if err != errExpected {
|
||||||
|
t.Errorf("Expected to receive error: %v, got: %v", errExpected, err)
|
||||||
|
}
|
||||||
|
case <-time.After(10 * time.Millisecond):
|
||||||
|
t.Error("Timeout")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type fakeRTCPReader struct {
|
type fakeRTCPReader struct {
|
||||||
|
Reference in New Issue
Block a user