Update lint rules, force testify/assert for tests

Use testify's assert package instead of the standard library's testing
package.
This commit is contained in:
Joe Turki
2025-04-07 02:11:06 +02:00
parent f00fc07896
commit 8867eb8597
23 changed files with 620 additions and 1535 deletions

View File

@@ -19,12 +19,16 @@ linters-settings:
recommendations:
- errors
forbidigo:
analyze-types: true
forbid:
- ^fmt.Print(f|ln)?$
- ^log.(Panic|Fatal|Print)(f|ln)?$
- ^os.Exit$
- ^panic$
- ^print(ln)?$
- p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$
pkg: ^testing$
msg: "use testify/assert instead"
varnamelen:
max-distance: 12
min-name-length: 2
@@ -127,9 +131,12 @@ issues:
exclude-dirs-use-default: false
exclude-rules:
# Allow complex tests and examples, better to be self contained
- path: (examples|main\.go|_test\.go)
- path: (examples|main\.go)
linters:
- gocognit
- forbidigo
- path: _test\.go
linters:
- gocognit
# Allow forbidden identifiers in CLI commands

View File

@@ -4,10 +4,11 @@
package stun
import (
"errors"
"io"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMappedAddress(t *testing.T) {
@@ -16,48 +17,32 @@ func TestMappedAddress(t *testing.T) {
IP: net.ParseIP("122.12.34.5"),
Port: 5412,
}
if addr.String() != "122.12.34.5:5412" {
t.Error("bad string", addr)
}
assert.Equal(t, "122.12.34.5:5412", addr.String(), "bad string")
t.Run("Bad length", func(t *testing.T) {
badAddr := &MappedAddress{
IP: net.IP{1, 2, 3},
}
if err := badAddr.AddTo(msg); err == nil {
t.Error("should error")
}
assert.Error(t, badAddr.AddTo(msg), "should error")
})
t.Run("AddTo", func(t *testing.T) {
if err := addr.AddTo(msg); err != nil {
t.Error(err)
}
assert.NoError(t, addr.AddTo(msg))
t.Run("GetFrom", func(t *testing.T) {
got := new(MappedAddress)
if err := got.GetFrom(msg); err != nil {
t.Error(err)
}
if !got.IP.Equal(addr.IP) {
t.Error("got bad IP: ", got.IP)
}
assert.NoError(t, got.GetFrom(msg))
assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP)
t.Run("Not found", func(t *testing.T) {
message := new(Message)
if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) {
t.Error("should be not found: ", err)
}
assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found")
})
t.Run("Bad family", func(t *testing.T) {
v, _ := msg.Attributes.Get(AttrMappedAddress)
v.Value[0] = 32
if err := got.GetFrom(msg); err == nil {
t.Error("should error")
}
assert.Error(t, got.GetFrom(msg), "should error")
})
t.Run("Bad length", func(t *testing.T) {
message := new(Message)
message.Add(AttrMappedAddress, []byte{1, 2, 3})
if err := got.GetFrom(message); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Errorf("<%s> should be <%s>", err, io.ErrUnexpectedEOF)
}
assert.ErrorIs(t, got.GetFrom(message), io.ErrUnexpectedEOF)
})
})
})
@@ -70,22 +55,14 @@ func TestMappedAddressV6(t *testing.T) { //nolint:dupl
Port: 5412,
}
t.Run("AddTo", func(t *testing.T) {
if err := addr.AddTo(m); err != nil {
t.Error(err)
}
assert.NoError(t, addr.AddTo(m))
t.Run("GetFrom", func(t *testing.T) {
got := new(MappedAddress)
if err := got.GetFrom(m); err != nil {
t.Error(err)
}
if !got.IP.Equal(addr.IP) {
t.Error("got bad IP: ", got.IP)
}
assert.NoError(t, got.GetFrom(m))
assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP)
t.Run("Not found", func(t *testing.T) {
message := new(Message)
if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) {
t.Error("should be not found: ", err)
}
assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found")
})
})
})
@@ -98,22 +75,14 @@ func TestAlternateServer(t *testing.T) { //nolint:dupl
Port: 5412,
}
t.Run("AddTo", func(t *testing.T) {
if err := addr.AddTo(m); err != nil {
t.Error(err)
}
assert.NoError(t, addr.AddTo(m))
t.Run("GetFrom", func(t *testing.T) {
got := new(AlternateServer)
if err := got.GetFrom(m); err != nil {
t.Error(err)
}
if !got.IP.Equal(addr.IP) {
t.Error("got bad IP: ", got.IP)
}
assert.NoError(t, got.GetFrom(m))
assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP)
t.Run("Not found", func(t *testing.T) {
message := new(Message)
if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) {
t.Error("should be not found: ", err)
}
assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found")
})
})
})
@@ -126,22 +95,14 @@ func TestOtherAddress(t *testing.T) { //nolint:dupl
Port: 5412,
}
t.Run("AddTo", func(t *testing.T) {
if err := addr.AddTo(m); err != nil {
t.Error(err)
}
assert.NoError(t, addr.AddTo(m))
t.Run("GetFrom", func(t *testing.T) {
got := new(OtherAddress)
if err := got.GetFrom(m); err != nil {
t.Error(err)
}
if !got.IP.Equal(addr.IP) {
t.Error("got bad IP: ", got.IP)
}
assert.NoError(t, got.GetFrom(m))
assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP)
t.Run("Not found", func(t *testing.T) {
message := new(Message)
if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) {
t.Error("should be not found: ", err)
}
assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found")
})
})
})

View File

@@ -7,84 +7,44 @@ import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestAgent_ProcessInTransaction(t *testing.T) {
msg := New()
agent := NewAgent(func(e Event) {
if e.Error != nil {
t.Errorf("got error: %s", e.Error)
}
if !e.Message.Equal(msg) {
t.Errorf("%s (got) != %s (expected)", e.Message, msg)
}
assert.NoError(t, e.Error, "got error")
assert.True(t, e.Message.Equal(msg), "%s (got) != %s (expected)", e.Message, msg)
})
if err := msg.NewTransactionID(); err != nil {
t.Fatal(err)
}
if err := agent.Start(msg.TransactionID, time.Time{}); err != nil {
t.Fatal(err)
}
if err := agent.Process(msg); err != nil {
t.Error(err)
}
if err := agent.Close(); err != nil {
t.Error(err)
}
assert.NoError(t, msg.NewTransactionID())
assert.NoError(t, agent.Start(msg.TransactionID, time.Time{}))
assert.NoError(t, agent.Process(msg))
assert.NoError(t, agent.Close())
}
func TestAgent_Process(t *testing.T) {
msg := New()
agent := NewAgent(func(e Event) {
if e.Error != nil {
t.Errorf("got error: %s", e.Error)
}
if !e.Message.Equal(msg) {
t.Errorf("%s (got) != %s (expected)", e.Message, msg)
}
assert.NoError(t, e.Error, "got error")
assert.True(t, e.Message.Equal(msg), "%s (got) != %s (expected)", e.Message, msg)
})
if err := msg.NewTransactionID(); err != nil {
t.Fatal(err)
}
if err := agent.Process(msg); err != nil {
t.Error(err)
}
if err := agent.Close(); err != nil {
t.Error(err)
}
if err := agent.Process(msg); !errors.Is(err, ErrAgentClosed) {
t.Errorf("closed agent should return <%s>, but got <%s>",
ErrAgentClosed, err,
)
}
assert.NoError(t, msg.NewTransactionID())
assert.NoError(t, agent.Process(msg))
assert.NoError(t, agent.Close())
assert.ErrorIs(t, agent.Process(msg), ErrAgentClosed)
}
func TestAgent_Start(t *testing.T) {
agent := NewAgent(nil)
id := NewTransactionID()
deadline := time.Now().AddDate(0, 0, 1)
if err := agent.Start(id, deadline); err != nil {
t.Errorf("failed to statt transaction: %s", err)
}
if err := agent.Start(id, deadline); !errors.Is(err, ErrTransactionExists) {
t.Errorf("duplicate start should return <%s>, got <%s>",
ErrTransactionExists, err,
)
}
if err := agent.Close(); err != nil {
t.Error(err)
}
assert.NoError(t, agent.Start(id, deadline), "failed to start transaction")
assert.ErrorIs(t, agent.Start(id, deadline), ErrTransactionExists)
assert.NoError(t, agent.Close())
id = NewTransactionID()
if err := agent.Start(id, deadline); !errors.Is(err, ErrAgentClosed) {
t.Errorf("start on closed agent should return <%s>, got <%s>",
ErrAgentClosed, err,
)
}
if err := agent.SetHandler(nil); !errors.Is(err, ErrAgentClosed) {
t.Errorf("SetHandler on closed agent should return <%s>, got <%s>",
ErrAgentClosed, err,
)
}
assert.ErrorIs(t, agent.Start(id, deadline), ErrAgentClosed)
assert.ErrorIs(t, agent.SetHandler(nil), ErrAgentClosed)
}
func TestAgent_Stop(t *testing.T) {
@@ -92,36 +52,20 @@ func TestAgent_Stop(t *testing.T) {
agent := NewAgent(func(e Event) {
called <- e
})
if err := agent.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) {
t.Fatalf("unexpected error: %s, should be %s", err, ErrTransactionNotExists)
}
assert.ErrorIs(t, agent.Stop(transactionID{}), ErrTransactionNotExists)
id := NewTransactionID()
timeout := time.Millisecond * 200
if err := agent.Start(id, time.Now().Add(timeout)); err != nil {
t.Fatal(err)
}
if err := agent.Stop(id); err != nil {
t.Fatal(err)
}
assert.NoError(t, agent.Start(id, time.Now().Add(timeout)))
assert.NoError(t, agent.Stop(id))
select {
case e := <-called:
if !errors.Is(e.Error, ErrTransactionStopped) {
t.Fatalf("unexpected error: %s, should be %s",
e.Error, ErrTransactionStopped,
)
}
assert.ErrorIs(t, e.Error, ErrTransactionStopped)
case <-time.After(timeout * 2):
t.Fatal("timed out")
}
if err := agent.Close(); err != nil {
t.Fatal(err)
}
if err := agent.Close(); !errors.Is(err, ErrAgentClosed) {
t.Fatalf("a.Close returned %s instead of %s", err, ErrAgentClosed)
}
if err := agent.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) {
t.Fatalf("unexpected error: %s, should be %s", err, ErrAgentClosed)
assert.Fail(t, "timed out")
}
assert.NoError(t, agent.Close())
assert.ErrorIs(t, agent.Close(), ErrAgentClosed)
assert.ErrorIs(t, agent.Stop(transactionID{}), ErrAgentClosed)
}
func TestAgent_GC(t *testing.T) { //nolint:cyclop
@@ -136,60 +80,41 @@ func TestAgent_GC(t *testing.T) { //nolint:cyclop
agent.SetHandler(func(e Event) { //nolint:errcheck,gosec
id := e.TransactionID
shouldTimeOut, found := shouldTimeOutID[id]
if !found {
t.Error("unexpected transaction ID")
}
if shouldTimeOut && !errors.Is(e.Error, ErrTransactionTimeOut) {
t.Errorf("%x should time out, but got %v", id, e.Error)
}
if !shouldTimeOut && errors.Is(e.Error, ErrTransactionTimeOut) {
t.Errorf("%x should not time out, but got %v", id, e.Error)
assert.True(t, found, "unexpected transaction ID")
if shouldTimeOut {
assert.ErrorIs(t, e.Error, ErrTransactionTimeOut, "%x should time out", id)
} else {
assert.False(t, errors.Is(e.Error, ErrTransactionTimeOut), "%x should not time out", id)
}
})
for i := 0; i < 5; i++ {
id := NewTransactionID()
shouldTimeOutID[id] = false
if err := agent.Start(id, deadline); err != nil {
t.Fatal(err)
}
assert.NoError(t, agent.Start(id, deadline))
}
for i := 0; i < 5; i++ {
id := NewTransactionID()
shouldTimeOutID[id] = true
if err := agent.Start(id, deadlineNotGC); err != nil {
t.Fatal(err)
}
}
if err := agent.Collect(gcDeadline); err != nil {
t.Fatal(err)
}
if err := agent.Close(); err != nil {
t.Error(err)
}
if err := agent.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) {
t.Errorf("should <%s>, but got <%s>", ErrAgentClosed, err)
assert.NoError(t, agent.Start(id, deadlineNotGC))
}
assert.NoError(t, agent.Collect(gcDeadline))
assert.NoError(t, agent.Close())
assert.ErrorIs(t, agent.Collect(gcDeadline), ErrAgentClosed)
}
func BenchmarkAgent_GC(b *testing.B) {
agent := NewAgent(nil)
deadline := time.Now().AddDate(0, 0, 1)
for i := 0; i < agentCollectCap; i++ {
if err := agent.Start(NewTransactionID(), deadline); err != nil {
b.Fatal(err)
}
assert.NoError(b, agent.Start(NewTransactionID(), deadline))
}
defer func() {
if err := agent.Close(); err != nil {
b.Error(err)
}
assert.NoError(b, agent.Close())
}()
b.ReportAllocs()
gcDeadline := deadline.Add(-time.Second)
for i := 0; i < b.N; i++ {
if err := agent.Collect(gcDeadline); err != nil {
b.Fatal(err)
}
assert.NoError(b, agent.Collect(gcDeadline))
}
}
@@ -197,20 +122,14 @@ func BenchmarkAgent_Process(b *testing.B) {
agent := NewAgent(nil)
deadline := time.Now().AddDate(0, 0, 1)
for i := 0; i < 1000; i++ {
if err := agent.Start(NewTransactionID(), deadline); err != nil {
b.Fatal(err)
}
assert.NoError(b, agent.Start(NewTransactionID(), deadline))
}
defer func() {
if err := agent.Close(); err != nil {
b.Error(err)
}
assert.NoError(b, agent.Close())
}()
b.ReportAllocs()
m := MustBuild(TransactionID)
for i := 0; i < b.N; i++ {
if err := agent.Process(m); err != nil {
b.Fatal(err)
}
assert.NoError(b, agent.Process(m))
}
}

View File

@@ -6,7 +6,11 @@
package stun
import "testing"
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAttrOverflowErr_Error(t *testing.T) {
err := AttrOverflowErr{
@@ -14,9 +18,7 @@ func TestAttrOverflowErr_Error(t *testing.T) {
Max: 50,
Type: AttrLifetime,
}
if err.Error() != "incorrect length of LIFETIME attribute: 100 exceeds maximum 50" {
t.Error("bad error string", err)
}
assert.Equal(t, "incorrect length of LIFETIME attribute: 100 exceeds maximum 50", err.Error())
}
func TestAttrLengthErr_Error(t *testing.T) {
@@ -25,7 +27,5 @@ func TestAttrLengthErr_Error(t *testing.T) {
Expected: 15,
Got: 99,
}
if err.Error() != "incorrect length of ERROR-CODE attribute: got 99, expected 15" {
t.Errorf("bad error string: %s", err)
}
assert.Equal(t, "incorrect length of ERROR-CODE attribute: got 99, expected 15", err.Error())
}

View File

@@ -6,6 +6,8 @@ package stun
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
)
func BenchmarkMessage_GetNotFound(b *testing.B) {
@@ -31,16 +33,10 @@ func TestRawAttribute_AddTo(t *testing.T) {
Type: AttrData,
Value: v,
})
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
gotV, gotErr := m.Get(AttrData)
if gotErr != nil {
t.Fatal(gotErr)
}
if !bytes.Equal(gotV, v) {
t.Error("value mismatch")
}
assert.NoError(t, gotErr)
assert.True(t, bytes.Equal(gotV, v), "value mismatch")
}
func TestMessage_GetNoAllocs(t *testing.T) {
@@ -52,17 +48,13 @@ func TestMessage_GetNoAllocs(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() {
msg.Get(AttrSoftware) //nolint:errcheck,gosec
})
if allocs > 0 {
t.Error("allocated memory, but should not")
}
assert.Zero(t, allocs, "allocated memory, but should not")
})
t.Run("Not found", func(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() {
msg.Get(AttrOrigin) //nolint:errcheck,gosec
})
if allocs > 0 {
t.Error("allocated memory, but should not")
}
assert.Zero(t, allocs, "allocated memory, but should not")
})
}
@@ -83,11 +75,8 @@ func TestPadding(t *testing.T) {
{40, 40}, // 10
}
for i, c := range tt {
if got := nearestPaddedValueLength(c.in); got != c.out {
t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)",
i, c.in, got, c.out,
)
}
got := nearestPaddedValueLength(c.in)
assert.Equal(t, c.out, got, "[%d]: padd(%d)", i, c.in)
}
}
@@ -102,9 +91,8 @@ func TestAttrTypeRange(t *testing.T) {
a := a
t.Run(a.String(), func(t *testing.T) {
a := a
if a.Optional() || !a.Required() {
t.Error("should be required")
}
assert.True(t, a.Required(), "should be required")
assert.False(t, a.Optional(), "should be required")
})
}
for _, a := range []AttrType{
@@ -114,9 +102,8 @@ func TestAttrTypeRange(t *testing.T) {
} {
a := a
t.Run(a.String(), func(t *testing.T) {
if a.Required() || !a.Optional() {
t.Error("should be optional")
}
assert.False(t, a.Required(), "should be optional")
assert.True(t, a.Optional(), "should be optional")
})
}
}

