message: add ForEach helper

This commit is contained in:
Aleksandr Razumov
2018-08-02 10:51:10 +03:00
parent 04a5217700
commit 23a3a57854
2 changed files with 117 additions and 0 deletions

View File

@@ -77,3 +77,26 @@ func Build(setters ...Setter) (*Message, error) {
m := new(Message)
return m, m.Build(setters...)
}
// ForEach is helper that iterates over message attributes allowing to call
// Getter in f callback to get all attributes of type t and returning on first
// f error.
//
// The m.Get method inside f will be returning next attribute on each f call.
// Does not error if there are no results.
func (m *Message) ForEach(t AttrType, f func(m *Message) error) error {
attrs := m.Attributes
defer func() {
m.Attributes = attrs
}()
for i, a := range attrs {
if a.Type != t {
continue
}
m.Attributes = attrs[i:]
if err := f(m); err != nil {
return err
}
}
return nil
}

View File

@@ -3,6 +3,8 @@ package stun
import (
"errors"
"testing"
"github.com/gortc/stun/internal/testutil"
)
func BenchmarkBuildOverhead(b *testing.B) {
@@ -112,3 +114,95 @@ func TestHelpersErrorHandling(t *testing.T) {
MustBuild(e)
})
}
func TestMessage_ForEach(t *testing.T) {
initial := New()
if err := initial.Build(
NewRealm("realm1"), NewRealm("realm2"),
); err != nil {
t.Fatal(err)
}
newMessage := func() *Message {
m := New()
if err := m.Build(
NewRealm("realm1"), NewRealm("realm2"),
); err != nil {
t.Fatal(err)
}
return m
}
t.Run("NoResults", func(t *testing.T) {
m := newMessage()
if !m.Equal(initial) {
t.Error("m should be equal to initial")
}
if err := m.ForEach(AttrUsername, func(m *Message) error {
t.Error("should not be called")
return nil
}); err != nil {
t.Fatal(err)
}
if !m.Equal(initial) {
t.Error("m should be equal to initial")
}
})
t.Run("ReturnOnError", func(t *testing.T) {
m := newMessage()
var calls int
if err := m.ForEach(AttrRealm, func(m *Message) error {
if calls > 0 {
t.Error("called multiple times")
}
calls++
return ErrAttributeNotFound
}); err != ErrAttributeNotFound {
t.Fatal(err)
}
if !m.Equal(initial) {
t.Error("m should be equal to initial")
}
})
t.Run("Positive", func(t *testing.T) {
m := newMessage()
var realms []string
if err := m.ForEach(AttrRealm, func(m *Message) error {
var realm Realm
if err := realm.GetFrom(m); err != nil {
return err
}
realms = append(realms, realm.String())
return nil
}); err != nil {
t.Fatal(err)
}
if len(realms) != 2 {
t.Fatal("expected 2 realms")
}
if realms[0] != "realm1" {
t.Error("bad value for 1 realm")
}
if realms[1] != "realm2" {
t.Error("bad value for 2 realm")
}
if !m.Equal(initial) {
t.Error("m should be equal to initial")
}
t.Run("ZeroAlloc", func(t *testing.T) {
m = newMessage()
var realm Realm
testutil.ShouldNotAllocate(t, func() {
if err := m.ForEach(AttrRealm, func(m *Message) error {
if err := realm.GetFrom(m); err != nil {
return err
}
return nil
}); err != nil {
t.Fatal(err)
}
})
if !m.Equal(initial) {
t.Error("m should be equal to initial")
}
})
})
}