From ea1c7c69becc276a79beb1300dc184da9d44c60c Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Fri, 18 Jun 2021 17:34:25 +0200 Subject: [PATCH] server: support receiving RTCP packets from multicast clients --- server_read_test.go | 66 +++++++++++++++++++++++++++++++++++---------- serversession.go | 17 ++++++------ serverstream.go | 9 +++++++ 3 files changed, 70 insertions(+), 22 deletions(-) diff --git a/server_read_test.go b/server_read_test.go index a41a3942..4885def9 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "net" "strconv" + "sync/atomic" "testing" "time" @@ -16,6 +17,32 @@ import ( "github.com/aler9/gortsplib/pkg/headers" ) +func multicastCapableIP(t *testing.T) string { + intfs, err := net.Interfaces() + require.NoError(t, err) + + for _, intf := range intfs { + if (intf.Flags & net.FlagMulticast) != 0 { + addrs, err := intf.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + switch v := addr.(type) { + case *net.IPNet: + return v.IP.String() + case *net.IPAddr: + return v.IP.String() + } + } + } + } + + t.Errorf("unable to find a multicast IP") + return "" +} + func TestServerReadSetupPath(t *testing.T) { for _, ca := range []struct { name string @@ -278,6 +305,8 @@ func TestServerRead(t *testing.T) { stream := NewServerStream(Tracks{track}) + counter := uint64(0) + s := &Server{ Handler: &testServerHandler{ onConnOpen: func(ctx *ServerHandlerOnConnOpenCtx) { @@ -309,6 +338,11 @@ func TestServerRead(t *testing.T) { }, nil }, onFrame: func(ctx *ServerHandlerOnFrameCtx) { + // skip multicast loopback + if proto == "multicast" && atomic.AddUint64(&counter, 1) <= 1 { + return + } + require.Equal(t, 0, ctx.TrackID) require.Equal(t, StreamTypeRTCP, ctx.StreamType) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload) @@ -333,11 +367,12 @@ func TestServerRead(t *testing.T) { s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} } - err = s.Start("localhost:8554") + listenIP := multicastCapableIP(t) + err = s.Start(listenIP + ":8554") require.NoError(t, err) defer s.Close() - nconn, err := net.Dial("tcp", "localhost:8554") + nconn, err := net.Dial("tcp", listenIP+":8554") require.NoError(t, err) conn := func() net.Conn { @@ -378,7 +413,7 @@ func TestServerRead(t *testing.T) { res, err := writeReqReadRes(bconn, base.Request{ Method: base.Setup, - URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream/trackID=0"), Header: base.Header{ "CSeq": base.HeaderValue{"1"}, "Transport": inTH.Write(), @@ -397,11 +432,11 @@ func TestServerRead(t *testing.T) { var l2 net.PacketConn switch proto { case "udp": - l1, err = net.ListenPacket("udp", "localhost:35466") + l1, err = net.ListenPacket("udp", listenIP+":35466") require.NoError(t, err) defer l1.Close() - l2, err = net.ListenPacket("udp", "localhost:35467") + l2, err = net.ListenPacket("udp", listenIP+":35467") require.NoError(t, err) defer l2.Close() @@ -437,7 +472,7 @@ func TestServerRead(t *testing.T) { res, err = writeReqReadRes(bconn, base.Request{ Method: base.Play, - URL: mustParseURL("rtsp://localhost:8554/teststream"), + URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Session": res.Header["Session"], @@ -453,14 +488,14 @@ func TestServerRead(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf[:n]) - buf = make([]byte, 2048) - // skip firewall opening if proto == "udp" { - _, _, err = l2.ReadFrom(buf) + buf := make([]byte, 2048) + _, _, err := l2.ReadFrom(buf) require.NoError(t, err) } + buf = make([]byte, 2048) n, _, err = l2.ReadFrom(buf) require.NoError(t, err) require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, buf[:n]) @@ -491,8 +526,11 @@ func TestServerRead(t *testing.T) { <-framesReceived case "multicast": - // sending RTCP with multicast is currently not supported - // since the source IP cannot be verified correctly + l2.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{ + IP: multicastIP, + Port: th.Ports[1], + }) + <-framesReceived default: err = base.InterleavedFrame{ @@ -508,7 +546,7 @@ func TestServerRead(t *testing.T) { // ping with OPTIONS res, err = writeReqReadRes(bconn, base.Request{ Method: base.Options, - URL: mustParseURL("rtsp://localhost:8554/teststream"), + URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"), Header: base.Header{ "CSeq": base.HeaderValue{"4"}, "Session": res.Header["Session"], @@ -520,7 +558,7 @@ func TestServerRead(t *testing.T) { // ping with GET_PARAMETER res, err = writeReqReadRes(bconn, base.Request{ Method: base.GetParameter, - URL: mustParseURL("rtsp://localhost:8554/teststream"), + URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"), Header: base.Header{ "CSeq": base.HeaderValue{"5"}, "Session": res.Header["Session"], @@ -532,7 +570,7 @@ func TestServerRead(t *testing.T) { res, err = writeReqReadRes(bconn, base.Request{ Method: base.Teardown, - URL: mustParseURL("rtsp://localhost:8554/teststream"), + URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"), Header: base.Header{ "CSeq": base.HeaderValue{"6"}, "Session": res.Header["Session"], diff --git a/serversession.go b/serversession.go index b1344f7e..8fadf556 100644 --- a/serversession.go +++ b/serversession.go @@ -788,15 +788,16 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.setuppedStream.readerSetActive(ss) - if *ss.setuppedProtocol == base.StreamProtocolUDP && - *ss.setuppedDelivery == base.StreamDeliveryUnicast { - // readers can send RTCP frames, they cannot sent RTP frames - for trackID, track := range ss.setuppedTracks { - sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false) + if *ss.setuppedProtocol == base.StreamProtocolUDP { + if *ss.setuppedDelivery == base.StreamDeliveryUnicast { + for trackID, track := range ss.setuppedTracks { + // readers can send RTCP frames + sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false) - // open the firewall by sending packets to the counterpart - ss.WriteFrame(trackID, StreamTypeRTCP, - []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) + // open the firewall by sending packets to the counterpart + ss.WriteFrame(trackID, StreamTypeRTCP, + []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) + } } return res, err diff --git a/serverstream.go b/serverstream.go index 68ee80de..475cfde8 100644 --- a/serverstream.go +++ b/serverstream.go @@ -157,6 +157,11 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) { if *ss.setuppedDelivery == base.StreamDeliveryUnicast { st.readersUnicast[ss] = struct{}{} + } else { + for trackID := range ss.setuppedTracks { + st.multicastListeners[trackID].rtcpListener.addClient( + ss.udpIP, st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false) + } } } @@ -166,6 +171,10 @@ func (st *ServerStream) readerSetInactive(ss *ServerSession) { if *ss.setuppedDelivery == base.StreamDeliveryUnicast { delete(st.readersUnicast, ss) + } else { + for trackID := range ss.setuppedTracks { + st.multicastListeners[trackID].rtcpListener.removeClient(ss) + } } }