attributes, stun: refactor

This commit is contained in:
Aleksandr Razumov
2017-02-01 09:19:43 +03:00
parent 4ef2b466cf
commit 2d0e6be7eb
6 changed files with 50 additions and 63 deletions

View File

@@ -7,26 +7,20 @@ import (
"strconv" "strconv"
) )
// blank is just blank string and exists just because it is ugly to keep it
// in code.
const blank = ""
// Attributes is list of message attributes. // Attributes is list of message attributes.
type Attributes []Attribute type Attributes []Attribute
// BlankAttribute is attribute that is returned by // Get returns first attribute from list by the type.
// Attributes.Get if nothing found. // If attribute is present the Attribute is returned and the
var BlankAttribute = Attribute{} // boolean is true. Otherwise the returned Attribute will be
// empty and boolean will be false.
// Get returns first attribute from list which match AttrType. If nothing func (a Attributes) Get(t AttrType) (Attribute, bool) {
// found, it returns blank attribute.
func (a Attributes) Get(t AttrType) Attribute {
for _, candidate := range a { for _, candidate := range a {
if candidate.Type == t { if candidate.Type == t {
return candidate return candidate, true
} }
} }
return BlankAttribute return Attribute{}, false
} }
// AttrType is attribute type. // AttrType is attribute type.
@@ -132,11 +126,6 @@ type Attribute struct {
Value []byte Value []byte
} }
// IsBlank returns true if attribute equals to BlankAttribute.
func (a Attribute) IsBlank() bool {
return a.Equal(BlankAttribute)
}
// Equal returns true if a == b. // Equal returns true if a == b.
func (a Attribute) Equal(b Attribute) bool { func (a Attribute) Equal(b Attribute) bool {
if a.Type != b.Type { if a.Type != b.Type {
@@ -164,11 +153,11 @@ func (a Attribute) String() string {
// if there is no value attribute with shuck type, // if there is no value attribute with shuck type,
// ErrAttributeNotFound is returned. // ErrAttributeNotFound is returned.
func (m *Message) getAttrValue(t AttrType) ([]byte, error) { func (m *Message) getAttrValue(t AttrType) ([]byte, error) {
v := m.Attributes.Get(t).Value v, ok := m.Attributes.Get(t)
if len(v) == 0 { if !ok {
return nil, ErrAttributeNotFound return nil, ErrAttributeNotFound
} }
return v, nil return v.Value, nil
} }
// AddSoftwareBytes adds SOFTWARE attribute with value from byte slice. // AddSoftwareBytes adds SOFTWARE attribute with value from byte slice.
@@ -184,18 +173,16 @@ func (m *Message) AddSoftware(software string) {
// GetSoftwareBytes returns SOFTWARE attribute value in byte slice. // GetSoftwareBytes returns SOFTWARE attribute value in byte slice.
// If not found, returns nil. // If not found, returns nil.
func (m *Message) GetSoftwareBytes() []byte { func (m *Message) GetSoftwareBytes() []byte {
return m.Attributes.Get(AttrSoftware).Value v, ok := m.Attributes.Get(AttrSoftware)
if !ok {
return nil
}
return v.Value
} }
// GetSoftware returns SOFTWARE attribute value in string. // GetSoftware returns SOFTWARE attribute value in string.
// If not found, returns blank string. // If not found, returns blank string.
func (m *Message) GetSoftware() string { func (m *Message) GetSoftware() string { return string(m.GetSoftwareBytes()) }
v := m.GetSoftwareBytes()
if len(v) == 0 {
return blank
}
return string(v)
}
// Address family values. // Address family values.
const ( const (

View File

@@ -27,7 +27,10 @@ func TestMessage_AddSoftware(t *testing.T) {
t.Errorf("Expected %s, got %s.", v, vRead) t.Errorf("Expected %s, got %s.", v, vRead)
} }
sAttr := m.Attributes.Get(AttrSoftware) sAttr, ok := m.Attributes.Get(AttrSoftware)
if !ok {
t.Error("sowfware attribute should be found")
}
s := sAttr.String() s := sAttr.String()
if !strings.HasPrefix(s, "SOFTWARE:") { if !strings.HasPrefix(s, "SOFTWARE:") {
t.Error("bad string representation", s) t.Error("bad string representation", s)
@@ -282,20 +285,3 @@ func TestMessage_AddErrorCodeDefault(t *testing.T) {
t.Error("bad reason", string(reason)) t.Error("bad reason", string(reason))
} }
} }
func TestAttribute_IsBlank(t *testing.T) {
var tt = [...]struct {
in Attribute
out bool
}{
{BlankAttribute, true}, // 0
{Attribute{Type: AttrUseCandidate}, false}, // 1
{Attribute{Value: []byte{1, 2, 3}}, false}, // 2
{Attribute{}, true}, // 3
}
for i, v := range tt {
if got := v.in.IsBlank(); got != v.out {
t.Errorf("tt[%d]: (%+v).IsMessage %v != %v", i, v.in, got, v.out)
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"strings"
"time" "time"
"github.com/codegangsta/cli" "github.com/codegangsta/cli"
@@ -15,6 +16,16 @@ const (
version = "0.2" version = "0.2"
) )
func normalize(address string) string {
if len(address) == 0 {
address = "0.0.0.0"
}
if !strings.Contains(address, ":") {
address = fmt.Sprintf("%s:%d", address, stun.DefaultPort)
}
return address
}
// Defaults for Client fields. // Defaults for Client fields.
const ( const (
DefaultClientRetries = 9 DefaultClientRetries = 9
@@ -179,7 +190,7 @@ func discover(c *cli.Context) error {
request := Request{ request := Request{
Message: m, Message: m,
Target: stun.Normalize(c.String("server")), Target: normalize(c.String("server")),
} }
return DefaultClient.Do(request, func(r Response) error { return DefaultClient.Do(request, func(r Response) error {

17
stun.go
View File

@@ -28,7 +28,6 @@ import (
"io" "io"
"net" "net"
"strconv" "strconv"
"strings"
"sync" "sync"
"github.com/ernado/buffer" "github.com/ernado/buffer"
@@ -39,17 +38,6 @@ var (
bin = binary.BigEndian bin = binary.BigEndian
) )
// Normalize returns normalized address.
func Normalize(address string) string {
if len(address) == 0 {
address = "0.0.0.0"
}
if !strings.Contains(address, ":") {
address = fmt.Sprintf("%s:%d", address, DefaultPort)
}
return address
}
// DefaultPort is IANA assigned port for "stun" protocol. // DefaultPort is IANA assigned port for "stun" protocol.
const DefaultPort = 3478 const DefaultPort = 3478
@@ -287,7 +275,10 @@ func (m Message) Equal(b Message) bool {
return false return false
} }
for _, a := range m.Attributes { for _, a := range m.Attributes {
aB := b.Attributes.Get(a.Type) aB, ok := b.Attributes.Get(a.Type)
if !ok {
return false
}
if !aB.Equal(a) { if !aB.Equal(a) {
return false return false
} }

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"log" "log"
"net" "net"
"strings"
"github.com/ernado/stun" "github.com/ernado/stun"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -156,11 +157,21 @@ func ListenUDPAndServe(serverNet, laddr string) error {
return s.Serve(c) return s.Serve(c)
} }
func normalize(address string) string {
if len(address) == 0 {
address = "0.0.0.0"
}
if !strings.Contains(address, ":") {
address = fmt.Sprintf("%s:%d", address, stun.DefaultPort)
}
return address
}
func main() { func main() {
flag.Parse() flag.Parse()
switch *network { switch *network {
case "udp": case "udp":
normalized := stun.Normalize(*address) normalized := normalize(*address)
fmt.Println("cydev/stun listening on", normalized, "via", *network) fmt.Println("cydev/stun listening on", normalized, "via", *network)
log.Fatal(ListenUDPAndServe(*network, normalized)) log.Fatal(ListenUDPAndServe(*network, normalized))
default: default:

View File

@@ -1,9 +1,10 @@
package main package main
import ( import (
"github.com/ernado/stun"
"net" "net"
"testing" "testing"
"github.com/ernado/stun"
) )
func BenchmarkBasicProcess(b *testing.B) { func BenchmarkBasicProcess(b *testing.B) {