Update lint rules, force testify/assert for tests

Use testify's assert package instead of the standard library's testing
package.
This commit is contained in:
Sean DuBois
2025-04-22 13:44:26 -04:00
parent dd072edae9
commit f32c107a62
28 changed files with 388 additions and 869 deletions

View File

@@ -19,12 +19,16 @@ linters-settings:
recommendations:
- errors
forbidigo:
analyze-types: true
forbid:
- ^fmt.Print(f|ln)?$
- ^log.(Panic|Fatal|Print)(f|ln)?$
- ^os.Exit$
- ^panic$
- ^print(ln)?$
- p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$
pkg: ^testing$
msg: "use testify/assert instead"
varnamelen:
max-distance: 12
min-name-length: 2
@@ -127,9 +131,12 @@ issues:
exclude-dirs-use-default: false
exclude-rules:
# Allow complex tests and examples, better to be self contained
- path: (examples|main\.go|_test\.go)
- path: (examples|main\.go)
linters:
- gocognit
- forbidigo
- path: _test\.go
linters:
- gocognit
# Allow forbidden identifiers in CLI commands

View File

@@ -147,7 +147,7 @@ func TestActiveTCP(t *testing.T) {
req.NoError(err)
req.NotNil(activeAgent)
passiveAgentConn, activeAgenConn := connect(passiveAgent, activeAgent)
passiveAgentConn, activeAgenConn := connect(t, passiveAgent, activeAgent)
req.NotNil(passiveAgentConn)
req.NotNil(activeAgenConn)
@@ -220,7 +220,7 @@ func TestActiveTCP_NonBlocking(t *testing.T) {
require.NoError(t, aAgent.AddRemoteCandidate(invalidCandidate))
require.NoError(t, bAgent.AddRemoteCandidate(invalidCandidate))
connect(aAgent, bAgent)
connect(t, aAgent, bAgent)
<-isConnected
}
@@ -284,7 +284,7 @@ func TestActiveTCP_Respect_NetworkTypes(t *testing.T) {
require.NoError(t, aAgent.AddRemoteCandidate(invalidCandidate))
require.NoError(t, bAgent.AddRemoteCandidate(invalidCandidate))
connect(aAgent, bAgent)
connect(t, aAgent, bAgent)
<-isConnected
require.NoError(t, tcpListener.Close())

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/pion/transport/v3/test"
"github.com/stretchr/testify/assert"
)
func TestConnectionStateNotifier(t *testing.T) {
@@ -33,7 +34,7 @@ func TestConnectionStateNotifier(t *testing.T) {
}
select {
case <-updates:
t.Errorf("received more updates than expected")
t.Errorf("received more updates than expected") // nolint
case <-time.After(1 * time.Second):
}
close(done)
@@ -53,14 +54,11 @@ func TestConnectionStateNotifier(t *testing.T) {
done := make(chan struct{})
go func() {
for i := 0; i < 10000; i++ {
x := <-updates
if x != ConnectionState(i) {
t.Errorf("expected %d got %d", x, i)
}
assert.Equal(t, ConnectionState(i), <-updates)
}
select {
case <-updates:
t.Errorf("received more updates than expected")
t.Errorf("received more updates than expected") // nolint
case <-time.After(1 * time.Second):
}
close(done)

File diff suppressed because it is too large Load Diff

View File

@@ -71,7 +71,7 @@ func TestMuxAgent(t *testing.T) {
require.NoError(t, agent.Close())
}()
conn, muxedConn := connect(agent, muxedA)
conn, muxedConn := connect(t, agent, muxedA)
pair := muxedA.getSelectedPair()
require.NotNil(t, pair)

View File

@@ -80,7 +80,7 @@ func TestRelayOnlyConnection(t *testing.T) {
bNotifier, bConnected := onConnected()
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))
connect(aAgent, bAgent)
connect(t, aAgent, bAgent)
<-aConnected
<-bConnected
}

View File

@@ -73,7 +73,7 @@ func TestServerReflexiveOnlyConnection(t *testing.T) {
bNotifier, bConnected := onConnected()
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))
connect(aAgent, bAgent)
connect(t, aAgent, bAgent)
<-aConnected
<-bConnected
}

View File

