Update lint rules, force testify/assert for tests

Use testify's assert package instead of the standard library's testing
package.
This commit is contained in:
Joe Turki
2025-04-07 02:11:06 +02:00
parent f00fc07896
commit 8867eb8597
23 changed files with 620 additions and 1535 deletions

View File

@@ -19,12 +19,16 @@ linters-settings:
recommendations: recommendations:
- errors - errors
forbidigo: forbidigo:
analyze-types: true
forbid: forbid:
- ^fmt.Print(f|ln)?$ - ^fmt.Print(f|ln)?$
- ^log.(Panic|Fatal|Print)(f|ln)?$ - ^log.(Panic|Fatal|Print)(f|ln)?$
- ^os.Exit$ - ^os.Exit$
- ^panic$ - ^panic$
- ^print(ln)?$ - ^print(ln)?$
- p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$
pkg: ^testing$
msg: "use testify/assert instead"
varnamelen: varnamelen:
max-distance: 12 max-distance: 12
min-name-length: 2 min-name-length: 2
@@ -127,9 +131,12 @@ issues:
exclude-dirs-use-default: false exclude-dirs-use-default: false
exclude-rules: exclude-rules:
# Allow complex tests and examples, better to be self contained # Allow complex tests and examples, better to be self contained
- path: (examples|main\.go|_test\.go) - path: (examples|main\.go)
linters: linters:
- gocognit
- forbidigo - forbidigo
- path: _test\.go
linters:
- gocognit - gocognit
# Allow forbidden identifiers in CLI commands # Allow forbidden identifiers in CLI commands

View File

@@ -4,10 +4,11 @@
package stun package stun
import ( import (
"errors"
"io" "io"
"net" "net"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestMappedAddress(t *testing.T) { func TestMappedAddress(t *testing.T) {
@@ -16,48 +17,32 @@ func TestMappedAddress(t *testing.T) {
IP: net.ParseIP("122.12.34.5"), IP: net.ParseIP("122.12.34.5"),
Port: 5412, Port: 5412,
} }
if addr.String() != "122.12.34.5:5412" { assert.Equal(t, "122.12.34.5:5412", addr.String(), "bad string")
t.Error("bad string", addr)
}
t.Run("Bad length", func(t *testing.T) { t.Run("Bad length", func(t *testing.T) {
badAddr := &MappedAddress{ badAddr := &MappedAddress{
IP: net.IP{1, 2, 3}, IP: net.IP{1, 2, 3},
} }
if err := badAddr.AddTo(msg); err == nil { assert.Error(t, badAddr.AddTo(msg), "should error")
t.Error("should error")
}
}) })
t.Run("AddTo", func(t *testing.T) { t.Run("AddTo", func(t *testing.T) {
if err := addr.AddTo(msg); err != nil { assert.NoError(t, addr.AddTo(msg))
t.Error(err)
}
t.Run("GetFrom", func(t *testing.T) { t.Run("GetFrom", func(t *testing.T) {
got := new(MappedAddress) got := new(MappedAddress)
if err := got.GetFrom(msg); err != nil { assert.NoError(t, got.GetFrom(msg))
t.Error(err) assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP)
}
if !got.IP.Equal(addr.IP) {
t.Error("got bad IP: ", got.IP)
}
t.Run("Not found", func(t *testing.T) { t.Run("Not found", func(t *testing.T) {
message := new(Message) message := new(Message)
if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) { assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found")
t.Error("should be not found: ", err)
}
}) })
t.Run("Bad family", func(t *testing.T) { t.Run("Bad family", func(t *testing.T) {
v, _ := msg.Attributes.Get(AttrMappedAddress) v, _ := msg.Attributes.Get(AttrMappedAddress)
v.Value[0] = 32 v.Value[0] = 32
if err := got.GetFrom(msg); err == nil { assert.Error(t, got.GetFrom(msg), "should error")
t.Error("should error")
}
}) })
t.Run("Bad length", func(t *testing.T) { t.Run("Bad length", func(t *testing.T) {
message := new(Message) message := new(Message)
message.Add(AttrMappedAddress, []byte{1, 2, 3}) message.Add(AttrMappedAddress, []byte{1, 2, 3})
if err := got.GetFrom(message); !errors.Is(err, io.ErrUnexpectedEOF) { assert.ErrorIs(t, got.GetFrom(message), io.ErrUnexpectedEOF)
t.Errorf("<%s> should be <%s>", err, io.ErrUnexpectedEOF)
}
}) })
}) })
}) })
@@ -70,22 +55,14 @@ func TestMappedAddressV6(t *testing.T) { //nolint:dupl
Port: 5412, Port: 5412,
} }
t.Run("AddTo", func(t *testing.T) { t.Run("AddTo", func(t *testing.T) {
if err := addr.AddTo(m); err != nil { assert.NoError(t, addr.AddTo(m))
t.Error(err)
}
t.Run("GetFrom", func(t *testing.T) { t.Run("GetFrom", func(t *testing.T) {
got := new(MappedAddress) got := new(MappedAddress)
if err := got.GetFrom(m); err != nil { assert.NoError(t, got.GetFrom(m))
t.Error(err) assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP)
}
if !got.IP.Equal(addr.IP) {
t.Error("got bad IP: ", got.IP)
}
t.Run("Not found", func(t *testing.T) { t.Run("Not found", func(t *testing.T) {
message := new(Message) message := new(Message)
if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) { assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found")
t.Error("should be not found: ", err)
}
}) })
}) })
}) })
@@ -98,22 +75,14 @@ func TestAlternateServer(t *testing.T) { //nolint:dupl
Port: 5412, Port: 5412,
} }
t.Run("AddTo", func(t *testing.T) { t.Run("AddTo", func(t *testing.T) {
if err := addr.AddTo(m); err != nil { assert.NoError(t, addr.AddTo(m))
t.Error(err)
}
t.Run("GetFrom", func(t *testing.T) { t.Run("GetFrom", func(t *testing.T) {
got := new(AlternateServer) got := new(AlternateServer)
if err := got.GetFrom(m); err != nil { assert.NoError(t, got.GetFrom(m))
t.Error(err) assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP)
}
if !got.IP.Equal(addr.IP) {
t.Error("got bad IP: ", got.IP)
}
t.Run("Not found", func(t *testing.T) { t.Run("Not found", func(t *testing.T) {
message := new(Message) message := new(Message)
if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) { assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found")
t.Error("should be not found: ", err)
}
}) })
}) })
}) })
@@ -126,22 +95,14 @@ func TestOtherAddress(t *testing.T) { //nolint:dupl
Port: 5412, Port: 5412,
} }
t.Run("AddTo", func(t *testing.T) { t.Run("AddTo", func(t *testing.T) {
if err := addr.AddTo(m); err != nil { assert.NoError(t, addr.AddTo(m))
t.Error(err)
}
t.Run("GetFrom", func(t *testing.T) { t.Run("GetFrom", func(t *testing.T) {
got := new(OtherAddress) got := new(OtherAddress)
if err := got.GetFrom(m); err != nil { assert.NoError(t, got.GetFrom(m))
t.Error(err) assert.True(t, got.IP.Equal(addr.IP), "got bad IP: %v", got.IP)
}
if !got.IP.Equal(addr.IP) {
t.Error("got bad IP: ", got.IP)
}
t.Run("Not found", func(t *testing.T) { t.Run("Not found", func(t *testing.T) {
message := new(Message) message := new(Message)
if err := got.GetFrom(message); !errors.Is(err, ErrAttributeNotFound) { assert.ErrorIs(t, got.GetFrom(message), ErrAttributeNotFound, "should be not found")
t.Error("should be not found: ", err)
}
}) })
}) })
}) })

View File

@@ -7,84 +7,44 @@ import (
"errors" "errors"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func TestAgent_ProcessInTransaction(t *testing.T) { func TestAgent_ProcessInTransaction(t *testing.T) {
msg := New() msg := New()
agent := NewAgent(func(e Event) { agent := NewAgent(func(e Event) {
if e.Error != nil { assert.NoError(t, e.Error, "got error")
t.Errorf("got error: %s", e.Error) assert.True(t, e.Message.Equal(msg), "%s (got) != %s (expected)", e.Message, msg)
}
if !e.Message.Equal(msg) {
t.Errorf("%s (got) != %s (expected)", e.Message, msg)
}
}) })
if err := msg.NewTransactionID(); err != nil { assert.NoError(t, msg.NewTransactionID())
t.Fatal(err) assert.NoError(t, agent.Start(msg.TransactionID, time.Time{}))
} assert.NoError(t, agent.Process(msg))
if err := agent.Start(msg.TransactionID, time.Time{}); err != nil { assert.NoError(t, agent.Close())
t.Fatal(err)
}
if err := agent.Process(msg); err != nil {
t.Error(err)
}
if err := agent.Close(); err != nil {
t.Error(err)
}
} }
func TestAgent_Process(t *testing.T) { func TestAgent_Process(t *testing.T) {
msg := New() msg := New()
agent := NewAgent(func(e Event) { agent := NewAgent(func(e Event) {
if e.Error != nil { assert.NoError(t, e.Error, "got error")
t.Errorf("got error: %s", e.Error) assert.True(t, e.Message.Equal(msg), "%s (got) != %s (expected)", e.Message, msg)
}
if !e.Message.Equal(msg) {
t.Errorf("%s (got) != %s (expected)", e.Message, msg)
}
}) })
if err := msg.NewTransactionID(); err != nil { assert.NoError(t, msg.NewTransactionID())
t.Fatal(err) assert.NoError(t, agent.Process(msg))
} assert.NoError(t, agent.Close())
if err := agent.Process(msg); err != nil { assert.ErrorIs(t, agent.Process(msg), ErrAgentClosed)
t.Error(err)
}
if err := agent.Close(); err != nil {
t.Error(err)
}
if err := agent.Process(msg); !errors.Is(err, ErrAgentClosed) {
t.Errorf("closed agent should return <%s>, but got <%s>",
ErrAgentClosed, err,
)
}
} }
func TestAgent_Start(t *testing.T) { func TestAgent_Start(t *testing.T) {
agent := NewAgent(nil) agent := NewAgent(nil)
id := NewTransactionID() id := NewTransactionID()
deadline := time.Now().AddDate(0, 0, 1) deadline := time.Now().AddDate(0, 0, 1)
if err := agent.Start(id, deadline); err != nil { assert.NoError(t, agent.Start(id, deadline), "failed to start transaction")
t.Errorf("failed to statt transaction: %s", err) assert.ErrorIs(t, agent.Start(id, deadline), ErrTransactionExists)
} assert.NoError(t, agent.Close())
if err := agent.Start(id, deadline); !errors.Is(err, ErrTransactionExists) {
t.Errorf("duplicate start should return <%s>, got <%s>",
ErrTransactionExists, err,
)
}
if err := agent.Close(); err != nil {
t.Error(err)
}
id = NewTransactionID() id = NewTransactionID()
if err := agent.Start(id, deadline); !errors.Is(err, ErrAgentClosed) { assert.ErrorIs(t, agent.Start(id, deadline), ErrAgentClosed)
t.Errorf("start on closed agent should return <%s>, got <%s>", assert.ErrorIs(t, agent.SetHandler(nil), ErrAgentClosed)
ErrAgentClosed, err,
)
}
if err := agent.SetHandler(nil); !errors.Is(err, ErrAgentClosed) {
t.Errorf("SetHandler on closed agent should return <%s>, got <%s>",
ErrAgentClosed, err,
)
}
} }
func TestAgent_Stop(t *testing.T) { func TestAgent_Stop(t *testing.T) {
@@ -92,36 +52,20 @@ func TestAgent_Stop(t *testing.T) {
agent := NewAgent(func(e Event) { agent := NewAgent(func(e Event) {
called <- e called <- e
}) })
if err := agent.Stop(transactionID{}); !errors.Is(err, ErrTransactionNotExists) { assert.ErrorIs(t, agent.Stop(transactionID{}), ErrTransactionNotExists)
t.Fatalf("unexpected error: %s, should be %s", err, ErrTransactionNotExists)
}
id := NewTransactionID() id := NewTransactionID()
timeout := time.Millisecond * 200 timeout := time.Millisecond * 200
if err := agent.Start(id, time.Now().Add(timeout)); err != nil { assert.NoError(t, agent.Start(id, time.Now().Add(timeout)))
t.Fatal(err) assert.NoError(t, agent.Stop(id))
}
if err := agent.Stop(id); err != nil {
t.Fatal(err)
}
select { select {
case e := <-called: case e := <-called:
if !errors.Is(e.Error, ErrTransactionStopped) { assert.ErrorIs(t, e.Error, ErrTransactionStopped)
t.Fatalf("unexpected error: %s, should be %s",
e.Error, ErrTransactionStopped,
)
}
case <-time.After(timeout * 2): case <-time.After(timeout * 2):
t.Fatal("timed out") assert.Fail(t, "timed out")
}
if err := agent.Close(); err != nil {
t.Fatal(err)
}
if err := agent.Close(); !errors.Is(err, ErrAgentClosed) {
t.Fatalf("a.Close returned %s instead of %s", err, ErrAgentClosed)
}
if err := agent.Stop(transactionID{}); !errors.Is(err, ErrAgentClosed) {
t.Fatalf("unexpected error: %s, should be %s", err, ErrAgentClosed)
} }
assert.NoError(t, agent.Close())
assert.ErrorIs(t, agent.Close(), ErrAgentClosed)
assert.ErrorIs(t, agent.Stop(transactionID{}), ErrAgentClosed)
} }
func TestAgent_GC(t *testing.T) { //nolint:cyclop func TestAgent_GC(t *testing.T) { //nolint:cyclop
@@ -136,60 +80,41 @@ func TestAgent_GC(t *testing.T) { //nolint:cyclop
agent.SetHandler(func(e Event) { //nolint:errcheck,gosec agent.SetHandler(func(e Event) { //nolint:errcheck,gosec
id := e.TransactionID id := e.TransactionID
shouldTimeOut, found := shouldTimeOutID[id] shouldTimeOut, found := shouldTimeOutID[id]
if !found { assert.True(t, found, "unexpected transaction ID")
t.Error("unexpected transaction ID") if shouldTimeOut {
} assert.ErrorIs(t, e.Error, ErrTransactionTimeOut, "%x should time out", id)
if shouldTimeOut && !errors.Is(e.Error, ErrTransactionTimeOut) { } else {
t.Errorf("%x should time out, but got %v", id, e.Error) assert.False(t, errors.Is(e.Error, ErrTransactionTimeOut), "%x should not time out", id)
}
if !shouldTimeOut && errors.Is(e.Error, ErrTransactionTimeOut) {
t.Errorf("%x should not time out, but got %v", id, e.Error)
} }
}) })
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
id := NewTransactionID() id := NewTransactionID()
shouldTimeOutID[id] = false shouldTimeOutID[id] = false
if err := agent.Start(id, deadline); err != nil { assert.NoError(t, agent.Start(id, deadline))
t.Fatal(err)
}
} }
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
id := NewTransactionID() id := NewTransactionID()
shouldTimeOutID[id] = true shouldTimeOutID[id] = true
if err := agent.Start(id, deadlineNotGC); err != nil { assert.NoError(t, agent.Start(id, deadlineNotGC))
t.Fatal(err)
}
}
if err := agent.Collect(gcDeadline); err != nil {
t.Fatal(err)
}
if err := agent.Close(); err != nil {
t.Error(err)
}
if err := agent.Collect(gcDeadline); !errors.Is(err, ErrAgentClosed) {
t.Errorf("should <%s>, but got <%s>", ErrAgentClosed, err)
} }
assert.NoError(t, agent.Collect(gcDeadline))
assert.NoError(t, agent.Close())
assert.ErrorIs(t, agent.Collect(gcDeadline), ErrAgentClosed)
} }
func BenchmarkAgent_GC(b *testing.B) { func BenchmarkAgent_GC(b *testing.B) {
agent := NewAgent(nil) agent := NewAgent(nil)
deadline := time.Now().AddDate(0, 0, 1) deadline := time.Now().AddDate(0, 0, 1)
for i := 0; i < agentCollectCap; i++ { for i := 0; i < agentCollectCap; i++ {
if err := agent.Start(NewTransactionID(), deadline); err != nil { assert.NoError(b, agent.Start(NewTransactionID(), deadline))
b.Fatal(err)
}
} }
defer func() { defer func() {
if err := agent.Close(); err != nil { assert.NoError(b, agent.Close())
b.Error(err)
}
}() }()
b.ReportAllocs() b.ReportAllocs()
gcDeadline := deadline.Add(-time.Second) gcDeadline := deadline.Add(-time.Second)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := agent.Collect(gcDeadline); err != nil { assert.NoError(b, agent.Collect(gcDeadline))
b.Fatal(err)
}
} }
} }
@@ -197,20 +122,14 @@ func BenchmarkAgent_Process(b *testing.B) {
agent := NewAgent(nil) agent := NewAgent(nil)
deadline := time.Now().AddDate(0, 0, 1) deadline := time.Now().AddDate(0, 0, 1)
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
if err := agent.Start(NewTransactionID(), deadline); err != nil { assert.NoError(b, agent.Start(NewTransactionID(), deadline))
b.Fatal(err)
}
} }
defer func() { defer func() {
if err := agent.Close(); err != nil { assert.NoError(b, agent.Close())
b.Error(err)
}
}() }()
b.ReportAllocs() b.ReportAllocs()
m := MustBuild(TransactionID) m := MustBuild(TransactionID)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := agent.Process(m); err != nil { assert.NoError(b, agent.Process(m))
b.Fatal(err)
}
} }
} }

View File

@@ -6,7 +6,11 @@
package stun package stun
import "testing" import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAttrOverflowErr_Error(t *testing.T) { func TestAttrOverflowErr_Error(t *testing.T) {
err := AttrOverflowErr{ err := AttrOverflowErr{
@@ -14,9 +18,7 @@ func TestAttrOverflowErr_Error(t *testing.T) {
Max: 50, Max: 50,
Type: AttrLifetime, Type: AttrLifetime,
} }
if err.Error() != "incorrect length of LIFETIME attribute: 100 exceeds maximum 50" { assert.Equal(t, "incorrect length of LIFETIME attribute: 100 exceeds maximum 50", err.Error())
t.Error("bad error string", err)
}
} }
func TestAttrLengthErr_Error(t *testing.T) { func TestAttrLengthErr_Error(t *testing.T) {
@@ -25,7 +27,5 @@ func TestAttrLengthErr_Error(t *testing.T) {
Expected: 15, Expected: 15,
Got: 99, Got: 99,
} }
if err.Error() != "incorrect length of ERROR-CODE attribute: got 99, expected 15" { assert.Equal(t, "incorrect length of ERROR-CODE attribute: got 99, expected 15", err.Error())
t.Errorf("bad error string: %s", err)
}
} }

