diff --git a/client.go b/client.go index 7e2b1c4a..0879322d 100644 --- a/client.go +++ b/client.go @@ -352,8 +352,8 @@ type Client struct { checkTimeoutTimer *time.Timer checkTimeoutInitial bool tcpLastFrameTime *int64 - keepalivePeriod time.Duration - keepaliveTimer *time.Timer + keepAlivePeriod time.Duration + keepAliveTimer *time.Timer closeError error writer *asyncProcessor writerMutex sync.RWMutex @@ -481,8 +481,8 @@ func (c *Client) Start(scheme string, host string) error { c.ctx = ctx c.ctxCancel = ctxCancel c.checkTimeoutTimer = emptyTimer() - c.keepalivePeriod = 30 * time.Second - c.keepaliveTimer = emptyTimer() + c.keepAlivePeriod = 30 * time.Second + c.keepAliveTimer = emptyTimer() if c.BytesReceived != nil { c.bytesReceived = c.BytesReceived @@ -659,12 +659,12 @@ func (c *Client) runInner() error { } c.checkTimeoutTimer = time.NewTimer(c.checkTimeoutPeriod) - case <-c.keepaliveTimer.C: + case <-c.keepAliveTimer.C: err := c.doKeepAlive() if err != nil { return err } - c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) + c.keepAliveTimer = time.NewTimer(c.keepAlivePeriod) case <-chWriterError: return c.writer.stopError @@ -889,9 +889,11 @@ func (c *Client) startTransportRoutines() { c.tcpBuffer = make([]byte, c.MaxPacketSize+4) } - if c.state == clientStatePlay && c.stdChannelSetupped { - c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) + if c.state == clientStatePlay { + c.keepAliveTimer = time.NewTimer(c.keepAlivePeriod) + } + if c.state == clientStatePlay && c.stdChannelSetupped { switch *c.effectiveTransport { case TransportUDP: c.checkTimeoutTimer = time.NewTimer(c.InitialUDPReadTimeout) @@ -918,7 +920,7 @@ func (c *Client) stopTransportRoutines() { } c.checkTimeoutTimer = emptyTimer() - c.keepaliveTimer = emptyTimer() + c.keepAliveTimer = emptyTimer() for _, cm := range c.setuppedMedias { cm.stop() @@ -1056,7 +1058,7 @@ func (c *Client) do(req *base.Request, skipResponse bool) (*base.Response, error c.session = sx.Session if sx.Timeout != nil && *sx.Timeout > 0 { - c.keepalivePeriod = time.Duration(*sx.Timeout) * time.Second * 8 / 10 + c.keepAlivePeriod = time.Duration(*sx.Timeout) * time.Second * 8 / 10 } } diff --git a/client_play_test.go b/client_play_test.go index 895785c2..28864b68 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -2609,7 +2609,7 @@ func TestClientPlaySeek(t *testing.T) { require.NoError(t, err) } -func TestClientPlayKeepalive(t *testing.T) { +func TestClientPlayKeepAlive(t *testing.T) { for _, ca := range []string{"response before frame", "response after frame", "no response"} { t.Run(ca, func(t *testing.T) { l, err := net.Listen("tcp", "localhost:8554") @@ -3436,6 +3436,10 @@ func TestClientPlayBackChannel(t *testing.T) { StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), + "Session": headers.Session{ + Session: "ABCDE", + Timeout: uintPtr(1), + }.Marshal(), }, }) require.NoError(t, err2) @@ -3458,6 +3462,10 @@ func TestClientPlayBackChannel(t *testing.T) { StatusCode: base.StatusOK, Header: base.Header{ "Transport": th.Marshal(), + "Session": headers.Session{ + Session: "ABCDE", + Timeout: uintPtr(1), + }.Marshal(), }, }) require.NoError(t, err2) @@ -3489,6 +3497,20 @@ func TestClientPlayBackChannel(t *testing.T) { require.Equal(t, uint32(1), sr.PacketCount) require.Equal(t, uint32(4), sr.OctetCount) + recv := make(chan struct{}) + go func() { + defer close(recv) + req, err2 = conn.ReadRequest() + require.NoError(t, err2) + require.Equal(t, base.Options, req.Method) + }() + + select { + case <-recv: + case <-time.After(2 * time.Second): + t.Errorf("should not happen") + } + err2 = conn.WriteInterleavedFrame(&base.InterleavedFrame{ Channel: 0, Payload: testRTPPacketMarshaled,