@@ -176,9 +176,7 @@ func TestCandidatePriority(t *testing.T) {
WantPriority: 16777215,
},
} {
if got, want := test.Candidate.Priority(), test.WantPriority; got != want {
t.Fatalf("Candidate(%v).Priority() = %d, want %d", test.Candidate, got, want)
}
require.Equal(t, test.Candidate.Priority(), test.WantPriority)
}
}
@@ -271,9 +269,7 @@ func mustCandidateHost(t *testing.T, conf *CandidateHostConfig) Candidate {
t.Helper()
cand, err := NewCandidateHost(conf)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
return cand
}
@@ -286,9 +282,7 @@ func mustCandidateHostWithExtensions(
t.Helper()
cand, err := NewCandidateHost(conf)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
cand.setExtensions(extensions)
@@ -299,9 +293,7 @@ func mustCandidateRelay(t *testing.T, conf *CandidateRelayConfig) Candidate {
t.Helper()
cand, err := NewCandidateRelay(conf)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
return cand
}
@@ -314,9 +306,7 @@ func mustCandidateRelayWithExtensions(
t.Helper()
cand, err := NewCandidateRelay(conf)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
cand.setExtensions(extensions)
@@ -327,9 +317,7 @@ func mustCandidateServerReflexive(t *testing.T, conf *CandidateServerReflexiveCo
t.Helper()
cand, err := NewCandidateServerReflexive(conf)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
return cand
}
@@ -342,9 +330,7 @@ func mustCandidateServerReflexiveWithExtensions(
t.Helper()
cand, err := NewCandidateServerReflexive(conf)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
cand.setExtensions(extensions)
@@ -359,9 +345,7 @@ func mustCandidatePeerReflexiveWithExtensions(
t.Helper()
cand, err := NewCandidatePeerReflexive(conf)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
cand.setExtensions(extensions)
@@ -603,7 +587,7 @@ func TestCandidateWriteTo(t *testing.T) {
})
require.NoError(t, err, "error creating test TCP listener")
conn, err := net.DialTCP("tcp", nil, listener.Addr().(*net.TCPAddr))
conn, err := net.DialTCP("tcp", nil, listener.Addr().(*net.TCPAddr)) // nolint
require.NoError(t, err, "error dialing test TCP connection")
loggerFactory := logging.NewDefaultLoggerFactory()
@@ -1049,9 +1033,7 @@ func TestCandidateGetExtension(t *testing.T) {
Priority: 500,
Foundation: "750",
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
candidate.setExtensions(extensions)
@@ -1086,9 +1068,7 @@ func TestCandidateGetExtension(t *testing.T) {
Priority: 500,
Foundation: "750",
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
candidate.setExtensions(extensions)
@@ -1111,9 +1091,7 @@ func TestCandidateGetExtension(t *testing.T) {
Foundation: "750",
TCPType: TCPTypeActive,
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
tcpType, ok := candidate.GetExtension("tcptype")
@@ -1136,9 +1114,7 @@ func TestCandidateGetExtension(t *testing.T) {
Priority: 500,
Foundation: "750",
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
tcpType, ok = candidate2.GetExtension("tcptype")
@@ -1163,9 +1139,7 @@ func TestBaseCandidateMarshalExtensions(t *testing.T) {
Priority: 500,
Foundation: "750",
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
candidate.setExtensions(extensions)
@@ -1181,9 +1155,7 @@ func TestBaseCandidateMarshalExtensions(t *testing.T) {
Priority: 500,
Foundation: "750",
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
value := candidate.marshalExtensions()
require.Equal(t, "", value)
@@ -1198,9 +1170,7 @@ func TestBaseCandidateMarshalExtensions(t *testing.T) {
Foundation: "750",
TCPType: TCPTypeActive,
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
value := candidate.marshalExtensions()
require.Equal(t, "tcptype active", value)
@@ -1296,9 +1266,7 @@ func TestBaseCandidateExtensionsEqual(t *testing.T) {
Priority: 500,
Foundation: "750",
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
cand.setExtensions(testCase.extensions1)
@@ -1316,9 +1284,7 @@ func TestCandidateAddExtension(t *testing.T) {
Priority: 500,
Foundation: "750",
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"}))
@@ -1335,9 +1301,7 @@ func TestCandidateAddExtension(t *testing.T) {
Priority: 500,
Foundation: "750",
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "d"}))
@@ -1355,9 +1319,7 @@ func TestCandidateAddExtension(t *testing.T) {
Foundation: "750",
TCPType: TCPTypeActive,
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
ext, ok := candidate.GetExtension("tcptype")
require.True(t, ok)
@@ -1380,9 +1342,7 @@ func TestCandidateAddExtension(t *testing.T) {
Priority: 500,
Foundation: "750",
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
require.NoError(t, candidate.AddExtension(CandidateExtension{"tcptype", "active"}))
@@ -1401,9 +1361,7 @@ func TestCandidateAddExtension(t *testing.T) {
Priority: 500,
Foundation: "750",
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
require.Error(t, candidate.AddExtension(CandidateExtension{"", ""}))
@@ -1424,9 +1382,7 @@ func TestCandidateRemoveExtension(t *testing.T) {
Priority: 500,
Foundation: "750",
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"}))
@@ -1445,9 +1401,7 @@ func TestCandidateRemoveExtension(t *testing.T) {
Priority: 500,
Foundation: "750",
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"}))
require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"}))
@@ -1467,9 +1421,7 @@ func TestCandidateRemoveExtension(t *testing.T) {
Foundation: "750",
TCPType: TCPTypeActive,
})
if err != nil {
t.Error(err)
}
require.NoError(t, err)
// tcptype extension should be removed, even if it's not in the extensions list (Not Parsed)
require.True(t, candidate.RemoveExtension("tcptype"))

View File

@@ -115,9 +115,7 @@ func TestCandidatePairPriority(t *testing.T) {
WantPriority: 72057593987596287,
},
} {
if got, want := test.Pair.priority(), test.WantPriority; got != want {
t.Fatalf("CandidatePair(%v).Priority() = %d, want %d", test.Pair, got, want)
}
require.Equal(t, test.Pair.priority(), test.WantPriority)
}
}
@@ -125,9 +123,7 @@ func TestCandidatePairEquality(t *testing.T) {
pairA := newCandidatePair(hostCandidate(), srflxCandidate(), true)
pairB := newCandidatePair(hostCandidate(), srflxCandidate(), false)
if !pairA.equal(pairB) {
t.Fatalf("Expected %v to equal %v", pairA, pairB)
}
require.True(t, pairA.equal(pairB))
}
func TestNilCandidatePairString(t *testing.T) {

View File

@@ -200,15 +200,16 @@ func addVNetSTUN(wanNet *vnet.Net, loggerFactory logging.LoggerFactory) (*turn.S
return server, err
}
func connectWithVNet(aAgent, bAgent *Agent) (*Conn, *Conn) {
func connectWithVNet(t *testing.T, aAgent, bAgent *Agent) (*Conn, *Conn) {
t.Helper()
// Manual signaling
aUfrag, aPwd, err := aAgent.GetLocalUserCredentials()
check(err)
require.NoError(t, err)
bUfrag, bPwd, err := bAgent.GetLocalUserCredentials()
check(err)
require.NoError(t, err)
gatherAndExchangeCandidates(aAgent, bAgent)
gatherAndExchangeCandidates(t, aAgent, bAgent)
accepted := make(chan struct{})
var aConn *Conn
@@ -216,12 +217,12 @@ func connectWithVNet(aAgent, bAgent *Agent) (*Conn, *Conn) {
go func() {
var acceptErr error
aConn, acceptErr = aAgent.Accept(context.TODO(), bUfrag, bPwd)
check(acceptErr)
require.NoError(t, acceptErr)
close(accepted)
}()
bConn, err := bAgent.Dial(context.TODO(), aUfrag, aPwd)
check(err)
require.NoError(t, err)
// Ensure accepted
<-accepted
@@ -234,7 +235,8 @@ type agentTestConfig struct {
nat1To1IPCandidateType CandidateType
}
func pipeWithVNet(vnet *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*Conn, *Conn) {
func pipeWithVNet(t *testing.T, vnet *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*Conn, *Conn) {
t.Helper()
aNotifier, aConnected := onConnected()
bNotifier, bConnected := onConnected()
@@ -255,13 +257,8 @@ func pipeWithVNet(vnet *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig)
}
aAgent, err := NewAgent(cfg0)
if err != nil {
panic(err)
}
err = aAgent.OnConnectionStateChange(aNotifier)
if err != nil {
panic(err)
}
require.NoError(t, err)
require.NoError(t, aAgent.OnConnectionStateChange(aNotifier))
if a1TestConfig.nat1To1IPCandidateType != CandidateTypeUnspecified {
nat1To1IPs = []string{
@@ -278,15 +275,10 @@ func pipeWithVNet(vnet *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig)
}
bAgent, err := NewAgent(cfg1)
if err != nil {
panic(err)
}
err = bAgent.OnConnectionStateChange(bNotifier)
if err != nil {
panic(err)
}
require.NoError(t, err)
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))
aConn, bConn := connectWithVNet(aAgent, bAgent)
aConn, bConn := connectWithVNet(t, aAgent, bAgent)
// Ensure pair selected
// Note: this assumes ConnectionStateConnected is thrown after selecting the final pair
@@ -347,7 +339,7 @@ func TestConnectivityVNet(t *testing.T) {
stunServerURL,
},
}
ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
ca, cb := pipeWithVNet(t, vnet, a0TestConfig, a1TestConfig)
time.Sleep(1 * time.Second)
@@ -381,7 +373,7 @@ func TestConnectivityVNet(t *testing.T) {
stunServerURL,
},
}
ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
ca, cb := pipeWithVNet(t, vnet, a0TestConfig, a1TestConfig)
log.Debug("Closing...")
closePipe(t, ca, cb)
@@ -413,7 +405,7 @@ func TestConnectivityVNet(t *testing.T) {
a1TestConfig := &agentTestConfig{
urls: []*stun.URI{},
}
ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
ca, cb := pipeWithVNet(t, vnet, a0TestConfig, a1TestConfig)
log.Debug("Closing...")
closePipe(t, ca, cb)
@@ -445,7 +437,7 @@ func TestConnectivityVNet(t *testing.T) {
a1TestConfig := &agentTestConfig{
urls: []*stun.URI{},
}
ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
ca, cb := pipeWithVNet(t, vnet, a0TestConfig, a1TestConfig)
log.Debug("Closing...")
closePipe(t, ca, cb)
@@ -527,7 +519,7 @@ func TestDisconnectedToConnected(t *testing.T) {
controlledStateChanges <- c
}))
connectWithVNet(controllingAgent, controlledAgent)
connectWithVNet(t, controllingAgent, controlledAgent)
blockUntilStateSeen := func(expectedState ConnectionState, stateQueue chan ConnectionState) {
for s := range stateQueue {
if s == expectedState {
@@ -618,7 +610,7 @@ func TestWriteUseValidPair(t *testing.T) {
require.NoError(t, controlledAgent.Close())
}()
gatherAndExchangeCandidates(controllingAgent, controlledAgent)
gatherAndExchangeCandidates(t, controllingAgent, controlledAgent)
controllingUfrag, controllingPwd, err := controllingAgent.GetLocalUserCredentials()
require.NoError(t, err)

View File

@@ -136,7 +136,6 @@ var (
errParseRelatedAddr = errors.New("failed to parse related addresses")
errParseExtension = errors.New("failed to parse extension")
errParseTCPType = errors.New("failed to parse TCP type")
errRead = errors.New("failed to read")
errUDPMuxDisabled = errors.New("UDPMux is not enabled")
errUnknownRole = errors.New("unknown role")
errWrite = errors.New("failed to write")

View File

@@ -12,7 +12,6 @@ import (
"io"
"net"
"net/url"
"reflect"
"sort"
"strconv"
"sync"
@@ -78,19 +77,13 @@ func TestListenUDP(t *testing.T) {
require.NoError(t, err)
p, _ := strconv.Atoi(port)
if p < portMin || p > portMax {
t.Fatalf("listenUDP with port restriction [%d, %d] listened on incorrect port (%s)", portMin, portMax, port)
}
require.False(t, p < portMin || p > portMax)
result = append(result, p)
portRange = append(portRange, portMin+i)
}
if sort.IntsAreSorted(result) {
t.Fatalf("listenUDP with port restriction [%d, %d], ports result should be random", portMin, portMax)
}
require.False(t, sort.IntsAreSorted(result))
sort.Ints(result)
if !reflect.DeepEqual(result, portRange) {
t.Fatalf("listenUDP with port restriction [%d, %d], got:%v, want:%v", portMin, portMax, result, portRange)
}
require.Equal(t, result, portRange)
_, err = listenUDPInPortRange(agent.net, agent.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0})
require.Equal(t, err, ErrPort, "listenUDP with port restriction [%d, %d], did not return ErrPort", portMin, portMax)
}

View File

@@ -8,7 +8,6 @@ package ice
import (
"context"
"errors"
"fmt"
"net"
"testing"
@@ -38,36 +37,25 @@ func TestVNetGather(t *testing.T) { //nolint:cyclop
}()
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if len(localIPs) > 0 {
t.Fatal("should return no local IP")
}
require.Len(t, localIPs, 0)
require.NoError(t, err)
})
t.Run("Gather a dynamic IP address", func(t *testing.T) {
cider := "1.2.3.0/24"
_, ipNet, err := net.ParseCIDR(cider)
if err != nil {
t.Fatalf("Failed to parse CIDR: %s", err)
}
require.NoError(t, err)
router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: cider,
LoggerFactory: loggerFactory,
})
if err != nil {
t.Fatalf("Failed to create a router: %s", err)
}
require.NoError(t, err)
nw, err := vnet.NewNet(&vnet.NetConfig{})
if err != nil {
t.Fatalf("Failed to create a Net: %s", err)
}
require.NoError(t, err)
err = router.AddNet(nw)
if err != nil {
t.Fatalf("Failed to add a Net to the router: %s", err)
}
require.NoError(t, router.AddNet(nw))
a, err := NewAgent(&AgentConfig{
Net: nw,
@@ -78,18 +66,12 @@ func TestVNetGather(t *testing.T) { //nolint:cyclop
}()
_, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if len(localAddrs) == 0 {
t.Fatal("should have one local IP")
}
require.Len(t, localAddrs, 1)
require.NoError(t, err)
for _, addr := range localAddrs {
if addr.IsLoopback() {
t.Fatal("should not return loopback IP")
}
if !ipNet.Contains(addr.AsSlice()) {
t.Fatal("should be contained in the CIDR")
}
require.False(t, addr.IsLoopback())
require.True(t, ipNet.Contains(addr.AsSlice()))
}
})
@@ -98,24 +80,15 @@ func TestVNetGather(t *testing.T) { //nolint:cyclop
CIDR: "1.2.3.0/24",
LoggerFactory: loggerFactory,
})
if err != nil {
t.Fatalf("Failed to create a router: %s", err)
}
require.NoError(t, err)
nw, err := vnet.NewNet(&vnet.NetConfig{})
if err != nil {
t.Fatalf("Failed to create a Net: %s", err)
}
require.NoError(t, err)
err = router.AddNet(nw)
if err != nil {
t.Fatalf("Failed to add a Net to the router: %s", err)
}
require.NoError(t, router.AddNet(nw))
agent, err := NewAgent(&AgentConfig{Net: nw})
if err != nil {
t.Fatalf("Failed to create agent: %s", err)
}
require.NoError(t, err)
defer func() {
require.NoError(t, agent.Close())
}()
@@ -127,35 +100,22 @@ func TestVNetGather(t *testing.T) { //nolint:cyclop
[]NetworkType{NetworkTypeUDP4},
false,
)
if len(localAddrs) == 0 {
t.Fatal("localInterfaces found no interfaces, unable to test")
}
require.NotEqual(t, 0, len(localAddrs))
require.NoError(t, err)
ip := localAddrs[0].AsSlice()
conn, err := listenUDPInPortRange(agent.net, agent.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0})
if err != nil {
t.Fatalf("listenUDP error with no port restriction %v", err)
} else if conn == nil {
t.Fatalf("listenUDP error with no port restriction return a nil conn")
}
err = conn.Close()
if err != nil {
t.Fatalf("failed to close conn")
}
require.NoError(t, err)
require.NotNil(t, conn)
require.NoError(t, conn.Close())
_, err = listenUDPInPortRange(agent.net, agent.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
if !errors.Is(err, ErrPort) {
t.Fatal("listenUDP with invalid port range did not return ErrPort")
}
require.ErrorIs(t, ErrPort, err)
conn, err = listenUDPInPortRange(agent.net, agent.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
if err != nil {
t.Fatalf("listenUDP error with no port restriction %v", err)
} else if conn == nil {
t.Fatalf("listenUDP error with no port restriction return a nil conn")
}
require.NoError(t, err)
require.NotNil(t, conn)
defer func() {
require.NoError(t, conn.Close())
}()
@@ -163,9 +123,7 @@ func TestVNetGather(t *testing.T) { //nolint:cyclop
_, port, err := net.SplitHostPort(conn.LocalAddr().String())
require.NoError(t, err)
if port != "5000" {
t.Fatalf("listenUDP with port restriction of 5000 listened on incorrect port (%s)", port)
}
require.Equal(t, "5000", port)
})
}
@@ -205,9 +163,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) { //nolint:cyclop
nw, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{localIP0, localIP1},
})
if err != nil {
t.Fatalf("Failed to create a Net: %s", err)
}
require.NoError(t, err)
err = lan.AddNet(nw)
require.NoError(t, err, "should succeed")
@@ -242,38 +198,22 @@ func TestVNetGatherWithNAT1To1(t *testing.T) { //nolint:cyclop
candidates, err := agent.GetLocalCandidates()
require.NoError(t, err, "should succeed")
if len(candidates) != 2 {
t.Fatal("There must be two candidates")
}
require.Len(t, candidates, 2)
lAddr := [2]*net.UDPAddr{nil, nil}
for i, candi := range candidates {
lAddr[i] = candi.(*CandidateHost).conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert
if candi.Port() != lAddr[i].Port {
t.Fatalf("Unexpected candidate port: %d", candi.Port())
}
require.Equal(t, candi.Port(), lAddr[i].Port)
}
if candidates[0].Address() == externalIP0 { //nolint:nestif
if candidates[1].Address() != externalIP1 {
t.Fatalf("Unexpected candidate IP: %s", candidates[1].Address())
}
if lAddr[0].IP.String() != localIP0 {
t.Fatalf("Unexpected listen IP: %s", lAddr[0].IP.String())
}
if lAddr[1].IP.String() != localIP1 {
t.Fatalf("Unexpected listen IP: %s", lAddr[1].IP.String())
}
require.Equal(t, candidates[1].Address(), externalIP1)
require.Equal(t, lAddr[0].IP.String(), localIP0)
require.Equal(t, lAddr[1].IP.String(), localIP1)
} else if candidates[0].Address() == externalIP1 {
if candidates[1].Address() != externalIP0 {
t.Fatalf("Unexpected candidate IP: %s", candidates[1].Address())
}
if lAddr[0].IP.String() != localIP1 {
t.Fatalf("Unexpected listen IP: %s", lAddr[0].IP.String())
}
if lAddr[1].IP.String() != localIP0 {
t.Fatalf("Unexpected listen IP: %s", lAddr[1].IP.String())
}
require.Equal(t, candidates[1].Address(), externalIP0)
require.Equal(t, lAddr[0].IP.String(), localIP1)
require.Equal(t, lAddr[1].IP.String(), localIP0)
}
})
@@ -304,9 +244,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) { //nolint:cyclop
"10.0.0.1",
},
})
if err != nil {
t.Fatalf("Failed to create a Net: %s", err)
}
require.NoError(t, err)
err = lan.AddNet(nw)
require.NoError(t, err, "should succeed")
@@ -344,9 +282,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) { //nolint:cyclop
candidates, err := agent.GetLocalCandidates()
require.NoError(t, err, "should succeed")
if len(candidates) != 2 {
t.Fatalf("Expected two candidates. actually %d", len(candidates))
}
require.Len(t, candidates, 2)
var candiHost *CandidateHost
var candiSrflx *CandidateServerReflexive
@@ -358,7 +294,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) { //nolint:cyclop
case *CandidateServerReflexive:
candiSrflx = candi
default:
t.Fatal("Unexpected candidate type")
t.Fatal("Unexpected candidate type") // nolint
}
}
@@ -377,18 +313,11 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
CIDR: "1.2.3.0/24",
LoggerFactory: loggerFactory,
})
if err != nil {
t.Fatalf("Failed to create a router: %s", err)
}
require.NoError(t, err)
nw, err := vnet.NewNet(&vnet.NetConfig{})
if err != nil {
t.Fatalf("Failed to create a Net: %s", err)
}
if err = router.AddNet(nw); err != nil {
t.Fatalf("Failed to add a Net to the router: %s", err)
}
require.NoError(t, err)
require.NoError(t, router.AddNet(nw))
t.Run("InterfaceFilter should exclude the interface", func(t *testing.T) {
agent, err := NewAgent(&AgentConfig{
@@ -412,10 +341,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
false,
)
require.NoError(t, err)
if len(localIPs) != 0 {
t.Fatal("InterfaceFilter should have excluded everything")
}
require.Len(t, localIPs, 0)
})
t.Run("IPFilter should exclude the IP", func(t *testing.T) {
@@ -440,10 +366,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
false,
)
require.NoError(t, err)
if len(localIPs) != 0 {
t.Fatal("IPFilter should have excluded everything")
}
require.Len(t, localIPs, 0)
})
t.Run("InterfaceFilter should not exclude the interface", func(t *testing.T) {
@@ -468,10 +391,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
false,
)
require.NoError(t, err)
if len(localIPs) == 0 {
t.Fatal("InterfaceFilter should not have excluded anything")
}
require.Len(t, localIPs, 1)
})
}

View File

@@ -4,69 +4,52 @@
package ice
import (
"errors"
"testing"
"github.com/pion/stun/v3"
"github.com/stretchr/testify/require"
)
func TestControlled_GetFrom(t *testing.T) { //nolint:dupl
m := new(stun.Message)
var attrCtr AttrControlled
if err := attrCtr.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error")
}
if err := m.Build(stun.BindingRequest, &attrCtr); err != nil {
t.Error(err)
}
require.ErrorIs(t, stun.ErrAttributeNotFound, attrCtr.GetFrom(m))
require.NoError(t, m.Build(stun.BindingRequest, &attrCtr))
m1 := new(stun.Message)
if _, err := m1.Write(m.Raw); err != nil {
t.Error(err)
}
_, err := m1.Write(m.Raw)
require.NoError(t, err)
var c1 AttrControlled
if err := c1.GetFrom(m1); err != nil {
t.Error(err)
}
if c1 != attrCtr {
t.Error("not equal")
}
require.NoError(t, c1.GetFrom(m1))
require.Equal(t, c1, attrCtr)
t.Run("IncorrectSize", func(t *testing.T) {
m3 := new(stun.Message)
m3.Add(stun.AttrICEControlled, make([]byte, 100))
var c2 AttrControlled
if err := c2.GetFrom(m3); !stun.IsAttrSizeInvalid(err) {
t.Error("should error")
}
require.True(t, stun.IsAttrSizeInvalid(c2.GetFrom(m3)))
})
}
func TestControlling_GetFrom(t *testing.T) { //nolint:dupl
m := new(stun.Message)
var attrCtr AttrControlling
if err := attrCtr.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error")
}
if err := m.Build(stun.BindingRequest, &attrCtr); err != nil {
t.Error(err)
}
require.ErrorIs(t, stun.ErrAttributeNotFound, attrCtr.GetFrom(m))
require.NoError(t, m.Build(stun.BindingRequest, &attrCtr))
m1 := new(stun.Message)
if _, err := m1.Write(m.Raw); err != nil {
t.Error(err)
}
_, err := m1.Write(m.Raw)
require.NoError(t, err)
var c1 AttrControlling
if err := c1.GetFrom(m1); err != nil {
t.Error(err)
}
if c1 != attrCtr {
t.Error("not equal")
}
require.NoError(t, c1.GetFrom(m1))
require.Equal(t, c1, attrCtr)
t.Run("IncorrectSize", func(t *testing.T) {
m3 := new(stun.Message)
m3.Add(stun.AttrICEControlling, make([]byte, 100))
var c2 AttrControlling
if err := c2.GetFrom(m3); !stun.IsAttrSizeInvalid(err) {
t.Error("should error")
}
require.True(t, stun.IsAttrSizeInvalid(c2.GetFrom(m3)))
})
}
@@ -74,70 +57,49 @@ func TestControl_GetFrom(t *testing.T) { //nolint:cyclop
t.Run("Blank", func(t *testing.T) {
m := new(stun.Message)
var c AttrControl
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error")
}
require.ErrorIs(t, stun.ErrAttributeNotFound, c.GetFrom(m))
})
t.Run("Controlling", func(t *testing.T) { //nolint:dupl
m := new(stun.Message)
var attCtr AttrControl
if err := attCtr.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error")
}
require.ErrorIs(t, stun.ErrAttributeNotFound, attCtr.GetFrom(m))
attCtr.Role = Controlling
attCtr.Tiebreaker = 4321
if err := m.Build(stun.BindingRequest, &attCtr); err != nil {
t.Error(err)
}
require.NoError(t, m.Build(stun.BindingRequest, &attCtr))
m1 := new(stun.Message)
if _, err := m1.Write(m.Raw); err != nil {
t.Error(err)
}
_, err := m1.Write(m.Raw)
require.NoError(t, err)
var c1 AttrControl
if err := c1.GetFrom(m1); err != nil {
t.Error(err)
}
if c1 != attCtr {
t.Error("not equal")
}
require.NoError(t, c1.GetFrom(m1))
require.Equal(t, c1, attCtr)
t.Run("IncorrectSize", func(t *testing.T) {
m3 := new(stun.Message)
m3.Add(stun.AttrICEControlling, make([]byte, 100))
var c2 AttrControl
if err := c2.GetFrom(m3); !stun.IsAttrSizeInvalid(err) {
t.Error("should error")
}
err := c2.GetFrom(m3)
require.True(t, stun.IsAttrSizeInvalid(err))
})
})
t.Run("Controlled", func(t *testing.T) { //nolint:dupl
m := new(stun.Message)
var attrCtrl AttrControl
if err := attrCtrl.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error")
}
require.ErrorIs(t, stun.ErrAttributeNotFound, attrCtrl.GetFrom(m))
attrCtrl.Role = Controlled
attrCtrl.Tiebreaker = 1234
if err := m.Build(stun.BindingRequest, &attrCtrl); err != nil {
t.Error(err)
}
require.NoError(t, m.Build(stun.BindingRequest, &attrCtrl))
m1 := new(stun.Message)
if _, err := m1.Write(m.Raw); err != nil {
t.Error(err)
}
_, err := m1.Write(m.Raw)
require.NoError(t, err)
var c1 AttrControl
if err := c1.GetFrom(m1); err != nil {
t.Error(err)
}
if c1 != attrCtrl {
t.Error("not equal")
}
require.NoError(t, c1.GetFrom(m1))
require.Equal(t, c1, attrCtrl)
t.Run("IncorrectSize", func(t *testing.T) {
m3 := new(stun.Message)
m3.Add(stun.AttrICEControlling, make([]byte, 100))
var c2 AttrControl
if err := c2.GetFrom(m3); !stun.IsAttrSizeInvalid(err) {
t.Error("should error")
}
err := c2.GetFrom(m3)
require.True(t, stun.IsAttrSizeInvalid(err))
})
})
}

