fix race condition in tests (#3834)

This commit is contained in:
Alessandro Ros
2024-10-05 21:54:11 +02:00
committed by GitHub
parent 534b637bc7
commit 2586782031
10 changed files with 51 additions and 32 deletions

View File

@@ -76,5 +76,7 @@ func TestFromStreamSkipUnsupportedTracks(t *testing.T) {
err = FromStream(stream, l, m) err = FromStream(stream, l, m)
require.NoError(t, err) require.NoError(t, err)
defer stream.RemoveReader(l)
require.Equal(t, 2, n) require.Equal(t, 2, n)
} }

View File

@@ -64,5 +64,7 @@ func TestFromStreamSkipUnsupportedTracks(t *testing.T) {
err = FromStream(stream, l, nil, nil, 0) err = FromStream(stream, l, nil, nil, 0)
require.NoError(t, err) require.NoError(t, err)
defer stream.RemoveReader(l)
require.Equal(t, 1, n) require.Equal(t, 1, n)
} }

View File

@@ -78,5 +78,7 @@ func TestFromStreamSkipUnsupportedTracks(t *testing.T) {
err = FromStream(stream, l, conn, nil, 0) err = FromStream(stream, l, conn, nil, 0)
require.NoError(t, err) require.NoError(t, err)
defer stream.RemoveReader(l)
require.Equal(t, 2, n) require.Equal(t, 2, n)
} }

View File

@@ -66,6 +66,8 @@ func TestFromStreamSkipUnsupportedTracks(t *testing.T) {
err = FromStream(stream, l, pc) err = FromStream(stream, l, pc)
require.NoError(t, err) require.NoError(t, err)
defer stream.RemoveReader(l)
require.Equal(t, 1, n) require.Equal(t, 1, n)
} }
@@ -93,6 +95,7 @@ func TestFromStream(t *testing.T) {
err = FromStream(stream, nil, pc) err = FromStream(stream, nil, pc)
require.NoError(t, err) require.NoError(t, err)
defer stream.RemoveReader(nil)
require.Equal(t, ca.webrtcCaps, pc.OutgoingTracks[0].Caps) require.Equal(t, ca.webrtcCaps, pc.OutgoingTracks[0].Caps)
}) })

View File