View File

@@ -6,6 +6,8 @@ package stun
import ( import (
"bytes" "bytes"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func BenchmarkMessage_GetNotFound(b *testing.B) { func BenchmarkMessage_GetNotFound(b *testing.B) {
@@ -31,16 +33,10 @@ func TestRawAttribute_AddTo(t *testing.T) {
Type: AttrData, Type: AttrData,
Value: v, Value: v,
}) })
if err != nil { assert.NoError(t, err)
t.Fatal(err)
}
gotV, gotErr := m.Get(AttrData) gotV, gotErr := m.Get(AttrData)
if gotErr != nil { assert.NoError(t, gotErr)
t.Fatal(gotErr) assert.True(t, bytes.Equal(gotV, v), "value mismatch")
}
if !bytes.Equal(gotV, v) {
t.Error("value mismatch")
}
} }
func TestMessage_GetNoAllocs(t *testing.T) { func TestMessage_GetNoAllocs(t *testing.T) {
@@ -52,17 +48,13 @@ func TestMessage_GetNoAllocs(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() { allocs := testing.AllocsPerRun(10, func() {
msg.Get(AttrSoftware) //nolint:errcheck,gosec msg.Get(AttrSoftware) //nolint:errcheck,gosec
}) })
if allocs > 0 { assert.Zero(t, allocs, "allocated memory, but should not")
t.Error("allocated memory, but should not")
}
}) })
t.Run("Not found", func(t *testing.T) { t.Run("Not found", func(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() { allocs := testing.AllocsPerRun(10, func() {
msg.Get(AttrOrigin) //nolint:errcheck,gosec msg.Get(AttrOrigin) //nolint:errcheck,gosec
}) })
if allocs > 0 { assert.Zero(t, allocs, "allocated memory, but should not")
t.Error("allocated memory, but should not")
}
}) })
} }
@@ -83,11 +75,8 @@ func TestPadding(t *testing.T) {
{40, 40}, // 10 {40, 40}, // 10
} }
for i, c := range tt { for i, c := range tt {
if got := nearestPaddedValueLength(c.in); got != c.out { got := nearestPaddedValueLength(c.in)
t.Errorf("[%d]: padd(%d) %d (got) != %d (expected)", assert.Equal(t, c.out, got, "[%d]: padd(%d)", i, c.in)
i, c.in, got, c.out,
)
}
} }
} }
@@ -102,9 +91,8 @@ func TestAttrTypeRange(t *testing.T) {
a := a a := a
t.Run(a.String(), func(t *testing.T) { t.Run(a.String(), func(t *testing.T) {
a := a a := a
if a.Optional() || !a.Required() { assert.True(t, a.Required(), "should be required")
t.Error("should be required") assert.False(t, a.Optional(), "should be required")
}
}) })
} }
for _, a := range []AttrType{ for _, a := range []AttrType{
@@ -114,9 +102,8 @@ func TestAttrTypeRange(t *testing.T) {
} { } {
a := a a := a
t.Run(a.String(), func(t *testing.T) { t.Run(a.String(), func(t *testing.T) {
if a.Required() || !a.Optional() { assert.False(t, a.Required(), "should be optional")
t.Error("should be optional") assert.True(t, a.Optional(), "should be optional")
}
}) })
} }
} }

View File

