Improve code cov by adding various tests

This commit is contained in:
philipch07
2025-09-14 13:24:04 -04:00
committed by philipch07
parent 14b3cccabf
commit f7437da850
13 changed files with 1350 additions and 7 deletions

View File

@@ -7,7 +7,9 @@
package ice package ice
import ( import (
"context"
"fmt" "fmt"
"io"
"net" "net"
"net/netip" "net/netip"
"sync/atomic" "sync/atomic"
@@ -290,3 +292,122 @@ func TestActiveTCP_Respect_NetworkTypes(t *testing.T) {
require.NoError(t, tcpListener.Close()) require.NoError(t, tcpListener.Close())
require.Equal(t, uint64(0), atomic.LoadUint64(&incomingTCPCount)) require.Equal(t, uint64(0), atomic.LoadUint64(&incomingTCPCount))
} }
func TestNewActiveTCPConn_LocalAddrError_EarlyReturn(t *testing.T) {
defer test.CheckRoutines(t)()
logger := logging.NewDefaultLoggerFactory().NewLogger("ice")
// an invalid local address so getTCPAddrOnInterface fails at ResolveTCPAddr.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ra := netip.MustParseAddrPort("127.0.0.1:1")
a := newActiveTCPConn(ctx, "this_is_not_a_valid_addr", ra, logger)
require.NotNil(t, a)
require.True(t, a.closed.Load(), "should be closed on early return error")
la := a.LocalAddr()
require.NotNil(t, la)
}
func TestActiveTCPConn_ReadLoop_BufferWriteError(t *testing.T) {
defer test.CheckRoutines(t)()
tcpListener, err := net.Listen("tcp", "127.0.0.1:0") // nolint: noctx
require.NoError(t, err)
defer func() { _ = tcpListener.Close() }()
ra := netip.MustParseAddrPort(tcpListener.Addr().String())
logger := logging.NewDefaultLoggerFactory().NewLogger("ice")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
a := newActiveTCPConn(ctx, "127.0.0.1:0", ra, logger)
require.NotNil(t, a)
srvConn, err := tcpListener.Accept()
require.NoError(t, err)
require.NoError(t, a.readBuffer.Close())
_, err = writeStreamingPacket(srvConn, []byte("ping"))
require.NoError(t, err)
require.NoError(t, a.Close())
require.NoError(t, srvConn.Close())
}
func TestActiveTCPConn_WriteLoop_WriteStreamingError(t *testing.T) {
defer test.CheckRoutines(t)()
tcpListener, err := net.Listen("tcp", "127.0.0.1:0") // nolint: noctx
require.NoError(t, err)
defer func() { _ = tcpListener.Close() }()
ra := netip.MustParseAddrPort(tcpListener.Addr().String())
logger := logging.NewDefaultLoggerFactory().NewLogger("ice")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
a := newActiveTCPConn(ctx, "127.0.0.1:0", ra, logger)
require.NotNil(t, a)
srvConn, err := tcpListener.Accept()
require.NoError(t, err)
require.NoError(t, srvConn.Close())
n, err := a.WriteTo([]byte("data"), nil)
require.NoError(t, err)
require.Equal(t, len("data"), n)
require.NoError(t, a.Close())
}
func TestActiveTCPConn_LocalAddr_DefaultWhenUnset(t *testing.T) {
defer test.CheckRoutines(t)()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
invalidLocal := "127.0.0.1:65536"
remote := netip.MustParseAddrPort("127.0.0.1:1")
log := logging.NewDefaultLoggerFactory().NewLogger("ice")
a := newActiveTCPConn(ctx, invalidLocal, remote, log)
require.NotNil(t, a)
require.True(t, a.closed.Load(), "expected early-return closed state")
la := a.LocalAddr()
ta, ok := la.(*net.TCPAddr)
require.True(t, ok, "LocalAddr() should return *net.TCPAddr")
require.Nil(t, ta.IP, "fallback *net.TCPAddr should be zero value (nil IP)")
require.Equal(t, 0, ta.Port, "fallback *net.TCPAddr should be zero value (port 0)")
require.Equal(t, "", ta.Zone, "fallback *net.TCPAddr should be zero value (empty zone)")
}
func TestActiveTCPConn_SetDeadlines_ReturnEOF(t *testing.T) {
defer test.CheckRoutines(t)()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
invalidLocal := "127.0.0.1:65536"
remote := netip.MustParseAddrPort("127.0.0.1:1")
log := logging.NewDefaultLoggerFactory().NewLogger("ice")
a := newActiveTCPConn(ctx, invalidLocal, remote, log)
require.NotNil(t, a)
require.True(t, a.closed.Load(), "expected early-return closed state")
err := a.SetReadDeadline(time.Now())
require.ErrorIs(t, err, io.EOF)
err = a.SetWriteDeadline(time.Now())
require.ErrorIs(t, err, io.EOF)
}

133
addr_test.go Normal file
View File

@@ -0,0 +1,133 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
//go:build !js
// +build !js
package ice
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
)
// A net.Addr type that parseAddr doesn't handle.
type unknownAddr struct{}
func (unknownAddr) Network() string { return "unknown" }
func (unknownAddr) String() string { return "unknown-addr" }
func TestParseAddrFromIface_ErrFromParseAddr(t *testing.T) {
in := unknownAddr{}
ip, port, nt, err := parseAddrFromIface(in, "eth0")
require.Error(t, err, "expected error from parseAddr for unknown net.Addr type")
require.Zero(t, port)
require.Zero(t, nt)
require.True(t, !ip.IsValid(), "ip should be zero value when error is returned")
}
func TestParseAddr_ErrorBranches(t *testing.T) {
t.Run("IPNet invalid IP -> error", func(t *testing.T) {
// length 1 slice -> ipAddrToNetIP fails
_, _, _, err := parseAddr(&net.IPNet{IP: net.IP{1}})
var convErr ipConvertError
require.ErrorAs(t, err, &convErr)
})
t.Run("IPAddr invalid IP -> error", func(t *testing.T) {
_, _, _, err := parseAddr(&net.IPAddr{IP: net.IP{1}, Zone: "eth0"})
var convErr ipConvertError
require.ErrorAs(t, err, &convErr)
})
t.Run("UDPAddr invalid IP -> error", func(t *testing.T) {
_, _, _, err := parseAddr(&net.UDPAddr{IP: net.IP{1}, Port: 3478})
var convErr ipConvertError
require.ErrorAs(t, err, &convErr)
})
t.Run("TCPAddr invalid IP -> error", func(t *testing.T) {
_, _, _, err := parseAddr(&net.TCPAddr{IP: net.IP{1}, Port: 3478})
var convErr ipConvertError
require.ErrorAs(t, err, &convErr)
})
t.Run("Unknown net.Addr type -> addrParseError", func(t *testing.T) {
_, _, _, err := parseAddr(unknownAddr{})
var ap addrParseError
require.ErrorAs(t, err, &ap)
})
}
func TestParseAddr_IPAddr_Success(t *testing.T) {
ip := net.ParseIP("fe80::1")
require.NotNil(t, ip)
gotIP, port, nt, err := parseAddr(&net.IPAddr{IP: ip, Zone: "lo0"})
require.NoError(t, err)
require.Equal(t, 0, port)
require.Equal(t, NetworkType(0), nt)
require.True(t, gotIP.Is6())
require.Equal(t, "lo0", gotIP.Zone())
require.Equal(t, 0, gotIP.Compare(netip.MustParseAddr("fe80::1%lo0").Unmap()))
}
func TestAddrParseError_Error(t *testing.T) {
e := addrParseError{addr: &net.TCPAddr{}}
require.Equal(t,
"do not know how to parse address type *net.TCPAddr",
e.Error(),
)
}
func TestIPConvertError_Error(t *testing.T) {
e := ipConvertError{ip: []byte("bad-ip")}
require.Equal(t,
"failed to convert IP 'bad-ip' to netip.Addr",
e.Error(),
)
}
func TestIPAddrToNetIP_Error_InvalidBytes(t *testing.T) {
bad := []byte{1} // invalid length -> AddrFromSlice returns ok=false
got, err := ipAddrToNetIP(bad, "")
require.Equal(t, netip.Addr{}, got, "should return zero addr on error")
require.Error(t, err)
require.IsType(t, ipConvertError{}, err)
require.Contains(t, err.Error(), "failed to convert IP")
}
func TestIPAddrToNetIP_OK_IPv4(t *testing.T) {
ipv4 := []byte{1, 2, 3, 4}
got, err := ipAddrToNetIP(ipv4, "")
require.NoError(t, err)
require.True(t, got.Is4())
want := netip.AddrFrom4([4]byte{1, 2, 3, 4})
require.Equal(t, want, got)
}
func TestAddrEqual_FirstParseError(t *testing.T) {
a := unknownAddr{}
b := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 9999}
require.False(t, addrEqual(a, b))
}
func TestAddrEqual_SecondParseError(t *testing.T) {
a := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 9999}
b := unknownAddr{}
require.False(t, addrEqual(a, b))
}
func TestAddrEqual_SameTypeIPPort(t *testing.T) {
a := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 4242}
b := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 4242}
require.True(t, addrEqual(a, b))
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/pion/transport/v3/test" "github.com/pion/transport/v3/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestConnectionStateNotifier(t *testing.T) { func TestConnectionStateNotifier(t *testing.T) {
@@ -70,3 +71,119 @@ func TestConnectionStateNotifier(t *testing.T) {
notifer.Close(true) notifer.Close(true)
}) })
} }
func TestHandlerNotifier_Close_AlreadyClosed(t *testing.T) {
defer test.CheckRoutines(t)()
notifier := &handlerNotifier{
connectionStateFunc: func(ConnectionState) {},
candidateFunc: func(Candidate) {},
candidatePairFunc: func(*CandidatePair) {},
done: make(chan struct{}),
}
// first close
notifier.Close(false)
isClosed := func(ch <-chan struct{}) bool {
select {
case <-ch:
return true
default:
return false
}
}
assert.True(t, isClosed(notifier.done), "expected h.done to be closed after first Close")
// second close should hit `case <-h.done` and return immediately
// without blocking on the WaitGroup.
finished := make(chan struct{}, 1)
go func() {
notifier.Close(true)
close(finished)
}()
assert.Eventually(t, func() bool {
select {
case <-finished:
return true
default:
return false
}
}, 250*time.Millisecond, 10*time.Millisecond, "second Close(true) did not return promptly")
// ensure still closed afterwards
assert.True(t, isClosed(notifier.done), "expected h.done to remain closed after second Close")
// sanity: no enqueues should start after close.
require.False(t, notifier.running)
require.Zero(t, len(notifier.connectionStates))
require.Zero(t, len(notifier.candidates))
require.Zero(t, len(notifier.selectedCandidatePairs))
}
func TestHandlerNotifier_EnqueueConnectionState_AfterClose(t *testing.T) {
defer test.CheckRoutines(t)()
connCh := make(chan struct{}, 1)
notifier := &handlerNotifier{
connectionStateFunc: func(ConnectionState) { connCh <- struct{}{} },
done: make(chan struct{}),
}
notifier.Close(false)
notifier.EnqueueConnectionState(ConnectionStateConnected)
assert.Never(t, func() bool {
select {
case <-connCh:
return true
default:
return false
}
}, 250*time.Millisecond, 10*time.Millisecond, "connectionStateFunc should not be called after close")
}
func TestHandlerNotifier_EnqueueCandidate_AfterClose(t *testing.T) {
defer test.CheckRoutines(t)()
candidateCh := make(chan struct{}, 1)
h := &handlerNotifier{
candidateFunc: func(Candidate) { candidateCh <- struct{}{} },
done: make(chan struct{}),
}
h.Close(false)
h.EnqueueCandidate(nil)
assert.Never(t, func() bool {
select {
case <-candidateCh:
return true
default:
return false
}
}, 250*time.Millisecond, 10*time.Millisecond, "candidateFunc should not be called after close")
}
func TestHandlerNotifier_EnqueueSelectedCandidatePair_AfterClose(t *testing.T) {
defer test.CheckRoutines(t)()
pairCh := make(chan struct{}, 1)
h := &handlerNotifier{
candidatePairFunc: func(*CandidatePair) { pairCh <- struct{}{} },
done: make(chan struct{}),
}
h.Close(false)
h.EnqueueSelectedCandidatePair(nil)
assert.Never(t, func() bool {
select {
case <-pairCh:
return true
default:
return false
}
}, 250*time.Millisecond, 10*time.Millisecond, "candidatePairFunc should not be called after close")
}

