diff --git a/active_tcp_test.go b/active_tcp_test.go index c6728d9..db7f98c 100644 --- a/active_tcp_test.go +++ b/active_tcp_test.go @@ -7,7 +7,9 @@ package ice import ( + "context" "fmt" + "io" "net" "net/netip" "sync/atomic" @@ -290,3 +292,122 @@ func TestActiveTCP_Respect_NetworkTypes(t *testing.T) { require.NoError(t, tcpListener.Close()) require.Equal(t, uint64(0), atomic.LoadUint64(&incomingTCPCount)) } + +func TestNewActiveTCPConn_LocalAddrError_EarlyReturn(t *testing.T) { + defer test.CheckRoutines(t)() + + logger := logging.NewDefaultLoggerFactory().NewLogger("ice") + + // an invalid local address so getTCPAddrOnInterface fails at ResolveTCPAddr. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ra := netip.MustParseAddrPort("127.0.0.1:1") + + a := newActiveTCPConn(ctx, "this_is_not_a_valid_addr", ra, logger) + + require.NotNil(t, a) + require.True(t, a.closed.Load(), "should be closed on early return error") + la := a.LocalAddr() + require.NotNil(t, la) +} + +func TestActiveTCPConn_ReadLoop_BufferWriteError(t *testing.T) { + defer test.CheckRoutines(t)() + + tcpListener, err := net.Listen("tcp", "127.0.0.1:0") // nolint: noctx + require.NoError(t, err) + defer func() { _ = tcpListener.Close() }() + + ra := netip.MustParseAddrPort(tcpListener.Addr().String()) + logger := logging.NewDefaultLoggerFactory().NewLogger("ice") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + a := newActiveTCPConn(ctx, "127.0.0.1:0", ra, logger) + require.NotNil(t, a) + + srvConn, err := tcpListener.Accept() + require.NoError(t, err) + + require.NoError(t, a.readBuffer.Close()) + + _, err = writeStreamingPacket(srvConn, []byte("ping")) + require.NoError(t, err) + + require.NoError(t, a.Close()) + require.NoError(t, srvConn.Close()) +} + +func TestActiveTCPConn_WriteLoop_WriteStreamingError(t *testing.T) { + defer test.CheckRoutines(t)() + + tcpListener, err := net.Listen("tcp", "127.0.0.1:0") // nolint: noctx + require.NoError(t, err) + defer func() { _ = tcpListener.Close() }() + + ra := netip.MustParseAddrPort(tcpListener.Addr().String()) + logger := logging.NewDefaultLoggerFactory().NewLogger("ice") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + a := newActiveTCPConn(ctx, "127.0.0.1:0", ra, logger) + require.NotNil(t, a) + + srvConn, err := tcpListener.Accept() + require.NoError(t, err) + + require.NoError(t, srvConn.Close()) + + n, err := a.WriteTo([]byte("data"), nil) + require.NoError(t, err) + require.Equal(t, len("data"), n) + + require.NoError(t, a.Close()) +} + +func TestActiveTCPConn_LocalAddr_DefaultWhenUnset(t *testing.T) { + defer test.CheckRoutines(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + invalidLocal := "127.0.0.1:65536" + remote := netip.MustParseAddrPort("127.0.0.1:1") + + log := logging.NewDefaultLoggerFactory().NewLogger("ice") + + a := newActiveTCPConn(ctx, invalidLocal, remote, log) + require.NotNil(t, a) + require.True(t, a.closed.Load(), "expected early-return closed state") + + la := a.LocalAddr() + ta, ok := la.(*net.TCPAddr) + require.True(t, ok, "LocalAddr() should return *net.TCPAddr") + require.Nil(t, ta.IP, "fallback *net.TCPAddr should be zero value (nil IP)") + require.Equal(t, 0, ta.Port, "fallback *net.TCPAddr should be zero value (port 0)") + require.Equal(t, "", ta.Zone, "fallback *net.TCPAddr should be zero value (empty zone)") +} + +func TestActiveTCPConn_SetDeadlines_ReturnEOF(t *testing.T) { + defer test.CheckRoutines(t)() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + invalidLocal := "127.0.0.1:65536" + remote := netip.MustParseAddrPort("127.0.0.1:1") + log := logging.NewDefaultLoggerFactory().NewLogger("ice") + + a := newActiveTCPConn(ctx, invalidLocal, remote, log) + require.NotNil(t, a) + require.True(t, a.closed.Load(), "expected early-return closed state") + + err := a.SetReadDeadline(time.Now()) + require.ErrorIs(t, err, io.EOF) + + err = a.SetWriteDeadline(time.Now()) + require.ErrorIs(t, err, io.EOF) +} diff --git a/addr_test.go b/addr_test.go new file mode 100644 index 0000000..b5e6b3f --- /dev/null +++ b/addr_test.go @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js +// +build !js + +package ice + +import ( + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/require" +) + +// A net.Addr type that parseAddr doesn't handle. +type unknownAddr struct{} + +func (unknownAddr) Network() string { return "unknown" } +func (unknownAddr) String() string { return "unknown-addr" } + +func TestParseAddrFromIface_ErrFromParseAddr(t *testing.T) { + in := unknownAddr{} + ip, port, nt, err := parseAddrFromIface(in, "eth0") + + require.Error(t, err, "expected error from parseAddr for unknown net.Addr type") + require.Zero(t, port) + require.Zero(t, nt) + require.True(t, !ip.IsValid(), "ip should be zero value when error is returned") +} + +func TestParseAddr_ErrorBranches(t *testing.T) { + t.Run("IPNet invalid IP -> error", func(t *testing.T) { + // length 1 slice -> ipAddrToNetIP fails + _, _, _, err := parseAddr(&net.IPNet{IP: net.IP{1}}) + var convErr ipConvertError + require.ErrorAs(t, err, &convErr) + }) + + t.Run("IPAddr invalid IP -> error", func(t *testing.T) { + _, _, _, err := parseAddr(&net.IPAddr{IP: net.IP{1}, Zone: "eth0"}) + var convErr ipConvertError + require.ErrorAs(t, err, &convErr) + }) + + t.Run("UDPAddr invalid IP -> error", func(t *testing.T) { + _, _, _, err := parseAddr(&net.UDPAddr{IP: net.IP{1}, Port: 3478}) + var convErr ipConvertError + require.ErrorAs(t, err, &convErr) + }) + + t.Run("TCPAddr invalid IP -> error", func(t *testing.T) { + _, _, _, err := parseAddr(&net.TCPAddr{IP: net.IP{1}, Port: 3478}) + var convErr ipConvertError + require.ErrorAs(t, err, &convErr) + }) + + t.Run("Unknown net.Addr type -> addrParseError", func(t *testing.T) { + _, _, _, err := parseAddr(unknownAddr{}) + var ap addrParseError + require.ErrorAs(t, err, &ap) + }) +} + +func TestParseAddr_IPAddr_Success(t *testing.T) { + ip := net.ParseIP("fe80::1") + require.NotNil(t, ip) + + gotIP, port, nt, err := parseAddr(&net.IPAddr{IP: ip, Zone: "lo0"}) + require.NoError(t, err) + require.Equal(t, 0, port) + require.Equal(t, NetworkType(0), nt) + require.True(t, gotIP.Is6()) + require.Equal(t, "lo0", gotIP.Zone()) + require.Equal(t, 0, gotIP.Compare(netip.MustParseAddr("fe80::1%lo0").Unmap())) +} + +func TestAddrParseError_Error(t *testing.T) { + e := addrParseError{addr: &net.TCPAddr{}} + require.Equal(t, + "do not know how to parse address type *net.TCPAddr", + e.Error(), + ) +} + +func TestIPConvertError_Error(t *testing.T) { + e := ipConvertError{ip: []byte("bad-ip")} + require.Equal(t, + "failed to convert IP 'bad-ip' to netip.Addr", + e.Error(), + ) +} + +func TestIPAddrToNetIP_Error_InvalidBytes(t *testing.T) { + bad := []byte{1} // invalid length -> AddrFromSlice returns ok=false + got, err := ipAddrToNetIP(bad, "") + require.Equal(t, netip.Addr{}, got, "should return zero addr on error") + require.Error(t, err) + require.IsType(t, ipConvertError{}, err) + require.Contains(t, err.Error(), "failed to convert IP") +} + +func TestIPAddrToNetIP_OK_IPv4(t *testing.T) { + ipv4 := []byte{1, 2, 3, 4} + got, err := ipAddrToNetIP(ipv4, "") + require.NoError(t, err) + require.True(t, got.Is4()) + + want := netip.AddrFrom4([4]byte{1, 2, 3, 4}) + require.Equal(t, want, got) +} + +func TestAddrEqual_FirstParseError(t *testing.T) { + a := unknownAddr{} + b := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 9999} + + require.False(t, addrEqual(a, b)) +} + +func TestAddrEqual_SecondParseError(t *testing.T) { + a := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 9999} + b := unknownAddr{} + + require.False(t, addrEqual(a, b)) +} + +func TestAddrEqual_SameTypeIPPort(t *testing.T) { + a := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 4242} + b := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 4242} + + require.True(t, addrEqual(a, b)) +} diff --git a/agent_handlers_test.go b/agent_handlers_test.go index 0c980f6..059959c 100644 --- a/agent_handlers_test.go +++ b/agent_handlers_test.go @@ -9,6 +9,7 @@ import ( "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestConnectionStateNotifier(t *testing.T) { @@ -70,3 +71,119 @@ func TestConnectionStateNotifier(t *testing.T) { notifer.Close(true) }) } + +func TestHandlerNotifier_Close_AlreadyClosed(t *testing.T) { + defer test.CheckRoutines(t)() + + notifier := &handlerNotifier{ + connectionStateFunc: func(ConnectionState) {}, + candidateFunc: func(Candidate) {}, + candidatePairFunc: func(*CandidatePair) {}, + done: make(chan struct{}), + } + + // first close + notifier.Close(false) + + isClosed := func(ch <-chan struct{}) bool { + select { + case <-ch: + return true + default: + return false + } + } + assert.True(t, isClosed(notifier.done), "expected h.done to be closed after first Close") + + // second close should hit `case <-h.done` and return immediately + // without blocking on the WaitGroup. + finished := make(chan struct{}, 1) + go func() { + notifier.Close(true) + close(finished) + }() + + assert.Eventually(t, func() bool { + select { + case <-finished: + return true + default: + return false + } + }, 250*time.Millisecond, 10*time.Millisecond, "second Close(true) did not return promptly") + + // ensure still closed afterwards + assert.True(t, isClosed(notifier.done), "expected h.done to remain closed after second Close") + + // sanity: no enqueues should start after close. + require.False(t, notifier.running) + require.Zero(t, len(notifier.connectionStates)) + require.Zero(t, len(notifier.candidates)) + require.Zero(t, len(notifier.selectedCandidatePairs)) +} + +func TestHandlerNotifier_EnqueueConnectionState_AfterClose(t *testing.T) { + defer test.CheckRoutines(t)() + + connCh := make(chan struct{}, 1) + notifier := &handlerNotifier{ + connectionStateFunc: func(ConnectionState) { connCh <- struct{}{} }, + done: make(chan struct{}), + } + + notifier.Close(false) + notifier.EnqueueConnectionState(ConnectionStateConnected) + + assert.Never(t, func() bool { + select { + case <-connCh: + return true + default: + return false + } + }, 250*time.Millisecond, 10*time.Millisecond, "connectionStateFunc should not be called after close") +} + +func TestHandlerNotifier_EnqueueCandidate_AfterClose(t *testing.T) { + defer test.CheckRoutines(t)() + + candidateCh := make(chan struct{}, 1) + h := &handlerNotifier{ + candidateFunc: func(Candidate) { candidateCh <- struct{}{} }, + done: make(chan struct{}), + } + + h.Close(false) + h.EnqueueCandidate(nil) + + assert.Never(t, func() bool { + select { + case <-candidateCh: + return true + default: + return false + } + }, 250*time.Millisecond, 10*time.Millisecond, "candidateFunc should not be called after close") +} + +func TestHandlerNotifier_EnqueueSelectedCandidatePair_AfterClose(t *testing.T) { + defer test.CheckRoutines(t)() + + pairCh := make(chan struct{}, 1) + h := &handlerNotifier{ + candidatePairFunc: func(*CandidatePair) { pairCh <- struct{}{} }, + done: make(chan struct{}), + } + + h.Close(false) + h.EnqueueSelectedCandidatePair(nil) + + assert.Never(t, func() bool { + select { + case <-pairCh: + return true + default: + return false + } + }, 250*time.Millisecond, 10*time.Millisecond, "candidatePairFunc should not be called after close") +} diff --git a/agent_test.go b/agent_test.go index 7f381c0..4297592 100644 --- a/agent_test.go +++ b/agent_test.go @@ -2012,3 +2012,28 @@ func TestRoleConflict(t *testing.T) { runTest(t, false) }) } + +func TestAgentConfig_initWithDefaults_UsesProvidedValues(t *testing.T) { + valMaxBindingReq := uint16(0) + valSrflxWait := 111 * time.Millisecond + valPrflxWait := 222 * time.Millisecond + valRelayWait := 3 * time.Second + valStunTimeout := 4 * time.Second + + cfg := &AgentConfig{ + MaxBindingRequests: &valMaxBindingReq, + SrflxAcceptanceMinWait: &valSrflxWait, + PrflxAcceptanceMinWait: &valPrflxWait, + RelayAcceptanceMinWait: &valRelayWait, + STUNGatherTimeout: &valStunTimeout, + } + + var a Agent + cfg.initWithDefaults(&a) + + require.Equal(t, valMaxBindingReq, a.maxBindingRequests, "expected override for MaxBindingRequests") + require.Equal(t, valSrflxWait, a.srflxAcceptanceMinWait, "expected override for SrflxAcceptanceMinWait") + require.Equal(t, valPrflxWait, a.prflxAcceptanceMinWait, "expected override for PrflxAcceptanceMinWait") + require.Equal(t, valRelayWait, a.relayAcceptanceMinWait, "expected override for RelayAcceptanceMinWait") + require.Equal(t, valStunTimeout, a.stunGatherTimeout, "expected override for STUNGatherTimeout") +} diff --git a/candidatepair_test.go b/candidatepair_test.go index f6e8284..558d640 100644 --- a/candidatepair_test.go +++ b/candidatepair_test.go @@ -130,3 +130,48 @@ func TestNilCandidatePairString(t *testing.T) { var nilCandidatePair *CandidatePair require.Equal(t, nilCandidatePair.String(), "") } + +func TestCandidatePairState_String(t *testing.T) { + tests := []struct { + name string + in CandidatePairState + want string + }{ + {"waiting", CandidatePairStateWaiting, "waiting"}, + {"in-progress", CandidatePairStateInProgress, "in-progress"}, + {"failed", CandidatePairStateFailed, "failed"}, + {"succeeded", CandidatePairStateSucceeded, "succeeded"}, + {"unknown", CandidatePairState(255), "Unknown candidate pair state"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, tt.in.String()) + }) + } +} + +func TestCandidatePairEqual_NilCases(t *testing.T) { + // both nil -> true + var a *CandidatePair + var b *CandidatePair + require.True(t, a.equal(b), "both nil pairs should be equal") + + // left non-nil, right nil -> false + a = newCandidatePair(hostCandidate(), srflxCandidate(), true) + require.False(t, a.equal(nil), "non-nil vs nil should be false") + + // left nil, right non-nil -> false + require.False(t, (*CandidatePair)(nil).equal(a), "nil vs non-nil should be false") +} + +func TestCandidatePair_TimeGetters_DefaultZero(t *testing.T) { + p := newCandidatePair(hostCandidate(), srflxCandidate(), true) + + require.True(t, p.FirstRequestSentAt().IsZero(), "FirstRequestSentAt should be zero by default") + require.True(t, p.LastRequestSentAt().IsZero(), "LastRequestSentAt should be zero by default") + require.True(t, p.FirstReponseReceivedAt().IsZero(), "FirstReponseReceivedAt should be zero by default") + require.True(t, p.LastResponseReceivedAt().IsZero(), "LastResponseReceivedAt should be zero by default") + require.True(t, p.FirstRequestReceivedAt().IsZero(), "FirstRequestReceivedAt should be zero by default") + require.True(t, p.LastRequestReceivedAt().IsZero(), "LastRequestReceivedAt should be zero by default") +} diff --git a/candidatetype_test.go b/candidatetype_test.go new file mode 100644 index 0000000..e7342ec --- /dev/null +++ b/candidatetype_test.go @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package ice + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCandidateType_String_KnownCases(t *testing.T) { + cases := map[CandidateType]string{ + CandidateTypeHost: "host", + CandidateTypeServerReflexive: "srflx", + CandidateTypePeerReflexive: "prflx", + CandidateTypeRelay: "relay", + CandidateTypeUnspecified: "Unknown candidate type", + } + + for ct, want := range cases { + require.Equal(t, want, ct.String(), "unexpected string for %v", ct) + } +} + +func TestCandidateType_String_Default(t *testing.T) { + const outOfBounds CandidateType = 255 + require.Equal(t, "Unknown candidate type", outOfBounds.String()) +} + +func TestCandidateType_Preference_DefaultCase(t *testing.T) { + const outOfBounds CandidateType = 255 + require.Equal(t, uint16(0), outOfBounds.Preference()) +} + +func TestContainsCandidateType_NilSlice(t *testing.T) { + var list []CandidateType // nil slice + require.False(t, containsCandidateType(CandidateTypeHost, list)) +} diff --git a/net_test.go b/net_test.go index 8ebdf1f..3d6f052 100644 --- a/net_test.go +++ b/net_test.go @@ -4,11 +4,16 @@ package ice import ( + "errors" "net" "net/netip" + "sort" "strings" "testing" + "github.com/pion/logging" + "github.com/pion/transport/v3" + "github.com/pion/transport/v3/stdnet" "github.com/stretchr/testify/require" ) @@ -58,3 +63,181 @@ func mustAddr(t *testing.T, ip net.IP) netip.Addr { return addr } + +type errInterfacesNet struct { + transport.Net + retErr error +} + +func (e *errInterfacesNet) Interfaces() ([]*transport.Interface, error) { + return nil, e.retErr +} + +var errBoom = errors.New("boom") + +func TestLocalInterfaces_ErrorFromInterfaces(t *testing.T) { + base, err := stdnet.NewNet() + require.NoError(t, err) + + wrapped := &errInterfacesNet{ + Net: base, + retErr: errBoom, + } + + ifaces, addrs, gotErr := localInterfaces( + wrapped, + nil, + nil, + nil, + false, + ) + + require.ErrorIs(t, gotErr, wrapped.retErr) + require.Nil(t, ifaces, "expected nil iface slice on error") + require.NotNil(t, addrs, "ipAddrs should be a non-nil empty slice") + require.Len(t, addrs, 0) +} + +type fixedInterfacesNet struct { + transport.Net + list []*transport.Interface +} + +func (f *fixedInterfacesNet) Interfaces() ([]*transport.Interface, error) { + return f.list, nil +} + +func TestLocalInterfaces_SkipInterfaceDown(t *testing.T) { + base, err := stdnet.NewNet() + require.NoError(t, err) + + sysIfaces, err := base.Interfaces() + require.NoError(t, err) + if len(sysIfaces) == 0 { + t.Skip("no system network interfaces available") + } + + clone := *sysIfaces[0] + clone.Flags &^= net.FlagUp + + wrapped := &fixedInterfacesNet{ + Net: base, + list: []*transport.Interface{&clone}, + } + + ifcs, addrs, ierr := localInterfaces( + wrapped, + nil, + nil, + nil, + false, + ) + require.NoError(t, ierr) + require.Len(t, ifcs, 0, "down interfaces must be skipped") + require.Len(t, addrs, 0, "no addresses should be collected from a down interface") +} + +func TestLocalInterfaces_SkipLoopbackAddrs_WhenIncludeLoopbackFalse(t *testing.T) { + base, err := stdnet.NewNet() + require.NoError(t, err) + + sysIfaces, err := base.Interfaces() + require.NoError(t, err) + + var loop *transport.Interface + for _, ifc := range sysIfaces { + if ifc.Flags&net.FlagLoopback != 0 { + loop = ifc + + break + } + } + if loop == nil { + t.Skip("no loopback interface found on this system") + } + + // clone the loopback iface and clear the Loopback flag so the outer check + // doesn't drop it to force the inner `(ipAddr.IsLoopback() && !includeLoopback)`. + cloned := *loop + cloned.Flags |= net.FlagUp + cloned.Flags &^= net.FlagLoopback + + wrapped := &fixedInterfacesNet{ + Net: base, + list: []*transport.Interface{&cloned}, + } + + ifaces, addrs, ierr := localInterfaces( + wrapped, + nil, // interfaceFilter + nil, // ipFilter + nil, // networkTypes + false, // includeLoopback + ) + require.NoError(t, ierr) + + // don't assert on the number of interfaces because some systems may + // report the iface as having addresses in a way that causes it to be included. + // assert that all loopback addresses were skipped. + for _, a := range addrs { + require.False(t, a.IsLoopback(), "loopback addresses must be skipped when includeLoopback=false") + } + + _ = ifaces // intentionally don't assert on this, see above comment +} + +// Captures ListenUDP attempts and always fails so the loop exhausts. +type listenUDPCaptor struct { + transport.Net + attempts []int +} + +func (c *listenUDPCaptor) ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + c.attempts = append(c.attempts, laddr.Port) + + return nil, errBoom +} + +func TestListenUDPInPortRange_DefaultsPortMinTo1024(t *testing.T) { + base, err := stdnet.NewNet() + require.NoError(t, err) + + captor := &listenUDPCaptor{Net: base} + logger := logging.NewDefaultLoggerFactory().NewLogger("ice-test") + + // portMin == 0 (should become 1024), portMax small to keep the loop short. + _, err = listenUDPInPortRange( + captor, + logger, + 1030, // portMax + 0, // portMin -> becomes 1024 + udp4, + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}, + ) + require.ErrorIs(t, err, ErrPort) + + // should have attempted exactly [1024..1030] in some order. + sort.Ints(captor.attempts) + require.Equal(t, []int{1024, 1025, 1026, 1027, 1028, 1029, 1030}, captor.attempts) +} + +func TestListenUDPInPortRange_DefaultsPortMaxToFFFF(t *testing.T) { + base, err := stdnet.NewNet() + require.NoError(t, err) + + captor := &listenUDPCaptor{Net: base} + logger := logging.NewDefaultLoggerFactory().NewLogger("ice-test") + + // portMax == 0 (should become 0xFFFF). Use portMin=65535 so the range is 1 port. + _, err = listenUDPInPortRange( + captor, + logger, + 0, // portMax -> becomes 65535 + 65535, // portMin + udp4, + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}, + ) + require.ErrorIs(t, err, ErrPort) + + require.Equal(t, []int{65535}, captor.attempts) +} diff --git a/networktype_test.go b/networktype_test.go index 31aba30..70dfa70 100644 --- a/networktype_test.go +++ b/networktype_test.go @@ -83,3 +83,53 @@ func TestNetworkTypeIsTCP(t *testing.T) { require.False(t, NetworkTypeTCP4.IsUDP()) require.False(t, NetworkTypeTCP6.IsUDP()) } + +func TestNetworkType_String_Default(t *testing.T) { + var invalid NetworkType // 0 triggers default branch + require.Equal(t, ErrUnknownType.Error(), invalid.String()) + + require.Equal(t, "udp4", NetworkTypeUDP4.String()) + require.Equal(t, "udp6", NetworkTypeUDP6.String()) + require.Equal(t, "tcp4", NetworkTypeTCP4.String()) + require.Equal(t, "tcp6", NetworkTypeTCP6.String()) +} + +func TestNetworkType_NetworkShort_Default(t *testing.T) { + var invalid NetworkType + require.Equal(t, ErrUnknownType.Error(), invalid.NetworkShort()) + + require.Equal(t, udp, NetworkTypeUDP4.NetworkShort()) + require.Equal(t, udp, NetworkTypeUDP6.NetworkShort()) + require.Equal(t, tcp, NetworkTypeTCP4.NetworkShort()) + require.Equal(t, tcp, NetworkTypeTCP6.NetworkShort()) +} + +func TestNetworkType_IPvFlags_Default(t *testing.T) { + var invalid NetworkType + require.False(t, invalid.IsIPv4()) + require.False(t, invalid.IsIPv6()) + + require.True(t, NetworkTypeUDP4.IsIPv4()) + require.True(t, NetworkTypeTCP4.IsIPv4()) + require.False(t, NetworkTypeUDP6.IsIPv4()) + require.False(t, NetworkTypeTCP6.IsIPv4()) + + require.True(t, NetworkTypeUDP6.IsIPv6()) + require.True(t, NetworkTypeTCP6.IsIPv6()) + require.False(t, NetworkTypeUDP4.IsIPv6()) + require.False(t, NetworkTypeTCP4.IsIPv6()) +} + +func TestNetworkType_IsReliable(t *testing.T) { + // UDP is unreliable + require.False(t, NetworkTypeUDP4.IsReliable()) + require.False(t, NetworkTypeUDP6.IsReliable()) + + // TCP is reliable + require.True(t, NetworkTypeTCP4.IsReliable()) + require.True(t, NetworkTypeTCP6.IsReliable()) + + // default/unknown falls through to false + var invalid NetworkType + require.False(t, invalid.IsReliable()) +} diff --git a/role_test.go b/role_test.go new file mode 100644 index 0000000..712ab72 --- /dev/null +++ b/role_test.go @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js +// +build !js + +package ice + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnmarshalText_Success(t *testing.T) { + tests := []struct { + name string + in string + want Role + }{ + {"controlling", "controlling", Controlling}, + {"controlled", "controlled", Controlled}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var r Role + err := r.UnmarshalText([]byte(tt.in)) + require.NoError(t, err) + require.Equal(t, tt.want, r) + }) + } +} + +func TestUnmarshalText_UnknownKeepsValueAndErrors(t *testing.T) { + r := Controlled + err := r.UnmarshalText([]byte("neither")) + require.ErrorIs(t, err, errUnknownRole) + require.Equal(t, Controlled, r, "role should remain unchanged on error") +} + +func TestMarshalText(t *testing.T) { + tests := []struct { + name string + in Role + want string + }{ + {"controlling", Controlling, "controlling"}, + {"controlled", Controlled, "controlled"}, + {"unknown", Role(99), "unknown"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b, err := tt.in.MarshalText() + require.NoError(t, err) + require.Equal(t, tt.want, string(b)) + }) + } +} + +func TestString(t *testing.T) { + require.Equal(t, "controlling", Controlling.String()) + require.Equal(t, "controlled", Controlled.String()) + require.Equal(t, "unknown", Role(255).String()) +} diff --git a/selection_test.go b/selection_test.go index 810df8e..ffd0b35 100644 --- a/selection_test.go +++ b/selection_test.go @@ -10,12 +10,15 @@ import ( "bytes" "context" "errors" + "fmt" "io" "net" + "strings" "sync/atomic" "testing" "time" + "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" @@ -158,3 +161,199 @@ func TestBindingRequestHandler(t *testing.T) { closePipe(t, controllingConn, controlledConn) } + +// copied from pion/webrtc's peerconnection_go_test.go. +type testICELogger struct { + lastErrorMessage string +} + +func (t *testICELogger) Trace(string) {} +func (t *testICELogger) Tracef(string, ...any) {} +func (t *testICELogger) Debug(string) {} +func (t *testICELogger) Debugf(string, ...any) {} +func (t *testICELogger) Info(string) {} +func (t *testICELogger) Infof(string, ...any) {} +func (t *testICELogger) Warn(string) {} +func (t *testICELogger) Warnf(string, ...any) {} +func (t *testICELogger) Error(msg string) { t.lastErrorMessage = msg } +func (t *testICELogger) Errorf(format string, args ...any) { + t.lastErrorMessage = fmt.Sprintf(format, args...) +} + +type testICELoggerFactory struct { + logger *testICELogger +} + +func (t *testICELoggerFactory) NewLogger(string) logging.LeveledLogger { + return t.logger +} + +func TestControllingSelector_IsNominatable_LogsInvalidType(t *testing.T) { + testLogger := &testICELogger{} + loggerFactory := &testICELoggerFactory{logger: testLogger} + + sel := &controllingSelector{ + agent: &Agent{}, + log: loggerFactory.NewLogger("test"), + } + sel.Start() + + c := hostCandidate() + c.candidateBase.candidateType = CandidateTypeUnspecified + + got := sel.isNominatable(c) + + require.False(t, got) + require.Contains(t, testLogger.lastErrorMessage, "Invalid candidate type") + require.Contains(t, testLogger.lastErrorMessage, "Unknown candidate type") // from c.Type().String() +} + +func TestControllingSelector_NominatePair_BuildError(t *testing.T) { + testLogger := &testICELogger{} + loggerFactory := &testICELoggerFactory{logger: testLogger} + + // selector with an Agent with ufrags to make an oversized username + // (username = remoteUfrag + ":" + localUfrag) since oversized username causes + // stun.NewUsername(...) inside stun.Build to fail. + long := strings.Repeat("x", 300) // > 255 each side + sel := &controllingSelector{ + agent: &Agent{ + remoteUfrag: long, + localUfrag: long, + remotePwd: "pwd", // any non-empty value is fine + tieBreaker: 0, + }, + log: loggerFactory.NewLogger("test"), + } + sel.Start() + + p := newCandidatePair(hostCandidate(), hostCandidate(), true) + + sel.nominatePair(p) + + require.NotEmpty(t, testLogger.lastErrorMessage, "expected error log from nominatePair on Build failure") +} + +type pingNoIOCand struct{ candidateBase } + +func newPingNoIOCand() *pingNoIOCand { + return &pingNoIOCand{ + candidateBase: candidateBase{ + candidateType: CandidateTypeHost, + component: ComponentRTP, + }, + } +} +func (d *pingNoIOCand) writeTo(b []byte, _ Candidate) (int, error) { return len(b), nil } + +func bareAgentForPing() *Agent { + return &Agent{ + hostAcceptanceMinWait: time.Hour, + srflxAcceptanceMinWait: time.Hour, + prflxAcceptanceMinWait: time.Hour, + relayAcceptanceMinWait: time.Hour, + + checklist: []*CandidatePair{}, + keepaliveInterval: time.Second, + checkInterval: time.Second, + + connectionStateNotifier: &handlerNotifier{ + done: make(chan struct{}), + connectionStateFunc: func(ConnectionState) {}}, //nolint formatting + + candidateNotifier: &handlerNotifier{ + done: make(chan struct{}), + candidateFunc: func(Candidate) {}}, //nolint formatting + + selectedCandidatePairNotifier: &handlerNotifier{ + done: make(chan struct{}), + candidatePairFunc: func(*CandidatePair) {}}, //nolint formatting + } +} + +func bigStr() string { return strings.Repeat("x", 40000) } + +func TestControllingSelector_PingCandidate_BuildError(t *testing.T) { + a := bareAgentForPing() + // make Username really big so stun.Build returns an error. + a.remoteUfrag = bigStr() + a.localUfrag = bigStr() + a.remotePwd = "pwd" + a.tieBreaker = 1 + + testLogger := &testICELogger{} + sel := &controllingSelector{agent: a, log: testLogger} + sel.Start() + + local := newPingNoIOCand() + remote := newPingNoIOCand() + + sel.PingCandidate(local, remote) + + require.NotEmpty(t, testLogger.lastErrorMessage, "expected error to be logged from stun.Build") +} + +func TestControlledSelector_PingCandidate_BuildError(t *testing.T) { + a := bareAgentForPing() + a.remoteUfrag = bigStr() + a.localUfrag = bigStr() + a.remotePwd = "pwd" + a.tieBreaker = 1 + + testLogger := &testICELogger{} + sel := &controlledSelector{agent: a, log: testLogger} + sel.Start() + + local := newPingNoIOCand() + remote := newPingNoIOCand() + + sel.PingCandidate(local, remote) + + require.NotEmpty(t, testLogger.lastErrorMessage, "expected error to be logged from stun.Build") +} + +type warnTestLogger struct { + warned bool +} + +func (l *warnTestLogger) Trace(string) {} +func (l *warnTestLogger) Tracef(string, ...any) {} +func (l *warnTestLogger) Debug(string) {} +func (l *warnTestLogger) Debugf(string, ...any) {} +func (l *warnTestLogger) Info(string) {} +func (l *warnTestLogger) Infof(string, ...any) {} +func (l *warnTestLogger) Warn(string) { l.warned = true } +func (l *warnTestLogger) Warnf(string, ...any) { l.warned = true } +func (l *warnTestLogger) Error(string) {} +func (l *warnTestLogger) Errorf(string, ...any) {} + +type dummyNoIOCand struct{ candidateBase } + +func newDummyNoIOCand(t CandidateType) *dummyNoIOCand { + return &dummyNoIOCand{ + candidateBase: candidateBase{ + candidateType: t, + component: ComponentRTP, + }, + } +} +func (d *dummyNoIOCand) writeTo(p []byte, _ Candidate) (int, error) { return len(p), nil } + +func TestControlledSelector_HandleSuccessResponse_UnknownTxID(t *testing.T) { + logger := &warnTestLogger{} + + ag := &Agent{log: logger} + + sel := &controlledSelector{agent: ag, log: logger} + sel.Start() + + local := newDummyNoIOCand(CandidateTypeHost) + remote := newDummyNoIOCand(CandidateTypeHost) + + var m stun.Message + copy(m.TransactionID[:], []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}) + + sel.HandleSuccessResponse(&m, local, remote, nil) + + require.True(t, logger.warned, "expected Warnf to be called for unknown TransactionID (hitting !ok branch)") +} diff --git a/transport_test.go b/transport_test.go index 5bb3e84..cdb4f6b 100644 --- a/transport_test.go +++ b/transport_test.go @@ -310,3 +310,49 @@ func TestConnStats(t *testing.T) { require.Equal(t, uint64(10), ca.BytesSent()) require.Equal(t, uint64(10), cb.BytesReceived()) } + +func TestAgent_connect_ErrEarly(t *testing.T) { + defer test.CheckRoutines(t)() + + cfg := &AgentConfig{ + NetworkTypes: supportedNetworkTypes(), + } + a, err := NewAgent(cfg) + require.NoError(t, err) + + require.NoError(t, a.Close()) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + // isControlling = true + conn, cerr := a.connect(ctx, true, "ufragX", "pwdX") + require.Nil(t, conn) + require.Error(t, cerr, "expected error from a.loop.Err() short-circuit") +} + +func TestConn_Write_RejectsSTUN(t *testing.T) { + defer test.CheckRoutines(t)() + defer test.TimeOut(10 * time.Second).Stop() + + cfg := &AgentConfig{ + NetworkTypes: supportedNetworkTypes(), + MulticastDNSMode: MulticastDNSModeDisabled, + } + a, err := NewAgent(cfg) + require.NoError(t, err) + defer func() { + _ = a.Close() + }() + + c := &Conn{agent: a} + require.Nil(t, c.agent.getSelectedPair(), "precondition: no selected pair") + + msg := stun.New() + msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} + msg.Encode() + + n, werr := c.Write(msg.Raw) + require.Zero(t, n) + require.ErrorIs(t, werr, errWriteSTUNMessageToIceConn) +} diff --git a/udp_mux_test.go b/udp_mux_test.go index 7ccf69c..296438c 100644 --- a/udp_mux_test.go +++ b/udp_mux_test.go @@ -10,12 +10,15 @@ import ( "crypto/rand" "crypto/sha256" "encoding/binary" + "errors" "io" "net" "sync" "testing" "time" + "github.com/pion/ice/v4/internal/fakenet" + "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/require" @@ -511,3 +514,319 @@ func TestUDPMuxedConn_writePacket_NotifyDefaultBranch(t *testing.T) { // channel still full => no send happened (default path executed) require.Equal(t, 1, len(conn.notify)) } + +func TestNewUDPMuxDefault_LocalAddrNotUDPAddr(t *testing.T) { + defer test.CheckRoutines(t)() + + c1, c2 := net.Pipe() + defer func() { _ = c2.Close() }() + + pc := &fakenet.PacketConn{Conn: c1} + + mux := NewUDPMuxDefault(UDPMuxParams{ + Logger: logging.NewDefaultLoggerFactory().NewLogger("ice"), + UDPConn: pc, + }) + require.NotNil(t, mux) + + defer func() { _ = mux.Close() }() + + addrs := mux.GetListenAddresses() + require.Len(t, addrs, 1) + require.Equal(t, pc.LocalAddr().String(), addrs[0].String()) +} + +func TestUDPMuxDefault_GetConn_InvalidAddress(t *testing.T) { + defer test.CheckRoutines(t)() + + connA, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + + udpMux := NewUDPMuxDefault(UDPMuxParams{ + Logger: nil, + UDPConn: connA, + }) + defer func() { + _ = udpMux.Close() + _ = connA.Close() + }() + + connB, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + defer func() { _ = connB.Close() }() + + pc, gerr := udpMux.GetConn("some-ufrag", connB.LocalAddr()) + require.Nil(t, pc) + require.ErrorIs(t, gerr, errInvalidAddress) +} + +func TestUDPMuxDefault_registerConnForAddress_ClosedMuxEarlyReturn(t *testing.T) { + defer test.CheckRoutines(t)() + + udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + udpMux := NewUDPMuxDefault(UDPMuxParams{UDPConn: udp}) + + require.NoError(t, udpMux.Close()) + _ = udp.Close() + + conn := secondTestMuxedConn(t, 64) + addr, err := newIPPort(net.ParseIP("1.2.3.4"), "", 9999) + require.NoError(t, err) + + before := len(udpMux.addressMap) + udpMux.registerConnForAddress(conn, addr) + after := len(udpMux.addressMap) + + require.Equal(t, before, after) + _, exists := udpMux.addressMap[addr] + require.False(t, exists) +} + +func TestUDPMuxDefault_registerConnForAddress_ReplacesExisting(t *testing.T) { + defer test.CheckRoutines(t)() + + udp, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) + require.NoError(t, err) + udpMux := NewUDPMuxDefault(UDPMuxParams{UDPConn: udp}) + defer func() { + _ = udpMux.Close() + _ = udp.Close() + }() + + ipAddr, err := newIPPort(net.ParseIP("5.6.7.8"), "", 12345) + require.NoError(t, err) + + existing := secondTestMuxedConn(t, 64) + existing.addresses = []ipPort{ipAddr} + udpMux.addressMapMu.Lock() + udpMux.addressMap[ipAddr] = existing + udpMux.addressMapMu.Unlock() + + // new conn should replace existing mapping and cause removeAddress on the old one. + newConn := secondTestMuxedConn(t, 64) + udpMux.registerConnForAddress(newConn, ipAddr) + + // map should now point to newConn. + udpMux.addressMapMu.RLock() + mapped := udpMux.addressMap[ipAddr] + udpMux.addressMapMu.RUnlock() + require.Equal(t, newConn, mapped) + + // old conn should have ipAddr removed from its addresses. + require.False(t, existing.containsAddress(ipAddr), "old conn should have removed the address backref") +} + +func stunWithLen(l uint16) []byte { + m := stun.New() + m.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} + m.Encode() + out := append([]byte{}, m.Raw...) + out[2] = byte(l >> 8) + out[3] = byte(l & 0xff) + + return out +} + +type scriptedUDPPC struct { + local *net.UDPAddr + seq []struct { + data []byte + addr net.Addr + err error + } + i int +} + +func (s *scriptedUDPPC) ReadFrom(p []byte) (int, net.Addr, error) { + if s.i >= len(s.seq) { + return 0, s.local, errIoEOF + } + step := s.seq[s.i] + s.i++ + if step.err != nil { + return 0, step.addr, step.err + } + n := copy(p, step.data) + + return n, step.addr, nil +} +func (s *scriptedUDPPC) WriteTo([]byte, net.Addr) (int, error) { return 0, nil } +func (s *scriptedUDPPC) Close() error { return nil } +func (s *scriptedUDPPC) LocalAddr() net.Addr { return s.local } +func (s *scriptedUDPPC) SetDeadline(time.Time) error { return nil } +func (s *scriptedUDPPC) SetReadDeadline(time.Time) error { return nil } +func (s *scriptedUDPPC) SetWriteDeadline(time.Time) error { return nil } + +var errIoEOF = errors.New("EOF") + +func TestUDPMux_connWorker_AddrNotUDP(t *testing.T) { + defer test.CheckRoutines(t)() + + c1, c2 := net.Pipe() + defer func() { + _ = c2.Close() + }() + + pc := &fakenet.PacketConn{Conn: c1} + mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc}) + require.NotNil(t, mux) + defer func() { + _ = mux.Close() + }() + + _, _ = c2.Write([]byte("frame")) + _ = c2.Close() +} + +func TestUDPMux_connWorker_ReadError_Timeout(t *testing.T) { + defer test.CheckRoutines(t)() + + c1, c2 := net.Pipe() + pc := &fakenet.PacketConn{Conn: c1} + mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc}) + require.NotNil(t, mux) + + _ = pc.SetReadDeadline(time.Unix(0, 0)) + + _ = c2.Close() + _ = mux.Close() +} + +func TestUDPMux_connWorker_NewIPPortError(t *testing.T) { + defer test.CheckRoutines(t)() + + badIP := net.IP{1} + remote := &net.UDPAddr{IP: badIP, Port: 9999} + pc := &scriptedUDPPC{ + local: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7000}, + seq: []struct { + data []byte + addr net.Addr + err error + }{ + {data: []byte{1}, addr: remote, err: nil}, // triggers newIPPort error + }, + } + mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc}) + require.NotNil(t, mux) + _ = mux.Close() +} + +func TestUDPMux_connWorker_STUNDecodeError(t *testing.T) { + defer test.CheckRoutines(t)() + + remote := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 2), Port: 5678} + pc := &scriptedUDPPC{ + local: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7001}, + seq: []struct { + data []byte + addr net.Addr + err error + }{ + // bad STUN length -> Decode() error -> Warnf + continue + {data: stunWithLen(4), addr: remote, err: nil}, + {data: nil, addr: remote, err: errIoEOF}, // exit loop + }, + } + mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc}) + require.NotNil(t, mux) + _ = mux.Close() +} + +func TestUDPMux_connWorker_STUNNoUsername(t *testing.T) { + defer test.CheckRoutines(t)() + + msg := stun.New() + msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} + msg.Encode() // valid STUN + no USERNAME + + remote := &net.UDPAddr{IP: net.IPv4(10, 0, 0, 3), Port: 5679} + pc := &scriptedUDPPC{ + local: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7002}, + seq: []struct { + data []byte + addr net.Addr + err error + }{ + {data: append([]byte{}, msg.Raw...), addr: remote, err: nil}, // Get(USERNAME) fails + {data: nil, addr: remote, err: errIoEOF}, // exit loop + }, + } + mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc}) + require.NotNil(t, mux) + _ = mux.Close() +} + +func TestUDPMux_connWorker_WritePacketError(t *testing.T) { + defer test.CheckRoutines(t)() + + local := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7003} + remote := &net.UDPAddr{IP: net.IPv4(203, 0, 113, 7), Port: 5555} + payload := []byte("0123456789ABCDEF") + + pc := &scriptedUDPPC{ + local: local, + seq: []struct { + data []byte + addr net.Addr + err error + }{ + {data: payload, addr: remote, err: nil}, + {data: nil, addr: remote, err: errIoEOF}, // exit loop + }, + } + mux := NewUDPMuxDefault(UDPMuxParams{UDPConn: pc}) + require.NotNil(t, mux) + defer func() { + _ = mux.Close() + }() + + // shrink pool to force io.ErrShortBuffer in writePacket + mux.pool = &sync.Pool{New: func() any { return newBufferHolder(8) }} + + // make connWorker route to new conn. + c, err := mux.GetConn("ufragX", mux.LocalAddr()) + require.NoError(t, err) + defer func() { + _ = c.Close() + }() + + // remote port is controlled. we use 5555 here to skip int overflow check as we would + // otherwise have to cast remote.Port (int) to uint16. + ipport, err := newIPPort(remote.IP, remote.Zone, 5555) + require.NoError(t, err) + + cInner, ok := c.(*udpMuxedConn) + require.True(t, ok, "expected *udpMuxedConn from UDPMuxDefault.GetConn") + mux.registerConnForAddress(cInner, ipport) +} + +func TestNewUDPMuxDefault_UnspecifiedAddr_AutoInitNet(t *testing.T) { + defer test.CheckRoutines(t)() + + conn, err := net.ListenUDP(udp, &net.UDPAddr{IP: net.IPv4zero}) + require.NoError(t, err) + defer func() { _ = conn.Close() }() + + mux := NewUDPMuxDefault(UDPMuxParams{ + Logger: nil, + UDPConn: conn, + Net: nil, + }) + require.NotNil(t, mux) + defer func() { _ = mux.Close() }() + + addrs := mux.GetListenAddresses() + require.GreaterOrEqual(t, len(addrs), 1, "should list at least one local listen address") + + udpAddr, ok := conn.LocalAddr().(*net.UDPAddr) + require.True(t, ok, "LocalAddr is not *net.UDPAddr") + wantPort := udpAddr.Port + + for _, a := range addrs { + ua, ok := a.(*net.UDPAddr) + require.True(t, ok, "returned listen address must be *net.UDPAddr") + require.Equal(t, wantPort, ua.Port, "listen addresses should reuse the same UDP port") + } +} diff --git a/usecandidate_test.go b/usecandidate_test.go index d50315a..b3e3f3f 100644 --- a/usecandidate_test.go +++ b/usecandidate_test.go @@ -11,12 +11,14 @@ import ( ) func TestUseCandidateAttr_AddTo(t *testing.T) { - m := new(stun.Message) - require.False(t, UseCandidate().IsSet(m)) - require.NoError(t, m.Build(stun.BindingRequest, UseCandidate())) + msg := stun.New() + msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest} + require.False(t, UseCandidate().IsSet(msg)) - m1 := new(stun.Message) - _, err := m1.Write(m.Raw) - require.NoError(t, err) - require.True(t, UseCandidate().IsSet(m1)) + require.NoError(t, UseCandidate().AddTo(msg)) + msg.Encode() + + msg2 := &stun.Message{Raw: append([]byte{}, msg.Raw...)} + require.NoError(t, msg2.Decode()) + require.True(t, UseCandidate().IsSet(msg2)) }