@@ -18,6 +18,8 @@ import (
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
var ( var (
@@ -88,13 +90,9 @@ func BenchmarkClient_Do(b *testing.B) {
client, err := NewClient(noopConnection{}, client, err := NewClient(noopConnection{},
WithAgent(agent), WithAgent(agent),
) )
if err != nil { assert.NoError(b, err)
log.Fatal(err)
}
defer func() { defer func() {
if closeErr := client.Close(); closeErr != nil { assert.NoError(b, client.Close())
panic(closeErr)
}
}() }()
noopF := func(Event) { noopF := func(Event) {
@@ -163,9 +161,8 @@ func TestClosedOrPanic(t *testing.T) {
func() { func() {
defer func() { defer func() {
r, ok := recover().(error) r, ok := recover().(error)
if !ok || !errors.Is(r, io.EOF) { assert.True(t, ok, "should be error")
t.Error(r) assert.ErrorIs(t, r, io.EOF)
}
}() }()
closedOrPanic(io.EOF) closedOrPanic(io.EOF)
}() }()
@@ -203,46 +200,32 @@ func TestClient_Start(t *testing.T) { //nolint:cyclop
}, },
} }
client, err := NewClient(conn) client, err := NewClient(conn)
if err != nil { assert.NoError(t, err)
log.Fatal(err)
}
defer func() { defer func() {
if err := client.Close(); err != nil { assert.NoError(t, client.Close())
t.Error(err) assert.Error(t, client.Close(), "second close should fail")
} assert.Error(t, client.Do(MustBuild(TransactionID), nil), "Do after Close should fail")
if err := client.Close(); err == nil {
t.Error("second close should fail")
}
if err := client.Do(MustBuild(TransactionID), nil); err == nil {
t.Error("Do after Close should fail")
}
}() }()
msg := MustBuild(response, BindingRequest) msg := MustBuild(response, BindingRequest)
t.Log("init") t.Log("init")
got := make(chan struct{}) got := make(chan struct{})
write <- struct{}{} write <- struct{}{}
t.Log("starting the first transaction") t.Log("starting the first transaction")
if err := client.Start(msg, func(event Event) { assert.NoError(t, client.Start(msg, func(event Event) {
t.Log("got first transaction callback") t.Log("got first transaction callback")
if event.Error != nil { assert.NoError(t, event.Error)
t.Error(event.Error)
}
got <- struct{}{} got <- struct{}{}
}); err != nil { }))
t.Error(err)
}
t.Log("starting the second transaction") t.Log("starting the second transaction")
if err := client.Start(msg, func(Event) { assert.ErrorIs(t, client.Start(msg, func(Event) {
t.Error("should not be called") assert.Fail(t, "should not be called")
}); !errors.Is(err, ErrTransactionExists) { }), ErrTransactionExists)
t.Errorf("unexpected error %v", err)
}
read <- struct{}{} read <- struct{}{}
select { select {
case <-got: case <-got:
// pass // pass
case <-time.After(time.Millisecond * 10): case <-time.After(time.Millisecond * 10):
t.Error("timed out") assert.Fail(t, "timed out")
} }
} }
@@ -256,34 +239,20 @@ func TestClient_Do(t *testing.T) {
}, },
} }
client, err := NewClient(conn) client, err := NewClient(conn)
if err != nil { assert.NoError(t, err)
log.Fatal(err)
}
defer func() { defer func() {
if err := client.Close(); err != nil { assert.NoError(t, client.Close())
t.Error(err) assert.Error(t, client.Close(), "second close should fail")
} assert.Error(t, client.Do(MustBuild(TransactionID), nil), "Do after Close should fail")
if err := client.Close(); err == nil {
t.Error("second close should fail")
}
if err := client.Do(MustBuild(TransactionID), nil); err == nil {
t.Error("Do after Close should fail")
}
}() }()
m := MustBuild( m := MustBuild(
NewTransactionIDSetter(response.TransactionID), NewTransactionIDSetter(response.TransactionID),
) )
if err := client.Do(m, func(event Event) { assert.NoError(t, client.Do(m, func(event Event) {
if event.Error != nil { assert.NoError(t, event.Error)
t.Error(event.Error) }))
}
}); err != nil {
t.Error(err)
}
m = MustBuild(TransactionID) m = MustBuild(TransactionID)
if err := client.Do(m, nil); err != nil { assert.NoError(t, client.Do(m, nil))
t.Error(err)
}
} }
func TestCloseErr_Error(t *testing.T) { func TestCloseErr_Error(t *testing.T) {
@@ -299,11 +268,7 @@ func TestCloseErr_Error(t *testing.T) {
ConnectionErr: io.ErrUnexpectedEOF, ConnectionErr: io.ErrUnexpectedEOF,
}, "failed to close: unexpected EOF (connection), <nil> (agent)"}, }, "failed to close: unexpected EOF (connection), <nil> (agent)"},
} { } {
if out := testCase.Err.Error(); out != testCase.Out { assert.Equal(t, testCase.Out, testCase.Err.Error(), "[%d]: Error(%#v)", id, testCase.Err)
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
id, testCase.Err, out, testCase.Out,
)
}
} }
} }
@@ -320,11 +285,7 @@ func TestStopErr_Error(t *testing.T) {
Cause: io.ErrUnexpectedEOF, Cause: io.ErrUnexpectedEOF,
}, "error while stopping due to unexpected EOF: <nil>"}, }, "error while stopping due to unexpected EOF: <nil>"},
} { } {
if out := testcase.Err.Error(); out != testcase.Out { assert.Equal(t, testcase.Out, testcase.Err.Error(), "[%d]: Error(%#v)", id, testcase.Err)
t.Errorf("[%d]: Error(%#v) %q (got) != %q (expected)",
id, testcase.Err, out, testcase.Out,
)
}
} }
} }
@@ -365,25 +326,15 @@ func TestClientAgentError(t *testing.T) {
startErr: io.ErrUnexpectedEOF, startErr: io.ErrUnexpectedEOF,
}), }),
) )
if err != nil { assert.NoError(t, err)
log.Fatal(err)
}
defer func() { defer func() {
if err := client.Close(); err != nil { assert.NoError(t, client.Close())
t.Error(err)
}
}() }()
m := MustBuild(NewTransactionIDSetter(response.TransactionID)) m := MustBuild(NewTransactionIDSetter(response.TransactionID))
if err := client.Do(m, nil); err != nil { assert.NoError(t, client.Do(m, nil))
t.Error(err) assert.ErrorIs(t, client.Do(m, func(event Event) {
} assert.Error(t, event.Error, "error expected")
if err := client.Do(m, func(event Event) { }), io.ErrUnexpectedEOF)
if event.Error == nil {
t.Error("error expected")
}
}); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Error("error expected")
}
} }
func TestClientConnErr(t *testing.T) { func TestClientConnErr(t *testing.T) {
@@ -393,21 +344,13 @@ func TestClientConnErr(t *testing.T) {
}, },
} }
client, err := NewClient(conn) client, err := NewClient(conn)
if err != nil { assert.NoError(t, err)
log.Fatal(err)
}
defer func() { defer func() {
if err := client.Close(); err != nil { assert.NoError(t, client.Close())
t.Error(err)
}
}() }()
m := MustBuild(TransactionID) m := MustBuild(TransactionID)
if err := client.Do(m, nil); err == nil { assert.Error(t, client.Do(m, nil), "error expected")
t.Error("error expected") assert.Error(t, client.Do(m, NoopHandler()), "error expected")
}
if err := client.Do(m, NoopHandler()); err == nil {
t.Error("error expected")
}
} }
func TestClientConnErrStopErr(t *testing.T) { func TestClientConnErrStopErr(t *testing.T) {
@@ -421,26 +364,19 @@ func TestClientConnErrStopErr(t *testing.T) {
stopErr: io.ErrUnexpectedEOF, stopErr: io.ErrUnexpectedEOF,
}), }),
) )
if err != nil { assert.NoError(t, err)
log.Fatal(err)
}
defer func() { defer func() {
if err := client.Close(); err != nil { assert.NoError(t, client.Close())
t.Error(err)
}
}() }()
m := MustBuild(TransactionID) m := MustBuild(TransactionID)
if err := client.Do(m, NoopHandler()); err == nil { assert.Error(t, client.Do(m, NoopHandler()), "error expected")
t.Error("error expected")
}
} }
func TestCallbackWaitHandler_setCallback(t *testing.T) { func TestCallbackWaitHandler_setCallback(t *testing.T) {
c := callbackWaitHandler{} c := callbackWaitHandler{}
defer func() { defer func() {
if err := recover(); err == nil { err := recover()
t.Error("should panic") assert.NotNil(t, err, "should panic")
}
}() }()
c.setCallback(nil) c.setCallback(nil)
} }
@@ -450,56 +386,39 @@ func TestCallbackWaitHandler_HandleEvent(t *testing.T) {
cond: sync.NewCond(new(sync.Mutex)), cond: sync.NewCond(new(sync.Mutex)),
} }
defer func() { defer func() {
if err := recover(); err == nil { err := recover()
t.Error("should panic") assert.NotNil(t, err, "should panic")
}
}() }()
c.HandleEvent(Event{}) c.HandleEvent(Event{})
} }
func TestNewClientNoConnection(t *testing.T) { func TestNewClientNoConnection(t *testing.T) {
c, err := NewClient(nil) c, err := NewClient(nil)
if c != nil { assert.Nil(t, c, "c should be nil")
t.Error("c should be nil") assert.ErrorIs(t, err, ErrNoConnection, "bad error")
}
if !errors.Is(err, ErrNoConnection) {
t.Error("bad error")
}
} }
func TestDial(t *testing.T) { func TestDial(t *testing.T) {
c, err := Dial("udp4", "localhost:3458") c, err := Dial("udp4", "localhost:3458")
if err != nil { assert.NoError(t, err)
t.Fatal(err)
}
defer func() { defer func() {
if err = c.Close(); err != nil { assert.NoError(t, c.Close())
t.Error(err)
}
}() }()
} }
func TestDialURI(t *testing.T) { func TestDialURI(t *testing.T) {
u, err := ParseURI("stun:localhost") u, err := ParseURI("stun:localhost")
if err != nil { assert.NoError(t, err)
t.Fatal(err)
}
c, err := DialURI(u, &DialConfig{}) c, err := DialURI(u, &DialConfig{})
if err != nil { assert.NoError(t, err)
t.Fatal(err)
}
defer func() { defer func() {
if err = c.Close(); err != nil { assert.NoError(t, c.Close())
t.Error(err)
}
}() }()
} }
func TestDialError(t *testing.T) { func TestDialError(t *testing.T) {
_, err := Dial("bad?network", "?????") _, err := Dial("bad?network", "?????")
if err == nil { assert.Error(t, err, "error expected")
t.Fatal("error expected")
}
} }
func TestClientCloseErr(t *testing.T) { func TestClientCloseErr(t *testing.T) {
@@ -516,13 +435,11 @@ func TestClientCloseErr(t *testing.T) {
closeErr: io.ErrUnexpectedEOF, closeErr: io.ErrUnexpectedEOF,
}), }),
) )
if err != nil { assert.NoError(t, err)
log.Fatal(err)
}
defer func() { defer func() {
if err, ok := c.Close().(CloseErr); !ok || !errors.Is(err.AgentErr, io.ErrUnexpectedEOF) { //nolint:errorlint err, ok := c.Close().(CloseErr) //nolint:errorlint
t.Error("unexpected close err") assert.True(t, ok, "should be CloseErr")
} assert.ErrorIs(t, err.AgentErr, io.ErrUnexpectedEOF, "unexpected close err")
}() }()
} }
@@ -542,12 +459,8 @@ func TestWithNoConnClose(t *testing.T) {
}), }),
WithNoConnClose(), WithNoConnClose(),
) )
if err != nil { assert.NoError(t, err)
log.Fatal(err) assert.NoError(t, c.Close(), "unexpected non-nil error")
}
if err := c.Close(); err != nil {
t.Error("unexpected non-nil error")
}
} }
type gcWaitAgent struct { type gcWaitAgent struct {
@@ -598,28 +511,20 @@ func TestClientGC(t *testing.T) {
WithAgent(agent), WithAgent(agent),
WithTimeoutRate(time.Millisecond), WithTimeoutRate(time.Millisecond),
) )
if err != nil { assert.NoError(t, err)
log.Fatal(err)
}
defer func() { defer func() {
if err = c.Close(); err != nil { assert.NoError(t, c.Close())
t.Error(err)
}
}() }()
select { select {
case <-agent.gc: case <-agent.gc:
case <-time.After(time.Millisecond * 200): case <-time.After(time.Millisecond * 200):
t.Error("timed out") assert.Fail(t, "timed out")
} }
} }
func TestClientCheckInit(t *testing.T) { func TestClientCheckInit(t *testing.T) {
if err := (&Client{}).Indicate(nil); !errors.Is(err, ErrClientNotInitialized) { assert.ErrorIs(t, (&Client{}).Indicate(nil), ErrClientNotInitialized)
t.Error("unexpected error") assert.ErrorIs(t, (&Client{}).Do(nil, nil), ErrClientNotInitialized)
}
if err := (&Client{}).Do(nil, nil); !errors.Is(err, ErrClientNotInitialized) {
t.Error("unexpected error")
}
} }
func captureLog() (*bytes.Buffer, func()) { func captureLog() (*bytes.Buffer, func()) {
@@ -645,9 +550,7 @@ func TestClientFinalizer(t *testing.T) {
}, },
} }
client, err := NewClient(conn) client, err := NewClient(conn)
if err != nil { assert.NoError(t, err)
log.Panic(err)
}
clientFinalizer(client) clientFinalizer(client)
clientFinalizer(client) clientFinalizer(client)
response := MustBuild(TransactionID, BindingSuccess) response := MustBuild(TransactionID, BindingSuccess)
@@ -663,9 +566,7 @@ func TestClientFinalizer(t *testing.T) {
closeErr: io.ErrUnexpectedEOF, closeErr: io.ErrUnexpectedEOF,
}), }),
) )
if err != nil { assert.NoError(t, err)
log.Panic(err)
}
clientFinalizer(client) clientFinalizer(client)
reader := bufio.NewScanner(buf) reader := bufio.NewScanner(buf)
var lines int var lines int
@@ -676,17 +577,11 @@ func TestClientFinalizer(t *testing.T) {
"<nil> (connection), unexpected EOF (agent)", "<nil> (connection), unexpected EOF (agent)",
} }
for reader.Scan() { for reader.Scan() {
if reader.Text() != expectedLines[lines] { assert.Equal(t, expectedLines[lines], reader.Text())
t.Error(reader.Text(), "!=", expectedLines[lines])
}
lines++ lines++
} }
if reader.Err() != nil { assert.NoError(t, reader.Err())
t.Error(err) assert.Equal(t, 3, lines, "incorrect count of log lines")
}
if lines != 3 {
t.Error("incorrect count of log lines:", lines)
}
} }
func TestCallbackWaitHandler(*testing.T) { func TestCallbackWaitHandler(*testing.T) {
@@ -784,9 +679,7 @@ func TestClientRetransmission(t *testing.T) {
response.Encode() response.Encode()
connL, connR := net.Pipe() connL, connR := net.Pipe()
defer func() { defer func() {
if closeErr := connL.Close(); closeErr != nil { assert.NoError(t, connR.Close())
panic(closeErr)
}
}() }()
collector := new(manualCollector) collector := new(manualCollector)
clock := &manualClock{current: time.Now()} clock := &manualClock{current: time.Now()}
@@ -814,36 +707,22 @@ func TestClientRetransmission(t *testing.T) {
WithCollector(collector), WithCollector(collector),
WithRTO(time.Millisecond), WithRTO(time.Millisecond),
) )
if err != nil { assert.NoError(t, err)
t.Fatal(err)
}
client.SetRTO(time.Second) client.SetRTO(time.Second)
gotReads := make(chan struct{}) gotReads := make(chan struct{})
go func() { go func() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
readN, readErr := connL.Read(buf) readN, readErr := connL.Read(buf)
if readErr != nil { assert.NoError(t, readErr)
t.Error(readErr) assert.True(t, IsMessage(buf[:readN]), "should be STUN")
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
readN, readErr = connL.Read(buf) readN, readErr = connL.Read(buf)
if readErr != nil { assert.NoError(t, readErr)
t.Error(readErr) assert.True(t, IsMessage(buf[:readN]), "should be STUN")
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
gotReads <- struct{}{} gotReads <- struct{}{}
}() }()
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
if event.Error != nil { assert.NoError(t, event.Error, "failed")
t.Error("failed") }))
}
}); doErr != nil {
t.Fatal(doErr)
}
<-gotReads <-gotReads
} }
@@ -854,9 +733,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
response.Encode() response.Encode()
connL, connR := net.Pipe() connL, connR := net.Pipe()
defer func() { defer func() {
if closeErr := connL.Close(); closeErr != nil { assert.NoError(t, connR.Close())
panic(closeErr)
}
}() }()
collector := new(manualCollector) collector := new(manualCollector)
clock := &manualClock{current: time.Now()} clock := &manualClock{current: time.Now()}
@@ -874,9 +751,7 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
WithClock(clock), WithClock(clock),
WithCollector(collector), WithCollector(collector),
) )
if err != nil { assert.NoError(t, err)
t.Fatal(err)
}
client.SetRTO(time.Second) client.SetRTO(time.Second)
conns := new(sync.WaitGroup) conns := new(sync.WaitGroup)
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
@@ -891,29 +766,21 @@ func testClientDoConcurrent(t *testing.T, concurrency int) { //nolint:cyclop
if errors.Is(readErr, io.EOF) { if errors.Is(readErr, io.EOF) {
break break
} }
t.Error(readErr) assert.NoError(t, readErr)
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
} }
assert.True(t, IsMessage(buf[:readN]), "should be STUN")
} }
}() }()
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
if doErr := client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) { assert.NoError(t, client.Do(MustBuild(TransactionID, BindingRequest), func(event Event) {
if event.Error != nil { assert.NoError(t, event.Error, "failed")
t.Error("failed") }))
}
}); doErr != nil {
t.Error(doErr)
}
}() }()
} }
wg.Wait() wg.Wait()
if connErr := connR.Close(); connErr != nil { assert.NoError(t, connR.Close())
t.Error(connErr)
}
conns.Wait() conns.Wait()
} }
@@ -942,24 +809,22 @@ func (c errorCollector) Close() error { return c.closeError }
func TestNewClient(t *testing.T) { func TestNewClient(t *testing.T) {
t.Run("SetCallbackError", func(t *testing.T) { t.Run("SetCallbackError", func(t *testing.T) {
setHandlerError := errClientSetHandler setHandlerError := errClientSetHandler
if _, createErr := NewClient(noopConnection{}, _, createErr := NewClient(noopConnection{},
WithAgent(&errorAgent{ WithAgent(&errorAgent{
setHandlerError: setHandlerError, setHandlerError: setHandlerError,
}), }),
); !errors.Is(createErr, setHandlerError) { )
t.Errorf("unexpected error returned: %v", createErr) assert.ErrorIs(t, createErr, setHandlerError, "unexpected error returned")
}
}) })
t.Run("CollectorStartError", func(t *testing.T) { t.Run("CollectorStartError", func(t *testing.T) {
startError := errClientStart startError := errClientStart
if _, createErr := NewClient(noopConnection{}, _, createErr := NewClient(noopConnection{},
WithAgent(&TestAgent{}), WithAgent(&TestAgent{}),
WithCollector(errorCollector{ WithCollector(errorCollector{
startError: startError, startError: startError,
}), }),
); !errors.Is(createErr, startError) { )
t.Errorf("unexpected error returned: %v", createErr) assert.ErrorIs(t, createErr, startError, "unexpected error returned")
}
}) })
} }
@@ -972,13 +837,9 @@ func TestClient_Close(t *testing.T) {
}), }),
WithAgent(&TestAgent{}), WithAgent(&TestAgent{}),
) )
if createErr != nil { assert.NoError(t, createErr, "unexpected create error returned")
t.Errorf("unexpected create error returned: %v", createErr)
}
gotCloseErr := c.Close() gotCloseErr := c.Close()
if !errors.Is(gotCloseErr, closeErr) { assert.ErrorIs(t, gotCloseErr, closeErr, "unexpected close error returned")
t.Errorf("unexpected close error returned: %v", gotCloseErr)
}
}) })
} }
@@ -992,19 +853,13 @@ func TestClientDefaultHandler(t *testing.T) {
client, createErr := NewClient(noopConnection{}, client, createErr := NewClient(noopConnection{},
WithAgent(agent), WithAgent(agent),
WithHandler(func(e Event) { WithHandler(func(e Event) {
if called { assert.False(t, called, "should not be called twice")
t.Error("should not be called twice")
}
called = true called = true
if e.TransactionID != id { assert.Equal(t, id, e.TransactionID, "wrong transaction ID")
t.Error("wrong transaction ID")
}
handlerCalled <- struct{}{} handlerCalled <- struct{}{}
}), }),
) )
if createErr != nil { assert.NoError(t, createErr)
t.Fatal(createErr)
}
go func() { go func() {
agent.h(Event{ agent.h(Event{
TransactionID: id, TransactionID: id,
@@ -1014,11 +869,9 @@ func TestClientDefaultHandler(t *testing.T) {
case <-handlerCalled: case <-handlerCalled:
// pass // pass
case <-time.After(time.Millisecond * 100): case <-time.After(time.Millisecond * 100):
t.Fatal("timed out") assert.Fail(t, "timed out")
}
if closeErr := client.Close(); closeErr != nil {
t.Error(closeErr)
} }
assert.NoError(t, client.Close())
// Handler call should be ignored. // Handler call should be ignored.
agent.h(Event{}) agent.h(Event{})
} }
@@ -1030,15 +883,9 @@ func TestClientClosedStart(t *testing.T) {
c, createErr := NewClient(noopConnection{}, c, createErr := NewClient(noopConnection{},
WithAgent(a), WithAgent(a),
) )
if createErr != nil { assert.NoError(t, createErr)
t.Fatal(createErr) assert.NoError(t, c.Close())
} assert.ErrorIs(t, c.start(&clientTransaction{}), ErrClientClosed)
if closeErr := c.Close(); closeErr != nil {
t.Error(closeErr)
}
if startErr := c.start(&clientTransaction{}); !errors.Is(startErr, ErrClientClosed) {
t.Error("should error")
}
} }
func TestWithNoRetransmit(t *testing.T) { func TestWithNoRetransmit(t *testing.T) {
@@ -1046,9 +893,7 @@ func TestWithNoRetransmit(t *testing.T) {
response.Encode() response.Encode()
connL, connR := net.Pipe() connL, connR := net.Pipe()
defer func() { defer func() {
if closeErr := connL.Close(); closeErr != nil { assert.NoError(t, connL.Close())
panic(closeErr)
}
}() }()
collector := new(manualCollector) collector := new(manualCollector)
clock := &manualClock{current: time.Now()} clock := &manualClock{current: time.Now()}
@@ -1062,7 +907,7 @@ func TestWithNoRetransmit(t *testing.T) {
Error: ErrTransactionTimeOut, Error: ErrTransactionTimeOut,
}) })
} else { } else {
t.Error("there should be no second attempt") assert.Fail(t, "there should be no second attempt")
go agent.h(Event{ go agent.h(Event{
TransactionID: id, TransactionID: id,
Error: ErrTransactionTimeOut, Error: ErrTransactionTimeOut,
@@ -1078,28 +923,18 @@ func TestWithNoRetransmit(t *testing.T) {
WithRTO(0), WithRTO(0),
WithNoRetransmit, WithNoRetransmit,
) )
if err != nil { assert.NoError(t, err)
t.Fatal(err)
}
gotReads := make(chan struct{}) gotReads := make(chan struct{})
go func() { go func() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
readN, readErr := connL.Read(buf) readN, readErr := connL.Read(buf)
if readErr != nil { assert.NoError(t, readErr)
t.Error(readErr) assert.True(t, IsMessage(buf[:readN]), "should be STUN")
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
gotReads <- struct{}{} gotReads <- struct{}{}
}() }()
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
if !errors.Is(event.Error, ErrTransactionTimeOut) { assert.ErrorIs(t, event.Error, ErrTransactionTimeOut, "unexpected error")
t.Error("unexpected error") }))
}
}); doErr != nil {
t.Fatal(err)
}
<-gotReads <-gotReads
} }
@@ -1114,9 +949,7 @@ func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
response.Encode() response.Encode()
connL, connR := net.Pipe() connL, connR := net.Pipe()
defer func() { defer func() {
if closeErr := connL.Close(); closeErr != nil { assert.NoError(t, connL.Close())
panic(closeErr)
}
}() }()
collector := new(manualCollector) collector := new(manualCollector)
shouldWait := false shouldWait := false
@@ -1169,9 +1002,7 @@ func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
t.Log("clock locked") t.Log("clock locked")
<-clockLocked <-clockLocked
t.Log("closing client") t.Log("closing client")
if closeErr := client.Close(); closeErr != nil { assert.NoError(t, client.Close())
t.Error(closeErr)
}
t.Log("client closed, unlocking clock") t.Log("client closed, unlocking clock")
clockWait <- struct{}{} clockWait <- struct{}{}
t.Log("clock unlocked") t.Log("clock unlocked")
@@ -1186,44 +1017,30 @@ func TestClientRTOStartErr(t *testing.T) { //nolint:cyclop
WithCollector(collector), WithCollector(collector),
WithRTO(time.Millisecond), WithRTO(time.Millisecond),
) )
if startClientErr != nil { assert.NoError(t, startClientErr)
t.Fatal(startClientErr)
}
go func() { go func() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
readN, readErr := connL.Read(buf) readN, readErr := connL.Read(buf)
if readErr != nil { assert.NoError(t, readErr)
t.Error(readErr) assert.True(t, IsMessage(buf[:readN]), "should be STUN")
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
readN, readErr = connL.Read(buf) readN, readErr = connL.Read(buf)
if readErr != nil { assert.NoError(t, readErr)
t.Error(readErr) assert.True(t, IsMessage(buf[:readN]), "should be STUN")
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
gotReads <- struct{}{} gotReads <- struct{}{}
}() }()
t.Log("starting") t.Log("starting")
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
if !errors.Is(event.Error, ErrClientClosed) { assert.ErrorIs(t, event.Error, ErrClientClosed)
t.Error(event.Error) }))
}
}); doErr != nil {
t.Error(doErr)
}
done <- struct{}{} done <- struct{}{}
}() }()
select { select {
case <-done: case <-done:
// ok // ok
case <-time.After(time.Second * 5): case <-time.After(time.Second * 5):
t.Error("timeout") assert.Fail(t, "timeout")
} }
} }
@@ -1232,9 +1049,7 @@ func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
response.Encode() response.Encode()
connL, connR := net.Pipe() connL, connR := net.Pipe()
defer func() { defer func() {
if closeErr := connL.Close(); closeErr != nil { assert.NoError(t, connL.Close())
panic(closeErr)
}
}() }()
collector := new(manualCollector) collector := new(manualCollector)
shouldWait := false shouldWait := false
@@ -1291,9 +1106,7 @@ func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
t.Log("clock locked") t.Log("clock locked")
<-clockLocked <-clockLocked
t.Log("closing connection") t.Log("closing connection")
if closeErr := connL.Close(); closeErr != nil { assert.NoError(t, connL.Close())
panic(closeErr)
}
t.Log("connection closed, unlocking clock") t.Log("connection closed, unlocking clock")
clockWait <- struct{}{} clockWait <- struct{}{}
t.Log("clock unlocked") t.Log("clock unlocked")
@@ -1308,52 +1121,33 @@ func TestClientRTOWriteErr(t *testing.T) { //nolint:cyclop
WithCollector(collector), WithCollector(collector),
WithRTO(time.Millisecond), WithRTO(time.Millisecond),
) )
if startClientErr != nil { assert.NoError(t, startClientErr)
t.Fatal(startClientErr)
}
go func() { go func() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
readN, readErr := connL.Read(buf) readN, readErr := connL.Read(buf)
if readErr != nil { assert.NoError(t, readErr)
t.Error(readErr) assert.True(t, IsMessage(buf[:readN]), "should be STUN")
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
readN, readErr = connL.Read(buf) readN, readErr = connL.Read(buf)
if readErr != nil { assert.NoError(t, readErr)
t.Error(readErr) assert.True(t, IsMessage(buf[:readN]), "should be STUN")
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
gotReads <- struct{}{} gotReads <- struct{}{}
}() }()
t.Log("starting") t.Log("starting")
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
var e StopErr var e StopErr
if !errors.As(event.Error, &e) { assert.ErrorAs(t, event.Error, &e)
t.Error(event.Error) assert.ErrorIs(t, e.Err, agentStopErr, "incorrect agent error")
} else { assert.ErrorIs(t, e.Cause, io.ErrClosedPipe, "incorrect connection error")
if !errors.Is(e.Err, agentStopErr) { }))
t.Error("incorrect agent error")
}
if !errors.Is(e.Cause, io.ErrClosedPipe) {
t.Error("incorrect connection error")
}
}
}); doErr != nil {
t.Error(doErr)
}
done <- struct{}{} done <- struct{}{}
}() }()
select { select {
case <-done: case <-done:
// ok // ok
case <-time.After(time.Second * 5): case <-time.After(time.Second * 5):
t.Error("timeout") assert.Fail(t, "timeout")
} }
} }
@@ -1362,9 +1156,7 @@ func TestClientRTOAgentErr(t *testing.T) {
response.Encode() response.Encode()
connL, connR := net.Pipe() connL, connR := net.Pipe()
defer func() { defer func() {
if closeErr := connL.Close(); closeErr != nil { assert.NoError(t, connL.Close())
panic(closeErr)
}
}() }()
collector := new(manualCollector) collector := new(manualCollector)
clock := callbackClock(time.Now) clock := callbackClock(time.Now)
@@ -1396,33 +1188,23 @@ func TestClientRTOAgentErr(t *testing.T) {
WithCollector(collector), WithCollector(collector),
WithRTO(time.Millisecond), WithRTO(time.Millisecond),
) )
if startClientErr != nil { assert.NoError(t, startClientErr)
t.Fatal(startClientErr)
}
go func() { go func() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
readN, readErr := connL.Read(buf) readN, readErr := connL.Read(buf)
if readErr != nil { assert.NoError(t, readErr)
t.Error(readErr) assert.True(t, IsMessage(buf[:readN]), "should be STUN")
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
gotReads <- struct{}{} gotReads <- struct{}{}
}() }()
t.Log("starting") t.Log("starting")
if doErr := client.Do(MustBuild(response, BindingRequest), func(event Event) { assert.NoError(t, client.Do(MustBuild(response, BindingRequest), func(event Event) {
if !errors.Is(event.Error, agentStartErr) { assert.ErrorIs(t, event.Error, agentStartErr)
t.Error(event.Error) }))
}
}); doErr != nil {
t.Error(doErr)
}
select { select {
case <-gotReads: case <-gotReads:
// ok // ok
case <-time.After(time.Second * 5): case <-time.After(time.Second * 5):
t.Error("reads timeout") assert.Fail(t, "reads timeout")
} }
} }
@@ -1431,9 +1213,7 @@ func TestClient_HandleProcessError(t *testing.T) {
response.Encode() response.Encode()
connL, connR := net.Pipe() connL, connR := net.Pipe()
defer func() { defer func() {
if closeErr := connL.Close(); closeErr != nil { assert.NoError(t, connL.Close())
panic(closeErr)
}
}() }()
collector := new(manualCollector) collector := new(manualCollector)
clock := callbackClock(time.Now) clock := callbackClock(time.Now)
@@ -1451,14 +1231,10 @@ func TestClient_HandleProcessError(t *testing.T) {
WithCollector(collector), WithCollector(collector),
WithRTO(time.Millisecond), WithRTO(time.Millisecond),
) )
if startClientErr != nil { assert.NoError(t, startClientErr)
t.Fatal(startClientErr)
}
go func() { go func() {
_, readErr := connL.Write(response.Raw) _, readErr := connL.Write(response.Raw)
if readErr != nil { assert.NoError(t, readErr)
t.Error(readErr)
}
gotWrites <- struct{}{} gotWrites <- struct{}{}
}() }()
t.Log("starting") t.Log("starting")
@@ -1466,20 +1242,16 @@ func TestClient_HandleProcessError(t *testing.T) {
case <-gotWrites: case <-gotWrites:
// ok // ok
case <-time.After(time.Second * 5): case <-time.After(time.Second * 5):
t.Error("reads timeout") assert.Fail(t, "reads timeout")
}
if closeErr := client.Close(); closeErr != nil {
t.Error(closeErr)
} }
assert.NoError(t, client.Close())
} }
func TestClientImmediateTimeout(t *testing.T) { func TestClientImmediateTimeout(t *testing.T) {
response := MustBuild(TransactionID, BindingSuccess) response := MustBuild(TransactionID, BindingSuccess)
connL, connR := net.Pipe() connL, connR := net.Pipe()
defer func() { defer func() {
if closeErr := connL.Close(); closeErr != nil { assert.NoError(t, connL.Close())
panic(closeErr)
}
}() }()
collector := new(manualCollector) collector := new(manualCollector)
clock := &manualClock{current: time.Now()} clock := &manualClock{current: time.Now()}
@@ -1488,16 +1260,14 @@ func TestClientImmediateTimeout(t *testing.T) {
attempt := 0 attempt := 0
agent.start = func(id [TransactionIDSize]byte, deadline time.Time) error { agent.start = func(id [TransactionIDSize]byte, deadline time.Time) error {
if attempt == 0 { if attempt == 0 {
if deadline.Before(clock.current.Add(rto / 2)) { assert.False(t, deadline.Before(clock.current.Add(rto/2)), "deadline too fast")
t.Error("deadline too fast")
}
attempt++ attempt++
go agent.h(Event{ go agent.h(Event{
TransactionID: id, TransactionID: id,
Message: response, Message: response,
}) })
} else { } else {
t.Error("there should be no second attempt") assert.Fail(t, "there should be no second attempt")
go agent.h(Event{ go agent.h(Event{
TransactionID: id, TransactionID: id,
Error: ErrTransactionTimeOut, Error: ErrTransactionTimeOut,
@@ -1512,25 +1282,17 @@ func TestClientImmediateTimeout(t *testing.T) {
WithCollector(collector), WithCollector(collector),
WithRTO(rto), WithRTO(rto),
) )
if err != nil { assert.NoError(t, err)
t.Fatal(err)
}
gotReads := make(chan struct{}) gotReads := make(chan struct{})
go func() { go func() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
readN, readErr := connL.Read(buf) readN, readErr := connL.Read(buf)
if readErr != nil { assert.NoError(t, readErr)
t.Error(readErr) assert.True(t, IsMessage(buf[:readN]), "should be STUN")
}
if !IsMessage(buf[:readN]) {
t.Error("should be STUN")
}
gotReads <- struct{}{} gotReads <- struct{}{}
}() }()
client.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec client.Start(MustBuild(response, BindingRequest), func(e Event) { //nolint:errcheck,gosec
if errors.Is(e.Error, ErrTransactionTimeOut) { assert.NoError(t, e.Error, "unexpected error")
t.Error("unexpected error")
}
}) })
<-gotReads <-gotReads
} }

View File

@@ -8,9 +8,10 @@ package stun
import ( import (
"encoding/base64" "encoding/base64"
"errors"
"io" "io"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func BenchmarkErrorCode_AddTo(b *testing.B) { func BenchmarkErrorCode_AddTo(b *testing.B) {
@@ -52,19 +53,13 @@ func TestErrorCodeAttribute_GetFrom(t *testing.T) {
m := New() m := New()
m.Add(AttrErrorCode, []byte{1}) m.Add(AttrErrorCode, []byte{1})
c := new(ErrorCodeAttribute) c := new(ErrorCodeAttribute)
if err := c.GetFrom(m); !errors.Is(err, io.ErrUnexpectedEOF) { assert.ErrorIs(t, c.GetFrom(m), io.ErrUnexpectedEOF)
t.Errorf("GetFrom should return <%s>, but got <%s>",
io.ErrUnexpectedEOF, err,
)
}
} }
func TestMessage_AddErrorCode(t *testing.T) { func TestMessage_AddErrorCode(t *testing.T) {
m := New() m := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil { assert.NoError(t, err)
t.Error(err)
}
copy(m.TransactionID[:], transactionID) copy(m.TransactionID[:], transactionID)
expectedCode := ErrorCode(438) expectedCode := ErrorCode(438)
expectedReason := "Stale Nonce" expectedReason := "Stale Nonce"
@@ -72,23 +67,13 @@ func TestMessage_AddErrorCode(t *testing.T) {
m.WriteHeader() m.WriteHeader()
mRes := New() mRes := New()
if _, err = mRes.ReadFrom(m.reader()); err != nil { _, err = mRes.ReadFrom(m.reader())
t.Fatal(err) assert.NoError(t, err)
}
errCodeAttr := new(ErrorCodeAttribute) errCodeAttr := new(ErrorCodeAttribute)
if err = errCodeAttr.GetFrom(mRes); err != nil { assert.NoError(t, errCodeAttr.GetFrom(mRes))
t.Error(err)
}
code := errCodeAttr.Code code := errCodeAttr.Code
if err != nil { assert.Equal(t, expectedCode, code, "bad code")
t.Error(err) assert.Equal(t, expectedReason, string(errCodeAttr.Reason), "bad reason")
}
if code != expectedCode {
t.Error("bad code", code)
}
if string(errCodeAttr.Reason) != expectedReason {
t.Error("bad reason", string(errCodeAttr.Reason))
}
} }
func TestErrorCode(t *testing.T) { func TestErrorCode(t *testing.T) {
@@ -96,19 +81,11 @@ func TestErrorCode(t *testing.T) {
Code: 404, Code: 404,
Reason: []byte("not found!"), Reason: []byte("not found!"),
} }
if attr.String() != "404: not found!" { assert.Equal(t, "404: not found!", attr.String(), "bad string")
t.Error("bad string", attr)
}
m := New() m := New()
cod := ErrorCode(666) cod := ErrorCode(666)
if err := cod.AddTo(m); !errors.Is(err, ErrNoDefaultReason) { assert.ErrorIs(t, cod.AddTo(m), ErrNoDefaultReason, "should be ErrNoDefaultReason")
t.Error("should be ErrNoDefaultReason", err) assert.Error(t, attr.GetFrom(m), "attr should not be in message")
}
if err := attr.GetFrom(m); err == nil {
t.Error("attr should not be in message")
}
attr.Reason = make([]byte, 2048) attr.Reason = make([]byte, 2048)
if err := attr.AddTo(m); err == nil { assert.Error(t, attr.AddTo(m), "should error")
t.Error("should error")
}
} }

View File

@@ -6,6 +6,8 @@ package stun
import ( import (
"errors" "errors"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestDecodeErr_IsInvalidCookie(t *testing.T) { func TestDecodeErr_IsInvalidCookie(t *testing.T) {
@@ -14,25 +16,13 @@ func TestDecodeErr_IsInvalidCookie(t *testing.T) {
decoded := new(Message) decoded := new(Message)
m.Raw[4] = 55 m.Raw[4] = 55
_, err := decoded.Write(m.Raw) _, err := decoded.Write(m.Raw)
if err == nil { assert.Error(t, err, "should error")
t.Fatal("should error")
}
expected := "BadFormat for message/cookie: " + expected := "BadFormat for message/cookie: " +
"3712a442 is invalid magic cookie (should be 2112a442)" "3712a442 is invalid magic cookie (should be 2112a442)"
if err.Error() != expected { assert.Equal(t, expected, err.Error(), "error message mismatch")
t.Error(err, "!=", expected)
}
var dErr *DecodeErr var dErr *DecodeErr
if !errors.As(err, &dErr) { assert.True(t, errors.As(err, &dErr), "not decode error")
t.Error("not decode error") assert.True(t, dErr.IsInvalidCookie(), "IsInvalidCookie = false, should be true")
} assert.True(t, dErr.IsPlaceChildren("cookie"), "bad children")
if !dErr.IsInvalidCookie() { assert.True(t, dErr.IsPlaceParent("message"), "bad parent")
t.Error("IsInvalidCookie = false, should be true")
}
if !dErr.IsPlaceChildren("cookie") {
t.Error("bad children")
}
if !dErr.IsPlaceParent("message") {
t.Error("bad parent")
}
} }

View File

@@ -9,6 +9,8 @@ package stun
import ( import (
"net" "net"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func BenchmarkFingerprint_AddTo(b *testing.B) { func BenchmarkFingerprint_AddTo(b *testing.B) {
@@ -36,26 +38,18 @@ func TestFingerprint_Check(t *testing.T) {
m.WriteHeader() m.WriteHeader()
Fingerprint.AddTo(m) //nolint:errcheck,gosec Fingerprint.AddTo(m) //nolint:errcheck,gosec
m.WriteHeader() m.WriteHeader()
if err := Fingerprint.Check(m); err != nil { assert.NoError(t, Fingerprint.Check(m))
t.Error(err)
}
m.Raw[3]++ m.Raw[3]++
if err := Fingerprint.Check(m); err == nil { assert.Error(t, Fingerprint.Check(m))
t.Error("should error")
}
} }
func TestFingerprint_CheckBad(t *testing.T) { func TestFingerprint_CheckBad(t *testing.T) {
m := new(Message) m := new(Message)
addAttr(t, m, NewSoftware("software")) addAttr(t, m, NewSoftware("software"))
m.WriteHeader() m.WriteHeader()
if err := Fingerprint.Check(m); err == nil { assert.Error(t, Fingerprint.Check(m))
t.Error("should error")
}
m.Add(AttrFingerprint, []byte{1, 2, 3}) m.Add(AttrFingerprint, []byte{1, 2, 3})
if !IsAttrSizeInvalid(Fingerprint.Check(m)) { assert.True(t, IsAttrSizeInvalid(Fingerprint.Check(m)))
t.Error("IsAttrSizeInvalid should be true")
}
} }
func BenchmarkFingerprint_Check(b *testing.B) { func BenchmarkFingerprint_Check(b *testing.B) {

View File

@@ -7,6 +7,8 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func FuzzMessage(f *testing.F) { func FuzzMessage(f *testing.F) {
@@ -26,21 +28,12 @@ func FuzzMessage(f *testing.F) {
} }
msg2 := New() msg2 := New()
if _, err := msg2.Write(msg1.Raw); err != nil { _, err := msg2.Write(msg1.Raw)
t.Fatalf("Failed to write: %s", err) assert.NoError(t, err, "Failed to write")
}
if msg2.TransactionID != msg1.TransactionID { assert.Equal(t, msg1.TransactionID, msg2.TransactionID, "Transaction ID mismatch")
t.Fatal("Transaction ID mismatch") assert.Equal(t, msg1.Type, msg2.Type, "Type mismatch")
} assert.Equal(t, len(msg1.Attributes), len(msg2.Attributes), "Attributes length mismatch")
if msg2.Type != msg1.Type {
t.Fatal("Type mismatch")
}
if len(msg2.Attributes) != len(msg1.Attributes) {
t.Fatal("Attributes length mismatch")
}
}) })
} }
@@ -51,15 +44,11 @@ func FuzzType(f *testing.F) {
t1 := MessageType{} t1 := MessageType{}
t1.ReadValue(v) t1.ReadValue(v)
v2 := t1.Value() v2 := t1.Value()
if v != v2 { assert.Equal(t, v, v2, "v != v2")
t.Fatal("v != v2")
}
t2 := MessageType{} t2 := MessageType{}
t2.ReadValue(v2) t2.ReadValue(v2)
if t2 != t1 { assert.Equal(t, t1, t2, "t2 != t1")
t.Fatal("t2 != t1")
}
}) })
} }
@@ -94,20 +83,19 @@ func FuzzSetters(f *testing.F) {
m1.WriteHeader() m1.WriteHeader()
m1.Add(attr.t, value) m1.Add(attr.t, value)
err := attr.g.GetFrom(m1) err := attr.g.GetFrom(m1)
if errors.Is(err, ErrAttributeNotFound) { assert.False(t, errors.Is(err, ErrAttributeNotFound))
t.Fatalf("Unexpected 404: %s", err)
}
if err != nil { if err != nil {
return return
} }
m2.WriteHeader() m2.WriteHeader()
if err = attr.g.AddTo(m2); err != nil { err = attr.g.AddTo(m2)
if err != nil {
// We allow decoding some text attributes // We allow decoding some text attributes
// when their length is too big, but // when their length is too big, but
// not encoding. // not encoding.
if !IsAttrSizeOverflow(err) { if !IsAttrSizeOverflow(err) {
t.Fatal(err) assert.NoError(t, err)
} }
return return
@@ -115,14 +103,10 @@ func FuzzSetters(f *testing.F) {
m3.WriteHeader() m3.WriteHeader()
v, err := m2.Get(attr.t) v, err := m2.Get(attr.t)
if err != nil { assert.NoError(t, err)
t.Fatal(err)
}
m3.Add(attr.t, v) m3.Add(attr.t, v)
if !m2.Equal(m3) { assert.True(t, m2.Equal(m3), "Not equal: %s != %s", m2, m3)
t.Fatalf("Not equal: %s != %s", m2, m3)
}
}) })
} }

View File

@@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/pion/stun/v3/internal/testutil" "github.com/pion/stun/v3/internal/testutil"
"github.com/stretchr/testify/assert"
) )
func BenchmarkBuildOverhead(b *testing.B) { func BenchmarkBuildOverhead(b *testing.B) {
@@ -59,21 +60,12 @@ func TestMessage_Apply(t *testing.T) {
integrity, integrity,
Fingerprint, Fingerprint,
) )
if err != nil { assert.NoError(t, err, "failed to build")
t.Fatal("failed to build:", err) assert.NoError(t, msg.Check(Fingerprint, integrity))
} _, err = decoded.Write(msg.Raw)
if err = msg.Check(Fingerprint, integrity); err != nil { assert.NoError(t, err)
t.Fatal(err) assert.True(t, decoded.Equal(msg))
} assert.NoError(t, integrity.Check(decoded))
if _, err := decoded.Write(msg.Raw); err != nil {
t.Fatal(err)
}
if !decoded.Equal(msg) {
t.Error("not equal")
}
if err := integrity.Check(decoded); err != nil {
t.Fatal(err)
}
} }
type errReturner struct { type errReturner struct {
@@ -97,25 +89,17 @@ func (e errReturner) GetFrom(*Message) error {
func TestHelpersErrorHandling(t *testing.T) { func TestHelpersErrorHandling(t *testing.T) {
m := New() m := New()
errReturn := errReturner{Err: errTError} errReturn := errReturner{Err: errTError}
if err := m.Build(errReturn); !errors.Is(err, errReturn.Err) { assert.ErrorIs(t, m.Build(errReturn), errReturn.Err)
t.Error(err, "!=", errReturn.Err) assert.ErrorIs(t, m.Check(errReturn), errReturn.Err)
} assert.ErrorIs(t, m.Parse(errReturn), errReturn.Err)
if err := m.Check(errReturn); !errors.Is(err, errReturn.Err) {
t.Error(err, "!=", errReturn.Err)
}
if err := m.Parse(errReturn); !errors.Is(err, errReturn.Err) {
t.Error(err, "!=", errReturn.Err)
}
t.Run("MustBuild", func(t *testing.T) { t.Run("MustBuild", func(t *testing.T) {
t.Run("Positive", func(*testing.T) { t.Run("Positive", func(*testing.T) {
MustBuild(NewTransactionIDSetter(transactionID{})) MustBuild(NewTransactionIDSetter(transactionID{}))
}) })
defer func() { defer func() {
if p, ok := recover().(error); !ok || !errors.Is(p, errReturn.Err) { p, ok := recover().(error)
t.Errorf("%s != %s", assert.True(t, ok)
p, errReturn.Err, assert.ErrorIs(t, p, errReturn.Err)
)
}
}() }()
MustBuild(errReturn) MustBuild(errReturn)
}) })
@@ -123,91 +107,62 @@ func TestHelpersErrorHandling(t *testing.T) {
func TestMessage_ForEach(t *testing.T) { //nolint:cyclop func TestMessage_ForEach(t *testing.T) { //nolint:cyclop
initial := New() initial := New()
if err := initial.Build( assert.NoError(t, initial.Build(
NewRealm("realm1"), NewRealm("realm2"), NewRealm("realm1"), NewRealm("realm2"),
); err != nil { ))
t.Fatal(err)
}
newMessage := func() *Message { newMessage := func() *Message {
m := New() m := New()
if err := m.Build( assert.NoError(t, m.Build(
NewRealm("realm1"), NewRealm("realm2"), NewRealm("realm1"), NewRealm("realm2"),
); err != nil { ))
t.Fatal(err)
}
return m return m
} }
t.Run("NoResults", func(t *testing.T) { t.Run("NoResults", func(t *testing.T) {
m := newMessage() m := newMessage()
if !m.Equal(initial) { assert.True(t, m.Equal(initial), "m should be equal to initial")
t.Error("m should be equal to initial") assert.NoError(t, m.ForEach(AttrUsername, func(*Message) error {
} assert.Fail(t, "should not be called")
if err := m.ForEach(AttrUsername, func(*Message) error {
t.Error("should not be called")
return nil return nil
}); err != nil { }))
t.Fatal(err) assert.True(t, m.Equal(initial), "m should be equal to initial")
}
if !m.Equal(initial) {
t.Error("m should be equal to initial")
}
}) })
t.Run("ReturnOnError", func(t *testing.T) { t.Run("ReturnOnError", func(t *testing.T) {
m := newMessage() m := newMessage()
var calls int var calls int
if err := m.ForEach(AttrRealm, func(*Message) error { err := m.ForEach(AttrRealm, func(*Message) error {
if calls > 0 { if calls > 0 {
t.Error("called multiple times") assert.Fail(t, "called multiple times")
} }
calls++ calls++
return ErrAttributeNotFound return ErrAttributeNotFound
}); !errors.Is(err, ErrAttributeNotFound) { })
t.Fatal(err) assert.ErrorIs(t, err, ErrAttributeNotFound)
} assert.True(t, m.Equal(initial), "m should be equal to initial")
if !m.Equal(initial) {
t.Error("m should be equal to initial")
}
}) })
t.Run("Positive", func(t *testing.T) { t.Run("Positive", func(t *testing.T) {
msg := newMessage() msg := newMessage()
var realms []string var realms []string
if err := msg.ForEach(AttrRealm, func(m *Message) error { assert.NoError(t, msg.ForEach(AttrRealm, func(m *Message) error {
var realm Realm var realm Realm
if err := realm.GetFrom(m); err != nil { assert.NoError(t, realm.GetFrom(m))
return err
}
realms = append(realms, realm.String()) realms = append(realms, realm.String())
return nil return nil
}); err != nil { }))
t.Fatal(err) assert.Len(t, realms, 2)
} assert.Equal(t, "realm1", realms[0], "bad value for 1 realm")
if len(realms) != 2 { assert.Equal(t, "realm2", realms[1], "bad value for 2 realm")
t.Fatal("expected 2 realms") assert.True(t, msg.Equal(initial), "m should be equal to initial")
}
if realms[0] != "realm1" {
t.Error("bad value for 1 realm")
}
if realms[1] != "realm2" {
t.Error("bad value for 2 realm")
}
if !msg.Equal(initial) {
t.Error("m should be equal to initial")
}
t.Run("ZeroAlloc", func(t *testing.T) { t.Run("ZeroAlloc", func(t *testing.T) {
msg = newMessage() msg = newMessage()
var realm Realm var realm Realm
testutil.ShouldNotAllocate(t, func() { testutil.ShouldNotAllocate(t, func() {
if err := msg.ForEach(AttrRealm, realm.GetFrom); err != nil { assert.NoError(t, msg.ForEach(AttrRealm, realm.GetFrom))
t.Fatal(err)
}
}) })
if !msg.Equal(initial) { assert.True(t, msg.Equal(initial), "m should be equal to initial")
t.Error("m should be equal to initial")
}
}) })
}) })
} }

View File

@@ -12,6 +12,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func loadCSV(tb testing.TB, name string) [][]string { func loadCSV(tb testing.TB, name string) [][]string {
@@ -21,9 +23,7 @@ func loadCSV(tb testing.TB, name string) [][]string {
r := csv.NewReader(bytes.NewReader(data)) r := csv.NewReader(bytes.NewReader(data))
r.Comment = '#' r.Comment = '#'
records, err := r.ReadAll() records, err := r.ReadAll()
if err != nil { assert.NoError(tb, err)
tb.Fatal(err)
}
return records return records
} }
@@ -41,20 +41,14 @@ func TestIANA(t *testing.T) { //nolint:cyclop
continue continue
} }
val, parseErr := strconv.ParseInt(v[2:], 16, 64) val, parseErr := strconv.ParseInt(v[2:], 16, 64)
if parseErr != nil { assert.NoError(t, parseErr)
t.Fatal(parseErr)
}
t.Logf("value: 0x%x, name: %s", val, name) t.Logf("value: 0x%x, name: %s", val, name)
methods[name] = Method(val) //nolint:gosec // G115 methods[name] = Method(val) //nolint:gosec // G115
} }
for val, name := range methodName() { for val, name := range methodName() {
mapped, ok := methods[name] mapped, ok := methods[name]
if !ok { assert.True(t, ok, "failed to find method %s in IANA", name)
t.Errorf("failed to find method %s in IANA", name) assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val)
}
if mapped != val {
t.Errorf("%s: IANA %d != actual %d", name, mapped, val)
}
} }
}) })
t.Run("Attributes", func(t *testing.T) { t.Run("Attributes", func(t *testing.T) {
@@ -69,9 +63,7 @@ func TestIANA(t *testing.T) { //nolint:cyclop
continue continue
} }
val, parseErr := strconv.ParseInt(v[2:], 16, 64) val, parseErr := strconv.ParseInt(v[2:], 16, 64)
if parseErr != nil { assert.NoError(t, parseErr)
t.Fatal(parseErr)
}
t.Logf("value: 0x%x, name: %s", val, name) t.Logf("value: 0x%x, name: %s", val, name)
attrTypes[name] = AttrType(val) //nolint:gosec // G115 attrTypes[name] = AttrType(val) //nolint:gosec // G115
} }
@@ -83,12 +75,8 @@ func TestIANA(t *testing.T) { //nolint:cyclop
} }
for val, name := range attrNames() { for val, name := range attrNames() {
mapped, ok := attrTypes[name] mapped, ok := attrTypes[name]
if !ok { assert.True(t, ok, "failed to find attribute %s in IANA", name)
t.Errorf("failed to find attribute %s in IANA", name) assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val)
}
if mapped != val {
t.Errorf("%s: IANA %d != actual %d", name, mapped, val)
}
} }
}) })
t.Run("ErrorCodes", func(t *testing.T) { t.Run("ErrorCodes", func(t *testing.T) {
@@ -103,21 +91,15 @@ func TestIANA(t *testing.T) { //nolint:cyclop
continue continue
} }
val, parseErr := strconv.ParseInt(v, 10, 64) val, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil { assert.NoError(t, parseErr)
t.Fatal(parseErr)
}
t.Logf("value: 0x%x, name: %s", val, name) t.Logf("value: 0x%x, name: %s", val, name)
errorCodes[name] = ErrorCode(val) errorCodes[name] = ErrorCode(val)
} }
for val, nameB := range errorReasons { for val, nameB := range errorReasons {
name := string(nameB) name := string(nameB)
mapped, ok := errorCodes[name] mapped, ok := errorCodes[name]
if !ok { assert.True(t, ok, "failed to find error code %s in IANA", name)
t.Errorf("failed to find error code %s in IANA", name) assert.Equal(t, mapped, val, "%s: IANA %d != actual %d", name, mapped, val)
}
if mapped != val {
t.Errorf("%s: IANA %d != actual %d", name, mapped, val)
}
} }
}) })
} }

