client: avoid sending/receiving invalid packet when reading with multicast

This commit is contained in:
aler9
2022-07-05 22:54:40 +02:00
parent 18005a9cde
commit fb39087150
4 changed files with 12 additions and 18 deletions

View File

@@ -1664,9 +1664,11 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp
return nil, err return nil, err
} }
// open the firewall by sending packets to the counterpart. // open the firewall by sending test packets to the counterpart.
// do this before sending the request. // do this before sending the request.
if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { // don't do this with multicast, otherwise the RTP packet is going to be broadcasted
// to all listeners, including us, messing up the stream.
if *c.effectiveTransport == TransportUDP {
for _, ct := range c.tracks { for _, ct := range c.tracks {
byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal()
ct.udpRTPListener.write(byts) ct.udpRTPListener.write(byts)

View File

@@ -368,11 +368,13 @@ func TestClientRead(t *testing.T) {
switch transport { switch transport {
case "udp", "multicast": case "udp", "multicast":
// skip firewall opening // skip firewall opening
if transport == "udp" {
buf := make([]byte, 2048) buf := make([]byte, 2048)
_, _, err := l2.ReadFrom(buf) _, _, err := l2.ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
}
buf = make([]byte, 2048) buf := make([]byte, 2048)
n, _, err := l2.ReadFrom(buf) n, _, err := l2.ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
packets, err := rtcp.Unmarshal(buf[:n]) packets, err := rtcp.Unmarshal(buf[:n])
@@ -403,8 +405,6 @@ func TestClientRead(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
counter := 0
c := &Client{ c := &Client{
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
@@ -427,14 +427,6 @@ func TestClientRead(t *testing.T) {
} }
c.OnPacketRTP = func(ctx *ClientOnPacketRTPCtx) { c.OnPacketRTP = func(ctx *ClientOnPacketRTPCtx) {
// ignore multicast loopback
if transport == "multicast" {
counter++
if counter <= 1 || counter >= 3 {
return
}
}
require.Equal(t, 0, ctx.TrackID) require.Equal(t, 0, ctx.TrackID)
require.Equal(t, &testRTPPacket, ctx.Packet) require.Equal(t, &testRTPPacket, ctx.Packet)

View File

@@ -347,7 +347,7 @@ func TestServerRead(t *testing.T) {
}, nil }, nil
}, },
onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) { onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) {
// skip multicast loopback // ignore multicast loopback
if transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 { if transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 {
return return
} }

View File

@@ -987,7 +987,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
go ss.runWriter() go ss.runWriter()
for trackID, st := range ss.setuppedTracks { for trackID, st := range ss.setuppedTracks {
// open the firewall by sending packets to the counterpart // open the firewall by sending test packets to the counterpart.
ss.WritePacketRTP(trackID, &rtp.Packet{Header: rtp.Header{Version: 2}}) ss.WritePacketRTP(trackID, &rtp.Packet{Header: rtp.Header{Version: 2}})
ss.WritePacketRTCP(trackID, &rtcp.ReceiverReport{}) ss.WritePacketRTCP(trackID, &rtcp.ReceiverReport{})