diff --git a/README.md b/README.md index 2b4a7c32..b3fdb9e3 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Features: * Publish tracks to servers with UDP or TCP * Pause reading or publishing without disconnecting from the server * Server - * Handle server-side connections + * Build servers and handle publishers and readers ## Examples @@ -28,6 +28,7 @@ Features: * [client-publish](examples/client-publish.go) * [client-publish-options](examples/client-publish-options.go) * [client-publish-pause](examples/client-publish-pause.go) +* [server](examples/server.go) ## Documentation diff --git a/examples/server.go b/examples/server.go new file mode 100644 index 00000000..6bc97c9d --- /dev/null +++ b/examples/server.go @@ -0,0 +1,61 @@ +// +build ignore + +package main + +import ( + "fmt" + "strings" + + "github.com/aler9/gortsplib" + "github.com/aler9/gortsplib/pkg/base" +) + +type serverConnHandler struct { +} + +func (sc *serverConnHandler) OnClose(err error) { +} + +func (sc *serverConnHandler) OnRequest(req *base.Request) (*base.Response, error) { + switch req.Method { + case base.Options: + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Announce), + string(base.Setup), + string(base.Play), + string(base.Record), + string(base.Pause), + string(base.Teardown), + }, ", ")}, + }, + }, nil + + case base.Teardown: + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{}, + }, fmt.Errorf("terminated") + } + + return &base.Response{ + StatusCode: base.StatusBadRequest, + Header: base.Header{}, + }, fmt.Errorf("unhandled method: %v", req.Method) +} + +func (sc *serverConnHandler) OnFrame(id int, typ gortsplib.StreamType, buf []byte) { +} + +func main() { + // create server + gortsplib.Serve(":8554", func(c *gortsplib.ServerConn) gortsplib.ServerConnHandler { + return &serverConnHandler{} + }) + + // wait forever + select {} +} diff --git a/server.go b/server.go index b9a5521d..61dfda9c 100644 --- a/server.go +++ b/server.go @@ -3,33 +3,27 @@ package gortsplib import ( "bufio" "net" + "sync" "time" - - "github.com/aler9/gortsplib/pkg/base" - "github.com/aler9/gortsplib/pkg/multibuffer" ) -// ServerHandler is the interface that must be implemented to use a Server. -type ServerHandler interface { -} - // Server is a RTSP server. type Server struct { conf ServerConf listener *net.TCPListener + handler func(sc *ServerConn) ServerConnHandler + wg sync.WaitGroup } // Close closes the server. func (s *Server) Close() error { - return s.listener.Close() + s.listener.Close() + s.wg.Wait() + return nil } -// Accept accepts a connection. -func (s *Server) Accept() (*ServerConn, error) { - nconn, err := s.listener.Accept() - if err != nil { - return nil, err - } +func (s *Server) run() { + defer s.wg.Done() if s.conf.ReadTimeout == 0 { s.conf.ReadTimeout = 10 * time.Second @@ -41,15 +35,26 @@ func (s *Server) Accept() (*ServerConn, error) { s.conf.ReadBufferCount = 1 } - sc := &ServerConn{ - conf: s.conf, - nconn: nconn, - br: bufio.NewReaderSize(nconn, serverReadBufferSize), - bw: bufio.NewWriterSize(nconn, serverWriteBufferSize), - request: &base.Request{}, - frame: &base.InterleavedFrame{}, - tcpFrameBuffer: multibuffer.New(s.conf.ReadBufferCount, clientTCPFrameReadBufferSize), - } + for { + nconn, err := s.listener.Accept() + if err != nil { + break + } - return sc, nil + sc := &ServerConn{ + s: s, + nconn: nconn, + br: bufio.NewReaderSize(nconn, serverReadBufferSize), + bw: bufio.NewWriterSize(nconn, serverWriteBufferSize), + } + + sc.connHandler = s.handler(sc) + if sc.connHandler == nil { + nconn.Close() + continue + } + + s.wg.Add(1) + go sc.run() + } } diff --git a/serverconf.go b/serverconf.go index 6f69d226..1cf08dff 100644 --- a/serverconf.go +++ b/serverconf.go @@ -9,7 +9,7 @@ import ( var DefaultServerConf = ServerConf{} // Serve starts a server on the given address. -func Serve(address string, handler ServerHandler) (*Server, error) { +func Serve(address string, handler func(sc *ServerConn) ServerConnHandler) (*Server, error) { return DefaultServerConf.Serve(address, handler) } @@ -36,7 +36,7 @@ type ServerConf struct { } // Serve starts a server on the given address. -func (c ServerConf) Serve(address string, handler ServerHandler) (*Server, error) { +func (c ServerConf) Serve(address string, handler func(sc *ServerConn) ServerConnHandler) (*Server, error) { if c.ReadTimeout == 0 { c.ReadTimeout = 10 * time.Second } @@ -63,7 +63,11 @@ func (c ServerConf) Serve(address string, handler ServerHandler) (*Server, error s := &Server{ conf: c, listener: listener, + handler: handler, } + s.wg.Add(1) + go s.run() + return s, nil } diff --git a/serverconn.go b/serverconn.go index 61b11e66..6039bd02 100644 --- a/serverconn.go +++ b/serverconn.go @@ -2,7 +2,9 @@ package gortsplib import ( "bufio" + "fmt" "net" + "sync" "time" "github.com/aler9/gortsplib/pkg/base" @@ -14,63 +16,137 @@ const ( serverWriteBufferSize = 4096 ) +// ServerConnHandler is the interface that must be implemented to use a ServerConn. +type ServerConnHandler interface { + OnClose(err error) + OnRequest(req *base.Request) (*base.Response, error) + OnFrame(rackID int, streamType StreamType, content []byte) +} + // ServerConn is a server-side RTSP connection. type ServerConn struct { - conf ServerConf - nconn net.Conn - br *bufio.Reader - bw *bufio.Writer - request *base.Request - frame *base.InterleavedFrame - tcpFrameBuffer *multibuffer.MultiBuffer + s *Server + nconn net.Conn + connHandler ServerConnHandler + br *bufio.Reader + bw *bufio.Writer + mutex sync.Mutex + frames bool + readTimeout bool } // Close closes all the ServerConn resources. -func (s *ServerConn) Close() error { - return s.nconn.Close() +func (sc *ServerConn) Close() error { + return sc.nconn.Close() } // NetConn returns the underlying net.Conn. -func (s *ServerConn) NetConn() net.Conn { - return s.nconn +func (sc *ServerConn) NetConn() net.Conn { + return sc.nconn } -// ReadRequest reads a Request. -func (s *ServerConn) ReadRequest() (*base.Request, error) { - s.nconn.SetReadDeadline(time.Time{}) // disable deadline - err := s.request.Read(s.br) - if err != nil { - return nil, err +// EnableFrames allows or denies receiving frames. +func (sc *ServerConn) EnableFrames(v bool) { + sc.frames = v +} + +// EnableReadTimeout sets or removes the timeout on incoming packets. +func (sc *ServerConn) EnableReadTimeout(v bool) { + sc.readTimeout = v +} + +func (sc *ServerConn) run() { + defer sc.s.wg.Done() + + var req base.Request + var frame base.InterleavedFrame + tcpFrameBuffer := multibuffer.New(sc.s.conf.ReadBufferCount, clientTCPFrameReadBufferSize) + var errRet error + +outer: + for { + if sc.readTimeout { + sc.nconn.SetReadDeadline(time.Now().Add(sc.s.conf.ReadTimeout)) + } else { + sc.nconn.SetReadDeadline(time.Time{}) + } + + if sc.frames { + frame.Content = tcpFrameBuffer.Next() + what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br) + if err != nil { + errRet = err + break outer + } + + switch what.(type) { + case *base.InterleavedFrame: + sc.connHandler.OnFrame(frame.TrackID, frame.StreamType, frame.Content) + + case *base.Request: + err := sc.handleRequest(&req) + if err != nil { + errRet = err + break outer + } + } + + } else { + err := req.Read(sc.br) + if err != nil { + errRet = err + break outer + } + + err = sc.handleRequest(&req) + if err != nil { + errRet = err + break outer + } + } } - return s.request, nil + sc.nconn.Close() + sc.connHandler.OnClose(errRet) } -// ReadFrameTCPOrRequest reads an InterleavedFrame or a Request. -func (s *ServerConn) ReadFrameTCPOrRequest(timeout bool) (interface{}, error) { - s.frame.Content = s.tcpFrameBuffer.Next() +func (sc *ServerConn) handleRequest(req *base.Request) error { + sc.mutex.Lock() + defer sc.mutex.Unlock() - if timeout { - s.nconn.SetReadDeadline(time.Now().Add(s.conf.ReadTimeout)) + // check cseq + cseq, ok := req.Header["CSeq"] + if !ok || len(cseq) != 1 { + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) + base.Response{ + StatusCode: base.StatusBadRequest, + Header: base.Header{}, + }.Write(sc.bw) + return fmt.Errorf("cseq is missing") } - return base.ReadInterleavedFrameOrRequest(s.frame, s.request, s.br) + res, err := sc.connHandler.OnRequest(req) + + // add cseq to response + res.Header["CSeq"] = cseq + + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) + res.Write(sc.bw) + + return err } -// WriteResponse writes a Response. -func (s *ServerConn) WriteResponse(res *base.Response) error { - s.nconn.SetWriteDeadline(time.Now().Add(s.conf.WriteTimeout)) - return res.Write(s.bw) -} +// WriteFrame writes a frame. +func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, content []byte) error { + sc.mutex.Lock() + defer sc.mutex.Unlock() -// WriteFrameTCP writes an InterleavedFrame. -func (s *ServerConn) WriteFrameTCP(trackID int, streamType StreamType, content []byte) error { frame := base.InterleavedFrame{ TrackID: trackID, StreamType: streamType, Content: content, } - s.nconn.SetWriteDeadline(time.Now().Add(s.conf.WriteTimeout)) - return frame.Write(s.bw) + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout)) + return frame.Write(sc.bw) }