View File

@@ -4,40 +4,29 @@
package stun package stun
import ( import (
"bytes"
"encoding/hex" "encoding/hex"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestMessageIntegrity_AddTo_Simple(t *testing.T) { func TestMessageIntegrity_AddTo_Simple(t *testing.T) {
integrity := NewLongTermIntegrity("user", "realm", "pass") integrity := NewLongTermIntegrity("user", "realm", "pass")
expected, err := hex.DecodeString("8493fbc53ba582fb4c044c456bdc40eb") expected, err := hex.DecodeString("8493fbc53ba582fb4c044c456bdc40eb")
if err != nil { assert.NoError(t, err)
t.Fatal(err) assert.EqualValues(t, expected, integrity)
}
if !bytes.Equal(expected, integrity) {
t.Error(ErrIntegrityMismatch)
}
t.Run("Check", func(t *testing.T) { t.Run("Check", func(t *testing.T) {
m := new(Message) m := new(Message)
m.WriteHeader() m.WriteHeader()
if err := integrity.AddTo(m); err != nil { assert.NoError(t, integrity.AddTo(m))
t.Error(err)
}
NewSoftware("software").AddTo(m) //nolint:errcheck,gosec NewSoftware("software").AddTo(m) //nolint:errcheck,gosec
m.WriteHeader() m.WriteHeader()
dM := new(Message) dM := new(Message)
dM.Raw = m.Raw dM.Raw = m.Raw
if err := dM.Decode(); err != nil { assert.NoError(t, dM.Decode())
t.Error(err) assert.NoError(t, integrity.Check(dM))
}
if err := integrity.Check(dM); err != nil {
t.Error(err)
}
dM.Raw[24] += 12 // HMAC now invalid dM.Raw[24] += 12 // HMAC now invalid
if integrity.Check(dM) == nil { assert.Error(t, integrity.Check(dM))
t.Error("should be invalid")
}
}) })
} }
@@ -47,38 +36,23 @@ func TestMessageIntegrityWithFingerprint(t *testing.T) {
msg.WriteHeader() msg.WriteHeader()
NewSoftware("software").AddTo(msg) //nolint:errcheck,gosec NewSoftware("software").AddTo(msg) //nolint:errcheck,gosec
integrity := NewShortTermIntegrity("pwd") integrity := NewShortTermIntegrity("pwd")
if integrity.String() != "KEY: 0x707764" { assert.Equal(t, "KEY: 0x707764", integrity.String())
t.Error("bad string", integrity) assert.NoError(t, integrity.AddTo(msg))
} assert.NoError(t, integrity.AddTo(msg))
if err := integrity.Check(msg); err == nil { assert.NoError(t, integrity.Check(msg))
t.Error("should error") assert.NoError(t, Fingerprint.AddTo(msg))
} assert.NoError(t, integrity.Check(msg))
if err := integrity.AddTo(msg); err != nil {
t.Fatal(err)
}
if err := Fingerprint.AddTo(msg); err != nil {
t.Fatal(err)
}
if err := integrity.Check(msg); err != nil {
t.Fatal(err)
}
msg.Raw[24] = 33 msg.Raw[24] = 33
if err := integrity.Check(msg); err == nil { assert.Error(t, integrity.Check(msg))
t.Fatal("mismatch expected")
}
} }
func TestMessageIntegrity(t *testing.T) { func TestMessageIntegrity(t *testing.T) {
m := new(Message) m := new(Message)
i := NewShortTermIntegrity("password") i := NewShortTermIntegrity("password")
m.WriteHeader() m.WriteHeader()
if err := i.AddTo(m); err != nil { assert.NoError(t, i.AddTo(m))
t.Error(err)
}
_, err := m.Get(AttrMessageIntegrity) _, err := m.Get(AttrMessageIntegrity)
if err != nil { assert.NoError(t, err)
t.Error(err)
}
} }
func TestMessageIntegrityBeforeFingerprint(t *testing.T) { func TestMessageIntegrityBeforeFingerprint(t *testing.T) {
@@ -86,9 +60,7 @@ func TestMessageIntegrityBeforeFingerprint(t *testing.T) {
m.WriteHeader() m.WriteHeader()
Fingerprint.AddTo(m) //nolint:errcheck,gosec Fingerprint.AddTo(m) //nolint:errcheck,gosec
i := NewShortTermIntegrity("password") i := NewShortTermIntegrity("password")
if err := i.AddTo(m); err == nil { assert.Error(t, i.AddTo(m))
t.Error("should error")
}
} }
func BenchmarkMessageIntegrity_AddTo(b *testing.B) { func BenchmarkMessageIntegrity_AddTo(b *testing.B) {
@@ -99,9 +71,7 @@ func BenchmarkMessageIntegrity_AddTo(b *testing.B) {
b.SetBytes(int64(len(m.Raw))) b.SetBytes(int64(len(m.Raw)))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m.WriteHeader() m.WriteHeader()
if err := integrity.AddTo(m); err != nil { assert.NoError(b, integrity.AddTo(m))
b.Error(err)
}
m.Reset() m.Reset()
} }
} }
@@ -114,13 +84,9 @@ func BenchmarkMessageIntegrity_Check(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
m.WriteHeader() m.WriteHeader()
b.SetBytes(int64(len(m.Raw))) b.SetBytes(int64(len(m.Raw)))
if err := integrity.AddTo(m); err != nil { assert.NoError(b, integrity.AddTo(m))
b.Error(err)
}
m.WriteLength() m.WriteLength()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := integrity.Check(m); err != nil { assert.NoError(b, integrity.Check(m))
b.Fatal(err)
}
} }
} }

