server: close session when there are no conns attached to it

This commit is contained in:
aler9
2021-05-08 15:35:13 +02:00
parent d3361ffd90
commit 028ed2b973
6 changed files with 366 additions and 267 deletions

View File

@@ -14,7 +14,7 @@ Features:
* Client * Client
* Query servers about available streams * Query servers about available streams
* Encrypt connection with TLS (RTSPS) * Encrypt connections with TLS (RTSPS)
* Read * Read
* Read streams from servers with UDP or TCP * Read streams from servers with UDP or TCP
* Switch protocol automatically (switch to TCP in case of server error or UDP timeout) * Switch protocol automatically (switch to TCP in case of server error or UDP timeout)
@@ -29,9 +29,9 @@ Features:
* Server * Server
* Handle requests from clients * Handle requests from clients
* Sessions and connections are independent; clients can control multiple sessions * Sessions and connections are independent; clients can control multiple sessions
* Encrypt connections with TLS (RTSPS)
* Read streams from clients with UDP or TCP * Read streams from clients with UDP or TCP
* Write streams to clients with UDP or TCP * Write streams to clients with UDP or TCP
* Encrypt streams with TLS (RTSPS)
* Generate RTCP receiver reports automatically * Generate RTCP receiver reports automatically
* Utilities * Utilities
* Encode and decode RTSP primitives, RTP/H264, RTP/AAC, SDP * Encode and decode RTSP primitives, RTP/H264, RTP/AAC, SDP

View File

