diff --git a/client_read_test.go b/client_read_test.go index 28c54a00..6c03248c 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -354,13 +354,15 @@ func TestClientRead(t *testing.T) { // client -> server (RTCP) switch proto { - case "udp": - // skip firewall opening - buf := make([]byte, 2048) - _, _, err := l2.ReadFrom(buf) - require.NoError(t, err) + case "udp", "multicast": + if proto == "udp" { + // skip firewall opening + buf := make([]byte, 2048) + _, _, err := l2.ReadFrom(buf) + require.NoError(t, err) + } - buf = make([]byte, 2048) + buf := make([]byte, 2048) n, _, err := l2.ReadFrom(buf) require.NoError(t, err) require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, buf[:n]) @@ -414,20 +416,19 @@ func TestClientRead(t *testing.T) { defer close(done) conn.ReadFrames(func(id int, streamType StreamType, payload []byte) { // skip multicast loopback - if proto == "multicast" && atomic.AddUint64(&counter, 1) <= 2 { - return + if proto == "multicast" { + add := atomic.AddUint64(&counter, 1) + if add >= 2 { + return + } } require.Equal(t, 0, id) require.Equal(t, StreamTypeRTP, streamType) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) - if proto != "multicast" { - err = conn.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08}) - require.NoError(t, err) - } else { - close(frameRecv) - } + err = conn.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08}) + require.NoError(t, err) }) }() diff --git a/clientconn.go b/clientconn.go index 48198868..33d1eb7c 100644 --- a/clientconn.go +++ b/clientconn.go @@ -475,13 +475,15 @@ func (cc *ClientConn) runBackground() { } func (cc *ClientConn) runBackgroundPlayUDP() error { - // open the firewall by sending packets to the counterpart - for _, cct := range cc.tracks { - cct.udpRTPListener.write( - []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + if *cc.protocol == ClientProtocolUDP { + // open the firewall by sending packets to the counterpart + for _, cct := range cc.tracks { + cct.udpRTPListener.write( + []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) - cct.udpRTCPListener.write( - []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) + cct.udpRTCPListener.write( + []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) + } } for _, cct := range cc.tracks { @@ -1379,7 +1381,8 @@ func (cc *ClientConn) doSetup( switch proto { case ClientProtocolUDP: - rtpListener.remoteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP + rtpListener.remoteReadIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP + rtpListener.remoteWriteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP rtpListener.remoteZone = cc.nconn.RemoteAddr().(*net.TCPAddr).Zone if thRes.ServerPorts != nil { rtpListener.remotePort = thRes.ServerPorts[0] @@ -1388,7 +1391,8 @@ func (cc *ClientConn) doSetup( rtpListener.streamType = StreamTypeRTP cct.udpRTPListener = rtpListener - rtcpListener.remoteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP + rtcpListener.remoteReadIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP + rtcpListener.remoteWriteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP rtcpListener.remoteZone = cc.nconn.RemoteAddr().(*net.TCPAddr).Zone if thRes.ServerPorts != nil { rtcpListener.remotePort = thRes.ServerPorts[1] @@ -1398,14 +1402,16 @@ func (cc *ClientConn) doSetup( cct.udpRTCPListener = rtcpListener case ClientProtocolMulticast: - rtpListener.remoteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP + rtpListener.remoteReadIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP + rtpListener.remoteWriteIP = *thRes.Destination rtpListener.remoteZone = "" rtpListener.remotePort = thRes.Ports[0] rtpListener.trackID = trackID rtpListener.streamType = StreamTypeRTP cct.udpRTPListener = rtpListener - rtcpListener.remoteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP + rtcpListener.remoteReadIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP + rtcpListener.remoteWriteIP = *thRes.Destination rtcpListener.remoteZone = "" rtcpListener.remotePort = thRes.Ports[1] rtcpListener.trackID = trackID diff --git a/clientconnudpl.go b/clientconnudpl.go index 57083af6..b6dbcb38 100644 --- a/clientconnudpl.go +++ b/clientconnudpl.go @@ -21,7 +21,8 @@ const ( type clientConnUDPListener struct { cc *ClientConn pc *net.UDPConn - remoteIP net.IP + remoteReadIP net.IP + remoteWriteIP net.IP remoteZone string remotePort int trackID int @@ -31,6 +32,7 @@ type clientConnUDPListener struct { lastFrameTime *int64 writeMutex sync.Mutex + // out done chan struct{} } @@ -148,7 +150,7 @@ func (l *clientConnUDPListener) run() { uaddr := addr.(*net.UDPAddr) - if !l.remoteIP.Equal(uaddr.IP) || (!isAnyPort(l.remotePort) && l.remotePort != uaddr.Port) { + if !l.remoteReadIP.Equal(uaddr.IP) || (!isAnyPort(l.remotePort) && l.remotePort != uaddr.Port) { continue } @@ -167,7 +169,7 @@ func (l *clientConnUDPListener) run() { uaddr := addr.(*net.UDPAddr) - if !l.remoteIP.Equal(uaddr.IP) || (!isAnyPort(l.remotePort) && l.remotePort != uaddr.Port) { + if !l.remoteReadIP.Equal(uaddr.IP) || (!isAnyPort(l.remotePort) && l.remotePort != uaddr.Port) { continue } @@ -184,7 +186,7 @@ func (l *clientConnUDPListener) write(buf []byte) error { l.pc.SetWriteDeadline(time.Now().Add(l.cc.c.WriteTimeout)) _, err := l.pc.WriteTo(buf, &net.UDPAddr{ - IP: l.remoteIP, + IP: l.remoteWriteIP, Zone: l.remoteZone, Port: l.remotePort, })