client: fix race condition

This commit is contained in:
aler9
2021-11-12 14:38:48 +01:00
committed by Alessandro Ros
parent e7dbfa5eb1
commit 07b1fe6a05
2 changed files with 22 additions and 24 deletions

View File

@@ -654,6 +654,19 @@ func (c *Client) trySwitchingProtocol() error {
} }
func (c *Client) playRecordStart() { func (c *Client) playRecordStart() {
// allow writing
c.writeMutex.Lock()
c.writeFrameAllowed = true
c.writeMutex.Unlock()
// start UDP listeners
if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast {
for _, cct := range c.tracks {
cct.udpRTPListener.start()
cct.udpRTCPListener.start()
}
}
// start timers // start timers
if c.state == clientStatePlay { if c.state == clientStatePlay {
c.reportTimer = time.NewTimer(c.receiverReportPeriod) c.reportTimer = time.NewTimer(c.receiverReportPeriod)
@@ -676,19 +689,6 @@ func (c *Client) playRecordStart() {
c.reportTimer = time.NewTimer(c.senderReportPeriod) c.reportTimer = time.NewTimer(c.senderReportPeriod)
} }
// allow writing
c.writeMutex.Lock()
c.writeFrameAllowed = true
c.writeMutex.Unlock()
// start UDP listeners
if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast {
for _, cct := range c.tracks {
cct.udpRTPListener.start()
cct.udpRTCPListener.start()
}
}
// for some reason, SetReadDeadline() must always be called in the same // for some reason, SetReadDeadline() must always be called in the same
// goroutine, otherwise Read() freezes. // goroutine, otherwise Read() freezes.
// therefore, we disable the deadline and perform a check with a ticker. // therefore, we disable the deadline and perform a check with a ticker.
@@ -782,6 +782,11 @@ func (c *Client) playRecordClose() {
<-c.readerErr <-c.readerErr
} }
// stop timers
c.reportTimer = emptyTimer()
c.checkStreamTimer = emptyTimer()
c.keepaliveTimer = emptyTimer()
// stop UDP listeners // stop UDP listeners
if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast { if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast {
for _, cct := range c.tracks { for _, cct := range c.tracks {
@@ -794,11 +799,6 @@ func (c *Client) playRecordClose() {
c.writeMutex.Lock() c.writeMutex.Lock()
c.writeFrameAllowed = false c.writeFrameAllowed = false
c.writeMutex.Unlock() c.writeMutex.Unlock()
// stop timers
c.reportTimer = emptyTimer()
c.checkStreamTimer = emptyTimer()
c.keepaliveTimer = emptyTimer()
} }
func (c *Client) connOpen() error { func (c *Client) connOpen() error {

View File

@@ -394,7 +394,7 @@ func TestClientRead(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
counter := uint64(0) counter := 0
c := &Client{ c := &Client{
Transport: func() *Transport { Transport: func() *Transport {
@@ -415,8 +415,8 @@ func TestClientRead(t *testing.T) {
OnPacketRTP: func(c *Client, trackID int, payload []byte) { OnPacketRTP: func(c *Client, trackID int, payload []byte) {
// ignore multicast loopback // ignore multicast loopback
if transport == "multicast" { if transport == "multicast" {
add := atomic.AddUint64(&counter, 1) counter++
if add >= 2 { if counter >= 2 {
return return
} }
} }
@@ -424,7 +424,7 @@ func TestClientRead(t *testing.T) {
require.Equal(t, 0, trackID) require.Equal(t, 0, trackID)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload)
err = c.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08}) err := c.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08})
require.NoError(t, err) require.NoError(t, err)
}, },
} }
@@ -442,8 +442,6 @@ func TestClientRead(t *testing.T) {
<-frameRecv <-frameRecv
c.Close() c.Close()
<-done <-done
c.ReadFrames()
}) })
} }
} }