Simplify handling if no TCP mux is configured

This commit is contained in:
Steffen Vogel
2023-04-18 18:20:06 +02:00
parent f40dd65abb
commit c596a7cc2b
5 changed files with 14 additions and 46 deletions

View File

@@ -297,6 +297,9 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
log: log,
net: config.Net,
proxyDialer: config.ProxyDialer,
tcpMux: config.TCPMux,
udpMux: config.UDPMux,
udpMuxSrflx: config.UDPMuxSrflx,
mDNSMode: mDNSMode,
mDNSName: mDNSName,
@@ -314,13 +317,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
includeLoopback: config.IncludeLoopback,
}
a.tcpMux = config.TCPMux
if a.tcpMux == nil {
a.tcpMux = newInvalidTCPMux()
}
a.udpMux = config.UDPMux
a.udpMuxSrflx = config.UDPMuxSrflx
if a.net == nil {
a.net, err = stdnet.NewNet()
if err != nil {
@@ -906,7 +902,9 @@ func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error)
}
func (a *Agent) removeUfragFromMux() {
if a.tcpMux != nil {
a.tcpMux.RemoveConnByUfrag(a.localUfrag)
}
if a.udpMux != nil {
a.udpMux.RemoveConnByUfrag(a.localUfrag)
}

View File

@@ -103,9 +103,6 @@ var (
// ErrRunCanceled indicates a run operation was canceled by its individual done
ErrRunCanceled = errors.New("run was canceled by done")
// ErrTCPMuxNotInitialized indicates TCPMux is not initialized and that invalidTCPMux is used.
ErrTCPMuxNotInitialized = errors.New("TCPMux is not initialized")
// ErrTCPRemoteAddrAlreadyExists indicates we already have the connection with same remote addr.
ErrTCPRemoteAddrAlreadyExists = errors.New("conn with same remote addr already exists")

View File

@@ -6,7 +6,6 @@ package ice
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
@@ -166,24 +165,24 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
switch network {
case tcp:
if a.tcpMux == nil {
continue
}
// Handle ICE TCP passive mode
var muxConns []net.PacketConn
if multi, ok := a.tcpMux.(AllConnsGetter); ok {
a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag)
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil, ip)
if err != nil {
if !errors.Is(err, ErrTCPMuxNotInitialized) {
a.log.Warnf("error getting all tcp conns by ufrag: %s %s %s", network, ip, a.localUfrag)
}
continue
}
} else {
a.log.Debugf("GetConn by ufrag: %s", a.localUfrag)
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil, ip)
if err != nil {
if !errors.Is(err, ErrTCPMuxNotInitialized) {
a.log.Warnf("error getting tcp conn by ufrag: %s %s %s", network, ip, a.localUfrag)
}
continue
}
muxConns = []net.PacketConn{conn}

View File

@@ -20,36 +20,13 @@ var ErrGetTransportAddress = errors.New("failed to get local transport address")
// TCPMux is allows grouping multiple TCP net.Conns and using them like UDP
// net.PacketConns. The main implementation of this is TCPMuxDefault, and this
// interface exists to:
// 1. prevent SEGV panics when TCPMuxDefault is not initialized by using the
// invalidTCPMux implementation, and
// 2. allow mocking in tests.
// interface exists to allow mocking in tests.
type TCPMux interface {
io.Closer
GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string)
}
// invalidTCPMux is an implementation of TCPMux that always returns ErrTCPMuxNotInitialized.
type invalidTCPMux struct{}
func newInvalidTCPMux() *invalidTCPMux {
return &invalidTCPMux{}
}
// Close implements TCPMux interface.
func (m *invalidTCPMux) Close() error {
return ErrTCPMuxNotInitialized
}
// GetConnByUfrag implements TCPMux interface.
func (m *invalidTCPMux) GetConnByUfrag(string, bool, net.IP) (net.PacketConn, error) {
return nil, ErrTCPMuxNotInitialized
}
// RemoveConnByUfrag implements TCPMux interface.
func (m *invalidTCPMux) RemoveConnByUfrag(string) {}
type ipAddr string
// TCPMuxDefault muxes TCP net.Conns into net.PacketConns and groups them by

View File

@@ -15,10 +15,7 @@ import (
"github.com/stretchr/testify/require"
)
var (
_ TCPMux = &TCPMuxDefault{}
_ TCPMux = &invalidTCPMux{}
)
var _ TCPMux = &TCPMuxDefault{}
func TestTCPMux_Recv(t *testing.T) {
for name, bufSize := range map[string]int{