client: allow bidirectional communication with multicast

This commit is contained in:
aler9
2021-06-26 13:25:38 +02:00
parent a512762ba0
commit 649c63cf5b
3 changed files with 37 additions and 28 deletions

View File

@@ -354,13 +354,15 @@ func TestClientRead(t *testing.T) {
// client -> server (RTCP) // client -> server (RTCP)
switch proto { switch proto {
case "udp": case "udp", "multicast":
if proto == "udp" {
// skip firewall opening // skip firewall opening
buf := make([]byte, 2048) buf := make([]byte, 2048)
_, _, err := l2.ReadFrom(buf) _, _, err := l2.ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
}
buf = make([]byte, 2048) buf := make([]byte, 2048)
n, _, err := l2.ReadFrom(buf) n, _, err := l2.ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, buf[:n]) require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, buf[:n])
@@ -414,20 +416,19 @@ func TestClientRead(t *testing.T) {
defer close(done) defer close(done)
conn.ReadFrames(func(id int, streamType StreamType, payload []byte) { conn.ReadFrames(func(id int, streamType StreamType, payload []byte) {
// skip multicast loopback // skip multicast loopback
if proto == "multicast" && atomic.AddUint64(&counter, 1) <= 2 { if proto == "multicast" {
add := atomic.AddUint64(&counter, 1)
if add >= 2 {
return return
} }
}
require.Equal(t, 0, id) require.Equal(t, 0, id)
require.Equal(t, StreamTypeRTP, streamType) require.Equal(t, StreamTypeRTP, streamType)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload)
if proto != "multicast" {
err = conn.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08}) err = conn.WriteFrame(0, StreamTypeRTCP, []byte{0x05, 0x06, 0x07, 0x08})
require.NoError(t, err) require.NoError(t, err)
} else {
close(frameRecv)
}
}) })
}() }()

View File

@@ -475,6 +475,7 @@ func (cc *ClientConn) runBackground() {
} }
func (cc *ClientConn) runBackgroundPlayUDP() error { func (cc *ClientConn) runBackgroundPlayUDP() error {
if *cc.protocol == ClientProtocolUDP {
// open the firewall by sending packets to the counterpart // open the firewall by sending packets to the counterpart
for _, cct := range cc.tracks { for _, cct := range cc.tracks {
cct.udpRTPListener.write( cct.udpRTPListener.write(
@@ -483,6 +484,7 @@ func (cc *ClientConn) runBackgroundPlayUDP() error {
cct.udpRTCPListener.write( cct.udpRTCPListener.write(
[]byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}) []byte{0x80, 0xc9, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00})
} }
}
for _, cct := range cc.tracks { for _, cct := range cc.tracks {
cct.udpRTPListener.start() cct.udpRTPListener.start()
@@ -1379,7 +1381,8 @@ func (cc *ClientConn) doSetup(
switch proto { switch proto {
case ClientProtocolUDP: case ClientProtocolUDP:
rtpListener.remoteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP rtpListener.remoteReadIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP
rtpListener.remoteWriteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP
rtpListener.remoteZone = cc.nconn.RemoteAddr().(*net.TCPAddr).Zone rtpListener.remoteZone = cc.nconn.RemoteAddr().(*net.TCPAddr).Zone
if thRes.ServerPorts != nil { if thRes.ServerPorts != nil {
rtpListener.remotePort = thRes.ServerPorts[0] rtpListener.remotePort = thRes.ServerPorts[0]
@@ -1388,7 +1391,8 @@ func (cc *ClientConn) doSetup(
rtpListener.streamType = StreamTypeRTP rtpListener.streamType = StreamTypeRTP
cct.udpRTPListener = rtpListener cct.udpRTPListener = rtpListener
rtcpListener.remoteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP rtcpListener.remoteReadIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP
rtcpListener.remoteWriteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP
rtcpListener.remoteZone = cc.nconn.RemoteAddr().(*net.TCPAddr).Zone rtcpListener.remoteZone = cc.nconn.RemoteAddr().(*net.TCPAddr).Zone
if thRes.ServerPorts != nil { if thRes.ServerPorts != nil {
rtcpListener.remotePort = thRes.ServerPorts[1] rtcpListener.remotePort = thRes.ServerPorts[1]
@@ -1398,14 +1402,16 @@ func (cc *ClientConn) doSetup(
cct.udpRTCPListener = rtcpListener cct.udpRTCPListener = rtcpListener
case ClientProtocolMulticast: case ClientProtocolMulticast:
rtpListener.remoteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP rtpListener.remoteReadIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP
rtpListener.remoteWriteIP = *thRes.Destination
rtpListener.remoteZone = "" rtpListener.remoteZone = ""
rtpListener.remotePort = thRes.Ports[0] rtpListener.remotePort = thRes.Ports[0]
rtpListener.trackID = trackID rtpListener.trackID = trackID
rtpListener.streamType = StreamTypeRTP rtpListener.streamType = StreamTypeRTP
cct.udpRTPListener = rtpListener cct.udpRTPListener = rtpListener
rtcpListener.remoteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP rtcpListener.remoteReadIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP
rtcpListener.remoteWriteIP = *thRes.Destination
rtcpListener.remoteZone = "" rtcpListener.remoteZone = ""
rtcpListener.remotePort = thRes.Ports[1] rtcpListener.remotePort = thRes.Ports[1]
rtcpListener.trackID = trackID rtcpListener.trackID = trackID

View File

@@ -21,7 +21,8 @@ const (
type clientConnUDPListener struct { type clientConnUDPListener struct {
cc *ClientConn cc *ClientConn
pc *net.UDPConn pc *net.UDPConn
remoteIP net.IP remoteReadIP net.IP
remoteWriteIP net.IP
remoteZone string remoteZone string
remotePort int remotePort int
trackID int trackID int
@@ -31,6 +32,7 @@ type clientConnUDPListener struct {
lastFrameTime *int64 lastFrameTime *int64
writeMutex sync.Mutex writeMutex sync.Mutex
// out
done chan struct{} done chan struct{}
} }
@@ -148,7 +150,7 @@ func (l *clientConnUDPListener) run() {
uaddr := addr.(*net.UDPAddr) uaddr := addr.(*net.UDPAddr)
if !l.remoteIP.Equal(uaddr.IP) || (!isAnyPort(l.remotePort) && l.remotePort != uaddr.Port) { if !l.remoteReadIP.Equal(uaddr.IP) || (!isAnyPort(l.remotePort) && l.remotePort != uaddr.Port) {
continue continue
} }
@@ -167,7 +169,7 @@ func (l *clientConnUDPListener) run() {
uaddr := addr.(*net.UDPAddr) uaddr := addr.(*net.UDPAddr)
if !l.remoteIP.Equal(uaddr.IP) || (!isAnyPort(l.remotePort) && l.remotePort != uaddr.Port) { if !l.remoteReadIP.Equal(uaddr.IP) || (!isAnyPort(l.remotePort) && l.remotePort != uaddr.Port) {
continue continue
} }
@@ -184,7 +186,7 @@ func (l *clientConnUDPListener) write(buf []byte) error {
l.pc.SetWriteDeadline(time.Now().Add(l.cc.c.WriteTimeout)) l.pc.SetWriteDeadline(time.Now().Add(l.cc.c.WriteTimeout))
_, err := l.pc.WriteTo(buf, &net.UDPAddr{ _, err := l.pc.WriteTo(buf, &net.UDPAddr{
IP: l.remoteIP, IP: l.remoteWriteIP,
Zone: l.remoteZone, Zone: l.remoteZone,
Port: l.remotePort, Port: l.remotePort,
}) })