From e52fda806d15fdeafb53dc987afa451900209a02 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Fri, 7 May 2021 11:42:01 +0200 Subject: [PATCH] add ServerConn.Close(), ServerSession.Close() --- client.go | 2 +- examples/server-tls/main.go | 8 +- examples/server/main.go | 8 +- pkg/liberrors/server.go | 14 ++- server.go | 50 ++++++-- server_read_test.go | 39 +++++- server_test.go | 79 +++++++++++- serverconn.go | 199 ++++++++++++++++-------------- serversession.go | 234 ++++++++++++++++++++++-------------- 9 files changed, 430 insertions(+), 203 deletions(-) diff --git a/client.go b/client.go index a376be7f..f000d2ec 100644 --- a/client.go +++ b/client.go @@ -74,7 +74,7 @@ type Client struct { // callback called before every request. OnRequest func(req *base.Request) - // callback called after very response. + // callback called after every response. OnResponse func(res *base.Response) // function used to initialize the TCP client. diff --git a/examples/server-tls/main.go b/examples/server-tls/main.go index 4ec7150f..b3f76ddf 100644 --- a/examples/server-tls/main.go +++ b/examples/server-tls/main.go @@ -22,22 +22,22 @@ type serverHandler struct { sdp []byte } -// called when a connection is opened. +// called after a connection is opened. func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) { log.Printf("conn opened") } -// called when a connection is closed. +// called after a connection is closed. func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) { log.Printf("conn closed (%v)", err) } -// called when a session is opened. +// called after a session is opened. func (sh *serverHandler) OnSessionOpen(ss *gortsplib.ServerSession) { log.Printf("session opened") } -// called when a session is closed. +// called after a session is closed. func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession, err error) { log.Printf("session closed") diff --git a/examples/server/main.go b/examples/server/main.go index e026f634..abd8e367 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -21,22 +21,22 @@ type serverHandler struct { sdp []byte } -// called when a connection is opened. +// called after a connection is opened. func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) { log.Printf("conn opened") } -// called when a connection is closed. +// called after a connection is closed. func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) { log.Printf("conn closed (%v)", err) } -// called when a session is opened. +// called after a session is opened. func (sh *serverHandler) OnSessionOpen(ss *gortsplib.ServerSession) { log.Printf("session opened") } -// called when a session is closed. +// called after a session is closed. func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession, err error) { log.Printf("session closed") diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index b3dee787..e9c365f2 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -15,6 +15,14 @@ func (e ErrServerTerminated) Error() string { return "terminated" } +// ErrServerSessionNotFound is an error that can be returned by a server. +type ErrServerSessionNotFound struct{} + +// Error implements the error interface. +func (e ErrServerSessionNotFound) Error() string { + return "session not found" +} + // ErrServerSessionTimedOut is an error that can be returned by a server. type ErrServerSessionTimedOut struct{} @@ -48,11 +56,13 @@ func (e ErrServerCSeqMissing) Error() string { } // ErrServerUnhandledRequest is an error that can be returned by a server. -type ErrServerUnhandledRequest struct{} +type ErrServerUnhandledRequest struct { + Req *base.Request +} // Error implements the error interface. func (e ErrServerUnhandledRequest) Error() string { - return "unhandled request" + return fmt.Sprintf("unhandled request (%v %v)", e.Req.Method, e.Req.URL) } // ErrServerWrongState is an error that can be returned by a server. diff --git a/server.go b/server.go index a3be138d..1bb28f00 100644 --- a/server.go +++ b/server.go @@ -9,6 +9,9 @@ import ( "strconv" "sync" "time" + + "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/liberrors" ) func extractPort(address string) (int, error) { @@ -41,10 +44,18 @@ func newSessionID(sessions map[string]*ServerSession) (string, error) { } } -type sessionGetReq struct { +type sessionReqRes struct { + res *base.Response + err error + ss *ServerSession +} + +type sessionReq struct { + sc *ServerConn + req *base.Request id string create bool - res chan *ServerSession + res chan sessionReqRes } // Server is a RTSP server. @@ -100,7 +111,7 @@ type Server struct { // in connClose chan *ServerConn - sessionGet chan sessionGetReq + sessionReq chan sessionReq sessionClose chan *ServerSession terminate chan struct{} @@ -194,7 +205,7 @@ func (s *Server) run() { s.sessions = make(map[string]*ServerSession) s.conns = make(map[*ServerConn]struct{}) s.connClose = make(chan *ServerConn) - s.sessionGet = make(chan sessionGetReq) + s.sessionReq = make(chan sessionReq) s.sessionClose = make(chan *ServerSession) var wg sync.WaitGroup @@ -233,25 +244,35 @@ outer: } s.doConnClose(sc) - case req := <-s.sessionGet: + case req := <-s.sessionReq: if ss, ok := s.sessions[req.id]; ok { - req.res <- ss + ss.request <- req } else { if !req.create { - req.res <- nil + req.res <- sessionReqRes{ + res: &base.Response{ + StatusCode: base.StatusBadRequest, + }, + err: liberrors.ErrServerSessionNotFound{}, + } continue } id, err := newSessionID(s.sessions) if err != nil { - req.res <- nil + req.res <- sessionReqRes{ + res: &base.Response{ + StatusCode: base.StatusBadRequest, + }, + err: fmt.Errorf("internal error"), + } continue } ss := newServerSession(s, id, &wg) s.sessions[id] = ss - req.res <- ss + ss.request <- req } case ss := <-s.sessionClose: @@ -284,11 +305,16 @@ outer: return } - case req, ok := <-s.sessionGet: + case req, ok := <-s.sessionReq: if !ok { return } - req.res <- nil + req.res <- sessionReqRes{ + res: &base.Response{ + StatusCode: base.StatusBadRequest, + }, + err: liberrors.ErrServerTerminated{}, + } case _, ok := <-s.sessionClose: if !ok { @@ -321,7 +347,7 @@ outer: close(acceptErr) close(connNew) close(s.connClose) - close(s.sessionGet) + close(s.sessionReq) close(s.sessionClose) close(s.done) } diff --git a/server_read_test.go b/server_read_test.go index 9848ab9c..e2946fb5 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -310,6 +310,11 @@ func TestServerRead(t *testing.T) { require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload) close(framesReceived) }, + onGetParameter: func(ctx *ServerHandlerOnGetParameterCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, }, } @@ -453,11 +458,43 @@ func TestServerRead(t *testing.T) { <-framesReceived + if proto == "udp" { + // ping with OPTIONS + err = base.Request{ + Method: base.Options, + URL: base.MustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"4"}, + "Session": res.Header["Session"], + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + // ping with GET_PARAMETER + err = base.Request{ + Method: base.GetParameter, + URL: base.MustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"5"}, + "Session": res.Header["Session"], + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + } + err = base.Request{ Method: base.Teardown, URL: base.MustParseURL("rtsp://localhost:8554/teststream"), Header: base.Header{ - "CSeq": base.HeaderValue{"3"}, + "CSeq": base.HeaderValue{"6"}, "Session": res.Header["Session"], }, }.Write(bconn.Writer) diff --git a/server_test.go b/server_test.go index 0debbd6b..465f8fba 100644 --- a/server_test.go +++ b/server_test.go @@ -427,6 +427,31 @@ func TestServerErrorWrongUDPPorts(t *testing.T) { }) } +func TestServerConnClose(t *testing.T) { + connClosed := make(chan struct{}) + + s := &Server{ + Handler: &testServerHandler{ + onConnOpen: func(sc *ServerConn) { + sc.Close() + }, + onConnClose: func(sc *ServerConn, err error) { + close(connClosed) + }, + }, + } + + err := s.Start("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + + <-connClosed +} + func TestServerCSeq(t *testing.T) { s := &Server{} err := s.Start("127.0.0.1:8554") @@ -493,7 +518,7 @@ func TestServerErrorCSeqMissing(t *testing.T) { func TestServerErrorInvalidMethod(t *testing.T) { h := &testServerHandler{ onConnClose: func(sc *ServerConn, err error) { - require.Equal(t, "unhandled request", err.Error()) + require.Equal(t, "unhandled request (INVALID rtsp://localhost:8554/)", err.Error()) }, } @@ -846,3 +871,55 @@ func TestServerErrorInvalidSession(t *testing.T) { }) } } + +func TestServerSessionClose(t *testing.T) { + sessionClosed := make(chan struct{}) + + s := &Server{ + Handler: &testServerHandler{ + onSessionOpen: func(ss *ServerSession) { + ss.Close() + }, + onSessionClose: func(ss *ServerSession, err error) { + close(sessionClosed) + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + InterleavedIDs: &[2]int{0, 1}, + }.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + <-sessionClosed +} diff --git a/serverconn.go b/serverconn.go index c494e438..1e9dd10e 100644 --- a/serverconn.go +++ b/serverconn.go @@ -3,7 +3,6 @@ package gortsplib import ( "bufio" "crypto/tls" - "fmt" "net" "strings" "sync" @@ -55,7 +54,8 @@ type ServerConn struct { tcpFrameBackgroundWriteDone chan struct{} // in - terminate chan struct{} + innerTerminate chan struct{} + terminate chan struct{} } func newServerConn( @@ -64,10 +64,11 @@ func newServerConn( nconn net.Conn) *ServerConn { sc := &ServerConn{ - s: s, - wg: wg, - nconn: nconn, - terminate: make(chan struct{}), + s: s, + wg: wg, + nconn: nconn, + innerTerminate: make(chan struct{}, 1), + terminate: make(chan struct{}), } wg.Add(1) @@ -76,6 +77,15 @@ func newServerConn( return sc } +// Close closes the ServerConn. +func (sc *ServerConn) Close() error { + select { + case sc.innerTerminate <- struct{}{}: + default: + } + return nil +} + // NetConn returns the underlying net.Conn. func (sc *ServerConn) NetConn() net.Conn { return sc.nconn @@ -177,12 +187,26 @@ func (sc *ServerConn) run() { } sc.nconn.Close() - sc.s.connClose <- sc + sc.s.connClose <- sc <-sc.terminate return err + case <-sc.innerTerminate: + sc.nconn.Close() + <-readDone + + if sc.tcpFrameEnabled { + sc.tcpFrameWriteBuffer.Close() + <-sc.tcpFrameBackgroundWriteDone + } + + sc.s.connClose <- sc + <-sc.terminate + + return liberrors.ErrServerTerminated{} + case <-sc.terminate: sc.nconn.Close() <-readDone @@ -226,6 +250,21 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { switch req.Method { case base.Options: + // handle request in session + if sxID != "" { + cres := make(chan sessionReqRes) + sc.s.sessionReq <- sessionReq{ + sc: sc, + req: req, + id: sxID, + create: false, + res: cres, + } + res := <-cres + return res.res, res.err + } + + // handle request here var methods []string if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok { methods = append(methods, string(base.Describe)) @@ -291,58 +330,46 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { case base.Announce: if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok { - sres := make(chan *ServerSession) - sc.s.sessionGet <- sessionGetReq{id: sxID, create: true, res: sres} - ss := <-sres - - if ss == nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("terminated") + cres := make(chan sessionReqRes) + sc.s.sessionReq <- sessionReq{ + sc: sc, + req: req, + id: sxID, + create: true, + res: cres, } - - rres := make(chan requestRes) - ss.request <- requestReq{sc: sc, req: req, res: rres} - res := <-rres + res := <-cres return res.res, res.err } case base.Setup: if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok { - sres := make(chan *ServerSession) - sc.s.sessionGet <- sessionGetReq{id: sxID, create: true, res: sres} - ss := <-sres - - if ss == nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("terminated") + cres := make(chan sessionReqRes) + sc.s.sessionReq <- sessionReq{ + sc: sc, + req: req, + id: sxID, + create: true, + res: cres, } - - rres := make(chan requestRes) - ss.request <- requestReq{sc: sc, req: req, res: rres} - res := <-rres + res := <-cres return res.res, res.err } case base.Play: if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok { - sres := make(chan *ServerSession) - sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres} - ss := <-sres - - if ss == nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerInvalidSession{} + cres := make(chan sessionReqRes) + sc.s.sessionReq <- sessionReq{ + sc: sc, + req: req, + id: sxID, + create: false, + res: cres, } - - rres := make(chan requestRes) - ss.request <- requestReq{sc: sc, req: req, res: rres} - res := <-rres + res := <-cres if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok { - sc.tcpFrameLinkedSession = ss + sc.tcpFrameLinkedSession = res.ss sc.tcpFrameIsRecording = false sc.tcpFrameSetEnabled = true return res.res, nil @@ -353,22 +380,18 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { case base.Record: if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok { - sres := make(chan *ServerSession) - sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres} - ss := <-sres - - if ss == nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerInvalidSession{} + cres := make(chan sessionReqRes) + sc.s.sessionReq <- sessionReq{ + sc: sc, + req: req, + id: sxID, + create: false, + res: cres, } - - rres := make(chan requestRes) - ss.request <- requestReq{sc: sc, req: req, res: rres} - res := <-rres + res := <-cres if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok { - sc.tcpFrameLinkedSession = ss + sc.tcpFrameLinkedSession = res.ss sc.tcpFrameIsRecording = true sc.tcpFrameSetEnabled = true return res.res, nil @@ -379,19 +402,15 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { case base.Pause: if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok { - sres := make(chan *ServerSession) - sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres} - ss := <-sres - - if ss == nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerInvalidSession{} + cres := make(chan sessionReqRes) + sc.s.sessionReq <- sessionReq{ + sc: sc, + req: req, + id: sxID, + create: false, + res: cres, } - - rres := make(chan requestRes) - ss.request <- requestReq{sc: sc, req: req, res: rres} - res := <-rres + res := <-cres if _, ok := res.err.(liberrors.ErrServerTCPFramesDisable); ok { sc.tcpFrameSetEnabled = false @@ -402,31 +421,29 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } case base.Teardown: - sres := make(chan *ServerSession) - sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres} - ss := <-sres - - if ss == nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerInvalidSession{} + cres := make(chan sessionReqRes) + sc.s.sessionReq <- sessionReq{ + sc: sc, + req: req, + id: sxID, + create: false, + res: cres, } - - rres := make(chan requestRes) - ss.request <- requestReq{sc: sc, req: req, res: rres} - res := <-rres + res := <-cres return res.res, res.err case base.GetParameter: - sres := make(chan *ServerSession) - sc.s.sessionGet <- sessionGetReq{id: sxID, create: false, res: sres} - ss := <-sres - - // send request to session - if ss != nil { - rres := make(chan requestRes) - ss.request <- requestReq{sc: sc, req: req, res: rres} - res := <-rres + // handle request in session + if sxID != "" { + cres := make(chan sessionReqRes) + sc.s.sessionReq <- sessionReq{ + sc: sc, + req: req, + id: sxID, + create: false, + res: cres, + } + res := <-cres return res.res, res.err } @@ -471,7 +488,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerUnhandledRequest{} + }, liberrors.ErrServerUnhandledRequest{Req: req} } func (sc *ServerConn) handleRequestOuter(req *base.Request) error { diff --git a/serversession.go b/serversession.go index 37a81453..33ebdce8 100644 --- a/serversession.go +++ b/serversession.go @@ -112,17 +112,6 @@ type ServerSessionAnnouncedTrack struct { rtcpReceiver *rtcpreceiver.RTCPReceiver } -type requestRes struct { - res *base.Response - err error -} - -type requestReq struct { - sc *ServerConn - req *base.Request - res chan requestRes -} - // ServerSession is a server-side RTSP session. type ServerSession struct { s *Server @@ -142,8 +131,9 @@ type ServerSession struct { udpLastFrameTime *int64 // publish, udp // in - request chan requestReq - terminate chan struct{} + request chan sessionReq + innerTerminate chan struct{} + terminate chan struct{} } func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession { @@ -152,7 +142,8 @@ func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession { id: id, wg: wg, lastRequestTime: time.Now(), - request: make(chan requestReq), + request: make(chan sessionReq), + innerTerminate: make(chan struct{}, 1), terminate: make(chan struct{}), } @@ -162,6 +153,15 @@ func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession { return ss } +// Close closes the ServerSession. +func (ss *ServerSession) Close() error { + select { + case ss.innerTerminate <- struct{}{}: + default: + } + return nil +} + // State returns the state of the session. func (ss *ServerSession) State() ServerSessionState { return ss.state @@ -203,78 +203,106 @@ func (ss *ServerSession) run() { h.OnSessionOpen(ss) } - checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod) - defer checkTimeoutTicker.Stop() - - receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod) - defer receiverReportTicker.Stop() - - err := func() error { - for { - select { - case req := <-ss.request: - res, err := ss.handleRequest(req.sc, req.req) - - ss.lastRequestTime = time.Now() - - if res.StatusCode == base.StatusOK { - if res.Header == nil { - res.Header = make(base.Header) - } - res.Header["Session"] = base.HeaderValue{ss.id} - } - - if _, ok := err.(liberrors.ErrServerSessionTeardown); ok { - req.res <- requestRes{res, nil} - return liberrors.ErrServerSessionTeardown{} - } - - req.res <- requestRes{res, err} - - case <-checkTimeoutTicker.C: - switch { - // in case of record and UDP, timeout happens when no frames are being received - case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP: - now := time.Now() - lft := atomic.LoadInt64(ss.udpLastFrameTime) - if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout { - return liberrors.ErrServerSessionTimedOut{} - } - - // in case there's a linked TCP connection, timeout is handled in the connection - case ss.linkedConn != nil: - - // otherwise, timeout happens when no requests arrives - default: - now := time.Now() - if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor { - return liberrors.ErrServerSessionTimedOut{} - } - } - - case <-receiverReportTicker.C: - if ss.state != ServerSessionStateRecord { - continue - } - - now := time.Now() - for trackID, track := range ss.announcedTracks { - r := track.rtcpReceiver.Report(now) - ss.WriteFrame(trackID, StreamTypeRTCP, r) - } - - case <-ss.terminate: - return liberrors.ErrServerTerminated{} - } - } - }() - + readDone := make(chan error) go func() { - for req := range ss.request { - req.res <- requestRes{nil, fmt.Errorf("terminated")} - } + readDone <- func() error { + checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod) + defer checkTimeoutTicker.Stop() + + receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod) + defer receiverReportTicker.Stop() + + for { + select { + case req := <-ss.request: + res, err := ss.handleRequest(req.sc, req.req) + + ss.lastRequestTime = time.Now() + + if res.StatusCode == base.StatusOK { + if res.Header == nil { + res.Header = make(base.Header) + } + res.Header["Session"] = base.HeaderValue{ss.id} + } + + if _, ok := err.(liberrors.ErrServerSessionTeardown); ok { + req.res <- sessionReqRes{res: res, err: nil} + return liberrors.ErrServerSessionTeardown{} + } + + req.res <- sessionReqRes{ + res: res, + err: err, + ss: ss, + } + + case <-checkTimeoutTicker.C: + switch { + // in case of record and UDP, timeout happens when no frames are being received + case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP: + now := time.Now() + lft := atomic.LoadInt64(ss.udpLastFrameTime) + if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout { + return liberrors.ErrServerSessionTimedOut{} + } + + // in case there's a linked TCP connection, timeout is handled in the connection + case ss.linkedConn != nil: + + // otherwise, timeout happens when no requests arrives + default: + now := time.Now() + if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor { + return liberrors.ErrServerSessionTimedOut{} + } + } + + case <-receiverReportTicker.C: + if ss.state != ServerSessionStateRecord { + continue + } + + now := time.Now() + for trackID, track := range ss.announcedTracks { + r := track.rtcpReceiver.Report(now) + ss.WriteFrame(trackID, StreamTypeRTCP, r) + } + + case <-ss.innerTerminate: + return liberrors.ErrServerTerminated{} + } + } + }() }() + var err error + select { + case err = <-readDone: + go func() { + for req := range ss.request { + req.res <- sessionReqRes{ + res: &base.Response{ + StatusCode: base.StatusBadRequest, + }, + err: liberrors.ErrServerTerminated{}, + } + } + }() + + ss.s.sessionClose <- ss + <-ss.terminate + + case <-ss.terminate: + select { + case ss.innerTerminate <- struct{}{}: + default: + } + <-readDone + + err = liberrors.ErrServerTerminated{} + } + switch ss.state { case ServerSessionStatePlay: if *ss.setupProtocol == StreamProtocolUDP { @@ -292,9 +320,6 @@ func (ss *ServerSession) run() { ss.s.connClose <- ss.linkedConn } - ss.s.sessionClose <- ss - <-ss.terminate - close(ss.request) if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok { @@ -310,6 +335,39 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base } switch req.Method { + case base.Options: + var methods []string + if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok { + methods = append(methods, string(base.Describe)) + } + if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok { + methods = append(methods, string(base.Announce)) + } + if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok { + methods = append(methods, string(base.Setup)) + } + if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok { + methods = append(methods, string(base.Play)) + } + if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok { + methods = append(methods, string(base.Record)) + } + if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok { + methods = append(methods, string(base.Pause)) + } + methods = append(methods, string(base.GetParameter)) + if _, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok { + methods = append(methods, string(base.SetParameter)) + } + methods = append(methods, string(base.Teardown)) + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join(methods, ", ")}, + }, + }, nil + case base.Announce: err := ss.checkState(map[ServerSessionState]struct{}{ ServerSessionStateInitial: {}, @@ -808,7 +866,9 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base }, nil } - return nil, fmt.Errorf("unimplemented") + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerUnhandledRequest{Req: req} } // WriteFrame writes a frame.