From 351ccd808850434a98100e2c512e8e776a45f2d5 Mon Sep 17 00:00:00 2001 From: philipch07 Date: Fri, 12 Sep 2025 23:03:43 -0400 Subject: [PATCH] Improve code cov for udp_mux_universal --- udp_mux_multi_test.go | 268 ++++++++++++++++++++++++++++++++++++++ udp_mux_universal_test.go | 265 +++++++++++++++++++++++++++++++++++++ 2 files changed, 533 insertions(+) diff --git a/udp_mux_multi_test.go b/udp_mux_multi_test.go index 8fe6536..109ff9b 100644 --- a/udp_mux_multi_test.go +++ b/udp_mux_multi_test.go @@ -7,11 +7,14 @@ package ice import ( + "errors" "net" "sync" "testing" "time" + "github.com/pion/logging" + "github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" ) @@ -138,3 +141,268 @@ func TestUnspecifiedUDPMux(t *testing.T) { require.NoError(t, udpMuxMulti.Close()) } + +func TestMultiUDPMux_GetConn_NoUDPMuxAvailable(t *testing.T) { + conn1, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + defer func() { + _ = conn1.Close() + }() + + conn2, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + defer func() { + _ = conn2.Close() + }() + + mux1 := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn1}) + mux2 := NewUDPMuxDefault(UDPMuxParams{UDPConn: conn2}) + multi := NewMultiUDPMuxDefault(mux1, mux2) + defer func() { + _ = multi.Close() + }() + + // tweak the port so it doesn't match. + addrs := multi.GetListenAddresses() + require.NotEmpty(t, addrs) + + udpAddr, ok := addrs[0].(*net.UDPAddr) + require.True(t, ok, "expected *net.UDPAddr") + + // change the port to something different so addr.String() is not in localAddrToMux. + missing := &net.UDPAddr{IP: udpAddr.IP, Port: udpAddr.Port + 1, Zone: udpAddr.Zone} + + pc, getErr := multi.GetConn("missing-ufrag", missing) + require.Nil(t, pc) + require.ErrorIs(t, getErr, errNoUDPMuxAvailable) +} + +type closeErrUDPMux struct { + UDPMux + ret error +} + +func (w *closeErrUDPMux) Close() error { + _ = w.UDPMux.Close() // ensure underlying resources are released + + return w.ret +} + +var ( + errCloseBoom = errors.New("close boom") + errCloseFirst = errors.New("first close failed") + errCloseSecond = errors.New("second close failed") +) + +func TestMultiUDPMux_Close_PropagatesError(t *testing.T) { + udp1, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + udp2, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + + mux1 := NewUDPMuxDefault(UDPMuxParams{UDPConn: udp1}) + mux2real := NewUDPMuxDefault(UDPMuxParams{UDPConn: udp2}) + + mux2 := &closeErrUDPMux{UDPMux: mux2real, ret: errCloseBoom} + + multi := NewMultiUDPMuxDefault(mux1, mux2) + got := multi.Close() + + require.ErrorIs(t, got, errCloseBoom) +} + +func TestMultiUDPMux_Close_LastErrorWins(t *testing.T) { + udpA, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + udpB, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + + muxAReal := NewUDPMuxDefault(UDPMuxParams{UDPConn: udpA}) + muxBReal := NewUDPMuxDefault(UDPMuxParams{UDPConn: udpB}) + + muxA := &closeErrUDPMux{UDPMux: muxAReal, ret: errCloseFirst} + muxB := &closeErrUDPMux{UDPMux: muxBReal, ret: errCloseSecond} + + multi := NewMultiUDPMuxDefault(muxA, muxB) + got := multi.Close() + + require.ErrorIs(t, got, errCloseSecond) +} + +func TestUDPMuxFromPortOptions_Apply(t *testing.T) { + t.Run("IPFilter", func(t *testing.T) { + var p multiUDPMuxFromPortParam + + keepLoopbackV4 := func(ip net.IP) bool { return ip.IsLoopback() && ip.To4() != nil } + opt := UDPMuxFromPortWithIPFilter(keepLoopbackV4) + opt.apply(&p) + + require.NotNil(t, p.ipFilter) + require.True(t, p.ipFilter(net.ParseIP("127.0.0.1"))) + require.False(t, p.ipFilter(net.ParseIP("8.8.8.8"))) + }) + + t.Run("Networks single", func(t *testing.T) { + var p multiUDPMuxFromPortParam + + opt := UDPMuxFromPortWithNetworks(NetworkTypeUDP4) + opt.apply(&p) + + require.Len(t, p.networks, 1) + require.Equal(t, NetworkTypeUDP4, p.networks[0]) + }) + + t.Run("Networks multiple", func(t *testing.T) { + var p multiUDPMuxFromPortParam + + opt := UDPMuxFromPortWithNetworks(NetworkTypeUDP4, NetworkTypeUDP6) + opt.apply(&p) + + require.Len(t, p.networks, 2) + require.ElementsMatch(t, []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, p.networks) + }) + + t.Run("ReadBufferSize", func(t *testing.T) { + var p multiUDPMuxFromPortParam + + opt := UDPMuxFromPortWithReadBufferSize(4096) + opt.apply(&p) + + require.Equal(t, 4096, p.readBufferSize) + }) + + t.Run("WriteBufferSize", func(t *testing.T) { + var p multiUDPMuxFromPortParam + + opt := UDPMuxFromPortWithWriteBufferSize(8192) + opt.apply(&p) + + require.Equal(t, 8192, p.writeBufferSize) + }) + + t.Run("Logger", func(t *testing.T) { + var p multiUDPMuxFromPortParam + + logger := logging.NewDefaultLoggerFactory().NewLogger("ice-test") + opt := UDPMuxFromPortWithLogger(logger) + opt.apply(&p) + + require.NotNil(t, p.logger) + require.Equal(t, logger, p.logger) + }) + + t.Run("Net", func(t *testing.T) { + var p multiUDPMuxFromPortParam + + n, err := stdnet.NewNet() + require.NoError(t, err) + + opt := UDPMuxFromPortWithNet(n) + opt.apply(&p) + + require.NotNil(t, p.net) + require.Equal(t, n, p.net) + }) +} + +func TestNewMultiUDPMuxFromPort_PortInUse_ListenErrorAndCleanup(t *testing.T) { + pre, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + defer func() { + _ = pre.Close() + }() + + srvAddr, ok := pre.LocalAddr().(*net.UDPAddr) + require.True(t, ok, "pre.LocalAddr is not *net.UDPAddr") + port := srvAddr.Port + + multi, buildErr := NewMultiUDPMuxFromPort( + port, + UDPMuxFromPortWithLoopback(), + UDPMuxFromPortWithNetworks(NetworkTypeUDP4), + ) + + require.Nil(t, multi) + require.Error(t, buildErr) +} + +func TestNewMultiUDPMuxFromPort_Success_SetsBuffers(t *testing.T) { + multi, err := NewMultiUDPMuxFromPort( + 0, + UDPMuxFromPortWithLoopback(), + UDPMuxFromPortWithNetworks(NetworkTypeUDP4), + UDPMuxFromPortWithReadBufferSize(4096), + UDPMuxFromPortWithWriteBufferSize(8192), + ) + require.NoError(t, err) + require.NotNil(t, multi) + + addrs := multi.GetListenAddresses() + require.NotEmpty(t, addrs) + + require.NoError(t, multi.Close()) +} + +func TestNewMultiUDPMuxFromPort_CleanupClosesAll(t *testing.T) { + stdNet, err := stdnet.NewNet() + require.NoError(t, err) + + _, addrs, err := localInterfaces(stdNet, nil, nil, []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, true) + require.NoError(t, err) + if len(addrs) < 2 { + t.Skip("need at least two local addresses to hit partial-success then failure") + } + + second := addrs[1] + l2, err := stdNet.ListenUDP("udp", &net.UDPAddr{ + IP: second.AsSlice(), + Port: 0, + Zone: second.Zone(), + }) + require.NoError(t, err) + defer func() { + _ = l2.Close() + }() + + udpAddr2, ok := l2.LocalAddr().(*net.UDPAddr) + require.True(t, ok, "LocalAddr is not *net.UDPAddr") + picked := udpAddr2.Port + + preBinds := []net.PacketConn{l2} + for i := 2; i < len(addrs); i++ { + a := addrs[i] + l, e := stdNet.ListenUDP("udp", &net.UDPAddr{ + IP: a.AsSlice(), + Port: picked, + Zone: a.Zone(), + }) + if e == nil { + preBinds = append(preBinds, l) + } + } + t.Cleanup(func() { + for _, c := range preBinds { + _ = c.Close() + } + }) + + require.GreaterOrEqual(t, len(preBinds), 1, "need at least one prebound address after the first") + + multi, buildErr := NewMultiUDPMuxFromPort( + picked, + UDPMuxFromPortWithNet(stdNet), + UDPMuxFromPortWithNetworks(NetworkTypeUDP4, NetworkTypeUDP6), + UDPMuxFromPortWithLoopback(), + ) + require.Nil(t, multi) + require.Error(t, buildErr) + + first := addrs[0] + rebind, err := stdNet.ListenUDP("udp", &net.UDPAddr{ + IP: first.AsSlice(), + Port: picked, + Zone: first.Zone(), + }) + require.NoError(t, err, "expected first address/port to be free after cleanup") + _ = rebind.Close() +} diff --git a/udp_mux_universal_test.go b/udp_mux_universal_test.go index 1f648e6..20a4351 100644 --- a/udp_mux_universal_test.go +++ b/udp_mux_universal_test.go @@ -7,11 +7,15 @@ package ice import ( + "encoding/binary" + "io" "net" "sync" "testing" "time" + "github.com/pion/ice/v4/internal/fakenet" + "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/stretchr/testify/require" ) @@ -130,3 +134,264 @@ func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag require.NotNil(t, err) require.Nil(t, address) } + +func TestUniversalUDPMux_GetConnForURL_UniquePerURL(t *testing.T) { + conn, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + + udpMux := NewUniversalUDPMuxDefault(UniversalUDPMuxParams{ + Logger: nil, + UDPConn: conn, + }) + defer func() { + _ = udpMux.Close() + _ = conn.Close() + }() + + lf := udpMux.LocalAddr() + require.NotNil(t, lf) + + // different URLs -> must be distinct muxed conns + pc1, err := udpMux.GetConnForURL("ufragX", "stun:serverA", lf) + require.NoError(t, err) + defer func() { + _ = pc1.Close() + }() + + pc2, err := udpMux.GetConnForURL("ufragX", "stun:serverB", lf) + require.NoError(t, err) + defer func() { + _ = pc2.Close() + }() + + c1, ok := pc1.(*udpMuxedConn) + require.True(t, ok, "pc1 is not *udpMuxedConn") + c2, ok := pc2.(*udpMuxedConn) + require.True(t, ok, "pc2 is not *udpMuxedConn") + require.NotEqual(t, c1, c2, "expected distinct muxed conns for different URLs with same ufrag") + + pc1b, err := udpMux.GetConnForURL("ufragX", "stun:serverA", lf) + require.NoError(t, err) + defer func() { + _ = pc1b.Close() + }() + + c1b, ok := pc1b.(*udpMuxedConn) + require.True(t, ok, "pc1b is not *udpMuxedConn") + + require.Equal(t, c1, c1b, "expected same muxed conn when requesting the same (ufrag,url)") +} + +func newLogger() logging.LeveledLogger { + return logging.NewDefaultLoggerFactory().NewLogger("ice") +} + +func newFakenetReader(t *testing.T, payload []byte) *fakenet.PacketConn { + t.Helper() + r, w := net.Pipe() + go func() { + _, _ = w.Write(payload) + _ = w.Close() + }() + pc := &fakenet.PacketConn{} + pc.Conn = r + + return pc +} + +func Test_udpConn_ReadFrom_STUNDecodeError(t *testing.T) { + server, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + t.Cleanup(func() { _ = server.Close() }) + + srvAddr, ok := server.LocalAddr().(*net.UDPAddr) + require.True(t, ok, "server.LocalAddr is not *net.UDPAddr") + + client, err := net.DialUDP("udp4", nil, srvAddr) + require.NoError(t, err) + t.Cleanup(func() { _ = client.Close() }) + + // build a valid STUN Binding Request then corrupt the header length field. + msg := stun.New() + msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} + msg.Encode() + raw := append([]byte{}, msg.Raw...) + decl := binary.BigEndian.Uint16(raw[2:4]) + binary.BigEndian.PutUint16(raw[2:4], decl+4) // makes Decode() fail + + _, err = client.Write(raw) + require.NoError(t, err) + + u := &udpConn{PacketConn: server, mux: nil, logger: newLogger()} + _ = server.SetReadDeadline(time.Now().Add(time.Second)) + + buf := make([]byte, 1500) + n, addr, gotErr := u.ReadFrom(buf) + + require.Equal(t, len(raw), n) + require.IsType(t, &net.UDPAddr{}, addr) + require.NoError(t, gotErr) +} + +func Test_udpConn_ReadFrom_AddrNotUDP(t *testing.T) { + msg := stun.New() + msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} + msg.Encode() + + pc := newFakenetReader(t, msg.Raw) + u := &udpConn{PacketConn: pc, mux: nil, logger: newLogger()} + + buf := make([]byte, 1500) + n, addr, gotErr := u.ReadFrom(buf) + + require.Equal(t, len(msg.Raw), n) + require.NoError(t, gotErr) + + require.NotNil(t, addr) + _, isUDP := addr.(*net.UDPAddr) + require.False(t, isUDP, "expected a non-UDP addr from fakenet.PacketConn") +} + +func Test_udpConn_ReadFrom_XOR(t *testing.T) { + server, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + t.Cleanup(func() { _ = server.Close() }) + + srvAddr, ok := server.LocalAddr().(*net.UDPAddr) + require.True(t, ok, "server.LocalAddr is not *net.UDPAddr") + + client, err := net.DialUDP("udp4", nil, srvAddr) + require.NoError(t, err) + t.Cleanup(func() { _ = client.Close() }) + + // success response + short XORMappedAddress value will make GetFrom() fail. + msg := stun.New() + msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassSuccessResponse} + msg.Add(stun.AttrXORMappedAddress, []byte{0x00}) // intentionally invalid + msg.Encode() + + mux := &UniversalUDPMuxDefault{ + UDPMuxDefault: &UDPMuxDefault{}, + xorMappedMap: map[string]*xorMapped{ + client.LocalAddr().String(): { + waitAddrReceived: make(chan struct{}), + expiresAt: time.Now().Add(time.Minute), + }, + }, + } + + _, err = client.Write(msg.Raw) + require.NoError(t, err) + + u := &udpConn{PacketConn: server, mux: mux, logger: newLogger()} + _ = server.SetReadDeadline(time.Now().Add(time.Second)) + + buf := make([]byte, 1500) + n, addr, gotErr := u.ReadFrom(buf) + + require.Equal(t, len(msg.Raw), n) + require.IsType(t, &net.UDPAddr{}, addr) + require.NoError(t, gotErr) +} + +func Test_udpConn_ReadFrom_NonSTUN(t *testing.T) { + payload := []byte("not a stun packet") + pc := newFakenetReader(t, payload) + + u := &udpConn{PacketConn: pc, mux: nil, logger: newLogger()} + + buf := make([]byte, 1500) + n, addr, gotErr := u.ReadFrom(buf) + + require.NoError(t, gotErr) + require.Equal(t, len(payload), n) + require.Equal(t, payload, buf[:n]) + + require.NotNil(t, addr) + _, isUDP := addr.(*net.UDPAddr) + require.False(t, isUDP, "expected a non-UDP addr from fakenet.PacketConn") +} + +func TestUniversalUDPMux_handleXORMappedResponse_NoMapping(t *testing.T) { + mux := &UniversalUDPMuxDefault{ + UDPMuxDefault: &UDPMuxDefault{}, + xorMappedMap: make(map[string]*xorMapped), + } + + stunSrv := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 3478} + msg := stun.New() + + err := mux.handleXORMappedResponse(stunSrv, msg) + require.ErrorIs(t, err, errNoXorAddrMapping) +} + +func newFakePC(t *testing.T) (*fakenet.PacketConn, net.Conn, net.Conn) { + t.Helper() + c1, c2 := net.Pipe() + pc := &fakenet.PacketConn{} + pc.Conn = c1 + + return pc, c1, c2 +} + +func TestUniversalUDPMux_GetXORMappedAddr_Pending_WriteError(t *testing.T) { + serverAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 3478} + + pc, c1, c2 := newFakePC(t) + _ = c2.Close() // other end unused + _ = c1.Close() // force future WriteTo to error + + mux := &UniversalUDPMuxDefault{ + UDPMuxDefault: &UDPMuxDefault{}, + params: UniversalUDPMuxParams{ + UDPConn: pc, // writeSTUN will call WriteTo on this fakenet PacketConn + }, + xorMappedMap: map[string]*xorMapped{ + serverAddr.String(): { + waitAddrReceived: make(chan struct{}), + expiresAt: time.Now().Add(time.Minute), + }, + }, + } + + addr, err := mux.GetXORMappedAddr(serverAddr, time.Second) + require.Nil(t, addr) + require.ErrorIs(t, err, errWriteSTUNMessage) +} + +func TestUniversalUDPMux_GetXORMappedAddr_WaitClosed_NoAddr(t *testing.T) { + serverAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 3478} + + pc, c1, c2 := newFakePC(t) + drainDone := make(chan struct{}) + go func() { + _, _ = io.Copy(io.Discard, c2) + close(drainDone) + }() + t.Cleanup(func() { + _ = c1.Close() + _ = c2.Close() + <-drainDone + }) + + waitCh := make(chan struct{}) + close(waitCh) + + mux := &UniversalUDPMuxDefault{ + UDPMuxDefault: &UDPMuxDefault{}, + params: UniversalUDPMuxParams{ + UDPConn: pc, + }, + xorMappedMap: map[string]*xorMapped{ + serverAddr.String(): { + addr: nil, + waitAddrReceived: waitCh, + expiresAt: time.Now().Add(time.Minute), + }, + }, + } + + addr, err := mux.GetXORMappedAddr(serverAddr, time.Second) + require.Nil(t, addr) + require.ErrorIs(t, err, errNoXorAddrMapping) +}