diff --git a/examples/server-tls/main.go b/examples/server-tls/main.go index 58a39c1f..768f903d 100644 --- a/examples/server-tls/main.go +++ b/examples/server-tls/main.go @@ -151,15 +151,16 @@ func main() { if err != nil { panic(err) } - conf := gortsplib.ServerConf{ - TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}}, - } // create server - s, err := conf.Serve(":8554") + s := &gortsplib.Server{ + TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}}, + } + err = s.Serve(":8554") if err != nil { panic(err) } + log.Printf("server is ready") // accept connections diff --git a/examples/server-udp/main.go b/examples/server-udp/main.go index f2eb2f39..abedd32b 100644 --- a/examples/server-udp/main.go +++ b/examples/server-udp/main.go @@ -143,17 +143,16 @@ func handleConn(conn *gortsplib.ServerConn) { } func main() { - // create configuration - conf := gortsplib.ServerConf{ + // create server + s := &gortsplib.Server{ UDPRTPAddress: ":8000", UDPRTCPAddress: ":8001", } - - // create server - s, err := conf.Serve(":8554") + err := s.Serve(":8554") if err != nil { panic(err) } + log.Printf("server is ready") // accept connections diff --git a/examples/server/main.go b/examples/server/main.go index b2eae2c3..0d421c1c 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -144,10 +144,12 @@ func handleConn(conn *gortsplib.ServerConn) { func main() { // create server - s, err := gortsplib.Serve(":8554") + s := &gortsplib.Server{} + err := s.Serve(":8554") if err != nil { panic(err) } + log.Printf("server is ready") // accept connections diff --git a/server.go b/server.go index c02d4bff..dfe22091 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ package gortsplib import ( + "crypto/tls" "fmt" "net" "strconv" @@ -23,87 +24,131 @@ func extractPort(address string) (int, error) { // Server is a RTSP server. type Server struct { - conf ServerConf + // a TLS configuration to accept TLS (RTSPS) connections. + TLSConfig *tls.Config + + // a port to send and receive UDP/RTP packets. + // If UDPRTPAddress and UDPRTCPAddress are != "", the server can accept and send UDP streams. + UDPRTPAddress string + + // a port to send and receive UDP/RTCP packets. + // If UDPRTPAddress and UDPRTCPAddress are != "", the server can accept and send UDP streams. + UDPRTCPAddress string + + // timeout of read operations. + // It defaults to 10 seconds + ReadTimeout time.Duration + + // timeout of write operations. + // It defaults to 10 seconds + WriteTimeout time.Duration + + // read buffer count. + // If greater than 1, allows to pass buffers to routines different than the one + // that is reading frames. + // It also allows to buffer routed frames and mitigate network fluctuations + // that are particularly high when using UDP. + // It defaults to 512 + ReadBufferCount int + + // read buffer size. + // This must be touched only when the server reports problems about buffer sizes. + // It defaults to 2048. + ReadBufferSize int + + // function used to initialize the TCP listener. + // It defaults to net.Listen + Listen func(network string, address string) (net.Listener, error) + + receiverReportPeriod time.Duration + tcpListener net.Listener udpRTPListener *serverUDPListener udpRTCPListener *serverUDPListener } -func newServer(conf ServerConf, address string) (*Server, error) { - if conf.ReadTimeout == 0 { - conf.ReadTimeout = 10 * time.Second +// Serve starts listening on the given address. +func (s *Server) Serve(address string) error { + if s.ReadTimeout == 0 { + s.ReadTimeout = 10 * time.Second } - if conf.WriteTimeout == 0 { - conf.WriteTimeout = 10 * time.Second + if s.WriteTimeout == 0 { + s.WriteTimeout = 10 * time.Second } - if conf.ReadBufferCount == 0 { - conf.ReadBufferCount = 512 + if s.ReadBufferCount == 0 { + s.ReadBufferCount = 512 } - if conf.ReadBufferSize == 0 { - conf.ReadBufferSize = 2048 + if s.ReadBufferSize == 0 { + s.ReadBufferSize = 2048 } - if conf.Listen == nil { - conf.Listen = net.Listen + if s.Listen == nil { + s.Listen = net.Listen } - if conf.receiverReportPeriod == 0 { - conf.receiverReportPeriod = 10 * time.Second + if s.receiverReportPeriod == 0 { + s.receiverReportPeriod = 10 * time.Second } - if conf.TLSConfig != nil && conf.UDPRTPAddress != "" { - return nil, fmt.Errorf("TLS can't be used together with UDP") + if s.TLSConfig != nil && s.UDPRTPAddress != "" { + return fmt.Errorf("TLS can't be used together with UDP") } - if (conf.UDPRTPAddress != "" && conf.UDPRTCPAddress == "") || - (conf.UDPRTPAddress == "" && conf.UDPRTCPAddress != "") { - return nil, fmt.Errorf("UDPRTPAddress and UDPRTCPAddress must be used together") + if (s.UDPRTPAddress != "" && s.UDPRTCPAddress == "") || + (s.UDPRTPAddress == "" && s.UDPRTCPAddress != "") { + return fmt.Errorf("UDPRTPAddress and UDPRTCPAddress must be used together") } - s := &Server{ - conf: conf, - } - - if conf.UDPRTPAddress != "" { - rtpPort, err := extractPort(conf.UDPRTPAddress) + if s.UDPRTPAddress != "" { + rtpPort, err := extractPort(s.UDPRTPAddress) if err != nil { - return nil, err + return err } - rtcpPort, err := extractPort(conf.UDPRTCPAddress) + rtcpPort, err := extractPort(s.UDPRTCPAddress) if err != nil { - return nil, err + return err } if (rtpPort % 2) != 0 { - return nil, fmt.Errorf("RTP port must be even") + return fmt.Errorf("RTP port must be even") } if rtcpPort != (rtpPort + 1) { - return nil, fmt.Errorf("RTCP and RTP ports must be consecutive") + return fmt.Errorf("RTCP and RTP ports must be consecutive") } - s.udpRTPListener, err = newServerUDPListener(conf, conf.UDPRTPAddress, StreamTypeRTP) + s.udpRTPListener, err = newServerUDPListener(s, s.UDPRTPAddress, StreamTypeRTP) if err != nil { - return nil, err + return err } - s.udpRTCPListener, err = newServerUDPListener(conf, conf.UDPRTCPAddress, StreamTypeRTCP) + s.udpRTCPListener, err = newServerUDPListener(s, s.UDPRTCPAddress, StreamTypeRTCP) if err != nil { - return nil, err + return err } } var err error - s.tcpListener, err = conf.Listen("tcp", address) + s.tcpListener, err = s.Listen("tcp", address) + if err != nil { + return err + } + + return nil +} + +// Accept accepts a connection. +func (s *Server) Accept() (*ServerConn, error) { + nconn, err := s.tcpListener.Accept() if err != nil { return nil, err } - return s, nil + return newServerConn(s, nconn), nil } -// Close closes the server. +// Close closes all the server resources. func (s *Server) Close() error { s.tcpListener.Close() @@ -117,13 +162,3 @@ func (s *Server) Close() error { return nil } - -// Accept accepts a connection. -func (s *Server) Accept() (*ServerConn, error) { - nconn, err := s.tcpListener.Accept() - if err != nil { - return nil, err - } - - return newServerConn(s.conf, s.udpRTPListener, s.udpRTCPListener, nconn), nil -} diff --git a/serverconf.go b/serverconf.go deleted file mode 100644 index 1b0f6506..00000000 --- a/serverconf.go +++ /dev/null @@ -1,62 +0,0 @@ -package gortsplib - -import ( - "crypto/tls" - "net" - "time" -) - -// DefaultServerConf is the default ServerConf. -var DefaultServerConf = ServerConf{} - -// Serve starts a server on the given address. -func Serve(address string) (*Server, error) { - return DefaultServerConf.Serve(address) -} - -// ServerConf allows to configure a Server. -// All fields are optional. -type ServerConf struct { - // a TLS configuration to accept TLS (RTSPS) connections. - TLSConfig *tls.Config - - // a port to send and receive UDP/RTP packets. - // If UDPRTPAddress and UDPRTCPAddress are != "", the server can accept and send UDP streams. - UDPRTPAddress string - - // a port to send and receive UDP/RTCP packets. - // If UDPRTPAddress and UDPRTCPAddress are != "", the server can accept and send UDP streams. - UDPRTCPAddress string - - // timeout of read operations. - // It defaults to 10 seconds - ReadTimeout time.Duration - - // timeout of write operations. - // It defaults to 10 seconds - WriteTimeout time.Duration - - // read buffer count. - // If greater than 1, allows to pass buffers to routines different than the one - // that is reading frames. - // It also allows to buffer routed frames and mitigate network fluctuations - // that are particularly high when using UDP. - // It defaults to 512 - ReadBufferCount int - - // read buffer size. - // This must be touched only when the server reports problems about buffer sizes. - // It defaults to 2048. - ReadBufferSize int - - // function used to initialize the TCP listener. - // It defaults to net.Listen - Listen func(network string, address string) (net.Listener, error) - - receiverReportPeriod time.Duration -} - -// Serve starts a server on the given address. -func (c ServerConf) Serve(address string) (*Server, error) { - return newServer(c, address) -} diff --git a/serverconn.go b/serverconn.go index a965e7be..e65f1c94 100644 --- a/serverconn.go +++ b/serverconn.go @@ -249,17 +249,15 @@ type ServerConnReadHandlers struct { // ServerConn is a server-side RTSP connection. type ServerConn struct { - conf ServerConf - nconn net.Conn - udpRTPListener *serverUDPListener - udpRTCPListener *serverUDPListener - br *bufio.Reader - bw *bufio.Writer - state ServerConnState - setuppedTracks map[int]ServerConnSetuppedTrack - setupProtocol *StreamProtocol - setupPath *string - setupQuery *string + s *Server + nconn net.Conn + br *bufio.Reader + bw *bufio.Writer + state ServerConnState + setuppedTracks map[int]ServerConnSetuppedTrack + setupProtocol *StreamProtocol + setupPath *string + setupQuery *string // TCP stream protocol doEnableTCPFrame bool @@ -282,26 +280,23 @@ type ServerConn struct { terminate chan struct{} } -func newServerConn(conf ServerConf, - udpRTPListener *serverUDPListener, - udpRTCPListener *serverUDPListener, +func newServerConn( + s *Server, nconn net.Conn) *ServerConn { conn := func() net.Conn { - if conf.TLSConfig != nil { - return tls.Server(nconn, conf.TLSConfig) + if s.TLSConfig != nil { + return tls.Server(nconn, s.TLSConfig) } return nconn }() return &ServerConn{ - conf: conf, - udpRTPListener: udpRTPListener, - udpRTCPListener: udpRTCPListener, - nconn: nconn, - br: bufio.NewReaderSize(conn, serverConnReadBufferSize), - bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize), + s: s, + nconn: nconn, + br: bufio.NewReaderSize(conn, serverConnReadBufferSize), + bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize), // always instantiate to allow writing to it before Play() - tcpFrameWriteBuffer: ringbuffer.New(uint64(conf.ReadBufferCount)), + tcpFrameWriteBuffer: ringbuffer.New(uint64(s.ReadBufferCount)), tcpBackgroundWriteDone: make(chan struct{}), terminate: make(chan struct{}), } @@ -345,11 +340,11 @@ func (sc *ServerConn) tcpBackgroundWrite() { switch w := what.(type) { case *base.InterleavedFrame: - sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) w.Write(sc.bw) case *base.Response: - sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) w.Write(sc.bw) } } @@ -390,7 +385,7 @@ func (sc *ServerConn) frameModeEnable() { } else { // readers can send RTCP frames, they cannot sent RTP frames for trackID, track := range sc.setuppedTracks { - sc.udpRTCPListener.addClient(sc.ip(), track.udpRTCPPort, sc, trackID, false) + sc.s.udpRTCPListener.addClient(sc.ip(), track.udpRTCPPort, sc, trackID, false) } } @@ -401,8 +396,8 @@ func (sc *ServerConn) frameModeEnable() { } else { for trackID, track := range sc.setuppedTracks { - sc.udpRTPListener.addClient(sc.ip(), track.udpRTPPort, sc, trackID, true) - sc.udpRTCPListener.addClient(sc.ip(), track.udpRTCPPort, sc, trackID, true) + sc.s.udpRTPListener.addClient(sc.ip(), track.udpRTPPort, sc, trackID, true) + sc.s.udpRTCPListener.addClient(sc.ip(), track.udpRTCPPort, sc, trackID, true) // open the firewall by sending packets to the counterpart sc.WriteFrame(trackID, StreamTypeRTP, @@ -429,7 +424,7 @@ func (sc *ServerConn) frameModeDisable() { } else { for _, track := range sc.setuppedTracks { - sc.udpRTCPListener.removeClient(sc.ip(), track.udpRTCPPort) + sc.s.udpRTCPListener.removeClient(sc.ip(), track.udpRTCPPort) } } @@ -448,8 +443,8 @@ func (sc *ServerConn) frameModeDisable() { } else { for _, track := range sc.setuppedTracks { - sc.udpRTPListener.removeClient(sc.ip(), track.udpRTPPort) - sc.udpRTCPListener.removeClient(sc.ip(), track.udpRTCPPort) + sc.s.udpRTPListener.removeClient(sc.ip(), track.udpRTPPort) + sc.s.udpRTCPListener.removeClient(sc.ip(), track.udpRTCPPort) } } } @@ -712,7 +707,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } if th.Protocol == StreamProtocolUDP { - if sc.udpRTPListener == nil { + if sc.s.udpRTPListener == nil { return &base.Response{ StatusCode: base.StatusUnsupportedTransport, }, nil @@ -777,7 +772,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { return &v }(), ClientPorts: th.ClientPorts, - ServerPorts: &[2]int{sc.udpRTPListener.port(), sc.udpRTCPListener.port()}, + ServerPorts: &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}, }.Write() } else { @@ -1053,17 +1048,17 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error { sc.tcpFrameEnabled = true if sc.state == ServerConnStateRecord { - sc.tcpFrameBuffer = multibuffer.New(uint64(sc.conf.ReadBufferCount), uint64(sc.conf.ReadBufferSize)) + sc.tcpFrameBuffer = multibuffer.New(uint64(sc.s.ReadBufferCount), uint64(sc.s.ReadBufferSize)) } else { // when playing, tcpFrameBuffer is only used to receive RTCP receiver reports, // that are much smaller than RTP frames and are sent at a fixed interval // (about 2 frames every 10 secs). // decrease RAM consumption by allocating less buffers. - sc.tcpFrameBuffer = multibuffer.New(8, uint64(sc.conf.ReadBufferSize)) + sc.tcpFrameBuffer = multibuffer.New(8, uint64(sc.s.ReadBufferSize)) } // write response before frames - sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) res.Write(sc.bw) // start background write @@ -1074,7 +1069,7 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error { sc.tcpFrameWriteBuffer.Push(res) default: // write directly - sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) + sc.nconn.SetWriteDeadline(time.Now().Add(sc.s.WriteTimeout)) res.Write(sc.bw) } @@ -1090,7 +1085,7 @@ func (sc *ServerConn) backgroundRead() error { for { if sc.tcpFrameEnabled { if sc.tcpFrameTimeout { - sc.nconn.SetReadDeadline(time.Now().Add(sc.conf.ReadTimeout)) + sc.nconn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout)) } frame.Payload = sc.tcpFrameBuffer.Next() @@ -1155,7 +1150,7 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b track := sc.setuppedTracks[trackID] if streamType == StreamTypeRTP { - sc.udpRTPListener.write(payload, &net.UDPAddr{ + sc.s.udpRTPListener.write(payload, &net.UDPAddr{ IP: sc.ip(), Zone: sc.zone(), Port: track.udpRTPPort, @@ -1163,7 +1158,7 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b return } - sc.udpRTCPListener.write(payload, &net.UDPAddr{ + sc.s.udpRTCPListener.write(payload, &net.UDPAddr{ IP: sc.ip(), Zone: sc.zone(), Port: track.udpRTCPPort, @@ -1184,7 +1179,7 @@ func (sc *ServerConn) backgroundRecord() { checkStreamTicker := time.NewTicker(serverConnCheckStreamPeriod) defer checkStreamTicker.Stop() - receiverReportTicker := time.NewTicker(sc.conf.receiverReportPeriod) + receiverReportTicker := time.NewTicker(sc.s.receiverReportPeriod) defer receiverReportTicker.Stop() for { @@ -1198,7 +1193,7 @@ func (sc *ServerConn) backgroundRecord() { now := time.Now() for _, track := range sc.announcedTracks { lft := atomic.LoadInt64(track.udpLastFrameTime) - if now.Sub(time.Unix(lft, 0)) < sc.conf.ReadTimeout { + if now.Sub(time.Unix(lft, 0)) < sc.s.ReadTimeout { return false } } diff --git a/serverconn_test.go b/serverconn_test.go index 699c7e68..2c3c9b1d 100644 --- a/serverconn_test.go +++ b/serverconn_test.go @@ -26,19 +26,15 @@ type testServ struct { } func newTestServ(tlsConf *tls.Config) (*testServ, error) { - var conf ServerConf + s := &Server{} if tlsConf != nil { - conf = ServerConf{ - TLSConfig: tlsConf, - } + s.TLSConfig = tlsConf } else { - conf = ServerConf{ - UDPRTPAddress: "127.0.0.1:8000", - UDPRTCPAddress: "127.0.0.1:8001", - } + s.UDPRTPAddress = "127.0.0.1:8000" + s.UDPRTCPAddress = "127.0.0.1:8001" } - s, err := conf.Serve("127.0.0.1:8554") + err := s.Serve("127.0.0.1:8554") if err != nil { return nil, err } @@ -374,26 +370,27 @@ func TestServerHighLevelPublishRead(t *testing.T) { func TestServerErrorWrongUDPPorts(t *testing.T) { t.Run("non consecutive", func(t *testing.T) { - conf := ServerConf{ + s := &Server{ UDPRTPAddress: "127.0.0.1:8006", UDPRTCPAddress: "127.0.0.1:8009", } - _, err := conf.Serve("127.0.0.1:8554") + err := s.Serve("127.0.0.1:8554") require.Error(t, err) }) t.Run("non even", func(t *testing.T) { - conf := ServerConf{ + s := &Server{ UDPRTPAddress: "127.0.0.1:8003", UDPRTCPAddress: "127.0.0.1:8004", } - _, err := conf.Serve("127.0.0.1:8554") + err := s.Serve("127.0.0.1:8554") require.Error(t, err) }) } func TestServerCSeq(t *testing.T) { - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -432,7 +429,8 @@ func TestServerCSeq(t *testing.T) { } func TestServerErrorCSeqMissing(t *testing.T) { - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -468,7 +466,8 @@ func TestServerErrorCSeqMissing(t *testing.T) { } func TestServerTeardownResponse(t *testing.T) { - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() diff --git a/serverconnpublish_test.go b/serverconnpublish_test.go index df6294a6..bbeef487 100644 --- a/serverconnpublish_test.go +++ b/serverconnpublish_test.go @@ -76,7 +76,8 @@ func TestServerPublishSetupPath(t *testing.T) { } setupDone := make(chan pathTrackIDPair) - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -191,7 +192,8 @@ func TestServerPublishSetupPath(t *testing.T) { func TestServerPublishSetupErrorDifferentPaths(t *testing.T) { serverErr := make(chan error) - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -289,7 +291,8 @@ func TestServerPublishSetupErrorDifferentPaths(t *testing.T) { func TestServerPublishSetupErrorTrackTwice(t *testing.T) { serverErr := make(chan error) - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -401,7 +404,8 @@ func TestServerPublishSetupErrorTrackTwice(t *testing.T) { func TestServerPublishRecordErrorPartialTracks(t *testing.T) { serverErr := make(chan error) - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -525,14 +529,14 @@ func TestServerPublish(t *testing.T) { "tcp", } { t.Run(proto, func(t *testing.T) { - conf := ServerConf{} + s := &Server{} if proto == "udp" { - conf.UDPRTPAddress = "127.0.0.1:8000" - conf.UDPRTCPAddress = "127.0.0.1:8001" + s.UDPRTPAddress = "127.0.0.1:8000" + s.UDPRTCPAddress = "127.0.0.1:8001" } - s, err := conf.Serve("127.0.0.1:8554") + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -737,12 +741,12 @@ func TestServerPublish(t *testing.T) { } func TestServerPublishErrorWrongProtocol(t *testing.T) { - conf := ServerConf{ + s := &Server{ UDPRTPAddress: "127.0.0.1:8000", UDPRTCPAddress: "127.0.0.1:8001", } - s, err := conf.Serve("127.0.0.1:8554") + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -870,11 +874,11 @@ func TestServerPublishErrorWrongProtocol(t *testing.T) { } func TestServerPublishRTCPReport(t *testing.T) { - conf := ServerConf{ + s := &Server{ receiverReportPeriod: 1 * time.Second, } - s, err := conf.Serve("127.0.0.1:8554") + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -1048,16 +1052,16 @@ func TestServerPublishErrorTimeout(t *testing.T) { t.Run(proto, func(t *testing.T) { errDone := make(chan struct{}) - conf := ServerConf{ + s := &Server{ ReadTimeout: 1 * time.Second, } if proto == "udp" { - conf.UDPRTPAddress = "127.0.0.1:8000" - conf.UDPRTCPAddress = "127.0.0.1:8001" + s.UDPRTPAddress = "127.0.0.1:8000" + s.UDPRTCPAddress = "127.0.0.1:8001" } - s, err := conf.Serve("127.0.0.1:8554") + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() diff --git a/serverconnread_test.go b/serverconnread_test.go index 65eefdec..bb9bfe98 100644 --- a/serverconnread_test.go +++ b/serverconnread_test.go @@ -64,7 +64,8 @@ func TestServerReadSetupPath(t *testing.T) { } setupDone := make(chan pathTrackIDPair) - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -132,7 +133,8 @@ func TestServerReadSetupPath(t *testing.T) { func TestServerReadSetupErrorDifferentPaths(t *testing.T) { serverErr := make(chan error) - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -213,7 +215,8 @@ func TestServerReadSetupErrorDifferentPaths(t *testing.T) { func TestServerReadSetupErrorTrackTwice(t *testing.T) { serverErr := make(chan error) - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -299,12 +302,12 @@ func TestServerRead(t *testing.T) { t.Run(proto, func(t *testing.T) { framesReceived := make(chan struct{}) - conf := ServerConf{ + s := &Server{ UDPRTPAddress: "127.0.0.1:8000", UDPRTCPAddress: "127.0.0.1:8001", } - s, err := conf.Serve("127.0.0.1:8554") + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -464,7 +467,8 @@ func TestServerRead(t *testing.T) { } func TestServerReadTCPResponseBeforeFrames(t *testing.T) { - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -571,7 +575,8 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { } func TestServerReadPlayPlay(t *testing.T) { - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -685,7 +690,8 @@ func TestServerReadPlayPlay(t *testing.T) { } func TestServerReadPlayPausePlay(t *testing.T) { - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -826,7 +832,8 @@ func TestServerReadPlayPausePlay(t *testing.T) { } func TestServerReadPlayPausePause(t *testing.T) { - s, err := Serve("127.0.0.1:8554") + s := &Server{} + err := s.Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() diff --git a/serverudpl.go b/serverudpl.go index 978875cd..1cb36221 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -55,7 +55,7 @@ type serverUDPListener struct { } func newServerUDPListener( - conf ServerConf, + s *Server, address string, streamType StreamType) (*serverUDPListener, error) { @@ -70,30 +70,30 @@ func newServerUDPListener( return nil, err } - s := &serverUDPListener{ + u := &serverUDPListener{ pc: pc, clients: make(map[clientAddr]*clientData), done: make(chan struct{}), } - s.streamType = streamType - s.writeTimeout = conf.WriteTimeout - s.readBuf = multibuffer.New(uint64(conf.ReadBufferCount), uint64(conf.ReadBufferSize)) - s.ringBuffer = ringbuffer.New(uint64(conf.ReadBufferCount)) + u.streamType = streamType + u.writeTimeout = s.WriteTimeout + u.readBuf = multibuffer.New(uint64(s.ReadBufferCount), uint64(s.ReadBufferSize)) + u.ringBuffer = ringbuffer.New(uint64(s.ReadBufferCount)) - go s.run() + go u.run() - return s, nil + return u, nil } -func (s *serverUDPListener) close() { - s.pc.Close() - s.ringBuffer.Close() - <-s.done +func (u *serverUDPListener) close() { + u.pc.Close() + u.ringBuffer.Close() + <-u.done } -func (s *serverUDPListener) run() { - defer close(s.done) +func (u *serverUDPListener) run() { + defer close(u.done) var wg sync.WaitGroup @@ -102,19 +102,19 @@ func (s *serverUDPListener) run() { defer wg.Done() for { - buf := s.readBuf.Next() - n, addr, err := s.pc.ReadFromUDP(buf) + buf := u.readBuf.Next() + n, addr, err := u.pc.ReadFromUDP(buf) if err != nil { break } func() { - s.clientsMutex.RLock() - defer s.clientsMutex.RUnlock() + u.clientsMutex.RLock() + defer u.clientsMutex.RUnlock() var clientAddr clientAddr clientAddr.fill(addr.IP, addr.Port) - clientData, ok := s.clients[clientAddr] + clientData, ok := u.clients[clientAddr] if !ok { return } @@ -122,10 +122,10 @@ func (s *serverUDPListener) run() { if clientData.isPublishing { now := time.Now() atomic.StoreInt64(clientData.sc.announcedTracks[clientData.trackID].udpLastFrameTime, now.Unix()) - clientData.sc.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, s.streamType, buf[:n]) + clientData.sc.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n]) } - clientData.sc.readHandlers.OnFrame(clientData.trackID, s.streamType, buf[:n]) + clientData.sc.readHandlers.OnFrame(clientData.trackID, u.streamType, buf[:n]) }() } }() @@ -135,48 +135,48 @@ func (s *serverUDPListener) run() { defer wg.Done() for { - tmp, ok := s.ringBuffer.Pull() + tmp, ok := u.ringBuffer.Pull() if !ok { return } pair := tmp.(bufAddrPair) - s.pc.SetWriteDeadline(time.Now().Add(s.writeTimeout)) - s.pc.WriteTo(pair.buf, pair.addr) + u.pc.SetWriteDeadline(time.Now().Add(u.writeTimeout)) + u.pc.WriteTo(pair.buf, pair.addr) } }() wg.Wait() } -func (s *serverUDPListener) port() int { - return s.pc.LocalAddr().(*net.UDPAddr).Port +func (u *serverUDPListener) port() int { + return u.pc.LocalAddr().(*net.UDPAddr).Port } -func (s *serverUDPListener) write(buf []byte, addr *net.UDPAddr) { - s.ringBuffer.Push(bufAddrPair{buf, addr}) +func (u *serverUDPListener) write(buf []byte, addr *net.UDPAddr) { + u.ringBuffer.Push(bufAddrPair{buf, addr}) } -func (s *serverUDPListener) addClient(ip net.IP, port int, sc *ServerConn, trackID int, isPublishing bool) { - s.clientsMutex.Lock() - defer s.clientsMutex.Unlock() +func (u *serverUDPListener) addClient(ip net.IP, port int, sc *ServerConn, trackID int, isPublishing bool) { + u.clientsMutex.Lock() + defer u.clientsMutex.Unlock() var addr clientAddr addr.fill(ip, port) - s.clients[addr] = &clientData{ + u.clients[addr] = &clientData{ sc: sc, trackID: trackID, isPublishing: isPublishing, } } -func (s *serverUDPListener) removeClient(ip net.IP, port int) { - s.clientsMutex.Lock() - defer s.clientsMutex.Unlock() +func (u *serverUDPListener) removeClient(ip net.IP, port int) { + u.clientsMutex.Lock() + defer u.clientsMutex.Unlock() var addr clientAddr addr.fill(ip, port) - delete(s.clients, addr) + delete(u.clients, addr) }