mirror of
https://github.com/aler9/gortsplib
synced 2025-10-05 07:06:58 +08:00
server: close session when there are no conns attached to it
This commit is contained in:
@@ -14,7 +14,7 @@ Features:
|
||||
|
||||
* Client
|
||||
* Query servers about available streams
|
||||
* Encrypt connection with TLS (RTSPS)
|
||||
* Encrypt connections with TLS (RTSPS)
|
||||
* Read
|
||||
* Read streams from servers with UDP or TCP
|
||||
* Switch protocol automatically (switch to TCP in case of server error or UDP timeout)
|
||||
@@ -29,9 +29,9 @@ Features:
|
||||
* Server
|
||||
* Handle requests from clients
|
||||
* Sessions and connections are independent; clients can control multiple sessions
|
||||
* Encrypt connections with TLS (RTSPS)
|
||||
* Read streams from clients with UDP or TCP
|
||||
* Write streams to clients with UDP or TCP
|
||||
* Encrypt streams with TLS (RTSPS)
|
||||
* Generate RTCP receiver reports automatically
|
||||
* Utilities
|
||||
* Encode and decode RTSP primitives, RTP/H264, RTP/AAC, SDP
|
||||
|
38
server.go
38
server.go
@@ -44,18 +44,18 @@ func newSessionID(sessions map[string]*ServerSession) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
type sessionReqRes struct {
|
||||
type requestRes struct {
|
||||
ss *ServerSession
|
||||
res *base.Response
|
||||
err error
|
||||
ss *ServerSession
|
||||
}
|
||||
|
||||
type sessionReq struct {
|
||||
type request struct {
|
||||
sc *ServerConn
|
||||
req *base.Request
|
||||
id string
|
||||
create bool
|
||||
res chan sessionReqRes
|
||||
res chan requestRes
|
||||
}
|
||||
|
||||
// Server is a RTSP server.
|
||||
@@ -134,10 +134,10 @@ type Server struct {
|
||||
exitError error
|
||||
|
||||
// in
|
||||
connClose chan *ServerConn
|
||||
sessionReq chan sessionReq
|
||||
sessionClose chan *ServerSession
|
||||
terminate chan struct{}
|
||||
connClose chan *ServerConn
|
||||
sessionRequest chan request
|
||||
sessionClose chan *ServerSession
|
||||
terminate chan struct{}
|
||||
|
||||
// out
|
||||
done chan struct{}
|
||||
@@ -237,7 +237,7 @@ func (s *Server) run() {
|
||||
s.sessions = make(map[string]*ServerSession)
|
||||
s.conns = make(map[*ServerConn]struct{})
|
||||
s.connClose = make(chan *ServerConn)
|
||||
s.sessionReq = make(chan sessionReq)
|
||||
s.sessionRequest = make(chan request)
|
||||
s.sessionClose = make(chan *ServerSession)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
@@ -276,13 +276,13 @@ outer:
|
||||
}
|
||||
s.doConnClose(sc)
|
||||
|
||||
case req := <-s.sessionReq:
|
||||
case req := <-s.sessionRequest:
|
||||
if ss, ok := s.sessions[req.id]; ok {
|
||||
ss.request <- req
|
||||
|
||||
} else {
|
||||
if !req.create {
|
||||
req.res <- sessionReqRes{
|
||||
req.res <- requestRes{
|
||||
res: &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
},
|
||||
@@ -293,7 +293,7 @@ outer:
|
||||
|
||||
id, err := newSessionID(s.sessions)
|
||||
if err != nil {
|
||||
req.res <- sessionReqRes{
|
||||
req.res <- requestRes{
|
||||
res: &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
},
|
||||
@@ -330,6 +330,7 @@ outer:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
nconn.Close()
|
||||
|
||||
case _, ok := <-s.connClose:
|
||||
@@ -337,11 +338,12 @@ outer:
|
||||
return
|
||||
}
|
||||
|
||||
case req, ok := <-s.sessionReq:
|
||||
case req, ok := <-s.sessionRequest:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
req.res <- sessionReqRes{
|
||||
|
||||
req.res <- requestRes{
|
||||
res: &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
},
|
||||
@@ -379,7 +381,7 @@ outer:
|
||||
close(acceptErr)
|
||||
close(connNew)
|
||||
close(s.connClose)
|
||||
close(s.sessionReq)
|
||||
close(s.sessionRequest)
|
||||
close(s.sessionClose)
|
||||
close(s.done)
|
||||
}
|
||||
@@ -409,10 +411,12 @@ func (s *Server) StartAndWait(address string) error {
|
||||
|
||||
func (s *Server) doConnClose(sc *ServerConn) {
|
||||
delete(s.conns, sc)
|
||||
close(sc.terminate)
|
||||
close(sc.parentTerminate)
|
||||
sc.Close()
|
||||
}
|
||||
|
||||
func (s *Server) doSessionClose(ss *ServerSession) {
|
||||
delete(s.sessions, ss.id)
|
||||
close(ss.terminate)
|
||||
close(ss.parentTerminate)
|
||||
ss.Close()
|
||||
}
|
||||
|
@@ -935,7 +935,7 @@ func TestServerReadPlayPausePause(t *testing.T) {
|
||||
func TestServerReadTimeout(t *testing.T) {
|
||||
for _, proto := range []string{
|
||||
"udp",
|
||||
// checking TCP is useless, since there's no timeout when reading with TCP
|
||||
// there's no timeout when reading with TCP
|
||||
} {
|
||||
t.Run(proto, func(t *testing.T) {
|
||||
sessionClosed := make(chan struct{})
|
||||
|
@@ -929,3 +929,53 @@ func TestServerSessionClose(t *testing.T) {
|
||||
|
||||
<-sessionClosed
|
||||
}
|
||||
|
||||
func TestServerSessionAutoClose(t *testing.T) {
|
||||
sessionClosed := make(chan struct{})
|
||||
|
||||
s := &Server{
|
||||
Handler: &testServerHandler{
|
||||
onSessionClose: func(ss *ServerSession, err error) {
|
||||
close(sessionClosed)
|
||||
},
|
||||
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := s.Start("127.0.0.1:8554")
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
conn, err := net.Dial("tcp", "localhost:8554")
|
||||
require.NoError(t, err)
|
||||
bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||||
|
||||
err = base.Request{
|
||||
Method: base.Setup,
|
||||
URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
|
||||
Header: base.Header{
|
||||
"CSeq": base.HeaderValue{"1"},
|
||||
"Transport": headers.Transport{
|
||||
Protocol: StreamProtocolTCP,
|
||||
Delivery: func() *base.StreamDelivery {
|
||||
v := base.StreamDeliveryUnicast
|
||||
return &v
|
||||
}(),
|
||||
Mode: func() *headers.TransportMode {
|
||||
v := headers.TransportModePlay
|
||||
return &v
|
||||
}(),
|
||||
InterleavedIDs: &[2]int{0, 1},
|
||||
}.Write(),
|
||||
},
|
||||
}.Write(bconn.Writer)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn.Close()
|
||||
|
||||
<-sessionClosed
|
||||
}
|
||||
|
304
serverconn.go
304
serverconn.go
@@ -35,6 +35,11 @@ func getSessionID(header base.Header) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type readReq struct {
|
||||
req *base.Request
|
||||
res chan error
|
||||
}
|
||||
|
||||
// ServerConn is a server-side RTSP connection.
|
||||
type ServerConn struct {
|
||||
s *Server
|
||||
@@ -43,19 +48,23 @@ type ServerConn struct {
|
||||
br *bufio.Reader
|
||||
bw *bufio.Writer
|
||||
|
||||
sessions map[string]*ServerSession
|
||||
sessionsWG sync.WaitGroup
|
||||
|
||||
// TCP stream protocol
|
||||
tcpFrameLinkedSession *ServerSession
|
||||
tcpFrameIsRecording bool
|
||||
tcpFrameSetEnabled bool
|
||||
tcpFrameEnabled bool
|
||||
tcpSession *ServerSession
|
||||
tcpFrameIsRecording bool
|
||||
tcpFrameTimeout bool
|
||||
tcpFrameBuffer *multibuffer.MultiBuffer
|
||||
tcpFrameWriteBuffer *ringbuffer.RingBuffer
|
||||
tcpFrameBackgroundWriteDone chan struct{}
|
||||
|
||||
// in
|
||||
innerTerminate chan struct{}
|
||||
terminate chan struct{}
|
||||
sessionRemove chan *ServerSession
|
||||
innerTerminate chan struct{}
|
||||
parentTerminate chan struct{}
|
||||
}
|
||||
|
||||
func newServerConn(
|
||||
@@ -64,11 +73,12 @@ func newServerConn(
|
||||
nconn net.Conn) *ServerConn {
|
||||
|
||||
sc := &ServerConn{
|
||||
s: s,
|
||||
wg: wg,
|
||||
nconn: nconn,
|
||||
innerTerminate: make(chan struct{}, 1),
|
||||
terminate: make(chan struct{}),
|
||||
s: s,
|
||||
wg: wg,
|
||||
nconn: nconn,
|
||||
sessionRemove: make(chan *ServerSession),
|
||||
innerTerminate: make(chan struct{}, 1),
|
||||
parentTerminate: make(chan struct{}),
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
@@ -115,13 +125,17 @@ func (sc *ServerConn) run() {
|
||||
|
||||
sc.br = bufio.NewReaderSize(conn, serverConnReadBufferSize)
|
||||
sc.bw = bufio.NewWriterSize(conn, serverConnWriteBufferSize)
|
||||
sc.sessions = make(map[string]*ServerSession)
|
||||
|
||||
// instantiate always to allow writing to this conn before Play()
|
||||
sc.tcpFrameWriteBuffer = ringbuffer.New(uint64(sc.s.ReadBufferCount))
|
||||
|
||||
readDone := make(chan error)
|
||||
readRequest := make(chan readReq)
|
||||
readErr := make(chan error)
|
||||
readDone := make(chan struct{})
|
||||
go func() {
|
||||
readDone <- func() error {
|
||||
defer close(readDone)
|
||||
readErr <- func() error {
|
||||
var req base.Request
|
||||
var frame base.InterleavedFrame
|
||||
|
||||
@@ -140,15 +154,15 @@ func (sc *ServerConn) run() {
|
||||
switch what.(type) {
|
||||
case *base.InterleavedFrame:
|
||||
// forward frame only if it has been set up
|
||||
if _, ok := sc.tcpFrameLinkedSession.setuppedTracks[frame.TrackID]; ok {
|
||||
if _, ok := sc.tcpSession.setuppedTracks[frame.TrackID]; ok {
|
||||
if sc.tcpFrameIsRecording {
|
||||
sc.tcpFrameLinkedSession.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
|
||||
sc.tcpSession.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(),
|
||||
frame.StreamType, frame.Payload)
|
||||
}
|
||||
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnFrame); ok {
|
||||
h.OnFrame(&ServerHandlerOnFrameCtx{
|
||||
Session: sc.tcpFrameLinkedSession,
|
||||
Session: sc.tcpSession,
|
||||
TrackID: frame.TrackID,
|
||||
StreamType: frame.StreamType,
|
||||
Payload: frame.Payload,
|
||||
@@ -157,7 +171,9 @@ func (sc *ServerConn) run() {
|
||||
}
|
||||
|
||||
case *base.Request:
|
||||
err := sc.handleRequestOuter(&req)
|
||||
cres := make(chan error)
|
||||
readRequest <- readReq{req: &req, res: cres}
|
||||
err := <-cres
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -169,7 +185,9 @@ func (sc *ServerConn) run() {
|
||||
return err
|
||||
}
|
||||
|
||||
err = sc.handleRequestOuter(&req)
|
||||
cres := make(chan error)
|
||||
readRequest <- readReq{req: &req, res: cres}
|
||||
err = <-cres
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -179,51 +197,74 @@ func (sc *ServerConn) run() {
|
||||
}()
|
||||
|
||||
err := func() error {
|
||||
select {
|
||||
case err := <-readDone:
|
||||
if sc.tcpFrameEnabled {
|
||||
sc.tcpFrameWriteBuffer.Close()
|
||||
<-sc.tcpFrameBackgroundWriteDone
|
||||
for {
|
||||
select {
|
||||
case req := <-readRequest:
|
||||
req.res <- sc.handleRequestOuter(req.req)
|
||||
|
||||
case err := <-readErr:
|
||||
return err
|
||||
|
||||
case ss := <-sc.sessionRemove:
|
||||
if _, ok := sc.sessions[ss.ID()]; ok {
|
||||
delete(sc.sessions, ss.ID())
|
||||
ss.connRemove <- sc
|
||||
sc.sessionsWG.Done()
|
||||
}
|
||||
|
||||
case <-sc.innerTerminate:
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
|
||||
sc.nconn.Close()
|
||||
|
||||
sc.s.connClose <- sc
|
||||
<-sc.terminate
|
||||
|
||||
return err
|
||||
|
||||
case <-sc.innerTerminate:
|
||||
sc.nconn.Close()
|
||||
<-readDone
|
||||
|
||||
if sc.tcpFrameEnabled {
|
||||
sc.tcpFrameWriteBuffer.Close()
|
||||
<-sc.tcpFrameBackgroundWriteDone
|
||||
}
|
||||
|
||||
sc.s.connClose <- sc
|
||||
<-sc.terminate
|
||||
|
||||
return liberrors.ErrServerTerminated{}
|
||||
|
||||
case <-sc.terminate:
|
||||
sc.nconn.Close()
|
||||
<-readDone
|
||||
|
||||
if sc.tcpFrameEnabled {
|
||||
sc.tcpFrameWriteBuffer.Close()
|
||||
<-sc.tcpFrameBackgroundWriteDone
|
||||
}
|
||||
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case req, ok := <-readRequest:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
req.res <- liberrors.ErrServerTerminated{}
|
||||
|
||||
case _, ok := <-readErr:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
case ss, ok := <-sc.sessionRemove:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := sc.sessions[ss.ID()]; ok {
|
||||
sc.sessionsWG.Done()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
sc.nconn.Close()
|
||||
<-readDone
|
||||
|
||||
if sc.tcpFrameEnabled {
|
||||
sc.s.sessionClose <- sc.tcpFrameLinkedSession
|
||||
sc.tcpFrameWriteBuffer.Close()
|
||||
<-sc.tcpFrameBackgroundWriteDone
|
||||
}
|
||||
|
||||
for _, ss := range sc.sessions {
|
||||
ss.connRemove <- sc
|
||||
}
|
||||
sc.sessionsWG.Wait()
|
||||
|
||||
sc.s.connClose <- sc
|
||||
<-sc.parentTerminate
|
||||
|
||||
close(readRequest)
|
||||
close(readErr)
|
||||
close(sc.sessionRemove)
|
||||
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnConnClose); ok {
|
||||
h.OnConnClose(sc, err)
|
||||
}
|
||||
@@ -241,8 +282,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
// the connection can't communicate with another session
|
||||
// if it's receiving or sending TCP frames.
|
||||
if sc.tcpFrameLinkedSession != nil &&
|
||||
sxID != sc.tcpFrameLinkedSession.id {
|
||||
if sc.tcpSession != nil &&
|
||||
sxID != sc.tcpSession.id {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerLinkedToOtherSession{}
|
||||
@@ -252,16 +293,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
case base.Options:
|
||||
// handle request in session
|
||||
if sxID != "" {
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: false,
|
||||
res: cres,
|
||||
}
|
||||
res := <-cres
|
||||
return res.res, res.err
|
||||
_, res, err := sc.handleRequestInSession(sxID, req, false)
|
||||
return res, err
|
||||
}
|
||||
|
||||
// handle request here
|
||||
@@ -330,121 +363,65 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
case base.Announce:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok {
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: true,
|
||||
res: cres,
|
||||
}
|
||||
res := <-cres
|
||||
return res.res, res.err
|
||||
_, res, err := sc.handleRequestInSession(sxID, req, true)
|
||||
return res, err
|
||||
}
|
||||
|
||||
case base.Setup:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok {
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: true,
|
||||
res: cres,
|
||||
}
|
||||
res := <-cres
|
||||
return res.res, res.err
|
||||
_, res, err := sc.handleRequestInSession(sxID, req, true)
|
||||
return res, err
|
||||
}
|
||||
|
||||
case base.Play:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: false,
|
||||
res: cres,
|
||||
}
|
||||
res := <-cres
|
||||
ss, res, err := sc.handleRequestInSession(sxID, req, false)
|
||||
|
||||
if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok {
|
||||
sc.tcpFrameLinkedSession = res.ss
|
||||
if _, ok := err.(liberrors.ErrServerTCPFramesEnable); ok {
|
||||
sc.tcpSession = ss
|
||||
sc.tcpFrameIsRecording = false
|
||||
sc.tcpFrameSetEnabled = true
|
||||
return res.res, nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return res.res, res.err
|
||||
return res, err
|
||||
}
|
||||
|
||||
case base.Record:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: false,
|
||||
res: cres,
|
||||
}
|
||||
res := <-cres
|
||||
ss, res, err := sc.handleRequestInSession(sxID, req, false)
|
||||
|
||||
if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok {
|
||||
sc.tcpFrameLinkedSession = res.ss
|
||||
if _, ok := err.(liberrors.ErrServerTCPFramesEnable); ok {
|
||||
sc.tcpSession = ss
|
||||
sc.tcpFrameIsRecording = true
|
||||
sc.tcpFrameSetEnabled = true
|
||||
return res.res, nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return res.res, res.err
|
||||
return res, err
|
||||
}
|
||||
|
||||
case base.Pause:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: false,
|
||||
res: cres,
|
||||
}
|
||||
res := <-cres
|
||||
_, res, err := sc.handleRequestInSession(sxID, req, false)
|
||||
|
||||
if _, ok := res.err.(liberrors.ErrServerTCPFramesDisable); ok {
|
||||
if _, ok := err.(liberrors.ErrServerTCPFramesDisable); ok {
|
||||
sc.tcpFrameSetEnabled = false
|
||||
return res.res, nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
return res.res, res.err
|
||||
return res, err
|
||||
}
|
||||
|
||||
case base.Teardown:
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: false,
|
||||
res: cres,
|
||||
}
|
||||
res := <-cres
|
||||
return res.res, res.err
|
||||
_, res, err := sc.handleRequestInSession(sxID, req, false)
|
||||
return res, err
|
||||
|
||||
case base.GetParameter:
|
||||
// handle request in session
|
||||
if sxID != "" {
|
||||
cres := make(chan sessionReqRes)
|
||||
sc.s.sessionReq <- sessionReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: false,
|
||||
res: cres,
|
||||
}
|
||||
res := <-cres
|
||||
return res.res, res.err
|
||||
_, res, err := sc.handleRequestInSession(sxID, req, false)
|
||||
return res, err
|
||||
}
|
||||
|
||||
// handle request here
|
||||
@@ -563,6 +540,45 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (sc *ServerConn) handleRequestInSession(sxID string, req *base.Request, create bool,
|
||||
) (*ServerSession, *base.Response, error) {
|
||||
|
||||
// if the session is already linked to this conn, communicate directly with it
|
||||
if sxID != "" {
|
||||
if ss, ok := sc.sessions[sxID]; ok {
|
||||
cres := make(chan requestRes)
|
||||
ss.request <- request{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: create,
|
||||
res: cres,
|
||||
}
|
||||
res := <-cres
|
||||
|
||||
return ss, res.res, res.err
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise, pass through Server
|
||||
cres := make(chan requestRes)
|
||||
sc.s.sessionRequest <- request{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: create,
|
||||
res: cres,
|
||||
}
|
||||
res := <-cres
|
||||
|
||||
if res.ss != nil {
|
||||
sc.sessions[res.ss.ID()] = res.ss
|
||||
sc.sessionsWG.Add(1)
|
||||
}
|
||||
|
||||
return res.ss, res.res, res.err
|
||||
}
|
||||
|
||||
func (sc *ServerConn) tcpFrameBackgroundWrite() {
|
||||
defer close(sc.tcpFrameBackgroundWriteDone)
|
||||
|
||||
|
235
serversession.go
235
serversession.go
@@ -118,33 +118,42 @@ type ServerSession struct {
|
||||
id string
|
||||
wg *sync.WaitGroup
|
||||
|
||||
conns map[*ServerConn]struct{}
|
||||
connsWG sync.WaitGroup
|
||||
state ServerSessionState
|
||||
setuppedTracks map[int]ServerSessionSetuppedTrack
|
||||
setupProtocol *StreamProtocol
|
||||
setupPath *string
|
||||
setupQuery *string
|
||||
lastRequestTime time.Time
|
||||
linkedConn *ServerConn // tcp
|
||||
tcpConn *ServerConn // tcp
|
||||
udpIP net.IP // udp
|
||||
udpZone string // udp
|
||||
announcedTracks []ServerSessionAnnouncedTrack // publish
|
||||
udpLastFrameTime *int64 // publish, udp
|
||||
|
||||
// in
|
||||
request chan sessionReq
|
||||
innerTerminate chan struct{}
|
||||
terminate chan struct{}
|
||||
request chan request
|
||||
connRemove chan *ServerConn
|
||||
innerTerminate chan struct{}
|
||||
parentTerminate chan struct{}
|
||||
}
|
||||
|
||||
func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession {
|
||||
func newServerSession(s *Server,
|
||||
id string,
|
||||
wg *sync.WaitGroup,
|
||||
) *ServerSession {
|
||||
|
||||
ss := &ServerSession{
|
||||
s: s,
|
||||
id: id,
|
||||
wg: wg,
|
||||
conns: make(map[*ServerConn]struct{}),
|
||||
lastRequestTime: time.Now(),
|
||||
request: make(chan sessionReq),
|
||||
request: make(chan request),
|
||||
connRemove: make(chan *ServerConn),
|
||||
innerTerminate: make(chan struct{}, 1),
|
||||
terminate: make(chan struct{}),
|
||||
parentTerminate: make(chan struct{}),
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
@@ -208,105 +217,125 @@ func (ss *ServerSession) run() {
|
||||
h.OnSessionOpen(ss)
|
||||
}
|
||||
|
||||
readDone := make(chan error)
|
||||
go func() {
|
||||
readDone <- func() error {
|
||||
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
|
||||
defer checkTimeoutTicker.Stop()
|
||||
err := func() error {
|
||||
checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod)
|
||||
defer checkTimeoutTicker.Stop()
|
||||
|
||||
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
|
||||
defer receiverReportTicker.Stop()
|
||||
receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod)
|
||||
defer receiverReportTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-ss.request:
|
||||
res, err := ss.handleRequest(req.sc, req.req)
|
||||
for {
|
||||
select {
|
||||
case req := <-ss.request:
|
||||
ss.lastRequestTime = time.Now()
|
||||
|
||||
ss.lastRequestTime = time.Now()
|
||||
if _, ok := ss.conns[req.sc]; !ok {
|
||||
ss.conns[req.sc] = struct{}{}
|
||||
ss.connsWG.Add(1)
|
||||
}
|
||||
|
||||
if res.StatusCode == base.StatusOK {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
res, err := ss.handleRequest(req.sc, req.req)
|
||||
|
||||
if res.StatusCode == base.StatusOK {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
res.Header["Session"] = base.HeaderValue{ss.id}
|
||||
}
|
||||
|
||||
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
|
||||
req.res <- sessionReqRes{res: res, err: nil}
|
||||
if _, ok := err.(liberrors.ErrServerSessionTeardown); ok {
|
||||
req.res <- requestRes{res: res, err: nil}
|
||||
return liberrors.ErrServerSessionTeardown{}
|
||||
}
|
||||
|
||||
req.res <- requestRes{
|
||||
res: res,
|
||||
err: err,
|
||||
ss: ss,
|
||||
}
|
||||
|
||||
case sc := <-ss.connRemove:
|
||||
if _, ok := ss.conns[sc]; ok {
|
||||
delete(ss.conns, sc)
|
||||
sc.sessionRemove <- ss
|
||||
ss.connsWG.Done()
|
||||
}
|
||||
|
||||
// if session is not in state RECORD or PLAY, or protocol is TCP
|
||||
if (ss.state != ServerSessionStateRecord &&
|
||||
ss.state != ServerSessionStatePlay) ||
|
||||
*ss.setupProtocol == StreamProtocolTCP {
|
||||
|
||||
// close if there are no active connections
|
||||
if len(ss.conns) == 0 {
|
||||
return liberrors.ErrServerSessionTeardown{}
|
||||
}
|
||||
|
||||
req.res <- sessionReqRes{
|
||||
res: res,
|
||||
err: err,
|
||||
ss: ss,
|
||||
}
|
||||
|
||||
case <-checkTimeoutTicker.C:
|
||||
switch {
|
||||
// in case of record and UDP, timeout happens when no frames are being received
|
||||
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
|
||||
now := time.Now()
|
||||
lft := atomic.LoadInt64(ss.udpLastFrameTime)
|
||||
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
|
||||
// in case there's a linked TCP connection, timeout is handled in the connection
|
||||
case ss.linkedConn != nil:
|
||||
|
||||
// otherwise, timeout happens when no requests arrives
|
||||
default:
|
||||
now := time.Now()
|
||||
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
}
|
||||
|
||||
case <-receiverReportTicker.C:
|
||||
if ss.state != ServerSessionStateRecord {
|
||||
continue
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for trackID, track := range ss.announcedTracks {
|
||||
r := track.rtcpReceiver.Report(now)
|
||||
ss.WriteFrame(trackID, StreamTypeRTCP, r)
|
||||
}
|
||||
|
||||
case <-ss.innerTerminate:
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
|
||||
case <-checkTimeoutTicker.C:
|
||||
switch {
|
||||
// in case of RECORD and UDP, timeout happens when no frames are being received
|
||||
case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP:
|
||||
now := time.Now()
|
||||
lft := atomic.LoadInt64(ss.udpLastFrameTime)
|
||||
if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
|
||||
// in case of PLAY and UDP, timeout happens when no request arrives
|
||||
case ss.state == ServerSessionStatePlay && *ss.setupProtocol == StreamProtocolUDP:
|
||||
now := time.Now()
|
||||
if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor {
|
||||
return liberrors.ErrServerSessionTimedOut{}
|
||||
}
|
||||
|
||||
// otherwise, there's no timeout until all associated connections are closed
|
||||
}
|
||||
|
||||
case <-receiverReportTicker.C:
|
||||
if ss.state != ServerSessionStateRecord {
|
||||
continue
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for trackID, track := range ss.announcedTracks {
|
||||
r := track.rtcpReceiver.Report(now)
|
||||
ss.WriteFrame(trackID, StreamTypeRTCP, r)
|
||||
}
|
||||
|
||||
case <-ss.innerTerminate:
|
||||
return liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
var err error
|
||||
select {
|
||||
case err = <-readDone:
|
||||
go func() {
|
||||
for req := range ss.request {
|
||||
req.res <- sessionReqRes{
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case req, ok := <-ss.request:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
req.res <- requestRes{
|
||||
ss: nil,
|
||||
res: &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
},
|
||||
err: liberrors.ErrServerTerminated{},
|
||||
}
|
||||
|
||||
case sc, ok := <-ss.connRemove:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := ss.conns[sc]; ok {
|
||||
ss.connsWG.Done()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ss.s.sessionClose <- ss
|
||||
<-ss.terminate
|
||||
|
||||
case <-ss.terminate:
|
||||
select {
|
||||
case ss.innerTerminate <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-readDone
|
||||
|
||||
err = liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}()
|
||||
|
||||
switch ss.state {
|
||||
case ServerSessionStatePlay:
|
||||
@@ -321,11 +350,16 @@ func (ss *ServerSession) run() {
|
||||
}
|
||||
}
|
||||
|
||||
if ss.linkedConn != nil {
|
||||
ss.s.connClose <- ss.linkedConn
|
||||
for sc := range ss.conns {
|
||||
sc.sessionRemove <- ss
|
||||
}
|
||||
ss.connsWG.Wait()
|
||||
|
||||
ss.s.sessionClose <- ss
|
||||
<-ss.parentTerminate
|
||||
|
||||
close(ss.request)
|
||||
close(ss.connRemove)
|
||||
|
||||
if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok {
|
||||
h.OnSessionClose(ss, err)
|
||||
@@ -333,7 +367,7 @@ func (ss *ServerSession) run() {
|
||||
}
|
||||
|
||||
func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) {
|
||||
if ss.linkedConn != nil && sc != ss.linkedConn {
|
||||
if ss.tcpConn != nil && sc != ss.tcpConn {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerSessionLinkedToOtherConn{}
|
||||
@@ -620,10 +654,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
// this was causing problems during unit tests.
|
||||
if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 &&
|
||||
strings.HasPrefix(ua[0], "GStreamer") {
|
||||
select {
|
||||
case <-time.After(1 * time.Second):
|
||||
case <-sc.terminate:
|
||||
}
|
||||
<-time.After(1 * time.Second)
|
||||
}
|
||||
|
||||
return res, err
|
||||
@@ -664,7 +695,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
ss.udpIP = sc.ip()
|
||||
ss.udpZone = sc.zone()
|
||||
} else {
|
||||
ss.linkedConn = sc
|
||||
ss.tcpConn = sc
|
||||
}
|
||||
}
|
||||
|
||||
@@ -694,7 +725,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
|
||||
ss.udpIP = nil
|
||||
ss.udpZone = ""
|
||||
ss.linkedConn = nil
|
||||
ss.tcpConn = nil
|
||||
}
|
||||
|
||||
return res, err
|
||||
@@ -732,7 +763,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
ss.udpIP = sc.ip()
|
||||
ss.udpZone = sc.zone()
|
||||
} else {
|
||||
ss.linkedConn = sc
|
||||
ss.tcpConn = sc
|
||||
}
|
||||
|
||||
res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{
|
||||
@@ -766,7 +797,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
|
||||
ss.udpIP = nil
|
||||
ss.udpZone = ""
|
||||
ss.linkedConn = nil
|
||||
ss.tcpConn = nil
|
||||
|
||||
return res, err
|
||||
|
||||
@@ -809,7 +840,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
ss.state = ServerSessionStatePrePlay
|
||||
ss.udpIP = nil
|
||||
ss.udpZone = ""
|
||||
ss.linkedConn = nil
|
||||
ss.tcpConn = nil
|
||||
|
||||
if *ss.setupProtocol == StreamProtocolUDP {
|
||||
ss.s.udpRTCPListener.removeClient(ss)
|
||||
@@ -821,7 +852,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
ss.state = ServerSessionStatePreRecord
|
||||
ss.udpIP = nil
|
||||
ss.udpZone = ""
|
||||
ss.linkedConn = nil
|
||||
ss.tcpConn = nil
|
||||
|
||||
if *ss.setupProtocol == StreamProtocolUDP {
|
||||
ss.s.udpRTPListener.removeClient(ss)
|
||||
@@ -834,8 +865,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
return res, err
|
||||
|
||||
case base.Teardown:
|
||||
ss.linkedConn = nil
|
||||
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusOK,
|
||||
}, liberrors.ErrServerSessionTeardown{}
|
||||
@@ -896,7 +925,7 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload
|
||||
})
|
||||
}
|
||||
} else {
|
||||
ss.linkedConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{
|
||||
ss.tcpConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{
|
||||
TrackID: trackID,
|
||||
StreamType: streamType,
|
||||
Payload: payload,
|
||||
|
Reference in New Issue
Block a user