all: update client and server

This commit is contained in:
Aleksandr Razumov
2016-05-11 16:53:35 +03:00
parent 1f6a032d28
commit 7434b6616c
8 changed files with 329 additions and 7 deletions

View File

@@ -0,0 +1,25 @@
package stun
import "testing"
func TestErrorCode_Reason(t *testing.T) {
codes := [...]ErrorCode{
CodeTryAlternate,
CodeBadRequest,
CodeUnauthorised,
CodeUnknownAttribute,
CodeStaleNonce,
CodeServerError,
}
for _, code := range codes {
if code.Reason() == "Unknown Error" {
t.Error(code, "should not be unknown")
}
if len(code.Reason()) == 0 {
t.Error(code, "should not be blank")
}
}
if ErrorCode(999).Reason() != "Unknown Error" {
t.Error("999 error should be Unknown")
}
}

View File

@@ -222,3 +222,33 @@ func TestMessage_AddErrorCode(t *testing.T) {
t.Error("bad reason", string(reason))
}
}
func TestMessage_AddErrorCodeDefault(t *testing.T) {
m := AcquireMessage()
defer ReleaseMessage(m)
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
if err != nil {
t.Error(err)
}
copy(m.TransactionID[:], transactionID)
expectedCode := 500
expectedReason := "Server Error"
m.AddErrorCodeDefault(expectedCode)
m.WriteHeader()
mRes := AcquireMessage()
defer ReleaseMessage(mRes)
if _, err = mRes.ReadFrom(m.reader()); err != nil {
t.Fatal(err)
}
code, reason, err := mRes.GetErrorCode()
if err != nil {
t.Error(err)
}
if code != expectedCode {
t.Error("bad code", code)
}
if string(reason) != expectedReason {
t.Error("bad reason", string(reason))
}
}

144
client.go
View File

@@ -1,5 +1,149 @@
package stun
import (
"github.com/pkg/errors"
"net"
"time"
)
// SameTransaction returns true of a and b have same Transaction ID.
func SameTransaction(a *Message, b *Message) bool {
return a.TransactionID == b.TransactionID
}
const (
DefaultClientRetries = 9
DefaultMaxTimeout = 2 * time.Second
DefaultInitialTimeout = 1 * time.Millisecond
)
var (
// DefaultClient is Client with defaults that are close
// to RFC recommendations.
DefaultClient = Client{}
)
// Client implements STUN client.
type Client struct {
Retries int
MaxTimeout time.Duration
InitialTimeout time.Duration
addr *net.UDPAddr
}
func (c Client) getRetries() int {
if c.Retries == 0 {
return DefaultClientRetries
}
return c.Retries
}
func (c Client) getMaxTimeout() time.Duration {
if c.MaxTimeout == 0 {
return DefaultMaxTimeout
}
return c.MaxTimeout
}
func (c Client) getInitialTimeout() time.Duration {
if c.InitialTimeout == 0 {
return DefaultInitialTimeout
}
return c.InitialTimeout
}
func (c *Client) getAddr() (*net.UDPAddr, error) {
var (
err error
addr *net.UDPAddr
)
if c.addr != nil {
return c.addr, nil
}
addr, err = net.ResolveUDPAddr("udp", "0.0.0.0:0")
if err == nil {
c.addr = addr
}
return c.addr, err
}
// Request is wrapper on message and target server address.
type Request struct {
Message *Message
Target string
}
// Response is message returned from STUN server.
type Response struct {
Message *Message
}
// ResponseHandler is handler executed if response is received.
type ResponseHandler func(r Response) error
const timeoutGrowthRate = 2
// Do performs request and passing response to handler. If error occurs
// during request, Do returns it, not calling the handler.
// Do returns any error that is returned by handler.
// Response is only valid during handler execution.
//
// Never store Response, Message pointer or any values obtained from
// Message getters, copy message and use it if needed.
func (c Client) Do(request Request, h ResponseHandler) error {
var (
targetAddr *net.UDPAddr
clientAddr *net.UDPAddr
conn *net.UDPConn
err error
)
if targetAddr, err = net.ResolveUDPAddr("udp", request.Target); err != nil {
return errors.Wrap(err, "failed to resolve")
}
if clientAddr, err = c.getAddr(); err != nil {
return errors.Wrap(err, "failed to get local addr")
}
if conn, err = net.DialUDP("udp", clientAddr, targetAddr); err != nil {
return errors.Wrap(err, "failed to dial")
}
var (
timeout = c.getInitialTimeout()
maxTimeout = c.getMaxTimeout()
maxRetries = c.getRetries()
deadline = time.Now()
message = AcquireMessage()
)
defer ReleaseMessage(message)
for i := 0; i < maxRetries; i++ {
if _, err = request.Message.WriteTo(conn); err != nil {
return errors.Wrap(err, "failed to write")
}
deadline = time.Now().Add(timeout)
if err = conn.SetReadDeadline(deadline); err != nil {
return errors.Wrap(err, "failed to set deadline")
}
// updating timeout
if timeout < maxTimeout {
timeout *= timeoutGrowthRate
}
message.Reset()
if _, err = message.ReadFrom(conn); err != nil {
if _, ok := err.(net.Error); ok {
continue
}
return errors.Wrap(err, "network failed")
}
if SameTransaction(message, request.Message) {
return h(Response{
Message: message,
})
}
}
return errors.Wrap(err, "max retries reached")
}

