mirror of
https://github.com/pion/stun.git
synced 2025-10-27 17:51:04 +08:00
Update lint rules, force testify/assert for tests
Use testify's assert package instead of the standard library's testing package.
This commit is contained in:
@@ -19,12 +19,16 @@ linters-settings:
|
||||
recommendations:
|
||||
- errors
|
||||
forbidigo:
|
||||
analyze-types: true
|
||||
forbid:
|
||||
- ^fmt.Print(f|ln)?$
|
||||
- ^log.(Panic|Fatal|Print)(f|ln)?$
|
||||
- ^os.Exit$
|
||||
- ^panic$
|
||||
- ^print(ln)?$
|
||||
- p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$
|
||||
pkg: ^testing$
|
||||
msg: "use testify/assert instead"
|
||||
varnamelen:
|
||||
max-distance: 12
|
||||
min-name-length: 2
|
||||
@@ -127,9 +131,12 @@ issues:
|
||||
exclude-dirs-use-default: false
|
||||
exclude-rules:
|
||||
# Allow complex tests and examples, better to be self contained
|
||||
- path: (examples|main\.go|_test\.go)
|
||||
- path: (examples|main\.go)
|
||||
linters:
|
||||
- gocognit
|
||||
- forbidigo
|
||||
- path: _test\.go
|
||||
linters:
|
||||
- gocognit
|
||||
|
||||
# Allow forbidden identifiers in CLI commands
|
||||
|
||||
83
addr_test.go
83
addr_test.go
@@ -4,10 +4,11 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMappedAddress(t *testing.T) {
|
||||
@@ -16,48 +17,32 @@ func TestMappedAddress(t *testing.T) {
|
||||
IP: net.ParseIP("122.12.34.5"),
|
||||
Port: 5412,
|
||||
}
|
||||
if addr.String() != "122.12.34.5:5412" {
|
||||
t.Error("bad string", addr)
|
||||
}
|
||||
assert.Equal(t, "122.12.34.5:5412", addr.String(), "bad string")
|
||||
t.Run("Bad length", func(t *testing.T) {
|
||||
badAddr := &MappedAddress{
|
||||
IP: net.IP{1, 2, 3},
|
||||
}
|
||||
if err := badAddr.AddTo(msg); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
assert.Error(t, badAddr.AddTo(msg), "should error")
|
||||
})
|
||||
t.Run("AddTo", func(t *testing.T) {
|
||||
if err := addr.AddTo(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, addr.AddTo(msg))
|
||||
t.Run("GetFrom", func(t *testing.T) {
|
||||
got := new(MappedAddress)
|
||||
if err := got.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !got.IP.Equal(addr.IP) {
|
||||
t.Error("got bad IP: ", got.IP)
|
||||
}
|
||||
assert.NoError(t, got.GetFrom(msg))
|
||||
assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP)
|
||||
t.Run("Not found", func(t *testing.T) {
|
||||
message := new(Message)
|
||||
if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Error("should be not found: ", err)
|
||||
}
|
||||
assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found")
|
||||
})
|
||||
t.Run("Bad family", func(t *testing.T) {
|
||||
v, _ := msg.Attributes.Get(AttrMappedAddress)
|
||||
v.Value[0] = 32
|
||||
if err := got.GetFrom(msg); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
assert.Error(t, got.GetFrom(msg), "should error")
|
||||
})
|
||||
t.Run("Bad length", func(t *testing.T) {
|
||||
message := new(Message)
|
||||
message.Add(AttrMappedAddress, []byte{1, 2, 3})
|
||||
if err := got.GetFrom(message); !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
t.Errorf("<%s> should be <%s>", err, io.ErrUnexpectedEOF)
|
||||
}
|
||||
assert.ErrorIs(t, got.GetFrom(message), io.ErrUnexpectedEOF)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -70,22 +55,14 @@ func TestMappedAddressV6(t *testing.T) { //nolint:dupl
|
||||
Port: 5412,
|
||||
}
|
||||
t.Run("AddTo", func(t *testing.T) {
|
||||
if err := addr.AddTo(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, addr.AddTo(m))
|
||||
t.Run("GetFrom", func(t *testing.T) {
|
||||
got := new(MappedAddress)
|
||||
if err := got.GetFrom(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !got.IP.Equal(addr.IP) {
|
||||
t.Error("got bad IP: ", got.IP)
|
||||
}
|
||||
assert.NoError(t, got.GetFrom(m))
|
||||
assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP)
|
||||
t.Run("Not found", func(t *testing.T) {
|
||||
message := new(Message)
|
||||
if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Error("should be not found: ", err)
|
||||
}
|
||||
assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found")
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -98,22 +75,14 @@ func TestAlternateServer(t *testing.T) { //nolint:dupl
|
||||
Port: 5412,
|
||||
}
|
||||
t.Run("AddTo", func(t *testing.T) {
|
||||
if err := addr.AddTo(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, addr.AddTo(m))
|
||||
t.Run("GetFrom", func(t *testing.T) {
|
||||
got := new(AlternateServer)
|
||||
if err := got.GetFrom(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !got.IP.Equal(addr.IP) {
|
||||
t.Error("got bad IP: ", got.IP)
|
||||
}
|
||||
assert.NoError(t, got.GetFrom(m))
|
||||
assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP)
|
||||
t.Run("Not found", func(t *testing.T) {
|
||||
message := new(Message)
|
||||
if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Error("should be not found: ", err)
|
||||
}
|
||||
assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found")
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -126,22 +95,14 @@ func TestOtherAddress(t *testing.T) { //nolint:dupl
|
||||
Port: 5412,
|
||||
}
|
||||
t.Run("AddTo", func(t *testing.T) {
|
||||
if err := addr.AddTo(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, addr.AddTo(m))
|
||||
t.Run("GetFrom", func(t *testing.T) {
|
||||
got := new(OtherAddress)
|
||||
if err := got.GetFrom(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !got.IP.Equal(addr.IP) {
|
||||
t.Error("got bad IP: ", got.IP)
|
||||
}
|
||||
assert.NoError(t, got.GetFrom(m))
|
||||
assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP)
|
||||
t.Run("Not found", func(t *testing.T) {
|
||||
message := new(Message)
|
||||
if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Error("should be not found: ", err)
|
||||
}
|
||||
assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
167
agent_test.go
167
agent_test.go
@@ -7,84 +7,44 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAgent_ProcessInTransaction(t *testing.T) {
|
||||
msg := New()
|
||||
agent := NewAgent(func(e Event) {
|
||||
if e.Error != nil {
|
||||
t.Errorf("got error: %s", e.Error)
|
||||
}
|
||||
if !e.Message.Equal(msg) {
|
||||
t.Errorf("%s (got) != %s (expected)", e.Message, msg)
|
||||
}
|
||||
assert.NoError(t, e.Error, "got error")
|
||||
assert.True(t, e.Message.Equal(msg), "%s (got) != %s (expected)", e.Message, msg)
|
||||
})
|
||||
if err := msg.NewTransactionID(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := agent.Start(msg.TransactionID, time.Time{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := agent.Process(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := agent.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, msg.NewTransactionID())
|
||||
assert.NoError(t, agent.Start(msg.TransactionID, time.Time{}))
|
||||
assert.NoError(t, agent.Process(msg))
|
||||
assert.NoError(t, agent.Close())
|
||||
}
|
||||
|
||||
func TestAgent_Process(t *testing.T) {
|
||||
msg := New()
|
||||
agent := NewAgent(func(e Event) {
|
||||
if e.Error != nil {
|
||||
t.Errorf("got error: %s", e.Error)
|
||||
}
|
||||
if !e.Message.Equal(msg) {
|
||||
t.Errorf("%s (got) != %s (expected)", e.Message, msg)
|
||||
}
|
||||
assert.NoError(t, e.Error, "got error")
|
||||
assert.True(t, e.Message.Equal(msg), "%s (got) != %s (expected)", e.Message, msg)
|
||||
})
|
||||
if err := msg.NewTransactionID(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := agent.Process(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := agent.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := agent.Process(msg); !errors.Is(err, ErrAgentClosed) {
|
||||
t.Errorf("closed agent should return <%s>, but got <%s>",
|
||||
ErrAgentClosed, err,
|
||||
)
|
||||
}
|
||||
assert.NoError(t, msg.NewTransactionID())
|
||||
assert.NoError(t, agent.Process(msg))
|
||||
assert.NoError(t, agent.Close())
|
||||
assert.ErrorIs(t, agent.Process(msg), ErrAgentClosed)
|
||||
}
|
||||
|
||||
func TestAgent_Start(t *testing.T) {
|
||||
agent := NewAgent(nil)
|
||||
id := NewTransactionID()
|
||||
deadline := time.Now().AddDate(0, 0, 1)
|
||||
if err := agent.Start(id, deadline); err != nil {
|
||||
t.Errorf("failed to statt transaction: %s", err)
|
||||
}
|
||||
if err := agent.Start(id, deadline); !errors.Is(err, ErrTransactionExists) {
|
||||
t.Errorf("duplicate start should return <%s>, got <%s>",
|
||||
ErrTransactionExists, err,
|
||||
)
|
||||
}
|
||||
if err := agent.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, agent.Start(id, deadline), "failed to start transaction")
|
||||
assert.ErrorIs(t, agent.Start(id, deadline), ErrTransactionExists)
|
||||
assert.NoError(t, agent.Close())
|
||||
id = NewTransactionID()
|
||||
if err := agent.Start(id, deadline); !errors.Is(err, ErrAgentClosed) {
|
||||
t.Errorf("start on closed agent should return <%s>, got <%s>",
|
||||
ErrAgentClosed, err,
|
||||
)
|
||||
}
|
||||
if err := agent.SetHandler(nil); !errors.Is(err, ErrAgentClosed) {
|
||||
t.Errorf("SetHandler on closed agent should return <%s>, got <%s>",
|
||||
ErrAgentClosed, err,
|
||||
)
|
||||
}
|
||||
assert.ErrorIs(t, agent.Start(id, deadline), ErrAgentClosed)
|
||||
assert.ErrorIs(t, agent.SetHandler(nil), ErrAgentClosed)
|
||||
}
|
||||
|
||||
func TestAgent_Stop(t *testing.T) {
|
||||
@@ -92,36 +52,20 @@ func TestAgent_Stop(t *testing.T) {
|
||||
agent := NewAgent(func(e Event) {
|
||||
called <- e
|
||||
})
|
||||
if err := agent.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) {
|
||||
t.Fatalf("unexpected error: %s, should be %s", err, ErrTransactionNotExists)
|
||||
}
|
||||
assert.ErrorIs(t, agent.Stop(transactionID{}), ErrTransactionNotExists)
|
||||
id := NewTransactionID()
|
||||
timeout := time.Millisecond * 200
|
||||
if err := agent.Start(id, time.Now().Add(timeout)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := agent.Stop(id); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, agent.Start(id, time.Now().Add(timeout)))
|
||||
assert.NoError(t, agent.Stop(id))
|
||||
select {
|
||||
case e := <-called:
|
||||
if !errors.Is(e.Error, ErrTransactionStopped) {
|
||||
t.Fatalf("unexpected error: %s, should be %s",
|
||||
e.Error, ErrTransactionStopped,
|
||||
)
|
||||
}
|
||||
assert.ErrorIs(t, e.Error, ErrTransactionStopped)
|
||||
case <-time.After(timeout * 2):
|
||||
t.Fatal("timed out")
|
||||
}
|
||||
if err := agent.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := agent.Close(); !errors.Is(err, ErrAgentClosed) {
|
||||
t.Fatalf("a.Close returned %s instead of %s", err, ErrAgentClosed)
|
||||
}
|
||||
if err := agent.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) {
|
||||
t.Fatalf("unexpected error: %s, should be %s", err, ErrAgentClosed)
|
||||
assert.Fail(t, "timed out")
|
||||
}
|
||||
assert.NoError(t, agent.Close())
|
||||
assert.ErrorIs(t, agent.Close(), ErrAgentClosed)
|
||||
assert.ErrorIs(t, agent.Stop(transactionID{}), ErrAgentClosed)
|
||||
}
|
||||
|
||||
func TestAgent_GC(t *testing.T) { //nolint:cyclop
|
||||
@@ -136,60 +80,41 @@ func TestAgent_GC(t *testing.T) { //nolint:cyclop
|
||||
agent.SetHandler(func(e Event) { //nolint:errcheck,gosec
|
||||
id := e.TransactionID
|
||||
shouldTimeOut, found := shouldTimeOutID[id]
|
||||
if !found {
|
||||
t.Error("unexpected transaction ID")
|
||||
}
|
||||
if shouldTimeOut && !errors.Is(e.Error, ErrTransactionTimeOut) {
|
||||
t.Errorf("%x should time out, but got %v", id, e.Error)
|
||||
}
|
||||
if !shouldTimeOut && errors.Is(e.Error, ErrTransactionTimeOut) {
|
||||
t.Errorf("%x should not time out, but got %v", id, e.Error)
|
||||
assert.True(t, found, "unexpected transaction ID")
|
||||
if shouldTimeOut {
|
||||
assert.ErrorIs(t, e.Error, ErrTransactionTimeOut, "%x should time out", id)
|
||||
} else {
|
||||
assert.False(t, errors.Is(e.Error, ErrTransactionTimeOut), "%x should not time out", id)
|
||||
}
|
||||
})
|
||||
for i := 0; i < 5; i++ {
|
||||
id := NewTransactionID()
|
||||
shouldTimeOutID[id] = false
|
||||
if err := agent.Start(id, deadline); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, agent.Start(id, deadline))
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
id := NewTransactionID()
|
||||
shouldTimeOutID[id] = true
|
||||
if err := agent.Start(id, deadlineNotGC); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if err := agent.Collect(gcDeadline); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := agent.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := agent.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) {
|
||||
t.Errorf("should <%s>, but got <%s>", ErrAgentClosed, err)
|
||||
assert.NoError(t, agent.Start(id, deadlineNotGC))
|
||||
}
|
||||
assert.NoError(t, agent.Collect(gcDeadline))
|
||||
assert.NoError(t, agent.Close())
|
||||
assert.ErrorIs(t, agent.Collect(gcDeadline), ErrAgentClosed)
|
||||
}
|
||||
|
||||
func BenchmarkAgent_GC(b *testing.B) {
|
||||
agent := NewAgent(nil)
|
||||
deadline := time.Now().AddDate(0, 0, 1)
|
||||
for i := 0; i < agentCollectCap; i++ {
|
||||
if err := agent.Start(NewTransactionID(), deadline); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
assert.NoError(b, agent.Start(NewTransactionID(), deadline))
|
||||
}
|
||||
defer func() {
|
||||
if err := agent.Close(); err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
assert.NoError(b, agent.Close())
|
||||
}()
|
||||
b.ReportAllocs()
|
||||
gcDeadline := deadline.Add(-time.Second)
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := agent.Collect(gcDeadline); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
assert.NoError(b, agent.Collect(gcDeadline))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,20 +122,14 @@ func BenchmarkAgent_Process(b *testing.B) {
|
||||
agent := NewAgent(nil)
|
||||
deadline := time.Now().AddDate(0, 0, 1)
|
||||
for i := 0; i < 1000; i++ {
|
||||
if err := agent.Start(NewTransactionID(), deadline); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
assert.NoError(b, agent.Start(NewTransactionID(), deadline))
|
||||
}
|
||||
defer func() {
|
||||
if err := agent.Close(); err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
assert.NoError(b, agent.Close())
|
||||
}()
|
||||
b.ReportAllocs()
|
||||
m := MustBuild(TransactionID)
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := agent.Process(m); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
assert.NoError(b, agent.Process(m))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,11 @@
|
||||
|
||||
package stun
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAttrOverflowErr_Error(t *testing.T) {
|
||||
err := AttrOverflowErr{
|
||||
@@ -14,9 +18,7 @@ func TestAttrOverflowErr_Error(t *testing.T) {
|
||||
Max: 50,
|
||||
Type: AttrLifetime,
|
||||
}
|
||||
if err.Error() != "incorrect length of LIFETIME attribute: 100 exceeds maximum 50" {
|
||||
t.Error("bad error string", err)
|
||||
}
|
||||
assert.Equal(t, "incorrect length of LIFETIME attribute: 100 exceeds maximum 50", err.Error())
|
||||
}
|
||||
|
||||
func TestAttrLengthErr_Error(t *testing.T) {
|
||||
@@ -25,7 +27,5 @@ func TestAttrLengthErr_Error(t *testing.T) {
|
||||
Expected: 15,
|
||||
Got: 99,
|
||||
}
|
||||
if err.Error() != "incorrect length of ERROR-CODE attribute: got 99, expected 15" {
|
||||
t.Errorf("bad error string: %s", err)
|
||||
}
|
||||
assert.Equal(t, "incorrect length of ERROR-CODE attribute: got 99, expected 15", err.Error())
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ package stun
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func BenchmarkMessage_GetNotFound(b *testing.B) {
|
||||
@@ -31,16 +33,10 @@ func TestRawAttribute_AddTo(t *testing.T) {
|
||||
Type: AttrData,
|
||||
Value: v,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
gotV, gotErr := m.Get(AttrData)
|
||||
if gotErr != nil {
|
||||
t.Fatal(gotErr)
|
||||
}
|
||||
if !bytes.Equal(gotV, v) {
|
||||
t.Error("value mismatch")
|
||||
}
|
||||
assert.NoError(t, gotErr)
|
||||
assert.True(t, bytes.Equal(gotV, v), "value mismatch")
|
||||
}
|
||||
|
||||
func TestMessage_GetNoAllocs(t *testing.T) {
|
||||
@@ -52,17 +48,13 @@ func TestMessage_GetNoAllocs(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(10, func() {
|
||||
msg.Get(AttrSoftware) //nolint:errcheck,gosec
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Error("allocated memory, but should not")
|
||||
}
|
||||
assert.Zero(t, allocs, "allocated memory, but should not")
|
||||
})
|
||||
t.Run("Not found", func(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(10, func() {
|
||||
msg.Get(AttrOrigin) //nolint:errcheck,gosec
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Error("allocated memory, but should not")
|
||||
}
|
||||
assert.Zero(t, allocs, "allocated memory, but should not")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -83,11 +75,8 @@ func TestPadding(t *testing.T) {
|
||||
{40, 40}, // 10
|
||||
}
|
||||
for i, c := range tt {
|
||||
if got := nearestPaddedValueLength(c.in); got != c.out {
|
||||
t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)",
|
||||
i, c.in, got, c.out,
|
||||
)
|
||||
}
|
||||
got := nearestPaddedValueLength(c.in)
|
||||
assert.Equal(t, c.out, got, "[%d]: padd(%d)", i, c.in)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,9 +91,8 @@ func TestAttrTypeRange(t *testing.T) {
|
||||
a := a
|
||||
t.Run(a.String(), func(t *testing.T) {
|
||||
a := a
|
||||
if a.Optional() || !a.Required() {
|
||||
t.Error("should be required")
|
||||
}
|
||||
assert.True(t, a.Required(), "should be required")
|
||||
assert.False(t, a.Optional(), "should be required")
|
||||
})
|
||||
}
|
||||
for _, a := range []AttrType{
|
||||
@@ -114,9 +102,8 @@ func TestAttrTypeRange(t *testing.T) {
|
||||
} {
|
||||
a := a
|
||||
t.Run(a.String(), func(t *testing.T) {
|
||||
if a.Required() || !a.Optional() {
|
||||
t.Error("should be optional")
|
||||
}
|
||||
assert.False(t, a.Required(), "should be optional")
|
||||
assert.True(t, a.Optional(), "should be optional")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
544
client_test.go
544
client_test.go
@@ -18,6 +18,8 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -88,13 +90,9 @@ func BenchmarkClient_Do(b *testing.B) {
|
||||
client, err := NewClient(noopConnection{},
|
||||
WithAgent(agent),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
assert.NoError(b, err)
|
||||
defer func() {
|
||||
if closeErr := client.Close(); closeErr != nil {
|
||||
panic(closeErr)
|
||||
}
|
||||
assert.NoError(b, client.Close())
|
||||
}()
|
||||
|
||||
noopF := func(Event) {
|
||||
@@ -163,9 +161,8 @@ func TestClosedOrPanic(t *testing.T) {
|
||||
func() {
|
||||
defer func() {
|
||||
r, ok := recover().(error)
|
||||
if !ok || !errors.Is(r, io.EOF) {
|
||||
t.Error(r)
|
||||
}
|
||||
assert.True(t, ok, "should be error")
|
||||
assert.ErrorIs(t, r, io.EOF)
|
||||
}()
|
||||
closedOrPanic(io.EOF)
|
||||
}()
|
||||
@@ -203,46 +200,32 @@ func TestClient_Start(t *testing.T) { //nolint:cyclop
|
||||
},
|
||||
}
|
||||
client, err := NewClient(conn)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := client.Close(); err == nil {
|
||||
t.Error("second close should fail")
|
||||
}
|
||||
if err := client.Do(MustBuild(TransactionID), nil); err == nil {
|
||||
t.Error("Do after Close should fail")
|
||||
}
|
||||
assert.NoError(t, client.Close())
|
||||
assert.Error(t, client.Close(), "second close should fail")
|
||||
assert.Error(t, client.Do(MustBuild(TransactionID), nil), "Do after Close should fail")
|
||||
}()
|
||||
msg := MustBuild(response, BindingRequest)
|
||||
t.Log("init")
|
||||
got := make(chan struct{})
|
||||
write <- struct{}{}
|
||||
t.Log("starting the first transaction")
|
||||
if err := client.Start(msg, func(event Event) {
|
||||
assert.NoError(t, client.Start(msg, func(event Event) {
|
||||
t.Log("got first transaction callback")
|
||||
if event.Error != nil {
|
||||
t.Error(event.Error)
|
||||
}
|
||||
assert.NoError(t, event.Error)
|
||||
got <- struct{}{}
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}))
|
||||
t.Log("starting the second transaction")
|
||||
if err := client.Start(msg, func(Event) {
|
||||
t.Error("should not be called")
|
||||
}); !errors.Is(err, ErrTransactionExists) {
|
||||
t.Errorf("unexpected error %v", err)
|
||||
}
|
||||
assert.ErrorIs(t, client.Start(msg, func(Event) {
|
||||
assert.Fail(t, "should not be called")
|
||||
}), ErrTransactionExists)
|
||||
read <- struct{}{}
|
||||
select {
|
||||
case <-got:
|
||||
// pass
|
||||
case <-time.After(time.Millisecond * 10):
|
||||
t.Error("timed out")
|
||||
assert.Fail(t, "timed out")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,34 +239,20 @@ func TestClient_Do(t *testing.T) {
|
||||
},
|
||||
}
|
||||
client, err := NewClient(conn)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := client.Close(); err == nil {
|
||||
t.Error("second close should fail")
|
||||
}
|
||||
if err := client.Do(MustBuild(TransactionID), nil); err == nil {
|
||||
t.Error("Do after Close should fail")
|
||||
}
|
||||
assert.NoError(t, client.Close())
|
||||
assert.Error(t, client.Close(), "second close should fail")
|
||||
assert.Error(t, client.Do(MustBuild(TransactionID), nil), "Do after Close should fail")
|
||||
}()
|
||||
m := MustBuild(
|
||||
NewTransactionIDSetter(response.TransactionID),
|
||||
)
|
||||
if err := client.Do(m, func(event Event) {
|
||||
if event.Error != nil {
|
||||
t.Error(event.Error)
|
||||
}
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, client.Do(m, func(event Event) {
|
||||
assert.NoError(t, event.Error)
|
||||
}))
|
||||
m = MustBuild(TransactionID)
|
||||
if err := client.Do(m, nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, client.Do(m, nil))
|
||||
}
|
||||
|
||||
func TestCloseErr_Error(t *testing.T) {
|
||||
@@ -299,11 +268,7 @@ func TestCloseErr_Error(t *testing.T) {
|
||||
ConnectionErr: io.ErrUnexpectedEOF,
|
||||
}, "failed to close: unexpected EOF (connection), <nil> (agent)"},
|
||||
} {
|
||||
if out := testCase.Err.Error(); out != testCase.Out {
|
||||
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
|
||||
id, testCase.Err, out, testCase.Out,
|
||||
)
|
||||
}
|
||||
assert.Equal(t, testCase.Out, testCase.Err.Error(), "[%d]: Error(%#v)", id, testCase.Err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,11 +285,7 @@ func TestStopErr_Error(t *testing.T) {
|
||||
Cause: io.ErrUnexpectedEOF,
|
||||
}, "error while stopping due to unexpected EOF: <nil>"},
|
||||
} {
|
||||
if out := testcase.Err.Error(); out != testcase.Out {
|
||||
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
|
||||
id, testcase.Err, out, testcase.Out,
|
||||
)
|
||||
}
|
||||
assert.Equal(t, testcase.Out, testcase.Err.Error(), "[%d]: Error(%#v)", id, testcase.Err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -365,25 +326,15 @@ func TestClientAgentError(t *testing.T) {
|
||||
startErr: io.ErrUnexpectedEOF,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, client.Close())
|
||||
}()
|
||||
m := MustBuild(NewTransactionIDSetter(response.TransactionID))
|
||||
if err := client.Do(m, nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := client.Do(m, func(event Event) {
|
||||
if event.Error == nil {
|
||||
t.Error("error expected")
|
||||
}
|
||||
}); !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
t.Error("error expected")
|
||||
}
|
||||
assert.NoError(t, client.Do(m, nil))
|
||||
assert.ErrorIs(t, client.Do(m, func(event Event) {
|
||||
assert.Error(t, event.Error, "error expected")
|
||||
}), io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
func TestClientConnErr(t *testing.T) {
|
||||
@@ -393,21 +344,13 @@ func TestClientConnErr(t *testing.T) {
|
||||
},
|
||||
}
|
||||
client, err := NewClient(conn)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, client.Close())
|
||||
}()
|
||||
m := MustBuild(TransactionID)
|
||||
if err := client.Do(m, nil); err == nil {
|
||||
t.Error("error expected")
|
||||
}
|
||||
if err := client.Do(m, NoopHandler()); err == nil {
|
||||
t.Error("error expected")
|
||||
}
|
||||
assert.Error(t, client.Do(m, nil), "error expected")
|
||||
assert.Error(t, client.Do(m, NoopHandler()), "error expected")
|
||||
}
|
||||
|
||||
func TestClientConnErrStopErr(t *testing.T) {
|
||||
@@ -421,26 +364,19 @@ func TestClientConnErrStopErr(t *testing.T) {
|
||||
stopErr: io.ErrUnexpectedEOF,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, client.Close())
|
||||
}()
|
||||
m := MustBuild(TransactionID)
|
||||
if err := client.Do(m, NoopHandler()); err == nil {
|
||||
t.Error("error expected")
|
||||
}
|
||||
assert.Error(t, client.Do(m, NoopHandler()), "error expected")
|
||||
}
|
||||
|
||||
func TestCallbackWaitHandler_setCallback(t *testing.T) {
|
||||
c := callbackWaitHandler{}
|
||||
defer func() {
|
||||
if err := recover(); err == nil {
|
||||
t.Error("should panic")
|
||||
}
|
||||
err := recover()
|
||||
assert.NotNil(t, err, "should panic")
|
||||
}()
|
||||
c.setCallback(nil)
|
||||
}
|
||||
@@ -450,56 +386,39 @@ func TestCallbackWaitHandler_HandleEvent(t *testing.T) {
|
||||
cond: sync.NewCond(new(sync.Mutex)),
|
||||
}
|
||||
defer func() {
|
||||
if err := recover(); err == nil {
|
||||
t.Error("should panic")
|
||||
}
|
||||
err := recover()
|
||||
assert.NotNil(t, err, "should panic")
|
||||
}()
|
||||
c.HandleEvent(Event{})
|
||||
}
|
||||
|
||||
func TestNewClientNoConnection(t *testing.T) {
|
||||
c, err := NewClient(nil)
|
||||
if c != nil {
|
||||
t.Error("c should be nil")
|
||||
}
|
||||
if !errors.Is(err, ErrNoConnection) {
|
||||
t.Error("bad error")
|
||||
}
|
||||
assert.Nil(t, c, "c should be nil")
|
||||
assert.ErrorIs(t, err, ErrNoConnection, "bad error")
|
||||
}
|
||||
|
||||
func TestDial(t *testing.T) {
|
||||
c, err := Dial("udp4", "localhost:3458")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
if err = c.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, c.Close())
|
||||
}()
|
||||
}
|
||||
|
||||
func TestDialURI(t *testing.T) {
|
||||
u, err := ParseURI("stun:localhost")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
c, err := DialURI(u, &DialConfig{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
if err = c.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, c.Close())
|
||||
}()
|
||||
}
|
||||
|
||||
func TestDialError(t *testing.T) {
|
||||
_, err := Dial("bad?network", "?????")
|
||||
if err == nil {
|
||||
t.Fatal("error expected")
|
||||
}
|
||||
assert.Error(t, err, "error expected")
|
||||
}
|
||||
|
||||
func TestClientCloseErr(t *testing.T) {
|
||||
@@ -516,13 +435,11 @@ func TestClientCloseErr(t *testing.T) {
|
||||
closeErr: io.ErrUnexpectedEOF,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
if err, ok := c.Close().(CloseErr); !ok || !errors.Is(err.AgentErr, io.ErrUnexpectedEOF) { //nolint:errorlint
|
||||
t.Error("unexpected close err")
|
||||
}
|
||||
err, ok := c.Close().(CloseErr) //nolint:errorlint
|
||||
assert.True(t, ok, "should be CloseErr")
|
||||
assert.ErrorIs(t, err.AgentErr, io.ErrUnexpectedEOF, "unexpected close err")
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -542,12 +459,8 @@ func TestWithNoConnClose(t *testing.T) {
|
||||
}),
|
||||
WithNoConnClose(),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := c.Close(); err != nil {
|
||||
t.Error("unexpected non-nil error")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, c.Close(), "unexpected non-nil error")
|
||||
}
|
||||
|
||||
type gcWaitAgent struct {
|
||||
@@ -598,28 +511,20 @@ func TestClientGC(t *testing.T) {
|
||||
WithAgent(agent),
|
||||
WithTimeoutRate(time.Millisecond),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
if err = c.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, c.Close())
|
||||
}()
|
||||
select {
|
||||
case <-agent.gc:
|
||||
case <-time.After(time.Millisecond * 200):
|
||||
t.Error("timed out")
|
||||
assert.Fail(t, "timed out")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientCheckInit(t *testing.T) {
|
||||
if err := (&Client{}).Indicate(nil); !errors.Is(err, ErrClientNotInitialized) {
|
||||
t.Error("unexpected error")
|
||||
}
|
||||
if err := (&Client{}).Do(nil, nil); !errors.Is(err, ErrClientNotInitialized) {
|
||||
t.Error("unexpected error")
|
||||
}
|
||||
assert.ErrorIs(t, (&Client{}).Indicate(nil), ErrClientNotInitialized)
|
||||
assert.ErrorIs(t, (&Client{}).Do(nil, nil), ErrClientNotInitialized)
|
||||
}
|
||||
|
||||
func captureLog() (*bytes.Buffer, func()) {
|
||||
@@ -645,9 +550,7 @@ func TestClientFinalizer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
client, err := NewClient(conn)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
clientFinalizer(client)
|
||||
clientFinalizer(client)
|
||||
response := MustBuild(TransactionID, BindingSuccess)
|
||||
@@ -663,9 +566,7 @@ func TestClientFinalizer(t *testing.T) {
|
||||
closeErr: io.ErrUnexpectedEOF,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
clientFinalizer(client)
|
||||
reader := bufio.NewScanner(buf)
|
||||
var lines int
|
||||
@@ -676,17 +577,11 @@ func TestClientFinalizer(t *testing.T) {
|
||||
"<nil> (connection), unexpected EOF (agent)",
|
||||
}
|
||||
for reader.Scan() {
|
||||
if reader.Text() != expectedLines[lines] {
|
||||
t.Error(reader.Text(), "!=", expectedLines[lines])
|
||||
}
|
||||
assert.Equal(t, expectedLines[lines], reader.Text())
|
||||
lines++
|
||||
}
|
||||
if reader.Err() != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if lines != 3 {
|
||||
t.Error("incorrect count of log lines:", lines)
|
||||
}
|
||||
assert.NoError(t, reader.Err())
|
||||
assert.Equal(t, 3, lines, "incorrect count of log lines")
|
||||
}
|
||||
|
||||
func TestCallbackWaitHandler(*testing.T) {
|
||||
@@ -784,9 +679,7 @@ func TestClientRetransmission(t *testing.T) {
|
||||
response.Encode()
|
||||
connL, connR := net.Pipe()
|
||||
defer func() {
|
||||
if closeErr := connL.Close(); closeErr != nil {
|
||||
panic(closeErr)
|
||||
}
|
||||
assert.NoError(t, connR.Close())
|
||||
}()
|
||||
collector := new(manualCollector)
|
||||
clock := &manualClock{current: time.Now()}
|
||||
@@ -814,36 +707,22 @@ func TestClientRetransmission(t *testing.T) {
|
||||
WithCollector(collector),
|
||||
WithRTO(time.Millisecond),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
client.SetRTO(time.Second)
|
||||
gotReads := make(chan struct{})
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
readN, readErr := connL.Read(buf)
|
||||
if readErr != nil {
|
||||
t.Error(readErr)
|
||||
}
|
||||
if !IsMessage(buf[:readN]) {
|
||||
t.Error("should be STUN")
|
||||
}
|
||||
assert.NoError(t, readErr)
|
||||
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
|
||||
readN, readErr = connL.Read(buf)
|
||||
if readErr != nil {
|
||||
t.Error(readErr)
|
||||
}
|
||||
if !IsMessage(buf[:readN]) {
|
||||
t.Error("should be STUN")
|
||||
}
|
||||
assert.NoError(t, readErr)
|
||||
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
|
||||
gotReads <- struct{}{}
|
||||
}()
|
||||
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if event.Error != nil {
|
||||
t.Error("failed")
|
||||
}
|
||||
}); doErr != nil {
|
||||
t.Fatal(doErr)
|
||||
}
|
||||
assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
assert.NoError(t, event.Error, "failed")
|
||||
}))
|
||||
<-gotReads
|
||||
}
|
||||
|
||||
@@ -854,9 +733,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
|
||||
response.Encode()
|
||||
connL, connR := net.Pipe()
|
||||
defer func() {
|
||||
if closeErr := connL.Close(); closeErr != nil {
|
||||
panic(closeErr)
|
||||
}
|
||||
assert.NoError(t, connR.Close())
|
||||
}()
|
||||
collector := new(manualCollector)
|
||||
clock := &manualClock{current: time.Now()}
|
||||
@@ -874,9 +751,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
|
||||
WithClock(clock),
|
||||
WithCollector(collector),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
client.SetRTO(time.Second)
|
||||
conns := new(sync.WaitGroup)
|
||||
wg := new(sync.WaitGroup)
|
||||
@@ -891,29 +766,21 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
break
|
||||
}
|
||||
t.Error(readErr)
|
||||
}
|
||||
if !IsMessage(buf[:readN]) {
|
||||
t.Error("should be STUN")
|
||||
assert.NoError(t, readErr)
|
||||
}
|
||||
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
|
||||
}
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if doErr := client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) {
|
||||
if event.Error != nil {
|
||||
t.Error("failed")
|
||||
}
|
||||
}); doErr != nil {
|
||||
t.Error(doErr)
|
||||
}
|
||||
assert.NoError(t, client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) {
|
||||
assert.NoError(t, event.Error, "failed")
|
||||
}))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if connErr := connR.Close(); connErr != nil {
|
||||
t.Error(connErr)
|
||||
}
|
||||
assert.NoError(t, connR.Close())
|
||||
conns.Wait()
|
||||
}
|
||||
|
||||
@@ -942,24 +809,22 @@ func (c errorCollector) Close() error { return c.closeError }
|
||||
func TestNewClient(t *testing.T) {
|
||||
t.Run("SetCallbackError", func(t *testing.T) {
|
||||
setHandlerError := errClientSetHandler
|
||||
if _, createErr := NewClient(noopConnection{},
|
||||
_, createErr := NewClient(noopConnection{},
|
||||
WithAgent(&errorAgent{
|
||||
setHandlerError: setHandlerError,
|
||||
}),
|
||||
); !errors.Is(createErr, setHandlerError) {
|
||||
t.Errorf("unexpected error returned: %v", createErr)
|
||||
}
|
||||
)
|
||||
assert.ErrorIs(t, createErr, setHandlerError, "unexpected error returned")
|
||||
})
|
||||
t.Run("CollectorStartError", func(t *testing.T) {
|
||||
startError := errClientStart
|
||||
if _, createErr := NewClient(noopConnection{},
|
||||
_, createErr := NewClient(noopConnection{},
|
||||
WithAgent(&TestAgent{}),
|
||||
WithCollector(errorCollector{
|
||||
startError: startError,
|
||||
}),
|
||||
); !errors.Is(createErr, startError) {
|
||||
t.Errorf("unexpected error returned: %v", createErr)
|
||||
}
|
||||
)
|
||||
assert.ErrorIs(t, createErr, startError, "unexpected error returned")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -972,13 +837,9 @@ func TestClient_Close(t *testing.T) {
|
||||
}),
|
||||
WithAgent(&TestAgent{}),
|
||||
)
|
||||
if createErr != nil {
|
||||
t.Errorf("unexpected create error returned: %v", createErr)
|
||||
}
|
||||
assert.NoError(t, createErr, "unexpected create error returned")
|
||||
gotCloseErr := c.Close()
|
||||
if !errors.Is(gotCloseErr, closeErr) {
|
||||
t.Errorf("unexpected close error returned: %v", gotCloseErr)
|
||||
}
|
||||
assert.ErrorIs(t, gotCloseErr, closeErr, "unexpected close error returned")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -992,19 +853,13 @@ func TestClientDefaultHandler(t *testing.T) {
|
||||
client, createErr := NewClient(noopConnection{},
|
||||
WithAgent(agent),
|
||||
WithHandler(func(e Event) {
|
||||
if called {
|
||||
t.Error("should not be called twice")
|
||||
}
|
||||
assert.False(t, called, "should not be called twice")
|
||||
called = true
|
||||
if e.TransactionID != id {
|
||||
t.Error("wrong transaction ID")
|
||||
}
|
||||
assert.Equal(t, id, e.TransactionID, "wrong transaction ID")
|
||||
handlerCalled <- struct{}{}
|
||||
}),
|
||||
)
|
||||
if createErr != nil {
|
||||
t.Fatal(createErr)
|
||||
}
|
||||
assert.NoError(t, createErr)
|
||||
go func() {
|
||||
agent.h(Event{
|
||||
TransactionID: id,
|
||||
@@ -1014,11 +869,9 @@ func TestClientDefaultHandler(t *testing.T) {
|
||||
case <-handlerCalled:
|
||||
// pass
|
||||
case <-time.After(time.Millisecond * 100):
|
||||
t.Fatal("timed out")
|
||||
}
|
||||
if closeErr := client.Close(); closeErr != nil {
|
||||
t.Error(closeErr)
|
||||
assert.Fail(t, "timed out")
|
||||
}
|
||||
assert.NoError(t, client.Close())
|
||||
// Handler call should be ignored.
|
||||
agent.h(Event{})
|
||||
}
|
||||
@@ -1030,15 +883,9 @@ func TestClientClosedStart(t *testing.T) {
|
||||
c, createErr := NewClient(noopConnection{},
|
||||
WithAgent(a),
|
||||
)
|
||||
if createErr != nil {
|
||||
t.Fatal(createErr)
|
||||
}
|
||||
if closeErr := c.Close(); closeErr != nil {
|
||||
t.Error(closeErr)
|
||||
}
|
||||
if startErr := c.start(&clientTransaction{}); !errors.Is(startErr, ErrClientClosed) {
|
||||
t.Error("should error")
|
||||
}
|
||||
assert.NoError(t, createErr)
|
||||
assert.NoError(t, c.Close())
|
||||
assert.ErrorIs(t, c.start(&clientTransaction{}), ErrClientClosed)
|
||||
}
|
||||
|
||||
func TestWithNoRetransmit(t *testing.T) {
|
||||
@@ -1046,9 +893,7 @@ func TestWithNoRetransmit(t *testing.T) {
|
||||
response.Encode()
|
||||
connL, connR := net.Pipe()
|
||||
defer func() {
|
||||
if closeErr := connL.Close(); closeErr != nil {
|
||||
panic(closeErr)
|
||||
}
|
||||
assert.NoError(t, connL.Close())
|
||||
}()
|
||||
collector := new(manualCollector)
|
||||
clock := &manualClock{current: time.Now()}
|
||||
@@ -1062,7 +907,7 @@ func TestWithNoRetransmit(t *testing.T) {
|
||||
Error: ErrTransactionTimeOut,
|
||||
})
|
||||
} else {
|
||||
t.Error("there should be no second attempt")
|
||||
assert.Fail(t, "there should be no second attempt")
|
||||
go agent.h(Event{
|
||||
TransactionID: id,
|
||||
Error: ErrTransactionTimeOut,
|
||||
@@ -1078,28 +923,18 @@ func TestWithNoRetransmit(t *testing.T) {
|
||||
WithRTO(0),
|
||||
WithNoRetransmit,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
gotReads := make(chan struct{})
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
readN, readErr := connL.Read(buf)
|
||||
if readErr != nil {
|
||||
t.Error(readErr)
|
||||
}
|
||||
if !IsMessage(buf[:readN]) {
|
||||
t.Error("should be STUN")
|
||||
}
|
||||
assert.NoError(t, readErr)
|
||||
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
|
||||
gotReads <- struct{}{}
|
||||
}()
|
||||
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if !errors.Is(event.Error, ErrTransactionTimeOut) {
|
||||
t.Error("unexpected error")
|
||||
}
|
||||
}); doErr != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
assert.ErrorIs(t, event.Error, ErrTransactionTimeOut, "unexpected error")
|
||||
}))
|
||||
<-gotReads
|
||||
}
|
||||
|
||||
@@ -1114,9 +949,7 @@ func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
|
||||
response.Encode()
|
||||
connL, connR := net.Pipe()
|
||||
defer func() {
|
||||
if closeErr := connL.Close(); closeErr != nil {
|
||||
panic(closeErr)
|
||||
}
|
||||
assert.NoError(t, connL.Close())
|
||||
}()
|
||||
collector := new(manualCollector)
|
||||
shouldWait := false
|
||||
@@ -1169,9 +1002,7 @@ func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
|
||||
t.Log("clock locked")
|
||||
<-clockLocked
|
||||
t.Log("closing client")
|
||||
if closeErr := client.Close(); closeErr != nil {
|
||||
t.Error(closeErr)
|
||||
}
|
||||
assert.NoError(t, client.Close())
|
||||
t.Log("client closed, unlocking clock")
|
||||
clockWait <- struct{}{}
|
||||
t.Log("clock unlocked")
|
||||
@@ -1186,44 +1017,30 @@ func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
|
||||
WithCollector(collector),
|
||||
WithRTO(time.Millisecond),
|
||||
)
|
||||
if startClientErr != nil {
|
||||
t.Fatal(startClientErr)
|
||||
}
|
||||
assert.NoError(t, startClientErr)
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
readN, readErr := connL.Read(buf)
|
||||
if readErr != nil {
|
||||
t.Error(readErr)
|
||||
}
|
||||
if !IsMessage(buf[:readN]) {
|
||||
t.Error("should be STUN")
|
||||
}
|
||||
assert.NoError(t, readErr)
|
||||
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
|
||||
readN, readErr = connL.Read(buf)
|
||||
if readErr != nil {
|
||||
t.Error(readErr)
|
||||
}
|
||||
if !IsMessage(buf[:readN]) {
|
||||
t.Error("should be STUN")
|
||||
}
|
||||
assert.NoError(t, readErr)
|
||||
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
|
||||
gotReads <- struct{}{}
|
||||
}()
|
||||
t.Log("starting")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if !errors.Is(event.Error, ErrClientClosed) {
|
||||
t.Error(event.Error)
|
||||
}
|
||||
}); doErr != nil {
|
||||
t.Error(doErr)
|
||||
}
|
||||
assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
assert.ErrorIs(t, event.Error, ErrClientClosed)
|
||||
}))
|
||||
done <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
// ok
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Error("timeout")
|
||||
assert.Fail(t, "timeout")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1232,9 +1049,7 @@ func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
|
||||
response.Encode()
|
||||
connL, connR := net.Pipe()
|
||||
defer func() {
|
||||
if closeErr := connL.Close(); closeErr != nil {
|
||||
panic(closeErr)
|
||||
}
|
||||
assert.NoError(t, connL.Close())
|
||||
}()
|
||||
collector := new(manualCollector)
|
||||
shouldWait := false
|
||||
@@ -1291,9 +1106,7 @@ func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
|
||||
t.Log("clock locked")
|
||||
<-clockLocked
|
||||
t.Log("closing connection")
|
||||
if closeErr := connL.Close(); closeErr != nil {
|
||||
panic(closeErr)
|
||||
}
|
||||
assert.NoError(t, connL.Close())
|
||||
t.Log("connection closed, unlocking clock")
|
||||
clockWait <- struct{}{}
|
||||
t.Log("clock unlocked")
|
||||
@@ -1308,52 +1121,33 @@ func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
|
||||
WithCollector(collector),
|
||||
WithRTO(time.Millisecond),
|
||||
)
|
||||
if startClientErr != nil {
|
||||
t.Fatal(startClientErr)
|
||||
}
|
||||
assert.NoError(t, startClientErr)
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
readN, readErr := connL.Read(buf)
|
||||
if readErr != nil {
|
||||
t.Error(readErr)
|
||||
}
|
||||
if !IsMessage(buf[:readN]) {
|
||||
t.Error("should be STUN")
|
||||
}
|
||||
assert.NoError(t, readErr)
|
||||
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
|
||||
readN, readErr = connL.Read(buf)
|
||||
if readErr != nil {
|
||||
t.Error(readErr)
|
||||
}
|
||||
if !IsMessage(buf[:readN]) {
|
||||
t.Error("should be STUN")
|
||||
}
|
||||
assert.NoError(t, readErr)
|
||||
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
|
||||
gotReads <- struct{}{}
|
||||
}()
|
||||
t.Log("starting")
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
var e StopErr
|
||||
if !errors.As(event.Error, &e) {
|
||||
t.Error(event.Error)
|
||||
} else {
|
||||
if !errors.Is(e.Err, agentStopErr) {
|
||||
t.Error("incorrect agent error")
|
||||
}
|
||||
if !errors.Is(e.Cause, io.ErrClosedPipe) {
|
||||
t.Error("incorrect connection error")
|
||||
}
|
||||
}
|
||||
}); doErr != nil {
|
||||
t.Error(doErr)
|
||||
}
|
||||
assert.ErrorAs(t, event.Error, &e)
|
||||
assert.ErrorIs(t, e.Err, agentStopErr, "incorrect agent error")
|
||||
assert.ErrorIs(t, e.Cause, io.ErrClosedPipe, "incorrect connection error")
|
||||
}))
|
||||
done <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
// ok
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Error("timeout")
|
||||
assert.Fail(t, "timeout")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1362,9 +1156,7 @@ func TestClientRTOAgentErr(t *testing.T) {
|
||||
response.Encode()
|
||||
connL, connR := net.Pipe()
|
||||
defer func() {
|
||||
if closeErr := connL.Close(); closeErr != nil {
|
||||
panic(closeErr)
|
||||
}
|
||||
assert.NoError(t, connL.Close())
|
||||
}()
|
||||
collector := new(manualCollector)
|
||||
clock := callbackClock(time.Now)
|
||||
@@ -1396,33 +1188,23 @@ func TestClientRTOAgentErr(t *testing.T) {
|
||||
WithCollector(collector),
|
||||
WithRTO(time.Millisecond),
|
||||
)
|
||||
if startClientErr != nil {
|
||||
t.Fatal(startClientErr)
|
||||
}
|
||||
assert.NoError(t, startClientErr)
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
readN, readErr := connL.Read(buf)
|
||||
if readErr != nil {
|
||||
t.Error(readErr)
|
||||
}
|
||||
if !IsMessage(buf[:readN]) {
|
||||
t.Error("should be STUN")
|
||||
}
|
||||
assert.NoError(t, readErr)
|
||||
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
|
||||
gotReads <- struct{}{}
|
||||
}()
|
||||
t.Log("starting")
|
||||
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
if !errors.Is(event.Error, agentStartErr) {
|
||||
t.Error(event.Error)
|
||||
}
|
||||
}); doErr != nil {
|
||||
t.Error(doErr)
|
||||
}
|
||||
assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
|
||||
assert.ErrorIs(t, event.Error, agentStartErr)
|
||||
}))
|
||||
select {
|
||||
case <-gotReads:
|
||||
// ok
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Error("reads timeout")
|
||||
assert.Fail(t, "reads timeout")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1431,9 +1213,7 @@ func TestClient_HandleProcessError(t *testing.T) {
|
||||
response.Encode()
|
||||
connL, connR := net.Pipe()
|
||||
defer func() {
|
||||
if closeErr := connL.Close(); closeErr != nil {
|
||||
panic(closeErr)
|
||||
}
|
||||
assert.NoError(t, connL.Close())
|
||||
}()
|
||||
collector := new(manualCollector)
|
||||
clock := callbackClock(time.Now)
|
||||
@@ -1451,14 +1231,10 @@ func TestClient_HandleProcessError(t *testing.T) {
|
||||
WithCollector(collector),
|
||||
WithRTO(time.Millisecond),
|
||||
)
|
||||
if startClientErr != nil {
|
||||
t.Fatal(startClientErr)
|
||||
}
|
||||
assert.NoError(t, startClientErr)
|
||||
go func() {
|
||||
_, readErr := connL.Write(response.Raw)
|
||||
if readErr != nil {
|
||||
t.Error(readErr)
|
||||
}
|
||||
assert.NoError(t, readErr)
|
||||
gotWrites <- struct{}{}
|
||||
}()
|
||||
t.Log("starting")
|
||||
@@ -1466,20 +1242,16 @@ func TestClient_HandleProcessError(t *testing.T) {
|
||||
case <-gotWrites:
|
||||
// ok
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Error("reads timeout")
|
||||
}
|
||||
if closeErr := client.Close(); closeErr != nil {
|
||||
t.Error(closeErr)
|
||||
assert.Fail(t, "reads timeout")
|
||||
}
|
||||
assert.NoError(t, client.Close())
|
||||
}
|
||||
|
||||
func TestClientImmediateTimeout(t *testing.T) {
|
||||
response := MustBuild(TransactionID, BindingSuccess)
|
||||
connL, connR := net.Pipe()
|
||||
defer func() {
|
||||
if closeErr := connL.Close(); closeErr != nil {
|
||||
panic(closeErr)
|
||||
}
|
||||
assert.NoError(t, connL.Close())
|
||||
}()
|
||||
collector := new(manualCollector)
|
||||
clock := &manualClock{current: time.Now()}
|
||||
@@ -1488,16 +1260,14 @@ func TestClientImmediateTimeout(t *testing.T) {
|
||||
attempt := 0
|
||||
agent.start = func(id [TransactionIDSize]byte, deadline time.Time) error {
|
||||
if attempt == 0 {
|
||||
if deadline.Before(clock.current.Add(rto / 2)) {
|
||||
t.Error("deadline too fast")
|
||||
}
|
||||
assert.False(t, deadline.Before(clock.current.Add(rto/2)), "deadline too fast")
|
||||
attempt++
|
||||
go agent.h(Event{
|
||||
TransactionID: id,
|
||||
Message: response,
|
||||
})
|
||||
} else {
|
||||
t.Error("there should be no second attempt")
|
||||
assert.Fail(t, "there should be no second attempt")
|
||||
go agent.h(Event{
|
||||
TransactionID: id,
|
||||
Error: ErrTransactionTimeOut,
|
||||
@@ -1512,25 +1282,17 @@ func TestClientImmediateTimeout(t *testing.T) {
|
||||
WithCollector(collector),
|
||||
WithRTO(rto),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
gotReads := make(chan struct{})
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
readN, readErr := connL.Read(buf)
|
||||
if readErr != nil {
|
||||
t.Error(readErr)
|
||||
}
|
||||
if !IsMessage(buf[:readN]) {
|
||||
t.Error("should be STUN")
|
||||
}
|
||||
assert.NoError(t, readErr)
|
||||
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
|
||||
gotReads <- struct{}{}
|
||||
}()
|
||||
client.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec
|
||||
if errors.Is(e.Error, ErrTransactionTimeOut) {
|
||||
t.Error("unexpected error")
|
||||
}
|
||||
assert.NoError(t, e.Error, "unexpected error")
|
||||
})
|
||||
<-gotReads
|
||||
}
|
||||
|
||||
@@ -8,9 +8,10 @@ package stun
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func BenchmarkErrorCode_AddTo(b *testing.B) {
|
||||
@@ -52,19 +53,13 @@ func TestErrorCodeAttribute_GetFrom(t *testing.T) {
|
||||
m := New()
|
||||
m.Add(AttrErrorCode, []byte{1})
|
||||
c := new(ErrorCodeAttribute)
|
||||
if err := c.GetFrom(m); !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
t.Errorf("GetFrom should return <%s>, but got <%s>",
|
||||
io.ErrUnexpectedEOF, err,
|
||||
)
|
||||
}
|
||||
assert.ErrorIs(t, c.GetFrom(m), io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
func TestMessage_AddErrorCode(t *testing.T) {
|
||||
m := New()
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
copy(m.TransactionID[:], transactionID)
|
||||
expectedCode := ErrorCode(438)
|
||||
expectedReason := "Stale Nonce"
|
||||
@@ -72,23 +67,13 @@ func TestMessage_AddErrorCode(t *testing.T) {
|
||||
m.WriteHeader()
|
||||
|
||||
mRes := New()
|
||||
if _, err = mRes.ReadFrom(m.reader()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = mRes.ReadFrom(m.reader())
|
||||
assert.NoError(t, err)
|
||||
errCodeAttr := new(ErrorCodeAttribute)
|
||||
if err = errCodeAttr.GetFrom(mRes); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, errCodeAttr.GetFrom(mRes))
|
||||
code := errCodeAttr.Code
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if code != expectedCode {
|
||||
t.Error("bad code", code)
|
||||
}
|
||||
if string(errCodeAttr.Reason) != expectedReason {
|
||||
t.Error("bad reason", string(errCodeAttr.Reason))
|
||||
}
|
||||
assert.Equal(t, expectedCode, code, "bad code")
|
||||
assert.Equal(t, expectedReason, string(errCodeAttr.Reason), "bad reason")
|
||||
}
|
||||
|
||||
func TestErrorCode(t *testing.T) {
|
||||
@@ -96,19 +81,11 @@ func TestErrorCode(t *testing.T) {
|
||||
Code: 404,
|
||||
Reason: []byte("not found!"),
|
||||
}
|
||||
if attr.String() != "404: not found!" {
|
||||
t.Error("bad string", attr)
|
||||
}
|
||||
assert.Equal(t, "404: not found!", attr.String(), "bad string")
|
||||
m := New()
|
||||
cod := ErrorCode(666)
|
||||
if err := cod.AddTo(m); !errors.Is(err, ErrNoDefaultReason) {
|
||||
t.Error("should be ErrNoDefaultReason", err)
|
||||
}
|
||||
if err := attr.GetFrom(m); err == nil {
|
||||
t.Error("attr should not be in message")
|
||||
}
|
||||
assert.ErrorIs(t, cod.AddTo(m), ErrNoDefaultReason, "should be ErrNoDefaultReason")
|
||||
assert.Error(t, attr.GetFrom(m), "attr should not be in message")
|
||||
attr.Reason = make([]byte, 2048)
|
||||
if err := attr.AddTo(m); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
assert.Error(t, attr.AddTo(m), "should error")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ package stun
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDecodeErr_IsInvalidCookie(t *testing.T) {
|
||||
@@ -14,25 +16,13 @@ func TestDecodeErr_IsInvalidCookie(t *testing.T) {
|
||||
decoded := new(Message)
|
||||
m.Raw[4] = 55
|
||||
_, err := decoded.Write(m.Raw)
|
||||
if err == nil {
|
||||
t.Fatal("should error")
|
||||
}
|
||||
assert.Error(t, err, "should error")
|
||||
expected := "BadFormat for message/cookie: " +
|
||||
"3712a442 is invalid magic cookie (should be 2112a442)"
|
||||
if err.Error() != expected {
|
||||
t.Error(err, "!=", expected)
|
||||
}
|
||||
assert.Equal(t, expected, err.Error(), "error message mismatch")
|
||||
var dErr *DecodeErr
|
||||
if !errors.As(err, &dErr) {
|
||||
t.Error("not decode error")
|
||||
}
|
||||
if !dErr.IsInvalidCookie() {
|
||||
t.Error("IsInvalidCookie = false, should be true")
|
||||
}
|
||||
if !dErr.IsPlaceChildren("cookie") {
|
||||
t.Error("bad children")
|
||||
}
|
||||
if !dErr.IsPlaceParent("message") {
|
||||
t.Error("bad parent")
|
||||
}
|
||||
assert.True(t, errors.As(err, &dErr), "not decode error")
|
||||
assert.True(t, dErr.IsInvalidCookie(), "IsInvalidCookie = false, should be true")
|
||||
assert.True(t, dErr.IsPlaceChildren("cookie"), "bad children")
|
||||
assert.True(t, dErr.IsPlaceParent("message"), "bad parent")
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ package stun
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func BenchmarkFingerprint_AddTo(b *testing.B) {
|
||||
@@ -36,26 +38,18 @@ func TestFingerprint_Check(t *testing.T) {
|
||||
m.WriteHeader()
|
||||
Fingerprint.AddTo(m) //nolint:errcheck,gosec
|
||||
m.WriteHeader()
|
||||
if err := Fingerprint.Check(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, Fingerprint.Check(m))
|
||||
m.Raw[3]++
|
||||
if err := Fingerprint.Check(m); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
assert.Error(t, Fingerprint.Check(m))
|
||||
}
|
||||
|
||||
func TestFingerprint_CheckBad(t *testing.T) {
|
||||
m := new(Message)
|
||||
addAttr(t, m, NewSoftware("software"))
|
||||
m.WriteHeader()
|
||||
if err := Fingerprint.Check(m); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
assert.Error(t, Fingerprint.Check(m))
|
||||
m.Add(AttrFingerprint, []byte{1, 2, 3})
|
||||
if !IsAttrSizeInvalid(Fingerprint.Check(m)) {
|
||||
t.Error("IsAttrSizeInvalid should be true")
|
||||
}
|
||||
assert.True(t, IsAttrSizeInvalid(Fingerprint.Check(m)))
|
||||
}
|
||||
|
||||
func BenchmarkFingerprint_Check(b *testing.B) {
|
||||
|
||||
46
fuzz_test.go
46
fuzz_test.go
@@ -7,6 +7,8 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func FuzzMessage(f *testing.F) {
|
||||
@@ -26,21 +28,12 @@ func FuzzMessage(f *testing.F) {
|
||||
}
|
||||
|
||||
msg2 := New()
|
||||
if _, err := msg2.Write(msg1.Raw); err != nil {
|
||||
t.Fatalf("Failed to write: %s", err)
|
||||
}
|
||||
_, err := msg2.Write(msg1.Raw)
|
||||
assert.NoError(t, err, "Failed to write")
|
||||
|
||||
if msg2.TransactionID != msg1.TransactionID {
|
||||
t.Fatal("Transaction ID mismatch")
|
||||
}
|
||||
|
||||
if msg2.Type != msg1.Type {
|
||||
t.Fatal("Type mismatch")
|
||||
}
|
||||
|
||||
if len(msg2.Attributes) != len(msg1.Attributes) {
|
||||
t.Fatal("Attributes length mismatch")
|
||||
}
|
||||
assert.Equal(t, msg1.TransactionID, msg2.TransactionID, "Transaction ID mismatch")
|
||||
assert.Equal(t, msg1.Type, msg2.Type, "Type mismatch")
|
||||
assert.Equal(t, len(msg1.Attributes), len(msg2.Attributes), "Attributes length mismatch")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -51,15 +44,11 @@ func FuzzType(f *testing.F) {
|
||||
t1 := MessageType{}
|
||||
t1.ReadValue(v)
|
||||
v2 := t1.Value()
|
||||
if v != v2 {
|
||||
t.Fatal("v != v2")
|
||||
}
|
||||
assert.Equal(t, v, v2, "v != v2")
|
||||
|
||||
t2 := MessageType{}
|
||||
t2.ReadValue(v2)
|
||||
if t2 != t1 {
|
||||
t.Fatal("t2 != t1")
|
||||
}
|
||||
assert.Equal(t, t1, t2, "t2 != t1")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -94,20 +83,19 @@ func FuzzSetters(f *testing.F) {
|
||||
m1.WriteHeader()
|
||||
m1.Add(attr.t, value)
|
||||
err := attr.g.GetFrom(m1)
|
||||
if errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Fatalf("Unexpected 404: %s", err)
|
||||
}
|
||||
assert.False(t, errors.Is(err, ErrAttributeNotFound))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
m2.WriteHeader()
|
||||
if err = attr.g.AddTo(m2); err != nil {
|
||||
err = attr.g.AddTo(m2)
|
||||
if err != nil {
|
||||
// We allow decoding some text attributes
|
||||
// when their length is too big, but
|
||||
// not encoding.
|
||||
if !IsAttrSizeOverflow(err) {
|
||||
t.Fatal(err)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -115,14 +103,10 @@ func FuzzSetters(f *testing.F) {
|
||||
|
||||
m3.WriteHeader()
|
||||
v, err := m2.Get(attr.t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
m3.Add(attr.t, v)
|
||||
|
||||
if !m2.Equal(m3) {
|
||||
t.Fatalf("Not equal: %s != %s", m2, m3)
|
||||
}
|
||||
assert.True(t, m2.Equal(m3), "Not equal: %s != %s", m2, m3)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
117
helpers_test.go
117
helpers_test.go
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/pion/stun/v3/internal/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func BenchmarkBuildOverhead(b *testing.B) {
|
||||
@@ -59,21 +60,12 @@ func TestMessage_Apply(t *testing.T) {
|
||||
integrity,
|
||||
Fingerprint,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal("failed to build:", err)
|
||||
}
|
||||
if err = msg.Check(Fingerprint, integrity); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := decoded.Write(msg.Raw); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !decoded.Equal(msg) {
|
||||
t.Error("not equal")
|
||||
}
|
||||
if err := integrity.Check(decoded); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err, "failed to build")
|
||||
assert.NoError(t, msg.Check(Fingerprint, integrity))
|
||||
_, err = decoded.Write(msg.Raw)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, decoded.Equal(msg))
|
||||
assert.NoError(t, integrity.Check(decoded))
|
||||
}
|
||||
|
||||
type errReturner struct {
|
||||
@@ -97,25 +89,17 @@ func (e errReturner) GetFrom(*Message) error {
|
||||
func TestHelpersErrorHandling(t *testing.T) {
|
||||
m := New()
|
||||
errReturn := errReturner{Err: errTError}
|
||||
if err := m.Build(errReturn); !errors.Is(err, errReturn.Err) {
|
||||
t.Error(err, "!=", errReturn.Err)
|
||||
}
|
||||
if err := m.Check(errReturn); !errors.Is(err, errReturn.Err) {
|
||||
t.Error(err, "!=", errReturn.Err)
|
||||
}
|
||||
if err := m.Parse(errReturn); !errors.Is(err, errReturn.Err) {
|
||||
t.Error(err, "!=", errReturn.Err)
|
||||
}
|
||||
assert.ErrorIs(t, m.Build(errReturn), errReturn.Err)
|
||||
assert.ErrorIs(t, m.Check(errReturn), errReturn.Err)
|
||||
assert.ErrorIs(t, m.Parse(errReturn), errReturn.Err)
|
||||
t.Run("MustBuild", func(t *testing.T) {
|
||||
t.Run("Positive", func(*testing.T) {
|
||||
MustBuild(NewTransactionIDSetter(transactionID{}))
|
||||
})
|
||||
defer func() {
|
||||
if p, ok := recover().(error); !ok || !errors.Is(p, errReturn.Err) {
|
||||
t.Errorf("%s != %s",
|
||||
p, errReturn.Err,
|
||||
)
|
||||
}
|
||||
p, ok := recover().(error)
|
||||
assert.True(t, ok)
|
||||
assert.ErrorIs(t, p, errReturn.Err)
|
||||
}()
|
||||
MustBuild(errReturn)
|
||||
})
|
||||
@@ -123,91 +107,62 @@ func TestHelpersErrorHandling(t *testing.T) {
|
||||
|
||||
func TestMessage_ForEach(t *testing.T) { //nolint:cyclop
|
||||
initial := New()
|
||||
if err := initial.Build(
|
||||
assert.NoError(t, initial.Build(
|
||||
NewRealm("realm1"), NewRealm("realm2"),
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
))
|
||||
newMessage := func() *Message {
|
||||
m := New()
|
||||
if err := m.Build(
|
||||
assert.NoError(t, m.Build(
|
||||
NewRealm("realm1"), NewRealm("realm2"),
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
))
|
||||
|
||||
return m
|
||||
}
|
||||
t.Run("NoResults", func(t *testing.T) {
|
||||
m := newMessage()
|
||||
if !m.Equal(initial) {
|
||||
t.Error("m should be equal to initial")
|
||||
}
|
||||
if err := m.ForEach(AttrUsername, func(*Message) error {
|
||||
t.Error("should not be called")
|
||||
assert.True(t, m.Equal(initial), "m should be equal to initial")
|
||||
assert.NoError(t, m.ForEach(AttrUsername, func(*Message) error {
|
||||
assert.Fail(t, "should not be called")
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !m.Equal(initial) {
|
||||
t.Error("m should be equal to initial")
|
||||
}
|
||||
}))
|
||||
assert.True(t, m.Equal(initial), "m should be equal to initial")
|
||||
})
|
||||
t.Run("ReturnOnError", func(t *testing.T) {
|
||||
m := newMessage()
|
||||
var calls int
|
||||
if err := m.ForEach(AttrRealm, func(*Message) error {
|
||||
err := m.ForEach(AttrRealm, func(*Message) error {
|
||||
if calls > 0 {
|
||||
t.Error("called multiple times")
|
||||
assert.Fail(t, "called multiple times")
|
||||
}
|
||||
calls++
|
||||
|
||||
return ErrAttributeNotFound
|
||||
}); !errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !m.Equal(initial) {
|
||||
t.Error("m should be equal to initial")
|
||||
}
|
||||
})
|
||||
assert.ErrorIs(t, err, ErrAttributeNotFound)
|
||||
assert.True(t, m.Equal(initial), "m should be equal to initial")
|
||||
})
|
||||
t.Run("Positive", func(t *testing.T) {
|
||||
msg := newMessage()
|
||||
var realms []string
|
||||
if err := msg.ForEach(AttrRealm, func(m *Message) error {
|
||||
assert.NoError(t, msg.ForEach(AttrRealm, func(m *Message) error {
|
||||
var realm Realm
|
||||
if err := realm.GetFrom(m); err != nil {
|
||||
return err
|
||||
}
|
||||
assert.NoError(t, realm.GetFrom(m))
|
||||
realms = append(realms, realm.String())
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(realms) != 2 {
|
||||
t.Fatal("expected 2 realms")
|
||||
}
|
||||
if realms[0] != "realm1" {
|
||||
t.Error("bad value for 1 realm")
|
||||
}
|
||||
if realms[1] != "realm2" {
|
||||
t.Error("bad value for 2 realm")
|
||||
}
|
||||
if !msg.Equal(initial) {
|
||||
t.Error("m should be equal to initial")
|
||||
}
|
||||
}))
|
||||
assert.Len(t, realms, 2)
|
||||
assert.Equal(t, "realm1", realms[0], "bad value for 1 realm")
|
||||
assert.Equal(t, "realm2", realms[1], "bad value for 2 realm")
|
||||
assert.True(t, msg.Equal(initial), "m should be equal to initial")
|
||||
t.Run("ZeroAlloc", func(t *testing.T) {
|
||||
msg = newMessage()
|
||||
var realm Realm
|
||||
testutil.ShouldNotAllocate(t, func() {
|
||||
if err := msg.ForEach(AttrRealm, realm.GetFrom); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, msg.ForEach(AttrRealm, realm.GetFrom))
|
||||
})
|
||||
if !msg.Equal(initial) {
|
||||
t.Error("m should be equal to initial")
|
||||
}
|
||||
assert.True(t, msg.Equal(initial), "m should be equal to initial")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
42
iana_test.go
42
iana_test.go
@@ -12,6 +12,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func loadCSV(tb testing.TB, name string) [][]string {
|
||||
@@ -21,9 +23,7 @@ func loadCSV(tb testing.TB, name string) [][]string {
|
||||
r := csv.NewReader(bytes.NewReader(data))
|
||||
r.Comment = '#'
|
||||
records, err := r.ReadAll()
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
assert.NoError(tb, err)
|
||||
|
||||
return records
|
||||
}
|
||||
@@ -41,20 +41,14 @@ func TestIANA(t *testing.T) { //nolint:cyclop
|
||||
continue
|
||||
}
|
||||
val, parseErr := strconv.ParseInt(v[2:], 16, 64)
|
||||
if parseErr != nil {
|
||||
t.Fatal(parseErr)
|
||||
}
|
||||
assert.NoError(t, parseErr)
|
||||
t.Logf("value: 0x%x, name: %s", val, name)
|
||||
methods[name] = Method(val) //nolint:gosec // G115
|
||||
}
|
||||
for val, name := range methodName() {
|
||||
mapped, ok := methods[name]
|
||||
if !ok {
|
||||
t.Errorf("failed to find method %s in IANA", name)
|
||||
}
|
||||
if mapped != val {
|
||||
t.Errorf("%s: IANA %d != actual %d", name, mapped, val)
|
||||
}
|
||||
assert.True(t, ok, "failed to find method %s in IANA", name)
|
||||
assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val)
|
||||
}
|
||||
})
|
||||
t.Run("Attributes", func(t *testing.T) {
|
||||
@@ -69,9 +63,7 @@ func TestIANA(t *testing.T) { //nolint:cyclop
|
||||
continue
|
||||
}
|
||||
val, parseErr := strconv.ParseInt(v[2:], 16, 64)
|
||||
if parseErr != nil {
|
||||
t.Fatal(parseErr)
|
||||
}
|
||||
assert.NoError(t, parseErr)
|
||||
t.Logf("value: 0x%x, name: %s", val, name)
|
||||
attrTypes[name] = AttrType(val) //nolint:gosec // G115
|
||||
}
|
||||
@@ -83,12 +75,8 @@ func TestIANA(t *testing.T) { //nolint:cyclop
|
||||
}
|
||||
for val, name := range attrNames() {
|
||||
mapped, ok := attrTypes[name]
|
||||
if !ok {
|
||||
t.Errorf("failed to find attribute %s in IANA", name)
|
||||
}
|
||||
if mapped != val {
|
||||
t.Errorf("%s: IANA %d != actual %d", name, mapped, val)
|
||||
}
|
||||
assert.True(t, ok, "failed to find attribute %s in IANA", name)
|
||||
assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val)
|
||||
}
|
||||
})
|
||||
t.Run("ErrorCodes", func(t *testing.T) {
|
||||
@@ -103,21 +91,15 @@ func TestIANA(t *testing.T) { //nolint:cyclop
|
||||
continue
|
||||
}
|
||||
val, parseErr := strconv.ParseInt(v, 10, 64)
|
||||
if parseErr != nil {
|
||||
t.Fatal(parseErr)
|
||||
}
|
||||
assert.NoError(t, parseErr)
|
||||
t.Logf("value: 0x%x, name: %s", val, name)
|
||||
errorCodes[name] = ErrorCode(val)
|
||||
}
|
||||
for val, nameB := range errorReasons {
|
||||
name := string(nameB)
|
||||
mapped, ok := errorCodes[name]
|
||||
if !ok {
|
||||
t.Errorf("failed to find error code %s in IANA", name)
|
||||
}
|
||||
if mapped != val {
|
||||
t.Errorf("%s: IANA %d != actual %d", name, mapped, val)
|
||||
}
|
||||
assert.True(t, ok, "failed to find error code %s in IANA", name)
|
||||
assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,40 +4,29 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMessageIntegrity_AddTo_Simple(t *testing.T) {
|
||||
integrity := NewLongTermIntegrity("user", "realm", "pass")
|
||||
expected, err := hex.DecodeString("8493fbc53ba582fb4c044c456bdc40eb")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(expected, integrity) {
|
||||
t.Error(ErrIntegrityMismatch)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, expected, integrity)
|
||||
t.Run("Check", func(t *testing.T) {
|
||||
m := new(Message)
|
||||
m.WriteHeader()
|
||||
if err := integrity.AddTo(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, integrity.AddTo(m))
|
||||
NewSoftware("software").AddTo(m) //nolint:errcheck,gosec
|
||||
m.WriteHeader()
|
||||
dM := new(Message)
|
||||
dM.Raw = m.Raw
|
||||
if err := dM.Decode(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := integrity.Check(dM); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, dM.Decode())
|
||||
assert.NoError(t, integrity.Check(dM))
|
||||
dM.Raw[24] += 12 // HMAC now invalid
|
||||
if integrity.Check(dM) == nil {
|
||||
t.Error("should be invalid")
|
||||
}
|
||||
assert.Error(t, integrity.Check(dM))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -47,38 +36,23 @@ func TestMessageIntegrityWithFingerprint(t *testing.T) {
|
||||
msg.WriteHeader()
|
||||
NewSoftware("software").AddTo(msg) //nolint:errcheck,gosec
|
||||
integrity := NewShortTermIntegrity("pwd")
|
||||
if integrity.String() != "KEY: 0x707764" {
|
||||
t.Error("bad string", integrity)
|
||||
}
|
||||
if err := integrity.Check(msg); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
if err := integrity.AddTo(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := Fingerprint.AddTo(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := integrity.Check(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, "KEY: 0x707764", integrity.String())
|
||||
assert.NoError(t, integrity.AddTo(msg))
|
||||
assert.NoError(t, integrity.AddTo(msg))
|
||||
assert.NoError(t, integrity.Check(msg))
|
||||
assert.NoError(t, Fingerprint.AddTo(msg))
|
||||
assert.NoError(t, integrity.Check(msg))
|
||||
msg.Raw[24] = 33
|
||||
if err := integrity.Check(msg); err == nil {
|
||||
t.Fatal("mismatch expected")
|
||||
}
|
||||
assert.Error(t, integrity.Check(msg))
|
||||
}
|
||||
|
||||
func TestMessageIntegrity(t *testing.T) {
|
||||
m := new(Message)
|
||||
i := NewShortTermIntegrity("password")
|
||||
m.WriteHeader()
|
||||
if err := i.AddTo(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, i.AddTo(m))
|
||||
_, err := m.Get(AttrMessageIntegrity)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestMessageIntegrityBeforeFingerprint(t *testing.T) {
|
||||
@@ -86,9 +60,7 @@ func TestMessageIntegrityBeforeFingerprint(t *testing.T) {
|
||||
m.WriteHeader()
|
||||
Fingerprint.AddTo(m) //nolint:errcheck,gosec
|
||||
i := NewShortTermIntegrity("password")
|
||||
if err := i.AddTo(m); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
assert.Error(t, i.AddTo(m))
|
||||
}
|
||||
|
||||
func BenchmarkMessageIntegrity_AddTo(b *testing.B) {
|
||||
@@ -99,9 +71,7 @@ func BenchmarkMessageIntegrity_AddTo(b *testing.B) {
|
||||
b.SetBytes(int64(len(m.Raw)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.WriteHeader()
|
||||
if err := integrity.AddTo(m); err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
assert.NoError(b, integrity.AddTo(m))
|
||||
m.Reset()
|
||||
}
|
||||
}
|
||||
@@ -114,13 +84,9 @@ func BenchmarkMessageIntegrity_Check(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
m.WriteHeader()
|
||||
b.SetBytes(int64(len(m.Raw)))
|
||||
if err := integrity.AddTo(m); err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
assert.NoError(b, integrity.AddTo(m))
|
||||
m.WriteLength()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := integrity.Check(m); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
assert.NoError(b, integrity.Check(m))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"fmt"
|
||||
"hash"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type hmacTest struct {
|
||||
@@ -524,26 +526,17 @@ func hmacTests() []hmacTest { //nolint:maintidx
|
||||
func TestHMAC(t *testing.T) {
|
||||
for i, tt := range hmacTests() {
|
||||
hsh := New(tt.hash, tt.key)
|
||||
if s := hsh.Size(); s != tt.size {
|
||||
t.Errorf("Size: got %v, want %v", s, tt.size)
|
||||
}
|
||||
if b := hsh.BlockSize(); b != tt.blocksize {
|
||||
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
||||
}
|
||||
assert.Equal(t, tt.size, hsh.Size(), "Size mismatch")
|
||||
assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch")
|
||||
for j := 0; j < 4; j++ { //nolint:varnamelen
|
||||
n, err := hsh.Write(tt.in)
|
||||
if n != len(tt.in) || err != nil {
|
||||
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
||||
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n)
|
||||
assert.NoError(t, err, "test %d.%d: Write error", i, j)
|
||||
|
||||
// Repetitive Sum() calls should return the same value
|
||||
for k := 0; k < 2; k++ {
|
||||
sum := fmt.Sprintf("%x", hsh.Sum(nil))
|
||||
if sum != tt.out {
|
||||
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||
}
|
||||
assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||
}
|
||||
|
||||
// Second iteration: make sure reset works.
|
||||
@@ -568,18 +561,10 @@ func TestEqual(t *testing.T) {
|
||||
b := []byte("test1")
|
||||
c := []byte("test2")
|
||||
|
||||
if !Equal(b, b) {
|
||||
t.Error("Equal failed with equal arguments")
|
||||
}
|
||||
if Equal(a, b) {
|
||||
t.Error("Equal accepted a prefix of the second argument")
|
||||
}
|
||||
if Equal(b, a) {
|
||||
t.Error("Equal accepted a prefix of the first argument")
|
||||
}
|
||||
if Equal(b, c) {
|
||||
t.Error("Equal accepted unequal slices")
|
||||
}
|
||||
assert.True(t, Equal(b, b), "Equal failed with equal arguments")
|
||||
assert.False(t, Equal(a, b), "Equal accepted a prefix of the second argument")
|
||||
assert.False(t, Equal(b, a), "Equal accepted a prefix of the first argument")
|
||||
assert.False(t, Equal(b, c), "Equal accepted unequal slices")
|
||||
}
|
||||
|
||||
func BenchmarkHMACSHA256_1K(b *testing.B) {
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func BenchmarkHMACSHA1_512(b *testing.B) {
|
||||
@@ -44,26 +46,17 @@ func TestHMACReset(t *testing.T) {
|
||||
for i, tt := range hmacTests() {
|
||||
hsh := New(tt.hash, tt.key)
|
||||
hsh.(*hmac).resetTo(tt.key) //nolint:forcetypeassert
|
||||
if s := hsh.Size(); s != tt.size {
|
||||
t.Errorf("Size: got %v, want %v", s, tt.size)
|
||||
}
|
||||
if b := hsh.BlockSize(); b != tt.blocksize {
|
||||
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
||||
}
|
||||
assert.Equal(t, tt.size, hsh.Size(), "Size mismatch")
|
||||
assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch")
|
||||
for j := 0; j < 2; j++ {
|
||||
n, err := hsh.Write(tt.in)
|
||||
if n != len(tt.in) || err != nil {
|
||||
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
||||
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n)
|
||||
assert.NoError(t, err, "test %d.%d: Write error", i, j)
|
||||
|
||||
// Repetitive Sum() calls should return the same value
|
||||
for k := 0; k < 2; k++ {
|
||||
sum := fmt.Sprintf("%x", hsh.Sum(nil))
|
||||
if sum != tt.out {
|
||||
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||
}
|
||||
assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||
}
|
||||
|
||||
// Second iteration: make sure reset works.
|
||||
@@ -78,26 +71,17 @@ func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl,cyclop
|
||||
continue
|
||||
}
|
||||
hsh := AcquireSHA1(tt.key)
|
||||
if s := hsh.Size(); s != tt.size {
|
||||
t.Errorf("Size: got %v, want %v", s, tt.size)
|
||||
}
|
||||
if b := hsh.BlockSize(); b != tt.blocksize {
|
||||
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
||||
}
|
||||
assert.Equal(t, tt.size, hsh.Size(), "Size mismatch")
|
||||
assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch")
|
||||
for j := 0; j < 2; j++ {
|
||||
n, err := hsh.Write(tt.in)
|
||||
if n != len(tt.in) || err != nil {
|
||||
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
||||
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n)
|
||||
assert.NoError(t, err, "test %d.%d: Write error", i, j)
|
||||
|
||||
// Repetitive Sum() calls should return the same value
|
||||
for k := 0; k < 2; k++ {
|
||||
sum := fmt.Sprintf("%x", hsh.Sum(nil))
|
||||
if sum != tt.out {
|
||||
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||
}
|
||||
assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||
}
|
||||
|
||||
// Second iteration: make sure reset works.
|
||||
@@ -113,26 +97,17 @@ func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl,cyclop
|
||||
continue
|
||||
}
|
||||
hsh := AcquireSHA256(tt.key)
|
||||
if s := hsh.Size(); s != tt.size {
|
||||
t.Errorf("Size: got %v, want %v", s, tt.size)
|
||||
}
|
||||
if b := hsh.BlockSize(); b != tt.blocksize {
|
||||
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
|
||||
}
|
||||
assert.Equal(t, tt.size, hsh.Size(), "Size mismatch")
|
||||
assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch")
|
||||
for j := 0; j < 2; j++ {
|
||||
n, err := hsh.Write(tt.in)
|
||||
if n != len(tt.in) || err != nil {
|
||||
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
|
||||
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n)
|
||||
assert.NoError(t, err, "test %d.%d: Write error", i, j)
|
||||
|
||||
// Repetitive Sum() calls should return the same value
|
||||
for k := 0; k < 2; k++ {
|
||||
sum := fmt.Sprintf("%x", hsh.Sum(nil))
|
||||
if sum != tt.out {
|
||||
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||
}
|
||||
assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
|
||||
}
|
||||
|
||||
// Second iteration: make sure reset works.
|
||||
@@ -150,7 +125,7 @@ func TestAssertBlockSize(t *testing.T) {
|
||||
t.Run("Negative", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("should panic")
|
||||
assert.Fail(t, "should panic")
|
||||
}
|
||||
}()
|
||||
h := AcquireSHA256(make([]byte, 0, 1024))
|
||||
|
||||
@@ -6,6 +6,8 @@ package testutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// ShouldNotAllocate fails if f allocates.
|
||||
@@ -17,7 +19,5 @@ func ShouldNotAllocate(t *testing.T, f func()) {
|
||||
|
||||
return
|
||||
}
|
||||
if a := testing.AllocsPerRun(10, f); a > 0 {
|
||||
t.Errorf("Allocations detected: %f", a)
|
||||
}
|
||||
assert.Zero(t, testing.AllocsPerRun(10, f))
|
||||
}
|
||||
|
||||
387
message_test.go
387
message_test.go
@@ -21,6 +21,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type attributeEncoder interface {
|
||||
@@ -50,12 +52,9 @@ func TestMessageBuffer(t *testing.T) {
|
||||
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||
m.WriteHeader()
|
||||
mDecoded := New()
|
||||
if _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw)); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !mDecoded.Equal(m) {
|
||||
t.Error(mDecoded, "!", m)
|
||||
}
|
||||
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw))
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, mDecoded.Equal(m), "mDecoded != m")
|
||||
}
|
||||
|
||||
func BenchmarkMessage_Write(b *testing.B) {
|
||||
@@ -86,9 +85,7 @@ func TestMessageType_Value(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
b := tt.in.Value()
|
||||
if b != tt.out {
|
||||
t.Errorf("Value(%s) -> %s, want %s", tt.in, bUint16(b), bUint16(tt.out))
|
||||
}
|
||||
assert.Equal(t, tt.out, b, "Value(%s) -> %s, want %s", tt.in, bUint16(b), bUint16(tt.out))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,9 +101,7 @@ func TestMessageType_ReadValue(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
m := MessageType{}
|
||||
m.ReadValue(tt.in)
|
||||
if m != tt.out {
|
||||
t.Errorf("ReadValue(%s) -> %s, want %s", bUint16(tt.in), m, tt.out)
|
||||
}
|
||||
assert.Equal(t, tt.out, m, "ReadValue(%s) -> %s, want %s", bUint16(tt.in), m, tt.out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,12 +116,8 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
|
||||
m := MessageType{}
|
||||
v := tt.Value()
|
||||
m.ReadValue(v)
|
||||
if m != tt {
|
||||
t.Errorf("ReadValue(%s -> %s) = %s, should be %s", tt, bUint16(v), m, tt)
|
||||
if m.Method != tt.Method {
|
||||
t.Errorf("%s != %s", bUint16(uint16(m.Method)), bUint16(uint16(tt.Method)))
|
||||
}
|
||||
}
|
||||
assert.Equal(t, tt, m, "ReadValue(%s -> %s) = %s, should be %s", tt, bUint16(v), m, tt)
|
||||
assert.Equal(t, tt.Method, m.Method, "%s != %s", bUint16(uint16(m.Method)), bUint16(uint16(tt.Method)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,32 +128,26 @@ func TestMessage_WriteTo(t *testing.T) {
|
||||
msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||
msg.WriteHeader()
|
||||
buf := new(bytes.Buffer)
|
||||
if _, err := msg.WriteTo(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err := msg.WriteTo(buf)
|
||||
assert.NoError(t, err)
|
||||
mDecoded := New()
|
||||
if _, err := mDecoded.ReadFrom(buf); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !mDecoded.Equal(msg) {
|
||||
t.Error(mDecoded, "!", msg)
|
||||
}
|
||||
_, err = mDecoded.ReadFrom(buf)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, mDecoded.Equal(msg), "mDecoded != msg")
|
||||
}
|
||||
|
||||
func TestMessage_Cookie(t *testing.T) {
|
||||
buf := make([]byte, 20)
|
||||
mDecoded := New()
|
||||
if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
_, err := mDecoded.ReadFrom(bytes.NewReader(buf))
|
||||
assert.Error(t, err, "should error")
|
||||
}
|
||||
|
||||
func TestMessage_LengthLessHeaderSize(t *testing.T) {
|
||||
buf := make([]byte, 8)
|
||||
mDecoded := New()
|
||||
if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
_, err := mDecoded.ReadFrom(bytes.NewReader(buf))
|
||||
assert.Error(t, err, "should error")
|
||||
}
|
||||
|
||||
func TestMessage_BadLength(t *testing.T) {
|
||||
@@ -176,9 +161,8 @@ func TestMessage_BadLength(t *testing.T) {
|
||||
m.WriteHeader()
|
||||
m.Raw[20+3] = 10 // set attr length = 10
|
||||
mDecoded := New()
|
||||
if _, err := mDecoded.Write(m.Raw); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
_, err := mDecoded.Write(m.Raw)
|
||||
assert.Error(t, err, "should error")
|
||||
}
|
||||
|
||||
func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
|
||||
@@ -197,13 +181,8 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
|
||||
binary.BigEndian.PutUint16(m.Raw[2:4], 2) // rewrite to bad length
|
||||
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+2]))
|
||||
var e *DecodeErr
|
||||
if errors.As(err, &e) {
|
||||
if !e.IsPlace(DecodeErrPlace{"attribute", "header"}) {
|
||||
t.Error(e, "bad place")
|
||||
}
|
||||
} else {
|
||||
t.Error(err, "should be bad format")
|
||||
}
|
||||
assert.ErrorAs(t, err, &e)
|
||||
assert.True(t, e.IsPlace(DecodeErrPlace{"attribute", "header"}), "bad place")
|
||||
}
|
||||
|
||||
func TestMessage_AttrSizeLessThanLength(t *testing.T) {
|
||||
@@ -226,13 +205,8 @@ func TestMessage_AttrSizeLessThanLength(t *testing.T) {
|
||||
mDecoded := New()
|
||||
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+5]))
|
||||
var e *DecodeErr
|
||||
if errors.As(err, &e) {
|
||||
if !e.IsPlace(DecodeErrPlace{"attribute", "value"}) {
|
||||
t.Error(e, "bad place")
|
||||
}
|
||||
} else {
|
||||
t.Error(err, "should be bad format")
|
||||
}
|
||||
assert.ErrorAs(t, err, &e)
|
||||
assert.True(t, e.IsPlace(DecodeErrPlace{"attribute", "value"}), "bad place")
|
||||
}
|
||||
|
||||
type unexpectedEOFReader struct{}
|
||||
@@ -244,9 +218,7 @@ func (r unexpectedEOFReader) Read([]byte) (int, error) {
|
||||
func TestMessage_ReadFromError(t *testing.T) {
|
||||
mDecoded := New()
|
||||
_, err := mDecoded.ReadFrom(unexpectedEOFReader{})
|
||||
if !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
t.Error(err, "should be", io.ErrUnexpectedEOF)
|
||||
}
|
||||
assert.ErrorIs(t, err, io.ErrUnexpectedEOF, "should be", io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
func BenchmarkMessageType_Value(b *testing.B) {
|
||||
@@ -321,9 +293,7 @@ func BenchmarkMessage_ReadBytes(b *testing.B) {
|
||||
|
||||
func TestMessageClass_String(t *testing.T) {
|
||||
defer func() {
|
||||
if err := recover(); err == nil {
|
||||
t.Error(err, "should be not nil")
|
||||
}
|
||||
assert.NotNil(t, recover())
|
||||
}()
|
||||
|
||||
v := [...]MessageClass{
|
||||
@@ -333,14 +303,12 @@ func TestMessageClass_String(t *testing.T) {
|
||||
ClassIndication,
|
||||
}
|
||||
for _, k := range v {
|
||||
if k.String() == "" {
|
||||
t.Error(k, "bad stringer")
|
||||
}
|
||||
assert.NotEmpty(t, k.String(), "%v bad stringer", k)
|
||||
}
|
||||
|
||||
// should panic
|
||||
p := MessageClass(0x05).String()
|
||||
t.Error("should panic!", p)
|
||||
assert.Fail(t, "should panic", p)
|
||||
}
|
||||
|
||||
func TestAttrType_String(t *testing.T) {
|
||||
@@ -358,46 +326,26 @@ func TestAttrType_String(t *testing.T) {
|
||||
AttrFingerprint,
|
||||
}
|
||||
for _, k := range attrType {
|
||||
if k.String() == "" {
|
||||
t.Error(k, "bad stringer")
|
||||
}
|
||||
if strings.HasPrefix(k.String(), "0x") {
|
||||
t.Error(k, "bad stringer")
|
||||
}
|
||||
assert.NotEmpty(t, k.String(), "%v bad stringer", k)
|
||||
assert.False(t, strings.HasPrefix(k.String(), "0x"), "%v bad stringer", k)
|
||||
}
|
||||
vNonStandard := AttrType(0x512)
|
||||
if !strings.HasPrefix(vNonStandard.String(), "0x512") {
|
||||
t.Error(vNonStandard, "bad prefix")
|
||||
}
|
||||
assert.True(t, strings.HasPrefix(vNonStandard.String(), "0x512"), "%v bad prefix", vNonStandard)
|
||||
}
|
||||
|
||||
func TestMethod_String(t *testing.T) {
|
||||
if MethodBinding.String() != "Binding" {
|
||||
t.Error("binding is not binding!")
|
||||
}
|
||||
if Method(0x616).String() != "0x616" {
|
||||
t.Error("Bad stringer", Method(0x616))
|
||||
}
|
||||
assert.Equal(t, "Binding", MethodBinding.String(), "binding is not binding!")
|
||||
assert.Equal(t, "0x616", Method(0x616).String(), "Bad stringer")
|
||||
}
|
||||
|
||||
func TestAttribute_Equal(t *testing.T) {
|
||||
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
|
||||
attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}}
|
||||
if !attr1.Equal(attr2) {
|
||||
t.Error("should equal")
|
||||
}
|
||||
if attr1.Equal(RawAttribute{Type: 0x2}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
if attr1.Equal(RawAttribute{Length: 0x2}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
if attr1.Equal(RawAttribute{Length: 0x3}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
if attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
|
||||
attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
|
||||
assert.True(t, attr1.Equal(attr2))
|
||||
assert.False(t, attr1.Equal(RawAttribute{Type: 0x2}))
|
||||
assert.False(t, attr1.Equal(RawAttribute{Length: 0x2}))
|
||||
assert.False(t, attr1.Equal(RawAttribute{Length: 0x3}))
|
||||
assert.False(t, attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}))
|
||||
}
|
||||
|
||||
func TestMessage_Equal(t *testing.T) { //nolint:cyclop
|
||||
@@ -405,39 +353,23 @@ func TestMessage_Equal(t *testing.T) { //nolint:cyclop
|
||||
attrs := Attributes{attr}
|
||||
msg1 := &Message{Attributes: attrs, Length: 4 + 2}
|
||||
msg2 := &Message{Attributes: attrs, Length: 4 + 2}
|
||||
if !msg1.Equal(msg2) {
|
||||
t.Error("should equal")
|
||||
}
|
||||
if msg1.Equal(&Message{Type: MessageType{Class: 128}}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
assert.True(t, msg1.Equal(msg2))
|
||||
assert.False(t, msg1.Equal(&Message{Type: MessageType{Class: 128}}))
|
||||
tID := [TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
}
|
||||
if msg1.Equal(&Message{TransactionID: tID}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
if msg1.Equal(&Message{Length: 3}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
assert.False(t, msg1.Equal(&Message{TransactionID: tID}))
|
||||
assert.False(t, msg1.Equal(&Message{Length: 3}))
|
||||
tAttrs := Attributes{
|
||||
{Length: 1, Value: []byte{0x1}, Type: 0x1},
|
||||
}
|
||||
if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
assert.False(t, msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}))
|
||||
tAttrs = Attributes{
|
||||
{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2},
|
||||
}
|
||||
if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
if !(*Message)(nil).Equal(nil) {
|
||||
t.Error("nil should be equal to nil")
|
||||
}
|
||||
if msg1.Equal(nil) {
|
||||
t.Error("non-nil should not be equal to nil")
|
||||
}
|
||||
assert.False(t, msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}))
|
||||
assert.True(t, (*Message)(nil).Equal(nil), "nil should be equal to nil")
|
||||
assert.False(t, msg1.Equal(nil), "non-nil should not be equal to nil")
|
||||
t.Run("Nil attributes", func(t *testing.T) {
|
||||
msg1 := &Message{
|
||||
Attributes: nil,
|
||||
@@ -447,61 +379,43 @@ func TestMessage_Equal(t *testing.T) { //nolint:cyclop
|
||||
Attributes: attrs,
|
||||
Length: 4 + 2,
|
||||
}
|
||||
if msg1.Equal(msg2) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
if msg2.Equal(msg1) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
assert.False(t, msg1.Equal(msg2))
|
||||
assert.False(t, msg2.Equal(msg1))
|
||||
msg2.Attributes = nil
|
||||
if !msg1.Equal(msg2) {
|
||||
t.Error("should equal")
|
||||
}
|
||||
assert.True(t, msg1.Equal(msg2))
|
||||
})
|
||||
t.Run("Attributes length", func(t *testing.T) {
|
||||
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
|
||||
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
|
||||
a := &Message{Attributes: Attributes{attr}, Length: 4 + 2}
|
||||
b := &Message{Attributes: Attributes{attr, attr1}, Length: 4 + 2}
|
||||
if a.Equal(b) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
})
|
||||
t.Run("Attributes values", func(t *testing.T) {
|
||||
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
|
||||
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x1}
|
||||
a := &Message{Attributes: Attributes{attr, attr}, Length: 4 + 2}
|
||||
b := &Message{Attributes: Attributes{attr, attr1}, Length: 4 + 2}
|
||||
if a.Equal(b) {
|
||||
t.Error("should not equal")
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMessageGrow(t *testing.T) {
|
||||
m := New()
|
||||
m.grow(512)
|
||||
if len(m.Raw) < 512 {
|
||||
t.Error("Bad length", len(m.Raw))
|
||||
}
|
||||
assert.GreaterOrEqual(t, len(m.Raw), 512)
|
||||
}
|
||||
|
||||
func TestMessageGrowSmaller(t *testing.T) {
|
||||
m := New()
|
||||
m.grow(2)
|
||||
if cap(m.Raw) < 20 {
|
||||
t.Error("Bad capacity", cap(m.Raw))
|
||||
}
|
||||
if len(m.Raw) < 20 {
|
||||
t.Error("Bad length", len(m.Raw))
|
||||
}
|
||||
assert.GreaterOrEqual(t, cap(m.Raw), 20)
|
||||
assert.GreaterOrEqual(t, len(m.Raw), 20)
|
||||
}
|
||||
|
||||
func TestMessage_String(t *testing.T) {
|
||||
m := New()
|
||||
if m.String() == "" {
|
||||
t.Error("bad string")
|
||||
}
|
||||
assert.NotEmpty(t, m.String())
|
||||
}
|
||||
|
||||
func TestIsMessage(t *testing.T) {
|
||||
@@ -525,9 +439,7 @@ func TestIsMessage(t *testing.T) {
|
||||
}, true}, // 6
|
||||
}
|
||||
for i, v := range tt {
|
||||
if got := IsMessage(v.in); got != v.out {
|
||||
t.Errorf("tt[%d]: IsMessage(%+v) %v != %v", i, v.in, got, v.out)
|
||||
}
|
||||
assert.Equal(t, v.out, IsMessage(v.in), "tt[%d]: IsMessage(%+v)", i, v.in)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -553,18 +465,12 @@ func loadData(tb testing.TB, name string) []byte {
|
||||
|
||||
name = filepath.Join("testdata", name)
|
||||
f, err := os.Open(name) //nolint:gosec
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
assert.NoError(tb, err)
|
||||
defer func() {
|
||||
if errClose := f.Close(); errClose != nil {
|
||||
tb.Fatal(errClose)
|
||||
}
|
||||
assert.NoError(tb, f.Close())
|
||||
}()
|
||||
v, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
assert.NoError(tb, err)
|
||||
|
||||
return v
|
||||
}
|
||||
@@ -573,9 +479,7 @@ func TestExampleChrome(t *testing.T) {
|
||||
buf := loadData(t, "ex1_chrome.stun")
|
||||
m := New()
|
||||
_, err := m.Write(buf)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to parse ex1_chrome: %s", err)
|
||||
}
|
||||
assert.NoError(t, err, "Failed to parse ex1_chrome")
|
||||
}
|
||||
|
||||
func TestMessageFromBrowsers(t *testing.T) {
|
||||
@@ -583,9 +487,7 @@ func TestMessageFromBrowsers(t *testing.T) {
|
||||
reader := csv.NewReader(bytes.NewReader(loadData(t, "frombrowsers.csv")))
|
||||
reader.Comment = '#'
|
||||
_, err := reader.Read() // skipping header
|
||||
if err != nil {
|
||||
t.Fatal("failed to skip header of csv: ", err)
|
||||
}
|
||||
assert.NoError(t, err, "failed to skip header of csv")
|
||||
crcTable := crc64.MakeTable(crc64.ISO)
|
||||
msg := New()
|
||||
for {
|
||||
@@ -593,23 +495,14 @@ func TestMessageFromBrowsers(t *testing.T) {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal("failed to read csv line: ", err)
|
||||
}
|
||||
assert.NoError(t, err, "failed to read csv line")
|
||||
data, err := base64.StdEncoding.DecodeString(line[1])
|
||||
if err != nil {
|
||||
t.Fatal("failed to decode ", line[1], " as base64: ", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
b, err := strconv.ParseUint(line[2], 10, 64)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if b != crc64.Checksum(data, crcTable) {
|
||||
t.Error("crc64 check failed for ", line[1])
|
||||
}
|
||||
if _, err = msg.Write(data); err != nil {
|
||||
t.Error("failed to decode ", line[1], " as message: ", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, b, crc64.Checksum(data, crcTable), "crc64 check failed for %s", line[1])
|
||||
_, err = msg.Write(data)
|
||||
assert.NoError(t, err, "failed to decode %s as message: %s", line[1], err)
|
||||
msg.Reset()
|
||||
}
|
||||
}
|
||||
@@ -619,9 +512,7 @@ func BenchmarkMessage_NewTransactionID(b *testing.B) {
|
||||
m := new(Message)
|
||||
m.WriteHeader()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := m.NewTransactionID(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
assert.NoError(b, m.NewTransactionID())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -633,12 +524,8 @@ func BenchmarkMessageFull(b *testing.B) {
|
||||
IP: net.IPv4(213, 1, 223, 5),
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := addr.AddTo(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := s.AddTo(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
assert.NoError(b, addr.AddTo(msg))
|
||||
assert.NoError(b, s.AddTo(msg))
|
||||
msg.WriteAttributes()
|
||||
msg.WriteHeader()
|
||||
Fingerprint.AddTo(msg) //nolint:errcheck,gosec
|
||||
@@ -655,12 +542,8 @@ func BenchmarkMessageFullHardcore(b *testing.B) {
|
||||
IP: net.IPv4(213, 1, 223, 5),
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := addr.AddTo(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if err := s.AddTo(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
assert.NoError(b, addr.AddTo(msg))
|
||||
assert.NoError(b, s.AddTo(msg))
|
||||
msg.WriteHeader()
|
||||
msg.Reset()
|
||||
}
|
||||
@@ -684,12 +567,8 @@ func BenchmarkMessage_WriteHeader(b *testing.B) {
|
||||
func TestMessage_Contains(t *testing.T) {
|
||||
m := new(Message)
|
||||
m.Add(AttrSoftware, []byte("value"))
|
||||
if !m.Contains(AttrSoftware) {
|
||||
t.Error("message should contain software")
|
||||
}
|
||||
if m.Contains(AttrNonce) {
|
||||
t.Error("message should not contain nonce")
|
||||
}
|
||||
assert.True(t, m.Contains(AttrSoftware), "message should contain software")
|
||||
assert.False(t, m.Contains(AttrNonce), "message should not contain nonce")
|
||||
}
|
||||
|
||||
func ExampleMessage() {
|
||||
@@ -787,13 +666,9 @@ func TestAllocations(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(10, func() {
|
||||
m.Reset()
|
||||
m.WriteHeader()
|
||||
if err := s.AddTo(m); err != nil {
|
||||
t.Errorf("[%d] failed to add", i)
|
||||
}
|
||||
assert.NoError(t, s.AddTo(m), "[%d] failed to add", i)
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("[%d] allocated %.0f", i, allocs)
|
||||
}
|
||||
assert.Zero(t, allocs, "[%d] allocated", i)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -818,9 +693,7 @@ func TestAllocationsGetters(t *testing.T) {
|
||||
Fingerprint,
|
||||
}
|
||||
msg := New()
|
||||
if err := msg.Build(setters...); err != nil {
|
||||
t.Error("failed to build", err)
|
||||
}
|
||||
assert.NoError(t, msg.Build(setters...))
|
||||
getters := []Getter{
|
||||
new(Nonce),
|
||||
new(Username),
|
||||
@@ -832,66 +705,48 @@ func TestAllocationsGetters(t *testing.T) {
|
||||
g := g
|
||||
i := i
|
||||
allocs := testing.AllocsPerRun(10, func() {
|
||||
if err := g.GetFrom(msg); err != nil {
|
||||
t.Errorf("[%d] failed to get", i)
|
||||
}
|
||||
assert.NoError(t, g.GetFrom(msg), "[%d] failed to get", i)
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("[%d] allocated %.0f", i, allocs)
|
||||
}
|
||||
assert.Zero(t, allocs, "[%d] allocated", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageFullSize(t *testing.T) {
|
||||
msg := new(Message)
|
||||
if err := msg.Build(BindingRequest,
|
||||
assert.NoError(t, msg.Build(BindingRequest,
|
||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||
}),
|
||||
NewSoftware("pion/stun"),
|
||||
NewLongTermIntegrity("username", "realm", "password"),
|
||||
Fingerprint,
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
))
|
||||
msg.Raw = msg.Raw[:len(msg.Raw)-10]
|
||||
|
||||
decoder := new(Message)
|
||||
decoder.Raw = msg.Raw[:len(msg.Raw)-10]
|
||||
if err := decoder.Decode(); err == nil {
|
||||
t.Error("decode on truncated buffer should error")
|
||||
}
|
||||
assert.Error(t, decoder.Decode(), "decode on truncated buffer should error")
|
||||
}
|
||||
|
||||
func TestMessage_CloneTo(t *testing.T) {
|
||||
msg := new(Message)
|
||||
if err := msg.Build(BindingRequest,
|
||||
assert.NoError(t, msg.Build(BindingRequest,
|
||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||
}),
|
||||
NewSoftware("pion/stun"),
|
||||
NewLongTermIntegrity("username", "realm", "password"),
|
||||
Fingerprint,
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
))
|
||||
msg.Encode()
|
||||
msg2 := new(Message)
|
||||
if err := msg.CloneTo(msg2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !msg2.Equal(msg) {
|
||||
t.Fatal("not equal")
|
||||
}
|
||||
assert.NoError(t, msg.CloneTo(msg2))
|
||||
assert.True(t, msg2.Equal(msg), "cloned message should equal original")
|
||||
// Corrupting m and checking that b is not corrupted.
|
||||
s, ok := msg2.Attributes.Get(AttrSoftware)
|
||||
if !ok {
|
||||
t.Fatal("no software attribute")
|
||||
}
|
||||
assert.True(t, ok)
|
||||
s.Value[0] = 'k'
|
||||
if msg2.Equal(msg) {
|
||||
t.Fatal("should not be equal")
|
||||
}
|
||||
assert.False(t, msg2.Equal(msg), "should not be equal")
|
||||
}
|
||||
|
||||
func BenchmarkMessage_CloneTo(b *testing.B) {
|
||||
@@ -919,29 +774,21 @@ func BenchmarkMessage_CloneTo(b *testing.B) {
|
||||
|
||||
func TestMessage_AddTo(t *testing.T) {
|
||||
msg := new(Message)
|
||||
if err := msg.Build(BindingRequest,
|
||||
assert.NoError(t, msg.Build(BindingRequest,
|
||||
NewTransactionIDSetter([TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
|
||||
}),
|
||||
Fingerprint,
|
||||
); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
))
|
||||
msg.Encode()
|
||||
b := new(Message)
|
||||
if err := msg.CloneTo(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, msg.CloneTo(b))
|
||||
msg.TransactionID = [TransactionIDSize]byte{
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2,
|
||||
}
|
||||
if b.Equal(msg) {
|
||||
t.Fatal("should not be equal")
|
||||
}
|
||||
assert.False(t, b.Equal(msg), "should not be equal")
|
||||
msg.AddTo(b) //nolint:errcheck,gosec
|
||||
if !b.Equal(msg) {
|
||||
t.Fatal("should be equal")
|
||||
}
|
||||
assert.True(t, b.Equal(msg), "should be equal")
|
||||
}
|
||||
|
||||
func BenchmarkMessage_AddTo(b *testing.B) {
|
||||
@@ -966,9 +813,7 @@ func BenchmarkMessage_AddTo(b *testing.B) {
|
||||
|
||||
func TestDecode(t *testing.T) {
|
||||
t.Run("Nil", func(t *testing.T) {
|
||||
if err := Decode(nil, nil); !errors.Is(err, ErrDecodeToNil) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
assert.ErrorIs(t, Decode(nil, nil), ErrDecodeToNil)
|
||||
})
|
||||
msg := New()
|
||||
msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||
@@ -976,22 +821,14 @@ func TestDecode(t *testing.T) {
|
||||
msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
|
||||
msg.WriteHeader()
|
||||
mDecoded := New()
|
||||
if err := Decode(msg.Raw, mDecoded); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if !mDecoded.Equal(msg) {
|
||||
t.Error("decoded result is not equal to encoded message")
|
||||
}
|
||||
assert.NoError(t, Decode(msg.Raw, mDecoded))
|
||||
assert.True(t, mDecoded.Equal(msg), "decoded result is not equal to encoded message")
|
||||
t.Run("ZeroAlloc", func(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(10, func() {
|
||||
mDecoded.Reset()
|
||||
if err := Decode(msg.Raw, mDecoded); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, Decode(msg.Raw, mDecoded))
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Error("unexpected allocations")
|
||||
}
|
||||
assert.Zero(t, allocs, "unexpected allocations")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1021,25 +858,19 @@ func TestMessage_MarshalBinary(t *testing.T) {
|
||||
},
|
||||
)
|
||||
data, err := msg.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Reset m.Raw to check retention.
|
||||
for i := range msg.Raw {
|
||||
msg.Raw[i] = 0
|
||||
}
|
||||
if err := msg.UnmarshalBinary(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, msg.UnmarshalBinary(data))
|
||||
|
||||
// Reset data to check retention.
|
||||
for i := range data {
|
||||
data[i] = 0
|
||||
}
|
||||
if err := msg.Decode(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, msg.Decode())
|
||||
}
|
||||
|
||||
func TestMessage_GobDecode(t *testing.T) {
|
||||
@@ -1050,23 +881,17 @@ func TestMessage_GobDecode(t *testing.T) {
|
||||
},
|
||||
)
|
||||
data, err := msg.GobEncode()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Reset m.Raw to check retention.
|
||||
for i := range msg.Raw {
|
||||
msg.Raw[i] = 0
|
||||
}
|
||||
if err := msg.GobDecode(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, msg.GobDecode(data))
|
||||
|
||||
// Reset data to check retention.
|
||||
for i := range data {
|
||||
data[i] = 0
|
||||
}
|
||||
if err := msg.Decode(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, msg.Decode())
|
||||
}
|
||||
|
||||
120
rfc5769_test.go
120
rfc5769_test.go
@@ -6,6 +6,8 @@ package stun
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRFC5769(t *testing.T) { //nolint:cyclop
|
||||
@@ -32,19 +34,11 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop
|
||||
"\xe5\x7a\x3b\xcf",
|
||||
),
|
||||
}
|
||||
if err := m.Decode(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, m.Decode())
|
||||
software := new(Software)
|
||||
if err := software.GetFrom(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if software.String() != "STUN test client" {
|
||||
t.Error("bad software: ", software)
|
||||
}
|
||||
if err := Fingerprint.Check(m); err != nil {
|
||||
t.Error("check failed: ", err)
|
||||
}
|
||||
assert.NoError(t, software.GetFrom(m))
|
||||
assert.Equal(t, "STUN test client", software.String())
|
||||
assert.NoError(t, Fingerprint.Check(m))
|
||||
t.Run("Long-Term credentials", func(t *testing.T) {
|
||||
msg := &Message{
|
||||
Raw: []byte("\x00\x01\x00\x60" +
|
||||
@@ -64,40 +58,24 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop
|
||||
"\x2e\x85\xc9\xa2\x8c\xa8\x96\x66",
|
||||
),
|
||||
}
|
||||
if err := msg.Decode(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, msg.Decode())
|
||||
u := new(Username)
|
||||
if err := u.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, u.GetFrom(msg))
|
||||
expectedUsername := "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9"
|
||||
if u.String() != expectedUsername {
|
||||
t.Errorf("username: %q (got) != %q (exp)", u, expectedUsername)
|
||||
}
|
||||
assert.Equal(t, expectedUsername, u.String())
|
||||
n := new(Nonce)
|
||||
if err := n.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if n.String() != "f//499k954d6OL34oL9FSTvy64sA" {
|
||||
t.Error("bad nonce")
|
||||
}
|
||||
assert.NoError(t, n.GetFrom(msg))
|
||||
assert.Equal(t, "f//499k954d6OL34oL9FSTvy64sA", n.String())
|
||||
r := new(Realm)
|
||||
if err := r.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if r.String() != "example.org" { //nolint:goconst
|
||||
t.Error("bad realm")
|
||||
}
|
||||
assert.NoError(t, r.GetFrom(msg))
|
||||
assert.Equal(t, "example.org", r.String())
|
||||
// checking HMAC
|
||||
i := NewLongTermIntegrity(
|
||||
"\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9",
|
||||
"example.org",
|
||||
"TheMatrIX",
|
||||
)
|
||||
if err := i.Check(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, i.Check(msg))
|
||||
})
|
||||
})
|
||||
t.Run("Response", func(t *testing.T) {
|
||||
@@ -117,32 +95,18 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop
|
||||
"\xc0\x7d\x4c\x96",
|
||||
),
|
||||
}
|
||||
if err := msg.Decode(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, msg.Decode())
|
||||
|
||||
software := new(Software)
|
||||
if err := software.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if software.String() != "test vector" {
|
||||
t.Error("bad software: ", software)
|
||||
}
|
||||
if err := Fingerprint.Check(msg); err != nil {
|
||||
t.Error("Check failed: ", err)
|
||||
}
|
||||
assert.NoError(t, software.GetFrom(msg))
|
||||
assert.Equal(t, "test vector", software.String())
|
||||
assert.NoError(t, Fingerprint.Check(msg))
|
||||
addr := new(XORMappedAddress)
|
||||
if err := addr.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !addr.IP.Equal(net.ParseIP("192.0.2.1")) {
|
||||
t.Error("bad IP")
|
||||
}
|
||||
if addr.Port != 32853 {
|
||||
t.Error("bad Port")
|
||||
}
|
||||
if err := Fingerprint.Check(msg); err != nil {
|
||||
t.Error("check failed: ", err)
|
||||
}
|
||||
assert.NoError(t, addr.GetFrom(msg))
|
||||
expected := "192.0.2.1"
|
||||
assert.Equalf(t, expected, addr.IP.String(), "Expected %s, got %s", expected, addr.IP)
|
||||
assert.Equal(t, 32853, addr.Port)
|
||||
assert.NoError(t, Fingerprint.Check(msg))
|
||||
})
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
msg := &Message{
|
||||
@@ -162,32 +126,20 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop
|
||||
"\xc8\xfb\x0b\x4c",
|
||||
),
|
||||
}
|
||||
if err := msg.Decode(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, msg.Decode())
|
||||
software := new(Software)
|
||||
if err := software.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if software.String() != "test vector" {
|
||||
t.Error("bad software: ", software)
|
||||
}
|
||||
if err := Fingerprint.Check(msg); err != nil {
|
||||
t.Error("Check failed: ", err)
|
||||
}
|
||||
assert.NoError(t, software.GetFrom(msg))
|
||||
assert.Equal(t, "test vector", software.String())
|
||||
assert.NoError(t, Fingerprint.Check(msg))
|
||||
addr := new(XORMappedAddress)
|
||||
if err := addr.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !addr.IP.Equal(net.ParseIP("2001:db8:1234:5678:11:2233:4455:6677")) {
|
||||
t.Error("bad IP")
|
||||
}
|
||||
if addr.Port != 32853 {
|
||||
t.Error("bad Port")
|
||||
}
|
||||
if err := Fingerprint.Check(msg); err != nil {
|
||||
t.Error("check failed: ", err)
|
||||
}
|
||||
assert.NoError(t, addr.GetFrom(msg))
|
||||
expectedIP := "2001:db8:1234:5678:11:2233:4455:6677"
|
||||
assert.Truef(
|
||||
t, addr.IP.Equal(net.ParseIP(expectedIP)),
|
||||
"Expected %s, got %s", expectedIP, addr.IP,
|
||||
)
|
||||
assert.Equal(t, 32853, addr.Port)
|
||||
assert.NoError(t, Fingerprint.Check(msg))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
10
stun_test.go
10
stun_test.go
@@ -6,6 +6,8 @@ package stun
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type errorReader struct{}
|
||||
@@ -21,9 +23,7 @@ func (errorReader) Read([]byte) (int, error) {
|
||||
|
||||
func TestReadFullHelper(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("should panic")
|
||||
}
|
||||
assert.NotNil(t, recover(), "should panic")
|
||||
}()
|
||||
readFullOrPanic(errorReader{}, make([]byte, 1))
|
||||
}
|
||||
@@ -36,9 +36,7 @@ func (errorWriter) Write([]byte) (int, error) {
|
||||
|
||||
func TestWriteHelper(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("should panic")
|
||||
}
|
||||
assert.NotNil(t, recover(), "should panic")
|
||||
}()
|
||||
writeOrPanic(errorWriter{}, make([]byte, 1))
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var errUDPServerUnsupportedNetwork = errors.New("unsupported network")
|
||||
@@ -37,9 +39,7 @@ func NewUDPServer(
|
||||
}
|
||||
|
||||
udpConn, err := net.ListenUDP(network, &net.UDPAddr{IP: net.ParseIP(ip), Port: 0})
|
||||
if err != nil {
|
||||
t.Fatal(err) //nolint:forbidigo
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Necessary for IPv6
|
||||
address := fmt.Sprintf("%s:%d", ip, udpConn.LocalAddr().(*net.UDPAddr).Port) //nolint:forcetypeassert
|
||||
@@ -81,18 +81,14 @@ func NewUDPServer(
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
err := udpConn.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.NoError(t, udpConn.Close())
|
||||
<-errCh
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -7,9 +7,10 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSoftware_GetFrom(t *testing.T) {
|
||||
@@ -22,44 +23,29 @@ func TestSoftware_GetFrom(t *testing.T) {
|
||||
Raw: make([]byte, 0, 256),
|
||||
}
|
||||
software := new(Software)
|
||||
if _, err := m2.ReadFrom(msg.reader()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := software.GetFrom(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if software.String() != val {
|
||||
t.Errorf("Expected %q, got %q.", val, software)
|
||||
}
|
||||
_, err := m2.ReadFrom(msg.reader())
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, software.GetFrom(msg))
|
||||
assert.Equal(t, val, software.String())
|
||||
|
||||
sAttr, ok := msg.Attributes.Get(AttrSoftware)
|
||||
if !ok {
|
||||
t.Error("software attribute should be found")
|
||||
}
|
||||
assert.True(t, ok, "software attribute should be found")
|
||||
s := sAttr.String()
|
||||
if !strings.HasPrefix(s, "SOFTWARE:") {
|
||||
t.Error("bad string representation", s)
|
||||
}
|
||||
assert.True(t, strings.HasPrefix(s, "SOFTWARE:"), "bad string representation")
|
||||
}
|
||||
|
||||
func TestSoftware_AddTo_Invalid(t *testing.T) {
|
||||
m := New()
|
||||
s := make(Software, 1024)
|
||||
if err := s.AddTo(m); !IsAttrSizeOverflow(err) {
|
||||
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
|
||||
}
|
||||
if err := s.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
|
||||
}
|
||||
assert.True(t, IsAttrSizeOverflow(s.AddTo(m)), "AddTo should return *AttrOverflowErr")
|
||||
assert.ErrorIs(t, s.GetFrom(m), ErrAttributeNotFound)
|
||||
}
|
||||
|
||||
func TestSoftware_AddTo_Regression(t *testing.T) {
|
||||
// s.AddTo checked len(m.Raw) instead of len(s.Raw).
|
||||
m := &Message{Raw: make([]byte, 2048)}
|
||||
s := make(Software, 100)
|
||||
if err := s.AddTo(m); err != nil {
|
||||
t.Errorf("AddTo should return <nil>, got: %v", err)
|
||||
}
|
||||
assert.NoError(t, s.AddTo(m))
|
||||
}
|
||||
|
||||
func BenchmarkUsername_AddTo(b *testing.B) {
|
||||
@@ -95,28 +81,18 @@ func TestUsername(t *testing.T) {
|
||||
msg.WriteHeader()
|
||||
t.Run("Bad length", func(t *testing.T) {
|
||||
badU := make(Username, 600)
|
||||
if err := badU.AddTo(msg); !IsAttrSizeOverflow(err) {
|
||||
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
|
||||
}
|
||||
assert.True(t, IsAttrSizeOverflow(badU.AddTo(msg)), "AddTo should return *AttrOverflowErr")
|
||||
})
|
||||
t.Run("AddTo", func(t *testing.T) {
|
||||
if err := uName.AddTo(msg); err != nil {
|
||||
t.Error("errored:", err)
|
||||
}
|
||||
assert.NoError(t, uName.AddTo(msg))
|
||||
t.Run("GetFrom", func(t *testing.T) {
|
||||
got := new(Username)
|
||||
if err := got.GetFrom(msg); err != nil {
|
||||
t.Error("errored:", err)
|
||||
}
|
||||
if got.String() != username {
|
||||
t.Errorf("expedted: %s, got: %s", username, got)
|
||||
}
|
||||
assert.NoError(t, got.GetFrom(msg))
|
||||
assert.Equal(t, username, got.String())
|
||||
t.Run("Not found", func(t *testing.T) {
|
||||
m := new(Message)
|
||||
u := new(Username)
|
||||
if err := u.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Error("Should error")
|
||||
}
|
||||
assert.ErrorIs(t, u.GetFrom(m), ErrAttributeNotFound)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -124,14 +100,10 @@ func TestUsername(t *testing.T) {
|
||||
m := new(Message)
|
||||
m.WriteHeader()
|
||||
u := NewUsername("username")
|
||||
if allocs := testing.AllocsPerRun(10, func() {
|
||||
if err := u.AddTo(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Empty(t, testing.AllocsPerRun(10, func() {
|
||||
assert.NoError(t, u.AddTo(m))
|
||||
m.Reset()
|
||||
}); allocs > 0 {
|
||||
t.Errorf("got %f allocations, zero expected", allocs)
|
||||
}
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -145,38 +117,23 @@ func TestRealm_GetFrom(t *testing.T) {
|
||||
Raw: make([]byte, 0, 256),
|
||||
}
|
||||
r := new(Realm)
|
||||
if err := r.GetFrom(m2); !errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
|
||||
}
|
||||
if _, err := m2.ReadFrom(msg.reader()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := r.GetFrom(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if r.String() != val {
|
||||
t.Errorf("Expected %q, got %q.", val, r)
|
||||
}
|
||||
assert.ErrorIs(t, r.GetFrom(m2), ErrAttributeNotFound)
|
||||
_, err := m2.ReadFrom(msg.reader())
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, r.GetFrom(msg))
|
||||
assert.Equal(t, val, r.String())
|
||||
|
||||
rAttr, ok := msg.Attributes.Get(AttrRealm)
|
||||
if !ok {
|
||||
t.Error("realm attribute should be found")
|
||||
}
|
||||
assert.True(t, ok, "realm attribute should be found")
|
||||
s := rAttr.String()
|
||||
if !strings.HasPrefix(s, "REALM:") {
|
||||
t.Error("bad string representation", s)
|
||||
}
|
||||
assert.True(t, strings.HasPrefix(s, "REALM:"), "bad string representation")
|
||||
}
|
||||
|
||||
func TestRealm_AddTo_Invalid(t *testing.T) {
|
||||
m := New()
|
||||
r := make(Realm, 1024)
|
||||
if err := r.AddTo(m); !IsAttrSizeOverflow(err) {
|
||||
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
|
||||
}
|
||||
if err := r.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
|
||||
}
|
||||
assert.True(t, IsAttrSizeOverflow(r.AddTo(m)), "AddTo should return *AttrOverflowErr")
|
||||
assert.ErrorIs(t, r.GetFrom(m), ErrAttributeNotFound)
|
||||
}
|
||||
|
||||
func TestNonce_GetFrom(t *testing.T) {
|
||||
@@ -189,50 +146,31 @@ func TestNonce_GetFrom(t *testing.T) {
|
||||
Raw: make([]byte, 0, 256),
|
||||
}
|
||||
var nonce Nonce
|
||||
if _, err := m2.ReadFrom(msg.reader()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := nonce.GetFrom(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if nonce.String() != val {
|
||||
t.Errorf("Expected %q, got %q.", val, nonce)
|
||||
}
|
||||
_, err := m2.ReadFrom(msg.reader())
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, nonce.GetFrom(msg))
|
||||
assert.Equal(t, val, nonce.String())
|
||||
|
||||
nAttr, ok := msg.Attributes.Get(AttrNonce)
|
||||
if !ok {
|
||||
t.Error("nonce attribute should be found")
|
||||
}
|
||||
assert.True(t, ok, "nonce attribute should be found")
|
||||
s := nAttr.String()
|
||||
if !strings.HasPrefix(s, "NONCE:") {
|
||||
t.Error("bad string representation", s)
|
||||
}
|
||||
assert.True(t, strings.HasPrefix(s, "NONCE:"), "bad string representation")
|
||||
}
|
||||
|
||||
func TestNonce_AddTo_Invalid(t *testing.T) {
|
||||
m := New()
|
||||
n := make(Nonce, 1024)
|
||||
if err := n.AddTo(m); !IsAttrSizeOverflow(err) {
|
||||
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
|
||||
}
|
||||
if err := n.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) {
|
||||
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
|
||||
}
|
||||
assert.True(t, IsAttrSizeOverflow(n.AddTo(m)), "AddTo should return *AttrOverflowErr")
|
||||
assert.ErrorIs(t, n.GetFrom(m), ErrAttributeNotFound)
|
||||
}
|
||||
|
||||
func TestNonce_AddTo(t *testing.T) {
|
||||
m := New()
|
||||
n := Nonce("example.org")
|
||||
if err := n.AddTo(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, n.AddTo(m))
|
||||
v, err := m.Get(AttrNonce)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if string(v) != "example.org" {
|
||||
t.Errorf("bad nonce %q", v)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "example.org", string(v))
|
||||
}
|
||||
|
||||
func BenchmarkNonce_AddTo(b *testing.B) {
|
||||
|
||||
@@ -5,6 +5,8 @@ package stun
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUnknownAttributes(t *testing.T) {
|
||||
@@ -13,33 +15,19 @@ func TestUnknownAttributes(t *testing.T) {
|
||||
AttrDontFragment,
|
||||
AttrChannelNumber,
|
||||
}
|
||||
if attr.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" {
|
||||
t.Error("bad String:", attr)
|
||||
}
|
||||
if (UnknownAttributes{}).String() != "<nil>" {
|
||||
t.Error("bad blank string")
|
||||
}
|
||||
if err := attr.AddTo(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Equal(t, "DONT-FRAGMENT, CHANNEL-NUMBER", attr.String())
|
||||
assert.Equal(t, "<nil>", (UnknownAttributes{}).String())
|
||||
assert.NoError(t, attr.AddTo(msg))
|
||||
t.Run("GetFrom", func(t *testing.T) {
|
||||
attrs := make(UnknownAttributes, 10)
|
||||
if err := attrs.GetFrom(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, attrs.GetFrom(msg))
|
||||
for i, at := range *attr {
|
||||
if at != attrs[i] {
|
||||
t.Error("expected", at, "!=", attrs[i])
|
||||
}
|
||||
assert.Equal(t, at, attrs[i])
|
||||
}
|
||||
mBlank := new(Message)
|
||||
if err := attrs.GetFrom(mBlank); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
assert.Error(t, attrs.GetFrom(mBlank))
|
||||
mBlank.Add(AttrUnknownAttributes, []byte{1, 2, 3})
|
||||
if err := attrs.GetFrom(mBlank); err == nil {
|
||||
t.Error("should error")
|
||||
}
|
||||
assert.Error(t, attrs.GetFrom(mBlank))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
120
xoraddr_test.go
120
xoraddr_test.go
@@ -11,10 +11,11 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
|
||||
@@ -31,82 +32,56 @@ func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
|
||||
func BenchmarkXORMappedAddress_GetFrom(b *testing.B) {
|
||||
msg := New()
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
assert.NoError(b, err)
|
||||
copy(msg.TransactionID[:], transactionID)
|
||||
addrValue, err := hex.DecodeString("00019cd5f49f38ae")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
assert.NoError(b, err)
|
||||
msg.Add(AttrXORMappedAddress, addrValue)
|
||||
addr := new(XORMappedAddress)
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := addr.GetFrom(msg); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
assert.NoError(b, addr.GetFrom(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestXORMappedAddress_GetFrom(t *testing.T) {
|
||||
m := New()
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
copy(m.TransactionID[:], transactionID)
|
||||
addrValue, err := hex.DecodeString("00019cd5f49f38ae")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
m.Add(AttrXORMappedAddress, addrValue)
|
||||
addr := new(XORMappedAddress)
|
||||
if err = addr.GetFrom(m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !addr.IP.Equal(net.ParseIP("213.141.156.236")) {
|
||||
t.Error("bad IP", addr.IP, "!=", "213.141.156.236")
|
||||
}
|
||||
if addr.Port != 48583 {
|
||||
t.Error("bad Port", addr.Port, "!=", 48583)
|
||||
}
|
||||
assert.NoError(t, addr.GetFrom(m))
|
||||
assert.True(t, addr.IP.Equal(net.ParseIP("213.141.156.236")))
|
||||
assert.Equal(t, 48583, addr.Port)
|
||||
t.Run("UnexpectedEOF", func(t *testing.T) {
|
||||
m := New()
|
||||
// {0, 1} is correct addr family.
|
||||
m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4})
|
||||
addr := new(XORMappedAddress)
|
||||
if err = addr.GetFrom(m); !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
t.Errorf("len(v) = 4 should render <%s> error, got <%s>",
|
||||
io.ErrUnexpectedEOF, err,
|
||||
)
|
||||
}
|
||||
assert.ErrorIs(t, addr.GetFrom(m), io.ErrUnexpectedEOF, "len(v) = 4 should return io.ErrUnexpectedEOF")
|
||||
})
|
||||
t.Run("AttrOverflowErr", func(t *testing.T) {
|
||||
m := New()
|
||||
// {0, 1} is correct addr family.
|
||||
m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4, 5, 6, 7, 8, 9, 1, 1, 1, 1, 1, 2, 3, 4})
|
||||
addr := new(XORMappedAddress)
|
||||
if err := addr.GetFrom(m); !IsAttrSizeOverflow(err) {
|
||||
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
|
||||
}
|
||||
assert.True(t, IsAttrSizeOverflow(addr.GetFrom(m)), "GetFrom should return *AttrOverflowErr")
|
||||
})
|
||||
}
|
||||
|
||||
func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
|
||||
msg := New()
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
copy(msg.TransactionID[:], transactionID)
|
||||
expectedIP := net.ParseIP("213.141.156.236")
|
||||
expectedPort := 21254
|
||||
addr := new(XORMappedAddress)
|
||||
|
||||
if err = addr.GetFrom(msg); err == nil {
|
||||
t.Fatal(err, "should be nil")
|
||||
}
|
||||
assert.Error(t, addr.GetFrom(msg))
|
||||
|
||||
addr.IP = expectedIP
|
||||
addr.Port = expectedPort
|
||||
@@ -115,20 +90,15 @@ func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
|
||||
|
||||
mRes := New()
|
||||
binary.BigEndian.PutUint16(msg.Raw[20+4:20+4+2], 0x21)
|
||||
if _, err = mRes.ReadFrom(bytes.NewReader(msg.Raw)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = addr.GetFrom(msg); err == nil {
|
||||
t.Fatal(err, "should not be nil")
|
||||
}
|
||||
_, err = mRes.ReadFrom(bytes.NewReader(msg.Raw))
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, addr.GetFrom(msg))
|
||||
}
|
||||
|
||||
func TestXORMappedAddress_AddTo(t *testing.T) {
|
||||
msg := New()
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
copy(msg.TransactionID[:], transactionID)
|
||||
expectedIP := net.ParseIP("213.141.156.236")
|
||||
expectedPort := 21254
|
||||
@@ -136,31 +106,20 @@ func TestXORMappedAddress_AddTo(t *testing.T) {
|
||||
IP: net.ParseIP("213.141.156.236"),
|
||||
Port: expectedPort,
|
||||
}
|
||||
if err = addr.AddTo(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.NoError(t, addr.AddTo(msg))
|
||||
msg.WriteHeader()
|
||||
mRes := New()
|
||||
if _, err = mRes.Write(msg.Raw); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = addr.GetFrom(mRes); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !addr.IP.Equal(expectedIP) {
|
||||
t.Errorf("%s (got) != %s (expected)", addr.IP, expectedIP)
|
||||
}
|
||||
if addr.Port != expectedPort {
|
||||
t.Error("bad Port", addr.Port, "!=", expectedPort)
|
||||
}
|
||||
_, err = mRes.Write(msg.Raw)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, addr.GetFrom(mRes))
|
||||
assert.True(t, addr.IP.Equal(expectedIP), "Expected %s, got %s", expectedIP, addr.IP)
|
||||
assert.Equal(t, expectedPort, addr.Port)
|
||||
}
|
||||
|
||||
func TestXORMappedAddress_AddTo_IPv6(t *testing.T) {
|
||||
msg := New()
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
copy(msg.TransactionID[:], transactionID)
|
||||
expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009")
|
||||
expectedPort := 21254
|
||||
@@ -172,19 +131,12 @@ func TestXORMappedAddress_AddTo_IPv6(t *testing.T) {
|
||||
msg.WriteHeader()
|
||||
|
||||
mRes := New()
|
||||
if _, err = mRes.ReadFrom(msg.reader()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = mRes.ReadFrom(msg.reader())
|
||||
assert.NoError(t, err)
|
||||
gotAddr := new(XORMappedAddress)
|
||||
if err = gotAddr.GetFrom(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !gotAddr.IP.Equal(expectedIP) {
|
||||
t.Error("bad IP", gotAddr.IP, "!=", expectedIP)
|
||||
}
|
||||
if gotAddr.Port != expectedPort {
|
||||
t.Error("bad Port", gotAddr.Port, "!=", expectedPort)
|
||||
}
|
||||
assert.NoError(t, gotAddr.GetFrom(mRes))
|
||||
assert.True(t, gotAddr.IP.Equal(expectedIP), "Expected %s, got %s", expectedIP, gotAddr.IP)
|
||||
assert.Equal(t, expectedPort, gotAddr.Port)
|
||||
}
|
||||
|
||||
func TestXORMappedAddress_AddTo_Invalid(t *testing.T) {
|
||||
@@ -193,9 +145,7 @@ func TestXORMappedAddress_AddTo_Invalid(t *testing.T) {
|
||||
IP: []byte{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
Port: 21254,
|
||||
}
|
||||
if err := addr.AddTo(m); !errors.Is(err, ErrBadIPLength) {
|
||||
t.Errorf("AddTo should return %q, got: %v", ErrBadIPLength, err)
|
||||
}
|
||||
assert.ErrorIs(t, addr.AddTo(m), ErrBadIPLength)
|
||||
}
|
||||
|
||||
func TestXORMappedAddress_String(t *testing.T) {
|
||||
@@ -219,12 +169,6 @@ func TestXORMappedAddress_String(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for i, c := range tt {
|
||||
if got := c.in.String(); got != c.out {
|
||||
t.Errorf("[%d]: XORMappesAddres.String() %s (got) != %s (expected)",
|
||||
i,
|
||||
got,
|
||||
c.out,
|
||||
)
|
||||
}
|
||||
assert.Equalf(t, c.out, c.in.String(), "[%d]: XORMappesAddres.String()", i)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user