From 07b1fe6a05174a6a773eea5a1fcedac94d5cc09f Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Fri, 12 Nov 2021 14:38:48 +0100 Subject: [PATCH] client: fix race condition --- client.go | 36 ++++++++++++++++++------------------ client_read_test.go | 10 ++++------ 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/client.go b/client.go index e1b04768..4ad1fadf 100644 --- a/client.go +++ b/client.go @@ -654,6 +654,19 @@ func (c *Client) trySwitchingProtocol() error { } func (c *Client) playRecordStart() { + // allow writing + c.writeMutex.Lock() + c.writeFrameAllowed = true + c.writeMutex.Unlock() + + // start UDP listeners + if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast { + for _, cct := range c.tracks { + cct.udpRTPListener.start() + cct.udpRTCPListener.start() + } + } + // start timers if c.state == clientStatePlay { c.reportTimer = time.NewTimer(c.receiverReportPeriod) @@ -676,19 +689,6 @@ func (c *Client) playRecordStart() { c.reportTimer = time.NewTimer(c.senderReportPeriod) } - // allow writing - c.writeMutex.Lock() - c.writeFrameAllowed = true - c.writeMutex.Unlock() - - // start UDP listeners - if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast { - for _, cct := range c.tracks { - cct.udpRTPListener.start() - cct.udpRTCPListener.start() - } - } - // for some reason, SetReadDeadline() must always be called in the same // goroutine, otherwise Read() freezes. // therefore, we disable the deadline and perform a check with a ticker. @@ -782,6 +782,11 @@ func (c *Client) playRecordClose() { <-c.readerErr } + // stop timers + c.reportTimer = emptyTimer() + c.checkStreamTimer = emptyTimer() + c.keepaliveTimer = emptyTimer() + // stop UDP listeners if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast { for _, cct := range c.tracks { @@ -794,11 +799,6 @@ func (c *Client) playRecordClose() { c.writeMutex.Lock() c.writeFrameAllowed = false c.writeMutex.Unlock() - - // stop timers - c.reportTimer = emptyTimer() - c.checkStreamTimer = emptyTimer() - c.keepaliveTimer = emptyTimer() } func (c *Client) connOpen() error { diff --git a/client_read_test.go b/client_read_test.go index d7793202..fc9a3b06 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -394,7 +394,7 @@ func TestClientRead(t *testing.T) { require.NoError(t, err) }() - counter := uint64(0) + counter := 0 c := &Client{ Transport: func() *Transport { @@ -415,8 +415,8 @@ func TestClientRead(t *testing.T) { OnPacketRTP: func(c *Client, trackID int, payload []byte) { // ignore multicast loopback if transport == "multicast" { - add := atomic.AddUint64(&counter, 1) - if add >= 2 { + counter++ + if counter >= 2 { return } } @@ -424,7 +424,7 @@ func TestClientRead(t *testing.T) { require.Equal(t, 0, trackID) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) - err = c.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08}) + err := c.WritePacketRTCP(0, []byte{0x05, 0x06, 0x07, 0x08}) require.NoError(t, err) }, } @@ -442,8 +442,6 @@ func TestClientRead(t *testing.T) { <-frameRecv c.Close() <-done - - c.ReadFrames() }) } }