View File

@@ -2012,3 +2012,28 @@ func TestRoleConflict(t *testing.T) {
runTest(t, false) runTest(t, false)
}) })
} }
func TestAgentConfig_initWithDefaults_UsesProvidedValues(t *testing.T) {
valMaxBindingReq := uint16(0)
valSrflxWait := 111 * time.Millisecond
valPrflxWait := 222 * time.Millisecond
valRelayWait := 3 * time.Second
valStunTimeout := 4 * time.Second
cfg := &AgentConfig{
MaxBindingRequests: &valMaxBindingReq,
SrflxAcceptanceMinWait: &valSrflxWait,
PrflxAcceptanceMinWait: &valPrflxWait,
RelayAcceptanceMinWait: &valRelayWait,
STUNGatherTimeout: &valStunTimeout,
}
var a Agent
cfg.initWithDefaults(&a)
require.Equal(t, valMaxBindingReq, a.maxBindingRequests, "expected override for MaxBindingRequests")
require.Equal(t, valSrflxWait, a.srflxAcceptanceMinWait, "expected override for SrflxAcceptanceMinWait")
require.Equal(t, valPrflxWait, a.prflxAcceptanceMinWait, "expected override for PrflxAcceptanceMinWait")
require.Equal(t, valRelayWait, a.relayAcceptanceMinWait, "expected override for RelayAcceptanceMinWait")
require.Equal(t, valStunTimeout, a.stunGatherTimeout, "expected override for STUNGatherTimeout")
}

