Files
ice/transport_test.go
2025-09-14 15:34:52 -04:00

359 lines
8.1 KiB
Go

// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
//go:build !js
// +build !js
package ice
import (
"context"
"net"
"net/netip"
"sync"
"testing"
"time"
"github.com/pion/stun/v3"
"github.com/pion/transport/v3/test"
"github.com/stretchr/testify/require"
)
func TestStressDuplex(t *testing.T) {
// Check for leaking routines
defer test.CheckRoutines(t)()
// Limit runtime in case of deadlocks
defer test.TimeOut(time.Second * 20).Stop()
// Run the test
stressDuplex(t)
}
func testTimeout(t *testing.T, conn *Conn, timeout time.Duration) {
t.Helper()
const pollRate = 100 * time.Millisecond
const margin = 20 * time.Millisecond // Allow 20msec error in time
ticker := time.NewTicker(pollRate)
defer func() {
ticker.Stop()
require.NoError(t, conn.Close())
}()
startedAt := time.Now()
for cnt := time.Duration(0); cnt <= timeout+defaultKeepaliveInterval+pollRate; cnt += pollRate {
<-ticker.C
var cs ConnectionState
require.NoError(t, conn.agent.loop.Run(context.Background(), func(_ context.Context) {
cs = conn.agent.connectionState
}))
if cs != ConnectionStateConnected {
elapsed := time.Since(startedAt)
require.Less(t, timeout, elapsed+margin)
return
}
}
t.Fatalf("Connection failed to time out in time. (expected timeout: %v)", timeout) //nolint
}
func TestTimeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
// Check for leaking routines
defer test.CheckRoutines(t)()
// Limit runtime in case of deadlocks
defer test.TimeOut(time.Second * 20).Stop()
t.Run("WithoutDisconnectTimeout", func(t *testing.T) {
ca, cb := pipe(t, nil)
require.NoError(t, cb.Close())
testTimeout(t, ca, defaultDisconnectedTimeout)
})
t.Run("WithDisconnectTimeout", func(t *testing.T) {
ca, cb := pipeWithTimeout(t, 5*time.Second, 3*time.Second)
require.NoError(t, cb.Close())
testTimeout(t, ca, 5*time.Second)
})
}
func TestReadClosed(t *testing.T) {
// Check for leaking routines
defer test.CheckRoutines(t)()
// Limit runtime in case of deadlocks
defer test.TimeOut(time.Second * 20).Stop()
ca, cb := pipe(t, nil)
require.NoError(t, ca.Close())
require.NoError(t, cb.Close())
empty := make([]byte, 10)
_, err := ca.Read(empty)
require.Error(t, err)
}
func stressDuplex(t *testing.T) {
t.Helper()
ca, cb := pipe(t, nil)
defer func() {
require.NoError(t, ca.Close())
require.NoError(t, cb.Close())
}()
opt := test.Options{
MsgSize: 10,
MsgCount: 1, // Order not reliable due to UDP & potentially multiple candidate pairs.
}
require.NoError(t, test.StressDuplex(ca, cb, opt))
}
func gatherAndExchangeCandidates(t *testing.T, aAgent, bAgent *Agent) {
t.Helper()
var wg sync.WaitGroup
wg.Add(2)
require.NoError(t, aAgent.OnCandidate(func(candidate Candidate) {
if candidate == nil {
wg.Done()
}
}))
require.NoError(t, aAgent.GatherCandidates())
require.NoError(t, bAgent.OnCandidate(func(candidate Candidate) {
if candidate == nil {
wg.Done()
}
}))
require.NoError(t, bAgent.GatherCandidates())
wg.Wait()
candidates, err := aAgent.GetLocalCandidates()
require.NoError(t, err)
for _, c := range candidates {
if addr, parseErr := netip.ParseAddr(c.Address()); parseErr == nil {
require.False(t, shouldFilterLocationTrackedIP(addr))
}
candidateCopy, copyErr := c.copy()
require.NoError(t, copyErr)
require.NoError(t, bAgent.AddRemoteCandidate(candidateCopy))
}
candidates, err = bAgent.GetLocalCandidates()
require.NoError(t, err)
for _, c := range candidates {
candidateCopy, copyErr := c.copy()
require.NoError(t, copyErr)
require.NoError(t, aAgent.AddRemoteCandidate(candidateCopy))
}
}
func connect(t *testing.T, aAgent, bAgent *Agent) (*Conn, *Conn) {
t.Helper()
gatherAndExchangeCandidates(t, aAgent, bAgent)
accepted := make(chan struct{})
var aConn *Conn
go func() {
var acceptErr error
bUfrag, bPwd, acceptErr := bAgent.GetLocalUserCredentials()
require.NoError(t, acceptErr)
aConn, acceptErr = aAgent.Accept(context.TODO(), bUfrag, bPwd)
require.NoError(t, acceptErr)
close(accepted)
}()
aUfrag, aPwd, err := aAgent.GetLocalUserCredentials()
require.NoError(t, err)
bConn, err := bAgent.Dial(context.TODO(), aUfrag, aPwd)
require.NoError(t, err)
// Ensure accepted
<-accepted
return aConn, bConn
}
func pipe(t *testing.T, defaultConfig *AgentConfig) (*Conn, *Conn) {
t.Helper()
var urls []*stun.URI
aNotifier, aConnected := onConnected()
bNotifier, bConnected := onConnected()
cfg := &AgentConfig{}
if defaultConfig != nil {
*cfg = *defaultConfig
}
cfg.Urls = urls
cfg.NetworkTypes = supportedNetworkTypes()
aAgent, err := NewAgent(cfg)
require.NoError(t, err)
require.NoError(t, aAgent.OnConnectionStateChange(aNotifier))
bAgent, err := NewAgent(cfg)
require.NoError(t, err)
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))
aConn, bConn := connect(t, aAgent, bAgent)
// Ensure pair selected
// Note: this assumes ConnectionStateConnected is thrown after selecting the final pair
<-aConnected
<-bConnected
return aConn, bConn
}
func pipeWithTimeout(t *testing.T, disconnectTimeout time.Duration, iceKeepalive time.Duration) (*Conn, *Conn) {
t.Helper()
var urls []*stun.URI
aNotifier, aConnected := onConnected()
bNotifier, bConnected := onConnected()
cfg := &AgentConfig{
Urls: urls,
DisconnectedTimeout: &disconnectTimeout,
KeepaliveInterval: &iceKeepalive,
NetworkTypes: supportedNetworkTypes(),
}
aAgent, err := NewAgent(cfg)
require.NoError(t, err)
require.NoError(t, aAgent.OnConnectionStateChange(aNotifier))
bAgent, err := NewAgent(cfg)
require.NoError(t, err)
require.NoError(t, bAgent.OnConnectionStateChange(bNotifier))
aConn, bConn := connect(t, aAgent, bAgent)
// Ensure pair selected
// Note: this assumes ConnectionStateConnected is thrown after selecting the final pair
<-aConnected
<-bConnected
return aConn, bConn
}
func onConnected() (func(ConnectionState), chan struct{}) {
done := make(chan struct{})
return func(state ConnectionState) {
if state == ConnectionStateConnected {
close(done)
}
}, done
}
func randomPort(tb testing.TB) int {
tb.Helper()
conn, err := net.ListenPacket("udp4", "127.0.0.1:0") // nolint: noctx
if err != nil {
tb.Fatalf("failed to pickPort: %v", err)
}
defer func() {
_ = conn.Close()
}()
switch addr := conn.LocalAddr().(type) {
case *net.UDPAddr:
return addr.Port
default:
tb.Fatalf("unknown addr type %T", addr)
return 0
}
}
func TestConnStats(t *testing.T) {
// Check for leaking routines
defer test.CheckRoutines(t)()
// Limit runtime in case of deadlocks
defer test.TimeOut(time.Second * 20).Stop()
ca, cb := pipe(t, nil)
_, err := ca.Write(make([]byte, 10))
require.NoError(t, err)
defer closePipe(t, ca, cb)
var wg sync.WaitGroup
wg.Add(1)
go func() {
buf := make([]byte, 10)
_, err := cb.Read(buf)
require.NoError(t, err)
wg.Done()
}()
wg.Wait()
require.Equal(t, uint64(10), ca.BytesSent())
require.Equal(t, uint64(10), cb.BytesReceived())
}
func TestAgent_connect_ErrEarly(t *testing.T) {
defer test.CheckRoutines(t)()
cfg := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
}
a, err := NewAgent(cfg)
require.NoError(t, err)
require.NoError(t, a.Close())
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// isControlling = true
conn, cerr := a.connect(ctx, true, "ufragX", "pwdX")
require.Nil(t, conn)
require.Error(t, cerr, "expected error from a.loop.Err() short-circuit")
}
func TestConn_Write_RejectsSTUN(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(10 * time.Second).Stop()
cfg := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
}
a, err := NewAgent(cfg)
require.NoError(t, err)
defer func() {
_ = a.Close()
}()
c := &Conn{agent: a}
require.Nil(t, c.agent.getSelectedPair(), "precondition: no selected pair")
msg := stun.New()
msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
msg.Encode()
n, werr := c.Write(msg.Raw)
require.Zero(t, n)
require.ErrorIs(t, werr, errWriteSTUNMessageToIceConn)
}