mirror of
https://github.com/pion/ice.git
synced 2025-09-27 03:45:54 +08:00
265 lines
7.5 KiB
Go
265 lines
7.5 KiB
Go
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
package ice
|
|
|
|
import (
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pion/logging"
|
|
"github.com/pion/stun/v3"
|
|
"github.com/pion/transport/v3/test"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
var _ TCPMux = &TCPMuxDefault{}
|
|
|
|
func TestTCPMux_Recv(t *testing.T) {
|
|
for name, bufSize := range map[string]int{
|
|
"no buffer": 0,
|
|
"buffered 4MB": 4 * 1024 * 1024,
|
|
} {
|
|
bufSize := bufSize
|
|
t.Run(name, func(t *testing.T) {
|
|
defer test.CheckRoutines(t)()
|
|
|
|
loggerFactory := logging.NewDefaultLoggerFactory()
|
|
|
|
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
|
IP: net.IP{127, 0, 0, 1},
|
|
Port: 0,
|
|
})
|
|
require.NoError(t, err, "error starting listener")
|
|
defer func() {
|
|
_ = listener.Close()
|
|
}()
|
|
|
|
tcpMux := NewTCPMuxDefault(TCPMuxParams{
|
|
Listener: listener,
|
|
Logger: loggerFactory.NewLogger("ice"),
|
|
ReadBufferSize: 20,
|
|
WriteBufferSize: bufSize,
|
|
})
|
|
|
|
defer func() {
|
|
_ = tcpMux.Close()
|
|
}()
|
|
|
|
require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
|
|
|
|
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) // nolint
|
|
require.NoError(t, err, "error dialing test TCP connection")
|
|
|
|
msg := stun.New()
|
|
msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
|
|
msg.Add(stun.AttrUsername, []byte("myufrag:otherufrag"))
|
|
msg.Encode()
|
|
|
|
n, err := writeStreamingPacket(conn, msg.Raw)
|
|
require.NoError(t, err, "error writing TCP STUN packet")
|
|
|
|
listenerAddr, ok := listener.Addr().(*net.TCPAddr)
|
|
require.True(t, ok)
|
|
|
|
pktConn, err := tcpMux.GetConnByUfrag("myufrag", false, listenerAddr.IP)
|
|
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
|
defer func() {
|
|
_ = pktConn.Close()
|
|
}()
|
|
|
|
recv := make([]byte, n)
|
|
n2, rAddr, err := pktConn.ReadFrom(recv)
|
|
require.NoError(t, err, "error receiving data")
|
|
require.Equal(t, conn.LocalAddr(), rAddr, "remote tcp address mismatch")
|
|
require.Equal(t, n, n2, "received byte size mismatch")
|
|
require.Equal(t, msg.Raw, recv, "received bytes mismatch")
|
|
|
|
// Check echo response
|
|
n, err = pktConn.WriteTo(recv, conn.LocalAddr())
|
|
require.NoError(t, err, "error writing echo STUN packet")
|
|
recvEcho := make([]byte, n)
|
|
n3, err := readStreamingPacket(conn, recvEcho)
|
|
require.NoError(t, err, "error receiving echo data")
|
|
require.Equal(t, n2, n3, "received byte size mismatch")
|
|
require.Equal(t, msg.Raw, recvEcho, "received bytes mismatch")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) {
|
|
defer test.CheckRoutines(t)()
|
|
|
|
loggerFactory := logging.NewDefaultLoggerFactory()
|
|
|
|
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
|
IP: net.IP{127, 0, 0, 1},
|
|
Port: 0,
|
|
})
|
|
require.NoError(t, err, "error starting listener")
|
|
defer func() {
|
|
_ = listener.Close()
|
|
}()
|
|
|
|
tcpMux := NewTCPMuxDefault(TCPMuxParams{
|
|
Listener: listener,
|
|
Logger: loggerFactory.NewLogger("ice"),
|
|
ReadBufferSize: 20,
|
|
})
|
|
|
|
defer func() {
|
|
_ = tcpMux.Close()
|
|
}()
|
|
|
|
listenerAddr, ok := listener.Addr().(*net.TCPAddr)
|
|
require.True(t, ok)
|
|
|
|
_, err = tcpMux.GetConnByUfrag("test", false, listenerAddr.IP)
|
|
require.NoError(t, err, "error getting conn by ufrag")
|
|
|
|
require.NoError(t, tcpMux.Close(), "error closing tcpMux")
|
|
|
|
conn, err := tcpMux.GetConnByUfrag("test", false, listenerAddr.IP)
|
|
require.Nil(t, conn, "should receive nil because mux is closed")
|
|
require.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed")
|
|
}
|
|
|
|
func TestTCPMux_FirstPacketTimeout(t *testing.T) {
|
|
defer test.CheckRoutines(t)()
|
|
|
|
loggerFactory := logging.NewDefaultLoggerFactory()
|
|
|
|
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
|
IP: net.IP{127, 0, 0, 1},
|
|
Port: 0,
|
|
})
|
|
require.NoError(t, err, "error starting listener")
|
|
defer func() {
|
|
_ = listener.Close()
|
|
}()
|
|
|
|
tcpMux := NewTCPMuxDefault(TCPMuxParams{
|
|
Listener: listener,
|
|
Logger: loggerFactory.NewLogger("ice"),
|
|
ReadBufferSize: 20,
|
|
FirstStunBindTimeout: time.Second,
|
|
})
|
|
|
|
require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
|
|
|
|
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) // nolint
|
|
require.NoError(t, err, "error dialing test TCP connection")
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
// Don't send any data, the mux should close the connection after the timeout
|
|
time.Sleep(1500 * time.Millisecond)
|
|
require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second)))
|
|
buf := make([]byte, 1)
|
|
_, err = conn.Read(buf)
|
|
require.ErrorIs(t, err, io.EOF)
|
|
}
|
|
|
|
func TestTCPMux_NoLeakForConnectionFromStun(t *testing.T) {
|
|
defer test.CheckRoutines(t)()
|
|
|
|
loggerFactory := logging.NewDefaultLoggerFactory()
|
|
|
|
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
|
IP: net.IP{127, 0, 0, 1},
|
|
Port: 0,
|
|
})
|
|
require.NoError(t, err, "error starting listener")
|
|
defer func() {
|
|
_ = listener.Close()
|
|
}()
|
|
|
|
tcpMux := NewTCPMuxDefault(TCPMuxParams{
|
|
Listener: listener,
|
|
Logger: loggerFactory.NewLogger("ice"),
|
|
ReadBufferSize: 20,
|
|
AliveDurationForConnFromStun: time.Second,
|
|
})
|
|
|
|
defer func() {
|
|
_ = tcpMux.Close()
|
|
}()
|
|
|
|
require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
|
|
|
|
t.Run("close connection from stun msg after timeout", func(t *testing.T) {
|
|
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) // nolint
|
|
require.NoError(t, err, "error dialing test TCP connection")
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
|
|
stun.NewUsername("myufrag:otherufrag"),
|
|
stun.NewShortTermIntegrity("myufrag"),
|
|
stun.Fingerprint,
|
|
)
|
|
require.NoError(t, err, "error building STUN packet")
|
|
msg.Encode()
|
|
|
|
_, err = writeStreamingPacket(conn, msg.Raw)
|
|
require.NoError(t, err, "error writing TCP STUN packet")
|
|
|
|
time.Sleep(1500 * time.Millisecond)
|
|
require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second)))
|
|
buf := make([]byte, 1)
|
|
_, err = conn.Read(buf)
|
|
require.ErrorIs(t, err, io.EOF)
|
|
})
|
|
|
|
t.Run("connection keep alive if access by user", func(t *testing.T) {
|
|
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr)) // nolint
|
|
require.NoError(t, err, "error dialing test TCP connection")
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
|
|
stun.NewUsername("myufrag2:otherufrag2"),
|
|
stun.NewShortTermIntegrity("myufrag2"),
|
|
stun.Fingerprint,
|
|
)
|
|
require.NoError(t, err, "error building STUN packet")
|
|
msg.Encode()
|
|
|
|
n, err := writeStreamingPacket(conn, msg.Raw)
|
|
require.NoError(t, err, "error writing TCP STUN packet")
|
|
|
|
// wait for the connection to be created
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
listenerAddr, ok := listener.Addr().(*net.TCPAddr)
|
|
require.True(t, ok)
|
|
|
|
pktConn, err := tcpMux.GetConnByUfrag("myufrag2", false, listenerAddr.IP)
|
|
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
|
defer func() {
|
|
_ = pktConn.Close()
|
|
}()
|
|
|
|
time.Sleep(1500 * time.Millisecond)
|
|
|
|
// timeout, not closed
|
|
buf := make([]byte, 1024)
|
|
require.NoError(t, conn.SetReadDeadline(time.Now().Add(100*time.Millisecond)))
|
|
_, err = conn.Read(buf)
|
|
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
|
|
|
|
recv := make([]byte, n)
|
|
n2, rAddr, err := pktConn.ReadFrom(recv)
|
|
require.NoError(t, err, "error receiving data")
|
|
require.Equal(t, conn.LocalAddr(), rAddr, "remote tcp address mismatch")
|
|
require.Equal(t, n, n2, "received byte size mismatch")
|
|
require.Equal(t, msg.Raw, recv, "received bytes mismatch")
|
|
})
|
|
}
|