new server structure

This commit is contained in:
aler9
2020-12-08 21:15:02 +01:00
parent 927511d81e
commit c7f6d77392
5 changed files with 118 additions and 133 deletions

View File

@@ -33,7 +33,7 @@ func (c *ClientConn) Play() (*base.Response, error) {
return res, nil return res, nil
} }
func (c *ClientConn) backgroundPlayUDP(onFrameDone chan error) { func (c *ClientConn) backgroundPlayUDP(done chan error) {
defer close(c.backgroundDone) defer close(c.backgroundDone)
var returnError error var returnError error
@@ -44,7 +44,7 @@ func (c *ClientConn) backgroundPlayUDP(onFrameDone chan error) {
c.udpRtcpListeners[trackID].stop() c.udpRtcpListeners[trackID].stop()
} }
onFrameDone <- returnError done <- returnError
}() }()
// open the firewall by sending packets to the counterpart // open the firewall by sending packets to the counterpart
@@ -141,13 +141,13 @@ func (c *ClientConn) backgroundPlayUDP(onFrameDone chan error) {
} }
} }
func (c *ClientConn) backgroundPlayTCP(onFrameDone chan error) { func (c *ClientConn) backgroundPlayTCP(done chan error) {
defer close(c.backgroundDone) defer close(c.backgroundDone)
var returnError error var returnError error
defer func() { defer func() {
onFrameDone <- returnError done <- returnError
}() }()
readerDone := make(chan error) readerDone := make(chan error)
@@ -209,30 +209,30 @@ func (c *ClientConn) backgroundPlayTCP(onFrameDone chan error) {
} }
// OnFrame sets a callback that is called when a frame is received. // OnFrame sets a callback that is called when a frame is received.
// it returns a channel that is called when the reading stops. // it returns a channel that is written when the reading stops.
// This can be called only after Play(). // This can be called only after Play().
func (c *ClientConn) OnFrame(cb func(int, StreamType, []byte)) chan error { func (c *ClientConn) OnFrame(onFrame func(int, StreamType, []byte)) chan error {
// channel is buffered, since listening to it is not mandatory // channel is buffered, since listening to it is not mandatory
onFrameDone := make(chan error, 1) done := make(chan error, 1)
err := c.checkState(map[clientConnState]struct{}{ err := c.checkState(map[clientConnState]struct{}{
clientConnStatePrePlay: {}, clientConnStatePrePlay: {},
}) })
if err != nil { if err != nil {
onFrameDone <- err done <- err
return onFrameDone return done
} }
c.state = clientConnStatePlay c.state = clientConnStatePlay
c.readCB = cb c.readCB = onFrame
c.backgroundTerminate = make(chan struct{}) c.backgroundTerminate = make(chan struct{})
c.backgroundDone = make(chan struct{}) c.backgroundDone = make(chan struct{})
if *c.streamProtocol == StreamProtocolUDP { if *c.streamProtocol == StreamProtocolUDP {
go c.backgroundPlayUDP(onFrameDone) go c.backgroundPlayUDP(done)
} else { } else {
go c.backgroundPlayTCP(onFrameDone) go c.backgroundPlayTCP(done)
} }
return onFrameDone return done
} }

View File

@@ -10,13 +10,8 @@ import (
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
) )
type serverConnHandler struct { func handleConn(conn *gortsplib.ServerConn) {
} onRequest := func(req *base.Request) (*base.Response, error) {
func (sc *serverConnHandler) OnClose(err error) {
}
func (sc *serverConnHandler) OnRequest(req *base.Request) (*base.Response, error) {
switch req.Method { switch req.Method {
case base.Options: case base.Options:
return &base.Response{ return &base.Response{
@@ -45,17 +40,31 @@ func (sc *serverConnHandler) OnRequest(req *base.Request) (*base.Response, error
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
Header: base.Header{}, Header: base.Header{},
}, fmt.Errorf("unhandled method: %v", req.Method) }, fmt.Errorf("unhandled method: %v", req.Method)
} }
func (sc *serverConnHandler) OnFrame(id int, typ gortsplib.StreamType, buf []byte) { onFrame := func(id int, typ gortsplib.StreamType, buf []byte) {
}
done := conn.Read(onRequest, onFrame)
err := <-done
panic(err)
} }
func main() { func main() {
// create server // create server
gortsplib.Serve(":8554", func(c *gortsplib.ServerConn) gortsplib.ServerConnHandler { s, err := gortsplib.Serve(":8554")
return &serverConnHandler{} if err != nil {
}) panic(err)
}
// wait forever // accept connections
select {} for {
conn, err := s.Accept()
if err != nil {
panic(err)
}
go handleConn(conn)
}
} }

View File

@@ -3,58 +3,30 @@ package gortsplib
import ( import (
"bufio" "bufio"
"net" "net"
"sync"
"time"
) )
// Server is a RTSP server. // Server is a RTSP server.
type Server struct { type Server struct {
conf ServerConf conf ServerConf
listener *net.TCPListener listener *net.TCPListener
handler func(sc *ServerConn) ServerConnHandler
wg sync.WaitGroup
} }
// Close closes the server. // Close closes the server.
func (s *Server) Close() error { func (s *Server) Close() error {
s.listener.Close() return s.listener.Close()
s.wg.Wait()
return nil
} }
func (s *Server) run() { // Accept accepts a connection.
defer s.wg.Done() func (s *Server) Accept() (*ServerConn, error) {
if s.conf.ReadTimeout == 0 {
s.conf.ReadTimeout = 10 * time.Second
}
if s.conf.WriteTimeout == 0 {
s.conf.WriteTimeout = 10 * time.Second
}
if s.conf.ReadBufferCount == 0 {
s.conf.ReadBufferCount = 1
}
for {
nconn, err := s.listener.Accept() nconn, err := s.listener.Accept()
if err != nil { if err != nil {
break return nil, err
} }
sc := &ServerConn{ return &ServerConn{
s: s, s: s,
nconn: nconn, nconn: nconn,
br: bufio.NewReaderSize(nconn, serverReadBufferSize), br: bufio.NewReaderSize(nconn, serverReadBufferSize),
bw: bufio.NewWriterSize(nconn, serverWriteBufferSize), bw: bufio.NewWriterSize(nconn, serverWriteBufferSize),
} }, nil
sc.connHandler = s.handler(sc)
if sc.connHandler == nil {
nconn.Close()
continue
}
s.wg.Add(1)
go sc.run()
}
} }

View File

@@ -9,8 +9,8 @@ import (
var DefaultServerConf = ServerConf{} var DefaultServerConf = ServerConf{}
// Serve starts a server on the given address. // Serve starts a server on the given address.
func Serve(address string, handler func(sc *ServerConn) ServerConnHandler) (*Server, error) { func Serve(address string) (*Server, error) {
return DefaultServerConf.Serve(address, handler) return DefaultServerConf.Serve(address)
} }
// ServerConf allows to configure a Server. // ServerConf allows to configure a Server.
@@ -36,7 +36,7 @@ type ServerConf struct {
} }
// Serve starts a server on the given address. // Serve starts a server on the given address.
func (c ServerConf) Serve(address string, handler func(sc *ServerConn) ServerConnHandler) (*Server, error) { func (c ServerConf) Serve(address string) (*Server, error) {
if c.ReadTimeout == 0 { if c.ReadTimeout == 0 {
c.ReadTimeout = 10 * time.Second c.ReadTimeout = 10 * time.Second
} }
@@ -63,11 +63,7 @@ func (c ServerConf) Serve(address string, handler func(sc *ServerConn) ServerCon
s := &Server{ s := &Server{
conf: c, conf: c,
listener: listener, listener: listener,
handler: handler,
} }
s.wg.Add(1)
go s.run()
return s, nil return s, nil
} }

View File

@@ -16,18 +16,10 @@ const (
serverWriteBufferSize = 4096 serverWriteBufferSize = 4096
) )
// ServerConnHandler is the interface that must be implemented to use a ServerConn.
type ServerConnHandler interface {
OnClose(err error)
OnRequest(req *base.Request) (*base.Response, error)
OnFrame(rackID int, streamType StreamType, content []byte)
}
// ServerConn is a server-side RTSP connection. // ServerConn is a server-side RTSP connection.
type ServerConn struct { type ServerConn struct {
s *Server s *Server
nconn net.Conn nconn net.Conn
connHandler ServerConnHandler
br *bufio.Reader br *bufio.Reader
bw *bufio.Writer bw *bufio.Writer
mutex sync.Mutex mutex sync.Mutex
@@ -35,7 +27,7 @@ type ServerConn struct {
readTimeout bool readTimeout bool
} }
// Close closes all the ServerConn resources. // Close closes all the connection resources.
func (sc *ServerConn) Close() error { func (sc *ServerConn) Close() error {
return sc.nconn.Close() return sc.nconn.Close()
} }
@@ -55,8 +47,36 @@ func (sc *ServerConn) EnableReadTimeout(v bool) {
sc.readTimeout = v sc.readTimeout = v
} }
func (sc *ServerConn) run() { func (sc *ServerConn) backgroundRead(
defer sc.s.wg.Done() onRequest func(req *base.Request) (*base.Response, error),
onFrame func(trackID int, streamType StreamType, content []byte),
done chan error,
) {
handleRequest := func(req *base.Request) error {
sc.mutex.Lock()
defer sc.mutex.Unlock()
// check cseq
cseq, ok := req.Header["CSeq"]
if !ok || len(cseq) != 1 {
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
base.Response{
StatusCode: base.StatusBadRequest,
Header: base.Header{},
}.Write(sc.bw)
return fmt.Errorf("cseq is missing")
}
res, err := onRequest(req)
// add cseq to response
res.Header["CSeq"] = cseq
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
res.Write(sc.bw)
return err
}
var req base.Request var req base.Request
var frame base.InterleavedFrame var frame base.InterleavedFrame
@@ -81,10 +101,10 @@ outer:
switch what.(type) { switch what.(type) {
case *base.InterleavedFrame: case *base.InterleavedFrame:
sc.connHandler.OnFrame(frame.TrackID, frame.StreamType, frame.Content) onFrame(frame.TrackID, frame.StreamType, frame.Content)
case *base.Request: case *base.Request:
err := sc.handleRequest(&req) err := handleRequest(&req)
if err != nil { if err != nil {
errRet = err errRet = err
break outer break outer
@@ -98,7 +118,7 @@ outer:
break outer break outer
} }
err = sc.handleRequest(&req) err = handleRequest(&req)
if err != nil { if err != nil {
errRet = err errRet = err
break outer break outer
@@ -107,33 +127,21 @@ outer:
} }
sc.nconn.Close() sc.nconn.Close()
sc.connHandler.OnClose(errRet) done <- errRet
} }
func (sc *ServerConn) handleRequest(req *base.Request) error { // Read starts reading requests and frames from the connection.
sc.mutex.Lock() // it returns a channel that is written when the reading stops.
defer sc.mutex.Unlock() func (sc *ServerConn) Read(
onRequest func(req *base.Request) (*base.Response, error),
onFrame func(trackID int, streamType StreamType, content []byte),
) chan error {
// channel is buffered, since listening to it is not mandatory
done := make(chan error, 1)
// check cseq go sc.backgroundRead(onRequest, onFrame, done)
cseq, ok := req.Header["CSeq"]
if !ok || len(cseq) != 1 {
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
base.Response{
StatusCode: base.StatusBadRequest,
Header: base.Header{},
}.Write(sc.bw)
return fmt.Errorf("cseq is missing")
}
res, err := sc.connHandler.OnRequest(req) return done
// add cseq to response
res.Header["CSeq"] = cseq
sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.conf.WriteTimeout))
res.Write(sc.bw)
return err
} }
// WriteFrame writes a frame. // WriteFrame writes a frame.