View File

@@ -158,3 +158,30 @@ func TestClientSend(t *testing.T) {
}
}
}
func TestClient_Do(t *testing.T) {
skipIfNotFlagged(t, envExternalBlackbox)
client := Client{}
m := AcquireMessage()
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
m.TransactionID = NewTransactionID()
m.AddSoftware("cydev/stun alpha")
m.WriteHeader()
request := Request{
Target: "stun.l.google.com:19302",
Message: m,
}
if err := client.Do(request, func(r Response) error {
if r.Message.TransactionID != m.TransactionID {
t.Error("transaction id messmatch")
}
ip, port, err := r.Message.GetXORMappedAddress()
if err != nil {
t.Error(err)
}
log.Println("got", ip, port)
return nil
}); err != nil {
t.Fatal(err)
}
}

25
errors_test.go Normal file
View File

@@ -0,0 +1,25 @@
package stun
import "testing"
func TestDecodeErr(t *testing.T) {
err := newDecodeErr("parent", "children", "message")
if !err.IsPlace(DecodeErrPlace{Parent: "parent", Children: "children"}) {
t.Error("isPlace test failed")
}
if !err.IsPlaceParent("parent") {
t.Error("parent test failed")
}
if !err.IsPlaceChildren("children") {
t.Error("children test failed")
}
if err.Error() != "BadFormat for parent/children: message" {
t.Error("bad Error string")
}
}
func TestError_Error(t *testing.T) {
if Error("error").Error() != "error" {
t.Error("bad Error string")
}
}

66
integration_test.go Normal file
View File

@@ -0,0 +1,66 @@
package stun
import (
"net"
"testing"
)
func newServer(t *testing.T) (*net.UDPAddr, func()) {
laddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
con, err := net.ListenUDP("udp", laddr)
if err != nil {
t.Fatal(err)
}
addr, ok := con.LocalAddr().(*net.UDPAddr)
if !ok {
t.Error("not UDP addr")
}
s := &Server{}
go s.Serve(con)
return addr, func() {
if err := con.Close(); err != nil {
t.Error(err)
}
}
}
func newTestRequest(addr *net.UDPAddr, m *Message) Request {
return Request{
Message: m,
Target: addr.String(),
}
}
func TestClientServer(t *testing.T) {
serverAddr, closer := newServer(t)
defer closer()
m := AcquireFields(Message{
TransactionID: NewTransactionID(),
Type: MessageType{
Method: MethodBinding,
Class: ClassRequest,
},
})
m.AddSoftware("cydev/stun alpha")
m.WriteHeader()
r := newTestRequest(serverAddr, m)
defer ReleaseMessage(m)
if err := DefaultClient.Do(r, func(res Response) error {
if res.Message.GetSoftware() != "cydev/stun" {
t.Error("bad software attribute")
}
ip, _, err := res.Message.GetXORMappedAddress()
if err != nil {
t.Error(err)
}
if !ip.Equal(net.ParseIP("127.0.0.1")) {
t.Error("bad ip", ip)
}
return nil
}); err != nil {
t.Error(err)
}
}

View File

@@ -72,6 +72,9 @@ func (s *Server) getName() string {
}
func (s *Server) serveConn(c net.PacketConn) error {
if c == nil {
return nil
}
m := AcquireMessage()
buf := make([]byte, 2048)
n, addr, err := c.ReadFrom(buf)
@@ -127,7 +130,8 @@ func (s *Server) getConcurrency() int {
func (s *Server) Serve(c net.PacketConn) error {
for {
if err := s.serveConn(c); err != nil {
s.Logger.Printf("serve: %v", err)
s.logger().Printf("serve: %v", err)
return err
}
}
}

View File

@@ -27,7 +27,8 @@ func wrapLogrus(f func(c *cli.Context) error) func(c *cli.Context) error {
}
func discover(c *cli.Context) error {
conn, err := net.Dial("udp", stun.Normalize(c.String("server")))
normalized := stun.Normalize(c.String("server"))
conn, err := net.Dial("udp", normalized)
if err != nil {
return err
}
@@ -40,7 +41,7 @@ func discover(c *cli.Context) error {
})
m.AddSoftware("cydev/stun alpha")
m.WriteHeader()
timeout := 100 * time.Millisecond
timeout := 1000 * time.Millisecond
for i := 0; i < 9; i++ {
_, err := m.WriteTo(conn)
if err != nil {