@@ -305,7 +305,7 @@ func TestServerRead(t *testing.T) {
s.PathReady(&dummyPath{}) s.PathReady(&dummyPath{})
time.Sleep(100 * time.Millisecond) str.WaitRunningReader()
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
str.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{ str.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{
@@ -398,7 +398,7 @@ func TestServerReadAuthorizationHeader(t *testing.T) {
s.PathReady(&dummyPath{}) s.PathReady(&dummyPath{})
time.Sleep(100 * time.Millisecond) str.WaitRunningReader()
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
str.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{ str.WriteUnit(test.MediaH264, test.FormatH264, &unit.H264{

View File

@@ -163,14 +163,14 @@ func TestServerPublish(t *testing.T) {
return nil return nil
}) })
path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)
err = w.WriteH264(0, 0, true, [][]byte{ err = w.WriteH264(0, 0, true, [][]byte{
{5, 2, 3, 4}, {5, 2, 3, 4},
}) })
require.NoError(t, err) require.NoError(t, err)
path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)
<-recv <-recv
}) })
} }
@@ -250,6 +250,8 @@ func TestServerRead(t *testing.T) {
videoTrack, _ := r.Tracks() videoTrack, _ := r.Tracks()
require.Equal(t, test.FormatH264, videoTrack) require.Equal(t, test.FormatH264, videoTrack)
stream.WaitRunningReader()
stream.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], &unit.H264{ stream.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], &unit.H264{
Base: unit.Base{ Base: unit.Base{
NTP: time.Time{}, NTP: time.Time{},

View File

@@ -150,6 +150,9 @@ func TestServerPublish(t *testing.T) {
return nil return nil
}) })
path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)
err = source.WritePacketRTP(media0, &rtp.Packet{ err = source.WritePacketRTP(media0, &rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: 2, Version: 2,
@@ -163,9 +166,6 @@ func TestServerPublish(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)
<-recv <-recv
} }

View File

@@ -156,6 +156,9 @@ func TestServerPublish(t *testing.T) {
return nil return nil
}) })
path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)
err = w.WriteH264(track, 0, 0, true, [][]byte{ err = w.WriteH264(track, 0, 0, true, [][]byte{
{5, 2}, {5, 2},
}) })
@@ -164,9 +167,6 @@ func TestServerPublish(t *testing.T) {
err = bw.Flush() err = bw.Flush()
require.NoError(t, err) require.NoError(t, err)
path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)
<-recv <-recv
} }
@@ -219,6 +219,8 @@ func TestServerRead(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer reader.Close() defer reader.Close()
stream.WaitRunningReader()
stream.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], &unit.H264{ stream.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], &unit.H264{
Base: unit.Base{ Base: unit.Base{
NTP: time.Time{}, NTP: time.Time{},

View File

@@ -340,6 +340,9 @@ func TestServerPublish(t *testing.T) {
return nil return nil
}) })
path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)
err = track.WriteRTP(&rtp.Packet{ err = track.WriteRTP(&rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: 2, Version: 2,
@@ -353,9 +356,6 @@ func TestServerPublish(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
path.stream.StartReader(reader)
defer path.stream.RemoveReader(reader)
<-recv <-recv
} }
@@ -572,19 +572,11 @@ func TestServerRead(t *testing.T) {
} }
writerDone := make(chan struct{}) writerDone := make(chan struct{})
defer func() { <-writerDone }()
writerTerminate := make(chan struct{})
defer close(writerTerminate)
go func() { go func() {
defer close(writerDone) defer close(writerDone)
for {
select { str.WaitRunningReader()
case <-time.After(100 * time.Millisecond):
case <-writerTerminate:
return
}
r := reflect.New(reflect.TypeOf(ca.unit).Elem()) r := reflect.New(reflect.TypeOf(ca.unit).Elem())
r.Elem().Set(reflect.ValueOf(ca.unit).Elem()) r.Elem().Set(reflect.ValueOf(ca.unit).Elem())
@@ -595,7 +587,6 @@ func TestServerRead(t *testing.T) {
} else { } else {
str.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], r.Interface().(unit.Unit)) str.WriteUnit(desc.Medias[0], desc.Medias[0].Formats[0], r.Interface().(unit.Unit))
} }
}
}() }()
tracks, err := wc.Read(context.Background()) tracks, err := wc.Read(context.Background())
@@ -615,6 +606,7 @@ func TestServerRead(t *testing.T) {
wc.StartReading() wc.StartReading()
<-writerDone
<-done <-done
}) })
} }

View File

@@ -36,6 +36,8 @@ type Stream struct {
rtspStream *gortsplib.ServerStream rtspStream *gortsplib.ServerStream
rtspsStream *gortsplib.ServerStream rtspsStream *gortsplib.ServerStream
streamReaders map[Reader]*streamReader streamReaders map[Reader]*streamReader
readerRunning chan struct{}
} }
// New allocates a Stream. // New allocates a Stream.
@@ -55,6 +57,7 @@ func New(
s.streamMedias = make(map[*description.Media]*streamMedia) s.streamMedias = make(map[*description.Media]*streamMedia)
s.streamReaders = make(map[Reader]*streamReader) s.streamReaders = make(map[Reader]*streamReader)
s.readerRunning = make(chan struct{})
for _, media := range desc.Medias { for _, media := range desc.Medias {
var err error var err error
@@ -180,6 +183,12 @@ func (s *Stream) StartReader(reader Reader) {
sf.startReader(sr) sf.startReader(sr)
} }
} }
select {
case <-s.readerRunning:
default:
close(s.readerRunning)
}
} }
// ReaderError returns whenever there's an error. // ReaderError returns whenever there's an error.
@@ -209,6 +218,11 @@ func (s *Stream) ReaderFormats(reader Reader) []format.Format {
return formats return formats
} }
// WaitRunningReader waits for a running reader.
func (s *Stream) WaitRunningReader() {
<-s.readerRunning
}
// WriteUnit writes a Unit. // WriteUnit writes a Unit.
func (s *Stream) WriteUnit(medi *description.Media, forma format.Format, u unit.Unit) { func (s *Stream) WriteUnit(medi *description.Media, forma format.Format, u unit.Unit) {
sm := s.streamMedias[medi] sm := s.streamMedias[medi]