mirror of
				https://github.com/aler9/gortsplib
				synced 2025-11-01 02:52:36 +08:00 
			
		
		
		
	client: avoid sending/receiving invalid packet when reading with multicast
This commit is contained in:
		| @@ -1664,9 +1664,11 @@ func (c *Client) doPlay(ra *headers.Range, isSwitchingProtocol bool) (*base.Resp | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// open the firewall by sending packets to the counterpart. | 	// open the firewall by sending test packets to the counterpart. | ||||||
| 	// do this before sending the request. | 	// do this before sending the request. | ||||||
| 	if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { | 	// don't do this with multicast, otherwise the RTP packet is going to be broadcasted | ||||||
|  | 	// to all listeners, including us, messing up the stream. | ||||||
|  | 	if *c.effectiveTransport == TransportUDP { | ||||||
| 		for _, ct := range c.tracks { | 		for _, ct := range c.tracks { | ||||||
| 			byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() | 			byts, _ := (&rtp.Packet{Header: rtp.Header{Version: 2}}).Marshal() | ||||||
| 			ct.udpRTPListener.write(byts) | 			ct.udpRTPListener.write(byts) | ||||||
|   | |||||||
| @@ -368,11 +368,13 @@ func TestClientRead(t *testing.T) { | |||||||
| 				switch transport { | 				switch transport { | ||||||
| 				case "udp", "multicast": | 				case "udp", "multicast": | ||||||
| 					// skip firewall opening | 					// skip firewall opening | ||||||
| 					buf := make([]byte, 2048) | 					if transport == "udp" { | ||||||
| 					_, _, err := l2.ReadFrom(buf) | 						buf := make([]byte, 2048) | ||||||
| 					require.NoError(t, err) | 						_, _, err := l2.ReadFrom(buf) | ||||||
|  | 						require.NoError(t, err) | ||||||
|  | 					} | ||||||
|  |  | ||||||
| 					buf = make([]byte, 2048) | 					buf := make([]byte, 2048) | ||||||
| 					n, _, err := l2.ReadFrom(buf) | 					n, _, err := l2.ReadFrom(buf) | ||||||
| 					require.NoError(t, err) | 					require.NoError(t, err) | ||||||
| 					packets, err := rtcp.Unmarshal(buf[:n]) | 					packets, err := rtcp.Unmarshal(buf[:n]) | ||||||
| @@ -403,8 +405,6 @@ func TestClientRead(t *testing.T) { | |||||||
| 				require.NoError(t, err) | 				require.NoError(t, err) | ||||||
| 			}() | 			}() | ||||||
|  |  | ||||||
| 			counter := 0 |  | ||||||
|  |  | ||||||
| 			c := &Client{ | 			c := &Client{ | ||||||
| 				TLSConfig: &tls.Config{ | 				TLSConfig: &tls.Config{ | ||||||
| 					InsecureSkipVerify: true, | 					InsecureSkipVerify: true, | ||||||
| @@ -427,14 +427,6 @@ func TestClientRead(t *testing.T) { | |||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			c.OnPacketRTP = func(ctx *ClientOnPacketRTPCtx) { | 			c.OnPacketRTP = func(ctx *ClientOnPacketRTPCtx) { | ||||||
| 				// ignore multicast loopback |  | ||||||
| 				if transport == "multicast" { |  | ||||||
| 					counter++ |  | ||||||
| 					if counter <= 1 || counter >= 3 { |  | ||||||
| 						return |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				require.Equal(t, 0, ctx.TrackID) | 				require.Equal(t, 0, ctx.TrackID) | ||||||
| 				require.Equal(t, &testRTPPacket, ctx.Packet) | 				require.Equal(t, &testRTPPacket, ctx.Packet) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -347,7 +347,7 @@ func TestServerRead(t *testing.T) { | |||||||
| 						}, nil | 						}, nil | ||||||
| 					}, | 					}, | ||||||
| 					onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) { | 					onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) { | ||||||
| 						// skip multicast loopback | 						// ignore multicast loopback | ||||||
| 						if transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 { | 						if transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 { | ||||||
| 							return | 							return | ||||||
| 						} | 						} | ||||||
|   | |||||||
| @@ -987,7 +987,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base | |||||||
| 			go ss.runWriter() | 			go ss.runWriter() | ||||||
|  |  | ||||||
| 			for trackID, st := range ss.setuppedTracks { | 			for trackID, st := range ss.setuppedTracks { | ||||||
| 				// open the firewall by sending packets to the counterpart | 				// open the firewall by sending test packets to the counterpart. | ||||||
| 				ss.WritePacketRTP(trackID, &rtp.Packet{Header: rtp.Header{Version: 2}}) | 				ss.WritePacketRTP(trackID, &rtp.Packet{Header: rtp.Header{Version: 2}}) | ||||||
| 				ss.WritePacketRTCP(trackID, &rtcp.ReceiverReport{}) | 				ss.WritePacketRTCP(trackID, &rtcp.ReceiverReport{}) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 aler9
					aler9