View File

@@ -18,6 +18,8 @@ import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var (
@@ -88,13 +90,9 @@ func BenchmarkClient_Do(b *testing.B) {
client, err := NewClient(noopConnection{},
WithAgent(agent),
)
if err != nil {
log.Fatal(err)
}
assert.NoError(b, err)
defer func() {
if closeErr := client.Close(); closeErr != nil {
panic(closeErr)
}
assert.NoError(b, client.Close())
}()
noopF := func(Event) {
@@ -163,9 +161,8 @@ func TestClosedOrPanic(t *testing.T) {
func() {
defer func() {
r, ok := recover().(error)
if !ok || !errors.Is(r, io.EOF) {
t.Error(r)
}
assert.True(t, ok, "should be error")
assert.ErrorIs(t, r, io.EOF)
}()
closedOrPanic(io.EOF)
}()
@@ -203,46 +200,32 @@ func TestClient_Start(t *testing.T) { //nolint:cyclop
},
}
client, err := NewClient(conn)
if err != nil {
log.Fatal(err)
}
assert.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Error(err)
}
if err := client.Close(); err == nil {
t.Error("second close should fail")
}
if err := client.Do(MustBuild(TransactionID), nil); err == nil {
t.Error("Do after Close should fail")
}
assert.NoError(t, client.Close())
assert.Error(t, client.Close(), "second close should fail")
assert.Error(t, client.Do(MustBuild(TransactionID), nil), "Do after Close should fail")
}()
msg := MustBuild(response, BindingRequest)
t.Log("init")
got := make(chan struct{})
write <- struct{}{}
t.Log("starting the first transaction")
if err := client.Start(msg, func(event Event) {
assert.NoError(t, client.Start(msg, func(event Event) {
t.Log("got first transaction callback")
if event.Error != nil {
t.Error(event.Error)
}
assert.NoError(t, event.Error)
got <- struct{}{}
}); err != nil {
t.Error(err)
}
}))
t.Log("starting the second transaction")
if err := client.Start(msg, func(Event) {
t.Error("should not be called")
}); !errors.Is(err, ErrTransactionExists) {
t.Errorf("unexpected error %v", err)
}
assert.ErrorIs(t, client.Start(msg, func(Event) {
assert.Fail(t, "should not be called")
}), ErrTransactionExists)
read <- struct{}{}
select {
case <-got:
// pass
case <-time.After(time.Millisecond * 10):
t.Error("timed out")
assert.Fail(t, "timed out")
}
}
@@ -256,34 +239,20 @@ func TestClient_Do(t *testing.T) {
},
}
client, err := NewClient(conn)
if err != nil {
log.Fatal(err)
}
assert.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Error(err)
}
if err := client.Close(); err == nil {
t.Error("second close should fail")
}
if err := client.Do(MustBuild(TransactionID), nil); err == nil {
t.Error("Do after Close should fail")
}
assert.NoError(t, client.Close())
assert.Error(t, client.Close(), "second close should fail")
assert.Error(t, client.Do(MustBuild(TransactionID), nil), "Do after Close should fail")
}()
m := MustBuild(
NewTransactionIDSetter(response.TransactionID),
)
if err := client.Do(m, func(event Event) {
if event.Error != nil {
t.Error(event.Error)
}
}); err != nil {
t.Error(err)
}
assert.NoError(t, client.Do(m, func(event Event) {
assert.NoError(t, event.Error)
}))
m = MustBuild(TransactionID)
if err := client.Do(m, nil); err != nil {
t.Error(err)
}
assert.NoError(t, client.Do(m, nil))
}
func TestCloseErr_Error(t *testing.T) {
@@ -299,11 +268,7 @@ func TestCloseErr_Error(t *testing.T) {
ConnectionErr: io.ErrUnexpectedEOF,
}, "failed to close: unexpected EOF (connection), <nil> (agent)"},
} {
if out := testCase.Err.Error(); out != testCase.Out {
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
id, testCase.Err, out, testCase.Out,
)
}
assert.Equal(t, testCase.Out, testCase.Err.Error(), "[%d]: Error(%#v)", id, testCase.Err)
}
}
@@ -320,11 +285,7 @@ func TestStopErr_Error(t *testing.T) {
Cause: io.ErrUnexpectedEOF,
}, "error while stopping due to unexpected EOF: <nil>"},
} {
if out := testcase.Err.Error(); out != testcase.Out {
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
id, testcase.Err, out, testcase.Out,
)
}
assert.Equal(t, testcase.Out, testcase.Err.Error(), "[%d]: Error(%#v)", id, testcase.Err)
}
}
@@ -365,25 +326,15 @@ func TestClientAgentError(t *testing.T) {
startErr: io.ErrUnexpectedEOF,
}),
)
if err != nil {
log.Fatal(err)
}
assert.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Error(err)
}
assert.NoError(t, client.Close())
}()
m := MustBuild(NewTransactionIDSetter(response.TransactionID))
if err := client.Do(m, nil); err != nil {
t.Error(err)
}
if err := client.Do(m, func(event Event) {
if event.Error == nil {
t.Error("error expected")
}
}); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Error("error expected")
}
assert.NoError(t, client.Do(m, nil))
assert.ErrorIs(t, client.Do(m, func(event Event) {
assert.Error(t, event.Error, "error expected")
}), io.ErrUnexpectedEOF)
}
func TestClientConnErr(t *testing.T) {
@@ -393,21 +344,13 @@ func TestClientConnErr(t *testing.T) {
},
}
client, err := NewClient(conn)
if err != nil {
log.Fatal(err)
}
assert.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Error(err)
}
assert.NoError(t, client.Close())
}()
m := MustBuild(TransactionID)
if err := client.Do(m, nil); err == nil {
t.Error("error expected")
}
if err := client.Do(m, NoopHandler()); err == nil {
t.Error("error expected")
}
assert.Error(t, client.Do(m, nil), "error expected")
assert.Error(t, client.Do(m, NoopHandler()), "error expected")
}
func TestClientConnErrStopErr(t *testing.T) {
@@ -421,26 +364,19 @@ func TestClientConnErrStopErr(t *testing.T) {
stopErr: io.ErrUnexpectedEOF,
}),
)
if err != nil {
log.Fatal(err)
}
assert.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Error(err)
}
assert.NoError(t, client.Close())
}()
m := MustBuild(TransactionID)
if err := client.Do(m, NoopHandler()); err == nil {
t.Error("error expected")
}
assert.Error(t, client.Do(m, NoopHandler()), "error expected")
}
func TestCallbackWaitHandler_setCallback(t *testing.T) {
c := callbackWaitHandler{}
defer func() {
if err := recover(); err == nil {
t.Error("should panic")
}
err := recover()
assert.NotNil(t, err, "should panic")
}()
c.setCallback(nil)
}
@@ -450,56 +386,39 @@ func TestCallbackWaitHandler_HandleEvent(t *testing.T) {
cond: sync.NewCond(new(sync.Mutex)),
}
defer func() {
if err := recover(); err == nil {
t.Error("should panic")
}
err := recover()
assert.NotNil(t, err, "should panic")
}()
c.HandleEvent(Event{})
}
func TestNewClientNoConnection(t *testing.T) {
c, err := NewClient(nil)
if c != nil {
t.Error("c should be nil")
}
if !errors.Is(err, ErrNoConnection) {
t.Error("bad error")
}
assert.Nil(t, c, "c should be nil")
assert.ErrorIs(t, err, ErrNoConnection, "bad error")
}
func TestDial(t *testing.T) {
c, err := Dial("udp4", "localhost:3458")
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
defer func() {
if err = c.Close(); err != nil {
t.Error(err)
}
assert.NoError(t, c.Close())
}()
}
func TestDialURI(t *testing.T) {
u, err := ParseURI("stun:localhost")
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
c, err := DialURI(u, &DialConfig{})
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
defer func() {
if err = c.Close(); err != nil {
t.Error(err)
}
assert.NoError(t, c.Close())
}()
}
func TestDialError(t *testing.T) {
_, err := Dial("bad?network", "?????")
if err == nil {
t.Fatal("error expected")
}
assert.Error(t, err, "error expected")
}
func TestClientCloseErr(t *testing.T) {
@@ -516,13 +435,11 @@ func TestClientCloseErr(t *testing.T) {
closeErr: io.ErrUnexpectedEOF,
}),
)
if err != nil {
log.Fatal(err)
}
assert.NoError(t, err)
defer func() {
if err, ok := c.Close().(CloseErr); !ok || !errors.Is(err.AgentErr, io.ErrUnexpectedEOF) { //nolint:errorlint
t.Error("unexpected close err")
}
err, ok := c.Close().(CloseErr) //nolint:errorlint
assert.True(t, ok, "should be CloseErr")
assert.ErrorIs(t, err.AgentErr, io.ErrUnexpectedEOF, "unexpected close err")
}()
}
@@ -542,12 +459,8 @@ func TestWithNoConnClose(t *testing.T) {
}),
WithNoConnClose(),
)
if err != nil {
log.Fatal(err)
}
if err := c.Close(); err != nil {
t.Error("unexpected non-nil error")
}
assert.NoError(t, err)
assert.NoError(t, c.Close(), "unexpected non-nil error")
}
type gcWaitAgent struct {
@@ -598,28 +511,20 @@ func TestClientGC(t *testing.T) {
WithAgent(agent),
WithTimeoutRate(time.Millisecond),
)
if err != nil {
log.Fatal(err)
}
assert.NoError(t, err)
defer func() {
if err = c.Close(); err != nil {
t.Error(err)
}
assert.NoError(t, c.Close())
}()
select {
case <-agent.gc:
case <-time.After(time.Millisecond * 200):
t.Error("timed out")
assert.Fail(t, "timed out")
}
}
func TestClientCheckInit(t *testing.T) {
if err := (&Client{}).Indicate(nil); !errors.Is(err, ErrClientNotInitialized) {
t.Error("unexpected error")
}
if err := (&Client{}).Do(nil, nil); !errors.Is(err, ErrClientNotInitialized) {
t.Error("unexpected error")
}
assert.ErrorIs(t, (&Client{}).Indicate(nil), ErrClientNotInitialized)
assert.ErrorIs(t, (&Client{}).Do(nil, nil), ErrClientNotInitialized)
}
func captureLog() (*bytes.Buffer, func()) {
@@ -645,9 +550,7 @@ func TestClientFinalizer(t *testing.T) {
},
}
client, err := NewClient(conn)
if err != nil {
log.Panic(err)
}
assert.NoError(t, err)
clientFinalizer(client)
clientFinalizer(client)
response := MustBuild(TransactionID, BindingSuccess)
@@ -663,9 +566,7 @@ func TestClientFinalizer(t *testing.T) {
closeErr: io.ErrUnexpectedEOF,
}),
)
if err != nil {
log.Panic(err)
}
assert.NoError(t, err)
clientFinalizer(client)
reader := bufio.NewScanner(buf)
var lines int
@@ -676,17 +577,11 @@ func TestClientFinalizer(t *testing.T) {
"<nil> (connection), unexpected EOF (agent)",
}
for reader.Scan() {
if reader.Text() != expectedLines[lines] {
t.Error(reader.Text(), "!=", expectedLines[lines])
}
assert.Equal(t, expectedLines[lines], reader.Text())
lines++
}
if reader.Err() != nil {
t.Error(err)
}
if lines != 3 {
t.Error("incorrect count of log lines:", lines)
}
assert.NoError(t, reader.Err())
assert.Equal(t, 3, lines, "incorrect count of log lines")
}
func TestCallbackWaitHandler(*testing.T) {
@@ -784,9 +679,7 @@ func TestClientRetransmission(t *testing.T) {
response.Encode()
connL, connR := net.Pipe()
defer func() {
if closeErr := connL.Close(); closeErr != nil {
panic(closeErr)
}
assert.NoError(t, connR.Close())
}()
collector := new(manualCollector)
clock := &manualClock{current: time.Now()}
@@ -814,36 +707,22 @@ func TestClientRetransmission(t *testing.T) {
WithCollector(collector),
WithRTO(time.Millisecond),
)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
client.SetRTO(time.Second)
gotReads := make(chan struct{})
go func() {
buf := make([]byte, 1500)
readN, readErr := connL.Read(buf)
if readErr != nil {
t.Error(readErr)
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
assert.NoError(t, readErr)
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
readN, readErr = connL.Read(buf)
if readErr != nil {
t.Error(readErr)
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
assert.NoError(t, readErr)
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
gotReads <- struct{}{}
}()
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
if event.Error != nil {
t.Error("failed")
}
}); doErr != nil {
t.Fatal(doErr)
}
assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
assert.NoError(t, event.Error, "failed")
}))
<-gotReads
}
@@ -854,9 +733,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
response.Encode()
connL, connR := net.Pipe()
defer func() {
if closeErr := connL.Close(); closeErr != nil {
panic(closeErr)
}
assert.NoError(t, connR.Close())
}()
collector := new(manualCollector)
clock := &manualClock{current: time.Now()}
@@ -874,9 +751,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
WithClock(clock),
WithCollector(collector),
)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
client.SetRTO(time.Second)
conns := new(sync.WaitGroup)
wg := new(sync.WaitGroup)
@@ -891,29 +766,21 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
if errors.Is(readErr, io.EOF) {
break
}
t.Error(readErr)
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
assert.NoError(t, readErr)
}
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if doErr := client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) {
if event.Error != nil {
t.Error("failed")
}
}); doErr != nil {
t.Error(doErr)
}
assert.NoError(t, client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) {
assert.NoError(t, event.Error, "failed")
}))
}()
}
wg.Wait()
if connErr := connR.Close(); connErr != nil {
t.Error(connErr)
}
assert.NoError(t, connR.Close())
conns.Wait()
}
@@ -942,24 +809,22 @@ func (c errorCollector) Close() error { return c.closeError }
func TestNewClient(t *testing.T) {
t.Run("SetCallbackError", func(t *testing.T) {
setHandlerError := errClientSetHandler
if _, createErr := NewClient(noopConnection{},
_, createErr := NewClient(noopConnection{},
WithAgent(&errorAgent{
setHandlerError: setHandlerError,
}),
); !errors.Is(createErr, setHandlerError) {
t.Errorf("unexpected error returned: %v", createErr)
}
)
assert.ErrorIs(t, createErr, setHandlerError, "unexpected error returned")
})
t.Run("CollectorStartError", func(t *testing.T) {
startError := errClientStart
if _, createErr := NewClient(noopConnection{},
_, createErr := NewClient(noopConnection{},
WithAgent(&TestAgent{}),
WithCollector(errorCollector{
startError: startError,
}),
); !errors.Is(createErr, startError) {
t.Errorf("unexpected error returned: %v", createErr)
}
)
assert.ErrorIs(t, createErr, startError, "unexpected error returned")
})
}
@@ -972,13 +837,9 @@ func TestClient_Close(t *testing.T) {
}),
WithAgent(&TestAgent{}),
)
if createErr != nil {
t.Errorf("unexpected create error returned: %v", createErr)
}
assert.NoError(t, createErr, "unexpected create error returned")
gotCloseErr := c.Close()
if !errors.Is(gotCloseErr, closeErr) {
t.Errorf("unexpected close error returned: %v", gotCloseErr)
}
assert.ErrorIs(t, gotCloseErr, closeErr, "unexpected close error returned")
})
}
@@ -992,19 +853,13 @@ func TestClientDefaultHandler(t *testing.T) {
client, createErr := NewClient(noopConnection{},
WithAgent(agent),
WithHandler(func(e Event) {
if called {
t.Error("should not be called twice")
}
assert.False(t, called, "should not be called twice")
called = true
if e.TransactionID != id {
t.Error("wrong transaction ID")
}
assert.Equal(t, id, e.TransactionID, "wrong transaction ID")
handlerCalled <- struct{}{}
}),
)
if createErr != nil {
t.Fatal(createErr)
}
assert.NoError(t, createErr)
go func() {
agent.h(Event{
TransactionID: id,
@@ -1014,11 +869,9 @@ func TestClientDefaultHandler(t *testing.T) {
case <-handlerCalled:
// pass
case <-time.After(time.Millisecond * 100):
t.Fatal("timed out")
}
if closeErr := client.Close(); closeErr != nil {
t.Error(closeErr)
assert.Fail(t, "timed out")
}
assert.NoError(t, client.Close())
// Handler call should be ignored.
agent.h(Event{})
}
@@ -1030,15 +883,9 @@ func TestClientClosedStart(t *testing.T) {
c, createErr := NewClient(noopConnection{},
WithAgent(a),
)
if createErr != nil {
t.Fatal(createErr)
}
if closeErr := c.Close(); closeErr != nil {
t.Error(closeErr)
}
if startErr := c.start(&clientTransaction{}); !errors.Is(startErr, ErrClientClosed) {
t.Error("should error")
}
assert.NoError(t, createErr)
assert.NoError(t, c.Close())
assert.ErrorIs(t, c.start(&clientTransaction{}), ErrClientClosed)
}
func TestWithNoRetransmit(t *testing.T) {
@@ -1046,9 +893,7 @@ func TestWithNoRetransmit(t *testing.T) {
response.Encode()
connL, connR := net.Pipe()
defer func() {
if closeErr := connL.Close(); closeErr != nil {
panic(closeErr)
}
assert.NoError(t, connL.Close())
}()
collector := new(manualCollector)
clock := &manualClock{current: time.Now()}
@@ -1062,7 +907,7 @@ func TestWithNoRetransmit(t *testing.T) {
Error: ErrTransactionTimeOut,
})
} else {
t.Error("there should be no second attempt")
assert.Fail(t, "there should be no second attempt")
go agent.h(Event{
TransactionID: id,
Error: ErrTransactionTimeOut,
@@ -1078,28 +923,18 @@ func TestWithNoRetransmit(t *testing.T) {
WithRTO(0),
WithNoRetransmit,
)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
gotReads := make(chan struct{})
go func() {
buf := make([]byte, 1500)
readN, readErr := connL.Read(buf)
if readErr != nil {
t.Error(readErr)
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
assert.NoError(t, readErr)
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
gotReads <- struct{}{}
}()
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
if !errors.Is(event.Error, ErrTransactionTimeOut) {
t.Error("unexpected error")
}
}); doErr != nil {
t.Fatal(err)
}
assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
assert.ErrorIs(t, event.Error, ErrTransactionTimeOut, "unexpected error")
}))
<-gotReads
}
@@ -1114,9 +949,7 @@ func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
response.Encode()
connL, connR := net.Pipe()
defer func() {
if closeErr := connL.Close(); closeErr != nil {
panic(closeErr)
}
assert.NoError(t, connL.Close())
}()
collector := new(manualCollector)
shouldWait := false
@@ -1169,9 +1002,7 @@ func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
t.Log("clock locked")
<-clockLocked
t.Log("closing client")
if closeErr := client.Close(); closeErr != nil {
t.Error(closeErr)
}
assert.NoError(t, client.Close())
t.Log("client closed, unlocking clock")
clockWait <- struct{}{}
t.Log("clock unlocked")
@@ -1186,44 +1017,30 @@ func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
WithCollector(collector),
WithRTO(time.Millisecond),
)
if startClientErr != nil {
t.Fatal(startClientErr)
}
assert.NoError(t, startClientErr)
go func() {
buf := make([]byte, 1500)
readN, readErr := connL.Read(buf)
if readErr != nil {
t.Error(readErr)
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
assert.NoError(t, readErr)
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
readN, readErr = connL.Read(buf)
if readErr != nil {
t.Error(readErr)
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
assert.NoError(t, readErr)
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
gotReads <- struct{}{}
}()
t.Log("starting")
done := make(chan struct{})
go func() {
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
if !errors.Is(event.Error, ErrClientClosed) {
t.Error(event.Error)
}
}); doErr != nil {
t.Error(doErr)
}
assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
assert.ErrorIs(t, event.Error, ErrClientClosed)
}))
done <- struct{}{}
}()
select {
case <-done:
// ok
case <-time.After(time.Second * 5):
t.Error("timeout")
assert.Fail(t, "timeout")
}
}
@@ -1232,9 +1049,7 @@ func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
response.Encode()
connL, connR := net.Pipe()
defer func() {
if closeErr := connL.Close(); closeErr != nil {
panic(closeErr)
}
assert.NoError(t, connL.Close())
}()
collector := new(manualCollector)
shouldWait := false
@@ -1291,9 +1106,7 @@ func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
t.Log("clock locked")
<-clockLocked
t.Log("closing connection")
if closeErr := connL.Close(); closeErr != nil {
panic(closeErr)
}
assert.NoError(t, connL.Close())
t.Log("connection closed, unlocking clock")
clockWait <- struct{}{}
t.Log("clock unlocked")
@@ -1308,52 +1121,33 @@ func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
WithCollector(collector),
WithRTO(time.Millisecond),
)
if startClientErr != nil {
t.Fatal(startClientErr)
}
assert.NoError(t, startClientErr)
go func() {
buf := make([]byte, 1500)
readN, readErr := connL.Read(buf)
if readErr != nil {
t.Error(readErr)
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
assert.NoError(t, readErr)
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
readN, readErr = connL.Read(buf)
if readErr != nil {
t.Error(readErr)
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
assert.NoError(t, readErr)
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
gotReads <- struct{}{}
}()
t.Log("starting")
done := make(chan struct{})
go func() {
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
var e StopErr
if !errors.As(event.Error, &e) {
t.Error(event.Error)
} else {
if !errors.Is(e.Err, agentStopErr) {
t.Error("incorrect agent error")
}
if !errors.Is(e.Cause, io.ErrClosedPipe) {
t.Error("incorrect connection error")
}
}
}); doErr != nil {
t.Error(doErr)
}
assert.ErrorAs(t, event.Error, &e)
assert.ErrorIs(t, e.Err, agentStopErr, "incorrect agent error")
assert.ErrorIs(t, e.Cause, io.ErrClosedPipe, "incorrect connection error")
}))
done <- struct{}{}
}()
select {
case <-done:
// ok
case <-time.After(time.Second * 5):
t.Error("timeout")
assert.Fail(t, "timeout")
}
}
@@ -1362,9 +1156,7 @@ func TestClientRTOAgentErr(t *testing.T) {
response.Encode()
connL, connR := net.Pipe()
defer func() {
if closeErr := connL.Close(); closeErr != nil {
panic(closeErr)
}
assert.NoError(t, connL.Close())
}()
collector := new(manualCollector)
clock := callbackClock(time.Now)
@@ -1396,33 +1188,23 @@ func TestClientRTOAgentErr(t *testing.T) {
WithCollector(collector),
WithRTO(time.Millisecond),
)
if startClientErr != nil {
t.Fatal(startClientErr)
}
assert.NoError(t, startClientErr)
go func() {
buf := make([]byte, 1500)
readN, readErr := connL.Read(buf)
if readErr != nil {
t.Error(readErr)
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
assert.NoError(t, readErr)
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
gotReads <- struct{}{}
}()
t.Log("starting")
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
if !errors.Is(event.Error, agentStartErr) {
t.Error(event.Error)
}
}); doErr != nil {
t.Error(doErr)
}
assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
assert.ErrorIs(t, event.Error, agentStartErr)
}))
select {
case <-gotReads:
// ok
case <-time.After(time.Second * 5):
t.Error("reads timeout")
assert.Fail(t, "reads timeout")
}
}
@@ -1431,9 +1213,7 @@ func TestClient_HandleProcessError(t *testing.T) {
response.Encode()
connL, connR := net.Pipe()
defer func() {
if closeErr := connL.Close(); closeErr != nil {
panic(closeErr)
}
assert.NoError(t, connL.Close())
}()
collector := new(manualCollector)
clock := callbackClock(time.Now)
@@ -1451,14 +1231,10 @@ func TestClient_HandleProcessError(t *testing.T) {
WithCollector(collector),
WithRTO(time.Millisecond),
)
if startClientErr != nil {
t.Fatal(startClientErr)
}
assert.NoError(t, startClientErr)
go func() {
_, readErr := connL.Write(response.Raw)
if readErr != nil {
t.Error(readErr)
}
assert.NoError(t, readErr)
gotWrites <- struct{}{}
}()
t.Log("starting")
@@ -1466,20 +1242,16 @@ func TestClient_HandleProcessError(t *testing.T) {
case <-gotWrites:
// ok
case <-time.After(time.Second * 5):
t.Error("reads timeout")
}
if closeErr := client.Close(); closeErr != nil {
t.Error(closeErr)
assert.Fail(t, "reads timeout")
}
assert.NoError(t, client.Close())
}
func TestClientImmediateTimeout(t *testing.T) {
response := MustBuild(TransactionID, BindingSuccess)
connL, connR := net.Pipe()
defer func() {
if closeErr := connL.Close(); closeErr != nil {
panic(closeErr)
}
assert.NoError(t, connL.Close())
}()
collector := new(manualCollector)
clock := &manualClock{current: time.Now()}
@@ -1488,16 +1260,14 @@ func TestClientImmediateTimeout(t *testing.T) {
attempt := 0
agent.start = func(id [TransactionIDSize]byte, deadline time.Time) error {
if attempt == 0 {
if deadline.Before(clock.current.Add(rto / 2)) {
t.Error("deadline too fast")
}
assert.False(t, deadline.Before(clock.current.Add(rto/2)), "deadline too fast")
attempt++
go agent.h(Event{
TransactionID: id,
Message: response,
})
} else {
t.Error("there should be no second attempt")
assert.Fail(t, "there should be no second attempt")
go agent.h(Event{
TransactionID: id,
Error: ErrTransactionTimeOut,
@@ -1512,25 +1282,17 @@ func TestClientImmediateTimeout(t *testing.T) {
WithCollector(collector),
WithRTO(rto),
)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
gotReads := make(chan struct{})
go func() {
buf := make([]byte, 1500)
readN, readErr := connL.Read(buf)
if readErr != nil {
t.Error(readErr)
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
assert.NoError(t, readErr)
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
gotReads <- struct{}{}
}()
client.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec
if errors.Is(e.Error, ErrTransactionTimeOut) {
t.Error("unexpected error")
}
assert.NoError(t, e.Error, "unexpected error")
})
<-gotReads
}

View File

@@ -8,9 +8,10 @@ package stun
import (
"encoding/base64"
"errors"
"io"
"testing"
"github.com/stretchr/testify/assert"
)
func BenchmarkErrorCode_AddTo(b *testing.B) {
@@ -52,19 +53,13 @@ func TestErrorCodeAttribute_GetFrom(t *testing.T) {
m := New()
m.Add(AttrErrorCode, []byte{1})
c := new(ErrorCodeAttribute)
if err := c.GetFrom(m); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Errorf("GetFrom should return <%s>, but got <%s>",
io.ErrUnexpectedEOF, err,
)
}
assert.ErrorIs(t, c.GetFrom(m), io.ErrUnexpectedEOF)
}
func TestMessage_AddErrorCode(t *testing.T) {
m := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
t.Error(err)
}
assert.NoError(t, err)
copy(m.TransactionID[:], transactionID)
expectedCode := ErrorCode(438)
expectedReason := "Stale Nonce"
@@ -72,23 +67,13 @@ func TestMessage_AddErrorCode(t *testing.T) {
m.WriteHeader()
mRes := New()
if _, err = mRes.ReadFrom(m.reader()); err != nil {
t.Fatal(err)
}
_, err = mRes.ReadFrom(m.reader())
assert.NoError(t, err)
errCodeAttr := new(ErrorCodeAttribute)
if err = errCodeAttr.GetFrom(mRes); err != nil {
t.Error(err)
}
assert.NoError(t, errCodeAttr.GetFrom(mRes))
code := errCodeAttr.Code
if err != nil {
t.Error(err)
}
if code != expectedCode {
t.Error("bad code", code)
}
if string(errCodeAttr.Reason) != expectedReason {
t.Error("bad reason", string(errCodeAttr.Reason))
}
assert.Equal(t, expectedCode, code, "bad code")
assert.Equal(t, expectedReason, string(errCodeAttr.Reason), "bad reason")
}
func TestErrorCode(t *testing.T) {
@@ -96,19 +81,11 @@ func TestErrorCode(t *testing.T) {
Code: 404,
Reason: []byte("not found!"),
}
if attr.String() != "404: not found!" {
t.Error("bad string", attr)
}
assert.Equal(t, "404: not found!", attr.String(), "bad string")
m := New()
cod := ErrorCode(666)
if err := cod.AddTo(m); !errors.Is(err, ErrNoDefaultReason) {
t.Error("should be ErrNoDefaultReason", err)
}
if err := attr.GetFrom(m); err == nil {
t.Error("attr should not be in message")
}
assert.ErrorIs(t, cod.AddTo(m), ErrNoDefaultReason, "should be ErrNoDefaultReason")
assert.Error(t, attr.GetFrom(m), "attr should not be in message")
attr.Reason = make([]byte, 2048)
if err := attr.AddTo(m); err == nil {
t.Error("should error")
}
assert.Error(t, attr.AddTo(m), "should error")
}

View File

@@ -6,6 +6,8 @@ package stun
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestDecodeErr_IsInvalidCookie(t *testing.T) {
@@ -14,25 +16,13 @@ func TestDecodeErr_IsInvalidCookie(t *testing.T) {
decoded := new(Message)
m.Raw[4] = 55
_, err := decoded.Write(m.Raw)
if err == nil {
t.Fatal("should error")
}
assert.Error(t, err, "should error")
expected := "BadFormat for message/cookie: " +
"3712a442 is invalid magic cookie (should be 2112a442)"
if err.Error() != expected {
t.Error(err, "!=", expected)
}
assert.Equal(t, expected, err.Error(), "error message mismatch")
var dErr *DecodeErr
if !errors.As(err, &dErr) {
t.Error("not decode error")
}
if !dErr.IsInvalidCookie() {
t.Error("IsInvalidCookie = false, should be true")
}
if !dErr.IsPlaceChildren("cookie") {
t.Error("bad children")
}
if !dErr.IsPlaceParent("message") {
t.Error("bad parent")
}
assert.True(t, errors.As(err, &dErr), "not decode error")
assert.True(t, dErr.IsInvalidCookie(), "IsInvalidCookie = false, should be true")
assert.True(t, dErr.IsPlaceChildren("cookie"), "bad children")
assert.True(t, dErr.IsPlaceParent("message"), "bad parent")
}

View File

@@ -9,6 +9,8 @@ package stun
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func BenchmarkFingerprint_AddTo(b *testing.B) {
@@ -36,26 +38,18 @@ func TestFingerprint_Check(t *testing.T) {
m.WriteHeader()
Fingerprint.AddTo(m) //nolint:errcheck,gosec
m.WriteHeader()
if err := Fingerprint.Check(m); err != nil {
t.Error(err)
}
assert.NoError(t, Fingerprint.Check(m))
m.Raw[3]++
if err := Fingerprint.Check(m); err == nil {
t.Error("should error")
}
assert.Error(t, Fingerprint.Check(m))
}
func TestFingerprint_CheckBad(t *testing.T) {
m := new(Message)
addAttr(t, m, NewSoftware("software"))
m.WriteHeader()
if err := Fingerprint.Check(m); err == nil {
t.Error("should error")
}
assert.Error(t, Fingerprint.Check(m))
m.Add(AttrFingerprint, []byte{1, 2, 3})
if !IsAttrSizeInvalid(Fingerprint.Check(m)) {
t.Error("IsAttrSizeInvalid should be true")
}
assert.True(t, IsAttrSizeInvalid(Fingerprint.Check(m)))
}
func BenchmarkFingerprint_Check(b *testing.B) {

View File

@@ -7,6 +7,8 @@ import (
"encoding/binary"
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func FuzzMessage(f *testing.F) {
@@ -26,21 +28,12 @@ func FuzzMessage(f *testing.F) {
}
msg2 := New()
if _, err := msg2.Write(msg1.Raw); err != nil {
t.Fatalf("Failed to write: %s", err)
}
_, err := msg2.Write(msg1.Raw)
assert.NoError(t, err, "Failed to write")
if msg2.TransactionID != msg1.TransactionID {
t.Fatal("Transaction ID mismatch")
}
if msg2.Type != msg1.Type {
t.Fatal("Type mismatch")
}
if len(msg2.Attributes) != len(msg1.Attributes) {
t.Fatal("Attributes length mismatch")
}
assert.Equal(t, msg1.TransactionID, msg2.TransactionID, "Transaction ID mismatch")
assert.Equal(t, msg1.Type, msg2.Type, "Type mismatch")
assert.Equal(t, len(msg1.Attributes), len(msg2.Attributes), "Attributes length mismatch")
})
}
@@ -51,15 +44,11 @@ func FuzzType(f *testing.F) {
t1 := MessageType{}
t1.ReadValue(v)
v2 := t1.Value()
if v != v2 {
t.Fatal("v != v2")
}
assert.Equal(t, v, v2, "v != v2")
t2 := MessageType{}
t2.ReadValue(v2)
if t2 != t1 {
t.Fatal("t2 != t1")
}
assert.Equal(t, t1, t2, "t2 != t1")
})
}
@@ -94,20 +83,19 @@ func FuzzSetters(f *testing.F) {
m1.WriteHeader()
m1.Add(attr.t, value)
err := attr.g.GetFrom(m1)
if errors.Is(err, ErrAttributeNotFound) {
t.Fatalf("Unexpected 404: %s", err)
}
assert.False(t, errors.Is(err, ErrAttributeNotFound))
if err != nil {
return
}
m2.WriteHeader()
if err = attr.g.AddTo(m2); err != nil {
err = attr.g.AddTo(m2)
if err != nil {
// We allow decoding some text attributes
// when their length is too big, but
// not encoding.
if !IsAttrSizeOverflow(err) {
t.Fatal(err)
assert.NoError(t, err)
}
return
@@ -115,14 +103,10 @@ func FuzzSetters(f *testing.F) {
m3.WriteHeader()
v, err := m2.Get(attr.t)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
m3.Add(attr.t, v)
if !m2.Equal(m3) {
t.Fatalf("Not equal: %s != %s", m2, m3)
}
assert.True(t, m2.Equal(m3), "Not equal: %s != %s", m2, m3)
})
}

View File

@@ -9,6 +9,7 @@ import (
"testing"
"github.com/pion/stun/v3/internal/testutil"
"github.com/stretchr/testify/assert"
)
func BenchmarkBuildOverhead(b *testing.B) {
@@ -59,21 +60,12 @@ func TestMessage_Apply(t *testing.T) {
integrity,
Fingerprint,
)
if err != nil {
t.Fatal("failed to build:", err)
}
if err = msg.Check(Fingerprint, integrity); err != nil {
t.Fatal(err)
}
if _, err := decoded.Write(msg.Raw); err != nil {
t.Fatal(err)
}
if !decoded.Equal(msg) {
t.Error("not equal")
}
if err := integrity.Check(decoded); err != nil {
t.Fatal(err)
}
assert.NoError(t, err, "failed to build")
assert.NoError(t, msg.Check(Fingerprint, integrity))
_, err = decoded.Write(msg.Raw)
assert.NoError(t, err)
assert.True(t, decoded.Equal(msg))
assert.NoError(t, integrity.Check(decoded))
}
type errReturner struct {
@@ -97,25 +89,17 @@ func (e errReturner) GetFrom(*Message) error {
func TestHelpersErrorHandling(t *testing.T) {
m := New()
errReturn := errReturner{Err: errTError}
if err := m.Build(errReturn); !errors.Is(err, errReturn.Err) {
t.Error(err, "!=", errReturn.Err)
}
if err := m.Check(errReturn); !errors.Is(err, errReturn.Err) {
t.Error(err, "!=", errReturn.Err)
}
if err := m.Parse(errReturn); !errors.Is(err, errReturn.Err) {
t.Error(err, "!=", errReturn.Err)
}
assert.ErrorIs(t, m.Build(errReturn), errReturn.Err)
assert.ErrorIs(t, m.Check(errReturn), errReturn.Err)
assert.ErrorIs(t, m.Parse(errReturn), errReturn.Err)
t.Run("MustBuild", func(t *testing.T) {
t.Run("Positive", func(*testing.T) {
MustBuild(NewTransactionIDSetter(transactionID{}))
})
defer func() {
if p, ok := recover().(error); !ok || !errors.Is(p, errReturn.Err) {
t.Errorf("%s != %s",
p, errReturn.Err,
)
}
p, ok := recover().(error)
assert.True(t, ok)
assert.ErrorIs(t, p, errReturn.Err)
}()
MustBuild(errReturn)
})
@@ -123,91 +107,62 @@ func TestHelpersErrorHandling(t *testing.T) {
func TestMessage_ForEach(t *testing.T) { //nolint:cyclop
initial := New()
if err := initial.Build(
assert.NoError(t, initial.Build(
NewRealm("realm1"), NewRealm("realm2"),
); err != nil {
t.Fatal(err)
}
))
newMessage := func() *Message {
m := New()
if err := m.Build(
assert.NoError(t, m.Build(
NewRealm("realm1"), NewRealm("realm2"),
); err != nil {
t.Fatal(err)
}
))
return m
}
t.Run("NoResults", func(t *testing.T) {
m := newMessage()
if !m.Equal(initial) {
t.Error("m should be equal to initial")
}
if err := m.ForEach(AttrUsername, func(*Message) error {
t.Error("should not be called")
assert.True(t, m.Equal(initial), "m should be equal to initial")
assert.NoError(t, m.ForEach(AttrUsername, func(*Message) error {
assert.Fail(t, "should not be called")
return nil
}); err != nil {
t.Fatal(err)
}
if !m.Equal(initial) {
t.Error("m should be equal to initial")
}
}))
assert.True(t, m.Equal(initial), "m should be equal to initial")
})
t.Run("ReturnOnError", func(t *testing.T) {
m := newMessage()
var calls int
if err := m.ForEach(AttrRealm, func(*Message) error {
err := m.ForEach(AttrRealm, func(*Message) error {
if calls > 0 {
t.Error("called multiple times")
assert.Fail(t, "called multiple times")
}
calls++
return ErrAttributeNotFound
}); !errors.Is(err, ErrAttributeNotFound) {
t.Fatal(err)
}
if !m.Equal(initial) {
t.Error("m should be equal to initial")
}
})
assert.ErrorIs(t, err, ErrAttributeNotFound)
assert.True(t, m.Equal(initial), "m should be equal to initial")
})
t.Run("Positive", func(t *testing.T) {
msg := newMessage()
var realms []string
if err := msg.ForEach(AttrRealm, func(m *Message) error {
assert.NoError(t, msg.ForEach(AttrRealm, func(m *Message) error {
var realm Realm
if err := realm.GetFrom(m); err != nil {
return err
}
assert.NoError(t, realm.GetFrom(m))
realms = append(realms, realm.String())
return nil
}); err != nil {
t.Fatal(err)
}
if len(realms) != 2 {
t.Fatal("expected 2 realms")
}
if realms[0] != "realm1" {
t.Error("bad value for 1 realm")
}
if realms[1] != "realm2" {
t.Error("bad value for 2 realm")
}
if !msg.Equal(initial) {
t.Error("m should be equal to initial")
}
}))
assert.Len(t, realms, 2)
assert.Equal(t, "realm1", realms[0], "bad value for 1 realm")
assert.Equal(t, "realm2", realms[1], "bad value for 2 realm")
assert.True(t, msg.Equal(initial), "m should be equal to initial")
t.Run("ZeroAlloc", func(t *testing.T) {
msg = newMessage()
var realm Realm
testutil.ShouldNotAllocate(t, func() {
if err := msg.ForEach(AttrRealm, realm.GetFrom); err != nil {
t.Fatal(err)
}
assert.NoError(t, msg.ForEach(AttrRealm, realm.GetFrom))
})
if !msg.Equal(initial) {
t.Error("m should be equal to initial")
}
assert.True(t, msg.Equal(initial), "m should be equal to initial")
})
})
}

View File

@@ -12,6 +12,8 @@ import (
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func loadCSV(tb testing.TB, name string) [][]string {
@@ -21,9 +23,7 @@ func loadCSV(tb testing.TB, name string) [][]string {
r := csv.NewReader(bytes.NewReader(data))
r.Comment = '#'
records, err := r.ReadAll()
if err != nil {
tb.Fatal(err)
}
assert.NoError(tb, err)
return records
}
@@ -41,20 +41,14 @@ func TestIANA(t *testing.T) { //nolint:cyclop
continue
}
val, parseErr := strconv.ParseInt(v[2:], 16, 64)
if parseErr != nil {
t.Fatal(parseErr)
}
assert.NoError(t, parseErr)
t.Logf("value: 0x%x, name: %s", val, name)
methods[name] = Method(val) //nolint:gosec // G115
}
for val, name := range methodName() {
mapped, ok := methods[name]
if !ok {
t.Errorf("failed to find method %s in IANA", name)
}
if mapped != val {
t.Errorf("%s: IANA %d != actual %d", name, mapped, val)
}
assert.True(t, ok, "failed to find method %s in IANA", name)
assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val)
}
})
t.Run("Attributes", func(t *testing.T) {
@@ -69,9 +63,7 @@ func TestIANA(t *testing.T) { //nolint:cyclop
continue
}
val, parseErr := strconv.ParseInt(v[2:], 16, 64)
if parseErr != nil {
t.Fatal(parseErr)
}
assert.NoError(t, parseErr)
t.Logf("value: 0x%x, name: %s", val, name)
attrTypes[name] = AttrType(val) //nolint:gosec // G115
}
@@ -83,12 +75,8 @@ func TestIANA(t *testing.T) { //nolint:cyclop
}
for val, name := range attrNames() {
mapped, ok := attrTypes[name]
if !ok {
t.Errorf("failed to find attribute %s in IANA", name)
}
if mapped != val {
t.Errorf("%s: IANA %d != actual %d", name, mapped, val)
}
assert.True(t, ok, "failed to find attribute %s in IANA", name)
assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val)
}
})
t.Run("ErrorCodes", func(t *testing.T) {
@@ -103,21 +91,15 @@ func TestIANA(t *testing.T) { //nolint:cyclop
continue
}
val, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil {
t.Fatal(parseErr)
}
assert.NoError(t, parseErr)
t.Logf("value: 0x%x, name: %s", val, name)
errorCodes[name] = ErrorCode(val)
}
for val, nameB := range errorReasons {
name := string(nameB)
mapped, ok := errorCodes[name]
if !ok {
t.Errorf("failed to find error code %s in IANA", name)
}
if mapped != val {
t.Errorf("%s: IANA %d != actual %d", name, mapped, val)
}
assert.True(t, ok, "failed to find error code %s in IANA", name)
assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val)
}
})
}

View File

@@ -4,40 +4,29 @@
package stun
import (
"bytes"
"encoding/hex"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMessageIntegrity_AddTo_Simple(t *testing.T) {
integrity := NewLongTermIntegrity("user", "realm", "pass")
expected, err := hex.DecodeString("8493fbc53ba582fb4c044c456bdc40eb")
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected, integrity) {
t.Error(ErrIntegrityMismatch)
}
assert.NoError(t, err)
assert.EqualValues(t, expected, integrity)
t.Run("Check", func(t *testing.T) {
m := new(Message)
m.WriteHeader()
if err := integrity.AddTo(m); err != nil {
t.Error(err)
}
assert.NoError(t, integrity.AddTo(m))
NewSoftware("software").AddTo(m) //nolint:errcheck,gosec
m.WriteHeader()
dM := new(Message)
dM.Raw = m.Raw
if err := dM.Decode(); err != nil {
t.Error(err)
}
if err := integrity.Check(dM); err != nil {
t.Error(err)
}
assert.NoError(t, dM.Decode())
assert.NoError(t, integrity.Check(dM))
dM.Raw[24] += 12 // HMAC now invalid
if integrity.Check(dM) == nil {
t.Error("should be invalid")
}
assert.Error(t, integrity.Check(dM))
})
}
@@ -47,38 +36,23 @@ func TestMessageIntegrityWithFingerprint(t *testing.T) {
msg.WriteHeader()
NewSoftware("software").AddTo(msg) //nolint:errcheck,gosec
integrity := NewShortTermIntegrity("pwd")
if integrity.String() != "KEY: 0x707764" {
t.Error("bad string", integrity)
}
if err := integrity.Check(msg); err == nil {
t.Error("should error")
}
if err := integrity.AddTo(msg); err != nil {
t.Fatal(err)
}
if err := Fingerprint.AddTo(msg); err != nil {
t.Fatal(err)
}
if err := integrity.Check(msg); err != nil {
t.Fatal(err)
}
assert.Equal(t, "KEY: 0x707764", integrity.String())
assert.NoError(t, integrity.AddTo(msg))
assert.NoError(t, integrity.AddTo(msg))
assert.NoError(t, integrity.Check(msg))
assert.NoError(t, Fingerprint.AddTo(msg))
assert.NoError(t, integrity.Check(msg))
msg.Raw[24] = 33
if err := integrity.Check(msg); err == nil {
t.Fatal("mismatch expected")
}
assert.Error(t, integrity.Check(msg))
}
func TestMessageIntegrity(t *testing.T) {
m := new(Message)
i := NewShortTermIntegrity("password")
m.WriteHeader()
if err := i.AddTo(m); err != nil {
t.Error(err)
}
assert.NoError(t, i.AddTo(m))
_, err := m.Get(AttrMessageIntegrity)
if err != nil {
t.Error(err)
}
assert.NoError(t, err)
}
func TestMessageIntegrityBeforeFingerprint(t *testing.T) {
@@ -86,9 +60,7 @@ func TestMessageIntegrityBeforeFingerprint(t *testing.T) {
m.WriteHeader()
Fingerprint.AddTo(m) //nolint:errcheck,gosec
i := NewShortTermIntegrity("password")
if err := i.AddTo(m); err == nil {
t.Error("should error")
}
assert.Error(t, i.AddTo(m))
}
func BenchmarkMessageIntegrity_AddTo(b *testing.B) {
@@ -99,9 +71,7 @@ func BenchmarkMessageIntegrity_AddTo(b *testing.B) {
b.SetBytes(int64(len(m.Raw)))
for i := 0; i < b.N; i++ {
m.WriteHeader()
if err := integrity.AddTo(m); err != nil {
b.Error(err)
}
assert.NoError(b, integrity.AddTo(m))
m.Reset()
}
}
@@ -114,13 +84,9 @@ func BenchmarkMessageIntegrity_Check(b *testing.B) {
b.ReportAllocs()
m.WriteHeader()
b.SetBytes(int64(len(m.Raw)))
if err := integrity.AddTo(m); err != nil {
b.Error(err)
}
assert.NoError(b, integrity.AddTo(m))
m.WriteLength()
for i := 0; i < b.N; i++ {
if err := integrity.Check(m); err != nil {
b.Fatal(err)
}
assert.NoError(b, integrity.Check(m))
}
}

View File

@@ -11,6 +11,8 @@ import (
"fmt"
"hash"
"testing"
"github.com/stretchr/testify/assert"
)
type hmacTest struct {
@@ -524,26 +526,17 @@ func hmacTests() []hmacTest { //nolint:maintidx
func TestHMAC(t *testing.T) {
for i, tt := range hmacTests() {
hsh := New(tt.hash, tt.key)
if s := hsh.Size(); s != tt.size {
t.Errorf("Size: got %v, want %v", s, tt.size)
}
if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
}
assert.Equal(t, tt.size, hsh.Size(), "Size mismatch")
assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch")
for j := 0; j < 4; j++ { //nolint:varnamelen
n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue
}
assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n)
assert.NoError(t, err, "test %d.%d: Write error", i, j)
// Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
// Second iteration: make sure reset works.
@@ -568,18 +561,10 @@ func TestEqual(t *testing.T) {
b := []byte("test1")
c := []byte("test2")
if !Equal(b, b) {
t.Error("Equal failed with equal arguments")
}
if Equal(a, b) {
t.Error("Equal accepted a prefix of the second argument")
}
if Equal(b, a) {
t.Error("Equal accepted a prefix of the first argument")
}
if Equal(b, c) {
t.Error("Equal accepted unequal slices")
}
assert.True(t, Equal(b, b), "Equal failed with equal arguments")
assert.False(t, Equal(a, b), "Equal accepted a prefix of the second argument")
assert.False(t, Equal(b, a), "Equal accepted a prefix of the first argument")
assert.False(t, Equal(b, c), "Equal accepted unequal slices")
}
func BenchmarkHMACSHA256_1K(b *testing.B) {

View File

@@ -8,6 +8,8 @@ import (
"crypto/sha256"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func BenchmarkHMACSHA1_512(b *testing.B) {
@@ -44,26 +46,17 @@ func TestHMACReset(t *testing.T) {
for i, tt := range hmacTests() {
hsh := New(tt.hash, tt.key)
hsh.(*hmac).resetTo(tt.key) //nolint:forcetypeassert
if s := hsh.Size(); s != tt.size {
t.Errorf("Size: got %v, want %v", s, tt.size)
}
if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
}
assert.Equal(t, tt.size, hsh.Size(), "Size mismatch")
assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch")
for j := 0; j < 2; j++ {
n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue
}
assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n)
assert.NoError(t, err, "test %d.%d: Write error", i, j)
// Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
// Second iteration: make sure reset works.
@@ -78,26 +71,17 @@ func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl,cyclop
continue
}
hsh := AcquireSHA1(tt.key)
if s := hsh.Size(); s != tt.size {
t.Errorf("Size: got %v, want %v", s, tt.size)
}
if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
}
assert.Equal(t, tt.size, hsh.Size(), "Size mismatch")
assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch")
for j := 0; j < 2; j++ {
n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue
}
assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n)
assert.NoError(t, err, "test %d.%d: Write error", i, j)
// Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
// Second iteration: make sure reset works.
@@ -113,26 +97,17 @@ func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl,cyclop
continue
}
hsh := AcquireSHA256(tt.key)
if s := hsh.Size(); s != tt.size {
t.Errorf("Size: got %v, want %v", s, tt.size)
}
if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
}
assert.Equal(t, tt.size, hsh.Size(), "Size mismatch")
assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch")
for j := 0; j < 2; j++ {
n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue
}
assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n)
assert.NoError(t, err, "test %d.%d: Write error", i, j)
// Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
// Second iteration: make sure reset works.
@@ -150,7 +125,7 @@ func TestAssertBlockSize(t *testing.T) {
t.Run("Negative", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("should panic")
assert.Fail(t, "should panic")
}
}()
h := AcquireSHA256(make([]byte, 0, 1024))

View File

@@ -6,6 +6,8 @@ package testutil
import (
"testing"
"github.com/stretchr/testify/assert"
)
// ShouldNotAllocate fails if f allocates.
@@ -17,7 +19,5 @@ func ShouldNotAllocate(t *testing.T, f func()) {
return
}
if a := testing.AllocsPerRun(10, f); a > 0 {
t.Errorf("Allocations detected: %f", a)
}
assert.Zero(t, testing.AllocsPerRun(10, f))
}

View File

@@ -21,6 +21,8 @@ import (
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
type attributeEncoder interface {
@@ -50,12 +52,9 @@ func TestMessageBuffer(t *testing.T) {
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader()
mDecoded := New()
if _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw)); err != nil {
t.Error(err)
}
if !mDecoded.Equal(m) {
t.Error(mDecoded, "!", m)
}
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw))
assert.NoError(t, err)
assert.True(t, mDecoded.Equal(m), "mDecoded != m")
}
func BenchmarkMessage_Write(b *testing.B) {
@@ -86,9 +85,7 @@ func TestMessageType_Value(t *testing.T) {
}
for _, tt := range tests {
b := tt.in.Value()
if b != tt.out {
t.Errorf("Value(%s) -> %s, want %s", tt.in, bUint16(b), bUint16(tt.out))
}
assert.Equal(t, tt.out, b, "Value(%s) -> %s, want %s", tt.in, bUint16(b), bUint16(tt.out))
}
}
@@ -104,9 +101,7 @@ func TestMessageType_ReadValue(t *testing.T) {
for _, tt := range tests {
m := MessageType{}
m.ReadValue(tt.in)
if m != tt.out {
t.Errorf("ReadValue(%s) -> %s, want %s", bUint16(tt.in), m, tt.out)
}
assert.Equal(t, tt.out, m, "ReadValue(%s) -> %s, want %s", bUint16(tt.in), m, tt.out)
}
}
@@ -121,12 +116,8 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
m := MessageType{}
v := tt.Value()
m.ReadValue(v)
if m != tt {
t.Errorf("ReadValue(%s -> %s) = %s, should be %s", tt, bUint16(v), m, tt)
if m.Method != tt.Method {
t.Errorf("%s != %s", bUint16(uint16(m.Method)), bUint16(uint16(tt.Method)))
}
}
assert.Equal(t, tt, m, "ReadValue(%s -> %s) = %s, should be %s", tt, bUint16(v), m, tt)
assert.Equal(t, tt.Method, m.Method, "%s != %s", bUint16(uint16(m.Method)), bUint16(uint16(tt.Method)))
}
}
@@ -137,32 +128,26 @@ func TestMessage_WriteTo(t *testing.T) {
msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
msg.WriteHeader()
buf := new(bytes.Buffer)
if _, err := msg.WriteTo(buf); err != nil {
t.Fatal(err)
}
_, err := msg.WriteTo(buf)
assert.NoError(t, err)
mDecoded := New()
if _, err := mDecoded.ReadFrom(buf); err != nil {
t.Error(err)
}
if !mDecoded.Equal(msg) {
t.Error(mDecoded, "!", msg)
}
_, err = mDecoded.ReadFrom(buf)
assert.NoError(t, err)
assert.True(t, mDecoded.Equal(msg), "mDecoded != msg")
}
func TestMessage_Cookie(t *testing.T) {
buf := make([]byte, 20)
mDecoded := New()
if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil {
t.Error("should error")
}
_, err := mDecoded.ReadFrom(bytes.NewReader(buf))
assert.Error(t, err, "should error")
}
func TestMessage_LengthLessHeaderSize(t *testing.T) {
buf := make([]byte, 8)
mDecoded := New()
if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil {
t.Error("should error")
}
_, err := mDecoded.ReadFrom(bytes.NewReader(buf))
assert.Error(t, err, "should error")
}
func TestMessage_BadLength(t *testing.T) {
@@ -176,9 +161,8 @@ func TestMessage_BadLength(t *testing.T) {
m.WriteHeader()
m.Raw[20+3] = 10 // set attr length = 10
mDecoded := New()
if _, err := mDecoded.Write(m.Raw); err == nil {
t.Error("should error")
}
_, err := mDecoded.Write(m.Raw)
assert.Error(t, err, "should error")
}
func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
@@ -197,13 +181,8 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
binary.BigEndian.PutUint16(m.Raw[2:4], 2) // rewrite to bad length
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+2]))
var e *DecodeErr
if errors.As(err, &e) {
if !e.IsPlace(DecodeErrPlace{"attribute", "header"}) {
t.Error(e, "bad place")
}
} else {
t.Error(err, "should be bad format")
}
assert.ErrorAs(t, err, &e)
assert.True(t, e.IsPlace(DecodeErrPlace{"attribute", "header"}), "bad place")
}
func TestMessage_AttrSizeLessThanLength(t *testing.T) {
@@ -226,13 +205,8 @@ func TestMessage_AttrSizeLessThanLength(t *testing.T) {
mDecoded := New()
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+5]))
var e *DecodeErr
if errors.As(err, &e) {
if !e.IsPlace(DecodeErrPlace{"attribute", "value"}) {
t.Error(e, "bad place")
}
} else {
t.Error(err, "should be bad format")
}
assert.ErrorAs(t, err, &e)
assert.True(t, e.IsPlace(DecodeErrPlace{"attribute", "value"}), "bad place")
}
type unexpectedEOFReader struct{}
@@ -244,9 +218,7 @@ func (r unexpectedEOFReader) Read([]byte) (int, error) {
func TestMessage_ReadFromError(t *testing.T) {
mDecoded := New()
_, err := mDecoded.ReadFrom(unexpectedEOFReader{})
if !errors.Is(err, io.ErrUnexpectedEOF) {
t.Error(err, "should be", io.ErrUnexpectedEOF)
}
assert.ErrorIs(t, err, io.ErrUnexpectedEOF, "should be", io.ErrUnexpectedEOF)
}
func BenchmarkMessageType_Value(b *testing.B) {
@@ -321,9 +293,7 @@ func BenchmarkMessage_ReadBytes(b *testing.B) {
func TestMessageClass_String(t *testing.T) {
defer func() {
if err := recover(); err == nil {
t.Error(err, "should be not nil")
}
assert.NotNil(t, recover())
}()
v := [...]MessageClass{
@@ -333,14 +303,12 @@ func TestMessageClass_String(t *testing.T) {
ClassIndication,
}
for _, k := range v {
if k.String() == "" {
t.Error(k, "bad stringer")
}
assert.NotEmpty(t, k.String(), "%v bad stringer", k)
}
// should panic
p := MessageClass(0x05).String()
t.Error("should panic!", p)
assert.Fail(t, "should panic", p)
}
func TestAttrType_String(t *testing.T) {
@@ -358,46 +326,26 @@ func TestAttrType_String(t *testing.T) {
AttrFingerprint,
}
for _, k := range attrType {
if k.String() == "" {
t.Error(k, "bad stringer")
}
if strings.HasPrefix(k.String(), "0x") {
t.Error(k, "bad stringer")
}
assert.NotEmpty(t, k.String(), "%v bad stringer", k)
assert.False(t, strings.HasPrefix(k.String(), "0x"), "%v bad stringer", k)
}
vNonStandard := AttrType(0x512)
if !strings.HasPrefix(vNonStandard.String(), "0x512") {
t.Error(vNonStandard, "bad prefix")
}
assert.True(t, strings.HasPrefix(vNonStandard.String(), "0x512"), "%v bad prefix", vNonStandard)
}
func TestMethod_String(t *testing.T) {
if MethodBinding.String() != "Binding" {
t.Error("binding is not binding!")
}
if Method(0x616).String() != "0x616" {
t.Error("Bad stringer", Method(0x616))
}
assert.Equal(t, "Binding", MethodBinding.String(), "binding is not binding!")
assert.Equal(t, "0x616", Method(0x616).String(), "Bad stringer")
}
func TestAttribute_Equal(t *testing.T) {
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
if !attr1.Equal(attr2) {
t.Error("should equal")
}
if attr1.Equal(RawAttribute{Type: 0x2}) {
t.Error("should not equal")
}
if attr1.Equal(RawAttribute{Length: 0x2}) {
t.Error("should not equal")
}
if attr1.Equal(RawAttribute{Length: 0x3}) {
t.Error("should not equal")
}
if attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) {
t.Error("should not equal")
}
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
assert.True(t, attr1.Equal(attr2))
assert.False(t, attr1.Equal(RawAttribute{Type: 0x2}))
assert.False(t, attr1.Equal(RawAttribute{Length: 0x2}))
assert.False(t, attr1.Equal(RawAttribute{Length: 0x3}))
assert.False(t, attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}))
}
func TestMessage_Equal(t *testing.T) { //nolint:cyclop
@@ -405,39 +353,23 @@ func TestMessage_Equal(t *testing.T) { //nolint:cyclop
attrs := Attributes{attr}
msg1 := &Message{Attributes: attrs, Length: 4 + 2}
msg2 := &Message{Attributes: attrs, Length: 4 + 2}
if !msg1.Equal(msg2) {
t.Error("should equal")
}
if msg1.Equal(&Message{Type: MessageType{Class: 128}}) {
t.Error("should not equal")
}
assert.True(t, msg1.Equal(msg2))
assert.False(t, msg1.Equal(&Message{Type: MessageType{Class: 128}}))
tID := [TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
}
if msg1.Equal(&Message{TransactionID: tID}) {
t.Error("should not equal")
}
if msg1.Equal(&Message{Length: 3}) {
t.Error("should not equal")
}
assert.False(t, msg1.Equal(&Message{TransactionID: tID}))
assert.False(t, msg1.Equal(&Message{Length: 3}))
tAttrs := Attributes{
{Length: 1, Value: []byte{0x1}, Type: 0x1},
}
if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
t.Error("should not equal")
}
assert.False(t, msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}))
tAttrs = Attributes{
{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2},
}
if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
t.Error("should not equal")
}
if !(*Message)(nil).Equal(nil) {
t.Error("nil should be equal to nil")
}
if msg1.Equal(nil) {
t.Error("non-nil should not be equal to nil")
}
assert.False(t, msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}))
assert.True(t, (*Message)(nil).Equal(nil), "nil should be equal to nil")
assert.False(t, msg1.Equal(nil), "non-nil should not be equal to nil")
t.Run("Nil attributes", func(t *testing.T) {
msg1 := &Message{
Attributes: nil,
@@ -447,61 +379,43 @@ func TestMessage_Equal(t *testing.T) { //nolint:cyclop
Attributes: attrs,
Length: 4 + 2,
}
if msg1.Equal(msg2) {
t.Error("should not equal")
}
if msg2.Equal(msg1) {
t.Error("should not equal")
}
assert.False(t, msg1.Equal(msg2))
assert.False(t, msg2.Equal(msg1))
msg2.Attributes = nil
if !msg1.Equal(msg2) {
t.Error("should equal")
}
assert.True(t, msg1.Equal(msg2))
})
t.Run("Attributes length", func(t *testing.T) {
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
a := &Message{Attributes: Attributes{attr}, Length: 4 + 2}
b := &Message{Attributes: Attributes{attr, attr1}, Length: 4 + 2}
if a.Equal(b) {
t.Error("should not equal")
}
assert.False(t, a.Equal(b))
})
t.Run("Attributes values", func(t *testing.T) {
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x1}
a := &Message{Attributes: Attributes{attr, attr}, Length: 4 + 2}
b := &Message{Attributes: Attributes{attr, attr1}, Length: 4 + 2}
if a.Equal(b) {
t.Error("should not equal")
}
assert.False(t, a.Equal(b))
})
}
func TestMessageGrow(t *testing.T) {
m := New()
m.grow(512)
if len(m.Raw) < 512 {
t.Error("Bad length", len(m.Raw))
}
assert.GreaterOrEqual(t, len(m.Raw), 512)
}
func TestMessageGrowSmaller(t *testing.T) {
m := New()
m.grow(2)
if cap(m.Raw) < 20 {
t.Error("Bad capacity", cap(m.Raw))
}
if len(m.Raw) < 20 {
t.Error("Bad length", len(m.Raw))
}
assert.GreaterOrEqual(t, cap(m.Raw), 20)
assert.GreaterOrEqual(t, len(m.Raw), 20)
}
func TestMessage_String(t *testing.T) {
m := New()
if m.String() == "" {
t.Error("bad string")
}
assert.NotEmpty(t, m.String())
}
func TestIsMessage(t *testing.T) {
@@ -525,9 +439,7 @@ func TestIsMessage(t *testing.T) {
}, true}, // 6
}
for i, v := range tt {
if got := IsMessage(v.in); got != v.out {
t.Errorf("tt[%d]: IsMessage(%+v) %v != %v", i, v.in, got, v.out)
}
assert.Equal(t, v.out, IsMessage(v.in), "tt[%d]: IsMessage(%+v)", i, v.in)
}
}
@@ -553,18 +465,12 @@ func loadData(tb testing.TB, name string) []byte {
name = filepath.Join("testdata", name)
f, err := os.Open(name) //nolint:gosec
if err != nil {
tb.Fatal(err)
}
assert.NoError(tb, err)
defer func() {
if errClose := f.Close(); errClose != nil {
tb.Fatal(errClose)
}
assert.NoError(tb, f.Close())
}()
v, err := io.ReadAll(f)
if err != nil {
tb.Fatal(err)
}
assert.NoError(tb, err)
return v
}
@@ -573,9 +479,7 @@ func TestExampleChrome(t *testing.T) {
buf := loadData(t, "ex1_chrome.stun")
m := New()
_, err := m.Write(buf)
if err != nil {
t.Errorf("Failed to parse ex1_chrome: %s", err)
}
assert.NoError(t, err, "Failed to parse ex1_chrome")
}
func TestMessageFromBrowsers(t *testing.T) {
@@ -583,9 +487,7 @@ func TestMessageFromBrowsers(t *testing.T) {
reader := csv.NewReader(bytes.NewReader(loadData(t, "frombrowsers.csv")))
reader.Comment = '#'
_, err := reader.Read() // skipping header
if err != nil {
t.Fatal("failed to skip header of csv: ", err)
}
assert.NoError(t, err, "failed to skip header of csv")
crcTable := crc64.MakeTable(crc64.ISO)
msg := New()
for {
@@ -593,23 +495,14 @@ func TestMessageFromBrowsers(t *testing.T) {
if errors.Is(err, io.EOF) {
break
}
if err != nil {
t.Fatal("failed to read csv line: ", err)
}
assert.NoError(t, err, "failed to read csv line")
data, err := base64.StdEncoding.DecodeString(line[1])
if err != nil {
t.Fatal("failed to decode ", line[1], " as base64: ", err)
}
assert.NoError(t, err)
b, err := strconv.ParseUint(line[2], 10, 64)
if err != nil {
t.Fatal(err)
}
if b != crc64.Checksum(data, crcTable) {
t.Error("crc64 check failed for ", line[1])
}
if _, err = msg.Write(data); err != nil {
t.Error("failed to decode ", line[1], " as message: ", err)
}
assert.NoError(t, err)
assert.Equal(t, b, crc64.Checksum(data, crcTable), "crc64 check failed for %s", line[1])
_, err = msg.Write(data)
assert.NoError(t, err, "failed to decode %s as message: %s", line[1], err)
msg.Reset()
}
}
@@ -619,9 +512,7 @@ func BenchmarkMessage_NewTransactionID(b *testing.B) {
m := new(Message)
m.WriteHeader()
for i := 0; i < b.N; i++ {
if err := m.NewTransactionID(); err != nil {
b.Fatal(err)
}
assert.NoError(b, m.NewTransactionID())
}
}
@@ -633,12 +524,8 @@ func BenchmarkMessageFull(b *testing.B) {
IP: net.IPv4(213, 1, 223, 5),
}
for i := 0; i < b.N; i++ {
if err := addr.AddTo(msg); err != nil {
b.Fatal(err)
}
if err := s.AddTo(msg); err != nil {
b.Fatal(err)
}
assert.NoError(b, addr.AddTo(msg))
assert.NoError(b, s.AddTo(msg))
msg.WriteAttributes()
msg.WriteHeader()
Fingerprint.AddTo(msg) //nolint:errcheck,gosec
@@ -655,12 +542,8 @@ func BenchmarkMessageFullHardcore(b *testing.B) {
IP: net.IPv4(213, 1, 223, 5),
}
for i := 0; i < b.N; i++ {
if err := addr.AddTo(msg); err != nil {
b.Fatal(err)
}
if err := s.AddTo(msg); err != nil {
b.Fatal(err)
}
assert.NoError(b, addr.AddTo(msg))
assert.NoError(b, s.AddTo(msg))
msg.WriteHeader()
msg.Reset()
}
@@ -684,12 +567,8 @@ func BenchmarkMessage_WriteHeader(b *testing.B) {
func TestMessage_Contains(t *testing.T) {
m := new(Message)
m.Add(AttrSoftware, []byte("value"))
if !m.Contains(AttrSoftware) {
t.Error("message should contain software")
}
if m.Contains(AttrNonce) {
t.Error("message should not contain nonce")
}
assert.True(t, m.Contains(AttrSoftware), "message should contain software")
assert.False(t, m.Contains(AttrNonce), "message should not contain nonce")
}
func ExampleMessage() {
@@ -787,13 +666,9 @@ func TestAllocations(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() {
m.Reset()
m.WriteHeader()
if err := s.AddTo(m); err != nil {
t.Errorf("[%d] failed to add", i)
}
assert.NoError(t, s.AddTo(m), "[%d] failed to add", i)
})
if allocs > 0 {
t.Errorf("[%d] allocated %.0f", i, allocs)
}
assert.Zero(t, allocs, "[%d] allocated", i)
}
}
@@ -818,9 +693,7 @@ func TestAllocationsGetters(t *testing.T) {
Fingerprint,
}
msg := New()
if err := msg.Build(setters...); err != nil {
t.Error("failed to build", err)
}
assert.NoError(t, msg.Build(setters...))
getters := []Getter{
new(Nonce),
new(Username),
@@ -832,66 +705,48 @@ func TestAllocationsGetters(t *testing.T) {
g := g
i := i
allocs := testing.AllocsPerRun(10, func() {
if err := g.GetFrom(msg); err != nil {
t.Errorf("[%d] failed to get", i)
}
assert.NoError(t, g.GetFrom(msg), "[%d] failed to get", i)
})
if allocs > 0 {
t.Errorf("[%d] allocated %.0f", i, allocs)
}
assert.Zero(t, allocs, "[%d] allocated", i)
}
}
func TestMessageFullSize(t *testing.T) {
msg := new(Message)
if err := msg.Build(BindingRequest,
assert.NoError(t, msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}),
NewSoftware("pion/stun"),
NewLongTermIntegrity("username", "realm", "password"),
Fingerprint,
); err != nil {
t.Fatal(err)
}
))
msg.Raw = msg.Raw[:len(msg.Raw)-10]
decoder := new(Message)
decoder.Raw = msg.Raw[:len(msg.Raw)-10]
if err := decoder.Decode(); err == nil {
t.Error("decode on truncated buffer should error")
}
assert.Error(t, decoder.Decode(), "decode on truncated buffer should error")
}
func TestMessage_CloneTo(t *testing.T) {
msg := new(Message)
if err := msg.Build(BindingRequest,
assert.NoError(t, msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}),
NewSoftware("pion/stun"),
NewLongTermIntegrity("username", "realm", "password"),
Fingerprint,
); err != nil {
t.Fatal(err)
}
))
msg.Encode()
msg2 := new(Message)
if err := msg.CloneTo(msg2); err != nil {
t.Fatal(err)
}
if !msg2.Equal(msg) {
t.Fatal("not equal")
}
assert.NoError(t, msg.CloneTo(msg2))
assert.True(t, msg2.Equal(msg), "cloned message should equal original")
// Corrupting m and checking that b is not corrupted.
s, ok := msg2.Attributes.Get(AttrSoftware)
if !ok {
t.Fatal("no software attribute")
}
assert.True(t, ok)
s.Value[0] = 'k'
if msg2.Equal(msg) {
t.Fatal("should not be equal")
}
assert.False(t, msg2.Equal(msg), "should not be equal")
}
func BenchmarkMessage_CloneTo(b *testing.B) {
@@ -919,29 +774,21 @@ func BenchmarkMessage_CloneTo(b *testing.B) {
func TestMessage_AddTo(t *testing.T) {
msg := new(Message)
if err := msg.Build(BindingRequest,
assert.NoError(t, msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}),
Fingerprint,
); err != nil {
t.Fatal(err)
}
))
msg.Encode()
b := new(Message)
if err := msg.CloneTo(b); err != nil {
t.Fatal(err)
}
assert.NoError(t, msg.CloneTo(b))
msg.TransactionID = [TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2,
}
if b.Equal(msg) {
t.Fatal("should not be equal")
}
assert.False(t, b.Equal(msg), "should not be equal")
msg.AddTo(b) //nolint:errcheck,gosec
if !b.Equal(msg) {
t.Fatal("should be equal")
}
assert.True(t, b.Equal(msg), "should be equal")
}
func BenchmarkMessage_AddTo(b *testing.B) {
@@ -966,9 +813,7 @@ func BenchmarkMessage_AddTo(b *testing.B) {
func TestDecode(t *testing.T) {
t.Run("Nil", func(t *testing.T) {
if err := Decode(nil, nil); !errors.Is(err, ErrDecodeToNil) {
t.Errorf("unexpected error: %v", err)
}
assert.ErrorIs(t, Decode(nil, nil), ErrDecodeToNil)
})
msg := New()
msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
@@ -976,22 +821,14 @@ func TestDecode(t *testing.T) {
msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
msg.WriteHeader()
mDecoded := New()
if err := Decode(msg.Raw, mDecoded); err != nil {
t.Errorf("unexpected error: %v", err)
}
if !mDecoded.Equal(msg) {
t.Error("decoded result is not equal to encoded message")
}
assert.NoError(t, Decode(msg.Raw, mDecoded))
assert.True(t, mDecoded.Equal(msg), "decoded result is not equal to encoded message")
t.Run("ZeroAlloc", func(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() {
mDecoded.Reset()
if err := Decode(msg.Raw, mDecoded); err != nil {
t.Error(err)
}
assert.NoError(t, Decode(msg.Raw, mDecoded))
})
if allocs > 0 {
t.Error("unexpected allocations")
}
assert.Zero(t, allocs, "unexpected allocations")
})
}
@@ -1021,25 +858,19 @@ func TestMessage_MarshalBinary(t *testing.T) {
},
)
data, err := msg.MarshalBinary()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
// Reset m.Raw to check retention.
for i := range msg.Raw {
msg.Raw[i] = 0
}
if err := msg.UnmarshalBinary(data); err != nil {
t.Fatal(err)
}
assert.NoError(t, msg.UnmarshalBinary(data))
// Reset data to check retention.
for i := range data {
data[i] = 0
}
if err := msg.Decode(); err != nil {
t.Fatal(err)
}
assert.NoError(t, msg.Decode())
}
func TestMessage_GobDecode(t *testing.T) {
@@ -1050,23 +881,17 @@ func TestMessage_GobDecode(t *testing.T) {
},
)
data, err := msg.GobEncode()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
// Reset m.Raw to check retention.
for i := range msg.Raw {
msg.Raw[i] = 0
}
if err := msg.GobDecode(data); err != nil {
t.Fatal(err)
}
assert.NoError(t, msg.GobDecode(data))
// Reset data to check retention.
for i := range data {
data[i] = 0
}
if err := msg.Decode(); err != nil {
t.Fatal(err)
}
assert.NoError(t, msg.Decode())
}