View File

@@ -130,3 +130,48 @@ func TestNilCandidatePairString(t *testing.T) {
var nilCandidatePair *CandidatePair var nilCandidatePair *CandidatePair
require.Equal(t, nilCandidatePair.String(), "") require.Equal(t, nilCandidatePair.String(), "")
} }
func TestCandidatePairState_String(t *testing.T) {
tests := []struct {
name string
in CandidatePairState
want string
}{
{"waiting", CandidatePairStateWaiting, "waiting"},
{"in-progress", CandidatePairStateInProgress, "in-progress"},
{"failed", CandidatePairStateFailed, "failed"},
{"succeeded", CandidatePairStateSucceeded, "succeeded"},
{"unknown", CandidatePairState(255), "Unknown candidate pair state"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, tt.in.String())
})
}
}
func TestCandidatePairEqual_NilCases(t *testing.T) {
// both nil -> true
var a *CandidatePair
var b *CandidatePair
require.True(t, a.equal(b), "both nil pairs should be equal")
// left non-nil, right nil -> false
a = newCandidatePair(hostCandidate(), srflxCandidate(), true)
require.False(t, a.equal(nil), "non-nil vs nil should be false")
// left nil, right non-nil -> false
require.False(t, (*CandidatePair)(nil).equal(a), "nil vs non-nil should be false")
}
func TestCandidatePair_TimeGetters_DefaultZero(t *testing.T) {
p := newCandidatePair(hostCandidate(), srflxCandidate(), true)
require.True(t, p.FirstRequestSentAt().IsZero(), "FirstRequestSentAt should be zero by default")
require.True(t, p.LastRequestSentAt().IsZero(), "LastRequestSentAt should be zero by default")
require.True(t, p.FirstReponseReceivedAt().IsZero(), "FirstReponseReceivedAt should be zero by default")
require.True(t, p.LastResponseReceivedAt().IsZero(), "LastResponseReceivedAt should be zero by default")
require.True(t, p.FirstRequestReceivedAt().IsZero(), "FirstRequestReceivedAt should be zero by default")
require.True(t, p.LastRequestReceivedAt().IsZero(), "LastRequestReceivedAt should be zero by default")
}

39
candidatetype_test.go Normal file
View File

@@ -0,0 +1,39 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package ice
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestCandidateType_String_KnownCases(t *testing.T) {
cases := map[CandidateType]string{
CandidateTypeHost: "host",
CandidateTypeServerReflexive: "srflx",
CandidateTypePeerReflexive: "prflx",
CandidateTypeRelay: "relay",
CandidateTypeUnspecified: "Unknown candidate type",
}
for ct, want := range cases {
require.Equal(t, want, ct.String(), "unexpected string for %v", ct)
}
}
func TestCandidateType_String_Default(t *testing.T) {
const outOfBounds CandidateType = 255
require.Equal(t, "Unknown candidate type", outOfBounds.String())
}
func TestCandidateType_Preference_DefaultCase(t *testing.T) {
const outOfBounds CandidateType = 255
require.Equal(t, uint16(0), outOfBounds.Preference())
}
func TestContainsCandidateType_NilSlice(t *testing.T) {
var list []CandidateType // nil slice
require.False(t, containsCandidateType(CandidateTypeHost, list))
}

