diff --git a/server.go b/server.go index 59f60493..a3be138d 100644 --- a/server.go +++ b/server.go @@ -88,7 +88,8 @@ type Server struct { // It defaults to net.Listen Listen func(network string, address string) (net.Listener, error) - receiverReportPeriod time.Duration + receiverReportPeriod time.Duration + closeSessionAfterNoRequestsFor time.Duration tcpListener net.Listener udpRTPListener *serverUDPListener @@ -129,6 +130,9 @@ func (s *Server) Start(address string) error { if s.receiverReportPeriod == 0 { s.receiverReportPeriod = 10 * time.Second } + if s.closeSessionAfterNoRequestsFor == 0 { + s.closeSessionAfterNoRequestsFor = 1 * 60 * time.Second + } if s.TLSConfig != nil && s.UDPRTPAddress != "" { return fmt.Errorf("TLS can't be used together with UDP") diff --git a/server_read_test.go b/server_read_test.go index 71ab3856..6c1ae06c 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -904,3 +904,101 @@ func TestServerReadPlayPausePause(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) } + +func TestServerReadErrorTimeout(t *testing.T) { + for _, proto := range []string{ + "udp", + // checking TCP is useless, since there's no timeout when reading with TCP + } { + t.Run(proto, func(t *testing.T) { + sessionClosed := make(chan struct{}) + + s := &Server{ + Handler: &testServerHandler{ + onSessionClose: func(ss *ServerSession, err error) { + close(sessionClosed) + }, + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + ReadTimeout: 1 * time.Second, + closeSessionAfterNoRequestsFor: 1 * time.Second, + } + + s.UDPRTPAddress = "127.0.0.1:8000" + s.UDPRTCPAddress = "127.0.0.1:8001" + + err := s.Start("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + nconn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer nconn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(nconn), bufio.NewWriter(nconn)) + + inTH := &headers.Transport{ + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + } + + if proto == "udp" { + inTH.Protocol = StreamProtocolUDP + inTH.ClientPorts = &[2]int{35466, 35467} + } else { + inTH.Protocol = StreamProtocolTCP + inTH.InterleavedIDs = &[2]int{0, 1} + } + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": inTH.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + var res base.Response + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + err = base.Request{ + Method: base.Play, + URL: base.MustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Session": res.Header["Session"], + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + <-sessionClosed + }) + } +} diff --git a/serversession.go b/serversession.go index 86aa7d48..59850279 100644 --- a/serversession.go +++ b/serversession.go @@ -16,8 +16,7 @@ import ( ) const ( - serverSessionCheckStreamPeriod = 1 * time.Second - serverSessionCloseAfterNoRequestsFor = 1 * 60 * time.Second + serverSessionCheckStreamPeriod = 1 * time.Second ) func setupGetTrackIDPathQuery(url *base.URL, @@ -248,7 +247,7 @@ func (ss *ServerSession) run() { // otherwise, timeout happens when no requests arrives default: now := time.Now() - if now.Sub(ss.lastRequestTime) >= serverSessionCloseAfterNoRequestsFor { + if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor { return liberrors.ErrServerSessionTimedOut{} } }