View File

@@ -65,7 +65,7 @@ func TestMulticastDNSOnlyConnection(t *testing.T) {
bNotifier, bConnected := onConnected()
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))
connect(aAgent, bAgent)
connect(t, aAgent, bAgent)
<-aConnected
<-bConnected
})
@@ -124,7 +124,7 @@ func TestMulticastDNSMixedConnection(t *testing.T) {
bNotifier, bConnected := onConnected()
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))
connect(aAgent, bAgent)
connect(t, aAgent, bAgent)
<-aConnected
<-bConnected
})
@@ -195,7 +195,5 @@ func TestGenerateMulticastDNSName(t *testing.T) {
`^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-4[0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}.local+$`,
).MatchString
if !isMDNSName(name) {
t.Fatalf("mDNS name must be UUID v4 + \".local\" suffix, got %s", name)
}
require.True(t, isMDNSName(name))
}

View File

@@ -13,25 +13,11 @@ import (
)
func TestIsSupportedIPv6Partial(t *testing.T) {
if isSupportedIPv6Partial(net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}) {
t.Errorf("isSupportedIPv6Partial returned true with IPv4-compatible IPv6 address")
}
if isSupportedIPv6Partial(net.ParseIP("fec0::2333")) {
t.Errorf("isSupportedIPv6Partial returned true with IPv6 site-local unicast address")
}
if !isSupportedIPv6Partial(net.ParseIP("fe80::2333")) {
t.Errorf("isSupportedIPv6Partial returned false with IPv6 link-local address")
}
if !isSupportedIPv6Partial(net.ParseIP("ff02::2333")) {
t.Errorf("isSupportedIPv6Partial returned false with IPv6 link-local multicast address")
}
if !isSupportedIPv6Partial(net.ParseIP("2001::1")) {
t.Errorf("isSupportedIPv6Partial returned false with IPv6 global unicast address")
}
require.False(t, isSupportedIPv6Partial(net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}))
require.False(t, isSupportedIPv6Partial(net.ParseIP("fec0::2333")))
require.True(t, isSupportedIPv6Partial(net.ParseIP("fe80::2333")))
require.True(t, isSupportedIPv6Partial(net.ParseIP("ff02::2333")))
require.True(t, isSupportedIPv6Partial(net.ParseIP("2001::1")))
}
func TestCreateAddr(t *testing.T) {
@@ -67,7 +53,7 @@ func mustAddr(t *testing.T, ip net.IP) netip.Addr {
t.Helper()
addr, ok := netip.AddrFromSlice(ip)
if !ok {
t.Fatal(ipConvertError{ip})
t.Fatal(ipConvertError{ip}) // nolint
}
return addr

View File

@@ -46,13 +46,8 @@ func TestNetworkTypeParsing_Success(t *testing.T) {
},
} {
actual, err := determineNetworkType(test.inNetwork, mustAddr(t, test.inIP))
if err != nil {
t.Errorf("NetworkTypeParsing failed: %v", err)
}
if actual != test.expected {
t.Errorf("NetworkTypeParsing: '%s' -- input:%s expected:%s actual:%s",
test.name, test.inNetwork, test.expected, actual)
}
require.NoError(t, err)
require.Equal(t, test.expected, actual)
}
}
@@ -70,11 +65,8 @@ func TestNetworkTypeParsing_Failure(t *testing.T) {
ipv6,
},
} {
actual, err := determineNetworkType(test.inNetwork, mustAddr(t, test.inIP))
if err == nil {
t.Errorf("NetworkTypeParsing should fail: '%s' -- input:%s actual:%s",
test.name, test.inNetwork, actual)
}
_, err := determineNetworkType(test.inNetwork, mustAddr(t, test.inIP))
require.Error(t, err)
}
}

