mirror of
https://github.com/aler9/gortsplib
synced 2025-10-05 15:16:51 +08:00
client: allow receiving UDP packets before PLAY response
This commit is contained in:
108
client.go
108
client.go
@@ -597,7 +597,15 @@ func (c *Client) run() {
|
|||||||
|
|
||||||
func (c *Client) doClose(isClosing bool) {
|
func (c *Client) doClose(isClosing bool) {
|
||||||
if c.state == clientStatePlay || c.state == clientStateRecord {
|
if c.state == clientStatePlay || c.state == clientStateRecord {
|
||||||
c.playRecordClose(isClosing)
|
if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast {
|
||||||
|
// stop UDP listeners
|
||||||
|
for _, cct := range c.tracks {
|
||||||
|
cct.udpRTPListener.stop()
|
||||||
|
cct.udpRTCPListener.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.playRecordStop(isClosing)
|
||||||
|
|
||||||
c.do(&base.Request{
|
c.do(&base.Request{
|
||||||
Method: base.Teardown,
|
Method: base.Teardown,
|
||||||
@@ -686,14 +694,6 @@ func (c *Client) playRecordStart() {
|
|||||||
c.writeFrameAllowed = true
|
c.writeFrameAllowed = true
|
||||||
c.writeMutex.Unlock()
|
c.writeMutex.Unlock()
|
||||||
|
|
||||||
// start UDP listeners
|
|
||||||
if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast {
|
|
||||||
for _, cct := range c.tracks {
|
|
||||||
cct.udpRTPListener.start()
|
|
||||||
cct.udpRTCPListener.start()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// start timers
|
// start timers
|
||||||
if c.state == clientStatePlay {
|
if c.state == clientStatePlay {
|
||||||
c.reportTimer = time.NewTimer(c.receiverReportPeriod)
|
c.reportTimer = time.NewTimer(c.receiverReportPeriod)
|
||||||
@@ -801,7 +801,7 @@ func (c *Client) runReader() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) playRecordClose(isClosing bool) {
|
func (c *Client) playRecordStop(isClosing bool) {
|
||||||
// stop reader
|
// stop reader
|
||||||
if c.readerErr != nil {
|
if c.readerErr != nil {
|
||||||
c.nconn.SetReadDeadline(time.Now())
|
c.nconn.SetReadDeadline(time.Now())
|
||||||
@@ -813,14 +813,6 @@ func (c *Client) playRecordClose(isClosing bool) {
|
|||||||
c.checkStreamTimer = emptyTimer()
|
c.checkStreamTimer = emptyTimer()
|
||||||
c.keepaliveTimer = emptyTimer()
|
c.keepaliveTimer = emptyTimer()
|
||||||
|
|
||||||
// stop UDP listeners
|
|
||||||
if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast {
|
|
||||||
for _, cct := range c.tracks {
|
|
||||||
cct.udpRTPListener.stop()
|
|
||||||
cct.udpRTCPListener.stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// forbid writing
|
// forbid writing
|
||||||
c.writeMutex.Lock()
|
c.writeMutex.Lock()
|
||||||
c.writeFrameAllowed = false
|
c.writeFrameAllowed = false
|
||||||
@@ -1536,9 +1528,21 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// open the firewall by sending packets to the counterpart.
|
if c.OnPlay != nil {
|
||||||
// do this before sending the PLAY request.
|
c.OnPlay(c)
|
||||||
if *c.protocol == TransportUDP {
|
}
|
||||||
|
|
||||||
|
c.state = clientStatePlay
|
||||||
|
|
||||||
|
// setup UDP communication before sending the request.
|
||||||
|
if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast {
|
||||||
|
// start UDP listeners
|
||||||
|
for _, cct := range c.tracks {
|
||||||
|
cct.udpRTPListener.start()
|
||||||
|
cct.udpRTCPListener.start()
|
||||||
|
}
|
||||||
|
|
||||||
|
// open the firewall by sending packets to the counterpart.
|
||||||
for _, cct := range c.tracks {
|
for _, cct := range c.tracks {
|
||||||
cct.udpRTPListener.write(
|
cct.udpRTPListener.write(
|
||||||
[]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
|
[]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
|
||||||
@@ -1548,10 +1552,6 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.OnPlay != nil {
|
|
||||||
c.OnPlay(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
header := make(base.Header)
|
header := make(base.Header)
|
||||||
|
|
||||||
// Range is mandatory in Parrot Streaming Server
|
// Range is mandatory in Parrot Streaming Server
|
||||||
@@ -1574,12 +1574,21 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp
|
|||||||
}
|
}
|
||||||
|
|
||||||
if res.StatusCode != base.StatusOK {
|
if res.StatusCode != base.StatusOK {
|
||||||
|
if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast {
|
||||||
|
// stop UDP listeners
|
||||||
|
for _, cct := range c.tracks {
|
||||||
|
cct.udpRTPListener.stop()
|
||||||
|
cct.udpRTCPListener.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.state = clientStatePrePlay
|
||||||
|
|
||||||
return nil, liberrors.ErrClientBadStatusCode{
|
return nil, liberrors.ErrClientBadStatusCode{
|
||||||
Code: res.StatusCode, Message: res.StatusMessage,
|
Code: res.StatusCode, Message: res.StatusMessage,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.state = clientStatePlay
|
|
||||||
c.lastRange = ra
|
c.lastRange = ra
|
||||||
|
|
||||||
c.playRecordStart()
|
c.playRecordStart()
|
||||||
@@ -1609,6 +1618,16 @@ func (c *Client) doRecord() (*base.Response, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.state = clientStateRecord
|
||||||
|
|
||||||
|
if *c.protocol == TransportUDP {
|
||||||
|
// start UDP listeners
|
||||||
|
for _, cct := range c.tracks {
|
||||||
|
cct.udpRTPListener.start()
|
||||||
|
cct.udpRTCPListener.start()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
res, err := c.do(&base.Request{
|
res, err := c.do(&base.Request{
|
||||||
Method: base.Record,
|
Method: base.Record,
|
||||||
URL: c.streamBaseURL,
|
URL: c.streamBaseURL,
|
||||||
@@ -1618,13 +1637,21 @@ func (c *Client) doRecord() (*base.Response, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if res.StatusCode != base.StatusOK {
|
if res.StatusCode != base.StatusOK {
|
||||||
|
if *c.protocol == TransportUDP {
|
||||||
|
// stop UDP listeners
|
||||||
|
for _, cct := range c.tracks {
|
||||||
|
cct.udpRTPListener.stop()
|
||||||
|
cct.udpRTCPListener.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.state = clientStatePreRecord
|
||||||
|
|
||||||
return nil, liberrors.ErrClientBadStatusCode{
|
return nil, liberrors.ErrClientBadStatusCode{
|
||||||
Code: res.StatusCode, Message: res.StatusMessage,
|
Code: res.StatusCode, Message: res.StatusMessage,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.state = clientStateRecord
|
|
||||||
|
|
||||||
c.playRecordStart()
|
c.playRecordStart()
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -1653,7 +1680,23 @@ func (c *Client) doPause() (*base.Response, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.playRecordClose(false)
|
c.playRecordStop(false)
|
||||||
|
|
||||||
|
if *c.protocol == TransportUDP || *c.protocol == TransportUDPMulticast {
|
||||||
|
// stop UDP listeners
|
||||||
|
for _, cct := range c.tracks {
|
||||||
|
cct.udpRTPListener.stop()
|
||||||
|
cct.udpRTCPListener.stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// change state regardless of the response
|
||||||
|
switch c.state {
|
||||||
|
case clientStatePlay:
|
||||||
|
c.state = clientStatePrePlay
|
||||||
|
case clientStateRecord:
|
||||||
|
c.state = clientStatePreRecord
|
||||||
|
}
|
||||||
|
|
||||||
res, err := c.do(&base.Request{
|
res, err := c.do(&base.Request{
|
||||||
Method: base.Pause,
|
Method: base.Pause,
|
||||||
@@ -1669,13 +1712,6 @@ func (c *Client) doPause() (*base.Response, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch c.state {
|
|
||||||
case clientStatePlay:
|
|
||||||
c.state = clientStatePrePlay
|
|
||||||
case clientStateRecord:
|
|
||||||
c.state = clientStatePreRecord
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -360,14 +360,12 @@ func TestClientRead(t *testing.T) {
|
|||||||
// client -> server (RTCP)
|
// client -> server (RTCP)
|
||||||
switch transport {
|
switch transport {
|
||||||
case "udp", "multicast":
|
case "udp", "multicast":
|
||||||
if transport == "udp" {
|
// skip firewall opening
|
||||||
// skip firewall opening
|
|
||||||
buf := make([]byte, 2048)
|
|
||||||
_, _, err := l2.ReadFrom(buf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, 2048)
|
buf := make([]byte, 2048)
|
||||||
|
_, _, err := l2.ReadFrom(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf = make([]byte, 2048)
|
||||||
n, _, err := l2.ReadFrom(buf)
|
n, _, err := l2.ReadFrom(buf)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, buf[:n])
|
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, buf[:n])
|
||||||
@@ -416,7 +414,7 @@ func TestClientRead(t *testing.T) {
|
|||||||
// ignore multicast loopback
|
// ignore multicast loopback
|
||||||
if transport == "multicast" {
|
if transport == "multicast" {
|
||||||
counter++
|
counter++
|
||||||
if counter >= 2 {
|
if counter <= 1 || counter >= 3 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user