diff --git a/README.md b/README.md index e637b881..a852eecc 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Features: * Client * Query servers about available streams - * Encrypt connection with TLS (RTSPS) + * Encrypt connections with TLS (RTSPS) * Read * Read streams from servers with UDP or TCP * Switch protocol automatically (switch to TCP in case of server error or UDP timeout) @@ -29,9 +29,9 @@ Features: * Server * Handle requests from clients * Sessions and connections are independent; clients can control multiple sessions + * Encrypt connections with TLS (RTSPS) * Read streams from clients with UDP or TCP * Write streams to clients with UDP or TCP - * Encrypt streams with TLS (RTSPS) * Generate RTCP receiver reports automatically * Utilities * Encode and decode RTSP primitives, RTP/H264, RTP/AAC, SDP diff --git a/server.go b/server.go index 3f1f27c8..4a1919ed 100644 --- a/server.go +++ b/server.go @@ -44,18 +44,18 @@ func newSessionID(sessions map[string]*ServerSession) (string, error) { } } -type sessionReqRes struct { +type requestRes struct { + ss *ServerSession res *base.Response err error - ss *ServerSession } -type sessionReq struct { +type request struct { sc *ServerConn req *base.Request id string create bool - res chan sessionReqRes + res chan requestRes } // Server is a RTSP server. @@ -134,10 +134,10 @@ type Server struct { exitError error // in - connClose chan *ServerConn - sessionReq chan sessionReq - sessionClose chan *ServerSession - terminate chan struct{} + connClose chan *ServerConn + sessionRequest chan request + sessionClose chan *ServerSession + terminate chan struct{} // out done chan struct{} @@ -237,7 +237,7 @@ func (s *Server) run() { s.sessions = make(map[string]*ServerSession) s.conns = make(map[*ServerConn]struct{}) s.connClose = make(chan *ServerConn) - s.sessionReq = make(chan sessionReq) + s.sessionRequest = make(chan request) s.sessionClose = make(chan *ServerSession) var wg sync.WaitGroup @@ -276,13 +276,13 @@ outer: } s.doConnClose(sc) - case req := <-s.sessionReq: + case req := <-s.sessionRequest: if ss, ok := s.sessions[req.id]; ok { ss.request <- req } else { if !req.create { - req.res <- sessionReqRes{ + req.res <- requestRes{ res: &base.Response{ StatusCode: base.StatusBadRequest, }, @@ -293,7 +293,7 @@ outer: id, err := newSessionID(s.sessions) if err != nil { - req.res <- sessionReqRes{ + req.res <- requestRes{ res: &base.Response{ StatusCode: base.StatusBadRequest, }, @@ -330,6 +330,7 @@ outer: if !ok { return } + nconn.Close() case _, ok := <-s.connClose: @@ -337,11 +338,12 @@ outer: return } - case req, ok := <-s.sessionReq: + case req, ok := <-s.sessionRequest: if !ok { return } - req.res <- sessionReqRes{ + + req.res <- requestRes{ res: &base.Response{ StatusCode: base.StatusBadRequest, }, @@ -379,7 +381,7 @@ outer: close(acceptErr) close(connNew) close(s.connClose) - close(s.sessionReq) + close(s.sessionRequest) close(s.sessionClose) close(s.done) } @@ -409,10 +411,12 @@ func (s *Server) StartAndWait(address string) error { func (s *Server) doConnClose(sc *ServerConn) { delete(s.conns, sc) - close(sc.terminate) + close(sc.parentTerminate) + sc.Close() } func (s *Server) doSessionClose(ss *ServerSession) { delete(s.sessions, ss.id) - close(ss.terminate) + close(ss.parentTerminate) + ss.Close() } diff --git a/server_read_test.go b/server_read_test.go index 6d54fe49..f66f58c2 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -935,7 +935,7 @@ func TestServerReadPlayPausePause(t *testing.T) { func TestServerReadTimeout(t *testing.T) { for _, proto := range []string{ "udp", - // checking TCP is useless, since there's no timeout when reading with TCP + // there's no timeout when reading with TCP } { t.Run(proto, func(t *testing.T) { sessionClosed := make(chan struct{}) diff --git a/server_test.go b/server_test.go index 20fe46cd..c3272a4e 100644 --- a/server_test.go +++ b/server_test.go @@ -929,3 +929,53 @@ func TestServerSessionClose(t *testing.T) { <-sessionClosed } + +func TestServerSessionAutoClose(t *testing.T) { + sessionClosed := make(chan struct{}) + + s := &Server{ + Handler: &testServerHandler{ + onSessionClose: func(ss *ServerSession, err error) { + close(sessionClosed) + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + InterleavedIDs: &[2]int{0, 1}, + }.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + conn.Close() + + <-sessionClosed +} diff --git a/serverconn.go b/serverconn.go index 1e9dd10e..402c0153 100644 --- a/serverconn.go +++ b/serverconn.go @@ -35,6 +35,11 @@ func getSessionID(header base.Header) string { return "" } +type readReq struct { + req *base.Request + res chan error +} + // ServerConn is a server-side RTSP connection. type ServerConn struct { s *Server @@ -43,19 +48,23 @@ type ServerConn struct { br *bufio.Reader bw *bufio.Writer + sessions map[string]*ServerSession + sessionsWG sync.WaitGroup + // TCP stream protocol - tcpFrameLinkedSession *ServerSession - tcpFrameIsRecording bool tcpFrameSetEnabled bool tcpFrameEnabled bool + tcpSession *ServerSession + tcpFrameIsRecording bool tcpFrameTimeout bool tcpFrameBuffer *multibuffer.MultiBuffer tcpFrameWriteBuffer *ringbuffer.RingBuffer tcpFrameBackgroundWriteDone chan struct{} // in - innerTerminate chan struct{} - terminate chan struct{} + sessionRemove chan *ServerSession + innerTerminate chan struct{} + parentTerminate chan struct{} } func newServerConn( @@ -64,11 +73,12 @@ func newServerConn( nconn net.Conn) *ServerConn { sc := &ServerConn{ - s: s, - wg: wg, - nconn: nconn, - innerTerminate: make(chan struct{}, 1), - terminate: make(chan struct{}), + s: s, + wg: wg, + nconn: nconn, + sessionRemove: make(chan *ServerSession), + innerTerminate: make(chan struct{}, 1), + parentTerminate: make(chan struct{}), } wg.Add(1) @@ -115,13 +125,17 @@ func (sc *ServerConn) run() { sc.br = bufio.NewReaderSize(conn, serverConnReadBufferSize) sc.bw = bufio.NewWriterSize(conn, serverConnWriteBufferSize) + sc.sessions = make(map[string]*ServerSession) // instantiate always to allow writing to this conn before Play() sc.tcpFrameWriteBuffer = ringbuffer.New(uint64(sc.s.ReadBufferCount)) - readDone := make(chan error) + readRequest := make(chan readReq) + readErr := make(chan error) + readDone := make(chan struct{}) go func() { - readDone <- func() error { + defer close(readDone) + readErr <- func() error { var req base.Request var frame base.InterleavedFrame @@ -140,15 +154,15 @@ func (sc *ServerConn) run() { switch what.(type) { case *base.InterleavedFrame: // forward frame only if it has been set up - if _, ok := sc.tcpFrameLinkedSession.setuppedTracks[frame.TrackID]; ok { + if _, ok := sc.tcpSession.setuppedTracks[frame.TrackID]; ok { if sc.tcpFrameIsRecording { - sc.tcpFrameLinkedSession.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(), + sc.tcpSession.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(), frame.StreamType, frame.Payload) } if h, ok := sc.s.Handler.(ServerHandlerOnFrame); ok { h.OnFrame(&ServerHandlerOnFrameCtx{ - Session: sc.tcpFrameLinkedSession, + Session: sc.tcpSession, TrackID: frame.TrackID, StreamType: frame.StreamType, Payload: frame.Payload, @@ -157,7 +171,9 @@ func (sc *ServerConn) run() { } case *base.Request: - err := sc.handleRequestOuter(&req) + cres := make(chan error) + readRequest <- readReq{req: &req, res: cres} + err := <-cres if err != nil { return err } @@ -169,7 +185,9 @@ func (sc *ServerConn) run() { return err } - err = sc.handleRequestOuter(&req) + cres := make(chan error) + readRequest <- readReq{req: &req, res: cres} + err = <-cres if err != nil { return err } @@ -179,51 +197,74 @@ func (sc *ServerConn) run() { }() err := func() error { - select { - case err := <-readDone: - if sc.tcpFrameEnabled { - sc.tcpFrameWriteBuffer.Close() - <-sc.tcpFrameBackgroundWriteDone + for { + select { + case req := <-readRequest: + req.res <- sc.handleRequestOuter(req.req) + + case err := <-readErr: + return err + + case ss := <-sc.sessionRemove: + if _, ok := sc.sessions[ss.ID()]; ok { + delete(sc.sessions, ss.ID()) + ss.connRemove <- sc + sc.sessionsWG.Done() + } + + case <-sc.innerTerminate: + return liberrors.ErrServerTerminated{} } - - sc.nconn.Close() - - sc.s.connClose <- sc - <-sc.terminate - - return err - - case <-sc.innerTerminate: - sc.nconn.Close() - <-readDone - - if sc.tcpFrameEnabled { - sc.tcpFrameWriteBuffer.Close() - <-sc.tcpFrameBackgroundWriteDone - } - - sc.s.connClose <- sc - <-sc.terminate - - return liberrors.ErrServerTerminated{} - - case <-sc.terminate: - sc.nconn.Close() - <-readDone - - if sc.tcpFrameEnabled { - sc.tcpFrameWriteBuffer.Close() - <-sc.tcpFrameBackgroundWriteDone - } - - return liberrors.ErrServerTerminated{} } }() + go func() { + for { + select { + case req, ok := <-readRequest: + if !ok { + return + } + + req.res <- liberrors.ErrServerTerminated{} + + case _, ok := <-readErr: + if !ok { + return + } + + case ss, ok := <-sc.sessionRemove: + if !ok { + return + } + + if _, ok := sc.sessions[ss.ID()]; ok { + sc.sessionsWG.Done() + } + } + } + }() + + sc.nconn.Close() + <-readDone + if sc.tcpFrameEnabled { - sc.s.sessionClose <- sc.tcpFrameLinkedSession + sc.tcpFrameWriteBuffer.Close() + <-sc.tcpFrameBackgroundWriteDone } + for _, ss := range sc.sessions { + ss.connRemove <- sc + } + sc.sessionsWG.Wait() + + sc.s.connClose <- sc + <-sc.parentTerminate + + close(readRequest) + close(readErr) + close(sc.sessionRemove) + if h, ok := sc.s.Handler.(ServerHandlerOnConnClose); ok { h.OnConnClose(sc, err) } @@ -241,8 +282,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { // the connection can't communicate with another session // if it's receiving or sending TCP frames. - if sc.tcpFrameLinkedSession != nil && - sxID != sc.tcpFrameLinkedSession.id { + if sc.tcpSession != nil && + sxID != sc.tcpSession.id { return &base.Response{ StatusCode: base.StatusBadRequest, }, liberrors.ErrServerLinkedToOtherSession{} @@ -252,16 +293,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { case base.Options: // handle request in session if sxID != "" { - cres := make(chan sessionReqRes) - sc.s.sessionReq <- sessionReq{ - sc: sc, - req: req, - id: sxID, - create: false, - res: cres, - } - res := <-cres - return res.res, res.err + _, res, err := sc.handleRequestInSession(sxID, req, false) + return res, err } // handle request here @@ -330,121 +363,65 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { case base.Announce: if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok { - cres := make(chan sessionReqRes) - sc.s.sessionReq <- sessionReq{ - sc: sc, - req: req, - id: sxID, - create: true, - res: cres, - } - res := <-cres - return res.res, res.err + _, res, err := sc.handleRequestInSession(sxID, req, true) + return res, err } case base.Setup: if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok { - cres := make(chan sessionReqRes) - sc.s.sessionReq <- sessionReq{ - sc: sc, - req: req, - id: sxID, - create: true, - res: cres, - } - res := <-cres - return res.res, res.err + _, res, err := sc.handleRequestInSession(sxID, req, true) + return res, err } case base.Play: if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok { - cres := make(chan sessionReqRes) - sc.s.sessionReq <- sessionReq{ - sc: sc, - req: req, - id: sxID, - create: false, - res: cres, - } - res := <-cres + ss, res, err := sc.handleRequestInSession(sxID, req, false) - if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok { - sc.tcpFrameLinkedSession = res.ss + if _, ok := err.(liberrors.ErrServerTCPFramesEnable); ok { + sc.tcpSession = ss sc.tcpFrameIsRecording = false sc.tcpFrameSetEnabled = true - return res.res, nil + return res, nil } - return res.res, res.err + return res, err } case base.Record: if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok { - cres := make(chan sessionReqRes) - sc.s.sessionReq <- sessionReq{ - sc: sc, - req: req, - id: sxID, - create: false, - res: cres, - } - res := <-cres + ss, res, err := sc.handleRequestInSession(sxID, req, false) - if _, ok := res.err.(liberrors.ErrServerTCPFramesEnable); ok { - sc.tcpFrameLinkedSession = res.ss + if _, ok := err.(liberrors.ErrServerTCPFramesEnable); ok { + sc.tcpSession = ss sc.tcpFrameIsRecording = true sc.tcpFrameSetEnabled = true - return res.res, nil + return res, nil } - return res.res, res.err + return res, err } case base.Pause: if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok { - cres := make(chan sessionReqRes) - sc.s.sessionReq <- sessionReq{ - sc: sc, - req: req, - id: sxID, - create: false, - res: cres, - } - res := <-cres + _, res, err := sc.handleRequestInSession(sxID, req, false) - if _, ok := res.err.(liberrors.ErrServerTCPFramesDisable); ok { + if _, ok := err.(liberrors.ErrServerTCPFramesDisable); ok { sc.tcpFrameSetEnabled = false - return res.res, nil + return res, nil } - return res.res, res.err + return res, err } case base.Teardown: - cres := make(chan sessionReqRes) - sc.s.sessionReq <- sessionReq{ - sc: sc, - req: req, - id: sxID, - create: false, - res: cres, - } - res := <-cres - return res.res, res.err + _, res, err := sc.handleRequestInSession(sxID, req, false) + return res, err case base.GetParameter: // handle request in session if sxID != "" { - cres := make(chan sessionReqRes) - sc.s.sessionReq <- sessionReq{ - sc: sc, - req: req, - id: sxID, - create: false, - res: cres, - } - res := <-cres - return res.res, res.err + _, res, err := sc.handleRequestInSession(sxID, req, false) + return res, err } // handle request here @@ -563,6 +540,45 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error { return err } +func (sc *ServerConn) handleRequestInSession(sxID string, req *base.Request, create bool, +) (*ServerSession, *base.Response, error) { + + // if the session is already linked to this conn, communicate directly with it + if sxID != "" { + if ss, ok := sc.sessions[sxID]; ok { + cres := make(chan requestRes) + ss.request <- request{ + sc: sc, + req: req, + id: sxID, + create: create, + res: cres, + } + res := <-cres + + return ss, res.res, res.err + } + } + + // otherwise, pass through Server + cres := make(chan requestRes) + sc.s.sessionRequest <- request{ + sc: sc, + req: req, + id: sxID, + create: create, + res: cres, + } + res := <-cres + + if res.ss != nil { + sc.sessions[res.ss.ID()] = res.ss + sc.sessionsWG.Add(1) + } + + return res.ss, res.res, res.err +} + func (sc *ServerConn) tcpFrameBackgroundWrite() { defer close(sc.tcpFrameBackgroundWriteDone) diff --git a/serversession.go b/serversession.go index 1f0bc8d9..896e35c0 100644 --- a/serversession.go +++ b/serversession.go @@ -118,33 +118,42 @@ type ServerSession struct { id string wg *sync.WaitGroup + conns map[*ServerConn]struct{} + connsWG sync.WaitGroup state ServerSessionState setuppedTracks map[int]ServerSessionSetuppedTrack setupProtocol *StreamProtocol setupPath *string setupQuery *string lastRequestTime time.Time - linkedConn *ServerConn // tcp + tcpConn *ServerConn // tcp udpIP net.IP // udp udpZone string // udp announcedTracks []ServerSessionAnnouncedTrack // publish udpLastFrameTime *int64 // publish, udp // in - request chan sessionReq - innerTerminate chan struct{} - terminate chan struct{} + request chan request + connRemove chan *ServerConn + innerTerminate chan struct{} + parentTerminate chan struct{} } -func newServerSession(s *Server, id string, wg *sync.WaitGroup) *ServerSession { +func newServerSession(s *Server, + id string, + wg *sync.WaitGroup, +) *ServerSession { + ss := &ServerSession{ s: s, id: id, wg: wg, + conns: make(map[*ServerConn]struct{}), lastRequestTime: time.Now(), - request: make(chan sessionReq), + request: make(chan request), + connRemove: make(chan *ServerConn), innerTerminate: make(chan struct{}, 1), - terminate: make(chan struct{}), + parentTerminate: make(chan struct{}), } wg.Add(1) @@ -208,105 +217,125 @@ func (ss *ServerSession) run() { h.OnSessionOpen(ss) } - readDone := make(chan error) - go func() { - readDone <- func() error { - checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod) - defer checkTimeoutTicker.Stop() + err := func() error { + checkTimeoutTicker := time.NewTicker(serverSessionCheckStreamPeriod) + defer checkTimeoutTicker.Stop() - receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod) - defer receiverReportTicker.Stop() + receiverReportTicker := time.NewTicker(ss.s.receiverReportPeriod) + defer receiverReportTicker.Stop() - for { - select { - case req := <-ss.request: - res, err := ss.handleRequest(req.sc, req.req) + for { + select { + case req := <-ss.request: + ss.lastRequestTime = time.Now() - ss.lastRequestTime = time.Now() + if _, ok := ss.conns[req.sc]; !ok { + ss.conns[req.sc] = struct{}{} + ss.connsWG.Add(1) + } - if res.StatusCode == base.StatusOK { - if res.Header == nil { - res.Header = make(base.Header) - } - res.Header["Session"] = base.HeaderValue{ss.id} + res, err := ss.handleRequest(req.sc, req.req) + + if res.StatusCode == base.StatusOK { + if res.Header == nil { + res.Header = make(base.Header) } + res.Header["Session"] = base.HeaderValue{ss.id} + } - if _, ok := err.(liberrors.ErrServerSessionTeardown); ok { - req.res <- sessionReqRes{res: res, err: nil} + if _, ok := err.(liberrors.ErrServerSessionTeardown); ok { + req.res <- requestRes{res: res, err: nil} + return liberrors.ErrServerSessionTeardown{} + } + + req.res <- requestRes{ + res: res, + err: err, + ss: ss, + } + + case sc := <-ss.connRemove: + if _, ok := ss.conns[sc]; ok { + delete(ss.conns, sc) + sc.sessionRemove <- ss + ss.connsWG.Done() + } + + // if session is not in state RECORD or PLAY, or protocol is TCP + if (ss.state != ServerSessionStateRecord && + ss.state != ServerSessionStatePlay) || + *ss.setupProtocol == StreamProtocolTCP { + + // close if there are no active connections + if len(ss.conns) == 0 { return liberrors.ErrServerSessionTeardown{} } - - req.res <- sessionReqRes{ - res: res, - err: err, - ss: ss, - } - - case <-checkTimeoutTicker.C: - switch { - // in case of record and UDP, timeout happens when no frames are being received - case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP: - now := time.Now() - lft := atomic.LoadInt64(ss.udpLastFrameTime) - if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout { - return liberrors.ErrServerSessionTimedOut{} - } - - // in case there's a linked TCP connection, timeout is handled in the connection - case ss.linkedConn != nil: - - // otherwise, timeout happens when no requests arrives - default: - now := time.Now() - if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor { - return liberrors.ErrServerSessionTimedOut{} - } - } - - case <-receiverReportTicker.C: - if ss.state != ServerSessionStateRecord { - continue - } - - now := time.Now() - for trackID, track := range ss.announcedTracks { - r := track.rtcpReceiver.Report(now) - ss.WriteFrame(trackID, StreamTypeRTCP, r) - } - - case <-ss.innerTerminate: - return liberrors.ErrServerTerminated{} } + + case <-checkTimeoutTicker.C: + switch { + // in case of RECORD and UDP, timeout happens when no frames are being received + case ss.state == ServerSessionStateRecord && *ss.setupProtocol == StreamProtocolUDP: + now := time.Now() + lft := atomic.LoadInt64(ss.udpLastFrameTime) + if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout { + return liberrors.ErrServerSessionTimedOut{} + } + + // in case of PLAY and UDP, timeout happens when no request arrives + case ss.state == ServerSessionStatePlay && *ss.setupProtocol == StreamProtocolUDP: + now := time.Now() + if now.Sub(ss.lastRequestTime) >= ss.s.closeSessionAfterNoRequestsFor { + return liberrors.ErrServerSessionTimedOut{} + } + + // otherwise, there's no timeout until all associated connections are closed + } + + case <-receiverReportTicker.C: + if ss.state != ServerSessionStateRecord { + continue + } + + now := time.Now() + for trackID, track := range ss.announcedTracks { + r := track.rtcpReceiver.Report(now) + ss.WriteFrame(trackID, StreamTypeRTCP, r) + } + + case <-ss.innerTerminate: + return liberrors.ErrServerTerminated{} } - }() + } }() - var err error - select { - case err = <-readDone: - go func() { - for req := range ss.request { - req.res <- sessionReqRes{ + go func() { + for { + select { + case req, ok := <-ss.request: + if !ok { + return + } + + req.res <- requestRes{ + ss: nil, res: &base.Response{ StatusCode: base.StatusBadRequest, }, err: liberrors.ErrServerTerminated{}, } + + case sc, ok := <-ss.connRemove: + if !ok { + return + } + + if _, ok := ss.conns[sc]; ok { + ss.connsWG.Done() + } } - }() - - ss.s.sessionClose <- ss - <-ss.terminate - - case <-ss.terminate: - select { - case ss.innerTerminate <- struct{}{}: - default: } - <-readDone - - err = liberrors.ErrServerTerminated{} - } + }() switch ss.state { case ServerSessionStatePlay: @@ -321,11 +350,16 @@ func (ss *ServerSession) run() { } } - if ss.linkedConn != nil { - ss.s.connClose <- ss.linkedConn + for sc := range ss.conns { + sc.sessionRemove <- ss } + ss.connsWG.Wait() + + ss.s.sessionClose <- ss + <-ss.parentTerminate close(ss.request) + close(ss.connRemove) if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok { h.OnSessionClose(ss, err) @@ -333,7 +367,7 @@ func (ss *ServerSession) run() { } func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) { - if ss.linkedConn != nil && sc != ss.linkedConn { + if ss.tcpConn != nil && sc != ss.tcpConn { return &base.Response{ StatusCode: base.StatusBadRequest, }, liberrors.ErrServerSessionLinkedToOtherConn{} @@ -620,10 +654,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base // this was causing problems during unit tests. if ua, ok := req.Header["User-Agent"]; ok && len(ua) == 1 && strings.HasPrefix(ua[0], "GStreamer") { - select { - case <-time.After(1 * time.Second): - case <-sc.terminate: - } + <-time.After(1 * time.Second) } return res, err @@ -664,7 +695,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.udpIP = sc.ip() ss.udpZone = sc.zone() } else { - ss.linkedConn = sc + ss.tcpConn = sc } } @@ -694,7 +725,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.udpIP = nil ss.udpZone = "" - ss.linkedConn = nil + ss.tcpConn = nil } return res, err @@ -732,7 +763,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.udpIP = sc.ip() ss.udpZone = sc.zone() } else { - ss.linkedConn = sc + ss.tcpConn = sc } res, err := ss.s.Handler.(ServerHandlerOnRecord).OnRecord(&ServerHandlerOnRecordCtx{ @@ -766,7 +797,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.udpIP = nil ss.udpZone = "" - ss.linkedConn = nil + ss.tcpConn = nil return res, err @@ -809,7 +840,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.state = ServerSessionStatePrePlay ss.udpIP = nil ss.udpZone = "" - ss.linkedConn = nil + ss.tcpConn = nil if *ss.setupProtocol == StreamProtocolUDP { ss.s.udpRTCPListener.removeClient(ss) @@ -821,7 +852,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.state = ServerSessionStatePreRecord ss.udpIP = nil ss.udpZone = "" - ss.linkedConn = nil + ss.tcpConn = nil if *ss.setupProtocol == StreamProtocolUDP { ss.s.udpRTPListener.removeClient(ss) @@ -834,8 +865,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base return res, err case base.Teardown: - ss.linkedConn = nil - return &base.Response{ StatusCode: base.StatusOK, }, liberrors.ErrServerSessionTeardown{} @@ -896,7 +925,7 @@ func (ss *ServerSession) WriteFrame(trackID int, streamType StreamType, payload }) } } else { - ss.linkedConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{ + ss.tcpConn.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{ TrackID: trackID, StreamType: streamType, Payload: payload,