View File

@@ -6,6 +6,8 @@ package stun
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRFC5769(t *testing.T) { //nolint:cyclop
@@ -32,19 +34,11 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop
"\xe5\x7a\x3b\xcf",
),
}
if err := m.Decode(); err != nil {
t.Error(err)
}
assert.NoError(t, m.Decode())
software := new(Software)
if err := software.GetFrom(m); err != nil {
t.Error(err)
}
if software.String() != "STUN test client" {
t.Error("bad software: ", software)
}
if err := Fingerprint.Check(m); err != nil {
t.Error("check failed: ", err)
}
assert.NoError(t, software.GetFrom(m))
assert.Equal(t, "STUN test client", software.String())
assert.NoError(t, Fingerprint.Check(m))
t.Run("Long-Term credentials", func(t *testing.T) {
msg := &Message{
Raw: []byte("\x00\x01\x00\x60" +
@@ -64,40 +58,24 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop
"\x2e\x85\xc9\xa2\x8c\xa8\x96\x66",
),
}
if err := msg.Decode(); err != nil {
t.Error(err)
}
assert.NoError(t, msg.Decode())
u := new(Username)
if err := u.GetFrom(msg); err != nil {
t.Error(err)
}
assert.NoError(t, u.GetFrom(msg))
expectedUsername := "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9"
if u.String() != expectedUsername {
t.Errorf("username: %q (got) != %q (exp)", u, expectedUsername)
}
assert.Equal(t, expectedUsername, u.String())
n := new(Nonce)
if err := n.GetFrom(msg); err != nil {
t.Error(err)
}
if n.String() != "f//499k954d6OL34oL9FSTvy64sA" {
t.Error("bad nonce")
}
assert.NoError(t, n.GetFrom(msg))
assert.Equal(t, "f//499k954d6OL34oL9FSTvy64sA", n.String())
r := new(Realm)
if err := r.GetFrom(msg); err != nil {
t.Error(err)
}
if r.String() != "example.org" { //nolint:goconst
t.Error("bad realm")
}
assert.NoError(t, r.GetFrom(msg))
assert.Equal(t, "example.org", r.String())
// checking HMAC
i := NewLongTermIntegrity(
"\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9",
"example.org",
"TheMatrIX",
)
if err := i.Check(msg); err != nil {
t.Error(err)
}
assert.NoError(t, i.Check(msg))
})
})
t.Run("Response", func(t *testing.T) {
@@ -117,32 +95,18 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop
"\xc0\x7d\x4c\x96",
),
}
if err := msg.Decode(); err != nil {
t.Error(err)
}
assert.NoError(t, msg.Decode())
software := new(Software)
if err := software.GetFrom(msg); err != nil {
t.Error(err)
}
if software.String() != "test vector" {
t.Error("bad software: ", software)
}
if err := Fingerprint.Check(msg); err != nil {
t.Error("Check failed: ", err)
}
assert.NoError(t, software.GetFrom(msg))
assert.Equal(t, "test vector", software.String())
assert.NoError(t, Fingerprint.Check(msg))
addr := new(XORMappedAddress)
if err := addr.GetFrom(msg); err != nil {
t.Error(err)
}
if !addr.IP.Equal(net.ParseIP("192.0.2.1")) {
t.Error("bad IP")
}
if addr.Port != 32853 {
t.Error("bad Port")
}
if err := Fingerprint.Check(msg); err != nil {
t.Error("check failed: ", err)
}
assert.NoError(t, addr.GetFrom(msg))
expected := "192.0.2.1"
assert.Equalf(t, expected, addr.IP.String(), "Expected %s, got %s", expected, addr.IP)
assert.Equal(t, 32853, addr.Port)
assert.NoError(t, Fingerprint.Check(msg))
})
t.Run("IPv6", func(t *testing.T) {
msg := &Message{
@@ -162,32 +126,20 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop
"\xc8\xfb\x0b\x4c",
),
}
if err := msg.Decode(); err != nil {
t.Error(err)
}
assert.NoError(t, msg.Decode())
software := new(Software)
if err := software.GetFrom(msg); err != nil {
t.Error(err)
}
if software.String() != "test vector" {
t.Error("bad software: ", software)
}
if err := Fingerprint.Check(msg); err != nil {
t.Error("Check failed: ", err)
}
assert.NoError(t, software.GetFrom(msg))
assert.Equal(t, "test vector", software.String())
assert.NoError(t, Fingerprint.Check(msg))
addr := new(XORMappedAddress)
if err := addr.GetFrom(msg); err != nil {
t.Error(err)
}
if !addr.IP.Equal(net.ParseIP("2001:db8:1234:5678:11:2233:4455:6677")) {
t.Error("bad IP")
}
if addr.Port != 32853 {
t.Error("bad Port")
}
if err := Fingerprint.Check(msg); err != nil {
t.Error("check failed: ", err)
}
assert.NoError(t, addr.GetFrom(msg))
expectedIP := "2001:db8:1234:5678:11:2233:4455:6677"
assert.Truef(
t, addr.IP.Equal(net.ParseIP(expectedIP)),
"Expected %s, got %s", expectedIP, addr.IP,
)
assert.Equal(t, 32853, addr.Port)
assert.NoError(t, Fingerprint.Check(msg))
})
})
}

