diff --git a/client_read_test.go b/client_read_test.go index 96678783..3bf37cea 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -272,7 +272,7 @@ func TestClientRead(t *testing.T) { v := base.StreamDeliveryMulticast th.Delivery = &v th.Protocol = base.StreamProtocolUDP - v2 := multicastIP.String() + v2 := "224.1.0.1" th.Destination = &v2 th.Ports = &[2]int{25000, 25001} @@ -286,7 +286,7 @@ func TestClientRead(t *testing.T) { require.NoError(t, err) for _, intf := range intfs { - err := p.JoinGroup(&intf, &net.UDPAddr{IP: multicastIP}) + err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")}) require.NoError(t, err) } @@ -300,7 +300,7 @@ func TestClientRead(t *testing.T) { require.NoError(t, err) for _, intf := range intfs { - err := p.JoinGroup(&intf, &net.UDPAddr{IP: multicastIP}) + err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")}) require.NoError(t, err) } @@ -341,7 +341,7 @@ func TestClientRead(t *testing.T) { case "multicast": time.Sleep(1 * time.Second) l1.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ - IP: multicastIP, + IP: net.ParseIP("224.1.0.1"), Port: 25000, }) diff --git a/examples/server/main.go b/examples/server/main.go index 21920f2b..cee0de71 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -137,9 +137,12 @@ func (sh *serverHandler) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) { func main() { // configure server s := &gortsplib.Server{ - Handler: &serverHandler{}, - UDPRTPAddress: ":8000", - UDPRTCPAddress: ":8001", + Handler: &serverHandler{}, + UDPRTPAddress: ":8000", + UDPRTCPAddress: ":8001", + MulticastIPRange: "244.1.0.0/16", + MulticastRTPPort: 8002, + MulticastRTCPPort: 8003, } // start server and wait until a fatal error diff --git a/server.go b/server.go index a61b9dce..2c648bbf 100644 --- a/server.go +++ b/server.go @@ -45,18 +45,22 @@ func newSessionID(sessions map[string]*ServerSession) (string, error) { } } -type requestRes struct { +type sessionRequestRes struct { ss *ServerSession res *base.Response err error } -type request struct { +type sessionRequestReq struct { sc *ServerConn req *base.Request id string create bool - res chan requestRes + res chan sessionRequestRes +} + +type streamMulticastIPReq struct { + res chan net.IP } // Server is a RTSP server. @@ -68,7 +72,7 @@ type Server struct { Handler ServerHandler // - // connection + // RTSP parameters // // timeout of read operations. // It defaults to 10 seconds @@ -78,16 +82,21 @@ type Server struct { WriteTimeout time.Duration // a TLS configuration to accept TLS (RTSPS) connections. TLSConfig *tls.Config - // a port to send and receive UDP/RTP packets. - // If UDPRTPAddress and UDPRTCPAddress are != "", the server can accept and send UDP streams. + // a port to send and receive RTP packets with UDP. + // If UDPRTPAddress and UDPRTCPAddress are filled, the server can read and write UDP streams. UDPRTPAddress string - // a port to send and receive UDP/RTCP packets. - // If UDPRTPAddress and UDPRTCPAddress are != "", the server can accept and send UDP streams. + // a port to send and receive RTCP packets with UDP. + // If UDPRTPAddress and UDPRTCPAddress are filled, the server can read and write UDP streams. UDPRTCPAddress string - - // - // reading / writing - // + // a range of multicast IPs to use. + // If MulticastIPRange, MulticastRTPPort, MulticastRTCPPort are filled, the server can read and write UDP-multicast streams. + MulticastIPRange string + // a port to send RTP packets with UDP-multicast. + // If MulticastIPRange, MulticastRTPPort, MulticastRTCPPort are filled, the server can read and write UDP-multicast streams. + MulticastRTPPort uint + // a port to send RTCP packets with UDP-multicast. + // If MulticastIPRange, MulticastRTPPort, MulticastRTCPPort are filled, the server can read and write UDP-multicast streams. + MulticastRTCPPort uint // read buffer count. // If greater than 1, allows to pass buffers to routines different than the one // that is reading frames. @@ -120,6 +129,8 @@ type Server struct { ctx context.Context ctxCancel func() wg sync.WaitGroup + multicastNet *net.IPNet + multicastNextIP net.IP tcpListener net.Listener udpRTPListener *serverUDPListener udpRTCPListener *serverUDPListener @@ -129,24 +140,23 @@ type Server struct { streams map[*ServerStream]struct{} // in - connClose chan *ServerConn - sessionRequest chan request - sessionClose chan *ServerSession - streamAdd chan *ServerStream - streamRemove chan *ServerStream + connClose chan *ServerConn + sessionRequest chan sessionRequestReq + sessionClose chan *ServerSession + streamAdd chan *ServerStream + streamRemove chan *ServerStream + streamMulticastIP chan streamMulticastIPReq } // Start starts listening on the given address. func (s *Server) Start(address string) error { - // connection + // RTSP parameters if s.ReadTimeout == 0 { s.ReadTimeout = 10 * time.Second } if s.WriteTimeout == 0 { s.WriteTimeout = 10 * time.Second } - - // reading / writing if s.ReadBufferCount == 0 { s.ReadBufferCount = 512 } @@ -210,11 +220,63 @@ func (s *Server) Start(address string) error { } } + if s.MulticastIPRange != "" && (s.MulticastRTPPort == 0 || s.MulticastRTCPPort == 0) || + (s.MulticastRTPPort != 0 && (s.MulticastRTCPPort == 0 || s.MulticastIPRange == "")) || + s.MulticastRTCPPort != 0 && (s.MulticastRTPPort == 0 || s.MulticastIPRange == "") { + if s.udpRTPListener != nil { + s.udpRTPListener.close() + } + if s.udpRTCPListener != nil { + s.udpRTCPListener.close() + } + return fmt.Errorf("MulticastIPRange, MulticastRTPPort and MulticastRTCPPort must be used together") + } + + if s.MulticastIPRange != "" { + if (s.MulticastRTPPort % 2) != 0 { + if s.udpRTPListener != nil { + s.udpRTPListener.close() + } + if s.udpRTCPListener != nil { + s.udpRTCPListener.close() + } + return fmt.Errorf("RTP port must be even") + } + + if s.MulticastRTCPPort != (s.MulticastRTPPort + 1) { + if s.udpRTPListener != nil { + s.udpRTPListener.close() + } + if s.udpRTCPListener != nil { + s.udpRTCPListener.close() + } + return fmt.Errorf("RTCP and RTP ports must be consecutive") + } + + var err error + _, s.multicastNet, err = net.ParseCIDR(s.MulticastIPRange) + if err != nil { + if s.udpRTPListener != nil { + s.udpRTPListener.close() + } + if s.udpRTCPListener != nil { + s.udpRTCPListener.close() + } + return err + } + + s.multicastNextIP = s.multicastNet.IP + } + var err error s.tcpListener, err = s.Listen("tcp", address) if err != nil { - s.udpRTPListener.close() - s.udpRTPListener.close() + if s.udpRTPListener != nil { + s.udpRTPListener.close() + } + if s.udpRTCPListener != nil { + s.udpRTCPListener.close() + } return err } @@ -246,10 +308,11 @@ func (s *Server) run() { s.conns = make(map[*ServerConn]struct{}) s.streams = make(map[*ServerStream]struct{}) s.connClose = make(chan *ServerConn) - s.sessionRequest = make(chan request) + s.sessionRequest = make(chan sessionRequestReq) s.sessionClose = make(chan *ServerSession) s.streamAdd = make(chan *ServerStream) s.streamRemove = make(chan *ServerStream) + s.streamMulticastIP = make(chan streamMulticastIPReq) s.wg.Add(1) connNew := make(chan net.Conn) @@ -300,7 +363,7 @@ outer: ss.request <- req } else { if !req.create { - req.res <- requestRes{ + req.res <- sessionRequestRes{ res: &base.Response{ StatusCode: base.StatusBadRequest, }, @@ -311,7 +374,7 @@ outer: id, err := newSessionID(s.sessions) if err != nil { - req.res <- requestRes{ + req.res <- sessionRequestRes{ res: &base.Response{ StatusCode: base.StatusBadRequest, }, @@ -326,7 +389,7 @@ outer: select { case ss.request <- req: case <-ss.ctx.Done(): - req.res <- requestRes{ + req.res <- sessionRequestRes{ res: &base.Response{ StatusCode: base.StatusBadRequest, }, @@ -348,6 +411,15 @@ outer: case st := <-s.streamRemove: delete(s.streams, st) + case req := <-s.streamMulticastIP: + ip32 := binary.BigEndian.Uint32(s.multicastNextIP) + mask := binary.BigEndian.Uint32(s.multicastNet.Mask) + ip32 = (ip32 & mask) | ((ip32 + 1) & ^mask) + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, ip32) + s.multicastNextIP = ip + req.res <- ip + case <-s.ctx.Done(): break outer } diff --git a/server_read_test.go b/server_read_test.go index e333bc7c..4e71c5fa 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -357,10 +357,15 @@ func TestServerRead(t *testing.T) { } switch proto { - case "udp", "multicast": + case "udp": s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTCPAddress = "127.0.0.1:8001" + case "multicast": + s.MulticastIPRange = "224.1.0.0/16" + s.MulticastRTPPort = 8000 + s.MulticastRTCPPort = 8001 + case "tls": cert, err := tls.X509KeyPair(serverCert, serverKey) require.NoError(t, err) @@ -451,7 +456,7 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) for _, intf := range intfs { - err := p.JoinGroup(&intf, &net.UDPAddr{IP: multicastIP}) + err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(*th.Destination)}) require.NoError(t, err) } @@ -465,7 +470,7 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) for _, intf := range intfs { - err := p.JoinGroup(&intf, &net.UDPAddr{IP: multicastIP}) + err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(*th.Destination)}) require.NoError(t, err) } } @@ -527,7 +532,7 @@ func TestServerRead(t *testing.T) { case "multicast": l2.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ - IP: multicastIP, + IP: net.ParseIP(*th.Destination), Port: th.Ports[1], }) <-framesReceived diff --git a/server_test.go b/server_test.go index e993367e..f5bb90d5 100644 --- a/server_test.go +++ b/server_test.go @@ -389,6 +389,9 @@ func TestServerHighLevelPublishRead(t *testing.T) { proto = "rtsp" s.UDPRTPAddress = "127.0.0.1:8000" s.UDPRTCPAddress = "127.0.0.1:8001" + s.MulticastIPRange = "224.1.0.0/16" + s.MulticastRTPPort = 8002 + s.MulticastRTCPPort = 8003 } err := s.Start("localhost:8554") diff --git a/serverconn.go b/serverconn.go index 0ffda1b8..f484372e 100644 --- a/serverconn.go +++ b/serverconn.go @@ -543,8 +543,8 @@ func (sc *ServerConn) handleRequestInSession( // if the session is already linked to this conn, communicate directly with it if sxID != "" { if ss, ok := sc.sessions[sxID]; ok { - cres := make(chan requestRes) - sreq := request{ + cres := make(chan sessionRequestRes) + sreq := sessionRequestReq{ sc: sc, req: req, id: sxID, @@ -566,8 +566,8 @@ func (sc *ServerConn) handleRequestInSession( } // otherwise, pass through Server - cres := make(chan requestRes) - sreq := request{ + cres := make(chan sessionRequestRes) + sreq := sessionRequestReq{ sc: sc, req: req, id: sxID, diff --git a/serversession.go b/serversession.go index 71b7764d..7de85539 100644 --- a/serversession.go +++ b/serversession.go @@ -137,7 +137,7 @@ type ServerSession struct { udpLastFrameTime *int64 // publish, udp // in - request chan request + request chan sessionRequestReq connRemove chan *ServerConn } @@ -156,7 +156,7 @@ func newServerSession( ctxCancel: ctxCancel, conns: make(map[*ServerConn]struct{}), lastRequestTime: time.Now(), - request: make(chan request), + request: make(chan sessionRequestReq), connRemove: make(chan *ServerConn), } @@ -253,11 +253,11 @@ func (ss *ServerSession) run() { } if _, ok := err.(liberrors.ErrServerSessionTeardown); ok { - req.res <- requestRes{res: res, err: nil} + req.res <- sessionRequestRes{res: res, err: nil} return liberrors.ErrServerSessionTeardown{} } - req.res <- requestRes{ + req.res <- sessionRequestRes{ res: res, err: err, ss: ss, @@ -327,7 +327,8 @@ func (ss *ServerSession) run() { case ServerSessionStatePlay: ss.setuppedStream.readerSetInactive(ss) - if *ss.setuppedProtocol == base.StreamProtocolUDP { + if *ss.setuppedProtocol == base.StreamProtocolUDP && + *ss.setuppedDelivery == base.StreamDeliveryUnicast { ss.s.udpRTCPListener.removeClient(ss) } @@ -572,18 +573,23 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base } if inTH.Protocol == base.StreamProtocolUDP { - if ss.s.udpRTPListener == nil { + if delivery == base.StreamDeliveryUnicast { + if ss.s.udpRTPListener == nil { + return &base.Response{ + StatusCode: base.StatusUnsupportedTransport, + }, nil + } + + if inTH.ClientPorts == nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerTransportHeaderNoClientPorts{} + } + } else if ss.s.MulticastIPRange == "" { return &base.Response{ StatusCode: base.StatusUnsupportedTransport, }, nil } - - if delivery == base.StreamDeliveryUnicast && inTH.ClientPorts == nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerTransportHeaderNoClientPorts{} - } - } else { if delivery == base.StreamDeliveryMulticast { return &base.Response{ @@ -626,16 +632,22 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base }) if res.StatusCode == base.StatusOK { - th := headers.Transport{} - if ss.state == ServerSessionStateInitial { + err := stream.readerAdd(ss, delivery == base.StreamDeliveryMulticast) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err + } + ss.state = ServerSessionStatePrePlay ss.setuppedPath = &path ss.setuppedQuery = &query ss.setuppedStream = stream - stream.readerAdd(ss, delivery == base.StreamDeliveryMulticast) } + th := headers.Transport{} + if ss.state == ServerSessionStatePrePlay { ssrc := stream.ssrc(trackID) if ssrc != 0 { @@ -663,7 +675,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base th.Delivery = &de v := uint(127) th.TTL = &v - d := multicastIP.String() + d := stream.multicastListeners[trackID].rtpListener.ip().String() th.Destination = &d th.Ports = &[2]int{ stream.multicastListeners[trackID].rtpListener.port(), diff --git a/serverstream.go b/serverstream.go index 073a0c1b..44b66fb9 100644 --- a/serverstream.go +++ b/serverstream.go @@ -108,7 +108,7 @@ func (st *ServerStream) lastSequenceNumber(trackID int) uint16 { return uint16(atomic.LoadUint32(&st.trackInfos[trackID].lastSequenceNumber)) } -func (st *ServerStream) readerAdd(ss *ServerSession, isMulticast bool) { +func (st *ServerStream) readerAdd(ss *ServerSession, isMulticast bool) error { st.mutex.Lock() defer st.mutex.Unlock() @@ -123,18 +123,31 @@ func (st *ServerStream) readerAdd(ss *ServerSession, isMulticast bool) { } if !isMulticast || st.multicastListeners != nil { - return + return nil } st.multicastListeners = make([]*listenerPair, len(st.tracks)) for i := range st.tracks { - rtpListener, rtcpListener := newServerUDPListenerMulticastPair(st.s) + rtpListener, rtcpListener, err := newServerUDPListenerMulticastPair(st.s) + if err != nil { + for _, l := range st.multicastListeners { + if l != nil { + l.rtpListener.close() + l.rtcpListener.close() + } + } + st.multicastListeners = nil + return err + } + st.multicastListeners[i] = &listenerPair{ rtpListener: rtpListener, rtcpListener: rtcpListener, } } + + return nil } func (st *ServerStream) readerRemove(ss *ServerSession) { @@ -207,13 +220,13 @@ func (st *ServerStream) WriteFrame(trackID int, streamType StreamType, payload [ if st.multicastListeners != nil { if streamType == StreamTypeRTP { st.multicastListeners[trackID].rtpListener.write(payload, &net.UDPAddr{ - IP: multicastIP, + IP: st.multicastListeners[trackID].rtpListener.ip(), Zone: "", Port: st.multicastListeners[trackID].rtpListener.port(), }) } else { st.multicastListeners[trackID].rtcpListener.write(payload, &net.UDPAddr{ - IP: multicastIP, + IP: st.multicastListeners[trackID].rtpListener.ip(), Zone: "", Port: st.multicastListeners[trackID].rtcpListener.port(), }) diff --git a/serverudpl.go b/serverudpl.go index 0e6adf3b..49b090f1 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -2,7 +2,7 @@ package gortsplib import ( "context" - "math/rand" + "fmt" "net" "strconv" "sync" @@ -19,8 +19,6 @@ const ( serverConnUDPListenerKernelReadBufferSize = 0x80000 // same as gstreamer's rtspsrc ) -var multicastIP = net.ParseIP("239.0.0.0") - type bufAddrPair struct { buf []byte addr *net.UDPAddr @@ -55,6 +53,7 @@ type serverUDPListener struct { ctxCancel func() wg sync.WaitGroup pc *net.UDPConn + listenIP net.IP streamType StreamType writeTimeout time.Duration readBuf *multibuffer.MultiBuffer @@ -63,25 +62,29 @@ type serverUDPListener struct { ringBuffer *ringbuffer.RingBuffer } -func newServerUDPListenerMulticastPair(s *Server) (*serverUDPListener, *serverUDPListener) { - // choose two consecutive ports in range 65535-10000 - // rtp must be even and rtcp odd - for { - rtpPort := (rand.Intn((65535-10000)/2) * 2) + 10000 - rtpListener, err := newServerUDPListener(s, true, multicastIP.String()+":"+strconv.FormatInt(int64(rtpPort), 10), StreamTypeRTP) - if err != nil { - continue - } - - rtcpPort := rtpPort + 1 - rtcpListener, err := newServerUDPListener(s, true, multicastIP.String()+":"+strconv.FormatInt(int64(rtcpPort), 10), StreamTypeRTCP) - if err != nil { - rtpListener.close() - continue - } - - return rtpListener, rtcpListener +func newServerUDPListenerMulticastPair(s *Server) (*serverUDPListener, *serverUDPListener, error) { + res := make(chan net.IP) + select { + case s.streamMulticastIP <- streamMulticastIPReq{res: res}: + case <-s.ctx.Done(): + return nil, nil, fmt.Errorf("terminated") } + ip := <-res + + rtpListener, err := newServerUDPListener(s, true, + ip.String()+":"+strconv.FormatInt(int64(s.MulticastRTPPort), 10), StreamTypeRTP) + if err != nil { + return nil, nil, err + } + + rtcpListener, err := newServerUDPListener(s, true, + ip.String()+":"+strconv.FormatInt(int64(s.MulticastRTCPPort), 10), StreamTypeRTCP) + if err != nil { + rtpListener.close() + return nil, nil, err + } + + return rtpListener, rtcpListener, nil } func newServerUDPListener( @@ -90,6 +93,7 @@ func newServerUDPListener( address string, streamType StreamType) (*serverUDPListener, error) { var pc *net.UDPConn + var listenIP net.IP if multicast { host, port, err := net.SplitHostPort(address) if err != nil { @@ -113,8 +117,10 @@ func newServerUDPListener( return nil, err } + listenIP = net.ParseIP(host) + for _, intf := range intfs { - err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP(host)}) + err := p.JoinGroup(&intf, &net.UDPAddr{IP: listenIP}) if err != nil { return nil, err } @@ -126,7 +132,9 @@ func newServerUDPListener( if err != nil { return nil, err } + pc = tmp.(*net.UDPConn) + listenIP = tmp.LocalAddr().(*net.UDPAddr).IP } err := pc.SetReadBuffer(serverConnUDPListenerKernelReadBufferSize) @@ -141,6 +149,7 @@ func newServerUDPListener( ctx: ctx, ctxCancel: ctxCancel, pc: pc, + listenIP: listenIP, clients: make(map[clientAddr]*clientData), } @@ -160,6 +169,14 @@ func (u *serverUDPListener) close() { u.wg.Wait() } +func (u *serverUDPListener) ip() net.IP { + return u.listenIP +} + +func (u *serverUDPListener) port() int { + return u.pc.LocalAddr().(*net.UDPAddr).Port +} + func (u *serverUDPListener) run() { defer u.wg.Done() @@ -225,10 +242,6 @@ func (u *serverUDPListener) run() { u.ringBuffer.Close() } -func (u *serverUDPListener) port() int { - return u.pc.LocalAddr().(*net.UDPAddr).Port -} - func (u *serverUDPListener) write(buf []byte, addr *net.UDPAddr) { u.ringBuffer.Push(bufAddrPair{buf, addr}) }