Improve code cov for udp_mux_universal

This commit is contained in:
philipch07
2025-09-12 23:03:43 -04:00
committed by philipch07
parent 1a618559fc
commit 351ccd8088
2 changed files with 533 additions and 0 deletions

View File

@@ -7,11 +7,14 @@
package ice package ice
import ( import (
"errors"
"net" "net"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/pion/logging"
"github.com/pion/transport/v3/stdnet"
"github.com/pion/transport/v3/test" "github.com/pion/transport/v3/test"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -138,3 +141,268 @@ func TestUnspecifiedUDPMux(t *testing.T) {
require.NoError(t, udpMuxMulti.Close()) require.NoError(t, udpMuxMulti.Close())
} }
func TestMultiUDPMux_GetConn_NoUDPMuxAvailable(t *testing.T) {
conn1, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
defer func() {
_ = conn1.Close()
}()
conn2, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
defer func() {
_ = conn2.Close()
}()
mux1 := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn1})
mux2 := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn2})
multi := NewMultiUDPMuxDefault(mux1, mux2)
defer func() {
_ = multi.Close()
}()
// tweak the port so it doesn't match.
addrs := multi.GetListenAddresses()
require.NotEmpty(t, addrs)
udpAddr, ok := addrs[0].(*net.UDPAddr)
require.True(t, ok, "expected *net.UDPAddr")
// change the port to something different so addr.String() is not in localAddrToMux.
missing := &net.UDPAddr{IP: udpAddr.IP, Port: udpAddr.Port + 1, Zone: udpAddr.Zone}
pc, getErr := multi.GetConn("missing-ufrag", missing)
require.Nil(t, pc)
require.ErrorIs(t, getErr, errNoUDPMuxAvailable)
}
type closeErrUDPMux struct {
UDPMux
ret error
}
func (w *closeErrUDPMux) Close() error {
_ = w.UDPMux.Close() // ensure underlying resources are released
return w.ret
}
var (
errCloseBoom = errors.New("close boom")
errCloseFirst = errors.New("first close failed")
errCloseSecond = errors.New("second close failed")
)
func TestMultiUDPMux_Close_PropagatesError(t *testing.T) {
udp1, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
udp2, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
mux1 := NewUDPMuxDefault(UDPMuxParams{UDPConn: udp1})
mux2real := NewUDPMuxDefault(UDPMuxParams{UDPConn: udp2})
mux2 := &closeErrUDPMux{UDPMux: mux2real, ret: errCloseBoom}
multi := NewMultiUDPMuxDefault(mux1, mux2)
got := multi.Close()
require.ErrorIs(t, got, errCloseBoom)
}
func TestMultiUDPMux_Close_LastErrorWins(t *testing.T) {
udpA, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
udpB, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
muxAReal := NewUDPMuxDefault(UDPMuxParams{UDPConn: udpA})
muxBReal := NewUDPMuxDefault(UDPMuxParams{UDPConn: udpB})
muxA := &closeErrUDPMux{UDPMux: muxAReal, ret: errCloseFirst}
muxB := &closeErrUDPMux{UDPMux: muxBReal, ret: errCloseSecond}
multi := NewMultiUDPMuxDefault(muxA, muxB)
got := multi.Close()
require.ErrorIs(t, got, errCloseSecond)
}
func TestUDPMuxFromPortOptions_Apply(t *testing.T) {
t.Run("IPFilter", func(t *testing.T) {
var p multiUDPMuxFromPortParam
keepLoopbackV4 := func(ip net.IP) bool { return ip.IsLoopback() && ip.To4() != nil }
opt := UDPMuxFromPortWithIPFilter(keepLoopbackV4)
opt.apply(&p)
require.NotNil(t, p.ipFilter)
require.True(t, p.ipFilter(net.ParseIP("127.0.0.1")))
require.False(t, p.ipFilter(net.ParseIP("8.8.8.8")))
})
t.Run("Networks single", func(t *testing.T) {
var p multiUDPMuxFromPortParam
opt := UDPMuxFromPortWithNetworks(NetworkTypeUDP4)
opt.apply(&p)
require.Len(t, p.networks, 1)
require.Equal(t, NetworkTypeUDP4, p.networks[0])
})
t.Run("Networks multiple", func(t *testing.T) {
var p multiUDPMuxFromPortParam
opt := UDPMuxFromPortWithNetworks(NetworkTypeUDP4, NetworkTypeUDP6)
opt.apply(&p)
require.Len(t, p.networks, 2)
require.ElementsMatch(t, []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, p.networks)
})
t.Run("ReadBufferSize", func(t *testing.T) {
var p multiUDPMuxFromPortParam
opt := UDPMuxFromPortWithReadBufferSize(4096)
opt.apply(&p)
require.Equal(t, 4096, p.readBufferSize)
})
t.Run("WriteBufferSize", func(t *testing.T) {
var p multiUDPMuxFromPortParam
opt := UDPMuxFromPortWithWriteBufferSize(8192)
opt.apply(&p)
require.Equal(t, 8192, p.writeBufferSize)
})
t.Run("Logger", func(t *testing.T) {
var p multiUDPMuxFromPortParam
logger := logging.NewDefaultLoggerFactory().NewLogger("ice-test")
opt := UDPMuxFromPortWithLogger(logger)
opt.apply(&p)
require.NotNil(t, p.logger)
require.Equal(t, logger, p.logger)
})
t.Run("Net", func(t *testing.T) {
var p multiUDPMuxFromPortParam
n, err := stdnet.NewNet()
require.NoError(t, err)
opt := UDPMuxFromPortWithNet(n)
opt.apply(&p)
require.NotNil(t, p.net)
require.Equal(t, n, p.net)
})
}
func TestNewMultiUDPMuxFromPort_PortInUse_ListenErrorAndCleanup(t *testing.T) {
pre, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
defer func() {
_ = pre.Close()
}()
srvAddr, ok := pre.LocalAddr().(*net.UDPAddr)
require.True(t, ok, "pre.LocalAddr is not *net.UDPAddr")
port := srvAddr.Port
multi, buildErr := NewMultiUDPMuxFromPort(
port,
UDPMuxFromPortWithLoopback(),
UDPMuxFromPortWithNetworks(NetworkTypeUDP4),
)
require.Nil(t, multi)
require.Error(t, buildErr)
}
func TestNewMultiUDPMuxFromPort_Success_SetsBuffers(t *testing.T) {
multi, err := NewMultiUDPMuxFromPort(
0,
UDPMuxFromPortWithLoopback(),
UDPMuxFromPortWithNetworks(NetworkTypeUDP4),
UDPMuxFromPortWithReadBufferSize(4096),
UDPMuxFromPortWithWriteBufferSize(8192),
)
require.NoError(t, err)
require.NotNil(t, multi)
addrs := multi.GetListenAddresses()
require.NotEmpty(t, addrs)
require.NoError(t, multi.Close())
}
func TestNewMultiUDPMuxFromPort_CleanupClosesAll(t *testing.T) {
stdNet, err := stdnet.NewNet()
require.NoError(t, err)
_, addrs, err := localInterfaces(stdNet, nil, nil, []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, true)
require.NoError(t, err)
if len(addrs) < 2 {
t.Skip("need at least two local addresses to hit partial-success then failure")
}
second := addrs[1]
l2, err := stdNet.ListenUDP("udp", &net.UDPAddr{
IP: second.AsSlice(),
Port: 0,
Zone: second.Zone(),
})
require.NoError(t, err)
defer func() {
_ = l2.Close()
}()
udpAddr2, ok := l2.LocalAddr().(*net.UDPAddr)
require.True(t, ok, "LocalAddr is not *net.UDPAddr")
picked := udpAddr2.Port
preBinds := []net.PacketConn{l2}
for i := 2; i < len(addrs); i++ {
a := addrs[i]
l, e := stdNet.ListenUDP("udp", &net.UDPAddr{
IP: a.AsSlice(),
Port: picked,
Zone: a.Zone(),
})
if e == nil {
preBinds = append(preBinds, l)
}
}
t.Cleanup(func() {
for _, c := range preBinds {
_ = c.Close()
}
})
require.GreaterOrEqual(t, len(preBinds), 1, "need at least one prebound address after the first")
multi, buildErr := NewMultiUDPMuxFromPort(
picked,
UDPMuxFromPortWithNet(stdNet),
UDPMuxFromPortWithNetworks(NetworkTypeUDP4, NetworkTypeUDP6),
UDPMuxFromPortWithLoopback(),
)
require.Nil(t, multi)
require.Error(t, buildErr)
first := addrs[0]
rebind, err := stdNet.ListenUDP("udp", &net.UDPAddr{
IP: first.AsSlice(),
Port: picked,
Zone: first.Zone(),
})
require.NoError(t, err, "expected first address/port to be free after cleanup")
_ = rebind.Close()
}

