diff --git a/client.go b/client.go index a18c0770..4f2867cb 100644 --- a/client.go +++ b/client.go @@ -430,153 +430,155 @@ func (c *Client) Tracks() Tracks { func (c *Client) run() { defer close(c.done) - c.closeError = func() error { - for { - select { - case req := <-c.options: - res, err := c.doOptions(req.url) - req.res <- clientRes{res: res, err: err} - - case req := <-c.describe: - tracks, baseURL, res, err := c.doDescribe(req.url) - req.res <- clientRes{tracks: tracks, baseURL: baseURL, res: res, err: err} - - case req := <-c.announce: - res, err := c.doAnnounce(req.url, req.tracks) - req.res <- clientRes{res: res, err: err} - - case req := <-c.setup: - res, err := c.doSetup(req.forPlay, req.track, req.baseURL, req.rtpPort, req.rtcpPort) - req.res <- clientRes{res: res, err: err} - - case req := <-c.play: - res, err := c.doPlay(req.ra, false) - req.res <- clientRes{res: res, err: err} - - case req := <-c.record: - res, err := c.doRecord() - req.res <- clientRes{res: res, err: err} - - case req := <-c.pause: - res, err := c.doPause() - req.res <- clientRes{res: res, err: err} - - case <-c.udpReportTimer.C: - if c.state == clientStatePlay { - now := time.Now() - for trackID, cct := range c.tracks { - rr := cct.rtcpReceiver.Report(now) - if rr != nil { - c.WritePacketRTCP(trackID, rr) - } - } - - c.udpReportTimer = time.NewTimer(c.udpReceiverReportPeriod) - } else { // Record - now := time.Now() - for trackID, cct := range c.tracks { - sr := cct.rtcpSender.Report(now) - if sr != nil { - c.WritePacketRTCP(trackID, sr) - } - } - - c.udpReportTimer = time.NewTimer(c.udpSenderReportPeriod) - } - - case <-c.checkStreamTimer.C: - if *c.effectiveTransport == TransportUDP || - *c.effectiveTransport == TransportUDPMulticast { - if c.checkStreamInitial { - c.checkStreamInitial = false - - // check that at least one packet has been received - inTimeout := func() bool { - for _, cct := range c.tracks { - lft := atomic.LoadInt64(cct.udpRTPListener.lastPacketTime) - if lft != 0 { - return false - } - - lft = atomic.LoadInt64(cct.udpRTCPListener.lastPacketTime) - if lft != 0 { - return false - } - } - return true - }() - if inTimeout { - err := c.trySwitchingProtocol() - if err != nil { - return err - } - } - } else { - inTimeout := func() bool { - now := time.Now() - for _, cct := range c.tracks { - lft := time.Unix(atomic.LoadInt64(cct.udpRTPListener.lastPacketTime), 0) - if now.Sub(lft) < c.ReadTimeout { - return false - } - - lft = time.Unix(atomic.LoadInt64(cct.udpRTCPListener.lastPacketTime), 0) - if now.Sub(lft) < c.ReadTimeout { - return false - } - } - return true - }() - if inTimeout { - return liberrors.ErrClientUDPTimeout{} - } - } - } else { // TCP - inTimeout := func() bool { - now := time.Now() - lft := time.Unix(atomic.LoadInt64(c.tcpLastFrameTime), 0) - return now.Sub(lft) >= c.ReadTimeout - }() - if inTimeout { - return liberrors.ErrClientTCPTimeout{} - } - } - - c.checkStreamTimer = time.NewTimer(c.checkStreamPeriod) - - case <-c.keepaliveTimer.C: - _, err := c.do(&base.Request{ - Method: func() base.Method { - // the VLC integrated rtsp server requires GET_PARAMETER - if c.useGetParameter { - return base.GetParameter - } - return base.Options - }(), - // use the stream base URL, otherwise some cameras do not reply - URL: c.streamBaseURL, - }, true, false) - if err != nil { - return err - } - - c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) - - case err := <-c.readerErr: - c.readerErr = nil - return err - - case <-c.ctx.Done(): - return liberrors.ErrClientTerminated{} - } - } - }() + c.closeError = c.runInner() c.ctxCancel() c.doClose() } +func (c *Client) runInner() error { + for { + select { + case req := <-c.options: + res, err := c.doOptions(req.url) + req.res <- clientRes{res: res, err: err} + + case req := <-c.describe: + tracks, baseURL, res, err := c.doDescribe(req.url) + req.res <- clientRes{tracks: tracks, baseURL: baseURL, res: res, err: err} + + case req := <-c.announce: + res, err := c.doAnnounce(req.url, req.tracks) + req.res <- clientRes{res: res, err: err} + + case req := <-c.setup: + res, err := c.doSetup(req.forPlay, req.track, req.baseURL, req.rtpPort, req.rtcpPort) + req.res <- clientRes{res: res, err: err} + + case req := <-c.play: + res, err := c.doPlay(req.ra, false) + req.res <- clientRes{res: res, err: err} + + case req := <-c.record: + res, err := c.doRecord() + req.res <- clientRes{res: res, err: err} + + case req := <-c.pause: + res, err := c.doPause() + req.res <- clientRes{res: res, err: err} + + case <-c.udpReportTimer.C: + if c.state == clientStatePlay { + now := time.Now() + for trackID, cct := range c.tracks { + rr := cct.rtcpReceiver.Report(now) + if rr != nil { + c.WritePacketRTCP(trackID, rr) + } + } + + c.udpReportTimer = time.NewTimer(c.udpReceiverReportPeriod) + } else { // Record + now := time.Now() + for trackID, cct := range c.tracks { + sr := cct.rtcpSender.Report(now) + if sr != nil { + c.WritePacketRTCP(trackID, sr) + } + } + + c.udpReportTimer = time.NewTimer(c.udpSenderReportPeriod) + } + + case <-c.checkStreamTimer.C: + if *c.effectiveTransport == TransportUDP || + *c.effectiveTransport == TransportUDPMulticast { + if c.checkStreamInitial { + c.checkStreamInitial = false + + // check that at least one packet has been received + inTimeout := func() bool { + for _, cct := range c.tracks { + lft := atomic.LoadInt64(cct.udpRTPListener.lastPacketTime) + if lft != 0 { + return false + } + + lft = atomic.LoadInt64(cct.udpRTCPListener.lastPacketTime) + if lft != 0 { + return false + } + } + return true + }() + if inTimeout { + err := c.trySwitchingProtocol() + if err != nil { + return err + } + } + } else { + inTimeout := func() bool { + now := time.Now() + for _, cct := range c.tracks { + lft := time.Unix(atomic.LoadInt64(cct.udpRTPListener.lastPacketTime), 0) + if now.Sub(lft) < c.ReadTimeout { + return false + } + + lft = time.Unix(atomic.LoadInt64(cct.udpRTCPListener.lastPacketTime), 0) + if now.Sub(lft) < c.ReadTimeout { + return false + } + } + return true + }() + if inTimeout { + return liberrors.ErrClientUDPTimeout{} + } + } + } else { // TCP + inTimeout := func() bool { + now := time.Now() + lft := time.Unix(atomic.LoadInt64(c.tcpLastFrameTime), 0) + return now.Sub(lft) >= c.ReadTimeout + }() + if inTimeout { + return liberrors.ErrClientTCPTimeout{} + } + } + + c.checkStreamTimer = time.NewTimer(c.checkStreamPeriod) + + case <-c.keepaliveTimer.C: + _, err := c.do(&base.Request{ + Method: func() base.Method { + // the VLC integrated rtsp server requires GET_PARAMETER + if c.useGetParameter { + return base.GetParameter + } + return base.Options + }(), + // use the stream base URL, otherwise some cameras do not reply + URL: c.streamBaseURL, + }, true, false) + if err != nil { + return err + } + + c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) + + case err := <-c.readerErr: + c.readerErr = nil + return err + + case <-c.ctx.Done(): + return liberrors.ErrClientTerminated{} + } + } +} + func (c *Client) doClose() { if c.state == clientStatePlay || c.state == clientStateRecord { if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { @@ -957,22 +959,18 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba c.OnRequest(req) } + var buf bytes.Buffer + req.Write(&buf) + + c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) + _, err := c.conn.Write(buf.Bytes()) + if err != nil { + return nil, err + } + var res base.Response - err := func() error { - var buf bytes.Buffer - req.Write(&buf) - - c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) - _, err := c.conn.Write(buf.Bytes()) - if err != nil { - return err - } - - if skipResponse { - return nil - } - + if !skipResponse { c.conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) if allowFrames { @@ -983,51 +981,46 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba buf := make([]byte, c.ReadBufferSize) err = res.ReadIgnoreFrames(c.br, buf) if err != nil { - return err + return nil, err } } else { err = res.Read(c.br) if err != nil { - return err + return nil, err } } - return nil - }() - if err != nil { - return nil, err - } - - if c.OnResponse != nil { - c.OnResponse(&res) - } - - // get session from response - if v, ok := res.Header["Session"]; ok { - var sx headers.Session - err := sx.Read(v) - if err != nil { - return nil, liberrors.ErrClientSessionHeaderInvalid{Err: err} + if c.OnResponse != nil { + c.OnResponse(&res) } - c.session = sx.Session - if sx.Timeout != nil && *sx.Timeout > 0 { - c.keepalivePeriod = time.Duration(float64(*sx.Timeout)*0.8) * time.Second + // get session from response + if v, ok := res.Header["Session"]; ok { + var sx headers.Session + err := sx.Read(v) + if err != nil { + return nil, liberrors.ErrClientSessionHeaderInvalid{Err: err} + } + c.session = sx.Session + + if sx.Timeout != nil && *sx.Timeout > 0 { + c.keepalivePeriod = time.Duration(float64(*sx.Timeout)*0.8) * time.Second + } } - } - // if required, send request again with authentication - if res.StatusCode == base.StatusUnauthorized && req.URL.User != nil && c.sender == nil { - pass, _ := req.URL.User.Password() - user := req.URL.User.Username() + // if required, send request again with authentication + if res.StatusCode == base.StatusUnauthorized && req.URL.User != nil && c.sender == nil { + pass, _ := req.URL.User.Password() + user := req.URL.User.Username() - sender, err := auth.NewSender(res.Header["WWW-Authenticate"], user, pass) - if err != nil { - return nil, fmt.Errorf("unable to setup authentication: %s", err) + sender, err := auth.NewSender(res.Header["WWW-Authenticate"], user, pass) + if err != nil { + return nil, fmt.Errorf("unable to setup authentication: %s", err) + } + c.sender = sender + + return c.do(req, skipResponse, allowFrames) } - c.sender = sender - - return c.do(req, skipResponse, allowFrames) } return &res, nil diff --git a/serverconn.go b/serverconn.go index c7473102..ac0a258e 100644 --- a/serverconn.go +++ b/serverconn.go @@ -115,25 +115,7 @@ func (sc *ServerConn) run() { readDone := make(chan struct{}) go sc.runReader(readRequest, readErr, readDone) - err := func() error { - for { - select { - case req := <-readRequest: - req.res <- sc.handleRequestOuter(req.req) - - case err := <-readErr: - return err - - case ss := <-sc.sessionRemove: - if sc.session == ss { - sc.session = nil - } - - case <-sc.ctx.Done(): - return liberrors.ErrServerTerminated{} - } - } - }() + err := sc.runInner(readRequest, readErr) sc.ctxCancel() @@ -160,6 +142,26 @@ func (sc *ServerConn) run() { } } +func (sc *ServerConn) runInner(readRequest chan readReq, readErr chan error) error { + for { + select { + case req := <-readRequest: + req.res <- sc.handleRequestOuter(req.req) + + case err := <-readErr: + return err + + case ss := <-sc.sessionRemove: + if sc.session == ss { + sc.session = nil + } + + case <-sc.ctx.Done(): + return liberrors.ErrServerTerminated{} + } + } +} + var errSwitchReadFunc = errors.New("switch read function") func (sc *ServerConn) runReader(readRequest chan readReq, readErr chan error, readDone chan struct{}) { diff --git a/serversession.go b/serversession.go index 517a457d..3adf6e89 100644 --- a/serversession.go +++ b/serversession.go @@ -267,109 +267,7 @@ func (ss *ServerSession) run() { }) } - err := func() error { - for { - select { - case req := <-ss.request: - ss.lastRequestTime = time.Now() - - if _, ok := ss.conns[req.sc]; !ok { - ss.conns[req.sc] = struct{}{} - } - - res, err := ss.handleRequest(req.sc, req.req) - - var returnedSession *ServerSession - if err == nil || err == errSwitchReadFunc { - // ANNOUNCE responses don't contain the session header. - if req.req.Method != base.Announce { - if res.Header == nil { - res.Header = make(base.Header) - } - - res.Header["Session"] = headers.Session{ - Session: ss.secretID, - Timeout: func() *uint { - v := uint(ss.s.sessionTimeout / time.Second) - return &v - }(), - }.Write() - } - - // after a TEARDOWN, session must be unpaired with the connection. - if req.req.Method != base.Teardown { - returnedSession = ss - } - } - - savedMethod := req.req.Method - - req.res <- sessionRequestRes{ - res: res, - err: err, - ss: returnedSession, - } - - if (err == nil || err == errSwitchReadFunc) && savedMethod == base.Teardown { - return liberrors.ErrServerSessionTeardown{Author: req.sc.NetConn().RemoteAddr()} - } - - case sc := <-ss.connRemove: - delete(ss.conns, sc) - - // if session is not in state RECORD or PLAY, or transport is TCP - if (ss.state != ServerSessionStateRecord && - ss.state != ServerSessionStatePlay) || - *ss.setuppedTransport == TransportTCP { - // close session if there are no associated connections - if len(ss.conns) == 0 { - return liberrors.ErrServerSessionNotInUse{} - } - } - - case <-ss.startWriter: - if !ss.writerRunning && (ss.state == ServerSessionStateRecord || - ss.state == ServerSessionStatePlay) && - *ss.setuppedTransport == TransportTCP { - ss.writerRunning = true - ss.writerDone = make(chan struct{}) - go ss.runWriter() - } - - case <-ss.udpCheckStreamTimer.C: - now := time.Now() - - // in case of RECORD and UDP, timeout happens when no RTP or RTCP packets are being received - if ss.state == ServerSessionStateRecord { - lft := atomic.LoadInt64(ss.udpLastFrameTime) - if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout { - return liberrors.ErrServerNoUDPPacketsInAWhile{} - } - - // in case of PLAY and UDP, timeout happens when no RTSP request arrives - } else if now.Sub(ss.lastRequestTime) >= ss.s.sessionTimeout { - return liberrors.ErrServerNoRTSPRequestsInAWhile{} - } - - ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) - - case <-ss.udpReceiverReportTimer.C: - now := time.Now() - - for trackID, track := range ss.announcedTracks { - rr := track.rtcpReceiver.Report(now) - if rr != nil { - ss.WritePacketRTCP(trackID, rr) - } - } - - ss.udpReceiverReportTimer = time.NewTimer(ss.s.udpReceiverReportPeriod) - - case <-ss.ctx.Done(): - return liberrors.ErrServerTerminated{} - } - } - }() + err := ss.runInner() ss.ctxCancel() @@ -425,6 +323,110 @@ func (ss *ServerSession) run() { } } +func (ss *ServerSession) runInner() error { + for { + select { + case req := <-ss.request: + ss.lastRequestTime = time.Now() + + if _, ok := ss.conns[req.sc]; !ok { + ss.conns[req.sc] = struct{}{} + } + + res, err := ss.handleRequest(req.sc, req.req) + + var returnedSession *ServerSession + if err == nil || err == errSwitchReadFunc { + // ANNOUNCE responses don't contain the session header. + if req.req.Method != base.Announce { + if res.Header == nil { + res.Header = make(base.Header) + } + + res.Header["Session"] = headers.Session{ + Session: ss.secretID, + Timeout: func() *uint { + v := uint(ss.s.sessionTimeout / time.Second) + return &v + }(), + }.Write() + } + + // after a TEARDOWN, session must be unpaired with the connection. + if req.req.Method != base.Teardown { + returnedSession = ss + } + } + + savedMethod := req.req.Method + + req.res <- sessionRequestRes{ + res: res, + err: err, + ss: returnedSession, + } + + if (err == nil || err == errSwitchReadFunc) && savedMethod == base.Teardown { + return liberrors.ErrServerSessionTeardown{Author: req.sc.NetConn().RemoteAddr()} + } + + case sc := <-ss.connRemove: + delete(ss.conns, sc) + + // if session is not in state RECORD or PLAY, or transport is TCP + if (ss.state != ServerSessionStateRecord && + ss.state != ServerSessionStatePlay) || + *ss.setuppedTransport == TransportTCP { + // close session if there are no associated connections + if len(ss.conns) == 0 { + return liberrors.ErrServerSessionNotInUse{} + } + } + + case <-ss.startWriter: + if !ss.writerRunning && (ss.state == ServerSessionStateRecord || + ss.state == ServerSessionStatePlay) && + *ss.setuppedTransport == TransportTCP { + ss.writerRunning = true + ss.writerDone = make(chan struct{}) + go ss.runWriter() + } + + case <-ss.udpCheckStreamTimer.C: + now := time.Now() + + // in case of RECORD and UDP, timeout happens when no RTP or RTCP packets are being received + if ss.state == ServerSessionStateRecord { + lft := atomic.LoadInt64(ss.udpLastFrameTime) + if now.Sub(time.Unix(lft, 0)) >= ss.s.ReadTimeout { + return liberrors.ErrServerNoUDPPacketsInAWhile{} + } + + // in case of PLAY and UDP, timeout happens when no RTSP request arrives + } else if now.Sub(ss.lastRequestTime) >= ss.s.sessionTimeout { + return liberrors.ErrServerNoRTSPRequestsInAWhile{} + } + + ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod) + + case <-ss.udpReceiverReportTimer.C: + now := time.Now() + + for trackID, track := range ss.announcedTracks { + rr := track.rtcpReceiver.Report(now) + if rr != nil { + ss.WritePacketRTCP(trackID, rr) + } + } + + ss.udpReceiverReportTimer = time.NewTimer(ss.s.udpReceiverReportPeriod) + + case <-ss.ctx.Done(): + return liberrors.ErrServerTerminated{} + } + } +} + func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base.Response, error) { if ss.tcpConn != nil && sc != ss.tcpConn { return &base.Response{