diff --git a/clientconnread.go b/clientconnread.go index 9ce369e7..daa47b19 100644 --- a/clientconnread.go +++ b/clientconnread.go @@ -33,7 +33,7 @@ func (c *ClientConn) Play() (*base.Response, error) { return res, nil } -func (c *ClientConn) backgroundPlayUDP(onFrameDone chan error) { +func (c *ClientConn) backgroundPlayUDP(done chan error) { defer close(c.backgroundDone) var returnError error @@ -44,7 +44,7 @@ func (c *ClientConn) backgroundPlayUDP(onFrameDone chan error) { c.udpRtcpListeners[trackID].stop() } - onFrameDone <- returnError + done <- returnError }() // open the firewall by sending packets to the counterpart @@ -141,13 +141,13 @@ func (c *ClientConn) backgroundPlayUDP(onFrameDone chan error) { } } -func (c *ClientConn) backgroundPlayTCP(onFrameDone chan error) { +func (c *ClientConn) backgroundPlayTCP(done chan error) { defer close(c.backgroundDone) var returnError error defer func() { - onFrameDone <- returnError + done <- returnError }() readerDone := make(chan error) @@ -209,30 +209,30 @@ func (c *ClientConn) backgroundPlayTCP(onFrameDone chan error) { } // OnFrame sets a callback that is called when a frame is received. -// it returns a channel that is called when the reading stops. +// it returns a channel that is written when the reading stops. // This can be called only after Play(). -func (c *ClientConn) OnFrame(cb func(int, StreamType, []byte)) chan error { +func (c *ClientConn) OnFrame(onFrame func(int, StreamType, []byte)) chan error { // channel is buffered, since listening to it is not mandatory - onFrameDone := make(chan error, 1) + done := make(chan error, 1) err := c.checkState(map[clientConnState]struct{}{ clientConnStatePrePlay: {}, }) if err != nil { - onFrameDone <- err - return onFrameDone + done <- err + return done } c.state = clientConnStatePlay - c.readCB = cb + c.readCB = onFrame c.backgroundTerminate = make(chan struct{}) c.backgroundDone = make(chan struct{}) if *c.streamProtocol == StreamProtocolUDP { - go c.backgroundPlayUDP(onFrameDone) + go c.backgroundPlayUDP(done) } else { - go c.backgroundPlayTCP(onFrameDone) + go c.backgroundPlayTCP(done) } - return onFrameDone + return done } diff --git a/examples/server.go b/examples/server.go index 6bc97c9d..c962f418 100644 --- a/examples/server.go +++ b/examples/server.go @@ -10,52 +10,61 @@ import ( "github.com/aler9/gortsplib/pkg/base" ) -type serverConnHandler struct { -} +func handleConn(conn *gortsplib.ServerConn) { + onRequest := func(req *base.Request) (*base.Response, error) { + switch req.Method { + case base.Options: + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Announce), + string(base.Setup), + string(base.Play), + string(base.Record), + string(base.Pause), + string(base.Teardown), + }, ", ")}, + }, + }, nil -func (sc *serverConnHandler) OnClose(err error) { -} + case base.Teardown: + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{}, + }, fmt.Errorf("terminated") + } -func (sc *serverConnHandler) OnRequest(req *base.Request) (*base.Response, error) { - switch req.Method { - case base.Options: return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Public": base.HeaderValue{strings.Join([]string{ - string(base.Describe), - string(base.Announce), - string(base.Setup), - string(base.Play), - string(base.Record), - string(base.Pause), - string(base.Teardown), - }, ", ")}, - }, - }, nil - - case base.Teardown: - return &base.Response{ - StatusCode: base.StatusOK, + StatusCode: base.StatusBadRequest, Header: base.Header{}, - }, fmt.Errorf("terminated") + }, fmt.Errorf("unhandled method: %v", req.Method) } - return &base.Response{ - StatusCode: base.StatusBadRequest, - Header: base.Header{}, - }, fmt.Errorf("unhandled method: %v", req.Method) -} + onFrame := func(id int, typ gortsplib.StreamType, buf []byte) { + } -func (sc *serverConnHandler) OnFrame(id int, typ gortsplib.StreamType, buf []byte) { + done := conn.Read(onRequest, onFrame) + + err := <-done + panic(err) } func main() { // create server - gortsplib.Serve(":8554", func(c *gortsplib.ServerConn) gortsplib.ServerConnHandler { - return &serverConnHandler{} - }) + s, err := gortsplib.Serve(":8554") + if err != nil { + panic(err) + } - // wait forever - select {} + // accept connections + for { + conn, err := s.Accept() + if err != nil { + panic(err) + } + + go handleConn(conn) + } } diff --git a/server.go b/server.go index 61dfda9c..02cc795a 100644 --- a/server.go +++ b/server.go @@ -3,58 +3,30 @@ package gortsplib import ( "bufio" "net" - "sync" - "time" ) // Server is a RTSP server. type Server struct { conf ServerConf listener *net.TCPListener - handler func(sc *ServerConn) ServerConnHandler - wg sync.WaitGroup } // Close closes the server. func (s *Server) Close() error { - s.listener.Close() - s.wg.Wait() - return nil + return s.listener.Close() } -func (s *Server) run() { - defer s.wg.Done() - - if s.conf.ReadTimeout == 0 { - s.conf.ReadTimeout = 10 * time.Second - } - if s.conf.WriteTimeout == 0 { - s.conf.WriteTimeout = 10 * time.Second - } - if s.conf.ReadBufferCount == 0 { - s.conf.ReadBufferCount = 1 +// Accept accepts a connection. +func (s *Server) Accept() (*ServerConn, error) { + nconn, err := s.listener.Accept() + if err != nil { + return nil, err } - for { - nconn, err := s.listener.Accept() - if err != nil { - break - } - - sc := &ServerConn{ - s: s, - nconn: nconn, - br: bufio.NewReaderSize(nconn, serverReadBufferSize), - bw: bufio.NewWriterSize(nconn, serverWriteBufferSize), - } - - sc.connHandler = s.handler(sc) - if sc.connHandler == nil { - nconn.Close() - continue - } - - s.wg.Add(1) - go sc.run() - } + return &ServerConn{ + s: s, + nconn: nconn, + br: bufio.NewReaderSize(nconn, serverReadBufferSize), + bw: bufio.NewWriterSize(nconn, serverWriteBufferSize), + }, nil } diff --git a/serverconf.go b/serverconf.go index 1cf08dff..e21cb5c0 100644 --- a/serverconf.go +++ b/serverconf.go @@ -9,8 +9,8 @@ import ( var DefaultServerConf = ServerConf{} // Serve starts a server on the given address. -func Serve(address string, handler func(sc *ServerConn) ServerConnHandler) (*Server, error) { - return DefaultServerConf.Serve(address, handler) +func Serve(address string) (*Server, error) { + return DefaultServerConf.Serve(address) } // ServerConf allows to configure a Server. @@ -36,7 +36,7 @@ type ServerConf struct { } // Serve starts a server on the given address. -func (c ServerConf) Serve(address string, handler func(sc *ServerConn) ServerConnHandler) (*Server, error) { +func (c ServerConf) Serve(address string) (*Server, error) { if c.ReadTimeout == 0 { c.ReadTimeout = 10 * time.Second } @@ -63,11 +63,7 @@ func (c ServerConf) Serve(address string, handler func(sc *ServerConn) ServerCon s := &Server{ conf: c, listener: listener, - handler: handler, } - s.wg.Add(1) - go s.run() - return s, nil } diff --git a/serverconn.go b/serverconn.go index 6039bd02..d9b70b06 100644 --- a/serverconn.go +++ b/serverconn.go @@ -16,18 +16,10 @@ const ( serverWriteBufferSize = 4096 ) -// ServerConnHandler is the interface that must be implemented to use a ServerConn. -type ServerConnHandler interface { - OnClose(err error) - OnRequest(req *base.Request) (*base.Response, error) - OnFrame(rackID int, streamType StreamType, content []byte) -} - // ServerConn is a server-side RTSP connection. type ServerConn struct { s *Server nconn net.Conn - connHandler ServerConnHandler br *bufio.Reader bw *bufio.Writer mutex sync.Mutex @@ -35,7 +27,7 @@ type ServerConn struct { readTimeout bool } -// Close closes all the ServerConn resources. +// Close closes all the connection resources. func (sc *ServerConn) Close() error { return sc.nconn.Close() } @@ -55,8 +47,36 @@ func (sc *ServerConn) EnableReadTimeout(v bool) { sc.readTimeout = v } -func (sc *ServerConn) run() { - defer sc.s.wg.Done() +func (sc *ServerConn) backgroundRead( + onRequest func(req *base.Request) (*base.Response, error), + onFrame func(trackID int, streamType StreamType, content []byte), + done chan error, +) { + handleRequest := func(req *base.Request) error { + sc.mutex.Lock() + defer sc.mutex.Unlock() + + // check cseq + cseq, ok := req.Header["CSeq"] + if !ok || len(cseq) != 1 { + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) + base.Response{ + StatusCode: base.StatusBadRequest, + Header: base.Header{}, + }.Write(sc.bw) + return fmt.Errorf("cseq is missing") + } + + res, err := onRequest(req) + + // add cseq to response + res.Header["CSeq"] = cseq + + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) + res.Write(sc.bw) + + return err + } var req base.Request var frame base.InterleavedFrame @@ -81,10 +101,10 @@ outer: switch what.(type) { case *base.InterleavedFrame: - sc.connHandler.OnFrame(frame.TrackID, frame.StreamType, frame.Content) + onFrame(frame.TrackID, frame.StreamType, frame.Content) case *base.Request: - err := sc.handleRequest(&req) + err := handleRequest(&req) if err != nil { errRet = err break outer @@ -98,7 +118,7 @@ outer: break outer } - err = sc.handleRequest(&req) + err = handleRequest(&req) if err != nil { errRet = err break outer @@ -107,33 +127,21 @@ outer: } sc.nconn.Close() - sc.connHandler.OnClose(errRet) + done <- errRet } -func (sc *ServerConn) handleRequest(req *base.Request) error { - sc.mutex.Lock() - defer sc.mutex.Unlock() +// Read starts reading requests and frames from the connection. +// it returns a channel that is written when the reading stops. +func (sc *ServerConn) Read( + onRequest func(req *base.Request) (*base.Response, error), + onFrame func(trackID int, streamType StreamType, content []byte), +) chan error { + // channel is buffered, since listening to it is not mandatory + done := make(chan error, 1) - // check cseq - cseq, ok := req.Header["CSeq"] - if !ok || len(cseq) != 1 { - sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) - base.Response{ - StatusCode: base.StatusBadRequest, - Header: base.Header{}, - }.Write(sc.bw) - return fmt.Errorf("cseq is missing") - } + go sc.backgroundRead(onRequest, onFrame, done) - res, err := sc.connHandler.OnRequest(req) - - // add cseq to response - res.Header["CSeq"] = cseq - - sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) - res.Write(sc.bw) - - return err + return done } // WriteFrame writes a frame.