mirror of
https://github.com/gortc/stun.git
synced 2025-09-27 12:52:16 +08:00
attributes, stun: refactor
This commit is contained in:
@@ -7,26 +7,20 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// blank is just blank string and exists just because it is ugly to keep it
|
|
||||||
// in code.
|
|
||||||
const blank = ""
|
|
||||||
|
|
||||||
// Attributes is list of message attributes.
|
// Attributes is list of message attributes.
|
||||||
type Attributes []Attribute
|
type Attributes []Attribute
|
||||||
|
|
||||||
// BlankAttribute is attribute that is returned by
|
// Get returns first attribute from list by the type.
|
||||||
// Attributes.Get if nothing found.
|
// If attribute is present the Attribute is returned and the
|
||||||
var BlankAttribute = Attribute{}
|
// boolean is true. Otherwise the returned Attribute will be
|
||||||
|
// empty and boolean will be false.
|
||||||
// Get returns first attribute from list which match AttrType. If nothing
|
func (a Attributes) Get(t AttrType) (Attribute, bool) {
|
||||||
// found, it returns blank attribute.
|
|
||||||
func (a Attributes) Get(t AttrType) Attribute {
|
|
||||||
for _, candidate := range a {
|
for _, candidate := range a {
|
||||||
if candidate.Type == t {
|
if candidate.Type == t {
|
||||||
return candidate
|
return candidate, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return BlankAttribute
|
return Attribute{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// AttrType is attribute type.
|
// AttrType is attribute type.
|
||||||
@@ -132,11 +126,6 @@ type Attribute struct {
|
|||||||
Value []byte
|
Value []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsBlank returns true if attribute equals to BlankAttribute.
|
|
||||||
func (a Attribute) IsBlank() bool {
|
|
||||||
return a.Equal(BlankAttribute)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Equal returns true if a == b.
|
// Equal returns true if a == b.
|
||||||
func (a Attribute) Equal(b Attribute) bool {
|
func (a Attribute) Equal(b Attribute) bool {
|
||||||
if a.Type != b.Type {
|
if a.Type != b.Type {
|
||||||
@@ -164,11 +153,11 @@ func (a Attribute) String() string {
|
|||||||
// if there is no value attribute with shuck type,
|
// if there is no value attribute with shuck type,
|
||||||
// ErrAttributeNotFound is returned.
|
// ErrAttributeNotFound is returned.
|
||||||
func (m *Message) getAttrValue(t AttrType) ([]byte, error) {
|
func (m *Message) getAttrValue(t AttrType) ([]byte, error) {
|
||||||
v := m.Attributes.Get(t).Value
|
v, ok := m.Attributes.Get(t)
|
||||||
if len(v) == 0 {
|
if !ok {
|
||||||
return nil, ErrAttributeNotFound
|
return nil, ErrAttributeNotFound
|
||||||
}
|
}
|
||||||
return v, nil
|
return v.Value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddSoftwareBytes adds SOFTWARE attribute with value from byte slice.
|
// AddSoftwareBytes adds SOFTWARE attribute with value from byte slice.
|
||||||
@@ -184,18 +173,16 @@ func (m *Message) AddSoftware(software string) {
|
|||||||
// GetSoftwareBytes returns SOFTWARE attribute value in byte slice.
|
// GetSoftwareBytes returns SOFTWARE attribute value in byte slice.
|
||||||
// If not found, returns nil.
|
// If not found, returns nil.
|
||||||
func (m *Message) GetSoftwareBytes() []byte {
|
func (m *Message) GetSoftwareBytes() []byte {
|
||||||
return m.Attributes.Get(AttrSoftware).Value
|
v, ok := m.Attributes.Get(AttrSoftware)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return v.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSoftware returns SOFTWARE attribute value in string.
|
// GetSoftware returns SOFTWARE attribute value in string.
|
||||||
// If not found, returns blank string.
|
// If not found, returns blank string.
|
||||||
func (m *Message) GetSoftware() string {
|
func (m *Message) GetSoftware() string { return string(m.GetSoftwareBytes()) }
|
||||||
v := m.GetSoftwareBytes()
|
|
||||||
if len(v) == 0 {
|
|
||||||
return blank
|
|
||||||
}
|
|
||||||
return string(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Address family values.
|
// Address family values.
|
||||||
const (
|
const (
|
||||||
|
@@ -27,7 +27,10 @@ func TestMessage_AddSoftware(t *testing.T) {
|
|||||||
t.Errorf("Expected %s, got %s.", v, vRead)
|
t.Errorf("Expected %s, got %s.", v, vRead)
|
||||||
}
|
}
|
||||||
|
|
||||||
sAttr := m.Attributes.Get(AttrSoftware)
|
sAttr, ok := m.Attributes.Get(AttrSoftware)
|
||||||
|
if !ok {
|
||||||
|
t.Error("sowfware attribute should be found")
|
||||||
|
}
|
||||||
s := sAttr.String()
|
s := sAttr.String()
|
||||||
if !strings.HasPrefix(s, "SOFTWARE:") {
|
if !strings.HasPrefix(s, "SOFTWARE:") {
|
||||||
t.Error("bad string representation", s)
|
t.Error("bad string representation", s)
|
||||||
@@ -282,20 +285,3 @@ func TestMessage_AddErrorCodeDefault(t *testing.T) {
|
|||||||
t.Error("bad reason", string(reason))
|
t.Error("bad reason", string(reason))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAttribute_IsBlank(t *testing.T) {
|
|
||||||
var tt = [...]struct {
|
|
||||||
in Attribute
|
|
||||||
out bool
|
|
||||||
}{
|
|
||||||
{BlankAttribute, true}, // 0
|
|
||||||
{Attribute{Type: AttrUseCandidate}, false}, // 1
|
|
||||||
{Attribute{Value: []byte{1, 2, 3}}, false}, // 2
|
|
||||||
{Attribute{}, true}, // 3
|
|
||||||
}
|
|
||||||
for i, v := range tt {
|
|
||||||
if got := v.in.IsBlank(); got != v.out {
|
|
||||||
t.Errorf("tt[%d]: (%+v).IsMessage %v != %v", i, v.in, got, v.out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/codegangsta/cli"
|
"github.com/codegangsta/cli"
|
||||||
@@ -15,6 +16,16 @@ const (
|
|||||||
version = "0.2"
|
version = "0.2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func normalize(address string) string {
|
||||||
|
if len(address) == 0 {
|
||||||
|
address = "0.0.0.0"
|
||||||
|
}
|
||||||
|
if !strings.Contains(address, ":") {
|
||||||
|
address = fmt.Sprintf("%s:%d", address, stun.DefaultPort)
|
||||||
|
}
|
||||||
|
return address
|
||||||
|
}
|
||||||
|
|
||||||
// Defaults for Client fields.
|
// Defaults for Client fields.
|
||||||
const (
|
const (
|
||||||
DefaultClientRetries = 9
|
DefaultClientRetries = 9
|
||||||
@@ -179,7 +190,7 @@ func discover(c *cli.Context) error {
|
|||||||
|
|
||||||
request := Request{
|
request := Request{
|
||||||
Message: m,
|
Message: m,
|
||||||
Target: stun.Normalize(c.String("server")),
|
Target: normalize(c.String("server")),
|
||||||
}
|
}
|
||||||
|
|
||||||
return DefaultClient.Do(request, func(r Response) error {
|
return DefaultClient.Do(request, func(r Response) error {
|
||||||
|
17
stun.go
17
stun.go
@@ -28,7 +28,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/ernado/buffer"
|
"github.com/ernado/buffer"
|
||||||
@@ -39,17 +38,6 @@ var (
|
|||||||
bin = binary.BigEndian
|
bin = binary.BigEndian
|
||||||
)
|
)
|
||||||
|
|
||||||
// Normalize returns normalized address.
|
|
||||||
func Normalize(address string) string {
|
|
||||||
if len(address) == 0 {
|
|
||||||
address = "0.0.0.0"
|
|
||||||
}
|
|
||||||
if !strings.Contains(address, ":") {
|
|
||||||
address = fmt.Sprintf("%s:%d", address, DefaultPort)
|
|
||||||
}
|
|
||||||
return address
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultPort is IANA assigned port for "stun" protocol.
|
// DefaultPort is IANA assigned port for "stun" protocol.
|
||||||
const DefaultPort = 3478
|
const DefaultPort = 3478
|
||||||
|
|
||||||
@@ -287,7 +275,10 @@ func (m Message) Equal(b Message) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
for _, a := range m.Attributes {
|
for _, a := range m.Attributes {
|
||||||
aB := b.Attributes.Get(a.Type)
|
aB, ok := b.Attributes.Get(a.Type)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
if !aB.Equal(a) {
|
if !aB.Equal(a) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/ernado/stun"
|
"github.com/ernado/stun"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -156,11 +157,21 @@ func ListenUDPAndServe(serverNet, laddr string) error {
|
|||||||
return s.Serve(c)
|
return s.Serve(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalize(address string) string {
|
||||||
|
if len(address) == 0 {
|
||||||
|
address = "0.0.0.0"
|
||||||
|
}
|
||||||
|
if !strings.Contains(address, ":") {
|
||||||
|
address = fmt.Sprintf("%s:%d", address, stun.DefaultPort)
|
||||||
|
}
|
||||||
|
return address
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
switch *network {
|
switch *network {
|
||||||
case "udp":
|
case "udp":
|
||||||
normalized := stun.Normalize(*address)
|
normalized := normalize(*address)
|
||||||
fmt.Println("cydev/stun listening on", normalized, "via", *network)
|
fmt.Println("cydev/stun listening on", normalized, "via", *network)
|
||||||
log.Fatal(ListenUDPAndServe(*network, normalized))
|
log.Fatal(ListenUDPAndServe(*network, normalized))
|
||||||
default:
|
default:
|
||||||
|
@@ -1,9 +1,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/ernado/stun"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ernado/stun"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkBasicProcess(b *testing.B) {
|
func BenchmarkBasicProcess(b *testing.B) {
|
||||||
|
Reference in New Issue
Block a user