View File

@@ -4,11 +4,16 @@
package ice package ice
import ( import (
"errors"
"net" "net"
"net/netip" "net/netip"
"sort"
"strings" "strings"
"testing" "testing"
"github.com/pion/logging"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -58,3 +63,181 @@ func mustAddr(t *testing.T, ip net.IP) netip.Addr {
return addr return addr
} }
type errInterfacesNet struct {
transport.Net
retErr error
}
func (e *errInterfacesNet) Interfaces() ([]*transport.Interface, error) {
return nil, e.retErr
}
var errBoom = errors.New("boom")
func TestLocalInterfaces_ErrorFromInterfaces(t *testing.T) {
base, err := stdnet.NewNet()
require.NoError(t, err)
wrapped := &errInterfacesNet{
Net: base,
retErr: errBoom,
}
ifaces, addrs, gotErr := localInterfaces(
wrapped,
nil,
nil,
nil,
false,
)
require.ErrorIs(t, gotErr, wrapped.retErr)
require.Nil(t, ifaces, "expected nil iface slice on error")
require.NotNil(t, addrs, "ipAddrs should be a non-nil empty slice")
require.Len(t, addrs, 0)
}
type fixedInterfacesNet struct {
transport.Net
list []*transport.Interface
}
func (f *fixedInterfacesNet) Interfaces() ([]*transport.Interface, error) {
return f.list, nil
}
func TestLocalInterfaces_SkipInterfaceDown(t *testing.T) {
base, err := stdnet.NewNet()
require.NoError(t, err)
sysIfaces, err := base.Interfaces()
require.NoError(t, err)
if len(sysIfaces) == 0 {
t.Skip("no system network interfaces available")
}
clone := *sysIfaces[0]
clone.Flags &^= net.FlagUp
wrapped := &fixedInterfacesNet{
Net: base,
list: []*transport.Interface{&clone},
}
ifcs, addrs, ierr := localInterfaces(
wrapped,
nil,
nil,
nil,
false,
)
require.NoError(t, ierr)
require.Len(t, ifcs, 0, "down interfaces must be skipped")
require.Len(t, addrs, 0, "no addresses should be collected from a down interface")
}
func TestLocalInterfaces_SkipLoopbackAddrs_WhenIncludeLoopbackFalse(t *testing.T) {
base, err := stdnet.NewNet()
require.NoError(t, err)
sysIfaces, err := base.Interfaces()
require.NoError(t, err)
var loop *transport.Interface
for _, ifc := range sysIfaces {
if ifc.Flags&net.FlagLoopback != 0 {
loop = ifc
break
}
}
if loop == nil {
t.Skip("no loopback interface found on this system")
}
// clone the loopback iface and clear the Loopback flag so the outer check
// doesn't drop it to force the inner `(ipAddr.IsLoopback() && !includeLoopback)`.
cloned := *loop
cloned.Flags |= net.FlagUp
cloned.Flags &^= net.FlagLoopback
wrapped := &fixedInterfacesNet{
Net: base,
list: []*transport.Interface{&cloned},
}
ifaces, addrs, ierr := localInterfaces(
wrapped,
nil, // interfaceFilter
nil, // ipFilter
nil, // networkTypes
false, // includeLoopback
)
require.NoError(t, ierr)
// don't assert on the number of interfaces because some systems may
// report the iface as having addresses in a way that causes it to be included.
// assert that all loopback addresses were skipped.
for _, a := range addrs {
require.False(t, a.IsLoopback(), "loopback addresses must be skipped when includeLoopback=false")
}
_ = ifaces // intentionally don't assert on this, see above comment
}
// Captures ListenUDP attempts and always fails so the loop exhausts.
type listenUDPCaptor struct {
transport.Net
attempts []int
}
func (c *listenUDPCaptor) ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
c.attempts = append(c.attempts, laddr.Port)
return nil, errBoom
}
func TestListenUDPInPortRange_DefaultsPortMinTo1024(t *testing.T) {
base, err := stdnet.NewNet()
require.NoError(t, err)
captor := &listenUDPCaptor{Net: base}
logger := logging.NewDefaultLoggerFactory().NewLogger("ice-test")
// portMin == 0 (should become 1024), portMax small to keep the loop short.
_, err = listenUDPInPortRange(
captor,
logger,
1030, // portMax
0, // portMin -> becomes 1024
udp4,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0},
)
require.ErrorIs(t, err, ErrPort)
// should have attempted exactly [1024..1030] in some order.
sort.Ints(captor.attempts)
require.Equal(t, []int{1024, 1025, 1026, 1027, 1028, 1029, 1030}, captor.attempts)
}
func TestListenUDPInPortRange_DefaultsPortMaxToFFFF(t *testing.T) {
base, err := stdnet.NewNet()
require.NoError(t, err)
captor := &listenUDPCaptor{Net: base}
logger := logging.NewDefaultLoggerFactory().NewLogger("ice-test")
// portMax == 0 (should become 0xFFFF). Use portMin=65535 so the range is 1 port.
_, err = listenUDPInPortRange(
captor,
logger,
0, // portMax -> becomes 65535
65535, // portMin
udp4,
&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0},
)
require.ErrorIs(t, err, ErrPort)
require.Equal(t, []int{65535}, captor.attempts)
}

View File