View File

@@ -11,6 +11,8 @@ import (
"fmt" "fmt"
"hash" "hash"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
type hmacTest struct { type hmacTest struct {
@@ -524,26 +526,17 @@ func hmacTests() []hmacTest { //nolint:maintidx
func TestHMAC(t *testing.T) { func TestHMAC(t *testing.T) {
for i, tt := range hmacTests() { for i, tt := range hmacTests() {
hsh := New(tt.hash, tt.key) hsh := New(tt.hash, tt.key)
if s := hsh.Size(); s != tt.size { assert.Equal(t, tt.size, hsh.Size(), "Size mismatch")
t.Errorf("Size: got %v, want %v", s, tt.size) assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch")
}
if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
}
for j := 0; j < 4; j++ { //nolint:varnamelen for j := 0; j < 4; j++ { //nolint:varnamelen
n, err := hsh.Write(tt.in) n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil { assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n)
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) assert.NoError(t, err, "test %d.%d: Write error", i, j)
continue
}
// Repetitive Sum() calls should return the same value // Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ { for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", hsh.Sum(nil)) sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out { assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
} }
// Second iteration: make sure reset works. // Second iteration: make sure reset works.
@@ -568,18 +561,10 @@ func TestEqual(t *testing.T) {
b := []byte("test1") b := []byte("test1")
c := []byte("test2") c := []byte("test2")
if !Equal(b, b) { assert.True(t, Equal(b, b), "Equal failed with equal arguments")
t.Error("Equal failed with equal arguments") assert.False(t, Equal(a, b), "Equal accepted a prefix of the second argument")
} assert.False(t, Equal(b, a), "Equal accepted a prefix of the first argument")
if Equal(a, b) { assert.False(t, Equal(b, c), "Equal accepted unequal slices")
t.Error("Equal accepted a prefix of the second argument")
}
if Equal(b, a) {
t.Error("Equal accepted a prefix of the first argument")
}
if Equal(b, c) {
t.Error("Equal accepted unequal slices")
}
} }
func BenchmarkHMACSHA256_1K(b *testing.B) { func BenchmarkHMACSHA256_1K(b *testing.B) {

View File

@@ -8,6 +8,8 @@ import (
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func BenchmarkHMACSHA1_512(b *testing.B) { func BenchmarkHMACSHA1_512(b *testing.B) {
@@ -44,26 +46,17 @@ func TestHMACReset(t *testing.T) {
for i, tt := range hmacTests() { for i, tt := range hmacTests() {
hsh := New(tt.hash, tt.key) hsh := New(tt.hash, tt.key)
hsh.(*hmac).resetTo(tt.key) //nolint:forcetypeassert hsh.(*hmac).resetTo(tt.key) //nolint:forcetypeassert
if s := hsh.Size(); s != tt.size { assert.Equal(t, tt.size, hsh.Size(), "Size mismatch")
t.Errorf("Size: got %v, want %v", s, tt.size) assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch")
}
if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
}
for j := 0; j < 2; j++ { for j := 0; j < 2; j++ {
n, err := hsh.Write(tt.in) n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil { assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n)
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) assert.NoError(t, err, "test %d.%d: Write error", i, j)
continue
}
// Repetitive Sum() calls should return the same value // Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ { for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", hsh.Sum(nil)) sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out { assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
} }
// Second iteration: make sure reset works. // Second iteration: make sure reset works.
@@ -78,26 +71,17 @@ func TestHMACPool_SHA1(t *testing.T) { //nolint:dupl,cyclop
continue continue
} }
hsh := AcquireSHA1(tt.key) hsh := AcquireSHA1(tt.key)
if s := hsh.Size(); s != tt.size { assert.Equal(t, tt.size, hsh.Size(), "Size mismatch")
t.Errorf("Size: got %v, want %v", s, tt.size) assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch")
}
if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
}
for j := 0; j < 2; j++ { for j := 0; j < 2; j++ {
n, err := hsh.Write(tt.in) n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil { assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n)
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) assert.NoError(t, err, "test %d.%d: Write error", i, j)
continue
}
// Repetitive Sum() calls should return the same value // Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ { for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", hsh.Sum(nil)) sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out { assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
} }
// Second iteration: make sure reset works. // Second iteration: make sure reset works.
@@ -113,26 +97,17 @@ func TestHMACPool_SHA256(t *testing.T) { //nolint:dupl,cyclop
continue continue
} }
hsh := AcquireSHA256(tt.key) hsh := AcquireSHA256(tt.key)
if s := hsh.Size(); s != tt.size { assert.Equal(t, tt.size, hsh.Size(), "Size mismatch")
t.Errorf("Size: got %v, want %v", s, tt.size) assert.Equal(t, tt.blocksize, hsh.BlockSize(), "BlockSize mismatch")
}
if b := hsh.BlockSize(); b != tt.blocksize {
t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize)
}
for j := 0; j < 2; j++ { for j := 0; j < 2; j++ {
n, err := hsh.Write(tt.in) n, err := hsh.Write(tt.in)
if n != len(tt.in) || err != nil { assert.Equal(t, len(tt.in), n, "test %d.%d: Write(%d) = %d", i, j, len(tt.in), n)
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) assert.NoError(t, err, "test %d.%d: Write error", i, j)
continue
}
// Repetitive Sum() calls should return the same value // Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ { for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", hsh.Sum(nil)) sum := fmt.Sprintf("%x", hsh.Sum(nil))
if sum != tt.out { assert.Equal(t, tt.out, sum, "test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
t.Errorf("test %d.%d.%d: have %s want %s", i, j, k, sum, tt.out)
}
} }
// Second iteration: make sure reset works. // Second iteration: make sure reset works.
@@ -150,7 +125,7 @@ func TestAssertBlockSize(t *testing.T) {
t.Run("Negative", func(t *testing.T) { t.Run("Negative", func(t *testing.T) {
defer func() { defer func() {
if r := recover(); r == nil { if r := recover(); r == nil {
t.Error("should panic") assert.Fail(t, "should panic")
} }
}() }()
h := AcquireSHA256(make([]byte, 0, 1024)) h := AcquireSHA256(make([]byte, 0, 1024))

View File

@@ -6,6 +6,8 @@ package testutil
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
// ShouldNotAllocate fails if f allocates. // ShouldNotAllocate fails if f allocates.
@@ -17,7 +19,5 @@ func ShouldNotAllocate(t *testing.T, f func()) {
return return
} }
if a := testing.AllocsPerRun(10, f); a > 0 { assert.Zero(t, testing.AllocsPerRun(10, f))
t.Errorf("Allocations detected: %f", a)
}
} }

View File

@@ -21,6 +21,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
type attributeEncoder interface { type attributeEncoder interface {
@@ -50,12 +52,9 @@ func TestMessageBuffer(t *testing.T) {
m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) m.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
m.WriteHeader() m.WriteHeader()
mDecoded := New() mDecoded := New()
if _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw)); err != nil { _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw))
t.Error(err) assert.NoError(t, err)
} assert.True(t, mDecoded.Equal(m), "mDecoded != m")
if !mDecoded.Equal(m) {
t.Error(mDecoded, "!", m)
}
} }
func BenchmarkMessage_Write(b *testing.B) { func BenchmarkMessage_Write(b *testing.B) {
@@ -86,9 +85,7 @@ func TestMessageType_Value(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
b := tt.in.Value() b := tt.in.Value()
if b != tt.out { assert.Equal(t, tt.out, b, "Value(%s) -> %s, want %s", tt.in, bUint16(b), bUint16(tt.out))
t.Errorf("Value(%s) -> %s, want %s", tt.in, bUint16(b), bUint16(tt.out))
}
} }
} }
@@ -104,9 +101,7 @@ func TestMessageType_ReadValue(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
m := MessageType{} m := MessageType{}
m.ReadValue(tt.in) m.ReadValue(tt.in)
if m != tt.out { assert.Equal(t, tt.out, m, "ReadValue(%s) -> %s, want %s", bUint16(tt.in), m, tt.out)
t.Errorf("ReadValue(%s) -> %s, want %s", bUint16(tt.in), m, tt.out)
}
} }
} }
@@ -121,12 +116,8 @@ func TestMessageType_ReadWriteValue(t *testing.T) {
m := MessageType{} m := MessageType{}
v := tt.Value() v := tt.Value()
m.ReadValue(v) m.ReadValue(v)
if m != tt { assert.Equal(t, tt, m, "ReadValue(%s -> %s) = %s, should be %s", tt, bUint16(v), m, tt)
t.Errorf("ReadValue(%s -> %s) = %s, should be %s", tt, bUint16(v), m, tt) assert.Equal(t, tt.Method, m.Method, "%s != %s", bUint16(uint16(m.Method)), bUint16(uint16(tt.Method)))
if m.Method != tt.Method {
t.Errorf("%s != %s", bUint16(uint16(m.Method)), bUint16(uint16(tt.Method)))
}
}
} }
} }
@@ -137,32 +128,26 @@ func TestMessage_WriteTo(t *testing.T) {
msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
msg.WriteHeader() msg.WriteHeader()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if _, err := msg.WriteTo(buf); err != nil { _, err := msg.WriteTo(buf)
t.Fatal(err) assert.NoError(t, err)
}
mDecoded := New() mDecoded := New()
if _, err := mDecoded.ReadFrom(buf); err != nil { _, err = mDecoded.ReadFrom(buf)
t.Error(err) assert.NoError(t, err)
} assert.True(t, mDecoded.Equal(msg), "mDecoded != msg")
if !mDecoded.Equal(msg) {
t.Error(mDecoded, "!", msg)
}
} }
func TestMessage_Cookie(t *testing.T) { func TestMessage_Cookie(t *testing.T) {
buf := make([]byte, 20) buf := make([]byte, 20)
mDecoded := New() mDecoded := New()
if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil { _, err := mDecoded.ReadFrom(bytes.NewReader(buf))
t.Error("should error") assert.Error(t, err, "should error")
}
} }
func TestMessage_LengthLessHeaderSize(t *testing.T) { func TestMessage_LengthLessHeaderSize(t *testing.T) {
buf := make([]byte, 8) buf := make([]byte, 8)
mDecoded := New() mDecoded := New()
if _, err := mDecoded.ReadFrom(bytes.NewReader(buf)); err == nil { _, err := mDecoded.ReadFrom(bytes.NewReader(buf))
t.Error("should error") assert.Error(t, err, "should error")
}
} }
func TestMessage_BadLength(t *testing.T) { func TestMessage_BadLength(t *testing.T) {
@@ -176,9 +161,8 @@ func TestMessage_BadLength(t *testing.T) {
m.WriteHeader() m.WriteHeader()
m.Raw[20+3] = 10 // set attr length = 10 m.Raw[20+3] = 10 // set attr length = 10
mDecoded := New() mDecoded := New()
if _, err := mDecoded.Write(m.Raw); err == nil { _, err := mDecoded.Write(m.Raw)
t.Error("should error") assert.Error(t, err, "should error")
}
} }
func TestMessage_AttrLengthLessThanHeader(t *testing.T) { func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
@@ -197,13 +181,8 @@ func TestMessage_AttrLengthLessThanHeader(t *testing.T) {
binary.BigEndian.PutUint16(m.Raw[2:4], 2) // rewrite to bad length binary.BigEndian.PutUint16(m.Raw[2:4], 2) // rewrite to bad length
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+2])) _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+2]))
var e *DecodeErr var e *DecodeErr
if errors.As(err, &e) { assert.ErrorAs(t, err, &e)
if !e.IsPlace(DecodeErrPlace{"attribute", "header"}) { assert.True(t, e.IsPlace(DecodeErrPlace{"attribute", "header"}), "bad place")
t.Error(e, "bad place")
}
} else {
t.Error(err, "should be bad format")
}
} }
func TestMessage_AttrSizeLessThanLength(t *testing.T) { func TestMessage_AttrSizeLessThanLength(t *testing.T) {
@@ -226,13 +205,8 @@ func TestMessage_AttrSizeLessThanLength(t *testing.T) {
mDecoded := New() mDecoded := New()
_, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+5])) _, err := mDecoded.ReadFrom(bytes.NewReader(m.Raw[:20+5]))
var e *DecodeErr var e *DecodeErr
if errors.As(err, &e) { assert.ErrorAs(t, err, &e)
if !e.IsPlace(DecodeErrPlace{"attribute", "value"}) { assert.True(t, e.IsPlace(DecodeErrPlace{"attribute", "value"}), "bad place")
t.Error(e, "bad place")
}
} else {
t.Error(err, "should be bad format")
}
} }
type unexpectedEOFReader struct{} type unexpectedEOFReader struct{}
@@ -244,9 +218,7 @@ func (r unexpectedEOFReader) Read([]byte) (int, error) {
func TestMessage_ReadFromError(t *testing.T) { func TestMessage_ReadFromError(t *testing.T) {
mDecoded := New() mDecoded := New()
_, err := mDecoded.ReadFrom(unexpectedEOFReader{}) _, err := mDecoded.ReadFrom(unexpectedEOFReader{})
if !errors.Is(err, io.ErrUnexpectedEOF) { assert.ErrorIs(t, err, io.ErrUnexpectedEOF, "should be", io.ErrUnexpectedEOF)
t.Error(err, "should be", io.ErrUnexpectedEOF)
}
} }
func BenchmarkMessageType_Value(b *testing.B) { func BenchmarkMessageType_Value(b *testing.B) {
@@ -321,9 +293,7 @@ func BenchmarkMessage_ReadBytes(b *testing.B) {
func TestMessageClass_String(t *testing.T) { func TestMessageClass_String(t *testing.T) {
defer func() { defer func() {
if err := recover(); err == nil { assert.NotNil(t, recover())
t.Error(err, "should be not nil")
}
}() }()
v := [...]MessageClass{ v := [...]MessageClass{
@@ -333,14 +303,12 @@ func TestMessageClass_String(t *testing.T) {
ClassIndication, ClassIndication,
} }
for _, k := range v { for _, k := range v {
if k.String() == "" { assert.NotEmpty(t, k.String(), "%v bad stringer", k)
t.Error(k, "bad stringer")
}
} }
// should panic // should panic
p := MessageClass(0x05).String() p := MessageClass(0x05).String()
t.Error("should panic!", p) assert.Fail(t, "should panic", p)
} }
func TestAttrType_String(t *testing.T) { func TestAttrType_String(t *testing.T) {
@@ -358,46 +326,26 @@ func TestAttrType_String(t *testing.T) {
AttrFingerprint, AttrFingerprint,
} }
for _, k := range attrType { for _, k := range attrType {
if k.String() == "" { assert.NotEmpty(t, k.String(), "%v bad stringer", k)
t.Error(k, "bad stringer") assert.False(t, strings.HasPrefix(k.String(), "0x"), "%v bad stringer", k)
}
if strings.HasPrefix(k.String(), "0x") {
t.Error(k, "bad stringer")
}
} }
vNonStandard := AttrType(0x512) vNonStandard := AttrType(0x512)
if !strings.HasPrefix(vNonStandard.String(), "0x512") { assert.True(t, strings.HasPrefix(vNonStandard.String(), "0x512"), "%v bad prefix", vNonStandard)
t.Error(vNonStandard, "bad prefix")
}
} }
func TestMethod_String(t *testing.T) { func TestMethod_String(t *testing.T) {
if MethodBinding.String() != "Binding" { assert.Equal(t, "Binding", MethodBinding.String(), "binding is not binding!")
t.Error("binding is not binding!") assert.Equal(t, "0x616", Method(0x616).String(), "Bad stringer")
}
if Method(0x616).String() != "0x616" {
t.Error("Bad stringer", Method(0x616))
}
} }
func TestAttribute_Equal(t *testing.T) { func TestAttribute_Equal(t *testing.T) {
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}} attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}} attr2 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
if !attr1.Equal(attr2) { assert.True(t, attr1.Equal(attr2))
t.Error("should equal") assert.False(t, attr1.Equal(RawAttribute{Type: 0x2}))
} assert.False(t, attr1.Equal(RawAttribute{Length: 0x2}))
if attr1.Equal(RawAttribute{Type: 0x2}) { assert.False(t, attr1.Equal(RawAttribute{Length: 0x3}))
t.Error("should not equal") assert.False(t, attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}))
}
if attr1.Equal(RawAttribute{Length: 0x2}) {
t.Error("should not equal")
}
if attr1.Equal(RawAttribute{Length: 0x3}) {
t.Error("should not equal")
}
if attr1.Equal(RawAttribute{Length: 2, Value: []byte{0x1, 0x3}}) {
t.Error("should not equal")
}
} }
func TestMessage_Equal(t *testing.T) { //nolint:cyclop func TestMessage_Equal(t *testing.T) { //nolint:cyclop
@@ -405,39 +353,23 @@ func TestMessage_Equal(t *testing.T) { //nolint:cyclop
attrs := Attributes{attr} attrs := Attributes{attr}
msg1 := &Message{Attributes: attrs, Length: 4 + 2} msg1 := &Message{Attributes: attrs, Length: 4 + 2}
msg2 := &Message{Attributes: attrs, Length: 4 + 2} msg2 := &Message{Attributes: attrs, Length: 4 + 2}
if !msg1.Equal(msg2) { assert.True(t, msg1.Equal(msg2))
t.Error("should equal") assert.False(t, msg1.Equal(&Message{Type: MessageType{Class: 128}}))
}
if msg1.Equal(&Message{Type: MessageType{Class: 128}}) {
t.Error("should not equal")
}
tID := [TransactionIDSize]byte{ tID := [TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
} }
if msg1.Equal(&Message{TransactionID: tID}) { assert.False(t, msg1.Equal(&Message{TransactionID: tID}))
t.Error("should not equal") assert.False(t, msg1.Equal(&Message{Length: 3}))
}
if msg1.Equal(&Message{Length: 3}) {
t.Error("should not equal")
}
tAttrs := Attributes{ tAttrs := Attributes{
{Length: 1, Value: []byte{0x1}, Type: 0x1}, {Length: 1, Value: []byte{0x1}, Type: 0x1},
} }
if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) { assert.False(t, msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}))
t.Error("should not equal")
}
tAttrs = Attributes{ tAttrs = Attributes{
{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2}, {Length: 2, Value: []byte{0x1, 0x1}, Type: 0x2},
} }
if msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}) { assert.False(t, msg1.Equal(&Message{Attributes: tAttrs, Length: 4 + 2}))
t.Error("should not equal") assert.True(t, (*Message)(nil).Equal(nil), "nil should be equal to nil")
} assert.False(t, msg1.Equal(nil), "non-nil should not be equal to nil")
if !(*Message)(nil).Equal(nil) {
t.Error("nil should be equal to nil")
}
if msg1.Equal(nil) {
t.Error("non-nil should not be equal to nil")
}
t.Run("Nil attributes", func(t *testing.T) { t.Run("Nil attributes", func(t *testing.T) {
msg1 := &Message{ msg1 := &Message{
Attributes: nil, Attributes: nil,
@@ -447,61 +379,43 @@ func TestMessage_Equal(t *testing.T) { //nolint:cyclop
Attributes: attrs, Attributes: attrs,
Length: 4 + 2, Length: 4 + 2,
} }
if msg1.Equal(msg2) { assert.False(t, msg1.Equal(msg2))
t.Error("should not equal") assert.False(t, msg2.Equal(msg1))
}
if msg2.Equal(msg1) {
t.Error("should not equal")
}
msg2.Attributes = nil msg2.Attributes = nil
if !msg1.Equal(msg2) { assert.True(t, msg1.Equal(msg2))
t.Error("should equal")
}
}) })
t.Run("Attributes length", func(t *testing.T) { t.Run("Attributes length", func(t *testing.T) {
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1} attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1} attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
a := &Message{Attributes: Attributes{attr}, Length: 4 + 2} a := &Message{Attributes: Attributes{attr}, Length: 4 + 2}
b := &Message{Attributes: Attributes{attr, attr1}, Length: 4 + 2} b := &Message{Attributes: Attributes{attr, attr1}, Length: 4 + 2}
if a.Equal(b) { assert.False(t, a.Equal(b))
t.Error("should not equal")
}
}) })
t.Run("Attributes values", func(t *testing.T) { t.Run("Attributes values", func(t *testing.T) {
attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1} attr := RawAttribute{Length: 2, Value: []byte{0x1, 0x2}, Type: 0x1}
attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x1} attr1 := RawAttribute{Length: 2, Value: []byte{0x1, 0x1}, Type: 0x1}
a := &Message{Attributes: Attributes{attr, attr}, Length: 4 + 2} a := &Message{Attributes: Attributes{attr, attr}, Length: 4 + 2}
b := &Message{Attributes: Attributes{attr, attr1}, Length: 4 + 2} b := &Message{Attributes: Attributes{attr, attr1}, Length: 4 + 2}
if a.Equal(b) { assert.False(t, a.Equal(b))
t.Error("should not equal")
}
}) })
} }
func TestMessageGrow(t *testing.T) { func TestMessageGrow(t *testing.T) {
m := New() m := New()
m.grow(512) m.grow(512)
if len(m.Raw) < 512 { assert.GreaterOrEqual(t, len(m.Raw), 512)
t.Error("Bad length", len(m.Raw))
}
} }
func TestMessageGrowSmaller(t *testing.T) { func TestMessageGrowSmaller(t *testing.T) {
m := New() m := New()
m.grow(2) m.grow(2)
if cap(m.Raw) < 20 { assert.GreaterOrEqual(t, cap(m.Raw), 20)
t.Error("Bad capacity", cap(m.Raw)) assert.GreaterOrEqual(t, len(m.Raw), 20)
}
if len(m.Raw) < 20 {
t.Error("Bad length", len(m.Raw))
}
} }
func TestMessage_String(t *testing.T) { func TestMessage_String(t *testing.T) {
m := New() m := New()
if m.String() == "" { assert.NotEmpty(t, m.String())
t.Error("bad string")
}
} }
func TestIsMessage(t *testing.T) { func TestIsMessage(t *testing.T) {
@@ -525,9 +439,7 @@ func TestIsMessage(t *testing.T) {
}, true}, // 6 }, true}, // 6
} }
for i, v := range tt { for i, v := range tt {
if got := IsMessage(v.in); got != v.out { assert.Equal(t, v.out, IsMessage(v.in), "tt[%d]: IsMessage(%+v)", i, v.in)
t.Errorf("tt[%d]: IsMessage(%+v) %v != %v", i, v.in, got, v.out)
}
} }
} }
@@ -553,18 +465,12 @@ func loadData(tb testing.TB, name string) []byte {
name = filepath.Join("testdata", name) name = filepath.Join("testdata", name)
f, err := os.Open(name) //nolint:gosec f, err := os.Open(name) //nolint:gosec
if err != nil { assert.NoError(tb, err)
tb.Fatal(err)
}
defer func() { defer func() {
if errClose := f.Close(); errClose != nil { assert.NoError(tb, f.Close())
tb.Fatal(errClose)
}
}() }()
v, err := io.ReadAll(f) v, err := io.ReadAll(f)
if err != nil { assert.NoError(tb, err)
tb.Fatal(err)
}
return v return v
} }
@@ -573,9 +479,7 @@ func TestExampleChrome(t *testing.T) {
buf := loadData(t, "ex1_chrome.stun") buf := loadData(t, "ex1_chrome.stun")
m := New() m := New()
_, err := m.Write(buf) _, err := m.Write(buf)
if err != nil { assert.NoError(t, err, "Failed to parse ex1_chrome")
t.Errorf("Failed to parse ex1_chrome: %s", err)
}
} }
func TestMessageFromBrowsers(t *testing.T) { func TestMessageFromBrowsers(t *testing.T) {
@@ -583,9 +487,7 @@ func TestMessageFromBrowsers(t *testing.T) {
reader := csv.NewReader(bytes.NewReader(loadData(t, "frombrowsers.csv"))) reader := csv.NewReader(bytes.NewReader(loadData(t, "frombrowsers.csv")))
reader.Comment = '#' reader.Comment = '#'
_, err := reader.Read() // skipping header _, err := reader.Read() // skipping header
if err != nil { assert.NoError(t, err, "failed to skip header of csv")
t.Fatal("failed to skip header of csv: ", err)
}
crcTable := crc64.MakeTable(crc64.ISO) crcTable := crc64.MakeTable(crc64.ISO)
msg := New() msg := New()
for { for {
@@ -593,23 +495,14 @@ func TestMessageFromBrowsers(t *testing.T) {
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
break break
} }
if err != nil { assert.NoError(t, err, "failed to read csv line")
t.Fatal("failed to read csv line: ", err)
}
data, err := base64.StdEncoding.DecodeString(line[1]) data, err := base64.StdEncoding.DecodeString(line[1])
if err != nil { assert.NoError(t, err)
t.Fatal("failed to decode ", line[1], " as base64: ", err)
}
b, err := strconv.ParseUint(line[2], 10, 64) b, err := strconv.ParseUint(line[2], 10, 64)
if err != nil { assert.NoError(t, err)
t.Fatal(err) assert.Equal(t, b, crc64.Checksum(data, crcTable), "crc64 check failed for %s", line[1])
} _, err = msg.Write(data)
if b != crc64.Checksum(data, crcTable) { assert.NoError(t, err, "failed to decode %s as message: %s", line[1], err)
t.Error("crc64 check failed for ", line[1])
}
if _, err = msg.Write(data); err != nil {
t.Error("failed to decode ", line[1], " as message: ", err)
}
msg.Reset() msg.Reset()
} }
} }
@@ -619,9 +512,7 @@ func BenchmarkMessage_NewTransactionID(b *testing.B) {
m := new(Message) m := new(Message)
m.WriteHeader() m.WriteHeader()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := m.NewTransactionID(); err != nil { assert.NoError(b, m.NewTransactionID())
b.Fatal(err)
}
} }
} }
@@ -633,12 +524,8 @@ func BenchmarkMessageFull(b *testing.B) {
IP: net.IPv4(213, 1, 223, 5), IP: net.IPv4(213, 1, 223, 5),
} }
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := addr.AddTo(msg); err != nil { assert.NoError(b, addr.AddTo(msg))
b.Fatal(err) assert.NoError(b, s.AddTo(msg))
}
if err := s.AddTo(msg); err != nil {
b.Fatal(err)
}
msg.WriteAttributes() msg.WriteAttributes()
msg.WriteHeader() msg.WriteHeader()
Fingerprint.AddTo(msg) //nolint:errcheck,gosec Fingerprint.AddTo(msg) //nolint:errcheck,gosec
@@ -655,12 +542,8 @@ func BenchmarkMessageFullHardcore(b *testing.B) {
IP: net.IPv4(213, 1, 223, 5), IP: net.IPv4(213, 1, 223, 5),
} }
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := addr.AddTo(msg); err != nil { assert.NoError(b, addr.AddTo(msg))
b.Fatal(err) assert.NoError(b, s.AddTo(msg))
}
if err := s.AddTo(msg); err != nil {
b.Fatal(err)
}
msg.WriteHeader() msg.WriteHeader()
msg.Reset() msg.Reset()
} }
@@ -684,12 +567,8 @@ func BenchmarkMessage_WriteHeader(b *testing.B) {
func TestMessage_Contains(t *testing.T) { func TestMessage_Contains(t *testing.T) {
m := new(Message) m := new(Message)
m.Add(AttrSoftware, []byte("value")) m.Add(AttrSoftware, []byte("value"))
if !m.Contains(AttrSoftware) { assert.True(t, m.Contains(AttrSoftware), "message should contain software")
t.Error("message should contain software") assert.False(t, m.Contains(AttrNonce), "message should not contain nonce")
}
if m.Contains(AttrNonce) {
t.Error("message should not contain nonce")
}
} }
func ExampleMessage() { func ExampleMessage() {
@@ -787,13 +666,9 @@ func TestAllocations(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() { allocs := testing.AllocsPerRun(10, func() {
m.Reset() m.Reset()
m.WriteHeader() m.WriteHeader()
if err := s.AddTo(m); err != nil { assert.NoError(t, s.AddTo(m), "[%d] failed to add", i)
t.Errorf("[%d] failed to add", i)
}
}) })
if allocs > 0 { assert.Zero(t, allocs, "[%d] allocated", i)
t.Errorf("[%d] allocated %.0f", i, allocs)
}
} }
} }
@@ -818,9 +693,7 @@ func TestAllocationsGetters(t *testing.T) {
Fingerprint, Fingerprint,
} }
msg := New() msg := New()
if err := msg.Build(setters...); err != nil { assert.NoError(t, msg.Build(setters...))
t.Error("failed to build", err)
}
getters := []Getter{ getters := []Getter{
new(Nonce), new(Nonce),
new(Username), new(Username),
@@ -832,66 +705,48 @@ func TestAllocationsGetters(t *testing.T) {
g := g g := g
i := i i := i
allocs := testing.AllocsPerRun(10, func() { allocs := testing.AllocsPerRun(10, func() {
if err := g.GetFrom(msg); err != nil { assert.NoError(t, g.GetFrom(msg), "[%d] failed to get", i)
t.Errorf("[%d] failed to get", i)
}
}) })
if allocs > 0 { assert.Zero(t, allocs, "[%d] allocated", i)
t.Errorf("[%d] allocated %.0f", i, allocs)
}
} }
} }
func TestMessageFullSize(t *testing.T) { func TestMessageFullSize(t *testing.T) {
msg := new(Message) msg := new(Message)
if err := msg.Build(BindingRequest, assert.NoError(t, msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{ NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}), }),
NewSoftware("pion/stun"), NewSoftware("pion/stun"),
NewLongTermIntegrity("username", "realm", "password"), NewLongTermIntegrity("username", "realm", "password"),
Fingerprint, Fingerprint,
); err != nil { ))
t.Fatal(err)
}
msg.Raw = msg.Raw[:len(msg.Raw)-10] msg.Raw = msg.Raw[:len(msg.Raw)-10]
decoder := new(Message) decoder := new(Message)
decoder.Raw = msg.Raw[:len(msg.Raw)-10] decoder.Raw = msg.Raw[:len(msg.Raw)-10]
if err := decoder.Decode(); err == nil { assert.Error(t, decoder.Decode(), "decode on truncated buffer should error")
t.Error("decode on truncated buffer should error")
}
} }
func TestMessage_CloneTo(t *testing.T) { func TestMessage_CloneTo(t *testing.T) {
msg := new(Message) msg := new(Message)
if err := msg.Build(BindingRequest, assert.NoError(t, msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{ NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}), }),
NewSoftware("pion/stun"), NewSoftware("pion/stun"),
NewLongTermIntegrity("username", "realm", "password"), NewLongTermIntegrity("username", "realm", "password"),
Fingerprint, Fingerprint,
); err != nil { ))
t.Fatal(err)
}
msg.Encode() msg.Encode()
msg2 := new(Message) msg2 := new(Message)
if err := msg.CloneTo(msg2); err != nil { assert.NoError(t, msg.CloneTo(msg2))
t.Fatal(err) assert.True(t, msg2.Equal(msg), "cloned message should equal original")
}
if !msg2.Equal(msg) {
t.Fatal("not equal")
}
// Corrupting m and checking that b is not corrupted. // Corrupting m and checking that b is not corrupted.
s, ok := msg2.Attributes.Get(AttrSoftware) s, ok := msg2.Attributes.Get(AttrSoftware)
if !ok { assert.True(t, ok)
t.Fatal("no software attribute")
}
s.Value[0] = 'k' s.Value[0] = 'k'
if msg2.Equal(msg) { assert.False(t, msg2.Equal(msg), "should not be equal")
t.Fatal("should not be equal")
}
} }
func BenchmarkMessage_CloneTo(b *testing.B) { func BenchmarkMessage_CloneTo(b *testing.B) {
@@ -919,29 +774,21 @@ func BenchmarkMessage_CloneTo(b *testing.B) {
func TestMessage_AddTo(t *testing.T) { func TestMessage_AddTo(t *testing.T) {
msg := new(Message) msg := new(Message)
if err := msg.Build(BindingRequest, assert.NoError(t, msg.Build(BindingRequest,
NewTransactionIDSetter([TransactionIDSize]byte{ NewTransactionIDSetter([TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
}), }),
Fingerprint, Fingerprint,
); err != nil { ))
t.Fatal(err)
}
msg.Encode() msg.Encode()
b := new(Message) b := new(Message)
if err := msg.CloneTo(b); err != nil { assert.NoError(t, msg.CloneTo(b))
t.Fatal(err)
}
msg.TransactionID = [TransactionIDSize]byte{ msg.TransactionID = [TransactionIDSize]byte{
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2,
} }
if b.Equal(msg) { assert.False(t, b.Equal(msg), "should not be equal")
t.Fatal("should not be equal")
}
msg.AddTo(b) //nolint:errcheck,gosec msg.AddTo(b) //nolint:errcheck,gosec
if !b.Equal(msg) { assert.True(t, b.Equal(msg), "should be equal")
t.Fatal("should be equal")
}
} }
func BenchmarkMessage_AddTo(b *testing.B) { func BenchmarkMessage_AddTo(b *testing.B) {
@@ -966,9 +813,7 @@ func BenchmarkMessage_AddTo(b *testing.B) {
func TestDecode(t *testing.T) { func TestDecode(t *testing.T) {
t.Run("Nil", func(t *testing.T) { t.Run("Nil", func(t *testing.T) {
if err := Decode(nil, nil); !errors.Is(err, ErrDecodeToNil) { assert.ErrorIs(t, Decode(nil, nil), ErrDecodeToNil)
t.Errorf("unexpected error: %v", err)
}
}) })
msg := New() msg := New()
msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest} msg.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
@@ -976,22 +821,14 @@ func TestDecode(t *testing.T) {
msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa}) msg.Add(AttrErrorCode, []byte{0xff, 0xfe, 0xfa})
msg.WriteHeader() msg.WriteHeader()
mDecoded := New() mDecoded := New()
if err := Decode(msg.Raw, mDecoded); err != nil { assert.NoError(t, Decode(msg.Raw, mDecoded))
t.Errorf("unexpected error: %v", err) assert.True(t, mDecoded.Equal(msg), "decoded result is not equal to encoded message")
}
if !mDecoded.Equal(msg) {
t.Error("decoded result is not equal to encoded message")
}
t.Run("ZeroAlloc", func(t *testing.T) { t.Run("ZeroAlloc", func(t *testing.T) {
allocs := testing.AllocsPerRun(10, func() { allocs := testing.AllocsPerRun(10, func() {
mDecoded.Reset() mDecoded.Reset()
if err := Decode(msg.Raw, mDecoded); err != nil { assert.NoError(t, Decode(msg.Raw, mDecoded))
t.Error(err)
}
}) })
if allocs > 0 { assert.Zero(t, allocs, "unexpected allocations")
t.Error("unexpected allocations")
}
}) })
} }
@@ -1021,25 +858,19 @@ func TestMessage_MarshalBinary(t *testing.T) {
}, },
) )
data, err := msg.MarshalBinary() data, err := msg.MarshalBinary()
if err != nil { assert.NoError(t, err)
t.Fatal(err)
}
// Reset m.Raw to check retention. // Reset m.Raw to check retention.
for i := range msg.Raw { for i := range msg.Raw {
msg.Raw[i] = 0 msg.Raw[i] = 0
} }
if err := msg.UnmarshalBinary(data); err != nil { assert.NoError(t, msg.UnmarshalBinary(data))
t.Fatal(err)
}
// Reset data to check retention. // Reset data to check retention.
for i := range data { for i := range data {
data[i] = 0 data[i] = 0
} }
if err := msg.Decode(); err != nil { assert.NoError(t, msg.Decode())
t.Fatal(err)
}
} }
func TestMessage_GobDecode(t *testing.T) { func TestMessage_GobDecode(t *testing.T) {
@@ -1050,23 +881,17 @@ func TestMessage_GobDecode(t *testing.T) {
}, },
) )
data, err := msg.GobEncode() data, err := msg.GobEncode()
if err != nil { assert.NoError(t, err)
t.Fatal(err)
}
// Reset m.Raw to check retention. // Reset m.Raw to check retention.
for i := range msg.Raw { for i := range msg.Raw {
msg.Raw[i] = 0 msg.Raw[i] = 0
} }
if err := msg.GobDecode(data); err != nil { assert.NoError(t, msg.GobDecode(data))
t.Fatal(err)
}
// Reset data to check retention. // Reset data to check retention.
for i := range data { for i := range data {
data[i] = 0 data[i] = 0
} }
if err := msg.Decode(); err != nil { assert.NoError(t, msg.Decode())
t.Fatal(err)
}
} }

