From 3ba7c373b98ad21147aa7d45928aa065e8054a9a Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Fri, 12 Nov 2021 16:22:02 +0100 Subject: [PATCH] client: allow receiving UDP packets before PLAY response --- client.go | 108 +++++++++++++++++++++++++++++--------------- client_read_test.go | 14 +++--- 2 files changed, 78 insertions(+), 44 deletions(-) diff --git a/client.go b/client.go index 11519d0c..f3c890fe 100644 --- a/client.go +++ b/client.go @@ -597,7 +597,15 @@ func (c *Client) run() { func (c *Client) doClose(isClosing bool) { if c.state == clientStatePlay || c.state == clientStateRecord { - c.playRecordClose(isClosing) + if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast { + // stop UDP listeners + for _, cct := range c.tracks { + cct.udpRTPListener.stop() + cct.udpRTCPListener.stop() + } + } + + c.playRecordStop(isClosing) c.do(&base.Request{ Method: base.Teardown, @@ -686,14 +694,6 @@ func (c *Client) playRecordStart() { c.writeFrameAllowed = true c.writeMutex.Unlock() - // start UDP listeners - if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast { - for _, cct := range c.tracks { - cct.udpRTPListener.start() - cct.udpRTCPListener.start() - } - } - // start timers if c.state == clientStatePlay { c.reportTimer = time.NewTimer(c.receiverReportPeriod) @@ -801,7 +801,7 @@ func (c *Client) runReader() error { } } -func (c *Client) playRecordClose(isClosing bool) { +func (c *Client) playRecordStop(isClosing bool) { // stop reader if c.readerErr != nil { c.nconn.SetReadDeadline(time.Now()) @@ -813,14 +813,6 @@ func (c *Client) playRecordClose(isClosing bool) { c.checkStreamTimer = emptyTimer() c.keepaliveTimer = emptyTimer() - // stop UDP listeners - if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast { - for _, cct := range c.tracks { - cct.udpRTPListener.stop() - cct.udpRTCPListener.stop() - } - } - // forbid writing c.writeMutex.Lock() c.writeFrameAllowed = false @@ -1536,9 +1528,21 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp return nil, err } - // open the firewall by sending packets to the counterpart. - // do this before sending the PLAY request. - if *c.protocol == TransportUDP { + if c.OnPlay != nil { + c.OnPlay(c) + } + + c.state = clientStatePlay + + // setup UDP communication before sending the request. + if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast { + // start UDP listeners + for _, cct := range c.tracks { + cct.udpRTPListener.start() + cct.udpRTCPListener.start() + } + + // open the firewall by sending packets to the counterpart. for _, cct := range c.tracks { cct.udpRTPListener.write( []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) @@ -1548,10 +1552,6 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp } } - if c.OnPlay != nil { - c.OnPlay(c) - } - header := make(base.Header) // Range is mandatory in Parrot Streaming Server @@ -1574,12 +1574,21 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp } if res.StatusCode != base.StatusOK { + if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast { + // stop UDP listeners + for _, cct := range c.tracks { + cct.udpRTPListener.stop() + cct.udpRTCPListener.stop() + } + } + + c.state = clientStatePrePlay + return nil, liberrors.ErrClientBadStatusCode{ Code: res.StatusCode, Message: res.StatusMessage, } } - c.state = clientStatePlay c.lastRange = ra c.playRecordStart() @@ -1609,6 +1618,16 @@ func (c *Client) doRecord() (*base.Response, error) { return nil, err } + c.state = clientStateRecord + + if *c.protocol == TransportUDP { + // start UDP listeners + for _, cct := range c.tracks { + cct.udpRTPListener.start() + cct.udpRTCPListener.start() + } + } + res, err := c.do(&base.Request{ Method: base.Record, URL: c.streamBaseURL, @@ -1618,13 +1637,21 @@ func (c *Client) doRecord() (*base.Response, error) { } if res.StatusCode != base.StatusOK { + if *c.protocol == TransportUDP { + // stop UDP listeners + for _, cct := range c.tracks { + cct.udpRTPListener.stop() + cct.udpRTCPListener.stop() + } + } + + c.state = clientStatePreRecord + return nil, liberrors.ErrClientBadStatusCode{ Code: res.StatusCode, Message: res.StatusMessage, } } - c.state = clientStateRecord - c.playRecordStart() return nil, nil @@ -1653,7 +1680,23 @@ func (c *Client) doPause() (*base.Response, error) { return nil, err } - c.playRecordClose(false) + c.playRecordStop(false) + + if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast { + // stop UDP listeners + for _, cct := range c.tracks { + cct.udpRTPListener.stop() + cct.udpRTCPListener.stop() + } + } + + // change state regardless of the response + switch c.state { + case clientStatePlay: + c.state = clientStatePrePlay + case clientStateRecord: + c.state = clientStatePreRecord + } res, err := c.do(&base.Request{ Method: base.Pause, @@ -1669,13 +1712,6 @@ func (c *Client) doPause() (*base.Response, error) { } } - switch c.state { - case clientStatePlay: - c.state = clientStatePrePlay - case clientStateRecord: - c.state = clientStatePreRecord - } - return res, nil } diff --git a/client_read_test.go b/client_read_test.go index 5a3d786e..c58b69ae 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -360,14 +360,12 @@ func TestClientRead(t *testing.T) { // client -> server (RTCP) switch transport { case "udp", "multicast": - if transport == "udp" { - // skip firewall opening - buf := make([]byte, 2048) - _, _, err := l2.ReadFrom(buf) - require.NoError(t, err) - } - + // skip firewall opening buf := make([]byte, 2048) + _, _, err := l2.ReadFrom(buf) + require.NoError(t, err) + + buf = make([]byte, 2048) n, _, err := l2.ReadFrom(buf) require.NoError(t, err) require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, buf[:n]) @@ -416,7 +414,7 @@ func TestClientRead(t *testing.T) { // ignore multicast loopback if transport == "multicast" { counter++ - if counter >= 2 { + if counter <= 1 || counter >= 3 { return } }