View File

@@ -6,6 +6,8 @@ package stun
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
type errorReader struct{}
@@ -21,9 +23,7 @@ func (errorReader) Read([]byte) (int, error) {
func TestReadFullHelper(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("should panic")
}
assert.NotNil(t, recover(), "should panic")
}()
readFullOrPanic(errorReader{}, make([]byte, 1))
}
@@ -36,9 +36,7 @@ func (errorWriter) Write([]byte) (int, error) {
func TestWriteHelper(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("should panic")
}
assert.NotNil(t, recover(), "should panic")
}()
writeOrPanic(errorWriter{}, make([]byte, 1))
}

View File

@@ -9,6 +9,8 @@ import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
var errUDPServerUnsupportedNetwork = errors.New("unsupported network")
@@ -37,9 +39,7 @@ func NewUDPServer(
}
udpConn, err := net.ListenUDP(network, &net.UDPAddr{IP: net.ParseIP(ip), Port: 0})
if err != nil {
t.Fatal(err) //nolint:forbidigo
}
assert.NoError(t, err)
// Necessary for IPv6
address := fmt.Sprintf("%s:%d", ip, udpConn.LocalAddr().(*net.UDPAddr).Port) //nolint:forcetypeassert
@@ -81,18 +81,14 @@ func NewUDPServer(
select {
case err := <-errCh:
if err != nil {
t.Fatal(err)
assert.NoError(t, err)
return
}
default:
}
err := udpConn.Close()
if err != nil {
t.Fatal(err)
}
assert.NoError(t, udpConn.Close())
<-errCh
}, nil
}