View File

@@ -4,38 +4,29 @@
package ice
import (
"errors"
"testing"
"github.com/pion/stun/v3"
"github.com/stretchr/testify/require"
)
func TestPriority_GetFrom(t *testing.T) { //nolint:dupl
m := new(stun.Message)
var priority PriorityAttr
if err := priority.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error")
}
if err := m.Build(stun.BindingRequest, &priority); err != nil {
t.Error(err)
}
require.ErrorIs(t, stun.ErrAttributeNotFound, priority.GetFrom(m))
require.NoError(t, m.Build(stun.BindingRequest, &priority))
m1 := new(stun.Message)
if _, err := m1.Write(m.Raw); err != nil {
t.Error(err)
}
_, err := m1.Write(m.Raw)
require.NoError(t, err)
var p1 PriorityAttr
if err := p1.GetFrom(m1); err != nil {
t.Error(err)
}
if p1 != priority {
t.Error("not equal")
}
require.NoError(t, p1.GetFrom(m1))
require.Equal(t, p1, priority)
t.Run("IncorrectSize", func(t *testing.T) {
m3 := new(stun.Message)
m3.Add(stun.AttrPriority, make([]byte, 100))
var p2 PriorityAttr
if err := p2.GetFrom(m3); !stun.IsAttrSizeInvalid(err) {
t.Error("should error")
}
require.True(t, stun.IsAttrSizeInvalid(p2.GetFrom(m3)))
})
}