@@ -44,18 +44,18 @@ func newSessionID(sessions map[string]*ServerSession) (string, error) {
} }
} }
type sessionReqRes struct { type requestRes struct {
ss *ServerSession
res *base.Response res *base.Response
err error err error
ss *ServerSession
} }
type sessionReq struct { type request struct {
sc *ServerConn sc *ServerConn
req *base.Request req *base.Request
id string id string
create bool create bool
res chan sessionReqRes res chan requestRes
} }
// Server is a RTSP server. // Server is a RTSP server.
@@ -134,10 +134,10 @@ type Server struct {
exitError error exitError error
// in // in
connClose chan *ServerConn connClose chan *ServerConn
sessionReq chan sessionReq sessionRequest chan request
sessionClose chan *ServerSession sessionClose chan *ServerSession
terminate chan struct{} terminate chan struct{}
// out // out
done chan struct{} done chan struct{}
@@ -237,7 +237,7 @@ func (s *Server) run() {
s.sessions = make(map[string]*ServerSession) s.sessions = make(map[string]*ServerSession)
s.conns = make(map[*ServerConn]struct{}) s.conns = make(map[*ServerConn]struct{})
s.connClose = make(chan *ServerConn) s.connClose = make(chan *ServerConn)
s.sessionReq = make(chan sessionReq) s.sessionRequest = make(chan request)
s.sessionClose = make(chan *ServerSession) s.sessionClose = make(chan *ServerSession)
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -276,13 +276,13 @@ outer:
} }
s.doConnClose(sc) s.doConnClose(sc)
case req := <-s.sessionReq: case req := <-s.sessionRequest:
if ss, ok := s.sessions[req.id]; ok { if ss, ok := s.sessions[req.id]; ok {
ss.request <- req ss.request <- req
} else { } else {
if !req.create { if !req.create {
req.res <- sessionReqRes{ req.res <- requestRes{
res: &base.Response{ res: &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, },
@@ -293,7 +293,7 @@ outer:
id, err := newSessionID(s.sessions) id, err := newSessionID(s.sessions)
if err != nil { if err != nil {
req.res <- sessionReqRes{ req.res <- requestRes{
res: &base.Response{ res: &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, },
@@ -330,6 +330,7 @@ outer:
if !ok { if !ok {
return return
} }
nconn.Close() nconn.Close()
case _, ok := <-s.connClose: case _, ok := <-s.connClose:
@@ -337,11 +338,12 @@ outer:
return return
} }
case req, ok := <-s.sessionReq: case req, ok := <-s.sessionRequest:
if !ok { if !ok {
return return
} }
req.res <- sessionReqRes{
req.res <- requestRes{
res: &base.Response{ res: &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, },
@@ -379,7 +381,7 @@ outer:
close(acceptErr) close(acceptErr)
close(connNew) close(connNew)
close(s.connClose) close(s.connClose)
close(s.sessionReq) close(s.sessionRequest)
close(s.sessionClose) close(s.sessionClose)
close(s.done) close(s.done)
} }
@@ -409,10 +411,12 @@ func (s *Server) StartAndWait(address string) error {
func (s *Server) doConnClose(sc *ServerConn) { func (s *Server) doConnClose(sc *ServerConn) {
delete(s.conns, sc) delete(s.conns, sc)
close(sc.terminate) close(sc.parentTerminate)
sc.Close()
} }
func (s *Server) doSessionClose(ss *ServerSession) { func (s *Server) doSessionClose(ss *ServerSession) {
delete(s.sessions, ss.id) delete(s.sessions, ss.id)
close(ss.terminate) close(ss.parentTerminate)
ss.Close()
} }

View File

@@ -935,7 +935,7 @@ func TestServerReadPlayPausePause(t *testing.T) {
func TestServerReadTimeout(t *testing.T) { func TestServerReadTimeout(t *testing.T) {
for _, proto := range []string{ for _, proto := range []string{
"udp", "udp",
// checking TCP is useless, since there's no timeout when reading with TCP // there's no timeout when reading with TCP
} { } {
t.Run(proto, func(t *testing.T) { t.Run(proto, func(t *testing.T) {
sessionClosed := make(chan struct{}) sessionClosed := make(chan struct{})

View File

@@ -929,3 +929,53 @@ func TestServerSessionClose(t *testing.T) {
<-sessionClosed <-sessionClosed
} }
func TestServerSessionAutoClose(t *testing.T) {
sessionClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onSessionClose: func(ss *ServerSession, err error) {
close(sessionClosed)
},
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
}
err := s.Start("127.0.0.1:8554")
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
err = base.Request{
Method: base.Setup,
URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": headers.Transport{
Protocol: StreamProtocolTCP,
Delivery: func() *base.StreamDelivery {
v := base.StreamDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModePlay
return &v
}(),
InterleavedIDs: &[2]int{0, 1},
}.Write(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
conn.Close()
<-sessionClosed
}

View File

@@ -35,6 +35,11 @@ func getSessionID(header base.Header) string {
return "" return ""
} }
type readReq struct {
req *base.Request
res chan error
}
// ServerConn is a server-side RTSP connection. // ServerConn is a server-side RTSP connection.
type ServerConn struct { type ServerConn struct {
s *Server s *Server
@@ -43,19 +48,23 @@ type ServerConn struct {
br *bufio.Reader br *bufio.Reader
bw *bufio.Writer bw *bufio.Writer
sessions map[string]*ServerSession
sessionsWG sync.WaitGroup
// TCP stream protocol // TCP stream protocol
tcpFrameLinkedSession *ServerSession
tcpFrameIsRecording bool
tcpFrameSetEnabled bool tcpFrameSetEnabled bool
tcpFrameEnabled bool tcpFrameEnabled bool
tcpSession *ServerSession
tcpFrameIsRecording bool
tcpFrameTimeout bool tcpFrameTimeout bool
tcpFrameBuffer *multibuffer.MultiBuffer tcpFrameBuffer *multibuffer.MultiBuffer
tcpFrameWriteBuffer *ringbuffer.RingBuffer tcpFrameWriteBuffer *ringbuffer.RingBuffer
tcpFrameBackgroundWriteDone chan struct{} tcpFrameBackgroundWriteDone chan struct{}
// in // in
innerTerminate chan struct{} sessionRemove chan *ServerSession
terminate chan struct{} innerTerminate chan struct{}
parentTerminate chan struct{}
} }
func newServerConn( func newServerConn(
@@ -64,11 +73,12 @@ func newServerConn(
nconn net.Conn) *ServerConn { nconn net.Conn) *ServerConn {
sc := &ServerConn{ sc := &ServerConn{
s: s, s: s,
wg: wg, wg: wg,
nconn: nconn, nconn: nconn,
innerTerminate: make(chan struct{}, 1), sessionRemove: make(chan *ServerSession),
terminate: make(chan struct{}), innerTerminate: make(chan struct{}, 1),
parentTerminate: make(chan struct{}),
} }
wg.Add(1) wg.Add(1)
@@ -115,13 +125,17 @@ func (sc *ServerConn) run() {
sc.br = bufio.NewReaderSize(conn, serverConnReadBufferSize) sc.br = bufio.NewReaderSize(conn, serverConnReadBufferSize)
sc.bw = bufio.NewWriterSize(conn, serverConnWriteBufferSize) sc.bw = bufio.NewWriterSize(conn, serverConnWriteBufferSize)
sc.sessions = make(map[string]*ServerSession)
// instantiate always to allow writing to this conn before Play() // instantiate always to allow writing to this conn before Play()
sc.tcpFrameWriteBuffer = ringbuffer.New(uint64(sc.s.ReadBufferCount)) sc.tcpFrameWriteBuffer = ringbuffer.New(uint64(sc.s.ReadBufferCount))
readDone := make(chan error) readRequest := make(chan readReq)
readErr := make(chan error)
readDone := make(chan struct{})
go func() { go func() {
readDone <- func() error { defer close(readDone)
readErr <- func() error {
var req base.Request var req base.Request
var frame base.InterleavedFrame var frame base.InterleavedFrame
@@ -140,15 +154,15 @@ func (sc *ServerConn) run() {
switch what.(type) { switch what.(type) {
case *base.InterleavedFrame: case *base.InterleavedFrame:
// forward frame only if it has been set up // forward frame only if it has been set up
if _, ok := sc.tcpFrameLinkedSession.setuppedTracks[frame.TrackID]; ok { if _, ok := sc.tcpSession.setuppedTracks[frame.TrackID]; ok {
if sc.tcpFrameIsRecording { if sc.tcpFrameIsRecording {
sc.tcpFrameLinkedSession.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(), sc.tcpSession.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
frame.StreamType, frame.Payload) frame.StreamType, frame.Payload)
} }
if h, ok := sc.s.Handler.(ServerHandlerOnFrame); ok { if h, ok := sc.s.Handler.(ServerHandlerOnFrame); ok {
h.OnFrame(&ServerHandlerOnFrameCtx{ h.OnFrame(&ServerHandlerOnFrameCtx{
Session: sc.tcpFrameLinkedSession, Session: sc.tcpSession,
TrackID: frame.TrackID, TrackID: frame.TrackID,
StreamType: frame.StreamType, StreamType: frame.StreamType,
Payload: frame.Payload, Payload: frame.Payload,
@@ -157,7 +171,9 @@ func (sc *ServerConn) run() {
} }
case *base.Request: case *base.Request:
err := sc.handleRequestOuter(&req) cres := make(chan error)
readRequest <- readReq{req: &req, res: cres}
err := <-cres
if err != nil { if err != nil {
return err return err
} }
@@ -169,7 +185,9 @@ func (sc *ServerConn) run() {
return err return err
} }
err = sc.handleRequestOuter(&req) cres := make(chan error)
readRequest <- readReq{req: &req, res: cres}
err = <-cres
if err != nil { if err != nil {
return err return err
} }
@@ -179,51 +197,74 @@ func (sc *ServerConn) run() {
}() }()
err := func() error { err := func() error {
select { for {
case err := <-readDone: select {
if sc.tcpFrameEnabled { case req := <-readRequest:
sc.tcpFrameWriteBuffer.Close() req.res <- sc.handleRequestOuter(req.req)
<-sc.tcpFrameBackgroundWriteDone
case err := <-readErr:
return err
case ss := <-sc.sessionRemove:
if _, ok := sc.sessions[ss.ID()]; ok {
delete(sc.sessions, ss.ID())
ss.connRemove <- sc
sc.sessionsWG.Done()
}
case <-sc.innerTerminate:
return liberrors.ErrServerTerminated{}
} }
sc.nconn.Close()
sc.s.connClose <- sc
<-sc.terminate
return err
case <-sc.innerTerminate:
sc.nconn.Close()
<-readDone
if sc.tcpFrameEnabled {
sc.tcpFrameWriteBuffer.Close()
<-sc.tcpFrameBackgroundWriteDone
}
sc.s.connClose <- sc
<-sc.terminate
return liberrors.ErrServerTerminated{}
case <-sc.terminate:
sc.nconn.Close()
<-readDone
if sc.tcpFrameEnabled {
sc.tcpFrameWriteBuffer.Close()
<-sc.tcpFrameBackgroundWriteDone
}
return liberrors.ErrServerTerminated{}
} }
}() }()
go func() {
for {
select {
case req, ok := <-readRequest:
if !ok {
return
}
req.res <- liberrors.ErrServerTerminated{}
case _, ok := <-readErr:
if !ok {
return
}
case ss, ok := <-sc.sessionRemove:
if !ok {
return
}
if _, ok := sc.sessions[ss.ID()]; ok {
sc.sessionsWG.Done()
}
}
}
}()
sc.nconn.Close()
<-readDone
if sc.tcpFrameEnabled { if sc.tcpFrameEnabled {
sc.s.sessionClose <- sc.tcpFrameLinkedSession sc.tcpFrameWriteBuffer.Close()
<-sc.tcpFrameBackgroundWriteDone
} }
for _, ss := range sc.sessions {
ss.connRemove <- sc
}
sc.sessionsWG.Wait()
sc.s.connClose <- sc
<-sc.parentTerminate
close(readRequest)
close(readErr)
close(sc.sessionRemove)
if h, ok := sc.s.Handler.(ServerHandlerOnConnClose); ok { if h, ok := sc.s.Handler.(ServerHandlerOnConnClose); ok {
h.OnConnClose(sc, err) h.OnConnClose(sc, err)
} }
@@ -241,8 +282,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
// the connection can't communicate with another session // the connection can't communicate with another session
// if it's receiving or sending TCP frames. // if it's receiving or sending TCP frames.
if sc.tcpFrameLinkedSession != nil && if sc.tcpSession != nil &&
sxID != sc.tcpFrameLinkedSession.id { sxID != sc.tcpSession.id {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerLinkedToOtherSession{} }, liberrors.ErrServerLinkedToOtherSession{}
@@ -252,16 +293,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Options: case base.Options:
// handle request in session // handle request in session
if sxID != "" { if sxID != "" {
cres := make(chan sessionReqRes) _, res, err := sc.handleRequestInSession(sxID, req, false)
sc.s.sessionReq <- sessionReq{ return res, err
sc: sc,
req: req,
id: sxID,
create: false,
res: cres,
}
res := <-cres
return res.res, res.err
} }
// handle request here // handle request here
@@ -330,121 +363,65 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
case base.Announce: case base.Announce:
if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok { if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
cres := make(chan sessionReqRes) _, res, err := sc.handleRequestInSession(sxID, req, true)
sc.s.sessionReq <- sessionReq{ return res, err
sc: sc,
req: req,
id: sxID,
create: true,
res: cres,
}
res := <-cres
return res.res, res.err
} }
case base.Setup: case base.Setup:
if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok { if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
cres := make(chan sessionReqRes) _, res, err := sc.handleRequestInSession(sxID, req, true)
sc.s.sessionReq <- sessionReq{ return res, err
sc: sc,
req: req,
id: sxID,
create: true,
res: cres,
}
res := <-cres
return res.res, res.err
} }
case base.Play: case base.Play:
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok { if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
cres := make(chan sessionReqRes) ss, res, err := sc.handleRequestInSession(sxID, req, false)
sc.s.sessionReq <- sessionReq{
sc: sc,
req: req,
id: sxID,
create: false,
res: cres,
}
res := <-cres
if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok { if _, ok := err.(liberrors.ErrServerTCPFramesEnable); ok {
sc.tcpFrameLinkedSession = res.ss sc.tcpSession = ss
sc.tcpFrameIsRecording = false sc.tcpFrameIsRecording = false
sc.tcpFrameSetEnabled = true sc.tcpFrameSetEnabled = true
return res.res, nil return res, nil
} }
return res.res, res.err return res, err
} }
case base.Record: case base.Record:
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok { if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
cres := make(chan sessionReqRes) ss, res, err := sc.handleRequestInSession(sxID, req, false)
sc.s.sessionReq <- sessionReq{
sc: sc,
req: req,
id: sxID,
create: false,
res: cres,
}
res := <-cres
if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok { if _, ok := err.(liberrors.ErrServerTCPFramesEnable); ok {
sc.tcpFrameLinkedSession = res.ss sc.tcpSession = ss
sc.tcpFrameIsRecording = true sc.tcpFrameIsRecording = true
sc.tcpFrameSetEnabled = true sc.tcpFrameSetEnabled = true
return res.res, nil return res, nil
} }
return res.res, res.err return res, err
} }
case base.Pause: case base.Pause:
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok { if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
cres := make(chan sessionReqRes) _, res, err := sc.handleRequestInSession(sxID, req, false)
sc.s.sessionReq <- sessionReq{
sc: sc,
req: req,
id: sxID,
create: false,
res: cres,
}
res := <-cres
if _, ok := res.err.(liberrors.ErrServerTCPFramesDisable); ok { if _, ok := err.(liberrors.ErrServerTCPFramesDisable); ok {
sc.tcpFrameSetEnabled = false sc.tcpFrameSetEnabled = false
return res.res, nil return res, nil
} }
return res.res, res.err return res, err
} }
case base.Teardown: case base.Teardown:
cres := make(chan sessionReqRes) _, res, err := sc.handleRequestInSession(sxID, req, false)
sc.s.sessionReq <- sessionReq{ return res, err
sc: sc,
req: req,
id: sxID,
create: false,
res: cres,
}
res := <-cres
return res.res, res.err
case base.GetParameter: case base.GetParameter:
// handle request in session // handle request in session
if sxID != "" { if sxID != "" {
cres := make(chan sessionReqRes) _, res, err := sc.handleRequestInSession(sxID, req, false)
sc.s.sessionReq <- sessionReq{ return res, err
sc: sc,
req: req,
id: sxID,
create: false,
res: cres,
}
res := <-cres
return res.res, res.err
} }
// handle request here // handle request here
@@ -563,6 +540,45 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
return err return err
} }
func (sc *ServerConn) handleRequestInSession(sxID string, req *base.Request, create bool,
) (*ServerSession, *base.Response, error) {
// if the session is already linked to this conn, communicate directly with it
if sxID != "" {
if ss, ok := sc.sessions[sxID]; ok {
cres := make(chan requestRes)
ss.request <- request{
sc: sc,
req: req,
id: sxID,
create: create,
res: cres,
}
res := <-cres
return ss, res.res, res.err
}
}
// otherwise, pass through Server
cres := make(chan requestRes)
sc.s.sessionRequest <- request{
sc: sc,
req: req,
id: sxID,
create: create,
res: cres,
}
res := <-cres
if res.ss != nil {
sc.sessions[res.ss.ID()] = res.ss
sc.sessionsWG.Add(1)
}
return res.ss, res.res, res.err
}
func (sc *ServerConn) tcpFrameBackgroundWrite() { func (sc *ServerConn) tcpFrameBackgroundWrite() {
defer close(sc.tcpFrameBackgroundWriteDone) defer close(sc.tcpFrameBackgroundWriteDone)

View File

@@ -118,33 +118,42 @@ type ServerSession struct {
id string id string
wg *sync.WaitGroup wg *sync.WaitGroup
conns map[*ServerConn]struct{}
connsWG sync.WaitGroup
state ServerSessionState state ServerSessionState
setuppedTracks map[int]ServerSessionSetuppedTrack setuppedTracks map[int]ServerSessionSetuppedTrack
setupProtocol *StreamProtocol setupProtocol *StreamProtocol
setupPath *string setupPath *string
setupQuery *string setupQuery *string
lastRequestTime time.Time lastRequestTime time.Time
linkedConn *ServerConn // tcp tcpConn *ServerConn // tcp
udpIP net.IP // udp udpIP net.IP // udp
udpZone string // udp udpZone string // udp
announcedTracks []ServerSessionAnnouncedTrack // publish announcedTracks []ServerSessionAnnouncedTrack // publish
udpLastFrameTime *int64 // publish, udp udpLastFrameTime *int64 // publish, udp
// in // in
request chan sessionReq request chan request
innerTerminate chan struct{} connRemove chan *ServerConn
terminate chan struct{} innerTerminate chan struct{}
parentTerminate chan struct{}
} }
func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession { func newServerSession(s *Server,
id string,
wg *sync.WaitGroup,
) *ServerSession {
ss := &ServerSession{ ss := &ServerSession{
s: s, s: s,
id: id, id: id,
wg: wg, wg: wg,
conns: make(map[*ServerConn]struct{}),
lastRequestTime: time.Now(), lastRequestTime: time.Now(),
request: make(chan sessionReq), request: make(chan request),
connRemove: make(chan *ServerConn),
innerTerminate: make(chan struct{}, 1), innerTerminate: make(chan struct{}, 1),
terminate: make(chan struct{}), parentTerminate: make(chan struct{}),
} }
wg.Add(1) wg.Add(1)
@@ -208,105 +217,125 @@ func (ss *ServerSession) run() {
h.OnSessionOpen(ss) h.OnSessionOpen(ss)
} }
readDone := make(chan error) err := func() error {
go func() { checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
readDone <- func() error { defer checkTimeoutTicker.Stop()
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
defer checkTimeoutTicker.Stop()
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod) receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
defer receiverReportTicker.Stop() defer receiverReportTicker.Stop()
for { for {
select { select {
case req := <-ss.request: case req := <-ss.request:
res, err := ss.handleRequest(req.sc, req.req) ss.lastRequestTime = time.Now()
ss.lastRequestTime = time.Now() if _, ok := ss.conns[req.sc]; !ok {
ss.conns[req.sc] = struct{}{}
ss.connsWG.Add(1)
}
if res.StatusCode == base.StatusOK { res, err := ss.handleRequest(req.sc, req.req)
if res.Header == nil {
res.Header = make(base.Header) if res.StatusCode == base.StatusOK {
} if res.Header == nil {
res.Header["Session"] = base.HeaderValue{ss.id} res.Header = make(base.Header)
} }
res.Header["Session"] = base.HeaderValue{ss.id}
}
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok { if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
req.res <- sessionReqRes{res: res, err: nil} req.res <- requestRes{res: res, err: nil}
return liberrors.ErrServerSessionTeardown{}
}
req.res <- requestRes{
res: res,
err: err,
ss: ss,
}
case sc := <-ss.connRemove:
if _, ok := ss.conns[sc]; ok {
delete(ss.conns, sc)
sc.sessionRemove <- ss
ss.connsWG.Done()
}
// if session is not in state RECORD or PLAY, or protocol is TCP
if (ss.state != ServerSessionStateRecord &&
ss.state != ServerSessionStatePlay) ||
*ss.setupProtocol == StreamProtocolTCP {
// close if there are no active connections
if len(ss.conns) == 0 {
return liberrors.ErrServerSessionTeardown{} return liberrors.ErrServerSessionTeardown{}
} }
req.res <- sessionReqRes{
res: res,
err: err,
ss: ss,
}
case <-checkTimeoutTicker.C:
switch {
// in case of record and UDP, timeout happens when no frames are being received
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
now := time.Now()
lft := atomic.LoadInt64(ss.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
return liberrors.ErrServerSessionTimedOut{}
}
// in case there's a linked TCP connection, timeout is handled in the connection
case ss.linkedConn != nil:
// otherwise, timeout happens when no requests arrives
default:
now := time.Now()
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
return liberrors.ErrServerSessionTimedOut{}
}
}
case <-receiverReportTicker.C:
if ss.state != ServerSessionStateRecord {
continue
}
now := time.Now()
for trackID, track := range ss.announcedTracks {
r := track.rtcpReceiver.Report(now)
ss.WriteFrame(trackID, StreamTypeRTCP, r)
}
case <-ss.innerTerminate:
return liberrors.ErrServerTerminated{}
} }
case <-checkTimeoutTicker.C:
switch {
// in case of RECORD and UDP, timeout happens when no frames are being received
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
now := time.Now()
lft := atomic.LoadInt64(ss.udpLastFrameTime)
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
return liberrors.ErrServerSessionTimedOut{}
}
// in case of PLAY and UDP, timeout happens when no request arrives
case ss.state == ServerSessionStatePlay && *ss.setupProtocol == StreamProtocolUDP:
now := time.Now()
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
return liberrors.ErrServerSessionTimedOut{}
}
// otherwise, there's no timeout until all associated connections are closed
}
case <-receiverReportTicker.C:
if ss.state != ServerSessionStateRecord {
continue
}
now := time.Now()
for trackID, track := range ss.announcedTracks {
r := track.rtcpReceiver.Report(now)
ss.WriteFrame(trackID, StreamTypeRTCP, r)
}
case <-ss.innerTerminate:
return liberrors.ErrServerTerminated{}
} }
}() }
}() }()
var err error go func() {
select { for {
case err = <-readDone: select {
go func() { case req, ok := <-ss.request:
for req := range ss.request { if !ok {
req.res <- sessionReqRes{ return
}
req.res <- requestRes{
ss: nil,
res: &base.Response{ res: &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, },
err: liberrors.ErrServerTerminated{}, err: liberrors.ErrServerTerminated{},
} }
case sc, ok := <-ss.connRemove:
if !ok {
return
}
if _, ok := ss.conns[sc]; ok {
ss.connsWG.Done()
}
} }
}()
ss.s.sessionClose <- ss
<-ss.terminate
case <-ss.terminate:
select {
case ss.innerTerminate <- struct{}{}:
default:
} }
<-readDone }()
err = liberrors.ErrServerTerminated{}
}
switch ss.state { switch ss.state {
case ServerSessionStatePlay: case ServerSessionStatePlay:
@@ -321,11 +350,16 @@ func (ss *ServerSession) run() {
} }
} }
if ss.linkedConn != nil { for sc := range ss.conns {
ss.s.connClose <- ss.linkedConn sc.sessionRemove <- ss
} }
ss.connsWG.Wait()
ss.s.sessionClose <- ss
<-ss.parentTerminate
close(ss.request) close(ss.request)
close(ss.connRemove)
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok { if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
h.OnSessionClose(ss, err) h.OnSessionClose(ss, err)
@@ -333,7 +367,7 @@ func (ss *ServerSession) run() {
} }
func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) { func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) {
if ss.linkedConn != nil && sc != ss.linkedConn { if ss.tcpConn != nil && sc != ss.tcpConn {
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerSessionLinkedToOtherConn{} }, liberrors.ErrServerSessionLinkedToOtherConn{}
@@ -620,10 +654,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
// this was causing problems during unit tests. // this was causing problems during unit tests.
if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 && if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 &&
strings.HasPrefix(ua[0], "GStreamer") { strings.HasPrefix(ua[0], "GStreamer") {
select { <-time.After(1 * time.Second)
case <-time.After(1 * time.Second):
case <-sc.terminate:
}
} }
return res, err return res, err
@@ -664,7 +695,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.udpIP = sc.ip() ss.udpIP = sc.ip()
ss.udpZone = sc.zone() ss.udpZone = sc.zone()
} else { } else {
ss.linkedConn = sc ss.tcpConn = sc
} }
} }
@@ -694,7 +725,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.udpIP = nil ss.udpIP = nil
ss.udpZone = "" ss.udpZone = ""
ss.linkedConn = nil ss.tcpConn = nil
} }
return res, err return res, err
@@ -732,7 +763,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.udpIP = sc.ip() ss.udpIP = sc.ip()
ss.udpZone = sc.zone() ss.udpZone = sc.zone()
} else { } else {
ss.linkedConn = sc ss.tcpConn = sc
} }
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{ res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
@@ -766,7 +797,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.udpIP = nil ss.udpIP = nil
ss.udpZone = "" ss.udpZone = ""
ss.linkedConn = nil ss.tcpConn = nil
return res, err return res, err
@@ -809,7 +840,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.state = ServerSessionStatePrePlay ss.state = ServerSessionStatePrePlay
ss.udpIP = nil ss.udpIP = nil
ss.udpZone = "" ss.udpZone = ""
ss.linkedConn = nil ss.tcpConn = nil
if *ss.setupProtocol == StreamProtocolUDP { if *ss.setupProtocol == StreamProtocolUDP {
ss.s.udpRTCPListener.removeClient(ss) ss.s.udpRTCPListener.removeClient(ss)
@@ -821,7 +852,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.state = ServerSessionStatePreRecord ss.state = ServerSessionStatePreRecord
ss.udpIP = nil ss.udpIP = nil
ss.udpZone = "" ss.udpZone = ""
ss.linkedConn = nil ss.tcpConn = nil
if *ss.setupProtocol == StreamProtocolUDP { if *ss.setupProtocol == StreamProtocolUDP {
ss.s.udpRTPListener.removeClient(ss) ss.s.udpRTPListener.removeClient(ss)
@@ -834,8 +865,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
return res, err return res, err
case base.Teardown: case base.Teardown:
ss.linkedConn = nil
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, liberrors.ErrServerSessionTeardown{} }, liberrors.ErrServerSessionTeardown{}
@@ -896,7 +925,7 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload
}) })
} }
} else { } else {
ss.linkedConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{ ss.tcpConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{
TrackID: trackID, TrackID: trackID,
StreamType: streamType, StreamType: streamType,
Payload: payload, Payload: payload,