@@ -83,3 +83,53 @@ func TestNetworkTypeIsTCP(t *testing.T) {
require.False(t, NetworkTypeTCP4.IsUDP()) require.False(t, NetworkTypeTCP4.IsUDP())
require.False(t, NetworkTypeTCP6.IsUDP()) require.False(t, NetworkTypeTCP6.IsUDP())
} }
func TestNetworkType_String_Default(t *testing.T) {
var invalid NetworkType // 0 triggers default branch
require.Equal(t, ErrUnknownType.Error(), invalid.String())
require.Equal(t, "udp4", NetworkTypeUDP4.String())
require.Equal(t, "udp6", NetworkTypeUDP6.String())
require.Equal(t, "tcp4", NetworkTypeTCP4.String())
require.Equal(t, "tcp6", NetworkTypeTCP6.String())
}
func TestNetworkType_NetworkShort_Default(t *testing.T) {
var invalid NetworkType
require.Equal(t, ErrUnknownType.Error(), invalid.NetworkShort())
require.Equal(t, udp, NetworkTypeUDP4.NetworkShort())
require.Equal(t, udp, NetworkTypeUDP6.NetworkShort())
require.Equal(t, tcp, NetworkTypeTCP4.NetworkShort())
require.Equal(t, tcp, NetworkTypeTCP6.NetworkShort())
}
func TestNetworkType_IPvFlags_Default(t *testing.T) {
var invalid NetworkType
require.False(t, invalid.IsIPv4())
require.False(t, invalid.IsIPv6())
require.True(t, NetworkTypeUDP4.IsIPv4())
require.True(t, NetworkTypeTCP4.IsIPv4())
require.False(t, NetworkTypeUDP6.IsIPv4())
require.False(t, NetworkTypeTCP6.IsIPv4())
require.True(t, NetworkTypeUDP6.IsIPv6())
require.True(t, NetworkTypeTCP6.IsIPv6())
require.False(t, NetworkTypeUDP4.IsIPv6())
require.False(t, NetworkTypeTCP4.IsIPv6())
}
func TestNetworkType_IsReliable(t *testing.T) {
// UDP is unreliable
require.False(t, NetworkTypeUDP4.IsReliable())
require.False(t, NetworkTypeUDP6.IsReliable())
// TCP is reliable
require.True(t, NetworkTypeTCP4.IsReliable())
require.True(t, NetworkTypeTCP6.IsReliable())
// default/unknown falls through to false
var invalid NetworkType
require.False(t, invalid.IsReliable())
}

64
role_test.go Normal file
View File

@@ -0,0 +1,64 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
//go:build !js
// +build !js
package ice
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestUnmarshalText_Success(t *testing.T) {
tests := []struct {
name string
in string
want Role
}{
{"controlling", "controlling", Controlling},
{"controlled", "controlled", Controlled},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var r Role
err := r.UnmarshalText([]byte(tt.in))
require.NoError(t, err)
require.Equal(t, tt.want, r)
})
}
}
func TestUnmarshalText_UnknownKeepsValueAndErrors(t *testing.T) {
r := Controlled
err := r.UnmarshalText([]byte("neither"))
require.ErrorIs(t, err, errUnknownRole)
require.Equal(t, Controlled, r, "role should remain unchanged on error")
}
func TestMarshalText(t *testing.T) {
tests := []struct {
name string
in Role
want string
}{
{"controlling", Controlling, "controlling"},
{"controlled", Controlled, "controlled"},
{"unknown", Role(99), "unknown"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b, err := tt.in.MarshalText()
require.NoError(t, err)
require.Equal(t, tt.want, string(b))
})
}
}
func TestString(t *testing.T) {
require.Equal(t, "controlling", Controlling.String())
require.Equal(t, "controlled", Controlled.String())
require.Equal(t, "unknown", Role(255).String())
}

View File

