replace ServerConn.EnableReadFrames with EnableFrames; prevent writing if the flag is disabled

This commit is contained in:
aler9
2020-12-15 20:07:04 +01:00
parent 27636bc810
commit e5b1260075
4 changed files with 24 additions and 22 deletions

View File

@@ -98,8 +98,7 @@ func handleConn(conn *gortsplib.ServerConn) {
readers[conn] = struct{}{}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(false)
conn.EnableFrames(true)
return &base.Response{
StatusCode: base.StatusOK,
@@ -120,7 +119,7 @@ func handleConn(conn *gortsplib.ServerConn) {
}, fmt.Errorf("someone is already publishing")
}
conn.EnableReadFrames(true)
conn.EnableFrames(true)
conn.EnableReadTimeout(true)
return &base.Response{

View File

@@ -97,8 +97,7 @@ func handleConn(conn *gortsplib.ServerConn) {
readers[conn] = struct{}{}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(false)
conn.EnableFrames(true)
return &base.Response{
StatusCode: base.StatusOK,
@@ -119,7 +118,7 @@ func handleConn(conn *gortsplib.ServerConn) {
}, fmt.Errorf("someone is already publishing")
}
conn.EnableReadFrames(true)
conn.EnableFrames(true)
conn.EnableReadTimeout(true)
return &base.Response{

View File

@@ -126,8 +126,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
ts.readers[conn] = struct{}{}
conn.EnableReadFrames(true)
conn.EnableReadTimeout(false)
conn.EnableFrames(true)
return &base.Response{
StatusCode: base.StatusOK,
@@ -147,7 +146,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
}, fmt.Errorf("someone is already publishing")
}
conn.EnableReadFrames(true)
conn.EnableFrames(true)
conn.EnableReadTimeout(true)
return &base.Response{

View File

@@ -25,17 +25,18 @@ var (
ErrServerContentTypeMissing = errors.New("Content-Type header is missing")
ErrServerNoTracksDefined = errors.New("no tracks defined")
ErrServerMissingCseq = errors.New("CSeq is missing")
ErrServerFramesDisabled = errors.New("frames are disabled")
)
// ServerConn is a server-side RTSP connection.
type ServerConn struct {
s *Server
nconn net.Conn
br *bufio.Reader
bw *bufio.Writer
mutex sync.Mutex
readFrames bool
readTimeout bool
s *Server
nconn net.Conn
br *bufio.Reader
bw *bufio.Writer
mutex sync.Mutex
framesEnabled bool
readTimeoutEnabled bool
}
// Close closes all the connection resources.
@@ -48,14 +49,14 @@ func (sc *ServerConn) NetConn() net.Conn {
return sc.nconn
}
// EnableReadFrames allows or denies receiving frames.
func (sc *ServerConn) EnableReadFrames(v bool) {
sc.readFrames = v
// EnableFrames allows reading and writing TCP frames.
func (sc *ServerConn) EnableFrames(v bool) {
sc.framesEnabled = v
}
// EnableReadTimeout sets or removes the timeout on incoming packets.
func (sc *ServerConn) EnableReadTimeout(v bool) {
sc.readTimeout = v
sc.readTimeoutEnabled = v
}
// ServerConnReadHandlers allows to set the handlers required by ServerConn.Read.
@@ -289,13 +290,13 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan
outer:
for {
if sc.readTimeout {
if sc.readTimeoutEnabled {
sc.nconn.SetReadDeadline(time.Now().Add(sc.s.conf.ReadTimeout))
} else {
sc.nconn.SetReadDeadline(time.Time{})
}
if sc.readFrames {
if sc.framesEnabled {
frame.Content = tcpFrameBuffer.Next()
what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br)
if err != nil {
@@ -349,6 +350,10 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, content []b
sc.mutex.Lock()
defer sc.mutex.Unlock()
if !sc.framesEnabled {
return ErrServerFramesDisabled
}
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
frame := base.InterleavedFrame{
TrackID: trackID,