View File

@@ -7,11 +7,15 @@
package ice package ice
import ( import (
"encoding/binary"
"io"
"net" "net"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/pion/ice/v4/internal/fakenet"
"github.com/pion/logging"
"github.com/pion/stun/v3" "github.com/pion/stun/v3"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -130,3 +134,264 @@ func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag
require.NotNil(t, err) require.NotNil(t, err)
require.Nil(t, address) require.Nil(t, address)
} }
func TestUniversalUDPMux_GetConnForURL_UniquePerURL(t *testing.T) {
conn, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
udpMux := NewUniversalUDPMuxDefault(UniversalUDPMuxParams{
Logger: nil,
UDPConn: conn,
})
defer func() {
_ = udpMux.Close()
_ = conn.Close()
}()
lf := udpMux.LocalAddr()
require.NotNil(t, lf)
// different URLs -> must be distinct muxed conns
pc1, err := udpMux.GetConnForURL("ufragX", "stun:serverA", lf)
require.NoError(t, err)
defer func() {
_ = pc1.Close()
}()
pc2, err := udpMux.GetConnForURL("ufragX", "stun:serverB", lf)
require.NoError(t, err)
defer func() {
_ = pc2.Close()
}()
c1, ok := pc1.(*udpMuxedConn)
require.True(t, ok, "pc1 is not *udpMuxedConn")
c2, ok := pc2.(*udpMuxedConn)
require.True(t, ok, "pc2 is not *udpMuxedConn")
require.NotEqual(t, c1, c2, "expected distinct muxed conns for different URLs with same ufrag")
pc1b, err := udpMux.GetConnForURL("ufragX", "stun:serverA", lf)
require.NoError(t, err)
defer func() {
_ = pc1b.Close()
}()
c1b, ok := pc1b.(*udpMuxedConn)
require.True(t, ok, "pc1b is not *udpMuxedConn")
require.Equal(t, c1, c1b, "expected same muxed conn when requesting the same (ufrag,url)")
}
func newLogger() logging.LeveledLogger {
return logging.NewDefaultLoggerFactory().NewLogger("ice")
}
func newFakenetReader(t *testing.T, payload []byte) *fakenet.PacketConn {
t.Helper()
r, w := net.Pipe()
go func() {
_, _ = w.Write(payload)
_ = w.Close()
}()
pc := &fakenet.PacketConn{}
pc.Conn = r
return pc
}
func Test_udpConn_ReadFrom_STUNDecodeError(t *testing.T) {
server, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
t.Cleanup(func() { _ = server.Close() })
srvAddr, ok := server.LocalAddr().(*net.UDPAddr)
require.True(t, ok, "server.LocalAddr is not *net.UDPAddr")
client, err := net.DialUDP("udp4", nil, srvAddr)
require.NoError(t, err)
t.Cleanup(func() { _ = client.Close() })
// build a valid STUN Binding Request then corrupt the header length field.
msg := stun.New()
msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
msg.Encode()
raw := append([]byte{}, msg.Raw...)
decl := binary.BigEndian.Uint16(raw[2:4])
binary.BigEndian.PutUint16(raw[2:4], decl+4) // makes Decode() fail
_, err = client.Write(raw)
require.NoError(t, err)
u := &udpConn{PacketConn: server, mux: nil, logger: newLogger()}
_ = server.SetReadDeadline(time.Now().Add(time.Second))
buf := make([]byte, 1500)
n, addr, gotErr := u.ReadFrom(buf)
require.Equal(t, len(raw), n)
require.IsType(t, &net.UDPAddr{}, addr)
require.NoError(t, gotErr)
}
func Test_udpConn_ReadFrom_AddrNotUDP(t *testing.T) {
msg := stun.New()
msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
msg.Encode()
pc := newFakenetReader(t, msg.Raw)
u := &udpConn{PacketConn: pc, mux: nil, logger: newLogger()}
buf := make([]byte, 1500)
n, addr, gotErr := u.ReadFrom(buf)
require.Equal(t, len(msg.Raw), n)
require.NoError(t, gotErr)
require.NotNil(t, addr)
_, isUDP := addr.(*net.UDPAddr)
require.False(t, isUDP, "expected a non-UDP addr from fakenet.PacketConn")
}
func Test_udpConn_ReadFrom_XOR(t *testing.T) {
server, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
t.Cleanup(func() { _ = server.Close() })
srvAddr, ok := server.LocalAddr().(*net.UDPAddr)
require.True(t, ok, "server.LocalAddr is not *net.UDPAddr")
client, err := net.DialUDP("udp4", nil, srvAddr)
require.NoError(t, err)
t.Cleanup(func() { _ = client.Close() })
// success response + short XORMappedAddress value will make GetFrom() fail.
msg := stun.New()
msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassSuccessResponse}
msg.Add(stun.AttrXORMappedAddress, []byte{0x00}) // intentionally invalid
msg.Encode()
mux := &UniversalUDPMuxDefault{
UDPMuxDefault: &UDPMuxDefault{},
xorMappedMap: map[string]*xorMapped{
client.LocalAddr().String(): {
waitAddrReceived: make(chan struct{}),
expiresAt: time.Now().Add(time.Minute),
},
},
}
_, err = client.Write(msg.Raw)
require.NoError(t, err)
u := &udpConn{PacketConn: server, mux: mux, logger: newLogger()}
_ = server.SetReadDeadline(time.Now().Add(time.Second))
buf := make([]byte, 1500)
n, addr, gotErr := u.ReadFrom(buf)
require.Equal(t, len(msg.Raw), n)
require.IsType(t, &net.UDPAddr{}, addr)
require.NoError(t, gotErr)
}
func Test_udpConn_ReadFrom_NonSTUN(t *testing.T) {
payload := []byte("not a stun packet")
pc := newFakenetReader(t, payload)
u := &udpConn{PacketConn: pc, mux: nil, logger: newLogger()}
buf := make([]byte, 1500)
n, addr, gotErr := u.ReadFrom(buf)
require.NoError(t, gotErr)
require.Equal(t, len(payload), n)
require.Equal(t, payload, buf[:n])
require.NotNil(t, addr)
_, isUDP := addr.(*net.UDPAddr)
require.False(t, isUDP, "expected a non-UDP addr from fakenet.PacketConn")
}
func TestUniversalUDPMux_handleXORMappedResponse_NoMapping(t *testing.T) {
mux := &UniversalUDPMuxDefault{
UDPMuxDefault: &UDPMuxDefault{},
xorMappedMap: make(map[string]*xorMapped),
}
stunSrv := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 3478}
msg := stun.New()
err := mux.handleXORMappedResponse(stunSrv, msg)
require.ErrorIs(t, err, errNoXorAddrMapping)
}
func newFakePC(t *testing.T) (*fakenet.PacketConn, net.Conn, net.Conn) {
t.Helper()
c1, c2 := net.Pipe()
pc := &fakenet.PacketConn{}
pc.Conn = c1
return pc, c1, c2
}
func TestUniversalUDPMux_GetXORMappedAddr_Pending_WriteError(t *testing.T) {
serverAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 3478}
pc, c1, c2 := newFakePC(t)
_ = c2.Close() // other end unused
_ = c1.Close() // force future WriteTo to error
mux := &UniversalUDPMuxDefault{
UDPMuxDefault: &UDPMuxDefault{},
params: UniversalUDPMuxParams{
UDPConn: pc, // writeSTUN will call WriteTo on this fakenet PacketConn
},
xorMappedMap: map[string]*xorMapped{
serverAddr.String(): {
waitAddrReceived: make(chan struct{}),
expiresAt: time.Now().Add(time.Minute),
},
},
}
addr, err := mux.GetXORMappedAddr(serverAddr, time.Second)
require.Nil(t, addr)
require.ErrorIs(t, err, errWriteSTUNMessage)
}
func TestUniversalUDPMux_GetXORMappedAddr_WaitClosed_NoAddr(t *testing.T) {
serverAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 3478}
pc, c1, c2 := newFakePC(t)
drainDone := make(chan struct{})
go func() {
_, _ = io.Copy(io.Discard, c2)
close(drainDone)
}()
t.Cleanup(func() {
_ = c1.Close()
_ = c2.Close()
<-drainDone
})
waitCh := make(chan struct{})
close(waitCh)
mux := &UniversalUDPMuxDefault{
UDPMuxDefault: &UDPMuxDefault{},
params: UniversalUDPMuxParams{
UDPConn: pc,
},
xorMappedMap: map[string]*xorMapped{
serverAddr.String(): {
addr: nil,
waitAddrReceived: waitCh,
expiresAt: time.Now().Add(time.Minute),
},
},
}
addr, err := mux.GetXORMappedAddr(serverAddr, time.Second)
require.Nil(t, addr)
require.ErrorIs(t, err, errNoXorAddrMapping)
}