@@ -10,12 +10,15 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"strings"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
"github.com/pion/logging"
"github.com/pion/stun/v3" "github.com/pion/stun/v3"
"github.com/pion/transport/v3/test" "github.com/pion/transport/v3/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -158,3 +161,199 @@ func TestBindingRequestHandler(t *testing.T) {
closePipe(t, controllingConn, controlledConn) closePipe(t, controllingConn, controlledConn)
} }
// copied from pion/webrtc's peerconnection_go_test.go.
type testICELogger struct {
lastErrorMessage string
}
func (t *testICELogger) Trace(string) {}
func (t *testICELogger) Tracef(string, ...any) {}
func (t *testICELogger) Debug(string) {}
func (t *testICELogger) Debugf(string, ...any) {}
func (t *testICELogger) Info(string) {}
func (t *testICELogger) Infof(string, ...any) {}
func (t *testICELogger) Warn(string) {}
func (t *testICELogger) Warnf(string, ...any) {}
func (t *testICELogger) Error(msg string) { t.lastErrorMessage = msg }
func (t *testICELogger) Errorf(format string, args ...any) {
t.lastErrorMessage = fmt.Sprintf(format, args...)
}
type testICELoggerFactory struct {
logger *testICELogger
}
func (t *testICELoggerFactory) NewLogger(string) logging.LeveledLogger {
return t.logger
}
func TestControllingSelector_IsNominatable_LogsInvalidType(t *testing.T) {
testLogger := &testICELogger{}
loggerFactory := &testICELoggerFactory{logger: testLogger}
sel := &controllingSelector{
agent: &Agent{},
log: loggerFactory.NewLogger("test"),
}
sel.Start()
c := hostCandidate()
c.candidateBase.candidateType = CandidateTypeUnspecified
got := sel.isNominatable(c)
require.False(t, got)
require.Contains(t, testLogger.lastErrorMessage, "Invalid candidate type")
require.Contains(t, testLogger.lastErrorMessage, "Unknown candidate type") // from c.Type().String()
}
func TestControllingSelector_NominatePair_BuildError(t *testing.T) {
testLogger := &testICELogger{}
loggerFactory := &testICELoggerFactory{logger: testLogger}
// selector with an Agent with ufrags to make an oversized username
// (username = remoteUfrag + ":" + localUfrag) since oversized username causes
// stun.NewUsername(...) inside stun.Build to fail.
long := strings.Repeat("x", 300) // > 255 each side
sel := &controllingSelector{
agent: &Agent{
remoteUfrag: long,
localUfrag: long,
remotePwd: "pwd", // any non-empty value is fine
tieBreaker: 0,
},
log: loggerFactory.NewLogger("test"),
}
sel.Start()
p := newCandidatePair(hostCandidate(), hostCandidate(), true)
sel.nominatePair(p)
require.NotEmpty(t, testLogger.lastErrorMessage, "expected error log from nominatePair on Build failure")
}
type pingNoIOCand struct{ candidateBase }
func newPingNoIOCand() *pingNoIOCand {
return &pingNoIOCand{
candidateBase: candidateBase{
candidateType: CandidateTypeHost,
component: ComponentRTP,
},
}
}
func (d *pingNoIOCand) writeTo(b []byte, _ Candidate) (int, error) { return len(b), nil }
func bareAgentForPing() *Agent {
return &Agent{
hostAcceptanceMinWait: time.Hour,
srflxAcceptanceMinWait: time.Hour,
prflxAcceptanceMinWait: time.Hour,
relayAcceptanceMinWait: time.Hour,
checklist: []*CandidatePair{},
keepaliveInterval: time.Second,
checkInterval: time.Second,
connectionStateNotifier: &handlerNotifier{
done: make(chan struct{}),
connectionStateFunc: func(ConnectionState) {}}, //nolint formatting
candidateNotifier: &handlerNotifier{
done: make(chan struct{}),
candidateFunc: func(Candidate) {}}, //nolint formatting
selectedCandidatePairNotifier: &handlerNotifier{
done: make(chan struct{}),
candidatePairFunc: func(*CandidatePair) {}}, //nolint formatting
}
}
func bigStr() string { return strings.Repeat("x", 40000) }
func TestControllingSelector_PingCandidate_BuildError(t *testing.T) {
a := bareAgentForPing()
// make Username really big so stun.Build returns an error.
a.remoteUfrag = bigStr()
a.localUfrag = bigStr()
a.remotePwd = "pwd"
a.tieBreaker = 1
testLogger := &testICELogger{}
sel := &controllingSelector{agent: a, log: testLogger}
sel.Start()
local := newPingNoIOCand()
remote := newPingNoIOCand()
sel.PingCandidate(local, remote)
require.NotEmpty(t, testLogger.lastErrorMessage, "expected error to be logged from stun.Build")
}
func TestControlledSelector_PingCandidate_BuildError(t *testing.T) {
a := bareAgentForPing()
a.remoteUfrag = bigStr()
a.localUfrag = bigStr()
a.remotePwd = "pwd"
a.tieBreaker = 1
testLogger := &testICELogger{}
sel := &controlledSelector{agent: a, log: testLogger}
sel.Start()
local := newPingNoIOCand()
remote := newPingNoIOCand()
sel.PingCandidate(local, remote)
require.NotEmpty(t, testLogger.lastErrorMessage, "expected error to be logged from stun.Build")
}
type warnTestLogger struct {
warned bool
}
func (l *warnTestLogger) Trace(string) {}
func (l *warnTestLogger) Tracef(string, ...any) {}
func (l *warnTestLogger) Debug(string) {}
func (l *warnTestLogger) Debugf(string, ...any) {}
func (l *warnTestLogger) Info(string) {}
func (l *warnTestLogger) Infof(string, ...any) {}
func (l *warnTestLogger) Warn(string) { l.warned = true }
func (l *warnTestLogger) Warnf(string, ...any) { l.warned = true }
func (l *warnTestLogger) Error(string) {}
func (l *warnTestLogger) Errorf(string, ...any) {}
type dummyNoIOCand struct{ candidateBase }
func newDummyNoIOCand(t CandidateType) *dummyNoIOCand {
return &dummyNoIOCand{
candidateBase: candidateBase{
candidateType: t,
component: ComponentRTP,
},
}
}
func (d *dummyNoIOCand) writeTo(p []byte, _ Candidate) (int, error) { return len(p), nil }
func TestControlledSelector_HandleSuccessResponse_UnknownTxID(t *testing.T) {
logger := &warnTestLogger{}
ag := &Agent{log: logger}
sel := &controlledSelector{agent: ag, log: logger}
sel.Start()
local := newDummyNoIOCand(CandidateTypeHost)
remote := newDummyNoIOCand(CandidateTypeHost)
var m stun.Message
copy(m.TransactionID[:], []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})
sel.HandleSuccessResponse(&m, local, remote, nil)
require.True(t, logger.warned, "expected Warnf to be called for unknown TransactionID (hitting !ok branch)")
}

View File

@@ -310,3 +310,49 @@ func TestConnStats(t *testing.T) {
require.Equal(t, uint64(10), ca.BytesSent()) require.Equal(t, uint64(10), ca.BytesSent())
require.Equal(t, uint64(10), cb.BytesReceived()) require.Equal(t, uint64(10), cb.BytesReceived())
} }
func TestAgent_connect_ErrEarly(t *testing.T) {
defer test.CheckRoutines(t)()
cfg := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
}
a, err := NewAgent(cfg)
require.NoError(t, err)
require.NoError(t, a.Close())
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// isControlling = true
conn, cerr := a.connect(ctx, true, "ufragX", "pwdX")
require.Nil(t, conn)
require.Error(t, cerr, "expected error from a.loop.Err() short-circuit")
}
func TestConn_Write_RejectsSTUN(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(10 * time.Second).Stop()
cfg := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
}
a, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
_ = a.Close()
}()
c := &Conn{agent: a}
require.Nil(t, c.agent.getSelectedPair(), "precondition: no selected pair")
msg := stun.New()
msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
msg.Encode()
n, werr := c.Write(msg.Raw)
require.Zero(t, n)
require.ErrorIs(t, werr, errWriteSTUNMessageToIceConn)
}

