diff --git a/examples/server-udp/main.go b/examples/server-udp/main.go index 466f347e..91d61c54 100644 --- a/examples/server-udp/main.go +++ b/examples/server-udp/main.go @@ -147,20 +147,10 @@ func handleConn(conn *gortsplib.ServerConn) { } func main() { - // to publish or read UDP streams, two UDP listeners must be created - udpRTPListener, err := gortsplib.NewServerUDPListener(":8000") - if err != nil { - panic(err) - } - udpRTCPListener, err := gortsplib.NewServerUDPListener(":8001") - if err != nil { - panic(err) - } - // create configuration conf := gortsplib.ServerConf{ - UDPRTPListener: udpRTPListener, - UDPRTCPListener: udpRTCPListener, + UDPRTPAddress: ":8000", + UDPRTCPAddress: ":8001", } // create server diff --git a/server.go b/server.go index 4ca83217..cc18bfbe 100644 --- a/server.go +++ b/server.go @@ -8,8 +8,10 @@ import ( // Server is a RTSP server. type Server struct { - conf ServerConf - listener net.Listener + conf ServerConf + tcpListener net.Listener + udpRTPListener *serverUDPListener + udpRTCPListener *serverUDPListener } func newServer(conf ServerConf, address string) (*Server, error) { @@ -29,28 +31,36 @@ func newServer(conf ServerConf, address string) (*Server, error) { conf.Listen = net.Listen } - if conf.TLSConfig != nil && conf.UDPRTPListener != nil { + if conf.TLSConfig != nil && conf.UDPRTPAddress != "" { return nil, fmt.Errorf("TLS can't be used together with UDP") } - if (conf.UDPRTPListener != nil && conf.UDPRTCPListener == nil) || - (conf.UDPRTPListener == nil && conf.UDPRTCPListener != nil) { - return nil, fmt.Errorf("UDPRTPListener and UDPRTPListener must be used together") - } - - if conf.UDPRTPListener != nil { - conf.UDPRTPListener.initialize(conf, StreamTypeRTP) - conf.UDPRTCPListener.initialize(conf, StreamTypeRTCP) - } - - listener, err := conf.Listen("tcp", address) - if err != nil { - return nil, err + if (conf.UDPRTPAddress != "" && conf.UDPRTCPAddress == "") || + (conf.UDPRTPAddress == "" && conf.UDPRTCPAddress != "") { + return nil, fmt.Errorf("UDPRTPAddress and UDPRTCPAddress must be used together") } s := &Server{ - conf: conf, - listener: listener, + conf: conf, + } + + if conf.UDPRTPAddress != "" { + var err error + s.udpRTPListener, err = newServerUDPListener(conf, conf.UDPRTPAddress, StreamTypeRTP) + if err != nil { + return nil, err + } + + s.udpRTCPListener, err = newServerUDPListener(conf, conf.UDPRTCPAddress, StreamTypeRTCP) + if err != nil { + return nil, err + } + } + + var err error + s.tcpListener, err = conf.Listen("tcp", address) + if err != nil { + return nil, err } return s, nil @@ -58,15 +68,25 @@ func newServer(conf ServerConf, address string) (*Server, error) { // Close closes the server. func (s *Server) Close() error { - return s.listener.Close() + s.tcpListener.Close() + + if s.udpRTPListener != nil { + s.udpRTPListener.close() + } + + if s.udpRTCPListener != nil { + s.udpRTCPListener.close() + } + + return nil } // Accept accepts a connection. func (s *Server) Accept() (*ServerConn, error) { - nconn, err := s.listener.Accept() + nconn, err := s.tcpListener.Accept() if err != nil { return nil, err } - return newServerConn(s.conf, nconn), nil + return newServerConn(s.conf, s.udpRTPListener, s.udpRTCPListener, nconn), nil } diff --git a/serverconf.go b/serverconf.go index 3d4a6c11..f528006e 100644 --- a/serverconf.go +++ b/serverconf.go @@ -20,13 +20,13 @@ type ServerConf struct { // a TLS configuration to accept TLS (RTSPS) connections. TLSConfig *tls.Config - // a ServerUDPListener to send and receive UDP/RTP packets. - // If UDPRTPListener and UDPRTCPListener are not null, the server can accept and send UDP streams. - UDPRTPListener *ServerUDPListener + // a port to send and receive UDP/RTP packets. + // If UDPRTPAddress and UDPRTCPAddress are != "", the server can accept and send UDP streams. + UDPRTPAddress string - // a ServerUDPListener to send and receive UDP/RTCP packets. - // If UDPRTPListener and UDPRTCPListener are not null, the server can accept and send UDP streams. - UDPRTCPListener *ServerUDPListener + // a port to send and receive UDP/RTCP packets. + // If UDPRTPAddress and UDPRTCPAddress are != "", the server can accept and send UDP streams. + UDPRTCPAddress string // timeout of read operations. // It defaults to 10 seconds diff --git a/serverconf_test.go b/serverconf_test.go index 6e119327..b84cea5f 100644 --- a/serverconf_test.go +++ b/serverconf_test.go @@ -18,40 +18,24 @@ import ( ) type testServ struct { - s *Server - udpRTPListener *ServerUDPListener - udpRTCPListener *ServerUDPListener - wg sync.WaitGroup - mutex sync.Mutex - publisher *ServerConn - sdp []byte - readers map[*ServerConn]struct{} + s *Server + wg sync.WaitGroup + mutex sync.Mutex + publisher *ServerConn + sdp []byte + readers map[*ServerConn]struct{} } func newTestServ(tlsConf *tls.Config) (*testServ, error) { var conf ServerConf - var udpRTPListener *ServerUDPListener - var udpRTCPListener *ServerUDPListener if tlsConf != nil { conf = ServerConf{ TLSConfig: tlsConf, } - } else { - var err error - udpRTPListener, err = NewServerUDPListener(":8000") - if err != nil { - return nil, err - } - - udpRTCPListener, err = NewServerUDPListener(":8001") - if err != nil { - return nil, err - } - conf = ServerConf{ - UDPRTPListener: udpRTPListener, - UDPRTCPListener: udpRTCPListener, + UDPRTPAddress: ":8000", + UDPRTCPAddress: ":8001", } } @@ -61,10 +45,8 @@ func newTestServ(tlsConf *tls.Config) (*testServ, error) { } ts := &testServ{ - s: s, - udpRTPListener: udpRTPListener, - udpRTCPListener: udpRTCPListener, - readers: make(map[*ServerConn]struct{}), + s: s, + readers: make(map[*ServerConn]struct{}), } ts.wg.Add(1) @@ -76,12 +58,6 @@ func newTestServ(tlsConf *tls.Config) (*testServ, error) { func (ts *testServ) close() { ts.s.Close() ts.wg.Wait() - if ts.udpRTPListener != nil { - ts.udpRTPListener.Close() - } - if ts.udpRTCPListener != nil { - ts.udpRTCPListener.Close() - } } func (ts *testServ) run() { diff --git a/serverconn.go b/serverconn.go index 7e1edc35..cd9ed048 100644 --- a/serverconn.go +++ b/serverconn.go @@ -122,6 +122,8 @@ type ServerConnReadHandlers struct { type ServerConn struct { conf ServerConf nconn net.Conn + udpRTPListener *serverUDPListener + udpRTCPListener *serverUDPListener br *bufio.Reader bw *bufio.Writer state ServerConnState @@ -147,7 +149,10 @@ type ServerConn struct { terminate chan struct{} } -func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn { +func newServerConn(conf ServerConf, + udpRTPListener *serverUDPListener, + udpRTCPListener *serverUDPListener, + nconn net.Conn) *ServerConn { conn := func() net.Conn { if conf.TLSConfig != nil { return tls.Server(nconn, conf.TLSConfig) @@ -157,6 +162,8 @@ func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn { return &ServerConn{ conf: conf, + udpRTPListener: udpRTPListener, + udpRTCPListener: udpRTCPListener, nconn: nconn, br: bufio.NewReaderSize(conn, serverConnReadBufferSize), bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize), @@ -263,8 +270,8 @@ func (sc *ServerConn) frameModeEnable() { } else { for trackID, track := range sc.setuppedTracks { - sc.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc) - sc.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc) + sc.udpRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc) + sc.udpRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc) // open the firewall by sending packets to the counterpart sc.WriteFrame(trackID, StreamTypeRTP, @@ -303,8 +310,8 @@ func (sc *ServerConn) frameModeDisable() { } else { for _, track := range sc.setuppedTracks { - sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort) - sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort) + sc.udpRTPListener.removePublisher(sc.ip(), track.rtpPort) + sc.udpRTCPListener.removePublisher(sc.ip(), track.rtcpPort) } } } @@ -550,7 +557,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } if th.Protocol == StreamProtocolUDP { - if sc.conf.UDPRTPListener == nil { + if sc.udpRTPListener == nil { return &base.Response{ StatusCode: base.StatusUnsupportedTransport, }, nil @@ -600,7 +607,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { return &v }(), ClientPorts: th.ClientPorts, - ServerPorts: &[2]int{sc.conf.UDPRTPListener.port(), sc.conf.UDPRTCPListener.port()}, + ServerPorts: &[2]int{sc.udpRTPListener.port(), sc.udpRTCPListener.port()}, }.Write() } else { @@ -903,7 +910,7 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b track := sc.setuppedTracks[trackID] if streamType == StreamTypeRTP { - sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{ + sc.udpRTPListener.write(payload, &net.UDPAddr{ IP: sc.ip(), Zone: sc.zone(), Port: track.rtpPort, @@ -911,7 +918,7 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b return } - sc.conf.UDPRTCPListener.write(payload, &net.UDPAddr{ + sc.udpRTCPListener.write(payload, &net.UDPAddr{ IP: sc.ip(), Zone: sc.zone(), Port: track.rtcpPort, diff --git a/serverudpl.go b/serverudpl.go index a95f571e..204d2768 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -40,10 +40,8 @@ func (p *publisherAddr) fill(ip net.IP, port int) { } } -// ServerUDPListener is a UDP server that can be used to send and receive RTP and RTCP packets. -type ServerUDPListener struct { +type serverUDPListener struct { pc *net.UDPConn - initialized bool streamType StreamType writeTimeout time.Duration readBuf *multibuffer.MultiBuffer @@ -55,8 +53,11 @@ type ServerUDPListener struct { done chan struct{} } -// NewServerUDPListener allocates a ServerUDPListener. -func NewServerUDPListener(address string) (*ServerUDPListener, error) { +func newServerUDPListener( + conf ServerConf, + address string, + streamType StreamType) (*serverUDPListener, error) { + tmp, err := net.ListenPacket("udp", address) if err != nil { return nil, err @@ -68,37 +69,29 @@ func NewServerUDPListener(address string) (*ServerUDPListener, error) { return nil, err } - return &ServerUDPListener{ + s := &serverUDPListener{ pc: pc, publishers: make(map[publisherAddr]*publisherData), done: make(chan struct{}), - }, nil -} - -// Close closes the listener. -func (s *ServerUDPListener) Close() { - s.pc.Close() - - if s.initialized { - s.ringBuffer.Close() - <-s.done - } -} - -func (s *ServerUDPListener) initialize(conf ServerConf, streamType StreamType) { - if s.initialized { - return } - s.initialized = true s.streamType = streamType s.writeTimeout = conf.WriteTimeout s.readBuf = multibuffer.New(uint64(conf.ReadBufferCount), uint64(conf.ReadBufferSize)) s.ringBuffer = ringbuffer.New(uint64(conf.ReadBufferCount)) + go s.run() + + return s, nil } -func (s *ServerUDPListener) run() { +func (s *serverUDPListener) close() { + s.pc.Close() + s.ringBuffer.Close() + <-s.done +} + +func (s *serverUDPListener) run() { defer close(s.done) var wg sync.WaitGroup @@ -153,15 +146,15 @@ func (s *ServerUDPListener) run() { wg.Wait() } -func (s *ServerUDPListener) port() int { +func (s *serverUDPListener) port() int { return s.pc.LocalAddr().(*net.UDPAddr).Port } -func (s *ServerUDPListener) write(buf []byte, addr *net.UDPAddr) { +func (s *serverUDPListener) write(buf []byte, addr *net.UDPAddr) { s.ringBuffer.Push(bufAddrPair{buf, addr}) } -func (s *ServerUDPListener) addPublisher(ip net.IP, port int, trackID int, sc *ServerConn) { +func (s *serverUDPListener) addPublisher(ip net.IP, port int, trackID int, sc *ServerConn) { s.publishersMutex.Lock() defer s.publishersMutex.Unlock() @@ -174,7 +167,7 @@ func (s *ServerUDPListener) addPublisher(ip net.IP, port int, trackID int, sc *S } } -func (s *ServerUDPListener) removePublisher(ip net.IP, port int) { +func (s *serverUDPListener) removePublisher(ip net.IP, port int) { s.publishersMutex.Lock() defer s.publishersMutex.Unlock()