server: close session when there are no conns attached to it

This commit is contained in:
aler9
2021-05-08 15:35:13 +02:00
parent d3361ffd90
commit 028ed2b973
6 changed files with 366 additions and 267 deletions

View File

@@ -118,33 +118,42 @@ type ServerSession struct {
id string
wg *sync.WaitGroup
conns map[*ServerConn]struct{}
connsWG sync.WaitGroup
state ServerSessionState
setuppedTracks map[int]ServerSessionSetuppedTrack
setupProtocol *StreamProtocol
setupPath *string
setupQuery *string
lastRequestTime time.Time
linkedConn *ServerConn // tcp
tcpConn *ServerConn // tcp
udpIP net.IP // udp
udpZone string // udp
announcedTracks []ServerSessionAnnouncedTrack // publish
udpLastFrameTime *int64 // publish, udp
// in
request chan sessionReq
innerTerminate chan struct{}
terminate chan struct{}
request chan request
connRemove chan *ServerConn
innerTerminate chan struct{}
parentTerminate chan struct{}
}
func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
func newServerSession(s *Server,
id string,
wg *sync.WaitGroup,
) *ServerSession {
ss := &ServerSession{
s: s,
id: id,
wg: wg,
conns: make(map[*ServerConn]struct{}),
lastRequestTime: time.Now(),
request: make(chan sessionReq),
request: make(chan request),
connRemove: make(chan *ServerConn),
innerTerminate: make(chan struct{}, 1),
terminate: make(chan struct{}),
parentTerminate: make(chan struct{}),
}
wg.Add(1)
@@ -208,105 +217,125 @@ func (ss *ServerSession) run() {
h.OnSessionOpen(ss)
}
readDone := make(chan error)
go func() {
readDone <- func() error {
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
defer checkTimeoutTicker.Stop()
err := func() error {
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
defer checkTimeoutTicker.Stop()
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
defer receiverReportTicker.Stop()
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
defer receiverReportTicker.Stop()
for {
select {
case req := <-ss.request:
res, err := ss.handleRequest(req.sc, req.req)
for {
select {
case req := <-ss.request:
ss.lastRequestTime = time.Now()
ss.lastRequestTime = time.Now()
if _, ok := ss.conns[req.sc]; !ok {
ss.conns[req.sc] = struct{}{}
ss.connsWG.Add(1)
}
if res.StatusCode == base.StatusOK {
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
res, err := ss.handleRequest(req.sc, req.req)
if res.StatusCode == base.StatusOK {
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["Session"] = base.HeaderValue{ss.id}
}
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
req.res <- sessionReqRes{res: res, err: nil}
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
req.res <- requestRes{res: res, err: nil}
return liberrors.ErrServerSessionTeardown{}
}
req.res <- requestRes{
res: res,
err: err,
ss: ss,
}
case sc := <-ss.connRemove:
if _, ok := ss.conns[sc]; ok {
delete(ss.conns, sc)
sc.sessionRemove <- ss
ss.connsWG.Done()
}
// if session is not in state RECORD or PLAY, or protocol is TCP
if (ss.state != ServerSessionStateRecord &&
ss.state != ServerSessionStatePlay) ||
*ss.setupProtocol == StreamProtocolTCP {
// close if there are no active connections
if len(ss.conns) == 0 {
return liberrors.ErrServerSessionTeardown{}
}
req.res <- sessionReqRes{
res: res,
err: err,
ss: ss,
}
case <-checkTimeoutTicker.C:
switch {
// in case of record and UDP, timeout happens when no frames are being received
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
now := time.Now()
lft := atomic.LoadInt64(ss.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
return liberrors.ErrServerSessionTimedOut{}
}
// in case there's a linked TCP connection, timeout is handled in the connection
case ss.linkedConn != nil:
// otherwise, timeout happens when no requests arrives
default:
now := time.Now()
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
return liberrors.ErrServerSessionTimedOut{}
}
}
case <-receiverReportTicker.C:
if ss.state != ServerSessionStateRecord {
continue
}
now := time.Now()
for trackID, track := range ss.announcedTracks {
r := track.rtcpReceiver.Report(now)
ss.WriteFrame(trackID, StreamTypeRTCP, r)
}
case <-ss.innerTerminate:
return liberrors.ErrServerTerminated{}
}
case <-checkTimeoutTicker.C:
switch {
// in case of RECORD and UDP, timeout happens when no frames are being received
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
now := time.Now()
lft := atomic.LoadInt64(ss.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
return liberrors.ErrServerSessionTimedOut{}
}
// in case of PLAY and UDP, timeout happens when no request arrives
case ss.state == ServerSessionStatePlay && *ss.setupProtocol == StreamProtocolUDP:
now := time.Now()
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
return liberrors.ErrServerSessionTimedOut{}
}
// otherwise, there's no timeout until all associated connections are closed
}
case <-receiverReportTicker.C:
if ss.state != ServerSessionStateRecord {
continue
}
now := time.Now()
for trackID, track := range ss.announcedTracks {
r := track.rtcpReceiver.Report(now)
ss.WriteFrame(trackID, StreamTypeRTCP, r)
}
case <-ss.innerTerminate:
return liberrors.ErrServerTerminated{}
}
}()
}
}()
var err error
select {
case err = <-readDone:
go func() {
for req := range ss.request {
req.res <- sessionReqRes{
go func() {
for {
select {
case req, ok := <-ss.request:
if !ok {
return
}
req.res <- requestRes{
ss: nil,
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: liberrors.ErrServerTerminated{},
}
case sc, ok := <-ss.connRemove:
if !ok {
return
}
if _, ok := ss.conns[sc]; ok {
ss.connsWG.Done()
}
}
}()
ss.s.sessionClose <- ss
<-ss.terminate
case <-ss.terminate:
select {
case ss.innerTerminate <- struct{}{}:
default:
}
<-readDone
err = liberrors.ErrServerTerminated{}
}
}()
switch ss.state {
case ServerSessionStatePlay:
@@ -321,11 +350,16 @@ func (ss *ServerSession) run() {
}
}
if ss.linkedConn != nil {
ss.s.connClose <- ss.linkedConn
for sc := range ss.conns {
sc.sessionRemove <- ss
}
ss.connsWG.Wait()
ss.s.sessionClose <- ss
<-ss.parentTerminate
close(ss.request)
close(ss.connRemove)
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
h.OnSessionClose(ss, err)
@@ -333,7 +367,7 @@ func (ss *ServerSession) run() {
}
func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) {
if ss.linkedConn != nil && sc != ss.linkedConn {
if ss.tcpConn != nil && sc != ss.tcpConn {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerSessionLinkedToOtherConn{}
@@ -620,10 +654,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
// this was causing problems during unit tests.
if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 &&
strings.HasPrefix(ua[0], "GStreamer") {
select {
case <-time.After(1 * time.Second):
case <-sc.terminate:
}
<-time.After(1 * time.Second)
}
return res, err
@@ -664,7 +695,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.udpIP = sc.ip()
ss.udpZone = sc.zone()
} else {
ss.linkedConn = sc
ss.tcpConn = sc
}
}
@@ -694,7 +725,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.udpIP = nil
ss.udpZone = ""
ss.linkedConn = nil
ss.tcpConn = nil
}
return res, err
@@ -732,7 +763,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.udpIP = sc.ip()
ss.udpZone = sc.zone()
} else {
ss.linkedConn = sc
ss.tcpConn = sc
}
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
@@ -766,7 +797,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.udpIP = nil
ss.udpZone = ""
ss.linkedConn = nil
ss.tcpConn = nil
return res, err
@@ -809,7 +840,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.state = ServerSessionStatePrePlay
ss.udpIP = nil
ss.udpZone = ""
ss.linkedConn = nil
ss.tcpConn = nil
if *ss.setupProtocol == StreamProtocolUDP {
ss.s.udpRTCPListener.removeClient(ss)
@@ -821,7 +852,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.state = ServerSessionStatePreRecord
ss.udpIP = nil
ss.udpZone = ""
ss.linkedConn = nil
ss.tcpConn = nil
if *ss.setupProtocol == StreamProtocolUDP {
ss.s.udpRTPListener.removeClient(ss)
@@ -834,8 +865,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
return res, err
case base.Teardown:
ss.linkedConn = nil
return &base.Response{
StatusCode: base.StatusOK,
}, liberrors.ErrServerSessionTeardown{}
@@ -896,7 +925,7 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload
})
}
} else {
ss.linkedConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{
ss.tcpConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{
TrackID: trackID,
StreamType: streamType,
Payload: payload,