View File

@@ -6,6 +6,8 @@ package stun
import ( import (
"net" "net"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestRFC5769(t *testing.T) { //nolint:cyclop func TestRFC5769(t *testing.T) { //nolint:cyclop
@@ -32,19 +34,11 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop
"\xe5\x7a\x3b\xcf", "\xe5\x7a\x3b\xcf",
), ),
} }
if err := m.Decode(); err != nil { assert.NoError(t, m.Decode())
t.Error(err)
}
software := new(Software) software := new(Software)
if err := software.GetFrom(m); err != nil { assert.NoError(t, software.GetFrom(m))
t.Error(err) assert.Equal(t, "STUN test client", software.String())
} assert.NoError(t, Fingerprint.Check(m))
if software.String() != "STUN test client" {
t.Error("bad software: ", software)
}
if err := Fingerprint.Check(m); err != nil {
t.Error("check failed: ", err)
}
t.Run("Long-Term credentials", func(t *testing.T) { t.Run("Long-Term credentials", func(t *testing.T) {
msg := &Message{ msg := &Message{
Raw: []byte("\x00\x01\x00\x60" + Raw: []byte("\x00\x01\x00\x60" +
@@ -64,40 +58,24 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop
"\x2e\x85\xc9\xa2\x8c\xa8\x96\x66", "\x2e\x85\xc9\xa2\x8c\xa8\x96\x66",
), ),
} }
if err := msg.Decode(); err != nil { assert.NoError(t, msg.Decode())
t.Error(err)
}
u := new(Username) u := new(Username)
if err := u.GetFrom(msg); err != nil { assert.NoError(t, u.GetFrom(msg))
t.Error(err)
}
expectedUsername := "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9" expectedUsername := "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9"
if u.String() != expectedUsername { assert.Equal(t, expectedUsername, u.String())
t.Errorf("username: %q (got) != %q (exp)", u, expectedUsername)
}
n := new(Nonce) n := new(Nonce)
if err := n.GetFrom(msg); err != nil { assert.NoError(t, n.GetFrom(msg))
t.Error(err) assert.Equal(t, "f//499k954d6OL34oL9FSTvy64sA", n.String())
}
if n.String() != "f//499k954d6OL34oL9FSTvy64sA" {
t.Error("bad nonce")
}
r := new(Realm) r := new(Realm)
if err := r.GetFrom(msg); err != nil { assert.NoError(t, r.GetFrom(msg))
t.Error(err) assert.Equal(t, "example.org", r.String())
}
if r.String() != "example.org" { //nolint:goconst
t.Error("bad realm")
}
// checking HMAC // checking HMAC
i := NewLongTermIntegrity( i := NewLongTermIntegrity(
"\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9", "\u30DE\u30C8\u30EA\u30C3\u30AF\u30B9",
"example.org", "example.org",
"TheMatrIX", "TheMatrIX",
) )
if err := i.Check(msg); err != nil { assert.NoError(t, i.Check(msg))
t.Error(err)
}
}) })
}) })
t.Run("Response", func(t *testing.T) { t.Run("Response", func(t *testing.T) {
@@ -117,32 +95,18 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop
"\xc0\x7d\x4c\x96", "\xc0\x7d\x4c\x96",
), ),
} }
if err := msg.Decode(); err != nil { assert.NoError(t, msg.Decode())
t.Error(err)
}
software := new(Software) software := new(Software)
if err := software.GetFrom(msg); err != nil { assert.NoError(t, software.GetFrom(msg))
t.Error(err) assert.Equal(t, "test vector", software.String())
} assert.NoError(t, Fingerprint.Check(msg))
if software.String() != "test vector" {
t.Error("bad software: ", software)
}
if err := Fingerprint.Check(msg); err != nil {
t.Error("Check failed: ", err)
}
addr := new(XORMappedAddress) addr := new(XORMappedAddress)
if err := addr.GetFrom(msg); err != nil { assert.NoError(t, addr.GetFrom(msg))
t.Error(err) expected := "192.0.2.1"
} assert.Equalf(t, expected, addr.IP.String(), "Expected %s, got %s", expected, addr.IP)
if !addr.IP.Equal(net.ParseIP("192.0.2.1")) { assert.Equal(t, 32853, addr.Port)
t.Error("bad IP") assert.NoError(t, Fingerprint.Check(msg))
}
if addr.Port != 32853 {
t.Error("bad Port")
}
if err := Fingerprint.Check(msg); err != nil {
t.Error("check failed: ", err)
}
}) })
t.Run("IPv6", func(t *testing.T) { t.Run("IPv6", func(t *testing.T) {
msg := &Message{ msg := &Message{
@@ -162,32 +126,20 @@ func TestRFC5769(t *testing.T) { //nolint:cyclop
"\xc8\xfb\x0b\x4c", "\xc8\xfb\x0b\x4c",
), ),
} }
if err := msg.Decode(); err != nil { assert.NoError(t, msg.Decode())
t.Error(err)
}
software := new(Software) software := new(Software)
if err := software.GetFrom(msg); err != nil { assert.NoError(t, software.GetFrom(msg))
t.Error(err) assert.Equal(t, "test vector", software.String())
} assert.NoError(t, Fingerprint.Check(msg))
if software.String() != "test vector" {
t.Error("bad software: ", software)
}
if err := Fingerprint.Check(msg); err != nil {
t.Error("Check failed: ", err)
}
addr := new(XORMappedAddress) addr := new(XORMappedAddress)
if err := addr.GetFrom(msg); err != nil { assert.NoError(t, addr.GetFrom(msg))
t.Error(err) expectedIP := "2001:db8:1234:5678:11:2233:4455:6677"
} assert.Truef(
if !addr.IP.Equal(net.ParseIP("2001:db8:1234:5678:11:2233:4455:6677")) { t, addr.IP.Equal(net.ParseIP(expectedIP)),
t.Error("bad IP") "Expected %s, got %s", expectedIP, addr.IP,
} )
if addr.Port != 32853 { assert.Equal(t, 32853, addr.Port)
t.Error("bad Port") assert.NoError(t, Fingerprint.Check(msg))
}
if err := Fingerprint.Check(msg); err != nil {
t.Error("check failed: ", err)
}
}) })
}) })
} }

