diff --git a/serverconn.go b/serverconn.go index 1c04e6ad..712febf8 100644 --- a/serverconn.go +++ b/serverconn.go @@ -34,7 +34,8 @@ type ServerConn struct { nconn net.Conn br *bufio.Reader bw *bufio.Writer - mutex sync.Mutex + writeMutex sync.Mutex + nextFramesEnabled bool framesEnabled bool readTimeoutEnabled bool } @@ -51,7 +52,7 @@ func (sc *ServerConn) NetConn() net.Conn { // EnableFrames allows reading and writing TCP frames. func (sc *ServerConn) EnableFrames(v bool) { - sc.framesEnabled = v + sc.nextFramesEnabled = v } // EnableReadTimeout sets or removes the timeout on incoming packets. @@ -247,17 +248,16 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan } handleRequestOuter := func(req *base.Request) error { - sc.mutex.Lock() - defer sc.mutex.Unlock() - // check cseq cseq, ok := req.Header["CSeq"] if !ok || len(cseq) != 1 { + sc.writeMutex.Lock() sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) base.Response{ StatusCode: base.StatusBadRequest, Header: base.Header{}, }.Write(sc.bw) + sc.writeMutex.Unlock() return ErrServerMissingCseq } @@ -277,9 +277,19 @@ func (sc *ServerConn) backgroundRead(handlers ServerConnReadHandlers, done chan handlers.OnResponse(res) } + sc.writeMutex.Lock() + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) res.Write(sc.bw) + // set framesEnabled after sending the response + // in order to start sending frames after the response + if sc.framesEnabled != sc.nextFramesEnabled { + sc.framesEnabled = sc.nextFramesEnabled + } + + sc.writeMutex.Unlock() + return err } @@ -347,8 +357,8 @@ func (sc *ServerConn) Read(handlers ServerConnReadHandlers) chan error { // WriteFrame writes a frame. func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, content []byte) error { - sc.mutex.Lock() - defer sc.mutex.Unlock() + sc.writeMutex.Lock() + defer sc.writeMutex.Unlock() if !sc.framesEnabled { return ErrServerFramesDisabled