client: fix timeout when writing to back channels (#575) (#774)

Keep alives are now sent when writing to back channels too.
This commit is contained in:
Alessandro Ros
2025-05-04 13:17:53 +02:00
committed by GitHub
parent 7f5aac27d1
commit 2cbdc2a0b7
2 changed files with 35 additions and 11 deletions

View File

@@ -352,8 +352,8 @@ type Client struct {
checkTimeoutTimer *time.Timer checkTimeoutTimer *time.Timer
checkTimeoutInitial bool checkTimeoutInitial bool
tcpLastFrameTime *int64 tcpLastFrameTime *int64
keepalivePeriod time.Duration keepAlivePeriod time.Duration
keepaliveTimer *time.Timer keepAliveTimer *time.Timer
closeError error closeError error
writer *asyncProcessor writer *asyncProcessor
writerMutex sync.RWMutex writerMutex sync.RWMutex
@@ -481,8 +481,8 @@ func (c *Client) Start(scheme string, host string) error {
c.ctx = ctx c.ctx = ctx
c.ctxCancel = ctxCancel c.ctxCancel = ctxCancel
c.checkTimeoutTimer = emptyTimer() c.checkTimeoutTimer = emptyTimer()
c.keepalivePeriod = 30 * time.Second c.keepAlivePeriod = 30 * time.Second
c.keepaliveTimer = emptyTimer() c.keepAliveTimer = emptyTimer()
if c.BytesReceived != nil { if c.BytesReceived != nil {
c.bytesReceived = c.BytesReceived c.bytesReceived = c.BytesReceived
@@ -659,12 +659,12 @@ func (c *Client) runInner() error {
} }
c.checkTimeoutTimer = time.NewTimer(c.checkTimeoutPeriod) c.checkTimeoutTimer = time.NewTimer(c.checkTimeoutPeriod)
case <-c.keepaliveTimer.C: case <-c.keepAliveTimer.C:
err := c.doKeepAlive() err := c.doKeepAlive()
if err != nil { if err != nil {
return err return err
} }
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) c.keepAliveTimer = time.NewTimer(c.keepAlivePeriod)
case <-chWriterError: case <-chWriterError:
return c.writer.stopError return c.writer.stopError
@@ -889,9 +889,11 @@ func (c *Client) startTransportRoutines() {
c.tcpBuffer = make([]byte, c.MaxPacketSize+4) c.tcpBuffer = make([]byte, c.MaxPacketSize+4)
} }
if c.state == clientStatePlay && c.stdChannelSetupped { if c.state == clientStatePlay {
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) c.keepAliveTimer = time.NewTimer(c.keepAlivePeriod)
}
if c.state == clientStatePlay && c.stdChannelSetupped {
switch *c.effectiveTransport { switch *c.effectiveTransport {
case TransportUDP: case TransportUDP:
c.checkTimeoutTimer = time.NewTimer(c.InitialUDPReadTimeout) c.checkTimeoutTimer = time.NewTimer(c.InitialUDPReadTimeout)
@@ -918,7 +920,7 @@ func (c *Client) stopTransportRoutines() {
} }
c.checkTimeoutTimer = emptyTimer() c.checkTimeoutTimer = emptyTimer()
c.keepaliveTimer = emptyTimer() c.keepAliveTimer = emptyTimer()
for _, cm := range c.setuppedMedias { for _, cm := range c.setuppedMedias {
cm.stop() cm.stop()
@@ -1056,7 +1058,7 @@ func (c *Client) do(req *base.Request, skipResponse bool) (*base.Response, error
c.session = sx.Session c.session = sx.Session
if sx.Timeout != nil && *sx.Timeout > 0 { if sx.Timeout != nil && *sx.Timeout > 0 {
c.keepalivePeriod = time.Duration(*sx.Timeout) * time.Second * 8 / 10 c.keepAlivePeriod = time.Duration(*sx.Timeout) * time.Second * 8 / 10
} }
} }

View File

@@ -2609,7 +2609,7 @@ func TestClientPlaySeek(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func TestClientPlayKeepalive(t *testing.T) { func TestClientPlayKeepAlive(t *testing.T) {
for _, ca := range []string{"response before frame", "response after frame", "no response"} { for _, ca := range []string{"response before frame", "response after frame", "no response"} {
t.Run(ca, func(t *testing.T) { t.Run(ca, func(t *testing.T) {
l, err := net.Listen("tcp", "localhost:8554") l, err := net.Listen("tcp", "localhost:8554")
@@ -3436,6 +3436,10 @@ func TestClientPlayBackChannel(t *testing.T) {
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Transport": th.Marshal(), "Transport": th.Marshal(),
"Session": headers.Session{
Session: "ABCDE",
Timeout: uintPtr(1),
}.Marshal(),
}, },
}) })
require.NoError(t, err2) require.NoError(t, err2)
@@ -3458,6 +3462,10 @@ func TestClientPlayBackChannel(t *testing.T) {
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
Header: base.Header{ Header: base.Header{
"Transport": th.Marshal(), "Transport": th.Marshal(),
"Session": headers.Session{
Session: "ABCDE",
Timeout: uintPtr(1),
}.Marshal(),
}, },
}) })
require.NoError(t, err2) require.NoError(t, err2)
@@ -3489,6 +3497,20 @@ func TestClientPlayBackChannel(t *testing.T) {
require.Equal(t, uint32(1), sr.PacketCount) require.Equal(t, uint32(1), sr.PacketCount)
require.Equal(t, uint32(4), sr.OctetCount) require.Equal(t, uint32(4), sr.OctetCount)
recv := make(chan struct{})
go func() {
defer close(recv)
req, err2 = conn.ReadRequest()
require.NoError(t, err2)
require.Equal(t, base.Options, req.Method)
}()
select {
case <-recv:
case <-time.After(2 * time.Second):
t.Errorf("should not happen")
}
err2 = conn.WriteInterleavedFrame(&base.InterleavedFrame{ err2 = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0, Channel: 0,
Payload: testRTPPacketMarshaled, Payload: testRTPPacketMarshaled,