View File

@@ -6,6 +6,8 @@ package stun
import ( import (
"errors" "errors"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
type errorReader struct{} type errorReader struct{}
@@ -21,9 +23,7 @@ func (errorReader) Read([]byte) (int, error) {
func TestReadFullHelper(t *testing.T) { func TestReadFullHelper(t *testing.T) {
defer func() { defer func() {
if r := recover(); r == nil { assert.NotNil(t, recover(), "should panic")
t.Error("should panic")
}
}() }()
readFullOrPanic(errorReader{}, make([]byte, 1)) readFullOrPanic(errorReader{}, make([]byte, 1))
} }
@@ -36,9 +36,7 @@ func (errorWriter) Write([]byte) (int, error) {
func TestWriteHelper(t *testing.T) { func TestWriteHelper(t *testing.T) {
defer func() { defer func() {
if r := recover(); r == nil { assert.NotNil(t, recover(), "should panic")
t.Error("should panic")
}
}() }()
writeOrPanic(errorWriter{}, make([]byte, 1)) writeOrPanic(errorWriter{}, make([]byte, 1))
} }

View File

@@ -9,6 +9,8 @@ import (
"fmt" "fmt"
"net" "net"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
var errUDPServerUnsupportedNetwork = errors.New("unsupported network") var errUDPServerUnsupportedNetwork = errors.New("unsupported network")
@@ -37,9 +39,7 @@ func NewUDPServer(
} }
udpConn, err := net.ListenUDP(network, &net.UDPAddr{IP: net.ParseIP(ip), Port: 0}) udpConn, err := net.ListenUDP(network, &net.UDPAddr{IP: net.ParseIP(ip), Port: 0})
if err != nil { assert.NoError(t, err)
t.Fatal(err) //nolint:forbidigo
}
// Necessary for IPv6 // Necessary for IPv6
address := fmt.Sprintf("%s:%d", ip, udpConn.LocalAddr().(*net.UDPAddr).Port) //nolint:forcetypeassert address := fmt.Sprintf("%s:%d", ip, udpConn.LocalAddr().(*net.UDPAddr).Port) //nolint:forcetypeassert
@@ -81,18 +81,14 @@ func NewUDPServer(
select { select {
case err := <-errCh: case err := <-errCh:
if err != nil { if err != nil {
t.Fatal(err) assert.NoError(t, err)
return return
} }
default: default:
} }
err := udpConn.Close() assert.NoError(t, udpConn.Close())
if err != nil {
t.Fatal(err)
}
<-errCh <-errCh
}, nil }, nil
} }

View File

@@ -7,9 +7,10 @@
package stun package stun
import ( import (
"errors"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestSoftware_GetFrom(t *testing.T) { func TestSoftware_GetFrom(t *testing.T) {
@@ -22,44 +23,29 @@ func TestSoftware_GetFrom(t *testing.T) {
Raw: make([]byte, 0, 256), Raw: make([]byte, 0, 256),
} }
software := new(Software) software := new(Software)
if _, err := m2.ReadFrom(msg.reader()); err != nil { _, err := m2.ReadFrom(msg.reader())
t.Error(err) assert.NoError(t, err)
} assert.NoError(t, software.GetFrom(msg))
if err := software.GetFrom(msg); err != nil { assert.Equal(t, val, software.String())
t.Fatal(err)
}
if software.String() != val {
t.Errorf("Expected %q, got %q.", val, software)
}
sAttr, ok := msg.Attributes.Get(AttrSoftware) sAttr, ok := msg.Attributes.Get(AttrSoftware)
if !ok { assert.True(t, ok, "software attribute should be found")
t.Error("software attribute should be found")
}
s := sAttr.String() s := sAttr.String()
if !strings.HasPrefix(s, "SOFTWARE:") { assert.True(t, strings.HasPrefix(s, "SOFTWARE:"), "bad string representation")
t.Error("bad string representation", s)
}
} }
func TestSoftware_AddTo_Invalid(t *testing.T) { func TestSoftware_AddTo_Invalid(t *testing.T) {
m := New() m := New()
s := make(Software, 1024) s := make(Software, 1024)
if err := s.AddTo(m); !IsAttrSizeOverflow(err) { assert.True(t, IsAttrSizeOverflow(s.AddTo(m)), "AddTo should return *AttrOverflowErr")
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) assert.ErrorIs(t, s.GetFrom(m), ErrAttributeNotFound)
}
if err := s.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
} }
func TestSoftware_AddTo_Regression(t *testing.T) { func TestSoftware_AddTo_Regression(t *testing.T) {
// s.AddTo checked len(m.Raw) instead of len(s.Raw). // s.AddTo checked len(m.Raw) instead of len(s.Raw).
m := &Message{Raw: make([]byte, 2048)} m := &Message{Raw: make([]byte, 2048)}
s := make(Software, 100) s := make(Software, 100)
if err := s.AddTo(m); err != nil { assert.NoError(t, s.AddTo(m))
t.Errorf("AddTo should return <nil>, got: %v", err)
}
} }
func BenchmarkUsername_AddTo(b *testing.B) { func BenchmarkUsername_AddTo(b *testing.B) {
@@ -95,28 +81,18 @@ func TestUsername(t *testing.T) {
msg.WriteHeader() msg.WriteHeader()
t.Run("Bad length", func(t *testing.T) { t.Run("Bad length", func(t *testing.T) {
badU := make(Username, 600) badU := make(Username, 600)
if err := badU.AddTo(msg); !IsAttrSizeOverflow(err) { assert.True(t, IsAttrSizeOverflow(badU.AddTo(msg)), "AddTo should return *AttrOverflowErr")
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
}
}) })
t.Run("AddTo", func(t *testing.T) { t.Run("AddTo", func(t *testing.T) {
if err := uName.AddTo(msg); err != nil { assert.NoError(t, uName.AddTo(msg))
t.Error("errored:", err)
}
t.Run("GetFrom", func(t *testing.T) { t.Run("GetFrom", func(t *testing.T) {
got := new(Username) got := new(Username)
if err := got.GetFrom(msg); err != nil { assert.NoError(t, got.GetFrom(msg))
t.Error("errored:", err) assert.Equal(t, username, got.String())
}
if got.String() != username {
t.Errorf("expedted: %s, got: %s", username, got)
}
t.Run("Not found", func(t *testing.T) { t.Run("Not found", func(t *testing.T) {
m := new(Message) m := new(Message)
u := new(Username) u := new(Username)
if err := u.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) { assert.ErrorIs(t, u.GetFrom(m), ErrAttributeNotFound)
t.Error("Should error")
}
}) })
}) })
}) })
@@ -124,14 +100,10 @@ func TestUsername(t *testing.T) {
m := new(Message) m := new(Message)
m.WriteHeader() m.WriteHeader()
u := NewUsername("username") u := NewUsername("username")
if allocs := testing.AllocsPerRun(10, func() { assert.Empty(t, testing.AllocsPerRun(10, func() {
if err := u.AddTo(m); err != nil { assert.NoError(t, u.AddTo(m))
t.Error(err)
}
m.Reset() m.Reset()
}); allocs > 0 { }))
t.Errorf("got %f allocations, zero expected", allocs)
}
}) })
} }
@@ -145,38 +117,23 @@ func TestRealm_GetFrom(t *testing.T) {
Raw: make([]byte, 0, 256), Raw: make([]byte, 0, 256),
} }
r := new(Realm) r := new(Realm)
if err := r.GetFrom(m2); !errors.Is(err, ErrAttributeNotFound) { assert.ErrorIs(t, r.GetFrom(m2), ErrAttributeNotFound)
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err) _, err := m2.ReadFrom(msg.reader())
} assert.NoError(t, err)
if _, err := m2.ReadFrom(msg.reader()); err != nil { assert.NoError(t, r.GetFrom(msg))
t.Error(err) assert.Equal(t, val, r.String())
}
if err := r.GetFrom(msg); err != nil {
t.Fatal(err)
}
if r.String() != val {
t.Errorf("Expected %q, got %q.", val, r)
}
rAttr, ok := msg.Attributes.Get(AttrRealm) rAttr, ok := msg.Attributes.Get(AttrRealm)
if !ok { assert.True(t, ok, "realm attribute should be found")
t.Error("realm attribute should be found")
}
s := rAttr.String() s := rAttr.String()
if !strings.HasPrefix(s, "REALM:") { assert.True(t, strings.HasPrefix(s, "REALM:"), "bad string representation")
t.Error("bad string representation", s)
}
} }
func TestRealm_AddTo_Invalid(t *testing.T) { func TestRealm_AddTo_Invalid(t *testing.T) {
m := New() m := New()
r := make(Realm, 1024) r := make(Realm, 1024)
if err := r.AddTo(m); !IsAttrSizeOverflow(err) { assert.True(t, IsAttrSizeOverflow(r.AddTo(m)), "AddTo should return *AttrOverflowErr")
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) assert.ErrorIs(t, r.GetFrom(m), ErrAttributeNotFound)
}
if err := r.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
} }
func TestNonce_GetFrom(t *testing.T) { func TestNonce_GetFrom(t *testing.T) {
@@ -189,50 +146,31 @@ func TestNonce_GetFrom(t *testing.T) {
Raw: make([]byte, 0, 256), Raw: make([]byte, 0, 256),
} }
var nonce Nonce var nonce Nonce
if _, err := m2.ReadFrom(msg.reader()); err != nil { _, err := m2.ReadFrom(msg.reader())
t.Error(err) assert.NoError(t, err)
} assert.NoError(t, nonce.GetFrom(msg))
if err := nonce.GetFrom(msg); err != nil { assert.Equal(t, val, nonce.String())
t.Fatal(err)
}
if nonce.String() != val {
t.Errorf("Expected %q, got %q.", val, nonce)
}
nAttr, ok := msg.Attributes.Get(AttrNonce) nAttr, ok := msg.Attributes.Get(AttrNonce)
if !ok { assert.True(t, ok, "nonce attribute should be found")
t.Error("nonce attribute should be found")
}
s := nAttr.String() s := nAttr.String()
if !strings.HasPrefix(s, "NONCE:") { assert.True(t, strings.HasPrefix(s, "NONCE:"), "bad string representation")
t.Error("bad string representation", s)
}
} }
func TestNonce_AddTo_Invalid(t *testing.T) { func TestNonce_AddTo_Invalid(t *testing.T) {
m := New() m := New()
n := make(Nonce, 1024) n := make(Nonce, 1024)
if err := n.AddTo(m); !IsAttrSizeOverflow(err) { assert.True(t, IsAttrSizeOverflow(n.AddTo(m)), "AddTo should return *AttrOverflowErr")
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err) assert.ErrorIs(t, n.GetFrom(m), ErrAttributeNotFound)
}
if err := n.GetFrom(m); !errors.Is(err, ErrAttributeNotFound) {
t.Errorf("GetFrom should return %q, got: %v", ErrAttributeNotFound, err)
}
} }
func TestNonce_AddTo(t *testing.T) { func TestNonce_AddTo(t *testing.T) {
m := New() m := New()
n := Nonce("example.org") n := Nonce("example.org")
if err := n.AddTo(m); err != nil { assert.NoError(t, n.AddTo(m))
t.Error(err)
}
v, err := m.Get(AttrNonce) v, err := m.Get(AttrNonce)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.Equal(t, "example.org", string(v))
}
if string(v) != "example.org" {
t.Errorf("bad nonce %q", v)
}
} }
func BenchmarkNonce_AddTo(b *testing.B) { func BenchmarkNonce_AddTo(b *testing.B) {

View File

@@ -5,6 +5,8 @@ package stun
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestUnknownAttributes(t *testing.T) { func TestUnknownAttributes(t *testing.T) {
@@ -13,33 +15,19 @@ func TestUnknownAttributes(t *testing.T) {
AttrDontFragment, AttrDontFragment,
AttrChannelNumber, AttrChannelNumber,
} }
if attr.String() != "DONT-FRAGMENT, CHANNEL-NUMBER" { assert.Equal(t, "DONT-FRAGMENT, CHANNEL-NUMBER", attr.String())
t.Error("bad String:", attr) assert.Equal(t, "<nil>", (UnknownAttributes{}).String())
} assert.NoError(t, attr.AddTo(msg))
if (UnknownAttributes{}).String() != "<nil>" {
t.Error("bad blank string")
}
if err := attr.AddTo(msg); err != nil {
t.Error(err)
}
t.Run("GetFrom", func(t *testing.T) { t.Run("GetFrom", func(t *testing.T) {
attrs := make(UnknownAttributes, 10) attrs := make(UnknownAttributes, 10)
if err := attrs.GetFrom(msg); err != nil { assert.NoError(t, attrs.GetFrom(msg))
t.Error(err)
}
for i, at := range *attr { for i, at := range *attr {
if at != attrs[i] { assert.Equal(t, at, attrs[i])
t.Error("expected", at, "!=", attrs[i])
}
} }
mBlank := new(Message) mBlank := new(Message)
if err := attrs.GetFrom(mBlank); err == nil { assert.Error(t, attrs.GetFrom(mBlank))
t.Error("should error")
}
mBlank.Add(AttrUnknownAttributes, []byte{1, 2, 3}) mBlank.Add(AttrUnknownAttributes, []byte{1, 2, 3})
if err := attrs.GetFrom(mBlank); err == nil { assert.Error(t, attrs.GetFrom(mBlank))
t.Error("should error")
}
}) })
} }

View File

@@ -11,10 +11,11 @@ import (
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"errors"
"io" "io"
"net" "net"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func BenchmarkXORMappedAddress_AddTo(b *testing.B) { func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
@@ -31,82 +32,56 @@ func BenchmarkXORMappedAddress_AddTo(b *testing.B) {
func BenchmarkXORMappedAddress_GetFrom(b *testing.B) { func BenchmarkXORMappedAddress_GetFrom(b *testing.B) {
msg := New() msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil { assert.NoError(b, err)
b.Error(err)
}
copy(msg.TransactionID[:], transactionID) copy(msg.TransactionID[:], transactionID)
addrValue, err := hex.DecodeString("00019cd5f49f38ae") addrValue, err := hex.DecodeString("00019cd5f49f38ae")
if err != nil { assert.NoError(b, err)
b.Error(err)
}
msg.Add(AttrXORMappedAddress, addrValue) msg.Add(AttrXORMappedAddress, addrValue)
addr := new(XORMappedAddress) addr := new(XORMappedAddress)
b.ReportAllocs() b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if err := addr.GetFrom(msg); err != nil { assert.NoError(b, addr.GetFrom(msg))
b.Fatal(err)
}
} }
} }
func TestXORMappedAddress_GetFrom(t *testing.T) { func TestXORMappedAddress_GetFrom(t *testing.T) {
m := New() m := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil { assert.NoError(t, err)
t.Error(err)
}
copy(m.TransactionID[:], transactionID) copy(m.TransactionID[:], transactionID)
addrValue, err := hex.DecodeString("00019cd5f49f38ae") addrValue, err := hex.DecodeString("00019cd5f49f38ae")
if err != nil { assert.NoError(t, err)
t.Error(err)
}
m.Add(AttrXORMappedAddress, addrValue) m.Add(AttrXORMappedAddress, addrValue)
addr := new(XORMappedAddress) addr := new(XORMappedAddress)
if err = addr.GetFrom(m); err != nil { assert.NoError(t, addr.GetFrom(m))
t.Error(err) assert.True(t, addr.IP.Equal(net.ParseIP("213.141.156.236")))
} assert.Equal(t, 48583, addr.Port)
if !addr.IP.Equal(net.ParseIP("213.141.156.236")) {
t.Error("bad IP", addr.IP, "!=", "213.141.156.236")
}
if addr.Port != 48583 {
t.Error("bad Port", addr.Port, "!=", 48583)
}
t.Run("UnexpectedEOF", func(t *testing.T) { t.Run("UnexpectedEOF", func(t *testing.T) {
m := New() m := New()
// {0, 1} is correct addr family. // {0, 1} is correct addr family.
m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4}) m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4})
addr := new(XORMappedAddress) addr := new(XORMappedAddress)
if err = addr.GetFrom(m); !errors.Is(err, io.ErrUnexpectedEOF) { assert.ErrorIs(t, addr.GetFrom(m), io.ErrUnexpectedEOF, "len(v) = 4 should return io.ErrUnexpectedEOF")
t.Errorf("len(v) = 4 should render <%s> error, got <%s>",
io.ErrUnexpectedEOF, err,
)
}
}) })
t.Run("AttrOverflowErr", func(t *testing.T) { t.Run("AttrOverflowErr", func(t *testing.T) {
m := New() m := New()
// {0, 1} is correct addr family. // {0, 1} is correct addr family.
m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4, 5, 6, 7, 8, 9, 1, 1, 1, 1, 1, 2, 3, 4}) m.Add(AttrXORMappedAddress, []byte{0, 1, 3, 4, 5, 6, 7, 8, 9, 1, 1, 1, 1, 1, 2, 3, 4})
addr := new(XORMappedAddress) addr := new(XORMappedAddress)
if err := addr.GetFrom(m); !IsAttrSizeOverflow(err) { assert.True(t, IsAttrSizeOverflow(addr.GetFrom(m)), "GetFrom should return *AttrOverflowErr")
t.Errorf("AddTo should return *AttrOverflowErr, got: %v", err)
}
}) })
} }
func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) { func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
msg := New() msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil { assert.NoError(t, err)
t.Error(err)
}
copy(msg.TransactionID[:], transactionID) copy(msg.TransactionID[:], transactionID)
expectedIP := net.ParseIP("213.141.156.236") expectedIP := net.ParseIP("213.141.156.236")
expectedPort := 21254 expectedPort := 21254
addr := new(XORMappedAddress) addr := new(XORMappedAddress)
if err = addr.GetFrom(msg); err == nil { assert.Error(t, addr.GetFrom(msg))
t.Fatal(err, "should be nil")
}
addr.IP = expectedIP addr.IP = expectedIP
addr.Port = expectedPort addr.Port = expectedPort
@@ -115,20 +90,15 @@ func TestXORMappedAddress_GetFrom_Invalid(t *testing.T) {
mRes := New() mRes := New()
binary.BigEndian.PutUint16(msg.Raw[20+4:20+4+2], 0x21) binary.BigEndian.PutUint16(msg.Raw[20+4:20+4+2], 0x21)
if _, err = mRes.ReadFrom(bytes.NewReader(msg.Raw)); err != nil { _, err = mRes.ReadFrom(bytes.NewReader(msg.Raw))
t.Fatal(err) assert.NoError(t, err)
} assert.Error(t, addr.GetFrom(msg))
if err = addr.GetFrom(msg); err == nil {
t.Fatal(err, "should not be nil")
}
} }
func TestXORMappedAddress_AddTo(t *testing.T) { func TestXORMappedAddress_AddTo(t *testing.T) {
msg := New() msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil { assert.NoError(t, err)
t.Error(err)
}
copy(msg.TransactionID[:], transactionID) copy(msg.TransactionID[:], transactionID)
expectedIP := net.ParseIP("213.141.156.236") expectedIP := net.ParseIP("213.141.156.236")
expectedPort := 21254 expectedPort := 21254
@@ -136,31 +106,20 @@ func TestXORMappedAddress_AddTo(t *testing.T) {
IP: net.ParseIP("213.141.156.236"), IP: net.ParseIP("213.141.156.236"),
Port: expectedPort, Port: expectedPort,
} }
if err = addr.AddTo(msg); err != nil { assert.NoError(t, addr.AddTo(msg))
t.Fatal(err)
}
msg.WriteHeader() msg.WriteHeader()
mRes := New() mRes := New()
if _, err = mRes.Write(msg.Raw); err != nil { _, err = mRes.Write(msg.Raw)
t.Fatal(err) assert.NoError(t, err)
} assert.NoError(t, addr.GetFrom(mRes))
if err = addr.GetFrom(mRes); err != nil { assert.True(t, addr.IP.Equal(expectedIP), "Expected %s, got %s", expectedIP, addr.IP)
t.Fatal(err) assert.Equal(t, expectedPort, addr.Port)
}
if !addr.IP.Equal(expectedIP) {
t.Errorf("%s (got) != %s (expected)", addr.IP, expectedIP)
}
if addr.Port != expectedPort {
t.Error("bad Port", addr.Port, "!=", expectedPort)
}
} }
func TestXORMappedAddress_AddTo_IPv6(t *testing.T) { func TestXORMappedAddress_AddTo_IPv6(t *testing.T) {
msg := New() msg := New()
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er") transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil { assert.NoError(t, err)
t.Error(err)
}
copy(msg.TransactionID[:], transactionID) copy(msg.TransactionID[:], transactionID)
expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009") expectedIP := net.ParseIP("fe80::dc2b:44ff:fe20:6009")
expectedPort := 21254 expectedPort := 21254
@@ -172,19 +131,12 @@ func TestXORMappedAddress_AddTo_IPv6(t *testing.T) {
msg.WriteHeader() msg.WriteHeader()
mRes := New() mRes := New()
if _, err = mRes.ReadFrom(msg.reader()); err != nil { _, err = mRes.ReadFrom(msg.reader())
t.Fatal(err) assert.NoError(t, err)
}
gotAddr := new(XORMappedAddress) gotAddr := new(XORMappedAddress)
if err = gotAddr.GetFrom(msg); err != nil { assert.NoError(t, gotAddr.GetFrom(mRes))
t.Fatal(err) assert.True(t, gotAddr.IP.Equal(expectedIP), "Expected %s, got %s", expectedIP, gotAddr.IP)
} assert.Equal(t, expectedPort, gotAddr.Port)
if !gotAddr.IP.Equal(expectedIP) {
t.Error("bad IP", gotAddr.IP, "!=", expectedIP)
}
if gotAddr.Port != expectedPort {
t.Error("bad Port", gotAddr.Port, "!=", expectedPort)
}
} }
func TestXORMappedAddress_AddTo_Invalid(t *testing.T) { func TestXORMappedAddress_AddTo_Invalid(t *testing.T) {
@@ -193,9 +145,7 @@ func TestXORMappedAddress_AddTo_Invalid(t *testing.T) {
IP: []byte{1, 2, 3, 4, 5, 6, 7, 8}, IP: []byte{1, 2, 3, 4, 5, 6, 7, 8},
Port: 21254, Port: 21254,
} }
if err := addr.AddTo(m); !errors.Is(err, ErrBadIPLength) { assert.ErrorIs(t, addr.AddTo(m), ErrBadIPLength)
t.Errorf("AddTo should return %q, got: %v", ErrBadIPLength, err)
}
} }
func TestXORMappedAddress_String(t *testing.T) { func TestXORMappedAddress_String(t *testing.T) {
@@ -219,12 +169,6 @@ func TestXORMappedAddress_String(t *testing.T) {
}, },
} }
for i, c := range tt { for i, c := range tt {
if got := c.in.String(); got != c.out { assert.Equalf(t, c.out, c.in.String(), "[%d]: XORMappesAddres.String()", i)
t.Errorf("[%d]: XORMappesAddres.String() %s (got) != %s (expected)",
i,
got,
c.out,
)
}
} }
} }