View File

@@ -7,9 +7,10 @@
package stun
import (
"errors"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSoftware_GetFrom(t *testing.T) {
@@ -22,44 +23,29 @@ func TestSoftware_GetFrom(t *testing.T) {
Raw: make([]byte, 0, 256),
}
software := new(Software)
if _, err := m2.ReadFrom(msg.reader()); err != nil {
t.Error(err)
}
if err := software.GetFrom(msg); err != nil {
t.Fatal(err)
}
if software.String() != val {
t.Errorf("Expected %q, got %q.", val, software)
}
_, err := m2.ReadFrom(msg.reader())
assert.NoError(t, err)
assert.NoError(t, software.GetFrom(msg))
assert.Equal(t, val, software.String())
sAttr, ok := msg.Attributes.Get(AttrSoftware)
if !ok {
t.Error("software attribute should be found")
}
assert.True(t, ok, "software attribute should be found")
s := sAttr.String()
if !strings.HasPrefix(s, "SOFTWARE:") {
t.Error("bad string representation", s)
}
assert.True(t, strings.HasPrefix(s, "SOFTWARE:"), "bad string representation")
}
func TestSoftware_AddTo_Invalid(t *testing.T) {
m := New()
s := make(Software, 1024)
if err := s.AddTo(m); !IsAttrSizeOverflow(err) {
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
}
if err := s.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
assert.True(t, IsAttrSizeOverflow(s.AddTo(m)), "AddTo should return *AttrOverflowErr")
assert.ErrorIs(t, s.GetFrom(m), ErrAttributeNotFound)
}
func TestSoftware_AddTo_Regression(t *testing.T) {
// s.AddTo checked len(m.Raw) instead of len(s.Raw).
m := &Message{Raw: make([]byte, 2048)}
s := make(Software, 100)
if err := s.AddTo(m); err != nil {
t.Errorf("AddTo should return <nil>, got: %v", err)
}
assert.NoError(t, s.AddTo(m))
}
func BenchmarkUsername_AddTo(b *testing.B) {
@@ -95,28 +81,18 @@ func TestUsername(t *testing.T) {
msg.WriteHeader()
t.Run("Bad length", func(t *testing.T) {
badU := make(Username, 600)
if err := badU.AddTo(msg); !IsAttrSizeOverflow(err) {
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
}
assert.True(t, IsAttrSizeOverflow(badU.AddTo(msg)), "AddTo should return *AttrOverflowErr")
})
t.Run("AddTo", func(t *testing.T) {
if err := uName.AddTo(msg); err != nil {
t.Error("errored:", err)
}
assert.NoError(t, uName.AddTo(msg))
t.Run("GetFrom", func(t *testing.T) {
got := new(Username)
if err := got.GetFrom(msg); err != nil {
t.Error("errored:", err)
}
if got.String() != username {
t.Errorf("expedted: %s, got: %s", username, got)
}
assert.NoError(t, got.GetFrom(msg))
assert.Equal(t, username, got.String())
t.Run("Not found", func(t *testing.T) {
m := new(Message)
u := new(Username)
if err := u.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) {
t.Error("Should error")
}
assert.ErrorIs(t, u.GetFrom(m), ErrAttributeNotFound)
})
})
})
@@ -124,14 +100,10 @@ func TestUsername(t *testing.T) {
m := new(Message)
m.WriteHeader()
u := NewUsername("username")
if allocs := testing.AllocsPerRun(10, func() {
if err := u.AddTo(m); err != nil {
t.Error(err)
}
assert.Empty(t, testing.AllocsPerRun(10, func() {
assert.NoError(t, u.AddTo(m))
m.Reset()
}); allocs > 0 {
t.Errorf("got %f allocations, zero expected", allocs)
}
}))
})
}
@@ -145,38 +117,23 @@ func TestRealm_GetFrom(t *testing.T) {
Raw: make([]byte, 0, 256),
}
r := new(Realm)
if err := r.GetFrom(m2); !errors.Is(err, ErrAttributeNotFound) {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
if _, err := m2.ReadFrom(msg.reader()); err != nil {
t.Error(err)
}
if err := r.GetFrom(msg); err != nil {
t.Fatal(err)
}
if r.String() != val {
t.Errorf("Expected %q, got %q.", val, r)
}
assert.ErrorIs(t, r.GetFrom(m2), ErrAttributeNotFound)
_, err := m2.ReadFrom(msg.reader())
assert.NoError(t, err)
assert.NoError(t, r.GetFrom(msg))
assert.Equal(t, val, r.String())
rAttr, ok := msg.Attributes.Get(AttrRealm)
if !ok {
t.Error("realm attribute should be found")
}
assert.True(t, ok, "realm attribute should be found")
s := rAttr.String()
if !strings.HasPrefix(s, "REALM:") {
t.Error("bad string representation", s)
}
assert.True(t, strings.HasPrefix(s, "REALM:"), "bad string representation")
}
func TestRealm_AddTo_Invalid(t *testing.T) {
m := New()
r := make(Realm, 1024)
if err := r.AddTo(m); !IsAttrSizeOverflow(err) {
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
}
if err := r.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
assert.True(t, IsAttrSizeOverflow(r.AddTo(m)), "AddTo should return *AttrOverflowErr")
assert.ErrorIs(t, r.GetFrom(m), ErrAttributeNotFound)
}
func TestNonce_GetFrom(t *testing.T) {
@@ -189,50 +146,31 @@ func TestNonce_GetFrom(t *testing.T) {
Raw: make([]byte, 0, 256),
}
var nonce Nonce
if _, err := m2.ReadFrom(msg.reader()); err != nil {
t.Error(err)
}
if err := nonce.GetFrom(msg); err != nil {
t.Fatal(err)
}
if nonce.String() != val {
t.Errorf("Expected %q, got %q.", val, nonce)
}
_, err := m2.ReadFrom(msg.reader())
assert.NoError(t, err)
assert.NoError(t, nonce.GetFrom(msg))
assert.Equal(t, val, nonce.String())
nAttr, ok := msg.Attributes.Get(AttrNonce)
if !ok {
t.Error("nonce attribute should be found")
}
assert.True(t, ok, "nonce attribute should be found")
s := nAttr.String()
if !strings.HasPrefix(s, "NONCE:") {
t.Error("bad string representation", s)
}
assert.True(t, strings.HasPrefix(s, "NONCE:"), "bad string representation")
}
func TestNonce_AddTo_Invalid(t *testing.T) {
m := New()
n := make(Nonce, 1024)
if err := n.AddTo(m); !IsAttrSizeOverflow(err) {
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
}
if err := n.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
assert.True(t, IsAttrSizeOverflow(n.AddTo(m)), "AddTo should return *AttrOverflowErr")
assert.ErrorIs(t, n.GetFrom(m), ErrAttributeNotFound)
}
func TestNonce_AddTo(t *testing.T) {
m := New()
n := Nonce("example.org")
if err := n.AddTo(m); err != nil {
t.Error(err)
}
assert.NoError(t, n.AddTo(m))
v, err := m.Get(AttrNonce)
if err != nil {
t.Error(err)
}
if string(v) != "example.org" {
t.Errorf("bad nonce %q", v)
}
assert.NoError(t, err)
assert.Equal(t, "example.org", string(v))
}
func BenchmarkNonce_AddTo(b *testing.B) {

View File

@@ -5,6 +5,8 @@ package stun
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestUnknownAttributes(t *testing.T) {
@@ -13,33 +15,19 @@ func TestUnknownAttributes(t *testing.T) {
AttrDontFragment,
AttrChannelNumber,
}
if attr.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" {
t.Error("bad String:", attr)
}
if (UnknownAttributes{}).String() != "<nil>" {
t.Error("bad blank string")
}
if err := attr.AddTo(msg); err != nil {
t.Error(err)
}
assert.Equal(t, "DONT-FRAGMENT, CHANNEL-NUMBER", attr.String())
assert.Equal(t, "<nil>", (UnknownAttributes{}).String())
assert.NoError(t, attr.AddTo(msg))
t.Run("GetFrom", func(t *testing.T) {
attrs := make(UnknownAttributes, 10)
if err := attrs.GetFrom(msg); err != nil {
t.Error(err)
}
assert.NoError(t, attrs.GetFrom(msg))
for i, at := range *attr {
if at != attrs[i] {
t.Error("expected", at, "!=", attrs[i])
}
assert.Equal(t, at, attrs[i])
}
mBlank := new(Message)
if err := attrs.GetFrom(mBlank); err == nil {
t.Error("should error")
}
assert.Error(t, attrs.GetFrom(mBlank))
mBlank.Add(AttrUnknownAttributes, []byte{1, 2, 3})
if err := attrs.GetFrom(mBlank); err == nil {
t.Error("should error")
}
assert.Error(t, attrs.GetFrom(mBlank))
})
}