View File

@@ -10,12 +10,15 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"errors"
"io" "io"
"net" "net"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/pion/ice/v4/internal/fakenet"
"github.com/pion/logging"
"github.com/pion/stun/v3" "github.com/pion/stun/v3"
"github.com/pion/transport/v3/test" "github.com/pion/transport/v3/test"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -511,3 +514,319 @@ func TestUDPMuxedConn_writePacket_NotifyDefaultBranch(t *testing.T) {
// channel still full => no send happened (default path executed) // channel still full => no send happened (default path executed)
require.Equal(t, 1, len(conn.notify)) require.Equal(t, 1, len(conn.notify))
} }
func TestNewUDPMuxDefault_LocalAddrNotUDPAddr(t *testing.T) {
defer test.CheckRoutines(t)()
c1, c2 := net.Pipe()
defer func() { _ = c2.Close() }()
pc := &fakenet.PacketConn{Conn: c1}
mux := NewUDPMuxDefault(UDPMuxParams{
Logger: logging.NewDefaultLoggerFactory().NewLogger("ice"),
UDPConn: pc,
})
require.NotNil(t, mux)
defer func() { _ = mux.Close() }()
addrs := mux.GetListenAddresses()
require.Len(t, addrs, 1)
require.Equal(t, pc.LocalAddr().String(), addrs[0].String())
}
func TestUDPMuxDefault_GetConn_InvalidAddress(t *testing.T) {
defer test.CheckRoutines(t)()
connA, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
udpMux := NewUDPMuxDefault(UDPMuxParams{
Logger: nil,
UDPConn: connA,
})
defer func() {
_ = udpMux.Close()
_ = connA.Close()
}()
connB, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
defer func() { _ = connB.Close() }()
pc, gerr := udpMux.GetConn("some-ufrag", connB.LocalAddr())
require.Nil(t, pc)
require.ErrorIs(t, gerr, errInvalidAddress)
}
func TestUDPMuxDefault_registerConnForAddress_ClosedMuxEarlyReturn(t *testing.T) {
defer test.CheckRoutines(t)()
udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
udpMux := NewUDPMuxDefault(UDPMuxParams{UDPConn: udp})
require.NoError(t, udpMux.Close())
_ = udp.Close()
conn := secondTestMuxedConn(t, 64)
addr, err := newIPPort(net.ParseIP("1.2.3.4"), "", 9999)
require.NoError(t, err)
before := len(udpMux.addressMap)
udpMux.registerConnForAddress(conn, addr)
after := len(udpMux.addressMap)
require.Equal(t, before, after)
_, exists := udpMux.addressMap[addr]
require.False(t, exists)
}
func TestUDPMuxDefault_registerConnForAddress_ReplacesExisting(t *testing.T) {
defer test.CheckRoutines(t)()
udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
udpMux := NewUDPMuxDefault(UDPMuxParams{UDPConn: udp})
defer func() {
_ = udpMux.Close()
_ = udp.Close()
}()
ipAddr, err := newIPPort(net.ParseIP("5.6.7.8"), "", 12345)
require.NoError(t, err)
existing := secondTestMuxedConn(t, 64)
existing.addresses = []ipPort{ipAddr}
udpMux.addressMapMu.Lock()
udpMux.addressMap[ipAddr] = existing
udpMux.addressMapMu.Unlock()
// new conn should replace existing mapping and cause removeAddress on the old one.
newConn := secondTestMuxedConn(t, 64)
udpMux.registerConnForAddress(newConn, ipAddr)
// map should now point to newConn.
udpMux.addressMapMu.RLock()
mapped := udpMux.addressMap[ipAddr]
udpMux.addressMapMu.RUnlock()
require.Equal(t, newConn, mapped)
// old conn should have ipAddr removed from its addresses.
require.False(t, existing.containsAddress(ipAddr), "old conn should have removed the address backref")
}
func stunWithLen(l uint16) []byte {
m := stun.New()
m.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
m.Encode()
out := append([]byte{}, m.Raw...)
out[2] = byte(l >> 8)
out[3] = byte(l & 0xff)
return out
}
type scriptedUDPPC struct {
local *net.UDPAddr
seq []struct {
data []byte
addr net.Addr
err error
}
i int
}
func (s *scriptedUDPPC) ReadFrom(p []byte) (int, net.Addr, error) {
if s.i >= len(s.seq) {
return 0, s.local, errIoEOF
}
step := s.seq[s.i]
s.i++
if step.err != nil {
return 0, step.addr, step.err
}
n := copy(p, step.data)
return n, step.addr, nil
}
func (s *scriptedUDPPC) WriteTo([]byte, net.Addr) (int, error) { return 0, nil }
func (s *scriptedUDPPC) Close() error { return nil }
func (s *scriptedUDPPC) LocalAddr() net.Addr { return s.local }
func (s *scriptedUDPPC) SetDeadline(time.Time) error { return nil }
func (s *scriptedUDPPC) SetReadDeadline(time.Time) error { return nil }
func (s *scriptedUDPPC) SetWriteDeadline(time.Time) error { return nil }
var errIoEOF = errors.New("EOF")
func TestUDPMux_connWorker_AddrNotUDP(t *testing.T) {
defer test.CheckRoutines(t)()
c1, c2 := net.Pipe()
defer func() {
_ = c2.Close()
}()
pc := &fakenet.PacketConn{Conn: c1}
mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc})
require.NotNil(t, mux)
defer func() {
_ = mux.Close()
}()
_, _ = c2.Write([]byte("frame"))
_ = c2.Close()
}
func TestUDPMux_connWorker_ReadError_Timeout(t *testing.T) {
defer test.CheckRoutines(t)()
c1, c2 := net.Pipe()
pc := &fakenet.PacketConn{Conn: c1}
mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc})
require.NotNil(t, mux)
_ = pc.SetReadDeadline(time.Unix(0, 0))
_ = c2.Close()
_ = mux.Close()
}
func TestUDPMux_connWorker_NewIPPortError(t *testing.T) {
defer test.CheckRoutines(t)()
badIP := net.IP{1}
remote := &net.UDPAddr{IP: badIP, Port: 9999}
pc := &scriptedUDPPC{
local: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7000},
seq: []struct {
data []byte
addr net.Addr
err error
}{
{data: []byte{1}, addr: remote, err: nil}, // triggers newIPPort error
},
}
mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc})
require.NotNil(t, mux)
_ = mux.Close()
}
func TestUDPMux_connWorker_STUNDecodeError(t *testing.T) {
defer test.CheckRoutines(t)()
remote := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 2), Port: 5678}
pc := &scriptedUDPPC{
local: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7001},
seq: []struct {
data []byte
addr net.Addr
err error
}{
// bad STUN length -> Decode() error -> Warnf + continue
{data: stunWithLen(4), addr: remote, err: nil},
{data: nil, addr: remote, err: errIoEOF}, // exit loop
},
}
mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc})
require.NotNil(t, mux)
_ = mux.Close()
}
func TestUDPMux_connWorker_STUNNoUsername(t *testing.T) {
defer test.CheckRoutines(t)()
msg := stun.New()
msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
msg.Encode() // valid STUN + no USERNAME
remote := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 3), Port: 5679}
pc := &scriptedUDPPC{
local: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7002},
seq: []struct {
data []byte
addr net.Addr
err error
}{
{data: append([]byte{}, msg.Raw...), addr: remote, err: nil}, // Get(USERNAME) fails
{data: nil, addr: remote, err: errIoEOF}, // exit loop
},
}
mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc})
require.NotNil(t, mux)
_ = mux.Close()
}
func TestUDPMux_connWorker_WritePacketError(t *testing.T) {
defer test.CheckRoutines(t)()
local := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7003}
remote := &net.UDPAddr{IP: net.IPv4(203, 0, 113, 7), Port: 5555}
payload := []byte("0123456789ABCDEF")
pc := &scriptedUDPPC{
local: local,
seq: []struct {
data []byte
addr net.Addr
err error
}{
{data: payload, addr: remote, err: nil},
{data: nil, addr: remote, err: errIoEOF}, // exit loop
},
}
mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc})
require.NotNil(t, mux)
defer func() {
_ = mux.Close()
}()
// shrink pool to force io.ErrShortBuffer in writePacket
mux.pool = &sync.Pool{New: func() any { return newBufferHolder(8) }}
// make connWorker route to new conn.
c, err := mux.GetConn("ufragX", mux.LocalAddr())
require.NoError(t, err)
defer func() {
_ = c.Close()
}()
// remote port is controlled. we use 5555 here to skip int overflow check as we would
// otherwise have to cast remote.Port (int) to uint16.
ipport, err := newIPPort(remote.IP, remote.Zone, 5555)
require.NoError(t, err)
cInner, ok := c.(*udpMuxedConn)
require.True(t, ok, "expected *udpMuxedConn from UDPMuxDefault.GetConn")
mux.registerConnForAddress(cInner, ipport)
}
func TestNewUDPMuxDefault_UnspecifiedAddr_AutoInitNet(t *testing.T) {
defer test.CheckRoutines(t)()
conn, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4zero})
require.NoError(t, err)
defer func() { _ = conn.Close() }()
mux := NewUDPMuxDefault(UDPMuxParams{
Logger: nil,
UDPConn: conn,
Net: nil,
})
require.NotNil(t, mux)
defer func() { _ = mux.Close() }()
addrs := mux.GetListenAddresses()
require.GreaterOrEqual(t, len(addrs), 1, "should list at least one local listen address")
udpAddr, ok := conn.LocalAddr().(*net.UDPAddr)
require.True(t, ok, "LocalAddr is not *net.UDPAddr")
wantPort := udpAddr.Port
for _, a := range addrs {
ua, ok := a.(*net.UDPAddr)
require.True(t, ok, "returned listen address must be *net.UDPAddr")
require.Equal(t, wantPort, ua.Port, "listen addresses should reuse the same UDP port")
}
}

View File

@@ -11,12 +11,14 @@ import (
) )
func TestUseCandidateAttr_AddTo(t *testing.T) { func TestUseCandidateAttr_AddTo(t *testing.T) {
m := new(stun.Message) msg := stun.New()
require.False(t, UseCandidate().IsSet(m)) msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
require.NoError(t, m.Build(stun.BindingRequest, UseCandidate())) require.False(t, UseCandidate().IsSet(msg))
m1 := new(stun.Message) require.NoError(t, UseCandidate().AddTo(msg))
_, err := m1.Write(m.Raw) msg.Encode()
require.NoError(t, err)
require.True(t, UseCandidate().IsSet(m1)) msg2 := &stun.Message{Raw: append([]byte{}, msg.Raw...)}
require.NoError(t, msg2.Decode())
require.True(t, UseCandidate().IsSet(msg2))
} }