From c56eee37f8bab69a0c10cfe2d8bdbca989f66f3e Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Mon, 10 Apr 2023 22:42:19 +0200 Subject: [PATCH] do not listen on IPv6 when host is 0.0.0.0 (#240) (https://github.com/aler9/mediamtx/issues/1665) --- client.go | 8 ++++---- client_media.go | 3 ++- client_play_test.go | 4 ++-- client_udpl.go | 8 ++++---- restrict_network.go | 17 +++++++++++++++++ server.go | 2 +- server_udpl.go | 8 ++++---- 7 files changed, 34 insertions(+), 16 deletions(-) create mode 100644 restrict_network.go diff --git a/client.go b/client.go index 9f04f1d7..2a1489e6 100644 --- a/client.go +++ b/client.go @@ -1232,8 +1232,8 @@ func (c *Client) doSetup( err := cm.allocateUDPListeners( false, - ":"+strconv.FormatInt(int64(rtpPort), 10), - ":"+strconv.FormatInt(int64(rtcpPort), 10), + net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)), + net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)), ) if err != nil { return nil, err @@ -1377,8 +1377,8 @@ func (c *Client) doSetup( err := cm.allocateUDPListeners( true, - thRes.Destination.String()+":"+strconv.FormatInt(int64(thRes.Ports[0]), 10), - thRes.Destination.String()+":"+strconv.FormatInt(int64(thRes.Ports[1]), 10), + net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)), + net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)), ) if err != nil { return nil, err diff --git a/client_media.go b/client_media.go index 6bf8cbd6..73df164e 100644 --- a/client_media.go +++ b/client_media.go @@ -63,7 +63,8 @@ func (cm *clientMedia) allocateUDPListeners(multicast bool, rtpAddress string, r cm.c.WriteTimeout, multicast, rtcpAddress, - cm, false) + cm, + false) if err != nil { l1.close() return err diff --git a/client_play_test.go b/client_play_test.go index 1fcd6814..8cedb949 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -322,11 +322,11 @@ func TestClientPlay(t *testing.T) { clientPorts[i] = inTH.ClientPorts th.ServerPorts = &[2]int{34556 + i*2, 34557 + i*2} - l1s[i], err = net.ListenPacket("udp", listenIP+":"+strconv.FormatInt(int64(th.ServerPorts[0]), 10)) + l1s[i], err = net.ListenPacket("udp", net.JoinHostPort(listenIP, strconv.FormatInt(int64(th.ServerPorts[0]), 10))) require.NoError(t, err) defer l1s[i].Close() - l2s[i], err = net.ListenPacket("udp", listenIP+":"+strconv.FormatInt(int64(th.ServerPorts[1]), 10)) + l2s[i], err = net.ListenPacket("udp", net.JoinHostPort(listenIP, strconv.FormatInt(int64(th.ServerPorts[1]), 10))) require.NoError(t, err) defer l2s[i].Close() diff --git a/client_udpl.go b/client_udpl.go index 8a5faa56..03677b2e 100644 --- a/client_udpl.go +++ b/client_udpl.go @@ -49,7 +49,7 @@ func newClientUDPListenerPair( anyPortEnable, writeTimeout, false, - ":"+strconv.FormatInt(int64(rtpPort), 10), + net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)), cm, true) if err != nil { @@ -62,7 +62,7 @@ func newClientUDPListenerPair( anyPortEnable, writeTimeout, false, - ":"+strconv.FormatInt(int64(rtcpPort), 10), + net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)), cm, false) if err != nil { @@ -90,7 +90,7 @@ func newClientUDPListener( return nil, err } - tmp, err := listenPacket("udp", "224.0.0.0:"+port) + tmp, err := listenPacket(restrictNetwork("udp", "224.0.0.0:"+port)) if err != nil { return nil, err } @@ -116,7 +116,7 @@ func newClientUDPListener( pc = tmp.(*net.UDPConn) } else { - tmp, err := listenPacket("udp", address) + tmp, err := listenPacket(restrictNetwork("udp", address)) if err != nil { return nil, err } diff --git a/restrict_network.go b/restrict_network.go new file mode 100644 index 00000000..b191d8fa --- /dev/null +++ b/restrict_network.go @@ -0,0 +1,17 @@ +package gortsplib + +import ( + "net" +) + +// do not listen on IPv6 when address is 0.0.0.0. +func restrictNetwork(network string, address string) (string, string) { + host, _, err := net.SplitHostPort(address) + if err == nil { + if host == "0.0.0.0" { + return network + "4", address + } + } + + return network, address +} diff --git a/server.go b/server.go index 543b6166..b206983f 100644 --- a/server.go +++ b/server.go @@ -290,7 +290,7 @@ func (s *Server) Start() error { } var err error - s.tcpListener, err = s.Listen("tcp", s.RTSPAddress) + s.tcpListener, err = s.Listen(restrictNetwork("tcp", s.RTSPAddress)) if err != nil { if s.udpRTPListener != nil { s.udpRTPListener.close() diff --git a/server_udpl.go b/server_udpl.go index c42b1a47..30e4d639 100644 --- a/server_udpl.go +++ b/server_udpl.go @@ -60,7 +60,7 @@ func newServerUDPListenerMulticastPair( listenPacket, writeTimeout, true, - ip.String()+":"+strconv.FormatInt(int64(multicastRTPPort), 10), + net.JoinHostPort(ip.String(), strconv.FormatInt(int64(multicastRTPPort), 10)), true, ) if err != nil { @@ -71,7 +71,7 @@ func newServerUDPListenerMulticastPair( listenPacket, writeTimeout, true, - ip.String()+":"+strconv.FormatInt(int64(multicastRTCPPort), 10), + net.JoinHostPort(ip.String(), strconv.FormatInt(int64(multicastRTCPPort), 10)), false, ) if err != nil { @@ -97,7 +97,7 @@ func newServerUDPListener( return nil, err } - tmp, err := listenPacket("udp", "224.0.0.0:"+port) + tmp, err := listenPacket(restrictNetwork("udp", "224.0.0.0:"+port)) if err != nil { return nil, err } @@ -127,7 +127,7 @@ func newServerUDPListener( pc = tmp.(*net.UDPConn) } else { - tmp, err := listenPacket("udp", address) + tmp, err := listenPacket(restrictNetwork("udp", address)) if err != nil { return nil, err }