mirror of
https://github.com/pion/ice.git
synced 2025-10-05 07:26:55 +08:00
Add multi-port wrappers for UDPMux and TCPMux
These wrappers allow a caller to provide UDPMux and TCPMux instances to the ICE agent that represent multiple open ports. This can be desirable in what would otherwise be single-port deployments, as it increases the chance that one of the fixed ports will not be blocked by a users firewall.
This commit is contained in:
@@ -139,4 +139,6 @@ var (
|
||||
errSendSTUNPacket = errors.New("failed to send STUN packet")
|
||||
errXORMappedAddrTimeout = errors.New("timeout while waiting for XORMappedAddr")
|
||||
errNotImplemented = errors.New("not implemented yet")
|
||||
errNoUDPMuxAvailable = errors.New("no UDP mux is available")
|
||||
errNoTCPMuxAvailable = errors.New("no TCP mux is available")
|
||||
)
|
||||
|
69
gather.go
69
gather.go
@@ -158,7 +158,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
||||
for _, ip := range localIPs {
|
||||
mappedIP := ip
|
||||
if a.mDNSMode != MulticastDNSModeQueryAndGather && a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost {
|
||||
if _mappedIP, err := a.extIPMapper.findExternalIP(ip.String()); err == nil {
|
||||
if _mappedIP, innerErr := a.extIPMapper.findExternalIP(ip.String()); innerErr == nil {
|
||||
mappedIP = _mappedIP
|
||||
} else {
|
||||
a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", ip.String())
|
||||
@@ -171,70 +171,93 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
||||
}
|
||||
|
||||
for network := range networks {
|
||||
var (
|
||||
port int
|
||||
type connAndPort struct {
|
||||
conn net.PacketConn
|
||||
err error
|
||||
port int
|
||||
}
|
||||
var (
|
||||
conns []connAndPort
|
||||
tcpType TCPType
|
||||
)
|
||||
|
||||
switch network {
|
||||
case tcp:
|
||||
// Handle ICE TCP passive mode
|
||||
var muxConns []net.PacketConn
|
||||
if multi, ok := a.tcpMux.(AllConnsGetter); ok {
|
||||
a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag)
|
||||
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil)
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrTCPMuxNotInitialized) {
|
||||
a.log.Warnf("error getting all tcp conns by ufrag: %s %s %s", network, ip, a.localUfrag)
|
||||
}
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
a.log.Debugf("GetConn by ufrag: %s", a.localUfrag)
|
||||
conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil)
|
||||
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil)
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrTCPMuxNotInitialized) {
|
||||
a.log.Warnf("error getting tcp conn by ufrag: %s %s %s", network, ip, a.localUfrag)
|
||||
}
|
||||
continue
|
||||
}
|
||||
muxConns = []net.PacketConn{conn}
|
||||
}
|
||||
|
||||
// Extract the port for each PacketConn we got.
|
||||
for _, conn := range muxConns {
|
||||
if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok {
|
||||
port = tcpConn.Port
|
||||
conns = append(conns, connAndPort{conn, tcpConn.Port})
|
||||
} else {
|
||||
a.log.Warnf("failed to get port of conn from TCPMux: %s %s %s", network, ip, a.localUfrag)
|
||||
}
|
||||
}
|
||||
if len(conns) == 0 {
|
||||
// Didn't succeed with any, try the next network.
|
||||
continue
|
||||
}
|
||||
tcpType = TCPTypePassive
|
||||
// is there a way to verify that the listen address is even
|
||||
// accessible from the current interface.
|
||||
case udp:
|
||||
conn, err = listenUDPInPortRange(a.net, a.log, int(a.portmax), int(a.portmin), network, &net.UDPAddr{IP: ip, Port: 0})
|
||||
conn, err := listenUDPInPortRange(a.net, a.log, int(a.portmax), int(a.portmin), network, &net.UDPAddr{IP: ip, Port: 0})
|
||||
if err != nil {
|
||||
a.log.Warnf("could not listen %s %s", network, ip)
|
||||
continue
|
||||
}
|
||||
|
||||
if udpConn, ok := conn.LocalAddr().(*net.UDPAddr); ok {
|
||||
port = udpConn.Port
|
||||
conns = append(conns, connAndPort{conn, udpConn.Port})
|
||||
} else {
|
||||
a.log.Warnf("failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, ip, a.localUfrag)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, connAndPort := range conns {
|
||||
hostConfig := CandidateHostConfig{
|
||||
Network: network,
|
||||
Address: address,
|
||||
Port: port,
|
||||
Port: connAndPort.port,
|
||||
Component: ComponentRTP,
|
||||
TCPType: tcpType,
|
||||
}
|
||||
|
||||
c, err := NewCandidateHost(&hostConfig)
|
||||
if err != nil {
|
||||
closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create host candidate: %s %s %d: %v", network, mappedIP, port, err))
|
||||
closeConnAndLog(connAndPort.conn, a.log, fmt.Sprintf("Failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if a.mDNSMode == MulticastDNSModeQueryAndGather {
|
||||
if err = c.setIP(ip); err != nil {
|
||||
closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create host candidate: %s %s %d: %v", network, mappedIP, port, err))
|
||||
closeConnAndLog(connAndPort.conn, a.log, fmt.Sprintf("Failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := a.addCandidate(ctx, c, conn); err != nil {
|
||||
if err := a.addCandidate(ctx, c, connAndPort.conn); err != nil {
|
||||
if closeErr := c.close(); closeErr != nil {
|
||||
a.log.Warnf("Failed to close candidate: %v", closeErr)
|
||||
}
|
||||
@@ -242,9 +265,10 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
|
||||
func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit
|
||||
if a.udpMux == nil {
|
||||
return errUDPMuxDisabled
|
||||
}
|
||||
@@ -259,7 +283,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
|
||||
|
||||
for _, candidateIP := range localIPs {
|
||||
if a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost {
|
||||
if mappedIP, err := a.extIPMapper.findExternalIP(candidateIP.String()); err != nil {
|
||||
if mappedIP, innerErr := a.extIPMapper.findExternalIP(candidateIP.String()); innerErr != nil {
|
||||
a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", candidateIP.String())
|
||||
continue
|
||||
} else {
|
||||
@@ -267,11 +291,25 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
var conns []net.PacketConn
|
||||
if multi, ok := a.udpMux.(AllConnsGetter); ok {
|
||||
conns, err = multi.GetAllConns(a.localUfrag, candidateIP.To4() == nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(conns) == 0 {
|
||||
a.log.Warnf("Failed to get any connections from MultiUDPMux for candidate: %s", candidateIP)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
conn, err := a.udpMux.GetConn(a.localUfrag, candidateIP.To4() == nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conns = []net.PacketConn{conn}
|
||||
}
|
||||
|
||||
for _, conn := range conns {
|
||||
udpAddr, ok := conn.LocalAddr().(*net.UDPAddr)
|
||||
if !ok {
|
||||
closeConnAndLog(conn, a.log, fmt.Sprintf("Failed to create host mux candidate: %s failed to cast", candidateIP))
|
||||
@@ -300,6 +338,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -585,7 +624,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli
|
||||
return
|
||||
}
|
||||
|
||||
conn, connectErr := dtls.Dial(network, udpAddr, &dtls.Config{
|
||||
conn, connectErr := dtls.Dial(network, udpAddr, &dtls.Config{ //nolint:contextcheck
|
||||
ServerName: url.Host,
|
||||
InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec
|
||||
})
|
||||
|
113
gather_test.go
113
gather_test.go
@@ -488,6 +488,119 @@ func TestTURNProxyDialer(t *testing.T) {
|
||||
assert.NoError(t, a.Close())
|
||||
}
|
||||
|
||||
// Assert that candidates are given for each mux in a MultiUDPMux
|
||||
func TestMultiUDPMuxUsage(t *testing.T) {
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
lim := test.TimeOut(time.Second * 30)
|
||||
defer lim.Stop()
|
||||
|
||||
var expectedPorts []int
|
||||
var udpMuxInstances []UDPMux
|
||||
for i := 0; i < 3; i++ {
|
||||
port := randomPort(t)
|
||||
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: port})
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
expectedPorts = append(expectedPorts, port)
|
||||
udpMuxInstances = append(udpMuxInstances, NewUDPMuxDefault(UDPMuxParams{
|
||||
UDPConn: conn,
|
||||
}))
|
||||
}
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
CandidateTypes: []CandidateType{CandidateTypeHost},
|
||||
UDPMux: NewMultiUDPMuxDefault(udpMuxInstances...),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
candidateCh := make(chan Candidate)
|
||||
assert.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
if c == nil {
|
||||
close(candidateCh)
|
||||
return
|
||||
}
|
||||
candidateCh <- c
|
||||
}))
|
||||
assert.NoError(t, a.GatherCandidates())
|
||||
|
||||
portFound := make(map[int]bool)
|
||||
for c := range candidateCh {
|
||||
portFound[c.Port()] = true
|
||||
assert.True(t, c.NetworkType().IsUDP(), "All candidates should be UDP")
|
||||
}
|
||||
assert.Len(t, portFound, len(expectedPorts))
|
||||
for _, port := range expectedPorts {
|
||||
assert.True(t, portFound[port], "There should be a candidate for each UDP mux port")
|
||||
}
|
||||
|
||||
assert.NoError(t, a.Close())
|
||||
}
|
||||
|
||||
// Assert that candidates are given for each mux in a MultiTCPMux
|
||||
func TestMultiTCPMuxUsage(t *testing.T) {
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
lim := test.TimeOut(time.Second * 30)
|
||||
defer lim.Stop()
|
||||
|
||||
var expectedPorts []int
|
||||
var tcpMuxInstances []TCPMux
|
||||
for i := 0; i < 3; i++ {
|
||||
port := randomPort(t)
|
||||
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Port: port,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
|
||||
expectedPorts = append(expectedPorts, port)
|
||||
tcpMuxInstances = append(tcpMuxInstances, NewTCPMuxDefault(TCPMuxParams{
|
||||
Listener: listener,
|
||||
ReadBufferSize: 8,
|
||||
}))
|
||||
}
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
CandidateTypes: []CandidateType{CandidateTypeHost},
|
||||
TCPMux: NewMultiTCPMuxDefault(tcpMuxInstances...),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
candidateCh := make(chan Candidate)
|
||||
assert.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
if c == nil {
|
||||
close(candidateCh)
|
||||
return
|
||||
}
|
||||
candidateCh <- c
|
||||
}))
|
||||
assert.NoError(t, a.GatherCandidates())
|
||||
|
||||
portFound := make(map[int]bool)
|
||||
for c := range candidateCh {
|
||||
if c.NetworkType().IsTCP() {
|
||||
portFound[c.Port()] = true
|
||||
}
|
||||
}
|
||||
assert.Len(t, portFound, len(expectedPorts))
|
||||
for _, port := range expectedPorts {
|
||||
assert.True(t, portFound[port], "There should be a candidate for each TCP mux port")
|
||||
}
|
||||
|
||||
assert.NoError(t, a.Close())
|
||||
}
|
||||
|
||||
// Assert that UniversalUDPMux is used while gathering when configured in the Agent
|
||||
func TestUniversalUDPMuxUsage(t *testing.T) {
|
||||
report := test.CheckRoutines(t)
|
||||
|
73
tcp_mux_multi.go
Normal file
73
tcp_mux_multi.go
Normal file
@@ -0,0 +1,73 @@
|
||||
// Package ice ...
|
||||
//
|
||||
//nolint:dupl
|
||||
package ice
|
||||
|
||||
import "net"
|
||||
|
||||
// MultiTCPMuxDefault implements both TCPMux and AllConnsGetter,
|
||||
// allowing users to pass multiple TCPMux instances to the ICE agent
|
||||
// configuration.
|
||||
type MultiTCPMuxDefault struct {
|
||||
muxs []TCPMux
|
||||
}
|
||||
|
||||
// NewMultiTCPMuxDefault creates an instance of MultiTCPMuxDefault that
|
||||
// uses the provided TCPMux instances.
|
||||
func NewMultiTCPMuxDefault(muxs ...TCPMux) *MultiTCPMuxDefault {
|
||||
return &MultiTCPMuxDefault{
|
||||
muxs: muxs,
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnByUfrag returns a PacketConn given the connection's ufrag and network
|
||||
// creates the connection if an existing one can't be found. This, unlike
|
||||
// GetAllConns, will only return a single PacketConn from the first mux that was
|
||||
// passed in to NewMultiTCPMuxDefault.
|
||||
func (m *MultiTCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error) {
|
||||
// NOTE: We always use the first element here in order to maintain the
|
||||
// behavior of using an existing connection if one exists.
|
||||
if len(m.muxs) == 0 {
|
||||
return nil, errNoTCPMuxAvailable
|
||||
}
|
||||
return m.muxs[0].GetConnByUfrag(ufrag, isIPv6)
|
||||
}
|
||||
|
||||
// RemoveConnByUfrag stops and removes the muxed packet connection
|
||||
// from all underlying TCPMux instances.
|
||||
func (m *MultiTCPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
for _, mux := range m.muxs {
|
||||
mux.RemoveConnByUfrag(ufrag)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllConns returns a PacketConn for each underlying TCPMux
|
||||
func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool) ([]net.PacketConn, error) {
|
||||
if len(m.muxs) == 0 {
|
||||
// Make sure that we either return at least one connection or an error.
|
||||
return nil, errNoTCPMuxAvailable
|
||||
}
|
||||
var conns []net.PacketConn
|
||||
for _, mux := range m.muxs {
|
||||
conn, err := mux.GetConnByUfrag(ufrag, isIPv6)
|
||||
if err != nil {
|
||||
// For now, this implementation is all or none.
|
||||
return nil, err
|
||||
}
|
||||
if conn != nil {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
}
|
||||
return conns, nil
|
||||
}
|
||||
|
||||
// Close the multi mux, no further connections could be created
|
||||
func (m *MultiTCPMuxDefault) Close() error {
|
||||
var err error
|
||||
for _, mux := range m.muxs {
|
||||
if e := mux.Close(); e != nil {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
128
tcp_mux_multi_test.go
Normal file
128
tcp_mux_multi_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
//go:build !js
|
||||
// +build !js
|
||||
|
||||
package ice
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/stun"
|
||||
"github.com/pion/transport/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMultiTCPMux_Recv(t *testing.T) {
|
||||
for name, buffersize := range map[string]int{
|
||||
"no buffer": 0,
|
||||
"buffered 4MB": 4 * 1024 * 1024,
|
||||
} {
|
||||
bufSize := buffersize
|
||||
t.Run(name, func(t *testing.T) {
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
loggerFactory := logging.NewDefaultLoggerFactory()
|
||||
|
||||
var muxInstances []TCPMux
|
||||
for i := 0; i < 3; i++ {
|
||||
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Port: 0,
|
||||
})
|
||||
require.NoError(t, err, "error starting listener")
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
|
||||
tcpMux := NewTCPMuxDefault(TCPMuxParams{
|
||||
Listener: listener,
|
||||
Logger: loggerFactory.NewLogger("ice"),
|
||||
ReadBufferSize: 20,
|
||||
WriteBufferSize: bufSize,
|
||||
})
|
||||
muxInstances = append(muxInstances, tcpMux)
|
||||
require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
|
||||
}
|
||||
|
||||
multiMux := NewMultiTCPMuxDefault(muxInstances...)
|
||||
defer func() {
|
||||
_ = multiMux.Close()
|
||||
}()
|
||||
|
||||
pktConns, err := multiMux.GetAllConns("myufrag", false)
|
||||
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
||||
|
||||
for _, pktConn := range pktConns {
|
||||
defer func() {
|
||||
_ = pktConn.Close()
|
||||
}()
|
||||
conn, err := net.DialTCP("tcp", nil, pktConn.LocalAddr().(*net.TCPAddr))
|
||||
require.NoError(t, err, "error dialing test tcp connection")
|
||||
|
||||
msg := stun.New()
|
||||
msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
|
||||
msg.Add(stun.AttrUsername, []byte("myufrag:otherufrag"))
|
||||
msg.Encode()
|
||||
|
||||
n, err := writeStreamingPacket(conn, msg.Raw)
|
||||
require.NoError(t, err, "error writing tcp stun packet")
|
||||
|
||||
recv := make([]byte, n)
|
||||
n2, raddr, err := pktConn.ReadFrom(recv)
|
||||
require.NoError(t, err, "error receiving data")
|
||||
assert.Equal(t, conn.LocalAddr(), raddr, "remote tcp address mismatch")
|
||||
assert.Equal(t, n, n2, "received byte size mismatch")
|
||||
assert.Equal(t, msg.Raw, recv, "received bytes mismatch")
|
||||
|
||||
// check echo response
|
||||
n, err = pktConn.WriteTo(recv, conn.LocalAddr())
|
||||
require.NoError(t, err, "error writing echo stun packet")
|
||||
recvEcho := make([]byte, n)
|
||||
n3, err := readStreamingPacket(conn, recvEcho)
|
||||
require.NoError(t, err, "error receiving echo data")
|
||||
assert.Equal(t, n2, n3, "received byte size mismatch")
|
||||
assert.Equal(t, msg.Raw, recvEcho, "received bytes mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) {
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
loggerFactory := logging.NewDefaultLoggerFactory()
|
||||
|
||||
var tcpMuxInstances []TCPMux
|
||||
for i := 0; i < 3; i++ {
|
||||
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Port: 0,
|
||||
})
|
||||
require.NoError(t, err, "error starting listener")
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
|
||||
tcpMux := NewTCPMuxDefault(TCPMuxParams{
|
||||
Listener: listener,
|
||||
Logger: loggerFactory.NewLogger("ice"),
|
||||
ReadBufferSize: 20,
|
||||
})
|
||||
tcpMuxInstances = append(tcpMuxInstances, tcpMux)
|
||||
}
|
||||
muxMulti := NewMultiTCPMuxDefault(tcpMuxInstances...)
|
||||
|
||||
_, err := muxMulti.GetAllConns("test", false)
|
||||
require.NoError(t, err, "error getting conn by ufrag")
|
||||
|
||||
require.NoError(t, muxMulti.Close(), "error closing tcpMux")
|
||||
|
||||
conn, err := muxMulti.GetAllConns("test", false)
|
||||
assert.Nil(t, conn, "should receive nil because mux is closed")
|
||||
assert.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed")
|
||||
}
|
81
udp_mux_multi.go
Normal file
81
udp_mux_multi.go
Normal file
@@ -0,0 +1,81 @@
|
||||
// Package ice ...
|
||||
//
|
||||
//nolint:dupl
|
||||
package ice
|
||||
|
||||
import "net"
|
||||
|
||||
// AllConnsGetter allows multiple fixed UDP or TCP ports to be used,
|
||||
// each which is multiplexed like UDPMux. AllConnsGetter also acts as
|
||||
// a UDPMux, in which case it will return a single connection for one
|
||||
// of the ports.
|
||||
type AllConnsGetter interface {
|
||||
GetAllConns(ufrag string, isIPv6 bool) ([]net.PacketConn, error)
|
||||
}
|
||||
|
||||
// MultiUDPMuxDefault implements both UDPMux and AllConnsGetter,
|
||||
// allowing users to pass multiple UDPMux instances to the ICE agent
|
||||
// configuration.
|
||||
type MultiUDPMuxDefault struct {
|
||||
muxs []UDPMux
|
||||
}
|
||||
|
||||
// NewMultiUDPMuxDefault creates an instance of MultiUDPMuxDefault that
|
||||
// uses the provided UDPMux instances.
|
||||
func NewMultiUDPMuxDefault(muxs ...UDPMux) *MultiUDPMuxDefault {
|
||||
return &MultiUDPMuxDefault{
|
||||
muxs: muxs,
|
||||
}
|
||||
}
|
||||
|
||||
// GetConn returns a PacketConn given the connection's ufrag and network
|
||||
// creates the connection if an existing one can't be found. This, unlike
|
||||
// GetAllConns, will only return a single PacketConn from the first
|
||||
// mux that was passed in to NewMultiUDPMuxDefault.
|
||||
func (m *MultiUDPMuxDefault) GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error) {
|
||||
// NOTE: We always use the first element here in order to maintain the
|
||||
// behavior of using an existing connection if one exists.
|
||||
if len(m.muxs) == 0 {
|
||||
return nil, errNoUDPMuxAvailable
|
||||
}
|
||||
return m.muxs[0].GetConn(ufrag, isIPv6)
|
||||
}
|
||||
|
||||
// RemoveConnByUfrag stops and removes the muxed packet connection
|
||||
// from all underlying UDPMux instances.
|
||||
func (m *MultiUDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
for _, mux := range m.muxs {
|
||||
mux.RemoveConnByUfrag(ufrag)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllConns returns a PacketConn for each underlying UDPMux
|
||||
func (m *MultiUDPMuxDefault) GetAllConns(ufrag string, isIPv6 bool) ([]net.PacketConn, error) {
|
||||
if len(m.muxs) == 0 {
|
||||
// Make sure that we either return at least one connection or an error.
|
||||
return nil, errNoUDPMuxAvailable
|
||||
}
|
||||
var conns []net.PacketConn
|
||||
for _, mux := range m.muxs {
|
||||
conn, err := mux.GetConn(ufrag, isIPv6)
|
||||
if err != nil {
|
||||
// For now, this implementation is all or none.
|
||||
return nil, err
|
||||
}
|
||||
if conn != nil {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
}
|
||||
return conns, nil
|
||||
}
|
||||
|
||||
// Close the multi mux, no further connections could be created
|
||||
func (m *MultiUDPMuxDefault) Close() error {
|
||||
var err error
|
||||
for _, mux := range m.muxs {
|
||||
if e := mux.Close(); e != nil {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
85
udp_mux_multi_test.go
Normal file
85
udp_mux_multi_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
//go:build !js
|
||||
// +build !js
|
||||
|
||||
package ice
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/transport/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMultiUDPMux(t *testing.T) {
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
lim := test.TimeOut(time.Second * 30)
|
||||
defer lim.Stop()
|
||||
|
||||
conn1, err := net.ListenUDP(udp, &net.UDPAddr{})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn2, err := net.ListenUDP(udp, &net.UDPAddr{})
|
||||
require.NoError(t, err)
|
||||
|
||||
udpMuxMulti := NewMultiUDPMuxDefault(
|
||||
NewUDPMuxDefault(UDPMuxParams{UDPConn: conn1}),
|
||||
NewUDPMuxDefault(UDPMuxParams{UDPConn: conn2}),
|
||||
)
|
||||
defer func() {
|
||||
_ = udpMuxMulti.Close()
|
||||
_ = conn1.Close()
|
||||
_ = conn2.Close()
|
||||
}()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp)
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", "udp4")
|
||||
}()
|
||||
|
||||
// skip ipv6 test on i386
|
||||
const ptrSize = 32 << (^uintptr(0) >> 63)
|
||||
if ptrSize != 32 {
|
||||
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", "udp6")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
require.NoError(t, udpMuxMulti.Close())
|
||||
|
||||
// can't create more connections
|
||||
_, err = udpMuxMulti.GetConn("failufrag", false)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, ufrag string, network string) {
|
||||
pktConns, err := udpMuxMulti.GetAllConns(ufrag, false)
|
||||
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
||||
defer func() {
|
||||
for _, c := range pktConns {
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
require.Len(t, pktConns, len(udpMuxMulti.muxs), "there should be a PacketConn for every mux")
|
||||
|
||||
// Try talking with each PacketConn
|
||||
for _, pktConn := range pktConns {
|
||||
remoteConn, err := net.DialUDP(network, nil, &net.UDPAddr{
|
||||
Port: pktConn.LocalAddr().(*net.UDPAddr).Port,
|
||||
})
|
||||
require.NoError(t, err, "error dialing test udp connection")
|
||||
testMuxConnectionPair(t, pktConn, remoteConn, ufrag)
|
||||
}
|
||||
}
|
@@ -122,8 +122,12 @@ func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, networ
|
||||
})
|
||||
require.NoError(t, err, "error dialing test udp connection")
|
||||
|
||||
testMuxConnectionPair(t, pktConn, remoteConn, ufrag)
|
||||
}
|
||||
|
||||
func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net.UDPConn, ufrag string) {
|
||||
// initial messages are dropped
|
||||
_, err = remoteConn.Write([]byte("dropped bytes"))
|
||||
_, err := remoteConn.Write([]byte("dropped bytes"))
|
||||
require.NoError(t, err)
|
||||
// wait for packet to be consumed
|
||||
time.Sleep(time.Millisecond)
|
||||
|
Reference in New Issue
Block a user