View File

@@ -11,10 +11,11 @@ import (
"encoding/base64"
"encoding/binary"
"encoding/hex"
"errors"
"io"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
@@ -31,82 +32,56 @@ func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
func BenchmarkXORMappedAddress_GetFrom(b *testing.B) {
msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
b.Error(err)
}
assert.NoError(b, err)
copy(msg.TransactionID[:], transactionID)
addrValue, err := hex.DecodeString("00019cd5f49f38ae")
if err != nil {
b.Error(err)
}
assert.NoError(b, err)
msg.Add(AttrXORMappedAddress, addrValue)
addr := new(XORMappedAddress)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if err := addr.GetFrom(msg); err != nil {
b.Fatal(err)
}
assert.NoError(b, addr.GetFrom(msg))
}
}
func TestXORMappedAddress_GetFrom(t *testing.T) {
m := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
t.Error(err)
}
assert.NoError(t, err)
copy(m.TransactionID[:], transactionID)
addrValue, err := hex.DecodeString("00019cd5f49f38ae")
if err != nil {
t.Error(err)
}
assert.NoError(t, err)
m.Add(AttrXORMappedAddress, addrValue)
addr := new(XORMappedAddress)
if err = addr.GetFrom(m); err != nil {
t.Error(err)
}
if !addr.IP.Equal(net.ParseIP("213.141.156.236")) {
t.Error("bad IP", addr.IP, "!=", "213.141.156.236")
}
if addr.Port != 48583 {
t.Error("bad Port", addr.Port, "!=", 48583)
}
assert.NoError(t, addr.GetFrom(m))
assert.True(t, addr.IP.Equal(net.ParseIP("213.141.156.236")))
assert.Equal(t, 48583, addr.Port)
t.Run("UnexpectedEOF", func(t *testing.T) {
m := New()
// {0, 1} is correct addr family.
m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4})
addr := new(XORMappedAddress)
if err = addr.GetFrom(m); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Errorf("len(v) = 4 should render <%s> error, got <%s>",
io.ErrUnexpectedEOF, err,
)
}
assert.ErrorIs(t, addr.GetFrom(m), io.ErrUnexpectedEOF, "len(v) = 4 should return io.ErrUnexpectedEOF")
})
t.Run("AttrOverflowErr", func(t *testing.T) {
m := New()
// {0, 1} is correct addr family.
m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4, 5, 6, 7, 8, 9, 1, 1, 1, 1, 1, 2, 3, 4})
addr := new(XORMappedAddress)
if err := addr.GetFrom(m); !IsAttrSizeOverflow(err) {
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
}
assert.True(t, IsAttrSizeOverflow(addr.GetFrom(m)), "GetFrom should return *AttrOverflowErr")
})
}
func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
t.Error(err)
}
assert.NoError(t, err)
copy(msg.TransactionID[:], transactionID)
expectedIP := net.ParseIP("213.141.156.236")
expectedPort := 21254
addr := new(XORMappedAddress)
if err = addr.GetFrom(msg); err == nil {
t.Fatal(err, "should be nil")
}
assert.Error(t, addr.GetFrom(msg))
addr.IP = expectedIP
addr.Port = expectedPort
@@ -115,20 +90,15 @@ func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
mRes := New()
binary.BigEndian.PutUint16(msg.Raw[20+4:20+4+2], 0x21)
if _, err = mRes.ReadFrom(bytes.NewReader(msg.Raw)); err != nil {
t.Fatal(err)
}
if err = addr.GetFrom(msg); err == nil {
t.Fatal(err, "should not be nil")
}
_, err = mRes.ReadFrom(bytes.NewReader(msg.Raw))
assert.NoError(t, err)
assert.Error(t, addr.GetFrom(msg))
}
func TestXORMappedAddress_AddTo(t *testing.T) {
msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
t.Error(err)
}
assert.NoError(t, err)
copy(msg.TransactionID[:], transactionID)
expectedIP := net.ParseIP("213.141.156.236")
expectedPort := 21254
@@ -136,31 +106,20 @@ func TestXORMappedAddress_AddTo(t *testing.T) {
IP: net.ParseIP("213.141.156.236"),
Port: expectedPort,
}
if err = addr.AddTo(msg); err != nil {
t.Fatal(err)
}
assert.NoError(t, addr.AddTo(msg))
msg.WriteHeader()
mRes := New()
if _, err = mRes.Write(msg.Raw); err != nil {
t.Fatal(err)
}
if err = addr.GetFrom(mRes); err != nil {
t.Fatal(err)
}
if !addr.IP.Equal(expectedIP) {
t.Errorf("%s (got) != %s (expected)", addr.IP, expectedIP)
}
if addr.Port != expectedPort {
t.Error("bad Port", addr.Port, "!=", expectedPort)
}
_, err = mRes.Write(msg.Raw)
assert.NoError(t, err)
assert.NoError(t, addr.GetFrom(mRes))
assert.True(t, addr.IP.Equal(expectedIP), "Expected %s, got %s", expectedIP, addr.IP)
assert.Equal(t, expectedPort, addr.Port)
}
func TestXORMappedAddress_AddTo_IPv6(t *testing.T) {
msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
t.Error(err)
}
assert.NoError(t, err)
copy(msg.TransactionID[:], transactionID)
expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009")
expectedPort := 21254
@@ -172,19 +131,12 @@ func TestXORMappedAddress_AddTo_IPv6(t *testing.T) {
msg.WriteHeader()
mRes := New()
if _, err = mRes.ReadFrom(msg.reader()); err != nil {
t.Fatal(err)
}
_, err = mRes.ReadFrom(msg.reader())
assert.NoError(t, err)
gotAddr := new(XORMappedAddress)
if err = gotAddr.GetFrom(msg); err != nil {
t.Fatal(err)
}
if !gotAddr.IP.Equal(expectedIP) {
t.Error("bad IP", gotAddr.IP, "!=", expectedIP)
}
if gotAddr.Port != expectedPort {
t.Error("bad Port", gotAddr.Port, "!=", expectedPort)
}
assert.NoError(t, gotAddr.GetFrom(mRes))
assert.True(t, gotAddr.IP.Equal(expectedIP), "Expected %s, got %s", expectedIP, gotAddr.IP)
assert.Equal(t, expectedPort, gotAddr.Port)
}
func TestXORMappedAddress_AddTo_Invalid(t *testing.T) {
@@ -193,9 +145,7 @@ func TestXORMappedAddress_AddTo_Invalid(t *testing.T) {
IP: []byte{1, 2, 3, 4, 5, 6, 7, 8},
Port: 21254,
}
if err := addr.AddTo(m); !errors.Is(err, ErrBadIPLength) {
t.Errorf("AddTo should return %q, got: %v", ErrBadIPLength, err)
}
assert.ErrorIs(t, addr.AddTo(m), ErrBadIPLength)
}
func TestXORMappedAddress_String(t *testing.T) {
@@ -219,12 +169,6 @@ func TestXORMappedAddress_String(t *testing.T) {
},
}
for i, c := range tt {
if got := c.in.String(); got != c.out {
t.Errorf("[%d]: XORMappesAddres.String() %s (got) != %s (expected)",
i,
got,
c.out,
)
}
assert.Equalf(t, c.out, c.in.String(), "[%d]: XORMappesAddres.String()", i)
}
}