mirror of
https://github.com/aler9/gortsplib
synced 2025-10-21 14:19:36 +08:00
server: close session when there are no conns attached to it
This commit is contained in:
235
serversession.go
235
serversession.go
@@ -118,33 +118,42 @@ type ServerSession struct {
|
||||
id string
|
||||
wg *sync.WaitGroup
|
||||
|
||||
conns map[*ServerConn]struct{}
|
||||
connsWG sync.WaitGroup
|
||||
state ServerSessionState
|
||||
setuppedTracks map[int]ServerSessionSetuppedTrack
|
||||
setupProtocol *StreamProtocol
|
||||
setupPath *string
|
||||
setupQuery *string
|
||||
lastRequestTime time.Time
|
||||
linkedConn *ServerConn // tcp
|
||||
tcpConn *ServerConn // tcp
|
||||
udpIP net.IP // udp
|
||||
udpZone string // udp
|
||||
announcedTracks []ServerSessionAnnouncedTrack // publish
|
||||
udpLastFrameTime *int64 // publish, udp
|
||||
|
||||
// in
|
||||
request chan sessionReq
|
||||
innerTerminate chan struct{}
|
||||
terminate chan struct{}
|
||||
request chan request
|
||||
connRemove chan *ServerConn
|
||||
innerTerminate chan struct{}
|
||||
parentTerminate chan struct{}
|
||||
}
|
||||
|
||||
func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
|
||||
func newServerSession(s *Server,
|
||||
id string,
|
||||
wg *sync.WaitGroup,
|
||||
) *ServerSession {
|
||||
|
||||
ss := &ServerSession{
|
||||
s: s,
|
||||
id: id,
|
||||
wg: wg,
|
||||
conns: make(map[*ServerConn]struct{}),
|
||||
lastRequestTime: time.Now(),
|
||||
request: make(chan sessionReq),
|
||||
request: make(chan request),
|
||||
connRemove: make(chan *ServerConn),
|
||||
innerTerminate: make(chan struct{}, 1),
|
||||
terminate: make(chan struct{}),
|
||||
parentTerminate: make(chan struct{}),
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
@@ -208,105 +217,125 @@ func (ss *ServerSession) run() {
|
||||
h.OnSessionOpen(ss)
|
||||
}
|
||||
|
||||
readDone := make(chan error)
|
||||
go func() {
|
||||
readDone <- func() error {
|
||||
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
|
||||
defer checkTimeoutTicker.Stop()
|
||||
err := func() error {
|
||||
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
|
||||
defer checkTimeoutTicker.Stop()
|
||||
|
||||
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
|
||||
defer receiverReportTicker.Stop()
|
||||
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
|
||||
defer receiverReportTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-ss.request:
|
||||
res, err := ss.handleRequest(req.sc, req.req)
|
||||
for {
|
||||
select {
|
||||
case req := <-ss.request:
|
||||
ss.lastRequestTime = time.Now()
|
||||
|
||||
ss.lastRequestTime = time.Now()
|
||||
if _, ok := ss.conns[req.sc]; !ok {
|
||||
ss.conns[req.sc] = struct{}{}
|
||||
ss.connsWG.Add(1)
|
||||
}
|
||||
|
||||
if res.StatusCode == base.StatusOK {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
res, err := ss.handleRequest(req.sc, req.req)
|
||||
|
||||
if res.StatusCode == base.StatusOK {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
}
|
||||
|
||||
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
|
||||
req.res <- sessionReqRes{res: res, err: nil}
|
||||
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
|
||||
req.res <- requestRes{res: res, err: nil}
|
||||
return liberrors.ErrServerSessionTeardown{}
|
||||
}
|
||||
|
||||
req.res <- requestRes{
|
||||
res: res,
|
||||
err: err,
|
||||
ss: ss,
|
||||
}
|
||||
|
||||
case sc := <-ss.connRemove:
|
||||
if _, ok := ss.conns[sc]; ok {
|
||||
delete(ss.conns, sc)
|
||||
sc.sessionRemove <- ss
|
||||
ss.connsWG.Done()
|
||||
}
|
||||
|
||||
// if session is not in state RECORD or PLAY, or protocol is TCP
|
||||
if (ss.state != ServerSessionStateRecord &&
|
||||
ss.state != ServerSessionStatePlay) ||
|
||||
*ss.setupProtocol == StreamProtocolTCP {
|
||||
|
||||
// close if there are no active connections
|
||||
if len(ss.conns) == 0 {
|
||||
return liberrors.ErrServerSessionTeardown{}
|
||||
}
|
||||
|
||||
req.res <- sessionReqRes{
|
||||
res: res,
|
||||
err: err,
|
||||
ss: ss,
|
||||
}
|
||||
|
||||
case <-checkTimeoutTicker.C:
|
||||
switch {
|
||||
// in case of record and UDP, timeout happens when no frames are being received
|
||||
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
|
||||
now := time.Now()
|
||||
lft := atomic.LoadInt64(ss.udpLastFrameTime)
|
||||
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
|
||||
// in case there's a linked TCP connection, timeout is handled in the connection
|
||||
case ss.linkedConn != nil:
|
||||
|
||||
// otherwise, timeout happens when no requests arrives
|
||||
default:
|
||||
now := time.Now()
|
||||
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
}
|
||||
|
||||
case <-receiverReportTicker.C:
|
||||
if ss.state != ServerSessionStateRecord {
|
||||
continue
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for trackID, track := range ss.announcedTracks {
|
||||
r := track.rtcpReceiver.Report(now)
|
||||
ss.WriteFrame(trackID, StreamTypeRTCP, r)
|
||||
}
|
||||
|
||||
case <-ss.innerTerminate:
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
|
||||
case <-checkTimeoutTicker.C:
|
||||
switch {
|
||||
// in case of RECORD and UDP, timeout happens when no frames are being received
|
||||
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
|
||||
now := time.Now()
|
||||
lft := atomic.LoadInt64(ss.udpLastFrameTime)
|
||||
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
|
||||
// in case of PLAY and UDP, timeout happens when no request arrives
|
||||
case ss.state == ServerSessionStatePlay && *ss.setupProtocol == StreamProtocolUDP:
|
||||
now := time.Now()
|
||||
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
|
||||
// otherwise, there's no timeout until all associated connections are closed
|
||||
}
|
||||
|
||||
case <-receiverReportTicker.C:
|
||||
if ss.state != ServerSessionStateRecord {
|
||||
continue
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for trackID, track := range ss.announcedTracks {
|
||||
r := track.rtcpReceiver.Report(now)
|
||||
ss.WriteFrame(trackID, StreamTypeRTCP, r)
|
||||
}
|
||||
|
||||
case <-ss.innerTerminate:
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
var err error
|
||||
select {
|
||||
case err = <-readDone:
|
||||
go func() {
|
||||
for req := range ss.request {
|
||||
req.res <- sessionReqRes{
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case req, ok := <-ss.request:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
req.res <- requestRes{
|
||||
ss: nil,
|
||||
res: &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
},
|
||||
err: liberrors.ErrServerTerminated{},
|
||||
}
|
||||
|
||||
case sc, ok := <-ss.connRemove:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := ss.conns[sc]; ok {
|
||||
ss.connsWG.Done()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ss.s.sessionClose <- ss
|
||||
<-ss.terminate
|
||||
|
||||
case <-ss.terminate:
|
||||
select {
|
||||
case ss.innerTerminate <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-readDone
|
||||
|
||||
err = liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}()
|
||||
|
||||
switch ss.state {
|
||||
case ServerSessionStatePlay:
|
||||
@@ -321,11 +350,16 @@ func (ss *ServerSession) run() {
|
||||
}
|
||||
}
|
||||
|
||||
if ss.linkedConn != nil {
|
||||
ss.s.connClose <- ss.linkedConn
|
||||
for sc := range ss.conns {
|
||||
sc.sessionRemove <- ss
|
||||
}
|
||||
ss.connsWG.Wait()
|
||||
|
||||
ss.s.sessionClose <- ss
|
||||
<-ss.parentTerminate
|
||||
|
||||
close(ss.request)
|
||||
close(ss.connRemove)
|
||||
|
||||
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
|
||||
h.OnSessionClose(ss, err)
|
||||
@@ -333,7 +367,7 @@ func (ss *ServerSession) run() {
|
||||
}
|
||||
|
||||
func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) {
|
||||
if ss.linkedConn != nil && sc != ss.linkedConn {
|
||||
if ss.tcpConn != nil && sc != ss.tcpConn {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerSessionLinkedToOtherConn{}
|
||||
@@ -620,10 +654,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
// this was causing problems during unit tests.
|
||||
if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 &&
|
||||
strings.HasPrefix(ua[0], "GStreamer") {
|
||||
select {
|
||||
case <-time.After(1 * time.Second):
|
||||
case <-sc.terminate:
|
||||
}
|
||||
<-time.After(1 * time.Second)
|
||||
}
|
||||
|
||||
return res, err
|
||||
@@ -664,7 +695,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
ss.udpIP = sc.ip()
|
||||
ss.udpZone = sc.zone()
|
||||
} else {
|
||||
ss.linkedConn = sc
|
||||
ss.tcpConn = sc
|
||||
}
|
||||
}
|
||||
|
||||
@@ -694,7 +725,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
|
||||
ss.udpIP = nil
|
||||
ss.udpZone = ""
|
||||
ss.linkedConn = nil
|
||||
ss.tcpConn = nil
|
||||
}
|
||||
|
||||
return res, err
|
||||
@@ -732,7 +763,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
ss.udpIP = sc.ip()
|
||||
ss.udpZone = sc.zone()
|
||||
} else {
|
||||
ss.linkedConn = sc
|
||||
ss.tcpConn = sc
|
||||
}
|
||||
|
||||
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
|
||||
@@ -766,7 +797,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
|
||||
ss.udpIP = nil
|
||||
ss.udpZone = ""
|
||||
ss.linkedConn = nil
|
||||
ss.tcpConn = nil
|
||||
|
||||
return res, err
|
||||
|
||||
@@ -809,7 +840,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
ss.state = ServerSessionStatePrePlay
|
||||
ss.udpIP = nil
|
||||
ss.udpZone = ""
|
||||
ss.linkedConn = nil
|
||||
ss.tcpConn = nil
|
||||
|
||||
if *ss.setupProtocol == StreamProtocolUDP {
|
||||
ss.s.udpRTCPListener.removeClient(ss)
|
||||
@@ -821,7 +852,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
ss.state = ServerSessionStatePreRecord
|
||||
ss.udpIP = nil
|
||||
ss.udpZone = ""
|
||||
ss.linkedConn = nil
|
||||
ss.tcpConn = nil
|
||||
|
||||
if *ss.setupProtocol == StreamProtocolUDP {
|
||||
ss.s.udpRTPListener.removeClient(ss)
|
||||
@@ -834,8 +865,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
return res, err
|
||||
|
||||
case base.Teardown:
|
||||
ss.linkedConn = nil
|
||||
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, liberrors.ErrServerSessionTeardown{}
|
||||
@@ -896,7 +925,7 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload
|
||||
})
|
||||
}
|
||||
} else {
|
||||
ss.linkedConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{
|
||||
ss.tcpConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{
|
||||
TrackID: trackID,
|
||||
StreamType: streamType,
|
||||
Payload: payload,
|
||||
|
Reference in New Issue
Block a user