client: fix race condition

This commit is contained in:
aler9
2021-11-12 14:38:48 +01:00
committed by Alessandro Ros
parent e7dbfa5eb1
commit 07b1fe6a05
2 changed files with 22 additions and 24 deletions

View File

@@ -654,6 +654,19 @@ func (c *Client) trySwitchingProtocol() error {
}
func (c *Client) playRecordStart() {
// allow writing
c.writeMutex.Lock()
c.writeFrameAllowed = true
c.writeMutex.Unlock()
// start UDP listeners
if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast {
for _, cct := range c.tracks {
cct.udpRTPListener.start()
cct.udpRTCPListener.start()
}
}
// start timers
if c.state == clientStatePlay {
c.reportTimer = time.NewTimer(c.receiverReportPeriod)
@@ -676,19 +689,6 @@ func (c *Client) playRecordStart() {
c.reportTimer = time.NewTimer(c.senderReportPeriod)
}
// allow writing
c.writeMutex.Lock()
c.writeFrameAllowed = true
c.writeMutex.Unlock()
// start UDP listeners
if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast {
for _, cct := range c.tracks {
cct.udpRTPListener.start()
cct.udpRTCPListener.start()
}
}
// for some reason, SetReadDeadline() must always be called in the same
// goroutine, otherwise Read() freezes.
// therefore, we disable the deadline and perform a check with a ticker.
@@ -782,6 +782,11 @@ func (c *Client) playRecordClose() {
<-c.readerErr
}
// stop timers
c.reportTimer = emptyTimer()
c.checkStreamTimer = emptyTimer()
c.keepaliveTimer = emptyTimer()
// stop UDP listeners
if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast {
for _, cct := range c.tracks {
@@ -794,11 +799,6 @@ func (c *Client) playRecordClose() {
c.writeMutex.Lock()
c.writeFrameAllowed = false
c.writeMutex.Unlock()
// stop timers
c.reportTimer = emptyTimer()
c.checkStreamTimer = emptyTimer()
c.keepaliveTimer = emptyTimer()
}
func (c *Client) connOpen() error {

View File

@@ -394,7 +394,7 @@ func TestClientRead(t *testing.T) {
require.NoError(t, err)
}()
counter := uint64(0)
counter := 0
c := &Client{
Transport: func() *Transport {
@@ -415,8 +415,8 @@ func TestClientRead(t *testing.T) {
OnPacketRTP: func(c *Client, trackID int, payload []byte) {
// ignore multicast loopback
if transport == "multicast" {
add := atomic.AddUint64(&counter, 1)
if add >= 2 {
counter++
if counter >= 2 {
return
}
}
@@ -424,7 +424,7 @@ func TestClientRead(t *testing.T) {
require.Equal(t, 0, trackID)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload)
err = c.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08})
err := c.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08})
require.NoError(t, err)
},
}
@@ -442,8 +442,6 @@ func TestClientRead(t *testing.T) {
<-frameRecv
c.Close()
<-done
c.ReadFrames()
})
}
}