replace ServerConn.EnableReadFrames with EnableFrames; prevent writing if the flag is disabled

This commit is contained in:
aler9
2020-12-15 20:07:04 +01:00
parent 27636bc810
commit e5b1260075
4 changed files with 24 additions and 22 deletions

View File

@@ -98,8 +98,7 @@ func handleConn(conn *gortsplib.ServerConn) {
readers[conn] = struct{}{} readers[conn] = struct{}{}
conn.EnableReadFrames(true) conn.EnableFrames(true)
conn.EnableReadTimeout(false)
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
@@ -120,7 +119,7 @@ func handleConn(conn *gortsplib.ServerConn) {
}, fmt.Errorf("someone is already publishing") }, fmt.Errorf("someone is already publishing")
} }
conn.EnableReadFrames(true) conn.EnableFrames(true)
conn.EnableReadTimeout(true) conn.EnableReadTimeout(true)
return &base.Response{ return &base.Response{

View File

@@ -97,8 +97,7 @@ func handleConn(conn *gortsplib.ServerConn) {
readers[conn] = struct{}{} readers[conn] = struct{}{}
conn.EnableReadFrames(true) conn.EnableFrames(true)
conn.EnableReadTimeout(false)
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
@@ -119,7 +118,7 @@ func handleConn(conn *gortsplib.ServerConn) {
}, fmt.Errorf("someone is already publishing") }, fmt.Errorf("someone is already publishing")
} }
conn.EnableReadFrames(true) conn.EnableFrames(true)
conn.EnableReadTimeout(true) conn.EnableReadTimeout(true)
return &base.Response{ return &base.Response{

View File

@@ -126,8 +126,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
ts.readers[conn] = struct{}{} ts.readers[conn] = struct{}{}
conn.EnableReadFrames(true) conn.EnableFrames(true)
conn.EnableReadTimeout(false)
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
@@ -147,7 +146,7 @@ func (ts *testServ) handleConn(conn *ServerConn) {
}, fmt.Errorf("someone is already publishing") }, fmt.Errorf("someone is already publishing")
} }
conn.EnableReadFrames(true) conn.EnableFrames(true)
conn.EnableReadTimeout(true) conn.EnableReadTimeout(true)
return &base.Response{ return &base.Response{

View File

@@ -25,6 +25,7 @@ var (
ErrServerContentTypeMissing = errors.New("Content-Type header is missing") ErrServerContentTypeMissing = errors.New("Content-Type header is missing")
ErrServerNoTracksDefined = errors.New("no tracks defined") ErrServerNoTracksDefined = errors.New("no tracks defined")
ErrServerMissingCseq = errors.New("CSeq is missing") ErrServerMissingCseq = errors.New("CSeq is missing")
ErrServerFramesDisabled = errors.New("frames are disabled")
) )
// ServerConn is a server-side RTSP connection. // ServerConn is a server-side RTSP connection.
@@ -34,8 +35,8 @@ type ServerConn struct {
br *bufio.Reader br *bufio.Reader
bw *bufio.Writer bw *bufio.Writer
mutex sync.Mutex mutex sync.Mutex
readFrames bool framesEnabled bool
readTimeout bool readTimeoutEnabled bool
} }
// Close closes all the connection resources. // Close closes all the connection resources.
@@ -48,14 +49,14 @@ func (sc *ServerConn) NetConn() net.Conn {
return sc.nconn return sc.nconn
} }
// EnableReadFrames allows or denies receiving frames. // EnableFrames allows reading and writing TCP frames.
func (sc *ServerConn) EnableReadFrames(v bool) { func (sc *ServerConn) EnableFrames(v bool) {
sc.readFrames = v sc.framesEnabled = v
} }
// EnableReadTimeout sets or removes the timeout on incoming packets. // EnableReadTimeout sets or removes the timeout on incoming packets.
func (sc *ServerConn) EnableReadTimeout(v bool) { func (sc *ServerConn) EnableReadTimeout(v bool) {
sc.readTimeout = v sc.readTimeoutEnabled = v
} }
// ServerConnReadHandlers allows to set the handlers required by ServerConn.Read. // ServerConnReadHandlers allows to set the handlers required by ServerConn.Read.
@@ -289,13 +290,13 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan
outer: outer:
for { for {
if sc.readTimeout { if sc.readTimeoutEnabled {
sc.nconn.SetReadDeadline(time.Now().Add(sc.s.conf.ReadTimeout)) sc.nconn.SetReadDeadline(time.Now().Add(sc.s.conf.ReadTimeout))
} else { } else {
sc.nconn.SetReadDeadline(time.Time{}) sc.nconn.SetReadDeadline(time.Time{})
} }
if sc.readFrames { if sc.framesEnabled {
frame.Content = tcpFrameBuffer.Next() frame.Content = tcpFrameBuffer.Next()
what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br) what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br)
if err != nil { if err != nil {
@@ -349,6 +350,10 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, content []b
sc.mutex.Lock() sc.mutex.Lock()
defer sc.mutex.Unlock() defer sc.mutex.Unlock()
if !sc.framesEnabled {
return ErrServerFramesDisabled
}
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
frame := base.InterleavedFrame{ frame := base.InterleavedFrame{
TrackID: trackID, TrackID: trackID,