diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index fcf778ca..a5cf1215 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -31,12 +31,12 @@ func (e ErrServerCSeqMissing) Error() string { return "CSeq is missing" } -// ErrServerInvalidMethod is an error that can be returned by a server. -type ErrServerInvalidMethod struct{} +// ErrServerUnhandledRequest is an error that can be returned by a server. +type ErrServerUnhandledRequest struct{} // Error implements the error interface. -func (e ErrServerInvalidMethod) Error() string { - return "invalid method" +func (e ErrServerUnhandledRequest) Error() string { + return "unhandled request" } // ErrServerWrongState is an error that can be returned by a server. diff --git a/server_publish_test.go b/server_publish_test.go index 27a56657..9ff939e3 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -993,17 +993,12 @@ func TestServerPublishErrorTimeout(t *testing.T) { "tls", } { t.Run(proto, func(t *testing.T) { - errDone := make(chan struct{}) + sessionClosed := make(chan struct{}) s := &Server{ Handler: &testServerHandler{ onSessionClose: func(ss *ServerSession) { - /*if proto == "udp" { - require.Equal(t, "no UDP packets received (maybe there's a firewall/NAT in between)", err.Error()) - } else { - require.True(t, strings.HasSuffix(err.Error(), "i/o timeout")) - }*/ - close(errDone) + close(sessionClosed) }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { return &base.Response{ @@ -1130,7 +1125,7 @@ func TestServerPublishErrorTimeout(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - <-errDone + <-sessionClosed }) } } diff --git a/server_test.go b/server_test.go index e58e1525..ddd3030d 100644 --- a/server_test.go +++ b/server_test.go @@ -488,7 +488,7 @@ func TestServerErrorCSeqMissing(t *testing.T) { func TestServerErrorInvalidMethod(t *testing.T) { h := &testServerHandler{ onConnClose: func(sc *ServerConn, err error) { - require.Equal(t, "invalid method", err.Error()) + require.Equal(t, "unhandled request", err.Error()) }, } diff --git a/serverconn.go b/serverconn.go index 29310a39..4994ae68 100644 --- a/serverconn.go +++ b/serverconn.go @@ -296,7 +296,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { rres := make(chan requestRes) ss.request <- requestReq{sc: sc, req: req, res: rres} res := <-rres - return res.res, res.err } @@ -315,7 +314,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { rres := make(chan requestRes) ss.request <- requestReq{sc: sc, req: req, res: rres} res := <-rres - return res.res, res.err } @@ -409,10 +407,22 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { rres := make(chan requestRes) ss.request <- requestReq{sc: sc, req: req, res: rres} res := <-rres - return res.res, res.err case base.GetParameter: + sres := make(chan *ServerSession) + sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres} + ss := <-sres + + // send request to session + if ss != nil { + rres := make(chan requestRes) + ss.request <- requestReq{sc: sc, req: req, res: rres} + res := <-rres + return res.res, res.err + } + + // handle request here if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok { pathAndQuery, ok := req.URL.RTSPPath() if !ok { @@ -431,15 +441,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }) } - // GET_PARAMETER is used like a ping - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Content-Type": base.HeaderValue{"text/parameters"}, - }, - Body: []byte("\n"), - }, nil - case base.SetParameter: if h, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok { pathAndQuery, ok := req.URL.RTSPPath() @@ -462,7 +463,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerInvalidMethod{} + }, liberrors.ErrServerUnhandledRequest{} } func (sc *ServerConn) handleRequestOuter(req *base.Request) error { @@ -473,7 +474,7 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error { res, err := sc.handleRequest(req) if res.Header == nil { - res.Header = base.Header{} + res.Header = make(base.Header) } // add cseq diff --git a/serverhandler.go b/serverhandler.go index 7dbcc5da..41e03e92 100644 --- a/serverhandler.go +++ b/serverhandler.go @@ -127,10 +127,11 @@ type ServerHandlerOnPause interface { // ServerHandlerOnGetParameterCtx is the context of a GET_PARAMETER request. type ServerHandlerOnGetParameterCtx struct { - Conn *ServerConn - Req *base.Request - Path string - Query string + Session *ServerSession + Conn *ServerConn + Req *base.Request + Path string + Query string } // ServerHandlerOnGetParameter can be implemented by a ServerHandler. diff --git a/serversession.go b/serversession.go index 37223fcd..3a71c08d 100644 --- a/serversession.go +++ b/serversession.go @@ -16,7 +16,8 @@ import ( ) const ( - serverSessionCheckStreamPeriod = 1 * time.Second + serverSessionCheckStreamPeriod = 1 * time.Second + serverSessionCloseAfterNoRequestsFor = 1 * 60 * time.Second ) func setupGetTrackIDPathQuery(url *base.URL, @@ -108,9 +109,8 @@ type ServerSessionSetuppedTrack struct { // ServerSessionAnnouncedTrack is an announced track of a ServerSession. type ServerSessionAnnouncedTrack struct { - track *Track - rtcpReceiver *rtcpreceiver.RTCPReceiver - udpLastFrameTime *int64 + track *Track + rtcpReceiver *rtcpreceiver.RTCPReceiver } type requestRes struct { @@ -130,21 +130,17 @@ type ServerSession struct { id string wg *sync.WaitGroup - state ServerSessionState - setuppedTracks map[int]ServerSessionSetuppedTrack - setupProtocol *StreamProtocol - setupPath *string - setupQuery *string - - // TCP stream protocol - linkedConn *ServerConn - - // UDP stream protocol - udpIP net.IP - udpZone string - - // publish - announcedTracks []ServerSessionAnnouncedTrack + state ServerSessionState + setuppedTracks map[int]ServerSessionSetuppedTrack + setupProtocol *StreamProtocol + setupPath *string + setupQuery *string + lastRequestTime time.Time + linkedConn *ServerConn // tcp + udpIP net.IP // udp + udpZone string // udp + announcedTracks []ServerSessionAnnouncedTrack // publish + udpLastFrameTime *int64 // publish, udp // in request chan requestReq @@ -153,11 +149,12 @@ type ServerSession struct { func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession { ss := &ServerSession{ - s: s, - id: id, - wg: wg, - request: make(chan requestReq), - terminate: make(chan struct{}), + s: s, + id: id, + wg: wg, + lastRequestTime: time.Now(), + request: make(chan requestReq), + terminate: make(chan struct{}), } wg.Add(1) @@ -207,8 +204,8 @@ func (ss *ServerSession) run() { h.OnSessionOpen(ss) } - checkStreamTicker := time.NewTicker(serverSessionCheckStreamPeriod) - defer checkStreamTicker.Stop() + checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod) + defer checkTimeoutTicker.Stop() receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod) defer receiverReportTicker.Stop() @@ -219,6 +216,15 @@ outer: case req := <-ss.request: res, err := ss.handleRequest(req.sc, req.req) + ss.lastRequestTime = time.Now() + + if res.StatusCode == base.StatusOK { + if res.Header == nil { + res.Header = make(base.Header) + } + res.Header["Session"] = base.HeaderValue{ss.id} + } + if _, ok := err.(liberrors.ErrServerTeardown); ok { req.res <- requestRes{res, nil} break outer @@ -226,23 +232,25 @@ outer: req.res <- requestRes{res, err} - case <-checkStreamTicker.C: - if ss.state != ServerSessionStateRecord || *ss.setupProtocol != StreamProtocolUDP { - continue - } - - inTimeout := func() bool { + case <-checkTimeoutTicker.C: + switch { + // in case of record and UDP, timeout happens when no frames are being received + case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP: now := time.Now() - for _, track := range ss.announcedTracks { - lft := atomic.LoadInt64(track.udpLastFrameTime) - if now.Sub(time.Unix(lft, 0)) < ss.s.ReadTimeout { - return false - } + lft := atomic.LoadInt64(ss.udpLastFrameTime) + if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout { + break outer + } + + // in case there's a linked TCP connection, timeout is handled in the connection + case ss.linkedConn != nil: + + // otherwise, timeout happens when no requests arrives + default: + now := time.Now() + if now.Sub(ss.lastRequestTime) >= serverSessionCloseAfterNoRequestsFor { + break outer } - return true - }() - if inTimeout { - break outer } case <-receiverReportTicker.C: @@ -387,20 +395,14 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.announcedTracks = make([]ServerSessionAnnouncedTrack, len(tracks)) for trackID, track := range tracks { clockRate, _ := track.ClockRate() - v := time.Now().Unix() - ss.announcedTracks[trackID] = ServerSessionAnnouncedTrack{ - track: track, - rtcpReceiver: rtcpreceiver.New(nil, clockRate), - udpLastFrameTime: &v, + track: track, + rtcpReceiver: rtcpreceiver.New(nil, clockRate), } } - if res.Header == nil { - res.Header = make(base.Header) - } - - res.Header["Session"] = base.HeaderValue{ss.id} + v := time.Now().Unix() + ss.udpLastFrameTime = &v } return res, err @@ -517,8 +519,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base res.Header = make(base.Header) } - res.Header["Session"] = base.HeaderValue{ss.id} - if th.Protocol == StreamProtocolUDP { ss.setuppedTracks[trackID] = ServerSessionSetuppedTrack{ udpRTPPort: th.ClientPorts[0], @@ -595,7 +595,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base path, query := base.PathSplitQuery(pathAndQuery) - if ss.state != ServerSessionStatePlay { + if ss.state != ServerSessionStatePlay && *ss.setupProtocol == StreamProtocolTCP { ss.linkedConn = sc } @@ -611,12 +611,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base if ss.state != ServerSessionStatePlay { ss.state = ServerSessionStatePlay - if res.Header == nil { - res.Header = make(base.Header) - } - - res.Header["Session"] = base.HeaderValue{ss.id} - if *ss.setupProtocol == StreamProtocolUDP { ss.udpIP = sc.ip() ss.udpZone = sc.zone() @@ -625,6 +619,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base for trackID, track := range ss.setuppedTracks { sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false) } + return res, err } @@ -675,12 +670,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base if res.StatusCode == base.StatusOK { ss.state = ServerSessionStateRecord - if res.Header == nil { - res.Header = make(base.Header) - } - - res.Header["Session"] = base.HeaderValue{ss.id} - if *ss.setupProtocol == StreamProtocolUDP { ss.udpIP = sc.ip() ss.udpZone = sc.zone() @@ -695,6 +684,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.WriteFrame(trackID, StreamTypeRTCP, []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) } + return res, err } @@ -738,12 +728,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base }) if res.StatusCode == base.StatusOK { - if res.Header == nil { - res.Header = make(base.Header) - } - - res.Header["Session"] = base.HeaderValue{ss.id} - switch ss.state { case ServerSessionStatePlay: ss.state = ServerSessionStatePrePlay @@ -775,6 +759,36 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base return &base.Response{ StatusCode: base.StatusOK, }, liberrors.ErrServerTeardown{} + + case base.GetParameter: + if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok { + pathAndQuery, ok := req.URL.RTSPPath() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerNoPath{} + } + + path, query := base.PathSplitQuery(pathAndQuery) + + return h.OnGetParameter(&ServerHandlerOnGetParameterCtx{ + Session: ss, + Conn: sc, + Req: req, + Path: path, + Query: query, + }) + } + + // GET_PARAMETER is used like a ping when reading, and sometimes + // also when publishing; reply with 200 + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"text/parameters"}, + }, + Body: []byte("\n"), + }, nil } return nil, fmt.Errorf("unimplemented") diff --git a/serverudpl.go b/serverudpl.go index 6b2fe2bf..c9bf0e5a 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -123,7 +123,7 @@ func (u *serverUDPListener) run() { if clientData.isPublishing { now := time.Now() - atomic.StoreInt64(clientData.ss.announcedTracks[clientData.trackID].udpLastFrameTime, now.Unix()) + atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix()) clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n]) }