diff --git a/attributes.go b/attributes.go index b158c67..addbaca 100644 --- a/attributes.go +++ b/attributes.go @@ -7,26 +7,20 @@ import ( "strconv" ) -// blank is just blank string and exists just because it is ugly to keep it -// in code. -const blank = "" - // Attributes is list of message attributes. type Attributes []Attribute -// BlankAttribute is attribute that is returned by -// Attributes.Get if nothing found. -var BlankAttribute = Attribute{} - -// Get returns first attribute from list which match AttrType. If nothing -// found, it returns blank attribute. -func (a Attributes) Get(t AttrType) Attribute { +// Get returns first attribute from list by the type. +// If attribute is present the Attribute is returned and the +// boolean is true. Otherwise the returned Attribute will be +// empty and boolean will be false. +func (a Attributes) Get(t AttrType) (Attribute, bool) { for _, candidate := range a { if candidate.Type == t { - return candidate + return candidate, true } } - return BlankAttribute + return Attribute{}, false } // AttrType is attribute type. @@ -132,11 +126,6 @@ type Attribute struct { Value []byte } -// IsBlank returns true if attribute equals to BlankAttribute. -func (a Attribute) IsBlank() bool { - return a.Equal(BlankAttribute) -} - // Equal returns true if a == b. func (a Attribute) Equal(b Attribute) bool { if a.Type != b.Type { @@ -164,11 +153,11 @@ func (a Attribute) String() string { // if there is no value attribute with shuck type, // ErrAttributeNotFound is returned. func (m *Message) getAttrValue(t AttrType) ([]byte, error) { - v := m.Attributes.Get(t).Value - if len(v) == 0 { + v, ok := m.Attributes.Get(t) + if !ok { return nil, ErrAttributeNotFound } - return v, nil + return v.Value, nil } // AddSoftwareBytes adds SOFTWARE attribute with value from byte slice. @@ -184,18 +173,16 @@ func (m *Message) AddSoftware(software string) { // GetSoftwareBytes returns SOFTWARE attribute value in byte slice. // If not found, returns nil. func (m *Message) GetSoftwareBytes() []byte { - return m.Attributes.Get(AttrSoftware).Value + v, ok := m.Attributes.Get(AttrSoftware) + if !ok { + return nil + } + return v.Value } // GetSoftware returns SOFTWARE attribute value in string. // If not found, returns blank string. -func (m *Message) GetSoftware() string { - v := m.GetSoftwareBytes() - if len(v) == 0 { - return blank - } - return string(v) -} +func (m *Message) GetSoftware() string { return string(m.GetSoftwareBytes()) } // Address family values. const ( diff --git a/attributes_test.go b/attributes_test.go index ec85777..fc9179a 100644 --- a/attributes_test.go +++ b/attributes_test.go @@ -27,7 +27,10 @@ func TestMessage_AddSoftware(t *testing.T) { t.Errorf("Expected %s, got %s.", v, vRead) } - sAttr := m.Attributes.Get(AttrSoftware) + sAttr, ok := m.Attributes.Get(AttrSoftware) + if !ok { + t.Error("sowfware attribute should be found") + } s := sAttr.String() if !strings.HasPrefix(s, "SOFTWARE:") { t.Error("bad string representation", s) @@ -282,20 +285,3 @@ func TestMessage_AddErrorCodeDefault(t *testing.T) { t.Error("bad reason", string(reason)) } } - -func TestAttribute_IsBlank(t *testing.T) { - var tt = [...]struct { - in Attribute - out bool - }{ - {BlankAttribute, true}, // 0 - {Attribute{Type: AttrUseCandidate}, false}, // 1 - {Attribute{Value: []byte{1, 2, 3}}, false}, // 2 - {Attribute{}, true}, // 3 - } - for i, v := range tt { - if got := v.in.IsBlank(); got != v.out { - t.Errorf("tt[%d]: (%+v).IsMessage %v != %v", i, v.in, got, v.out) - } - } -} diff --git a/stun-cli/main.go b/stun-cli/main.go index f0c4832..59d388a 100644 --- a/stun-cli/main.go +++ b/stun-cli/main.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "os" + "strings" "time" "github.com/codegangsta/cli" @@ -15,6 +16,16 @@ const ( version = "0.2" ) +func normalize(address string) string { + if len(address) == 0 { + address = "0.0.0.0" + } + if !strings.Contains(address, ":") { + address = fmt.Sprintf("%s:%d", address, stun.DefaultPort) + } + return address +} + // Defaults for Client fields. const ( DefaultClientRetries = 9 @@ -179,7 +190,7 @@ func discover(c *cli.Context) error { request := Request{ Message: m, - Target: stun.Normalize(c.String("server")), + Target: normalize(c.String("server")), } return DefaultClient.Do(request, func(r Response) error { diff --git a/stun.go b/stun.go index b037869..257c969 100644 --- a/stun.go +++ b/stun.go @@ -28,7 +28,6 @@ import ( "io" "net" "strconv" - "strings" "sync" "github.com/ernado/buffer" @@ -39,17 +38,6 @@ var ( bin = binary.BigEndian ) -// Normalize returns normalized address. -func Normalize(address string) string { - if len(address) == 0 { - address = "0.0.0.0" - } - if !strings.Contains(address, ":") { - address = fmt.Sprintf("%s:%d", address, DefaultPort) - } - return address -} - // DefaultPort is IANA assigned port for "stun" protocol. const DefaultPort = 3478 @@ -287,7 +275,10 @@ func (m Message) Equal(b Message) bool { return false } for _, a := range m.Attributes { - aB := b.Attributes.Get(a.Type) + aB, ok := b.Attributes.Get(a.Type) + if !ok { + return false + } if !aB.Equal(a) { return false } diff --git a/stund/main.go b/stund/main.go index 9c2262b..5a94bb5 100644 --- a/stund/main.go +++ b/stund/main.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "net" + "strings" "github.com/ernado/stun" "github.com/pkg/errors" @@ -156,11 +157,21 @@ func ListenUDPAndServe(serverNet, laddr string) error { return s.Serve(c) } +func normalize(address string) string { + if len(address) == 0 { + address = "0.0.0.0" + } + if !strings.Contains(address, ":") { + address = fmt.Sprintf("%s:%d", address, stun.DefaultPort) + } + return address +} + func main() { flag.Parse() switch *network { case "udp": - normalized := stun.Normalize(*address) + normalized := normalize(*address) fmt.Println("cydev/stun listening on", normalized, "via", *network) log.Fatal(ListenUDPAndServe(*network, normalized)) default: diff --git a/stund/server_test.go b/stund/server_test.go index ad72b47..602e3ca 100644 --- a/stund/server_test.go +++ b/stund/server_test.go @@ -1,9 +1,10 @@ package main import ( - "github.com/ernado/stun" "net" "testing" + + "github.com/ernado/stun" ) func BenchmarkBasicProcess(b *testing.B) {