server: support receiving RTCP packets from multicast clients

This commit is contained in:
aler9
2021-06-18 17:34:25 +02:00
parent d165f1fead
commit ea1c7c69be
3 changed files with 70 additions and 22 deletions

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls"
"net"
"strconv"
"sync/atomic"
"testing"
"time"
@@ -16,6 +17,32 @@ import (
"github.com/aler9/gortsplib/pkg/headers"
)
func multicastCapableIP(t *testing.T) string {
intfs, err := net.Interfaces()
require.NoError(t, err)
for _, intf := range intfs {
if (intf.Flags & net.FlagMulticast) != 0 {
addrs, err := intf.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
switch v := addr.(type) {
case *net.IPNet:
return v.IP.String()
case *net.IPAddr:
return v.IP.String()
}
}
}
}
t.Errorf("unable to find a multicast IP")
return ""
}
func TestServerReadSetupPath(t *testing.T) {
for _, ca := range []struct {
name string
@@ -278,6 +305,8 @@ func TestServerRead(t *testing.T) {
stream := NewServerStream(Tracks{track})
counter := uint64(0)
s := &Server{
Handler: &testServerHandler{
onConnOpen: func(ctx *ServerHandlerOnConnOpenCtx) {
@@ -309,6 +338,11 @@ func TestServerRead(t *testing.T) {
}, nil
},
onFrame: func(ctx *ServerHandlerOnFrameCtx) {
// skip multicast loopback
if proto == "multicast" && atomic.AddUint64(&counter, 1) <= 1 {
return
}
require.Equal(t, 0, ctx.TrackID)
require.Equal(t, StreamTypeRTCP, ctx.StreamType)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload)
@@ -333,11 +367,12 @@ func TestServerRead(t *testing.T) {
s.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
}
err = s.Start("localhost:8554")
listenIP := multicastCapableIP(t)
err = s.Start(listenIP + ":8554")
require.NoError(t, err)
defer s.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
nconn, err := net.Dial("tcp", listenIP+":8554")
require.NoError(t, err)
conn := func() net.Conn {
@@ -378,7 +413,7 @@ func TestServerRead(t *testing.T) {
res, err := writeReqReadRes(bconn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": inTH.Write(),
@@ -397,11 +432,11 @@ func TestServerRead(t *testing.T) {
var l2 net.PacketConn
switch proto {
case "udp":
l1, err = net.ListenPacket("udp", "localhost:35466")
l1, err = net.ListenPacket("udp", listenIP+":35466")
require.NoError(t, err)
defer l1.Close()
l2, err = net.ListenPacket("udp", "localhost:35467")
l2, err = net.ListenPacket("udp", listenIP+":35467")
require.NoError(t, err)
defer l2.Close()
@@ -437,7 +472,7 @@ func TestServerRead(t *testing.T) {
res, err = writeReqReadRes(bconn, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": res.Header["Session"],
@@ -453,14 +488,14 @@ func TestServerRead(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, buf[:n])
buf = make([]byte, 2048)
// skip firewall opening
if proto == "udp" {
_, _, err = l2.ReadFrom(buf)
buf := make([]byte, 2048)
_, _, err := l2.ReadFrom(buf)
require.NoError(t, err)
}
buf = make([]byte, 2048)
n, _, err = l2.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, buf[:n])
@@ -491,8 +526,11 @@ func TestServerRead(t *testing.T) {
<-framesReceived
case "multicast":
// sending RTCP with multicast is currently not supported
// since the source IP cannot be verified correctly
l2.WriteTo([]byte{0x01, 0x02, 0x03, 0x04}, &net.UDPAddr{
IP: multicastIP,
Port: th.Ports[1],
})
<-framesReceived
default:
err = base.InterleavedFrame{
@@ -508,7 +546,7 @@ func TestServerRead(t *testing.T) {
// ping with OPTIONS
res, err = writeReqReadRes(bconn, base.Request{
Method: base.Options,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"4"},
"Session": res.Header["Session"],
@@ -520,7 +558,7 @@ func TestServerRead(t *testing.T) {
// ping with GET_PARAMETER
res, err = writeReqReadRes(bconn, base.Request{
Method: base.GetParameter,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"5"},
"Session": res.Header["Session"],
@@ -532,7 +570,7 @@ func TestServerRead(t *testing.T) {
res, err = writeReqReadRes(bconn, base.Request{
Method: base.Teardown,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
URL: mustParseURL("rtsp://" + listenIP + ":8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"6"},
"Session": res.Header["Session"],

View File

@@ -788,15 +788,16 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.setuppedStream.readerSetActive(ss)
if *ss.setuppedProtocol == base.StreamProtocolUDP &&
*ss.setuppedDelivery == base.StreamDeliveryUnicast {
// readers can send RTCP frames, they cannot sent RTP frames
for trackID, track := range ss.setuppedTracks {
sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false)
if *ss.setuppedProtocol == base.StreamProtocolUDP {
if *ss.setuppedDelivery == base.StreamDeliveryUnicast {
for trackID, track := range ss.setuppedTracks {
// readers can send RTCP frames
sc.s.udpRTCPListener.addClient(ss.udpIP, track.udpRTCPPort, ss, trackID, false)
// open the firewall by sending packets to the counterpart
ss.WriteFrame(trackID, StreamTypeRTCP,
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
// open the firewall by sending packets to the counterpart
ss.WriteFrame(trackID, StreamTypeRTCP,
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
}
}
return res, err

View File

@@ -157,6 +157,11 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) {
if *ss.setuppedDelivery == base.StreamDeliveryUnicast {
st.readersUnicast[ss] = struct{}{}
} else {
for trackID := range ss.setuppedTracks {
st.multicastListeners[trackID].rtcpListener.addClient(
ss.udpIP, st.multicastListeners[trackID].rtcpListener.port(), ss, trackID, false)
}
}
}
@@ -166,6 +171,10 @@ func (st *ServerStream) readerSetInactive(ss *ServerSession) {
if *ss.setuppedDelivery == base.StreamDeliveryUnicast {
delete(st.readersUnicast, ss)
} else {
for trackID := range ss.setuppedTracks {
st.multicastListeners[trackID].rtcpListener.removeClient(ss)
}
}
}