diff --git a/examples/server-tls/main.go b/examples/server-tls/main.go index 563f90eb..4ec7150f 100644 --- a/examples/server-tls/main.go +++ b/examples/server-tls/main.go @@ -38,7 +38,7 @@ func (sh *serverHandler) OnSessionOpen(ss *gortsplib.ServerSession) { } // called when a session is closed. -func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession) { +func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession, err error) { log.Printf("session closed") sh.mutex.Lock() diff --git a/examples/server/main.go b/examples/server/main.go index fa78bc37..e026f634 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -37,7 +37,7 @@ func (sh *serverHandler) OnSessionOpen(ss *gortsplib.ServerSession) { } // called when a session is closed. -func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession) { +func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession, err error) { log.Printf("session closed") sh.mutex.Lock() diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index a5cf1215..b3dee787 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -7,6 +7,22 @@ import ( "github.com/aler9/gortsplib/pkg/headers" ) +// ErrServerTerminated is an error that can be returned by a server. +type ErrServerTerminated struct{} + +// Error implements the error interface. +func (e ErrServerTerminated) Error() string { + return "terminated" +} + +// ErrServerSessionTimedOut is an error that can be returned by a server. +type ErrServerSessionTimedOut struct{} + +// Error implements the error interface. +func (e ErrServerSessionTimedOut) Error() string { + return "timed out" +} + // ErrServerTCPFramesEnable is an error that can be returned by a server. type ErrServerTCPFramesEnable struct{} @@ -192,11 +208,11 @@ func (e ErrServerLinkedToOtherSession) Error() string { return "connection is linked to another session" } -// ErrServerTeardown is an error that can be returned by a server. -type ErrServerTeardown struct{} +// ErrServerSessionTeardown is an error that can be returned by a server. +type ErrServerSessionTeardown struct{} // Error implements the error interface. -func (e ErrServerTeardown) Error() string { +func (e ErrServerSessionTeardown) Error() string { return "teardown" } diff --git a/server_publish_test.go b/server_publish_test.go index 9ff939e3..91ba3c35 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -500,7 +500,7 @@ func TestServerPublish(t *testing.T) { onSessionOpen: func(ss *ServerSession) { close(sessionOpened) }, - onSessionClose: func(ss *ServerSession) { + onSessionClose: func(ss *ServerSession, err error) { close(sessionClosed) }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { @@ -997,7 +997,7 @@ func TestServerPublishErrorTimeout(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onSessionClose: func(ss *ServerSession) { + onSessionClose: func(ss *ServerSession, err error) { close(sessionClosed) }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { diff --git a/server_read_test.go b/server_read_test.go index 7e63efcf..71ab3856 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -288,7 +288,7 @@ func TestServerRead(t *testing.T) { onSessionOpen: func(ss *ServerSession) { close(sessionOpened) }, - onSessionClose: func(ss *ServerSession) { + onSessionClose: func(ss *ServerSession, err error) { close(sessionClosed) }, onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { diff --git a/server_test.go b/server_test.go index ddd3030d..49dd0e9d 100644 --- a/server_test.go +++ b/server_test.go @@ -19,7 +19,7 @@ type testServerHandler struct { onConnOpen func(*ServerConn) onConnClose func(*ServerConn, error) onSessionOpen func(*ServerSession) - onSessionClose func(*ServerSession) + onSessionClose func(*ServerSession, error) onDescribe func(*ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) onAnnounce func(*ServerHandlerOnAnnounceCtx) (*base.Response, error) onSetup func(*ServerHandlerOnSetupCtx) (*base.Response, error) @@ -49,9 +49,9 @@ func (sh *testServerHandler) OnSessionOpen(ss *ServerSession) { } } -func (sh *testServerHandler) OnSessionClose(ss *ServerSession) { +func (sh *testServerHandler) OnSessionClose(ss *ServerSession, err error) { if sh.onSessionClose != nil { - sh.onSessionClose(ss) + sh.onSessionClose(ss, err) } } @@ -211,7 +211,7 @@ func TestServerHighLevelPublishRead(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onSessionClose: func(ss *ServerSession) { + onSessionClose: func(ss *ServerSession, err error) { mutex.Lock() defer mutex.Unlock() diff --git a/serverconn.go b/serverconn.go index 4994ae68..068573b1 100644 --- a/serverconn.go +++ b/serverconn.go @@ -168,25 +168,28 @@ func (sc *ServerConn) run() { }() }() - var err error - select { - case err = <-readDone: - if sc.tcpFrameEnabled { - sc.tcpFrameWriteBuffer.Close() - <-sc.tcpFrameBackgroundWriteDone - } - sc.nconn.Close() - sc.s.connClose <- sc - <-sc.terminate + err := func() error { + select { + case err := <-readDone: + if sc.tcpFrameEnabled { + sc.tcpFrameWriteBuffer.Close() + <-sc.tcpFrameBackgroundWriteDone + } + sc.nconn.Close() + sc.s.connClose <- sc + <-sc.terminate + return err - case <-sc.terminate: - if sc.tcpFrameEnabled { - sc.tcpFrameWriteBuffer.Close() - <-sc.tcpFrameBackgroundWriteDone + case <-sc.terminate: + if sc.tcpFrameEnabled { + sc.tcpFrameWriteBuffer.Close() + <-sc.tcpFrameBackgroundWriteDone + } + sc.nconn.Close() + <-readDone + return liberrors.ErrServerTerminated{} } - sc.nconn.Close() - err = <-readDone - } + }() if sc.tcpFrameEnabled { sc.s.sessionClose <- sc.tcpFrameLinkedSession diff --git a/serverhandler.go b/serverhandler.go index 41e03e92..26ef55a4 100644 --- a/serverhandler.go +++ b/serverhandler.go @@ -26,7 +26,7 @@ type ServerHandlerOnSessionOpen interface { // ServerHandlerOnSessionClose can be implemented by a ServerHandler. type ServerHandlerOnSessionClose interface { - OnSessionClose(*ServerSession) + OnSessionClose(*ServerSession, error) } // ServerHandlerOnRequest can be implemented by a ServerHandler. diff --git a/serversession.go b/serversession.go index 3a71c08d..86aa7d48 100644 --- a/serversession.go +++ b/serversession.go @@ -210,64 +210,65 @@ func (ss *ServerSession) run() { receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod) defer receiverReportTicker.Stop() -outer: - for { - select { - case req := <-ss.request: - res, err := ss.handleRequest(req.sc, req.req) + err := func() error { + for { + select { + case req := <-ss.request: + res, err := ss.handleRequest(req.sc, req.req) - ss.lastRequestTime = time.Now() + ss.lastRequestTime = time.Now() - if res.StatusCode == base.StatusOK { - if res.Header == nil { - res.Header = make(base.Header) + if res.StatusCode == base.StatusOK { + if res.Header == nil { + res.Header = make(base.Header) + } + res.Header["Session"] = base.HeaderValue{ss.id} } - res.Header["Session"] = base.HeaderValue{ss.id} - } - if _, ok := err.(liberrors.ErrServerTeardown); ok { - req.res <- requestRes{res, nil} - break outer - } + if _, ok := err.(liberrors.ErrServerSessionTeardown); ok { + req.res <- requestRes{res, nil} + return liberrors.ErrServerSessionTeardown{} + } - req.res <- requestRes{res, err} + req.res <- requestRes{res, err} + + case <-checkTimeoutTicker.C: + switch { + // in case of record and UDP, timeout happens when no frames are being received + case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP: + now := time.Now() + lft := atomic.LoadInt64(ss.udpLastFrameTime) + if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout { + return liberrors.ErrServerSessionTimedOut{} + } + + // in case there's a linked TCP connection, timeout is handled in the connection + case ss.linkedConn != nil: + + // otherwise, timeout happens when no requests arrives + default: + now := time.Now() + if now.Sub(ss.lastRequestTime) >= serverSessionCloseAfterNoRequestsFor { + return liberrors.ErrServerSessionTimedOut{} + } + } + + case <-receiverReportTicker.C: + if ss.state != ServerSessionStateRecord { + continue + } - case <-checkTimeoutTicker.C: - switch { - // in case of record and UDP, timeout happens when no frames are being received - case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP: now := time.Now() - lft := atomic.LoadInt64(ss.udpLastFrameTime) - if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout { - break outer + for trackID, track := range ss.announcedTracks { + r := track.rtcpReceiver.Report(now) + ss.WriteFrame(trackID, StreamTypeRTCP, r) } - // in case there's a linked TCP connection, timeout is handled in the connection - case ss.linkedConn != nil: - - // otherwise, timeout happens when no requests arrives - default: - now := time.Now() - if now.Sub(ss.lastRequestTime) >= serverSessionCloseAfterNoRequestsFor { - break outer - } + case <-ss.terminate: + return liberrors.ErrServerTerminated{} } - - case <-receiverReportTicker.C: - if ss.state != ServerSessionStateRecord { - continue - } - - now := time.Now() - for trackID, track := range ss.announcedTracks { - r := track.rtcpReceiver.Report(now) - ss.WriteFrame(trackID, StreamTypeRTCP, r) - } - - case <-ss.terminate: - break outer } - } + }() go func() { for req := range ss.request { @@ -298,7 +299,7 @@ outer: close(ss.request) if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok { - h.OnSessionClose(ss) + h.OnSessionClose(ss, err) } } @@ -758,7 +759,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base return &base.Response{ StatusCode: base.StatusOK, - }, liberrors.ErrServerTeardown{} + }, liberrors.ErrServerSessionTeardown{} case base.GetParameter: if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok {