diff --git a/README.md b/README.md index f363e3e7..0794c5e3 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,6 @@ Features: * [client-publish-options](examples/client-publish-options/main.go) * [client-publish-pause](examples/client-publish-pause/main.go) * [server](examples/server/main.go) -* [server-udp](examples/server-udp/main.go) * [server-tls](examples/server-tls/main.go) ## API Documentation diff --git a/examples/server-tls/main.go b/examples/server-tls/main.go index 768f903d..d78fa91c 100644 --- a/examples/server-tls/main.go +++ b/examples/server-tls/main.go @@ -15,131 +15,126 @@ import ( // 2. allow a single client to publish a stream with TCP // 3. allow multiple clients to read that stream with TCP -var mutex sync.Mutex -var publisher *gortsplib.ServerConn -var readers = make(map[*gortsplib.ServerConn]struct{}) -var sdp []byte +type serverHandler struct { + mutex sync.Mutex + publisher *gortsplib.ServerConn + readers map[*gortsplib.ServerConn]struct{} + sdp []byte +} -// this is called for each incoming connection -func handleConn(conn *gortsplib.ServerConn) { - defer conn.Close() +// called when a connection is opened. +func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) { + log.Printf("conn opened") +} - log.Printf("client connected") +// called when a connection is closed. +func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) { + log.Println("conn closed (%v)", err) - // called after receiving a DESCRIBE request. - onDescribe := func(ctx *gortsplib.ServerConnDescribeCtx) (*base.Response, []byte, error) { - mutex.Lock() - defer mutex.Unlock() + sh.mutex.Lock() + defer sh.mutex.Unlock() - // no one is publishing yet - if publisher == nil { - return &base.Response{ - StatusCode: base.StatusNotFound, - }, nil, nil - } - - return &base.Response{ - StatusCode: base.StatusOK, - }, sdp, nil - } - - // called after receiving an ANNOUNCE request. - onAnnounce := func(ctx *gortsplib.ServerConnAnnounceCtx) (*base.Response, error) { - mutex.Lock() - defer mutex.Unlock() - - if publisher != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("someone is already publishing") - } - - publisher = conn - sdp = ctx.Tracks.Write() - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - } - - // called after receiving a SETUP request. - onSetup := func(ctx *gortsplib.ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - } - - // called after receiving a PLAY request. - onPlay := func(ctx *gortsplib.ServerConnPlayCtx) (*base.Response, error) { - mutex.Lock() - defer mutex.Unlock() - - readers[conn] = struct{}{} - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - } - - // called after receiving a RECORD request. - onRecord := func(ctx *gortsplib.ServerConnRecordCtx) (*base.Response, error) { - mutex.Lock() - defer mutex.Unlock() - - if conn != publisher { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("someone is already publishing") - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - } - - // called after receiving a frame. - onFrame := func(trackID int, typ gortsplib.StreamType, buf []byte) { - mutex.Lock() - defer mutex.Unlock() - - // if we are the publisher, route frames to readers - if conn == publisher { - for r := range readers { - r.WriteFrame(trackID, typ, buf) - } - } - } - - err := <-conn.Read(gortsplib.ServerConnReadHandlers{ - OnDescribe: onDescribe, - OnAnnounce: onAnnounce, - OnSetup: onSetup, - OnPlay: onPlay, - OnRecord: onRecord, - OnFrame: onFrame, - }) - log.Printf("client disconnected (%s)", err) - - mutex.Lock() - defer mutex.Unlock() - - if conn == publisher { - publisher = nil - sdp = nil + if sc == sh.publisher { + sh.publisher = nil + sh.sdp = nil } else { - delete(readers, conn) + delete(sh.readers, sc) + } +} + +// called after receiving a DESCRIBE request. +func (sh *serverHandler) OnDescribe(ctx *gortsplib.ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) { + sh.mutex.Lock() + defer sh.mutex.Unlock() + + // no one is publishing yet + if sh.publisher == nil { + return &base.Response{ + StatusCode: base.StatusNotFound, + }, nil, nil + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, sh.sdp, nil +} + +// called after receiving an ANNOUNCE request. +func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (*base.Response, error) { + sh.mutex.Lock() + defer sh.mutex.Unlock() + + if sh.publisher != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("someone is already publishing") + } + + sh.publisher = ctx.Conn + sh.sdp = ctx.Tracks.Write() + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil +} + +// called after receiving a SETUP request. +func (sh *serverHandler) OnSetup(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil +} + +// called after receiving a PLAY request. +func (sh *serverHandler) OnPlay(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Response, error) { + sh.mutex.Lock() + defer sh.mutex.Unlock() + + sh.readers[ctx.Conn] = struct{}{} + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil +} + +// called after receiving a RECORD request. +func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*base.Response, error) { + sh.mutex.Lock() + defer sh.mutex.Unlock() + + if ctx.Conn != sh.publisher { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("someone is already publishing") + } + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil +} + +// called after receiving a frame. +func (sh *serverHandler) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) { + sh.mutex.Lock() + defer sh.mutex.Unlock() + + // if we are the publisher, route frames to readers + if ctx.Conn == sh.publisher { + for r := range sh.readers { + r.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) + } } } @@ -152,24 +147,12 @@ func main() { panic(err) } - // create server + // configure server s := &gortsplib.Server{ + Handler: &serverHandler{}, TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}}, } - err = s.Serve(":8554") - if err != nil { - panic(err) - } - log.Printf("server is ready") - - // accept connections - for { - conn, err := s.Accept() - if err != nil { - panic(err) - } - - go handleConn(conn) - } + // start server and wait until a fatal error + panic(s.StartAndWait(":8554")) } diff --git a/examples/server-udp/main.go b/examples/server-udp/main.go deleted file mode 100644 index abedd32b..00000000 --- a/examples/server-udp/main.go +++ /dev/null @@ -1,167 +0,0 @@ -package main - -import ( - "fmt" - "log" - "sync" - - "github.com/aler9/gortsplib" - "github.com/aler9/gortsplib/pkg/base" -) - -// This example shows how to -// 1. create a RTSP server which accepts plain connections -// 2. allow a single client to publish a stream with TCP or UDP -// 3. allow multiple clients to read that stream with TCP or UDP - -var mutex sync.Mutex -var publisher *gortsplib.ServerConn -var readers = make(map[*gortsplib.ServerConn]struct{}) -var sdp []byte - -// this is called for each incoming connection -func handleConn(conn *gortsplib.ServerConn) { - defer conn.Close() - - log.Printf("client connected") - - // called after receiving a DESCRIBE request. - onDescribe := func(ctx *gortsplib.ServerConnDescribeCtx) (*base.Response, []byte, error) { - mutex.Lock() - defer mutex.Unlock() - - // no one is publishing yet - if publisher == nil { - return &base.Response{ - StatusCode: base.StatusNotFound, - }, nil, nil - } - - return &base.Response{ - StatusCode: base.StatusOK, - }, sdp, nil - } - - // called after receiving an ANNOUNCE request. - onAnnounce := func(ctx *gortsplib.ServerConnAnnounceCtx) (*base.Response, error) { - mutex.Lock() - defer mutex.Unlock() - - if publisher != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("someone is already publishing") - } - - publisher = conn - sdp = ctx.Tracks.Write() - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - } - - // called after receiving a SETUP request. - onSetup := func(ctx *gortsplib.ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - } - - // called after receiving a PLAY request. - onPlay := func(ctx *gortsplib.ServerConnPlayCtx) (*base.Response, error) { - mutex.Lock() - defer mutex.Unlock() - - readers[conn] = struct{}{} - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - } - - // called after receiving a RECORD request. - onRecord := func(ctx *gortsplib.ServerConnRecordCtx) (*base.Response, error) { - mutex.Lock() - defer mutex.Unlock() - - if conn != publisher { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("someone is already publishing") - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - } - - // called after receiving a frame. - onFrame := func(trackID int, typ gortsplib.StreamType, buf []byte) { - mutex.Lock() - defer mutex.Unlock() - - // if we are the publisher, route frames to readers - if conn == publisher { - for r := range readers { - r.WriteFrame(trackID, typ, buf) - } - } - } - - err := <-conn.Read(gortsplib.ServerConnReadHandlers{ - OnDescribe: onDescribe, - OnAnnounce: onAnnounce, - OnSetup: onSetup, - OnPlay: onPlay, - OnRecord: onRecord, - OnFrame: onFrame, - }) - log.Printf("client disconnected (%s)", err) - - mutex.Lock() - defer mutex.Unlock() - - if conn == publisher { - publisher = nil - sdp = nil - } else { - delete(readers, conn) - } -} - -func main() { - // create server - s := &gortsplib.Server{ - UDPRTPAddress: ":8000", - UDPRTCPAddress: ":8001", - } - err := s.Serve(":8554") - if err != nil { - panic(err) - } - - log.Printf("server is ready") - - // accept connections - for { - conn, err := s.Accept() - if err != nil { - panic(err) - } - - go handleConn(conn) - } -} diff --git a/examples/server/main.go b/examples/server/main.go index 0d421c1c..3ec6d661 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -11,154 +11,140 @@ import ( // This example shows how to // 1. create a RTSP server which accepts plain connections -// 2. allow a single client to publish a stream with TCP -// 3. allow multiple clients to read that stream with TCP +// 2. allow a single client to publish a stream with TCP or UDP +// 3. allow multiple clients to read that stream with TCP or UDP -var mutex sync.Mutex -var publisher *gortsplib.ServerConn -var readers = make(map[*gortsplib.ServerConn]struct{}) -var sdp []byte +type serverHandler struct { + mutex sync.Mutex + publisher *gortsplib.ServerConn + readers map[*gortsplib.ServerConn]struct{} + sdp []byte +} -// this is called for each incoming connection -func handleConn(conn *gortsplib.ServerConn) { - defer conn.Close() +// called when a connection is opened. +func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) { + log.Printf("conn opened") +} - log.Printf("client connected") +// called when a connection is closed. +func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) { + log.Println("conn closed (%v)", err) - // called after receiving a DESCRIBE request. - onDescribe := func(ctx *gortsplib.ServerConnDescribeCtx) (*base.Response, []byte, error) { - mutex.Lock() - defer mutex.Unlock() + sh.mutex.Lock() + defer sh.mutex.Unlock() - // no one is publishing yet - if publisher == nil { - return &base.Response{ - StatusCode: base.StatusNotFound, - }, nil, nil - } - - return &base.Response{ - StatusCode: base.StatusOK, - }, sdp, nil - } - - // called after receiving an ANNOUNCE request. - onAnnounce := func(ctx *gortsplib.ServerConnAnnounceCtx) (*base.Response, error) { - mutex.Lock() - defer mutex.Unlock() - - if publisher != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("someone is already publishing") - } - - publisher = conn - sdp = ctx.Tracks.Write() - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - } - - // called after receiving a SETUP request. - onSetup := func(ctx *gortsplib.ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - } - - // called after receiving a PLAY request. - onPlay := func(ctx *gortsplib.ServerConnPlayCtx) (*base.Response, error) { - mutex.Lock() - defer mutex.Unlock() - - readers[conn] = struct{}{} - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - } - - // called after receiving a RECORD request. - onRecord := func(ctx *gortsplib.ServerConnRecordCtx) (*base.Response, error) { - mutex.Lock() - defer mutex.Unlock() - - if conn != publisher { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("someone is already publishing") - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil - } - - // called after receiving a frame. - onFrame := func(trackID int, typ gortsplib.StreamType, buf []byte) { - mutex.Lock() - defer mutex.Unlock() - - // if we are the publisher, route frames to readers - if conn == publisher { - for r := range readers { - r.WriteFrame(trackID, typ, buf) - } - } - } - - err := <-conn.Read(gortsplib.ServerConnReadHandlers{ - OnDescribe: onDescribe, - OnAnnounce: onAnnounce, - OnSetup: onSetup, - OnPlay: onPlay, - OnRecord: onRecord, - OnFrame: onFrame, - }) - log.Printf("client disconnected (%s)", err) - - mutex.Lock() - defer mutex.Unlock() - - if conn == publisher { - publisher = nil - sdp = nil + if sc == sh.publisher { + sh.publisher = nil + sh.sdp = nil } else { - delete(readers, conn) + delete(sh.readers, sc) + } +} + +// called after receiving a DESCRIBE request. +func (sh *serverHandler) OnDescribe(ctx *gortsplib.ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) { + sh.mutex.Lock() + defer sh.mutex.Unlock() + + // no one is publishing yet + if sh.publisher == nil { + return &base.Response{ + StatusCode: base.StatusNotFound, + }, nil, nil + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, sh.sdp, nil +} + +// called after receiving an ANNOUNCE request. +func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (*base.Response, error) { + sh.mutex.Lock() + defer sh.mutex.Unlock() + + if sh.publisher != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("someone is already publishing") + } + + sh.publisher = ctx.Conn + sh.sdp = ctx.Tracks.Write() + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil +} + +// called after receiving a SETUP request. +func (sh *serverHandler) OnSetup(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil +} + +// called after receiving a PLAY request. +func (sh *serverHandler) OnPlay(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Response, error) { + sh.mutex.Lock() + defer sh.mutex.Unlock() + + sh.readers[ctx.Conn] = struct{}{} + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil +} + +// called after receiving a RECORD request. +func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*base.Response, error) { + sh.mutex.Lock() + defer sh.mutex.Unlock() + + if ctx.Conn != sh.publisher { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("someone is already publishing") + } + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil +} + +// called after receiving a frame. +func (sh *serverHandler) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) { + sh.mutex.Lock() + defer sh.mutex.Unlock() + + // if we are the publisher, route frames to readers + if ctx.Conn == sh.publisher { + for r := range sh.readers { + r.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) + } } } func main() { - // create server - s := &gortsplib.Server{} - err := s.Serve(":8554") - if err != nil { - panic(err) + // configure server + s := &gortsplib.Server{ + Handler: &serverHandler{}, + UDPRTPAddress: ":8000", + UDPRTCPAddress: ":8001", } - log.Printf("server is ready") - - // accept connections - for { - conn, err := s.Accept() - if err != nil { - panic(err) - } - - go handleConn(conn) - } + // start server and wait until a fatal error + panic(s.StartAndWait(":8554")) } diff --git a/server.go b/server.go index dfe22091..9476971e 100644 --- a/server.go +++ b/server.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "strconv" + "sync" "time" ) @@ -24,6 +25,9 @@ func extractPort(address string) (int, error) { // Server is a RTSP server. type Server struct { + // an handler to handle requests. + Handler ServerHandler + // a TLS configuration to accept TLS (RTSPS) connections. TLSConfig *tls.Config @@ -65,10 +69,19 @@ type Server struct { tcpListener net.Listener udpRTPListener *serverUDPListener udpRTCPListener *serverUDPListener + conns map[*ServerConn]struct{} + exitError error + + // in + connClose chan *ServerConn + terminate chan struct{} + + // out + done chan struct{} } -// Serve starts listening on the given address. -func (s *Server) Serve(address string) error { +// Start starts listening on the given address. +func (s *Server) Start(address string) error { if s.ReadTimeout == 0 { s.ReadTimeout = 10 * time.Second } @@ -125,6 +138,7 @@ func (s *Server) Serve(address string) error { s.udpRTCPListener, err = newServerUDPListener(s, s.UDPRTCPAddress, StreamTypeRTCP) if err != nil { + s.udpRTPListener.close() return err } } @@ -132,33 +146,132 @@ func (s *Server) Serve(address string) error { var err error s.tcpListener, err = s.Listen("tcp", address) if err != nil { + s.udpRTPListener.close() + s.udpRTPListener.close() return err } + s.terminate = make(chan struct{}) + s.done = make(chan struct{}) + + go s.run() + return nil } -// Accept accepts a connection. -func (s *Server) Accept() (*ServerConn, error) { - nconn, err := s.tcpListener.Accept() - if err != nil { - return nil, err +func (s *Server) run() { + s.conns = make(map[*ServerConn]struct{}) + s.connClose = make(chan *ServerConn) + + var wg sync.WaitGroup + + wg.Add(1) + connNew := make(chan net.Conn) + acceptErr := make(chan error) + go func() { + defer wg.Done() + acceptErr <- func() error { + for { + nconn, err := s.tcpListener.Accept() + if err != nil { + return err + } + + connNew <- nconn + } + }() + }() + +outer: + for { + select { + case err := <-acceptErr: + s.exitError = err + break outer + + case nconn := <-connNew: + sc := newServerConn(s, &wg, nconn) + s.conns[sc] = struct{}{} + + case sc := <-s.connClose: + if _, ok := s.conns[sc]; !ok { + continue + } + s.doConnClose(sc) + + case <-s.terminate: + break outer + } } - return newServerConn(s, nconn), nil -} + go func() { + for { + select { + case _, ok := <-acceptErr: + if !ok { + return + } -// Close closes all the server resources. -func (s *Server) Close() error { - s.tcpListener.Close() + case nconn, ok := <-connNew: + if !ok { + return + } + nconn.Close() - if s.udpRTPListener != nil { - s.udpRTPListener.close() - } + case _, ok := <-s.connClose: + if !ok { + return + } + } + } + }() if s.udpRTCPListener != nil { s.udpRTCPListener.close() } + if s.udpRTPListener != nil { + s.udpRTPListener.close() + } + + s.tcpListener.Close() + + for sc := range s.conns { + s.doConnClose(sc) + } + + wg.Wait() + + close(acceptErr) + close(connNew) + close(s.connClose) + close(s.done) +} + +// Close closes all the server resources and waits for the server to exit. +func (s *Server) Close() error { + close(s.terminate) + <-s.done return nil } + +// Wait waits until a fatal error. +func (s *Server) Wait() error { + <-s.done + return s.exitError +} + +// StartAndWait starts the server and waits until a fatal error. +func (s *Server) StartAndWait(address string) error { + err := s.Start(address) + if err != nil { + return err + } + + return s.Wait() +} + +func (s *Server) doConnClose(sc *ServerConn) { + delete(s.conns, sc) + close(sc.terminate) +} diff --git a/serverconn_test.go b/server_test.go similarity index 59% rename from serverconn_test.go rename to server_test.go index 2c3c9b1d..68293d1b 100644 --- a/serverconn_test.go +++ b/server_test.go @@ -13,201 +13,70 @@ import ( "github.com/stretchr/testify/require" "github.com/aler9/gortsplib/pkg/base" - "github.com/aler9/gortsplib/pkg/liberrors" ) -type testServ struct { - s *Server - wg sync.WaitGroup - mutex sync.Mutex - publisher *ServerConn - sdp []byte - readers map[*ServerConn]struct{} +type testServerHandler struct { + onConnClose func(*ServerConn, error) + onDescribe func(*ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) + onAnnounce func(*ServerHandlerOnAnnounceCtx) (*base.Response, error) + onSetup func(*ServerHandlerOnSetupCtx) (*base.Response, error) + onPlay func(*ServerHandlerOnPlayCtx) (*base.Response, error) + onRecord func(*ServerHandlerOnRecordCtx) (*base.Response, error) + onPause func(*ServerHandlerOnPauseCtx) (*base.Response, error) + onFrame func(*ServerHandlerOnFrameCtx) } -func newTestServ(tlsConf *tls.Config) (*testServ, error) { - s := &Server{} - if tlsConf != nil { - s.TLSConfig = tlsConf - } else { - s.UDPRTPAddress = "127.0.0.1:8000" - s.UDPRTCPAddress = "127.0.0.1:8001" - } - - err := s.Serve("127.0.0.1:8554") - if err != nil { - return nil, err - } - - ts := &testServ{ - s: s, - readers: make(map[*ServerConn]struct{}), - } - - ts.wg.Add(1) - go ts.run() - - return ts, nil -} - -func (ts *testServ) close() { - ts.s.Close() - ts.wg.Wait() -} - -func (ts *testServ) run() { - defer ts.wg.Done() - - for { - conn, err := ts.s.Accept() - if err != nil { - return - } - - ts.wg.Add(1) - go ts.handleConn(conn) +func (sh *testServerHandler) OnConnClose(sc *ServerConn, err error) { + if sh.onConnClose != nil { + sh.onConnClose(sc, err) } } -func (ts *testServ) handleConn(conn *ServerConn) { - defer ts.wg.Done() - defer conn.Close() - - onDescribe := func(ctx *ServerConnDescribeCtx) (*base.Response, []byte, error) { - if ctx.Path != "teststream" { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, nil, fmt.Errorf("invalid path (%s)", ctx.Req.URL) - } - - ts.mutex.Lock() - defer ts.mutex.Unlock() - - if ts.publisher == nil { - return &base.Response{ - StatusCode: base.StatusNotFound, - }, nil, nil - } - - return &base.Response{ - StatusCode: base.StatusOK, - }, ts.sdp, nil +func (sh *testServerHandler) OnDescribe(ctx *ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) { + if sh.onDescribe != nil { + return sh.onDescribe(ctx) } + return nil, nil, fmt.Errorf("unimplemented") +} - onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) { - if ctx.Path != "teststream" { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) - } - - ts.mutex.Lock() - defer ts.mutex.Unlock() - - if ts.publisher != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("someone is already publishing") - } - - ts.publisher = conn - ts.sdp = ctx.Tracks.Write() - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil +func (sh *testServerHandler) OnAnnounce(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + if sh.onAnnounce != nil { + return sh.onAnnounce(ctx) } + return nil, fmt.Errorf("unimplemented") +} - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - if ctx.Path != "teststream" { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil +func (sh *testServerHandler) OnSetup(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + if sh.onSetup != nil { + return sh.onSetup(ctx) } + return nil, fmt.Errorf("unimplemented") +} - onPlay := func(ctx *ServerConnPlayCtx) (*base.Response, error) { - if ctx.Path != "teststream" { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) - } - - ts.mutex.Lock() - defer ts.mutex.Unlock() - - ts.readers[conn] = struct{}{} - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil +func (sh *testServerHandler) OnPlay(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + if sh.onPlay != nil { + return sh.onPlay(ctx) } + return nil, fmt.Errorf("unimplemented") +} - onRecord := func(ctx *ServerConnRecordCtx) (*base.Response, error) { - if ctx.Path != "teststream" { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) - } - - ts.mutex.Lock() - defer ts.mutex.Unlock() - - if conn != ts.publisher { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("someone is already publishing") - } - - return &base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Session": base.HeaderValue{"12345678"}, - }, - }, nil +func (sh *testServerHandler) OnRecord(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) { + if sh.onRecord != nil { + return sh.onRecord(ctx) } + return nil, fmt.Errorf("unimplemented") +} - onFrame := func(trackID int, typ StreamType, buf []byte) { - ts.mutex.Lock() - defer ts.mutex.Unlock() - - if conn == ts.publisher { - for r := range ts.readers { - r.WriteFrame(trackID, typ, buf) - } - } +func (sh *testServerHandler) OnPause(ctx *ServerHandlerOnPauseCtx) (*base.Response, error) { + if sh.onPause != nil { + return sh.onPause(ctx) } + return nil, fmt.Errorf("unimplemented") +} - <-conn.Read(ServerConnReadHandlers{ - OnDescribe: onDescribe, - OnAnnounce: onAnnounce, - OnSetup: onSetup, - OnPlay: onPlay, - OnRecord: onRecord, - OnFrame: onFrame, - }) - - ts.mutex.Lock() - defer ts.mutex.Unlock() - - if conn == ts.publisher { - ts.publisher = nil - ts.sdp = nil - } else { - delete(ts.readers, conn) +func (sh *testServerHandler) OnFrame(ctx *ServerHandlerOnFrameCtx) { + if sh.onFrame != nil { + sh.onFrame(ctx) } } @@ -298,22 +167,156 @@ func TestServerHighLevelPublishRead(t *testing.T) { t.Run(encryptedStr+"_"+ca.publisherSoft+"_"+ca.publisherProto+"_"+ ca.readerSoft+"_"+ca.readerProto, func(t *testing.T) { + + var mutex sync.Mutex + var publisher *ServerConn + var sdp []byte + readers := make(map[*ServerConn]struct{}) + + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + mutex.Lock() + defer mutex.Unlock() + + if sc == publisher { + publisher = nil + sdp = nil + } else { + delete(readers, sc) + } + }, + onDescribe: func(ctx *ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) { + if ctx.Path != "teststream" { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, nil, fmt.Errorf("invalid path (%s)", ctx.Req.URL) + } + + mutex.Lock() + defer mutex.Unlock() + + if publisher == nil { + return &base.Response{ + StatusCode: base.StatusNotFound, + }, nil, nil + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, sdp, nil + }, + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + if ctx.Path != "teststream" { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) + } + + mutex.Lock() + defer mutex.Unlock() + + if publisher != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("someone is already publishing") + } + + publisher = ctx.Conn + sdp = ctx.Tracks.Write() + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + if ctx.Path != "teststream" { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) + } + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + if ctx.Path != "teststream" { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) + } + + mutex.Lock() + defer mutex.Unlock() + + readers[ctx.Conn] = struct{}{} + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + }, + onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) { + if ctx.Path != "teststream" { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid path (%s)", ctx.Req.URL) + } + + mutex.Lock() + defer mutex.Unlock() + + if ctx.Conn != publisher { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("someone is already publishing") + } + + return &base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Session": base.HeaderValue{"12345678"}, + }, + }, nil + }, + onFrame: func(ctx *ServerHandlerOnFrameCtx) { + mutex.Lock() + defer mutex.Unlock() + + if ctx.Conn == publisher { + for r := range readers { + r.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) + } + } + }, + }, + } + var proto string - var tlsConf *tls.Config if !ca.encrypted { proto = "rtsp" - tlsConf = nil + s.UDPRTPAddress = "127.0.0.1:8000" + s.UDPRTCPAddress = "127.0.0.1:8001" } else { proto = "rtsps" cert, err := tls.X509KeyPair(serverCert, serverKey) require.NoError(t, err) - tlsConf = &tls.Config{Certificates: []tls.Certificate{cert}} + s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} } - ts, err := newTestServ(tlsConf) + err := s.Start("127.0.0.1:8554") require.NoError(t, err) - defer ts.close() + defer s.Close() switch ca.publisherSoft { case "ffmpeg": @@ -374,7 +377,7 @@ func TestServerErrorWrongUDPPorts(t *testing.T) { UDPRTPAddress: "127.0.0.1:8006", UDPRTCPAddress: "127.0.0.1:8009", } - err := s.Serve("127.0.0.1:8554") + err := s.Start("127.0.0.1:8554") require.Error(t, err) }) @@ -383,29 +386,17 @@ func TestServerErrorWrongUDPPorts(t *testing.T) { UDPRTPAddress: "127.0.0.1:8003", UDPRTCPAddress: "127.0.0.1:8004", } - err := s.Serve("127.0.0.1:8554") + err := s.Start("127.0.0.1:8554") require.Error(t, err) }) } func TestServerCSeq(t *testing.T) { s := &Server{} - err := s.Serve("127.0.0.1:8554") + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - <-conn.Read(ServerConnReadHandlers{}) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -429,24 +420,17 @@ func TestServerCSeq(t *testing.T) { } func TestServerErrorCSeqMissing(t *testing.T) { - s := &Server{} - err := s.Serve("127.0.0.1:8554") + h := &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + require.Equal(t, "CSeq is missing", err.Error()) + }, + } + + s := &Server{Handler: h} + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - err = <-conn.Read(ServerConnReadHandlers{}) - require.Equal(t, "CSeq is missing", err.Error()) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -467,24 +451,10 @@ func TestServerErrorCSeqMissing(t *testing.T) { func TestServerTeardownResponse(t *testing.T) { s := &Server{} - err := s.Serve("127.0.0.1:8554") + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - err = <-conn.Read(ServerConnReadHandlers{}) - _, ok := err.(liberrors.ErrServerTeardown) - require.Equal(t, true, ok) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() diff --git a/serverconn.go b/serverconn.go index e65f1c94..68e91c04 100644 --- a/serverconn.go +++ b/serverconn.go @@ -7,6 +7,7 @@ import ( "net" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -85,7 +86,20 @@ func setupGetTrackIDPathQuery(url *base.URL, return 0, "", "", fmt.Errorf("invalid track path (%s)", pathAndQuery) } -// ServerConnState is the state of the connection. +// ServerConnSetuppedTrack is a setupped track of a ServerConn. +type ServerConnSetuppedTrack struct { + udpRTPPort int + udpRTCPPort int +} + +// ServerConnAnnouncedTrack is an announced track of a ServerConn. +type ServerConnAnnouncedTrack struct { + track *Track + rtcpReceiver *rtcpreceiver.RTCPReceiver + udpLastFrameTime *int64 +} + +// ServerConnState is a state of a ServerConn. type ServerConnState int // standard states. @@ -114,142 +128,10 @@ func (s ServerConnState) String() string { return "unknown" } -// ServerConnSetuppedTrack is a setupped track of a ServerConn. -type ServerConnSetuppedTrack struct { - udpRTPPort int - udpRTCPPort int -} - -// ServerConnAnnouncedTrack is an announced track of a ServerConn. -type ServerConnAnnouncedTrack struct { - track *Track - rtcpReceiver *rtcpreceiver.RTCPReceiver - udpLastFrameTime *int64 -} - -// ServerConnOptionsCtx is the context of a OPTIONS request. -type ServerConnOptionsCtx struct { - Req *base.Request - Path string - Query string -} - -// ServerConnDescribeCtx is the context of a DESCRIBE request. -type ServerConnDescribeCtx struct { - Req *base.Request - Path string - Query string -} - -// ServerConnAnnounceCtx is the context of a ANNOUNCE request. -type ServerConnAnnounceCtx struct { - Req *base.Request - Path string - Query string - Tracks Tracks -} - -// ServerConnSetupCtx is the context of a OPTIONS request. -type ServerConnSetupCtx struct { - Req *base.Request - Path string - Query string - TrackID int - Transport *headers.Transport -} - -// ServerConnPlayCtx is the context of a PLAY request. -type ServerConnPlayCtx struct { - Req *base.Request - Path string - Query string -} - -// ServerConnRecordCtx is the context of a RECORD request. -type ServerConnRecordCtx struct { - Req *base.Request - Path string - Query string -} - -// ServerConnPauseCtx is the context of a PAUSE request. -type ServerConnPauseCtx struct { - Req *base.Request - Path string - Query string -} - -// ServerConnGetParameterCtx is the context of a GET_PARAMETER request. -type ServerConnGetParameterCtx struct { - Req *base.Request - Path string - Query string -} - -// ServerConnSetParameterCtx is the context of a SET_PARAMETER request. -type ServerConnSetParameterCtx struct { - Req *base.Request - Path string - Query string -} - -// ServerConnTeardownCtx is the context of a TEARDOWN request. -type ServerConnTeardownCtx struct { - Req *base.Request - Path string - Query string -} - -// ServerConnReadHandlers allows to set the handlers required by ServerConn.Read. -// all fields are optional. -type ServerConnReadHandlers struct { - // called after receiving any request. - OnRequest func(req *base.Request) - - // called before sending any response. - OnResponse func(res *base.Response) - - // called after receiving a OPTIONS request. - // if nil, it is generated automatically. - OnOptions func(ctx *ServerConnOptionsCtx) (*base.Response, error) - - // called after receiving a DESCRIBE request. - // the 2nd return value is a SDP, that is inserted into the response. - OnDescribe func(ctx *ServerConnDescribeCtx) (*base.Response, []byte, error) - - // called after receiving an ANNOUNCE request. - OnAnnounce func(ctx *ServerConnAnnounceCtx) (*base.Response, error) - - // called after receiving a SETUP request. - OnSetup func(ctx *ServerConnSetupCtx) (*base.Response, error) - - // called after receiving a PLAY request. - OnPlay func(ctx *ServerConnPlayCtx) (*base.Response, error) - - // called after receiving a RECORD request. - OnRecord func(ctx *ServerConnRecordCtx) (*base.Response, error) - - // called after receiving a PAUSE request. - OnPause func(ctx *ServerConnPauseCtx) (*base.Response, error) - - // called after receiving a GET_PARAMETER request. - // if nil, it is generated automatically. - OnGetParameter func(ctx *ServerConnGetParameterCtx) (*base.Response, error) - - // called after receiving a SET_PARAMETER request. - OnSetParameter func(ctx *ServerConnSetParameterCtx) (*base.Response, error) - - // called after receiving a TEARDOWN request. - // if nil, it is generated automatically. - OnTeardown func(ctx *ServerConnTeardownCtx) (*base.Response, error) - - // called after receiving a frame. - OnFrame func(trackID int, streamType StreamType, payload []byte) -} - // ServerConn is a server-side RTSP connection. type ServerConn struct { s *Server + wg *sync.WaitGroup nconn net.Conn br *bufio.Reader bw *bufio.Writer @@ -267,9 +149,6 @@ type ServerConn struct { tcpFrameWriteBuffer *ringbuffer.RingBuffer tcpBackgroundWriteDone chan struct{} - // read - readHandlers ServerConnReadHandlers - // publish announcedTracks []ServerConnAnnouncedTrack backgroundRecordTerminate chan struct{} @@ -282,31 +161,20 @@ type ServerConn struct { func newServerConn( s *Server, + wg *sync.WaitGroup, nconn net.Conn) *ServerConn { - conn := func() net.Conn { - if s.TLSConfig != nil { - return tls.Server(nconn, s.TLSConfig) - } - return nconn - }() - return &ServerConn{ - s: s, - nconn: nconn, - br: bufio.NewReaderSize(conn, serverConnReadBufferSize), - bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize), - // always instantiate to allow writing to it before Play() - tcpFrameWriteBuffer: ringbuffer.New(uint64(s.ReadBufferCount)), - tcpBackgroundWriteDone: make(chan struct{}), - terminate: make(chan struct{}), + sc := &ServerConn{ + s: s, + wg: wg, + nconn: nconn, + terminate: make(chan struct{}), } -} -// Close closes all the connection resources. -func (sc *ServerConn) Close() error { - err := sc.nconn.Close() - close(sc.terminate) - return err + wg.Add(1) + go sc.run() + + return sc } // State returns the state. @@ -329,25 +197,17 @@ func (sc *ServerConn) AnnouncedTracks() []ServerConnAnnouncedTrack { return sc.announcedTracks } -func (sc *ServerConn) tcpBackgroundWrite() { - defer close(sc.tcpBackgroundWriteDone) +// NetConn returns the underlying net.Conn. +func (sc *ServerConn) NetConn() net.Conn { + return sc.nconn +} - for { - what, ok := sc.tcpFrameWriteBuffer.Pull() - if !ok { - return - } +func (sc *ServerConn) ip() net.IP { + return sc.nconn.RemoteAddr().(*net.TCPAddr).IP +} - switch w := what.(type) { - case *base.InterleavedFrame: - sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) - w.Write(sc.bw) - - case *base.Response: - sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) - w.Write(sc.bw) - } - } +func (sc *ServerConn) zone() string { + return sc.nconn.RemoteAddr().(*net.TCPAddr).Zone } func (sc *ServerConn) checkState(allowed map[ServerConnState]struct{}) error { @@ -364,89 +224,46 @@ func (sc *ServerConn) checkState(allowed map[ServerConnState]struct{}) error { return liberrors.ErrServerWrongState{AllowedList: allowedList, State: sc.state} } -// NetConn returns the underlying net.Conn. -func (sc *ServerConn) NetConn() net.Conn { - return sc.nconn -} +func (sc *ServerConn) run() { + defer sc.wg.Done() -func (sc *ServerConn) ip() net.IP { - return sc.nconn.RemoteAddr().(*net.TCPAddr).IP -} - -func (sc *ServerConn) zone() string { - return sc.nconn.RemoteAddr().(*net.TCPAddr).Zone -} - -func (sc *ServerConn) frameModeEnable() { - switch sc.state { - case ServerConnStatePlay: - if *sc.setupProtocol == StreamProtocolTCP { - sc.doEnableTCPFrame = true - } else { - // readers can send RTCP frames, they cannot sent RTP frames - for trackID, track := range sc.setuppedTracks { - sc.s.udpRTCPListener.addClient(sc.ip(), track.udpRTCPPort, sc, trackID, false) - } - } - - case ServerConnStateRecord: - if *sc.setupProtocol == StreamProtocolTCP { - sc.doEnableTCPFrame = true - sc.tcpFrameTimeout = true - - } else { - for trackID, track := range sc.setuppedTracks { - sc.s.udpRTPListener.addClient(sc.ip(), track.udpRTPPort, sc, trackID, true) - sc.s.udpRTCPListener.addClient(sc.ip(), track.udpRTCPPort, sc, trackID, true) - - // open the firewall by sending packets to the counterpart - sc.WriteFrame(trackID, StreamTypeRTP, - []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) - sc.WriteFrame(trackID, StreamTypeRTCP, - []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) - } - } - - sc.backgroundRecordTerminate = make(chan struct{}) - sc.backgroundRecordDone = make(chan struct{}) - go sc.backgroundRecord() + if h, ok := sc.s.Handler.(ServerHandlerOnConnOpen); ok { + h.OnConnOpen(sc) } -} -func (sc *ServerConn) frameModeDisable() { - switch sc.state { - case ServerConnStatePlay: - if *sc.setupProtocol == StreamProtocolTCP { - sc.tcpFrameEnabled = false - sc.tcpFrameWriteBuffer.Close() - <-sc.tcpBackgroundWriteDone - sc.tcpFrameWriteBuffer.Reset() - - } else { - for _, track := range sc.setuppedTracks { - sc.s.udpRTCPListener.removeClient(sc.ip(), track.udpRTCPPort) - } + conn := func() net.Conn { + if sc.s.TLSConfig != nil { + return tls.Server(sc.nconn, sc.s.TLSConfig) } + return sc.nconn + }() - case ServerConnStateRecord: - close(sc.backgroundRecordTerminate) - <-sc.backgroundRecordDone + sc.br = bufio.NewReaderSize(conn, serverConnReadBufferSize) + sc.bw = bufio.NewWriterSize(conn, serverConnWriteBufferSize) - if *sc.setupProtocol == StreamProtocolTCP { - sc.tcpFrameTimeout = false - sc.nconn.SetReadDeadline(time.Time{}) + // instantiate always to allow writing to this conn before Play() + sc.tcpFrameWriteBuffer = ringbuffer.New(uint64(sc.s.ReadBufferCount)) + sc.tcpBackgroundWriteDone = make(chan struct{}) - sc.tcpFrameEnabled = false - sc.tcpFrameWriteBuffer.Close() - <-sc.tcpBackgroundWriteDone - sc.tcpFrameWriteBuffer.Reset() + readDone := make(chan error) + go func() { + readDone <- sc.backgroundRead() + }() - } else { - for _, track := range sc.setuppedTracks { - sc.s.udpRTPListener.removeClient(sc.ip(), track.udpRTPPort) - sc.s.udpRTCPListener.removeClient(sc.ip(), track.udpRTCPPort) - } - } + var err error + select { + case err = <-readDone: + sc.nconn.Close() + sc.s.connClose <- sc + <-sc.terminate + + case <-sc.terminate: + sc.nconn.Close() + err = <-readDone + } + + if h, ok := sc.s.Handler.(ServerHandlerOnConnClose); ok { + h.OnConnClose(sc, err) } } @@ -458,13 +275,9 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, liberrors.ErrServerCSeqMissing{} } - if sc.readHandlers.OnRequest != nil { - sc.readHandlers.OnRequest(req) - } - switch req.Method { case base.Options: - if sc.readHandlers.OnOptions != nil { + if h, ok := sc.s.Handler.(ServerHandlerOnOptions); ok { pathAndQuery, ok := req.URL.RTSPPath() if !ok { return &base.Response{ @@ -474,7 +287,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { path, query := base.PathSplitQuery(pathAndQuery) - return sc.readHandlers.OnOptions(&ServerConnOptionsCtx{ + return h.OnOptions(&ServerHandlerOnOptionsCtx{ + Conn: sc, Req: req, Path: path, Query: query, @@ -482,26 +296,26 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } var methods []string - if sc.readHandlers.OnDescribe != nil { + if _, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok { methods = append(methods, string(base.Describe)) } - if sc.readHandlers.OnAnnounce != nil { + if _, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok { methods = append(methods, string(base.Announce)) } - if sc.readHandlers.OnSetup != nil { + if _, ok := sc.s.Handler.(ServerHandlerOnSetup); ok { methods = append(methods, string(base.Setup)) } - if sc.readHandlers.OnPlay != nil { + if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok { methods = append(methods, string(base.Play)) } - if sc.readHandlers.OnRecord != nil { + if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok { methods = append(methods, string(base.Record)) } - if sc.readHandlers.OnPause != nil { + if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok { methods = append(methods, string(base.Pause)) } methods = append(methods, string(base.GetParameter)) - if sc.readHandlers.OnSetParameter != nil { + if _, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok { methods = append(methods, string(base.SetParameter)) } methods = append(methods, string(base.Teardown)) @@ -514,7 +328,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, nil case base.Describe: - if sc.readHandlers.OnDescribe != nil { + if h, ok := sc.s.Handler.(ServerHandlerOnDescribe); ok { err := sc.checkState(map[ServerConnState]struct{}{ ServerConnStateInitial: {}, }) @@ -533,7 +347,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { path, query := base.PathSplitQuery(pathAndQuery) - res, sdp, err := sc.readHandlers.OnDescribe(&ServerConnDescribeCtx{ + res, sdp, err := h.OnDescribe(&ServerHandlerOnDescribeCtx{ + Conn: sc, Req: req, Path: path, Query: query, @@ -553,7 +368,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } case base.Announce: - if sc.readHandlers.OnAnnounce != nil { + if h, ok := sc.s.Handler.(ServerHandlerOnAnnounce); ok { err := sc.checkState(map[ServerConnState]struct{}{ ServerConnStateInitial: {}, }) @@ -621,7 +436,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } } - res, err := sc.readHandlers.OnAnnounce(&ServerConnAnnounceCtx{ + res, err := h.OnAnnounce(&ServerHandlerOnAnnounceCtx{ + Conn: sc, Req: req, Path: path, Query: query, @@ -650,7 +466,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } case base.Setup: - if sc.readHandlers.OnSetup != nil { + if h, ok := sc.s.Handler.(ServerHandlerOnSetup); ok { err := sc.checkState(map[ServerConnState]struct{}{ ServerConnStateInitial: {}, ServerConnStatePrePlay: {}, @@ -741,7 +557,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, liberrors.ErrServerTracksDifferentProtocols{} } - res, err := sc.readHandlers.OnSetup(&ServerConnSetupCtx{ + res, err := h.OnSetup(&ServerHandlerOnSetupCtx{ + Conn: sc, Req: req, Path: path, Query: query, @@ -810,7 +627,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } case base.Play: - if sc.readHandlers.OnPlay != nil { + if h, ok := sc.s.Handler.(ServerHandlerOnPlay); ok { // play can be sent twice, allow calling it even if we're already playing err := sc.checkState(map[ServerConnState]struct{}{ ServerConnStatePrePlay: {}, @@ -840,7 +657,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { path, query := base.PathSplitQuery(pathAndQuery) - res, err := sc.readHandlers.OnPlay(&ServerConnPlayCtx{ + res, err := h.OnPlay(&ServerHandlerOnPlayCtx{ + Conn: sc, Req: req, Path: path, Query: query, @@ -855,7 +673,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } case base.Record: - if sc.readHandlers.OnRecord != nil { + if h, ok := sc.s.Handler.(ServerHandlerOnRecord); ok { err := sc.checkState(map[ServerConnState]struct{}{ ServerConnStatePreRecord: {}, }) @@ -889,7 +707,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { path, query := base.PathSplitQuery(pathAndQuery) - res, err := sc.readHandlers.OnRecord(&ServerConnRecordCtx{ + res, err := h.OnRecord(&ServerHandlerOnRecordCtx{ + Conn: sc, Req: req, Path: path, Query: query, @@ -904,7 +723,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } case base.Pause: - if sc.readHandlers.OnPause != nil { + if h, ok := sc.s.Handler.(ServerHandlerOnPause); ok { err := sc.checkState(map[ServerConnState]struct{}{ ServerConnStatePrePlay: {}, ServerConnStatePlay: {}, @@ -929,7 +748,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { path, query := base.PathSplitQuery(pathAndQuery) - res, err := sc.readHandlers.OnPause(&ServerConnPauseCtx{ + res, err := h.OnPause(&ServerHandlerOnPauseCtx{ + Conn: sc, Req: req, Path: path, Query: query, @@ -951,7 +771,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } case base.GetParameter: - if sc.readHandlers.OnGetParameter != nil { + if h, ok := sc.s.Handler.(ServerHandlerOnGetParameter); ok { pathAndQuery, ok := req.URL.RTSPPath() if !ok { return &base.Response{ @@ -961,7 +781,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { path, query := base.PathSplitQuery(pathAndQuery) - return sc.readHandlers.OnGetParameter(&ServerConnGetParameterCtx{ + return h.OnGetParameter(&ServerHandlerOnGetParameterCtx{ + Conn: sc, Req: req, Path: path, Query: query, @@ -978,7 +799,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, nil case base.SetParameter: - if sc.readHandlers.OnSetParameter != nil { + if h, ok := sc.s.Handler.(ServerHandlerOnSetParameter); ok { pathAndQuery, ok := req.URL.RTSPPath() if !ok { return &base.Response{ @@ -988,7 +809,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { path, query := base.PathSplitQuery(pathAndQuery) - return sc.readHandlers.OnSetParameter(&ServerConnSetParameterCtx{ + return h.OnSetParameter(&ServerHandlerOnSetParameterCtx{ + Conn: sc, Req: req, Path: path, Query: query, @@ -996,7 +818,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } case base.Teardown: - if sc.readHandlers.OnTeardown != nil { + if h, ok := sc.s.Handler.(ServerHandlerOnTeardown); ok { pathAndQuery, ok := req.URL.RTSPPath() if !ok { return &base.Response{ @@ -1006,7 +828,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { path, query := base.PathSplitQuery(pathAndQuery) - return sc.readHandlers.OnTeardown(&ServerConnTeardownCtx{ + return h.OnTeardown(&ServerHandlerOnTeardownCtx{ + Conn: sc, Req: req, Path: path, Query: query, @@ -1024,6 +847,10 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } func (sc *ServerConn) handleRequestOuter(req *base.Request) error { + if h, ok := sc.s.Handler.(ServerHandlerOnRequest); ok { + h.OnRequest(req) + } + res, err := sc.handleRequest(req) if res.Header == nil { @@ -1038,8 +865,8 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error { // add server res.Header["Server"] = base.HeaderValue{"gortsplib"} - if sc.readHandlers.OnResponse != nil { - sc.readHandlers.OnResponse(res) + if h, ok := sc.s.Handler.(ServerHandlerOnResponse); ok { + h.OnResponse(res) } switch { @@ -1102,7 +929,15 @@ func (sc *ServerConn) backgroundRead() error { sc.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(), frame.StreamType, frame.Payload) } - sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Payload) + + if h, ok := sc.s.Handler.(ServerHandlerOnFrame); ok { + h.OnFrame(&ServerHandlerOnFrameCtx{ + Conn: sc, + TrackID: frame.TrackID, + StreamType: frame.StreamType, + Payload: frame.Payload, + }) + } } case *base.Request: @@ -1129,50 +964,6 @@ func (sc *ServerConn) backgroundRead() error { } } -// Read starts reading requests and frames. -// it returns a channel that is written when the reading stops. -func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error { - // channel is buffered, since listening to it is not mandatory - done := make(chan error, 1) - - sc.readHandlers = readHandlers - - go func() { - done <- sc.backgroundRead() - }() - - return done -} - -// WriteFrame writes a frame. -func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) { - if *sc.setupProtocol == StreamProtocolUDP { - track := sc.setuppedTracks[trackID] - - if streamType == StreamTypeRTP { - sc.s.udpRTPListener.write(payload, &net.UDPAddr{ - IP: sc.ip(), - Zone: sc.zone(), - Port: track.udpRTPPort, - }) - return - } - - sc.s.udpRTCPListener.write(payload, &net.UDPAddr{ - IP: sc.ip(), - Zone: sc.zone(), - Port: track.udpRTCPPort, - }) - return - } - - sc.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{ - TrackID: trackID, - StreamType: streamType, - Payload: payload, - }) -} - func (sc *ServerConn) backgroundRecord() { defer close(sc.backgroundRecordDone) @@ -1217,3 +1008,126 @@ func (sc *ServerConn) backgroundRecord() { } } } + +func (sc *ServerConn) tcpBackgroundWrite() { + defer close(sc.tcpBackgroundWriteDone) + + for { + what, ok := sc.tcpFrameWriteBuffer.Pull() + if !ok { + return + } + + switch w := what.(type) { + case *base.InterleavedFrame: + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) + w.Write(sc.bw) + + case *base.Response: + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) + w.Write(sc.bw) + } + } +} + +func (sc *ServerConn) frameModeEnable() { + switch sc.state { + case ServerConnStatePlay: + if *sc.setupProtocol == StreamProtocolTCP { + sc.doEnableTCPFrame = true + } else { + // readers can send RTCP frames, they cannot sent RTP frames + for trackID, track := range sc.setuppedTracks { + sc.s.udpRTCPListener.addClient(sc.ip(), track.udpRTCPPort, sc, trackID, false) + } + } + + case ServerConnStateRecord: + if *sc.setupProtocol == StreamProtocolTCP { + sc.doEnableTCPFrame = true + sc.tcpFrameTimeout = true + + } else { + for trackID, track := range sc.setuppedTracks { + sc.s.udpRTPListener.addClient(sc.ip(), track.udpRTPPort, sc, trackID, true) + sc.s.udpRTCPListener.addClient(sc.ip(), track.udpRTCPPort, sc, trackID, true) + + // open the firewall by sending packets to the counterpart + sc.WriteFrame(trackID, StreamTypeRTP, + []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + sc.WriteFrame(trackID, StreamTypeRTCP, + []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) + } + } + + sc.backgroundRecordTerminate = make(chan struct{}) + sc.backgroundRecordDone = make(chan struct{}) + go sc.backgroundRecord() + } +} + +func (sc *ServerConn) frameModeDisable() { + switch sc.state { + case ServerConnStatePlay: + if *sc.setupProtocol == StreamProtocolTCP { + sc.tcpFrameEnabled = false + sc.tcpFrameWriteBuffer.Close() + <-sc.tcpBackgroundWriteDone + sc.tcpFrameWriteBuffer.Reset() + + } else { + for _, track := range sc.setuppedTracks { + sc.s.udpRTCPListener.removeClient(sc.ip(), track.udpRTCPPort) + } + } + + case ServerConnStateRecord: + close(sc.backgroundRecordTerminate) + <-sc.backgroundRecordDone + + if *sc.setupProtocol == StreamProtocolTCP { + sc.tcpFrameTimeout = false + sc.nconn.SetReadDeadline(time.Time{}) + + sc.tcpFrameEnabled = false + sc.tcpFrameWriteBuffer.Close() + <-sc.tcpBackgroundWriteDone + sc.tcpFrameWriteBuffer.Reset() + + } else { + for _, track := range sc.setuppedTracks { + sc.s.udpRTPListener.removeClient(sc.ip(), track.udpRTPPort) + sc.s.udpRTCPListener.removeClient(sc.ip(), track.udpRTCPPort) + } + } + } +} + +// WriteFrame writes a frame. +func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) { + if *sc.setupProtocol == StreamProtocolUDP { + track := sc.setuppedTracks[trackID] + + if streamType == StreamTypeRTP { + sc.s.udpRTPListener.write(payload, &net.UDPAddr{ + IP: sc.ip(), + Zone: sc.zone(), + Port: track.udpRTPPort, + }) + return + } + + sc.s.udpRTCPListener.write(payload, &net.UDPAddr{ + IP: sc.ip(), + Zone: sc.zone(), + Port: track.udpRTCPPort, + }) + return + } + + sc.tcpFrameWriteBuffer.Push(&base.InterleavedFrame{ + TrackID: trackID, + StreamType: streamType, + Payload: payload, + }) +} diff --git a/serverhandler.go b/serverhandler.go new file mode 100644 index 00000000..f37b2c1f --- /dev/null +++ b/serverhandler.go @@ -0,0 +1,183 @@ +package gortsplib + +import ( + "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/headers" +) + +// ServerHandler is the interface implemented by all the server handlers. +type ServerHandler interface { +} + +// ServerHandlerOnConnOpen can be implemented by a ServerHandler. +type ServerHandlerOnConnOpen interface { + OnConnOpen(*ServerConn) +} + +// ServerHandlerOnConnClose can be implemented by a ServerHandler. +type ServerHandlerOnConnClose interface { + OnConnClose(*ServerConn, error) +} + +// ServerHandlerOnRequest can be implemented by a ServerHandler. +type ServerHandlerOnRequest interface { + OnRequest(*base.Request) +} + +// ServerHandlerOnResponse can be implemented by a ServerHandler. +type ServerHandlerOnResponse interface { + OnResponse(*base.Response) +} + +// ServerHandlerOnOptionsCtx is the context of an OPTIONS request. +type ServerHandlerOnOptionsCtx struct { + Conn *ServerConn + Req *base.Request + Path string + Query string +} + +// ServerHandlerOnOptions can be implemented by a ServerHandler. +type ServerHandlerOnOptions interface { + OnOptions(*ServerHandlerOnOptionsCtx) (*base.Response, error) +} + +// ServerHandlerOnDescribeCtx is the context of a DESCRIBE request. +type ServerHandlerOnDescribeCtx struct { + Conn *ServerConn + Req *base.Request + Path string + Query string +} + +// ServerHandlerOnDescribe can be implemented by a ServerHandler. +type ServerHandlerOnDescribe interface { + OnDescribe(*ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) +} + +// ServerHandlerOnAnnounceCtx is the context of an ANNOUNCE request. +type ServerHandlerOnAnnounceCtx struct { + Conn *ServerConn + // Session *ServerSession + Req *base.Request + Path string + Query string + Tracks Tracks +} + +// ServerHandlerOnAnnounce can be implemented by a ServerHandler. +type ServerHandlerOnAnnounce interface { + OnAnnounce(*ServerHandlerOnAnnounceCtx) (*base.Response, error) +} + +// ServerHandlerOnSetupCtx is the context of a OPTIONS request. +type ServerHandlerOnSetupCtx struct { + Conn *ServerConn + Session *ServerSession + Req *base.Request + Path string + Query string + TrackID int + Transport *headers.Transport +} + +// ServerHandlerOnSetup can be implemented by a ServerHandler. +type ServerHandlerOnSetup interface { + OnSetup(*ServerHandlerOnSetupCtx) (*base.Response, error) +} + +// ServerHandlerOnPlayCtx is the context of a PLAY request. +type ServerHandlerOnPlayCtx struct { + Conn *ServerConn + // Session *ServerSession + Req *base.Request + Path string + Query string +} + +// ServerHandlerOnPlay can be implemented by a ServerHandler. +type ServerHandlerOnPlay interface { + OnPlay(*ServerHandlerOnPlayCtx) (*base.Response, error) +} + +// ServerHandlerOnRecordCtx is the context of a RECORD request. +type ServerHandlerOnRecordCtx struct { + Conn *ServerConn + // Session *ServerSession + Req *base.Request + Path string + Query string +} + +// ServerHandlerOnRecord can be implemented by a ServerHandler. +type ServerHandlerOnRecord interface { + OnRecord(*ServerHandlerOnRecordCtx) (*base.Response, error) +} + +// ServerHandlerOnPauseCtx is the context of a PAUSE request. +type ServerHandlerOnPauseCtx struct { + Conn *ServerConn + // Session *ServerSession + Req *base.Request + Path string + Query string +} + +// ServerHandlerOnPause can be implemented by a ServerHandler. +type ServerHandlerOnPause interface { + OnPause(*ServerHandlerOnPauseCtx) (*base.Response, error) +} + +// ServerHandlerOnGetParameterCtx is the context of a GET_PARAMETER request. +type ServerHandlerOnGetParameterCtx struct { + Conn *ServerConn + Req *base.Request + Path string + Query string +} + +// ServerHandlerOnGetParameter can be implemented by a ServerHandler. +type ServerHandlerOnGetParameter interface { + OnGetParameter(*ServerHandlerOnGetParameterCtx) (*base.Response, error) +} + +// ServerHandlerOnSetParameterCtx is the context of a SET_PARAMETER request. +type ServerHandlerOnSetParameterCtx struct { + Conn *ServerConn + Req *base.Request + Path string + Query string +} + +// ServerHandlerOnSetParameter can be implemented by a ServerHandler. +type ServerHandlerOnSetParameter interface { + OnSetParameter(*ServerHandlerOnSetParameterCtx) (*base.Response, error) +} + +// ServerHandlerOnTeardownCtx is the context of a TEARDOWN request. +type ServerHandlerOnTeardownCtx struct { + Conn *ServerConn + // Session *ServerSession + Req *base.Request + Path string + Query string +} + +// ServerHandlerOnTeardown can be implemented by a ServerHandler. +type ServerHandlerOnTeardown interface { + OnTeardown(*ServerHandlerOnTeardownCtx) (*base.Response, error) +} + +// ServerHandlerOnFrameCtx is the context of a frame request. +type ServerHandlerOnFrameCtx struct { + Conn *ServerConn + // Session *ServerSession + TrackID int + StreamType StreamType + Payload []byte +} + +// ServerHandlerOnFrame can be implemented by a ServerHandler. +type ServerHandlerOnFrame interface { + OnFrame(*ServerHandlerOnFrameCtx) +} diff --git a/serverconnpublish_test.go b/serverpublish_test.go similarity index 75% rename from serverconnpublish_test.go rename to serverpublish_test.go index bbeef487..8bf9132e 100644 --- a/serverconnpublish_test.go +++ b/serverpublish_test.go @@ -70,45 +70,30 @@ func TestServerPublishSetupPath(t *testing.T) { }, } { t.Run(ca.name, func(t *testing.T) { - type pathTrackIDPair struct { - path string - trackID int - } - setupDone := make(chan pathTrackIDPair) + setupDone := make(chan struct{}) - s := &Server{} - err := s.Serve("127.0.0.1:8554") + s := &Server{ + Handler: &testServerHandler{ + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + require.Equal(t, ca.path, ctx.Path) + require.Equal(t, ca.trackID, ctx.TrackID) + close(setupDone) + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - setupDone <- pathTrackIDPair{ctx.Path, ctx.TrackID} - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - <-conn.Read(ServerConnReadHandlers{ - OnAnnounce: onAnnounce, - OnSetup: onSetup, - }) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -178,9 +163,7 @@ func TestServerPublishSetupPath(t *testing.T) { }.Write(bconn.Writer) require.NoError(t, err) - pair := <-setupDone - require.Equal(t, ca.path, pair.path) - require.Equal(t, ca.trackID, pair.trackID) + <-setupDone err = res.Read(bconn.Reader) require.NoError(t, err) @@ -192,39 +175,28 @@ func TestServerPublishSetupPath(t *testing.T) { func TestServerPublishSetupErrorDifferentPaths(t *testing.T) { serverErr := make(chan error) - s := &Server{} - err := s.Serve("127.0.0.1:8554") + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + serverErr <- err + }, + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - err = <-conn.Read(ServerConnReadHandlers{ - OnAnnounce: onAnnounce, - OnSetup: onSetup, - }) - serverErr <- err - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -291,39 +263,28 @@ func TestServerPublishSetupErrorDifferentPaths(t *testing.T) { func TestServerPublishSetupErrorTrackTwice(t *testing.T) { serverErr := make(chan error) - s := &Server{} - err := s.Serve("127.0.0.1:8554") + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + serverErr <- err + }, + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - err = <-conn.Read(ServerConnReadHandlers{ - OnAnnounce: onAnnounce, - OnSetup: onSetup, - }) - serverErr <- err - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -404,46 +365,33 @@ func TestServerPublishSetupErrorTrackTwice(t *testing.T) { func TestServerPublishRecordErrorPartialTracks(t *testing.T) { serverErr := make(chan error) - s := &Server{} - err := s.Serve("127.0.0.1:8554") + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + serverErr <- err + }, + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onRecord := func(ctx *ServerConnRecordCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - err = <-conn.Read(ServerConnReadHandlers{ - OnAnnounce: onAnnounce, - OnSetup: onSetup, - OnRecord: onRecord, - }) - serverErr <- err - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -529,67 +477,50 @@ func TestServerPublish(t *testing.T) { "tcp", } { t.Run(proto, func(t *testing.T) { - s := &Server{} + rtpReceived := uint64(0) + + s := &Server{ + Handler: &testServerHandler{ + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onFrame: func(ctx *ServerHandlerOnFrameCtx) { + if atomic.SwapUint64(&rtpReceived, 1) == 0 { + require.Equal(t, 0, ctx.TrackID) + require.Equal(t, StreamTypeRTP, ctx.StreamType) + require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload) + } else { + require.Equal(t, 0, ctx.TrackID) + require.Equal(t, StreamTypeRTCP, ctx.StreamType) + require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, ctx.Payload) + + ctx.Conn.WriteFrame(0, StreamTypeRTCP, []byte{0x09, 0x0A, 0x0B, 0x0C}) + } + }, + }, + } if proto == "udp" { s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTCPAddress = "127.0.0.1:8001" } - err := s.Serve("127.0.0.1:8554") + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onRecord := func(ctx *ServerConnRecordCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - rtpReceived := uint64(0) - onFrame := func(trackID int, typ StreamType, buf []byte) { - if atomic.SwapUint64(&rtpReceived, 1) == 0 { - require.Equal(t, 0, trackID) - require.Equal(t, StreamTypeRTP, typ) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf) - } else { - require.Equal(t, 0, trackID) - require.Equal(t, StreamTypeRTCP, typ) - require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, buf) - - conn.WriteFrame(0, StreamTypeRTCP, []byte{0x09, 0x0A, 0x0B, 0x0C}) - } - } - - <-conn.Read(ServerConnReadHandlers{ - OnAnnounce: onAnnounce, - OnSetup: onSetup, - OnRecord: onRecord, - OnFrame: onFrame, - }) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -742,53 +673,34 @@ func TestServerPublish(t *testing.T) { func TestServerPublishErrorWrongProtocol(t *testing.T) { s := &Server{ + Handler: &testServerHandler{ + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onFrame: func(ctx *ServerHandlerOnFrameCtx) { + t.Error("should not happen") + }, + }, UDPRTPAddress: "127.0.0.1:8000", UDPRTCPAddress: "127.0.0.1:8001", } - err := s.Serve("127.0.0.1:8554") + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onRecord := func(ctx *ServerConnRecordCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onFrame := func(trackID int, typ StreamType, buf []byte) { - t.Error("should not happen") - } - - <-conn.Read(ServerConnReadHandlers{ - OnAnnounce: onAnnounce, - OnSetup: onSetup, - OnRecord: onRecord, - OnFrame: onFrame, - }) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -875,51 +787,30 @@ func TestServerPublishErrorWrongProtocol(t *testing.T) { func TestServerPublishRTCPReport(t *testing.T) { s := &Server{ + Handler: &testServerHandler{ + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, receiverReportPeriod: 1 * time.Second, } - err := s.Serve("127.0.0.1:8554") + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onRecord := func(ctx *ServerConnRecordCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onFrame := func(trackID int, typ StreamType, buf []byte) { - } - - <-conn.Read(ServerConnReadHandlers{ - OnAnnounce: onAnnounce, - OnSetup: onSetup, - OnRecord: onRecord, - OnFrame: onFrame, - }) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -1053,6 +944,31 @@ func TestServerPublishErrorTimeout(t *testing.T) { errDone := make(chan struct{}) s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + if proto == "udp" { + require.Equal(t, "no UDP packets received (maybe there's a firewall/NAT in between)", err.Error()) + } else { + require.True(t, strings.HasSuffix(err.Error(), "i/o timeout")) + } + close(errDone) + }, + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, ReadTimeout: 1 * time.Second, } @@ -1061,56 +977,10 @@ func TestServerPublishErrorTimeout(t *testing.T) { s.UDPRTCPAddress = "127.0.0.1:8001" } - err := s.Serve("127.0.0.1:8554") + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onAnnounce := func(ctx *ServerConnAnnounceCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onRecord := func(ctx *ServerConnRecordCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onFrame := func(trackID int, typ StreamType, buf []byte) { - } - - err = <-conn.Read(ServerConnReadHandlers{ - OnAnnounce: onAnnounce, - OnSetup: onSetup, - OnRecord: onRecord, - OnFrame: onFrame, - }) - - if proto == "udp" { - require.Equal(t, "no UDP packets received (maybe there's a firewall/NAT in between)", err.Error()) - } else { - require.True(t, strings.HasSuffix(err.Error(), "i/o timeout")) - } - - close(errDone) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() diff --git a/serverconnread_test.go b/serverread_test.go similarity index 69% rename from serverconnread_test.go rename to serverread_test.go index bb9bfe98..3b5e6e48 100644 --- a/serverconnread_test.go +++ b/serverread_test.go @@ -58,38 +58,25 @@ func TestServerReadSetupPath(t *testing.T) { }, } { t.Run(ca.name, func(t *testing.T) { - type pathTrackIDPair struct { - path string - trackID int - } - setupDone := make(chan pathTrackIDPair) + setupDone := make(chan struct{}) - s := &Server{} - err := s.Serve("127.0.0.1:8554") + s := &Server{ + Handler: &testServerHandler{ + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + require.Equal(t, ca.path, ctx.Path) + require.Equal(t, ca.trackID, ctx.TrackID) + close(setupDone) + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - setupDone <- pathTrackIDPair{ctx.Path, ctx.TrackID} - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - <-conn.Read(ServerConnReadHandlers{ - OnSetup: onSetup, - }) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -118,9 +105,7 @@ func TestServerReadSetupPath(t *testing.T) { }.Write(bconn.Writer) require.NoError(t, err) - pair := <-setupDone - require.Equal(t, ca.path, pair.path) - require.Equal(t, ca.trackID, pair.trackID) + <-setupDone var res base.Response err = res.Read(bconn.Reader) @@ -133,32 +118,23 @@ func TestServerReadSetupPath(t *testing.T) { func TestServerReadSetupErrorDifferentPaths(t *testing.T) { serverErr := make(chan error) - s := &Server{} - err := s.Serve("127.0.0.1:8554") + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + serverErr <- err + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - err = <-conn.Read(ServerConnReadHandlers{ - OnSetup: onSetup, - }) - serverErr <- err - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -215,32 +191,23 @@ func TestServerReadSetupErrorDifferentPaths(t *testing.T) { func TestServerReadSetupErrorTrackTwice(t *testing.T) { serverErr := make(chan error) - s := &Server{} - err := s.Serve("127.0.0.1:8554") + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + serverErr <- err + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - err = <-conn.Read(ServerConnReadHandlers{ - OnSetup: onSetup, - }) - serverErr <- err - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -303,52 +270,35 @@ func TestServerRead(t *testing.T) { framesReceived := make(chan struct{}) s := &Server{ + Handler: &testServerHandler{ + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) + ctx.Conn.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08}) + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onFrame: func(ctx *ServerHandlerOnFrameCtx) { + require.Equal(t, 0, ctx.TrackID) + require.Equal(t, StreamTypeRTCP, ctx.StreamType) + require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload) + close(framesReceived) + }, + }, UDPRTPAddress: "127.0.0.1:8000", UDPRTCPAddress: "127.0.0.1:8001", } - err := s.Serve("127.0.0.1:8554") + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onPlay := func(ctx *ServerConnPlayCtx) (*base.Response, error) { - conn.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) - conn.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08}) - - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onFrame := func(trackID int, typ StreamType, buf []byte) { - require.Equal(t, 0, trackID) - require.Equal(t, StreamTypeRTCP, typ) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf) - close(framesReceived) - } - - <-conn.Read(ServerConnReadHandlers{ - OnSetup: onSetup, - OnPlay: onPlay, - OnFrame: onFrame, - }) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -467,63 +417,52 @@ func TestServerRead(t *testing.T) { } func TestServerReadTCPResponseBeforeFrames(t *testing.T) { - s := &Server{} - err := s.Serve("127.0.0.1:8554") + writerDone := make(chan struct{}) + writerTerminate := make(chan struct{}) + + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + close(writerTerminate) + <-writerDone + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + go func() { + defer close(writerDone) + + ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) + + t := time.NewTicker(50 * time.Millisecond) + defer t.Stop() + + for { + select { + case <-t.C: + ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) + case <-writerTerminate: + return + } + } + }() + + time.Sleep(50 * time.Millisecond) + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - writerDone := make(chan struct{}) - defer func() { <-writerDone }() - writerTerminate := make(chan struct{}) - defer close(writerTerminate) - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onPlay := func(ctx *ServerConnPlayCtx) (*base.Response, error) { - go func() { - defer close(writerDone) - - conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) - - t := time.NewTicker(50 * time.Millisecond) - defer t.Stop() - - for { - select { - case <-t.C: - conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) - case <-writerTerminate: - return - } - } - }() - - time.Sleep(50 * time.Millisecond) - - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - <-conn.Read(ServerConnReadHandlers{ - OnSetup: onSetup, - OnPlay: onPlay, - }) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -575,61 +514,50 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { } func TestServerReadPlayPlay(t *testing.T) { - s := &Server{} - err := s.Serve("127.0.0.1:8554") + writerTerminate := make(chan struct{}) + writerDone := make(chan struct{}) + + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + close(writerTerminate) + <-writerDone + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + if ctx.Conn.State() != ServerConnStatePlay { + go func() { + defer close(writerDone) + + t := time.NewTicker(50 * time.Millisecond) + defer t.Stop() + + for { + select { + case <-t.C: + ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) + case <-writerTerminate: + return + } + } + }() + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - writerDone := make(chan struct{}) - defer func() { <-writerDone }() - writerTerminate := make(chan struct{}) - defer close(writerTerminate) - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onPlay := func(ctx *ServerConnPlayCtx) (*base.Response, error) { - if conn.State() != ServerConnStatePlay { - go func() { - defer close(writerDone) - - t := time.NewTicker(50 * time.Millisecond) - defer t.Stop() - - for { - select { - case <-t.C: - conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) - case <-writerTerminate: - return - } - } - }() - } - - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - <-conn.Read(ServerConnReadHandlers{ - OnSetup: onSetup, - OnPlay: onPlay, - }) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -690,70 +618,57 @@ func TestServerReadPlayPlay(t *testing.T) { } func TestServerReadPlayPausePlay(t *testing.T) { - s := &Server{} - err := s.Serve("127.0.0.1:8554") + writerStarted := false + writerDone := make(chan struct{}) + writerTerminate := make(chan struct{}) + + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + close(writerTerminate) + <-writerDone + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + if !writerStarted { + writerStarted = true + go func() { + defer close(writerDone) + + t := time.NewTicker(50 * time.Millisecond) + defer t.Stop() + + for { + select { + case <-t.C: + ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) + case <-writerTerminate: + return + } + } + }() + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPause: func(ctx *ServerHandlerOnPauseCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - writerStarted := false - writerDone := make(chan struct{}) - defer func() { <-writerDone }() - writerTerminate := make(chan struct{}) - defer close(writerTerminate) - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onPlay := func(ctx *ServerConnPlayCtx) (*base.Response, error) { - if !writerStarted { - writerStarted = true - go func() { - defer close(writerDone) - - t := time.NewTicker(50 * time.Millisecond) - defer t.Stop() - - for { - select { - case <-t.C: - conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) - case <-writerTerminate: - return - } - } - }() - } - - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onPause := func(ctx *ServerConnPauseCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - <-conn.Read(ServerConnReadHandlers{ - OnSetup: onSetup, - OnPlay: onPlay, - OnPause: onPause, - }) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() @@ -832,66 +747,53 @@ func TestServerReadPlayPausePlay(t *testing.T) { } func TestServerReadPlayPausePause(t *testing.T) { - s := &Server{} - err := s.Serve("127.0.0.1:8554") + writerDone := make(chan struct{}) + writerTerminate := make(chan struct{}) + + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + close(writerTerminate) + <-writerDone + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + go func() { + defer close(writerDone) + + t := time.NewTicker(50 * time.Millisecond) + defer t.Stop() + + for { + select { + case <-t.C: + ctx.Conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) + case <-writerTerminate: + return + } + } + }() + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPause: func(ctx *ServerHandlerOnPauseCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") require.NoError(t, err) defer s.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - writerDone := make(chan struct{}) - defer func() { <-writerDone }() - writerTerminate := make(chan struct{}) - defer close(writerTerminate) - - onSetup := func(ctx *ServerConnSetupCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onPlay := func(ctx *ServerConnPlayCtx) (*base.Response, error) { - go func() { - defer close(writerDone) - - t := time.NewTicker(50 * time.Millisecond) - defer t.Stop() - - for { - select { - case <-t.C: - conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) - case <-writerTerminate: - return - } - } - }() - - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onPause := func(ctx *ServerConnPauseCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - <-conn.Read(ServerConnReadHandlers{ - OnSetup: onSetup, - OnPlay: onPlay, - OnPause: onPause, - }) - }() - conn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer conn.Close() diff --git a/serversession.go b/serversession.go new file mode 100644 index 00000000..5d5abe01 --- /dev/null +++ b/serversession.go @@ -0,0 +1,5 @@ +package gortsplib + +// ServerSession is a server-side RTSP session. +type ServerSession struct { +} diff --git a/serverudpl.go b/serverudpl.go index 1cb36221..c25efdbc 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -42,6 +42,7 @@ func (p *clientAddr) fill(ip net.IP, port int) { } type serverUDPListener struct { + s *Server pc *net.UDPConn streamType StreamType writeTimeout time.Duration @@ -71,6 +72,7 @@ func newServerUDPListener( } u := &serverUDPListener{ + s: s, pc: pc, clients: make(map[clientAddr]*clientData), done: make(chan struct{}), @@ -125,7 +127,14 @@ func (u *serverUDPListener) run() { clientData.sc.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n]) } - clientData.sc.readHandlers.OnFrame(clientData.trackID, u.streamType, buf[:n]) + if h, ok := u.s.Handler.(ServerHandlerOnFrame); ok { + h.OnFrame(&ServerHandlerOnFrameCtx{ + Conn: clientData.sc, + TrackID: clientData.trackID, + StreamType: u.streamType, + Payload: buf[:n], + }) + } }() } }()