diff --git a/examples/server-tls.go b/examples/server-tls.go index 6e8706bb..edf78a47 100644 --- a/examples/server-tls.go +++ b/examples/server-tls.go @@ -98,8 +98,7 @@ func handleConn(conn *gortsplib.ServerConn) { readers[conn] = struct{}{} - conn.EnableReadFrames(true) - conn.EnableReadTimeout(false) + conn.EnableFrames(true) return &base.Response{ StatusCode: base.StatusOK, @@ -120,7 +119,7 @@ func handleConn(conn *gortsplib.ServerConn) { }, fmt.Errorf("someone is already publishing") } - conn.EnableReadFrames(true) + conn.EnableFrames(true) conn.EnableReadTimeout(true) return &base.Response{ diff --git a/examples/server.go b/examples/server.go index 22afc197..b038f6cb 100644 --- a/examples/server.go +++ b/examples/server.go @@ -97,8 +97,7 @@ func handleConn(conn *gortsplib.ServerConn) { readers[conn] = struct{}{} - conn.EnableReadFrames(true) - conn.EnableReadTimeout(false) + conn.EnableFrames(true) return &base.Response{ StatusCode: base.StatusOK, @@ -119,7 +118,7 @@ func handleConn(conn *gortsplib.ServerConn) { }, fmt.Errorf("someone is already publishing") } - conn.EnableReadFrames(true) + conn.EnableFrames(true) conn.EnableReadTimeout(true) return &base.Response{ diff --git a/serverconf_test.go b/serverconf_test.go index 966a9b1c..eb169891 100644 --- a/serverconf_test.go +++ b/serverconf_test.go @@ -126,8 +126,7 @@ func (ts *testServ) handleConn(conn *ServerConn) { ts.readers[conn] = struct{}{} - conn.EnableReadFrames(true) - conn.EnableReadTimeout(false) + conn.EnableFrames(true) return &base.Response{ StatusCode: base.StatusOK, @@ -147,7 +146,7 @@ func (ts *testServ) handleConn(conn *ServerConn) { }, fmt.Errorf("someone is already publishing") } - conn.EnableReadFrames(true) + conn.EnableFrames(true) conn.EnableReadTimeout(true) return &base.Response{ diff --git a/serverconn.go b/serverconn.go index 287f1316..1c04e6ad 100644 --- a/serverconn.go +++ b/serverconn.go @@ -25,17 +25,18 @@ var ( ErrServerContentTypeMissing = errors.New("Content-Type header is missing") ErrServerNoTracksDefined = errors.New("no tracks defined") ErrServerMissingCseq = errors.New("CSeq is missing") + ErrServerFramesDisabled = errors.New("frames are disabled") ) // ServerConn is a server-side RTSP connection. type ServerConn struct { - s *Server - nconn net.Conn - br *bufio.Reader - bw *bufio.Writer - mutex sync.Mutex - readFrames bool - readTimeout bool + s *Server + nconn net.Conn + br *bufio.Reader + bw *bufio.Writer + mutex sync.Mutex + framesEnabled bool + readTimeoutEnabled bool } // Close closes all the connection resources. @@ -48,14 +49,14 @@ func (sc *ServerConn) NetConn() net.Conn { return sc.nconn } -// EnableReadFrames allows or denies receiving frames. -func (sc *ServerConn) EnableReadFrames(v bool) { - sc.readFrames = v +// EnableFrames allows reading and writing TCP frames. +func (sc *ServerConn) EnableFrames(v bool) { + sc.framesEnabled = v } // EnableReadTimeout sets or removes the timeout on incoming packets. func (sc *ServerConn) EnableReadTimeout(v bool) { - sc.readTimeout = v + sc.readTimeoutEnabled = v } // ServerConnReadHandlers allows to set the handlers required by ServerConn.Read. @@ -289,13 +290,13 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan outer: for { - if sc.readTimeout { + if sc.readTimeoutEnabled { sc.nconn.SetReadDeadline(time.Now().Add(sc.s.conf.ReadTimeout)) } else { sc.nconn.SetReadDeadline(time.Time{}) } - if sc.readFrames { + if sc.framesEnabled { frame.Content = tcpFrameBuffer.Next() what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br) if err != nil { @@ -349,6 +350,10 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, content []b sc.mutex.Lock() defer sc.mutex.Unlock() + if !sc.framesEnabled { + return ErrServerFramesDisabled + } + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) frame := base.InterleavedFrame{ TrackID: trackID,