View File

@@ -67,15 +67,10 @@ func TestRandomGeneratorCollision(t *testing.T) {
}
wg.Wait()
if len(rands) != num {
t.Fatal("Failed to generate randoms")
}
require.Len(t, rands, num)
for i := 0; i < num; i++ {
for j := i + 1; j < num; j++ {
if rands[i] == rands[j] {
t.Fatalf("generateRandString caused collision: %s == %s", rands[i], rands[j])
}
require.NotEqual(t, rands[i], rands[j])
}
}
}

View File

@@ -97,7 +97,7 @@ func TestBindingRequestHandler(t *testing.T) {
require.NoError(t, err)
require.NoError(t, controlledAgent.OnConnectionStateChange(bNotifier))
controlledConn, controllingConn := connect(controlledAgent, controllingAgent)
controlledConn, controllingConn := connect(t, controlledAgent, controllingAgent)
<-aConnected
<-bConnected

View File

@@ -61,7 +61,7 @@ func TestMultiTCPMux_Recv(t *testing.T) {
defer func() {
_ = pktConn.Close()
}()
conn, err := net.DialTCP("tcp", nil, pktConn.LocalAddr().(*net.TCPAddr))
conn, err := net.DialTCP("tcp", nil, pktConn.LocalAddr().(*net.TCPAddr)) // nolint
require.NoError(t, err, "error dialing test TCP connection")
msg := stun.New()

View File

@@ -51,7 +51,7 @@ func TestTCPMux_Recv(t *testing.T) {
require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) // nolint
require.NoError(t, err, "error dialing test TCP connection")
msg := stun.New()
@@ -150,7 +150,7 @@ func TestTCPMux_FirstPacketTimeout(t *testing.T) {
require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) // nolint
require.NoError(t, err, "error dialing test TCP connection")
defer func() {
_ = conn.Close()
@@ -192,7 +192,7 @@ func TestTCPMux_NoLeakForConnectionFromStun(t *testing.T) {
require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
t.Run("close connection from stun msg after timeout", func(t *testing.T) {
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) // nolint
require.NoError(t, err, "error dialing test TCP connection")
defer func() {
_ = conn.Close()
@@ -217,7 +217,7 @@ func TestTCPMux_NoLeakForConnectionFromStun(t *testing.T) {
})
t.Run("connection keep alive if access by user", func(t *testing.T) {
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) // nolint
require.NoError(t, err, "error dialing test TCP connection")
defer func() {
_ = conn.Close()

View File

@@ -38,10 +38,7 @@ func testTimeout(t *testing.T, conn *Conn, timeout time.Duration) {
ticker := time.NewTicker(pollRate)
defer func() {
ticker.Stop()
err := conn.Close()
if err != nil {
t.Error(err)
}
require.NoError(t, conn.Close())
}()
startedAt := time.Now()
@@ -51,26 +48,18 @@ func testTimeout(t *testing.T, conn *Conn, timeout time.Duration) {
var cs ConnectionState
err := conn.agent.loop.Run(context.Background(), func(_ context.Context) {
require.NoError(t, conn.agent.loop.Run(context.Background(), func(_ context.Context) {
cs = conn.agent.connectionState
})
if err != nil {
// We should never get here.
panic(err)
}
}))
if cs != ConnectionStateConnected {
elapsed := time.Since(startedAt)
if elapsed+margin < timeout {
t.Fatalf("Connection timed out %f msec early", elapsed.Seconds()*1000)
} else {
t.Logf("Connection timed out in %f msec", elapsed.Seconds()*1000)
require.Less(t, timeout, elapsed+margin)
return
}
return
}
}
t.Fatalf("Connection failed to time out in time. (expected timeout: %v)", timeout)
t.Fatalf("Connection failed to time out in time. (expected timeout: %v)", timeout) //nolint
}
func TestTimeout(t *testing.T) {
@@ -85,24 +74,14 @@ func TestTimeout(t *testing.T) {
defer test.TimeOut(time.Second * 20).Stop()
t.Run("WithoutDisconnectTimeout", func(t *testing.T) {
ca, cb := pipe(nil)
err := cb.Close()
if err != nil {
// We should never get here.
panic(err)
}
ca, cb := pipe(t, nil)
require.NoError(t, cb.Close())
testTimeout(t, ca, defaultDisconnectedTimeout)
})
t.Run("WithDisconnectTimeout", func(t *testing.T) {
ca, cb := pipeWithTimeout(5*time.Second, 3*time.Second)
err := cb.Close()
if err != nil {
// We should never get here.
panic(err)
}
ca, cb := pipeWithTimeout(t, 5*time.Second, 3*time.Second)
require.NoError(t, cb.Close())
testTimeout(t, ca, 5*time.Second)
})
}
@@ -114,31 +93,19 @@ func TestReadClosed(t *testing.T) {
// Limit runtime in case of deadlocks
defer test.TimeOut(time.Second * 20).Stop()
ca, cb := pipe(nil)
err := ca.Close()
if err != nil {
// We should never get here.
panic(err)
}
err = cb.Close()
if err != nil {
// We should never get here.
panic(err)
}
ca, cb := pipe(t, nil)
require.NoError(t, ca.Close())
require.NoError(t, cb.Close())
empty := make([]byte, 10)
_, err = ca.Read(empty)
if err == nil {
t.Fatalf("Reading from a closed channel should return an error")
}
_, err := ca.Read(empty)
require.Error(t, err)
}
func stressDuplex(t *testing.T) {
t.Helper()
ca, cb := pipe(nil)
ca, cb := pipe(t, nil)
defer func() {
require.NoError(t, ca.Close())
@@ -153,58 +120,52 @@ func stressDuplex(t *testing.T) {
require.NoError(t, test.StressDuplex(ca, cb, opt))
}
func check(err error) {
if err != nil {
panic(err)
}
}
func gatherAndExchangeCandidates(aAgent, bAgent *Agent) {
func gatherAndExchangeCandidates(t *testing.T, aAgent, bAgent *Agent) {
t.Helper()
var wg sync.WaitGroup
wg.Add(2)
check(aAgent.OnCandidate(func(candidate Candidate) {
require.NoError(t, aAgent.OnCandidate(func(candidate Candidate) {
if candidate == nil {
wg.Done()
}
}))
check(aAgent.GatherCandidates())
require.NoError(t, aAgent.GatherCandidates())
check(bAgent.OnCandidate(func(candidate Candidate) {
require.NoError(t, bAgent.OnCandidate(func(candidate Candidate) {
if candidate == nil {
wg.Done()
}
}))
check(bAgent.GatherCandidates())
require.NoError(t, bAgent.GatherCandidates())
wg.Wait()
candidates, err := aAgent.GetLocalCandidates()
check(err)
require.NoError(t, err)
for _, c := range candidates {
if addr, parseErr := netip.ParseAddr(c.Address()); parseErr == nil {
if shouldFilterLocationTrackedIP(addr) {
panic(addr)
}
require.False(t, shouldFilterLocationTrackedIP(addr))
}
candidateCopy, copyErr := c.copy()
check(copyErr)
check(bAgent.AddRemoteCandidate(candidateCopy))
require.NoError(t, copyErr)
require.NoError(t, bAgent.AddRemoteCandidate(candidateCopy))
}
candidates, err = bAgent.GetLocalCandidates()
check(err)
require.NoError(t, err)
for _, c := range candidates {
candidateCopy, copyErr := c.copy()
check(copyErr)
check(aAgent.AddRemoteCandidate(candidateCopy))
require.NoError(t, copyErr)
require.NoError(t, aAgent.AddRemoteCandidate(candidateCopy))
}
}
func connect(aAgent, bAgent *Agent) (*Conn, *Conn) {
gatherAndExchangeCandidates(aAgent, bAgent)
func connect(t *testing.T, aAgent, bAgent *Agent) (*Conn, *Conn) {
t.Helper()
gatherAndExchangeCandidates(t, aAgent, bAgent)
accepted := make(chan struct{})
var aConn *Conn
@@ -212,15 +173,15 @@ func connect(aAgent, bAgent *Agent) (*Conn, *Conn) {
go func() {
var acceptErr error
bUfrag, bPwd, acceptErr := bAgent.GetLocalUserCredentials()
check(acceptErr)
require.NoError(t, acceptErr)
aConn, acceptErr = aAgent.Accept(context.TODO(), bUfrag, bPwd)
check(acceptErr)
require.NoError(t, acceptErr)
close(accepted)
}()
aUfrag, aPwd, err := aAgent.GetLocalUserCredentials()
check(err)
require.NoError(t, err)
bConn, err := bAgent.Dial(context.TODO(), aUfrag, aPwd)
check(err)
require.NoError(t, err)
// Ensure accepted
<-accepted
@@ -228,7 +189,8 @@ func connect(aAgent, bAgent *Agent) (*Conn, *Conn) {
return aConn, bConn
}
func pipe(defaultConfig *AgentConfig) (*Conn, *Conn) {
func pipe(t *testing.T, defaultConfig *AgentConfig) (*Conn, *Conn) {
t.Helper()
var urls []*stun.URI
aNotifier, aConnected := onConnected()
@@ -243,15 +205,15 @@ func pipe(defaultConfig *AgentConfig) (*Conn, *Conn) {
cfg.NetworkTypes = supportedNetworkTypes()
aAgent, err := NewAgent(cfg)
check(err)
check(aAgent.OnConnectionStateChange(aNotifier))
require.NoError(t, err)
require.NoError(t, aAgent.OnConnectionStateChange(aNotifier))
bAgent, err := NewAgent(cfg)
check(err)
require.NoError(t, err)
check(bAgent.OnConnectionStateChange(bNotifier))
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))
aConn, bConn := connect(aAgent, bAgent)
aConn, bConn := connect(t, aAgent, bAgent)
// Ensure pair selected
// Note: this assumes ConnectionStateConnected is thrown after selecting the final pair
@@ -261,7 +223,8 @@ func pipe(defaultConfig *AgentConfig) (*Conn, *Conn) {
return aConn, bConn
}
func pipeWithTimeout(disconnectTimeout time.Duration, iceKeepalive time.Duration) (*Conn, *Conn) {
func pipeWithTimeout(t *testing.T, disconnectTimeout time.Duration, iceKeepalive time.Duration) (*Conn, *Conn) {
t.Helper()
var urls []*stun.URI
aNotifier, aConnected := onConnected()
@@ -275,14 +238,14 @@ func pipeWithTimeout(disconnectTimeout time.Duration, iceKeepalive time.Duration
}
aAgent, err := NewAgent(cfg)
check(err)
check(aAgent.OnConnectionStateChange(aNotifier))
require.NoError(t, err)
require.NoError(t, aAgent.OnConnectionStateChange(aNotifier))
bAgent, err := NewAgent(cfg)
check(err)
check(bAgent.OnConnectionStateChange(bNotifier))
require.NoError(t, err)
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))
aConn, bConn := connect(aAgent, bAgent)
aConn, bConn := connect(t, aAgent, bAgent)
// Ensure pair selected
// Note: this assumes ConnectionStateConnected is thrown after selecting the final pair
@@ -328,29 +291,22 @@ func TestConnStats(t *testing.T) {
// Limit runtime in case of deadlocks
defer test.TimeOut(time.Second * 20).Stop()
ca, cb := pipe(nil)
if _, err := ca.Write(make([]byte, 10)); err != nil {
t.Fatal("unexpected error trying to write")
}
ca, cb := pipe(t, nil)
_, err := ca.Write(make([]byte, 10))
require.NoError(t, err)
defer closePipe(t, ca, cb)
var wg sync.WaitGroup
wg.Add(1)
go func() {
buf := make([]byte, 10)
if _, err := cb.Read(buf); err != nil {
panic(errRead)
}
_, err := cb.Read(buf)
require.NoError(t, err)
wg.Done()
}()
wg.Wait()
if ca.BytesSent() != 10 {
t.Fatal("bytes sent don't match")
}
if cb.BytesReceived() != 10 {
t.Fatal("bytes received don't match")
}
require.Equal(t, uint64(10), ca.BytesSent())
require.Equal(t, uint64(10), cb.BytesReceived())
}

View File

@@ -53,7 +53,7 @@ func TestRemoteLocalAddr(t *testing.T) {
})
t.Run("Remote/Local Pair Match between Agents", func(t *testing.T) {
ca, cb := pipeWithVNet(builtVnet,
ca, cb := pipeWithVNet(t, builtVnet,
&agentTestConfig{
urls: []*stun.URI{stunServerURL},
},

View File

@@ -103,7 +103,7 @@ func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, u
// Try talking with each PacketConn
for _, pktConn := range pktConns {
remoteConn, err := net.DialUDP(network, nil, pktConn.LocalAddr().(*net.UDPAddr))
remoteConn, err := net.DialUDP(network, nil, pktConn.LocalAddr().(*net.UDPAddr)) // nolint
require.NoError(t, err, "error dialing test UDP connection")
testMuxConnectionPair(t, pktConn, remoteConn, ufrag)
}

View File

@@ -252,7 +252,7 @@ func verifyPacket(t *testing.T, b []byte, nextSeq uint32) {
func TestUDPMux_Agent_Restart(t *testing.T) {
oneSecond := time.Second
connA, connB := pipe(&AgentConfig{
connA, connB := pipe(t, &AgentConfig{
DisconnectedTimeout: &oneSecond,
FailedTimeout: &oneSecond,
})
@@ -279,7 +279,7 @@ func TestUDPMux_Agent_Restart(t *testing.T) {
require.NoError(t, connA.agent.SetRemoteCredentials(ufragB, pwdB))
require.NoError(t, connB.agent.SetRemoteCredentials(ufragA, pwdA))
gatherAndExchangeCandidates(connA.agent, connB.agent)
gatherAndExchangeCandidates(t, connA.agent, connB.agent)
// Wait until both have gone back to connected
<-aConnected

View File

@@ -51,7 +51,7 @@ func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag
_ = pktConn.Close()
}()
remoteConn, err := net.DialUDP(network, nil, &net.UDPAddr{
remoteConn, err := net.DialUDP(network, nil, &net.UDPAddr{ // nolint
Port: udpMux.LocalAddr().(*net.UDPAddr).Port,
})
require.NoError(t, err, "error dialing test UDP connection")

View File

@@ -7,21 +7,16 @@ import (
"testing"
"github.com/pion/stun/v3"
"github.com/stretchr/testify/require"
)
func TestUseCandidateAttr_AddTo(t *testing.T) {
m := new(stun.Message)
if UseCandidate().IsSet(m) {
t.Error("should not be set")
}
if err := m.Build(stun.BindingRequest, UseCandidate()); err != nil {
t.Error(err)
}
require.False(t, UseCandidate().IsSet(m))
require.NoError(t, m.Build(stun.BindingRequest, UseCandidate()))
m1 := new(stun.Message)
if _, err := m1.Write(m.Raw); err != nil {
t.Error(err)
}
if !UseCandidate().IsSet(m1) {
t.Error("should be set")
}
_, err := m1.Write(m.Raw)
require.NoError(t, err)
require.True(t, UseCandidate().IsSet(m1))
}