add ServerConn.Close(), ServerSession.Close()

This commit is contained in:
aler9
2021-05-07 11:42:01 +02:00
parent 1c9fe6c394
commit e52fda806d
9 changed files with 430 additions and 203 deletions

View File

@@ -112,17 +112,6 @@ type ServerSessionAnnouncedTrack struct {
rtcpReceiver *rtcpreceiver.RTCPReceiver
}
type requestRes struct {
res *base.Response
err error
}
type requestReq struct {
sc *ServerConn
req *base.Request
res chan requestRes
}
// ServerSession is a server-side RTSP session.
type ServerSession struct {
s *Server
@@ -142,8 +131,9 @@ type ServerSession struct {
udpLastFrameTime *int64 // publish, udp
// in
request chan requestReq
terminate chan struct{}
request chan sessionReq
innerTerminate chan struct{}
terminate chan struct{}
}
func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
@@ -152,7 +142,8 @@ func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
id: id,
wg: wg,
lastRequestTime: time.Now(),
request: make(chan requestReq),
request: make(chan sessionReq),
innerTerminate: make(chan struct{}, 1),
terminate: make(chan struct{}),
}
@@ -162,6 +153,15 @@ func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
return ss
}
// Close closes the ServerSession.
func (ss *ServerSession) Close() error {
select {
case ss.innerTerminate <- struct{}{}:
default:
}
return nil
}
// State returns the state of the session.
func (ss *ServerSession) State() ServerSessionState {
return ss.state
@@ -203,78 +203,106 @@ func (ss *ServerSession) run() {
h.OnSessionOpen(ss)
}
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
defer checkTimeoutTicker.Stop()
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
defer receiverReportTicker.Stop()
err := func() error {
for {
select {
case req := <-ss.request:
res, err := ss.handleRequest(req.sc, req.req)
ss.lastRequestTime = time.Now()
if res.StatusCode == base.StatusOK {
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
}
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
req.res <- requestRes{res, nil}
return liberrors.ErrServerSessionTeardown{}
}
req.res <- requestRes{res, err}
case <-checkTimeoutTicker.C:
switch {
// in case of record and UDP, timeout happens when no frames are being received
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
now := time.Now()
lft := atomic.LoadInt64(ss.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
return liberrors.ErrServerSessionTimedOut{}
}
// in case there's a linked TCP connection, timeout is handled in the connection
case ss.linkedConn != nil:
// otherwise, timeout happens when no requests arrives
default:
now := time.Now()
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
return liberrors.ErrServerSessionTimedOut{}
}
}
case <-receiverReportTicker.C:
if ss.state != ServerSessionStateRecord {
continue
}
now := time.Now()
for trackID, track := range ss.announcedTracks {
r := track.rtcpReceiver.Report(now)
ss.WriteFrame(trackID, StreamTypeRTCP, r)
}
case <-ss.terminate:
return liberrors.ErrServerTerminated{}
}
}
}()
readDone := make(chan error)
go func() {
for req := range ss.request {
req.res <- requestRes{nil, fmt.Errorf("terminated")}
}
readDone <- func() error {
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
defer checkTimeoutTicker.Stop()
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
defer receiverReportTicker.Stop()
for {
select {
case req := <-ss.request:
res, err := ss.handleRequest(req.sc, req.req)
ss.lastRequestTime = time.Now()
if res.StatusCode == base.StatusOK {
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
}
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
req.res <- sessionReqRes{res: res, err: nil}
return liberrors.ErrServerSessionTeardown{}
}
req.res <- sessionReqRes{
res: res,
err: err,
ss: ss,
}
case <-checkTimeoutTicker.C:
switch {
// in case of record and UDP, timeout happens when no frames are being received
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
now := time.Now()
lft := atomic.LoadInt64(ss.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
return liberrors.ErrServerSessionTimedOut{}
}
// in case there's a linked TCP connection, timeout is handled in the connection
case ss.linkedConn != nil:
// otherwise, timeout happens when no requests arrives
default:
now := time.Now()
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
return liberrors.ErrServerSessionTimedOut{}
}
}
case <-receiverReportTicker.C:
if ss.state != ServerSessionStateRecord {
continue
}
now := time.Now()
for trackID, track := range ss.announcedTracks {
r := track.rtcpReceiver.Report(now)
ss.WriteFrame(trackID, StreamTypeRTCP, r)
}
case <-ss.innerTerminate:
return liberrors.ErrServerTerminated{}
}
}
}()
}()
var err error
select {
case err = <-readDone:
go func() {
for req := range ss.request {
req.res <- sessionReqRes{
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: liberrors.ErrServerTerminated{},
}
}
}()
ss.s.sessionClose <- ss
<-ss.terminate
case <-ss.terminate:
select {
case ss.innerTerminate <- struct{}{}:
default:
}
<-readDone
err = liberrors.ErrServerTerminated{}
}
switch ss.state {
case ServerSessionStatePlay:
if *ss.setupProtocol == StreamProtocolUDP {
@@ -292,9 +320,6 @@ func (ss *ServerSession) run() {
ss.s.connClose <- ss.linkedConn
}
ss.s.sessionClose <- ss
<-ss.terminate
close(ss.request)
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
@@ -310,6 +335,39 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}
switch req.Method {
case base.Options:
var methods []string
if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok {
methods = append(methods, string(base.Describe))
}
if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
methods = append(methods, string(base.Announce))
}
if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
methods = append(methods, string(base.Setup))
}
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
methods = append(methods, string(base.Play))
}
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
methods = append(methods, string(base.Record))
}
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
methods = append(methods, string(base.Pause))
}
methods = append(methods, string(base.GetParameter))
if _, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok {
methods = append(methods, string(base.SetParameter))
}
methods = append(methods, string(base.Teardown))
return &base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join(methods, ", ")},
},
}, nil
case base.Announce:
err := ss.checkState(map[ServerSessionState]struct{}{
ServerSessionStateInitial: {},
@@ -808,7 +866,9 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
}, nil
}
return nil, fmt.Errorf("unimplemented")
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerUnhandledRequest{Req: req}
}
// WriteFrame writes a frame.