simplify UDP configuration

This commit is contained in:
aler9
2021-03-06 09:46:24 +01:00
parent 964331cacd
commit 260af6e041
6 changed files with 96 additions and 110 deletions

View File

@@ -147,20 +147,10 @@ func handleConn(conn *gortsplib.ServerConn) {
} }
func main() { func main() {
// to publish or read UDP streams, two UDP listeners must be created
udpRTPListener, err := gortsplib.NewServerUDPListener(":8000")
if err != nil {
panic(err)
}
udpRTCPListener, err := gortsplib.NewServerUDPListener(":8001")
if err != nil {
panic(err)
}
// create configuration // create configuration
conf := gortsplib.ServerConf{ conf := gortsplib.ServerConf{
UDPRTPListener: udpRTPListener, UDPRTPAddress: ":8000",
UDPRTCPListener: udpRTCPListener, UDPRTCPAddress: ":8001",
} }
// create server // create server

View File

@@ -8,8 +8,10 @@ import (
// Server is a RTSP server. // Server is a RTSP server.
type Server struct { type Server struct {
conf ServerConf conf ServerConf
listener net.Listener tcpListener net.Listener
udpRTPListener *serverUDPListener
udpRTCPListener *serverUDPListener
} }
func newServer(conf ServerConf, address string) (*Server, error) { func newServer(conf ServerConf, address string) (*Server, error) {
@@ -29,28 +31,36 @@ func newServer(conf ServerConf, address string) (*Server, error) {
conf.Listen = net.Listen conf.Listen = net.Listen
} }
if conf.TLSConfig != nil && conf.UDPRTPListener != nil { if conf.TLSConfig != nil && conf.UDPRTPAddress != "" {
return nil, fmt.Errorf("TLS can't be used together with UDP") return nil, fmt.Errorf("TLS can't be used together with UDP")
} }
if (conf.UDPRTPListener != nil && conf.UDPRTCPListener == nil) || if (conf.UDPRTPAddress != "" && conf.UDPRTCPAddress == "") ||
(conf.UDPRTPListener == nil && conf.UDPRTCPListener != nil) { (conf.UDPRTPAddress == "" && conf.UDPRTCPAddress != "") {
return nil, fmt.Errorf("UDPRTPListener and UDPRTPListener must be used together") return nil, fmt.Errorf("UDPRTPAddress and UDPRTCPAddress must be used together")
}
if conf.UDPRTPListener != nil {
conf.UDPRTPListener.initialize(conf, StreamTypeRTP)
conf.UDPRTCPListener.initialize(conf, StreamTypeRTCP)
}
listener, err := conf.Listen("tcp", address)
if err != nil {
return nil, err
} }
s := &Server{ s := &Server{
conf: conf, conf: conf,
listener: listener, }
if conf.UDPRTPAddress != "" {
var err error
s.udpRTPListener, err = newServerUDPListener(conf, conf.UDPRTPAddress, StreamTypeRTP)
if err != nil {
return nil, err
}
s.udpRTCPListener, err = newServerUDPListener(conf, conf.UDPRTCPAddress, StreamTypeRTCP)
if err != nil {
return nil, err
}
}
var err error
s.tcpListener, err = conf.Listen("tcp", address)
if err != nil {
return nil, err
} }
return s, nil return s, nil
@@ -58,15 +68,25 @@ func newServer(conf ServerConf, address string) (*Server, error) {
// Close closes the server. // Close closes the server.
func (s *Server) Close() error { func (s *Server) Close() error {
return s.listener.Close() s.tcpListener.Close()
if s.udpRTPListener != nil {
s.udpRTPListener.close()
}
if s.udpRTCPListener != nil {
s.udpRTCPListener.close()
}
return nil
} }
// Accept accepts a connection. // Accept accepts a connection.
func (s *Server) Accept() (*ServerConn, error) { func (s *Server) Accept() (*ServerConn, error) {
nconn, err := s.listener.Accept() nconn, err := s.tcpListener.Accept()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newServerConn(s.conf, nconn), nil return newServerConn(s.conf, s.udpRTPListener, s.udpRTCPListener, nconn), nil
} }

View File

@@ -20,13 +20,13 @@ type ServerConf struct {
// a TLS configuration to accept TLS (RTSPS) connections. // a TLS configuration to accept TLS (RTSPS) connections.
TLSConfig *tls.Config TLSConfig *tls.Config
// a ServerUDPListener to send and receive UDP/RTP packets. // a port to send and receive UDP/RTP packets.
// If UDPRTPListener and UDPRTCPListener are not null, the server can accept and send UDP streams. // If UDPRTPAddress and UDPRTCPAddress are != "", the server can accept and send UDP streams.
UDPRTPListener *ServerUDPListener UDPRTPAddress string
// a ServerUDPListener to send and receive UDP/RTCP packets. // a port to send and receive UDP/RTCP packets.
// If UDPRTPListener and UDPRTCPListener are not null, the server can accept and send UDP streams. // If UDPRTPAddress and UDPRTCPAddress are != "", the server can accept and send UDP streams.
UDPRTCPListener *ServerUDPListener UDPRTCPAddress string
// timeout of read operations. // timeout of read operations.
// It defaults to 10 seconds // It defaults to 10 seconds

View File

@@ -18,40 +18,24 @@ import (
) )
type testServ struct { type testServ struct {
s *Server s *Server
udpRTPListener *ServerUDPListener wg sync.WaitGroup
udpRTCPListener *ServerUDPListener mutex sync.Mutex
wg sync.WaitGroup publisher *ServerConn
mutex sync.Mutex sdp []byte
publisher *ServerConn readers map[*ServerConn]struct{}
sdp []byte
readers map[*ServerConn]struct{}
} }
func newTestServ(tlsConf *tls.Config) (*testServ, error) { func newTestServ(tlsConf *tls.Config) (*testServ, error) {
var conf ServerConf var conf ServerConf
var udpRTPListener *ServerUDPListener
var udpRTCPListener *ServerUDPListener
if tlsConf != nil { if tlsConf != nil {
conf = ServerConf{ conf = ServerConf{
TLSConfig: tlsConf, TLSConfig: tlsConf,
} }
} else { } else {
var err error
udpRTPListener, err = NewServerUDPListener(":8000")
if err != nil {
return nil, err
}
udpRTCPListener, err = NewServerUDPListener(":8001")
if err != nil {
return nil, err
}
conf = ServerConf{ conf = ServerConf{
UDPRTPListener: udpRTPListener, UDPRTPAddress: ":8000",
UDPRTCPListener: udpRTCPListener, UDPRTCPAddress: ":8001",
} }
} }
@@ -61,10 +45,8 @@ func newTestServ(tlsConf *tls.Config) (*testServ, error) {
} }
ts := &testServ{ ts := &testServ{
s: s, s: s,
udpRTPListener: udpRTPListener, readers: make(map[*ServerConn]struct{}),
udpRTCPListener: udpRTCPListener,
readers: make(map[*ServerConn]struct{}),
} }
ts.wg.Add(1) ts.wg.Add(1)
@@ -76,12 +58,6 @@ func newTestServ(tlsConf *tls.Config) (*testServ, error) {
func (ts *testServ) close() { func (ts *testServ) close() {
ts.s.Close() ts.s.Close()
ts.wg.Wait() ts.wg.Wait()
if ts.udpRTPListener != nil {
ts.udpRTPListener.Close()
}
if ts.udpRTCPListener != nil {
ts.udpRTCPListener.Close()
}
} }
func (ts *testServ) run() { func (ts *testServ) run() {

View File

@@ -122,6 +122,8 @@ type ServerConnReadHandlers struct {
type ServerConn struct { type ServerConn struct {
conf ServerConf conf ServerConf
nconn net.Conn nconn net.Conn
udpRTPListener *serverUDPListener
udpRTCPListener *serverUDPListener
br *bufio.Reader br *bufio.Reader
bw *bufio.Writer bw *bufio.Writer
state ServerConnState state ServerConnState
@@ -147,7 +149,10 @@ type ServerConn struct {
terminate chan struct{} terminate chan struct{}
} }
func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn { func newServerConn(conf ServerConf,
udpRTPListener *serverUDPListener,
udpRTCPListener *serverUDPListener,
nconn net.Conn) *ServerConn {
conn := func() net.Conn { conn := func() net.Conn {
if conf.TLSConfig != nil { if conf.TLSConfig != nil {
return tls.Server(nconn, conf.TLSConfig) return tls.Server(nconn, conf.TLSConfig)
@@ -157,6 +162,8 @@ func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn {
return &ServerConn{ return &ServerConn{
conf: conf, conf: conf,
udpRTPListener: udpRTPListener,
udpRTCPListener: udpRTCPListener,
nconn: nconn, nconn: nconn,
br: bufio.NewReaderSize(conn, serverConnReadBufferSize), br: bufio.NewReaderSize(conn, serverConnReadBufferSize),
bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize), bw: bufio.NewWriterSize(conn, serverConnWriteBufferSize),
@@ -263,8 +270,8 @@ func (sc *ServerConn) frameModeEnable() {
} else { } else {
for trackID, track := range sc.setuppedTracks { for trackID, track := range sc.setuppedTracks {
sc.conf.UDPRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc) sc.udpRTPListener.addPublisher(sc.ip(), track.rtpPort, trackID, sc)
sc.conf.UDPRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc) sc.udpRTCPListener.addPublisher(sc.ip(), track.rtcpPort, trackID, sc)
// open the firewall by sending packets to the counterpart // open the firewall by sending packets to the counterpart
sc.WriteFrame(trackID, StreamTypeRTP, sc.WriteFrame(trackID, StreamTypeRTP,
@@ -303,8 +310,8 @@ func (sc *ServerConn) frameModeDisable() {
} else { } else {
for _, track := range sc.setuppedTracks { for _, track := range sc.setuppedTracks {
sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort) sc.udpRTPListener.removePublisher(sc.ip(), track.rtpPort)
sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort) sc.udpRTCPListener.removePublisher(sc.ip(), track.rtcpPort)
} }
} }
} }
@@ -550,7 +557,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
} }
if th.Protocol == StreamProtocolUDP { if th.Protocol == StreamProtocolUDP {
if sc.conf.UDPRTPListener == nil { if sc.udpRTPListener == nil {
return &base.Response{ return &base.Response{
StatusCode: base.StatusUnsupportedTransport, StatusCode: base.StatusUnsupportedTransport,
}, nil }, nil
@@ -600,7 +607,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
return &v return &v
}(), }(),
ClientPorts: th.ClientPorts, ClientPorts: th.ClientPorts,
ServerPorts: &[2]int{sc.conf.UDPRTPListener.port(), sc.conf.UDPRTCPListener.port()}, ServerPorts: &[2]int{sc.udpRTPListener.port(), sc.udpRTCPListener.port()},
}.Write() }.Write()
} else { } else {
@@ -903,7 +910,7 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b
track := sc.setuppedTracks[trackID] track := sc.setuppedTracks[trackID]
if streamType == StreamTypeRTP { if streamType == StreamTypeRTP {
sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{ sc.udpRTPListener.write(payload, &net.UDPAddr{
IP: sc.ip(), IP: sc.ip(),
Zone: sc.zone(), Zone: sc.zone(),
Port: track.rtpPort, Port: track.rtpPort,
@@ -911,7 +918,7 @@ func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []b
return return
} }
sc.conf.UDPRTCPListener.write(payload, &net.UDPAddr{ sc.udpRTCPListener.write(payload, &net.UDPAddr{
IP: sc.ip(), IP: sc.ip(),
Zone: sc.zone(), Zone: sc.zone(),
Port: track.rtcpPort, Port: track.rtcpPort,

View File

@@ -40,10 +40,8 @@ func (p *publisherAddr) fill(ip net.IP, port int) {
} }
} }
// ServerUDPListener is a UDP server that can be used to send and receive RTP and RTCP packets. type serverUDPListener struct {
type ServerUDPListener struct {
pc *net.UDPConn pc *net.UDPConn
initialized bool
streamType StreamType streamType StreamType
writeTimeout time.Duration writeTimeout time.Duration
readBuf *multibuffer.MultiBuffer readBuf *multibuffer.MultiBuffer
@@ -55,8 +53,11 @@ type ServerUDPListener struct {
done chan struct{} done chan struct{}
} }
// NewServerUDPListener allocates a ServerUDPListener. func newServerUDPListener(
func NewServerUDPListener(address string) (*ServerUDPListener, error) { conf ServerConf,
address string,
streamType StreamType) (*serverUDPListener, error) {
tmp, err := net.ListenPacket("udp", address) tmp, err := net.ListenPacket("udp", address)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -68,37 +69,29 @@ func NewServerUDPListener(address string) (*ServerUDPListener, error) {
return nil, err return nil, err
} }
return &ServerUDPListener{ s := &serverUDPListener{
pc: pc, pc: pc,
publishers: make(map[publisherAddr]*publisherData), publishers: make(map[publisherAddr]*publisherData),
done: make(chan struct{}), done: make(chan struct{}),
}, nil
}
// Close closes the listener.
func (s *ServerUDPListener) Close() {
s.pc.Close()
if s.initialized {
s.ringBuffer.Close()
<-s.done
}
}
func (s *ServerUDPListener) initialize(conf ServerConf, streamType StreamType) {
if s.initialized {
return
} }
s.initialized = true
s.streamType = streamType s.streamType = streamType
s.writeTimeout = conf.WriteTimeout s.writeTimeout = conf.WriteTimeout
s.readBuf = multibuffer.New(uint64(conf.ReadBufferCount), uint64(conf.ReadBufferSize)) s.readBuf = multibuffer.New(uint64(conf.ReadBufferCount), uint64(conf.ReadBufferSize))
s.ringBuffer = ringbuffer.New(uint64(conf.ReadBufferCount)) s.ringBuffer = ringbuffer.New(uint64(conf.ReadBufferCount))
go s.run() go s.run()
return s, nil
} }
func (s *ServerUDPListener) run() { func (s *serverUDPListener) close() {
s.pc.Close()
s.ringBuffer.Close()
<-s.done
}
func (s *serverUDPListener) run() {
defer close(s.done) defer close(s.done)
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -153,15 +146,15 @@ func (s *ServerUDPListener) run() {
wg.Wait() wg.Wait()
} }
func (s *ServerUDPListener) port() int { func (s *serverUDPListener) port() int {
return s.pc.LocalAddr().(*net.UDPAddr).Port return s.pc.LocalAddr().(*net.UDPAddr).Port
} }
func (s *ServerUDPListener) write(buf []byte, addr *net.UDPAddr) { func (s *serverUDPListener) write(buf []byte, addr *net.UDPAddr) {
s.ringBuffer.Push(bufAddrPair{buf, addr}) s.ringBuffer.Push(bufAddrPair{buf, addr})
} }
func (s *ServerUDPListener) addPublisher(ip net.IP, port int, trackID int, sc *ServerConn) { func (s *serverUDPListener) addPublisher(ip net.IP, port int, trackID int, sc *ServerConn) {
s.publishersMutex.Lock() s.publishersMutex.Lock()
defer s.publishersMutex.Unlock() defer s.publishersMutex.Unlock()
@@ -174,7 +167,7 @@ func (s *ServerUDPListener) addPublisher(ip net.IP, port int, trackID int, sc *S
} }
} }
func (s *ServerUDPListener) removePublisher(ip net.IP, port int) { func (s *serverUDPListener) removePublisher(ip net.IP, port int) {
s.publishersMutex.Lock() s.publishersMutex.Lock()
defer s.publishersMutex.Unlock() defer s.publishersMutex.Unlock()