diff --git a/client.go b/client.go index a74590d1..e2fb3298 100644 --- a/client.go +++ b/client.go @@ -1664,9 +1664,11 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp return nil, err } - // open the firewall by sending packets to the counterpart. + // open the firewall by sending test packets to the counterpart. // do this before sending the request. - if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { + // don't do this with multicast, otherwise the RTP packet is going to be broadcasted + // to all listeners, including us, messing up the stream. + if *c.effectiveTransport == TransportUDP { for _, ct := range c.tracks { byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() ct.udpRTPListener.write(byts) diff --git a/client_read_test.go b/client_read_test.go index 406e1f69..1360217f 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -368,11 +368,13 @@ func TestClientRead(t *testing.T) { switch transport { case "udp", "multicast": // skip firewall opening - buf := make([]byte, 2048) - _, _, err := l2.ReadFrom(buf) - require.NoError(t, err) + if transport == "udp" { + buf := make([]byte, 2048) + _, _, err := l2.ReadFrom(buf) + require.NoError(t, err) + } - buf = make([]byte, 2048) + buf := make([]byte, 2048) n, _, err := l2.ReadFrom(buf) require.NoError(t, err) packets, err := rtcp.Unmarshal(buf[:n]) @@ -403,8 +405,6 @@ func TestClientRead(t *testing.T) { require.NoError(t, err) }() - counter := 0 - c := &Client{ TLSConfig: &tls.Config{ InsecureSkipVerify: true, @@ -427,14 +427,6 @@ func TestClientRead(t *testing.T) { } c.OnPacketRTP = func(ctx *ClientOnPacketRTPCtx) { - // ignore multicast loopback - if transport == "multicast" { - counter++ - if counter <= 1 || counter >= 3 { - return - } - } - require.Equal(t, 0, ctx.TrackID) require.Equal(t, &testRTPPacket, ctx.Packet) diff --git a/server_read_test.go b/server_read_test.go index 77e11e83..7bebf5ca 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -347,7 +347,7 @@ func TestServerRead(t *testing.T) { }, nil }, onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) { - // skip multicast loopback + // ignore multicast loopback if transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 { return } diff --git a/serversession.go b/serversession.go index 074e84bf..59237239 100644 --- a/serversession.go +++ b/serversession.go @@ -987,7 +987,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base go ss.runWriter() for trackID, st := range ss.setuppedTracks { - // open the firewall by sending packets to the counterpart + // open the firewall by sending test packets to the counterpart. ss.WritePacketRTP(trackID, &rtp.Packet{Header: rtp.Header{Version: 2}}) ss.WritePacketRTCP(trackID, &rtcp.ReceiverReport{})