mirror of
https://github.com/pion/stun.git
synced 2025-10-05 07:47:00 +08:00
all: update client and server
This commit is contained in:
25
attribute_errorcode_test.go
Normal file
25
attribute_errorcode_test.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestErrorCode_Reason(t *testing.T) {
|
||||||
|
codes := [...]ErrorCode{
|
||||||
|
CodeTryAlternate,
|
||||||
|
CodeBadRequest,
|
||||||
|
CodeUnauthorised,
|
||||||
|
CodeUnknownAttribute,
|
||||||
|
CodeStaleNonce,
|
||||||
|
CodeServerError,
|
||||||
|
}
|
||||||
|
for _, code := range codes {
|
||||||
|
if code.Reason() == "Unknown Error" {
|
||||||
|
t.Error(code, "should not be unknown")
|
||||||
|
}
|
||||||
|
if len(code.Reason()) == 0 {
|
||||||
|
t.Error(code, "should not be blank")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ErrorCode(999).Reason() != "Unknown Error" {
|
||||||
|
t.Error("999 error should be Unknown")
|
||||||
|
}
|
||||||
|
}
|
@@ -222,3 +222,33 @@ func TestMessage_AddErrorCode(t *testing.T) {
|
|||||||
t.Error("bad reason", string(reason))
|
t.Error("bad reason", string(reason))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMessage_AddErrorCodeDefault(t *testing.T) {
|
||||||
|
m := AcquireMessage()
|
||||||
|
defer ReleaseMessage(m)
|
||||||
|
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
copy(m.TransactionID[:], transactionID)
|
||||||
|
expectedCode := 500
|
||||||
|
expectedReason := "Server Error"
|
||||||
|
m.AddErrorCodeDefault(expectedCode)
|
||||||
|
m.WriteHeader()
|
||||||
|
|
||||||
|
mRes := AcquireMessage()
|
||||||
|
defer ReleaseMessage(mRes)
|
||||||
|
if _, err = mRes.ReadFrom(m.reader()); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
code, reason, err := mRes.GetErrorCode()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if code != expectedCode {
|
||||||
|
t.Error("bad code", code)
|
||||||
|
}
|
||||||
|
if string(reason) != expectedReason {
|
||||||
|
t.Error("bad reason", string(reason))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
144
client.go
144
client.go
@@ -1,5 +1,149 @@
|
|||||||
package stun
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SameTransaction returns true of a and b have same Transaction ID.
|
||||||
|
func SameTransaction(a *Message, b *Message) bool {
|
||||||
|
return a.TransactionID == b.TransactionID
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultClientRetries = 9
|
||||||
|
DefaultMaxTimeout = 2 * time.Second
|
||||||
|
DefaultInitialTimeout = 1 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// DefaultClient is Client with defaults that are close
|
||||||
|
// to RFC recommendations.
|
||||||
|
DefaultClient = Client{}
|
||||||
|
)
|
||||||
|
|
||||||
// Client implements STUN client.
|
// Client implements STUN client.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
|
Retries int
|
||||||
|
MaxTimeout time.Duration
|
||||||
|
InitialTimeout time.Duration
|
||||||
|
|
||||||
|
addr *net.UDPAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Client) getRetries() int {
|
||||||
|
if c.Retries == 0 {
|
||||||
|
return DefaultClientRetries
|
||||||
|
}
|
||||||
|
return c.Retries
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Client) getMaxTimeout() time.Duration {
|
||||||
|
if c.MaxTimeout == 0 {
|
||||||
|
return DefaultMaxTimeout
|
||||||
|
}
|
||||||
|
return c.MaxTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Client) getInitialTimeout() time.Duration {
|
||||||
|
if c.InitialTimeout == 0 {
|
||||||
|
return DefaultInitialTimeout
|
||||||
|
}
|
||||||
|
return c.InitialTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getAddr() (*net.UDPAddr, error) {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
addr *net.UDPAddr
|
||||||
|
)
|
||||||
|
if c.addr != nil {
|
||||||
|
return c.addr, nil
|
||||||
|
}
|
||||||
|
addr, err = net.ResolveUDPAddr("udp", "0.0.0.0:0")
|
||||||
|
if err == nil {
|
||||||
|
c.addr = addr
|
||||||
|
}
|
||||||
|
return c.addr, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request is wrapper on message and target server address.
|
||||||
|
type Request struct {
|
||||||
|
Message *Message
|
||||||
|
Target string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response is message returned from STUN server.
|
||||||
|
type Response struct {
|
||||||
|
Message *Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseHandler is handler executed if response is received.
|
||||||
|
type ResponseHandler func(r Response) error
|
||||||
|
|
||||||
|
const timeoutGrowthRate = 2
|
||||||
|
|
||||||
|
// Do performs request and passing response to handler. If error occurs
|
||||||
|
// during request, Do returns it, not calling the handler.
|
||||||
|
// Do returns any error that is returned by handler.
|
||||||
|
// Response is only valid during handler execution.
|
||||||
|
//
|
||||||
|
// Never store Response, Message pointer or any values obtained from
|
||||||
|
// Message getters, copy message and use it if needed.
|
||||||
|
func (c Client) Do(request Request, h ResponseHandler) error {
|
||||||
|
var (
|
||||||
|
targetAddr *net.UDPAddr
|
||||||
|
clientAddr *net.UDPAddr
|
||||||
|
conn *net.UDPConn
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if targetAddr, err = net.ResolveUDPAddr("udp", request.Target); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to resolve")
|
||||||
|
}
|
||||||
|
if clientAddr, err = c.getAddr(); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to get local addr")
|
||||||
|
}
|
||||||
|
if conn, err = net.DialUDP("udp", clientAddr, targetAddr); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to dial")
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
timeout = c.getInitialTimeout()
|
||||||
|
maxTimeout = c.getMaxTimeout()
|
||||||
|
maxRetries = c.getRetries()
|
||||||
|
deadline = time.Now()
|
||||||
|
message = AcquireMessage()
|
||||||
|
)
|
||||||
|
defer ReleaseMessage(message)
|
||||||
|
|
||||||
|
for i := 0; i < maxRetries; i++ {
|
||||||
|
if _, err = request.Message.WriteTo(conn); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to write")
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline = time.Now().Add(timeout)
|
||||||
|
if err = conn.SetReadDeadline(deadline); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to set deadline")
|
||||||
|
}
|
||||||
|
|
||||||
|
// updating timeout
|
||||||
|
if timeout < maxTimeout {
|
||||||
|
timeout *= timeoutGrowthRate
|
||||||
|
}
|
||||||
|
|
||||||
|
message.Reset()
|
||||||
|
if _, err = message.ReadFrom(conn); err != nil {
|
||||||
|
if _, ok := err.(net.Error); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return errors.Wrap(err, "network failed")
|
||||||
|
}
|
||||||
|
if SameTransaction(message, request.Message) {
|
||||||
|
return h(Response{
|
||||||
|
Message: message,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors.Wrap(err, "max retries reached")
|
||||||
}
|
}
|
||||||
|
@@ -158,3 +158,30 @@ func TestClientSend(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClient_Do(t *testing.T) {
|
||||||
|
skipIfNotFlagged(t, envExternalBlackbox)
|
||||||
|
client := Client{}
|
||||||
|
m := AcquireMessage()
|
||||||
|
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||||
|
m.TransactionID = NewTransactionID()
|
||||||
|
m.AddSoftware("cydev/stun alpha")
|
||||||
|
m.WriteHeader()
|
||||||
|
request := Request{
|
||||||
|
Target: "stun.l.google.com:19302",
|
||||||
|
Message: m,
|
||||||
|
}
|
||||||
|
if err := client.Do(request, func(r Response) error {
|
||||||
|
if r.Message.TransactionID != m.TransactionID {
|
||||||
|
t.Error("transaction id messmatch")
|
||||||
|
}
|
||||||
|
ip, port, err := r.Message.GetXORMappedAddress()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
log.Println("got", ip, port)
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
25
errors_test.go
Normal file
25
errors_test.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDecodeErr(t *testing.T) {
|
||||||
|
err := newDecodeErr("parent", "children", "message")
|
||||||
|
if !err.IsPlace(DecodeErrPlace{Parent: "parent", Children: "children"}) {
|
||||||
|
t.Error("isPlace test failed")
|
||||||
|
}
|
||||||
|
if !err.IsPlaceParent("parent") {
|
||||||
|
t.Error("parent test failed")
|
||||||
|
}
|
||||||
|
if !err.IsPlaceChildren("children") {
|
||||||
|
t.Error("children test failed")
|
||||||
|
}
|
||||||
|
if err.Error() != "BadFormat for parent/children: message" {
|
||||||
|
t.Error("bad Error string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestError_Error(t *testing.T) {
|
||||||
|
if Error("error").Error() != "error" {
|
||||||
|
t.Error("bad Error string")
|
||||||
|
}
|
||||||
|
}
|
66
integration_test.go
Normal file
66
integration_test.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package stun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newServer(t *testing.T) (*net.UDPAddr, func()) {
|
||||||
|
laddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
con, err := net.ListenUDP("udp", laddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
addr, ok := con.LocalAddr().(*net.UDPAddr)
|
||||||
|
if !ok {
|
||||||
|
t.Error("not UDP addr")
|
||||||
|
}
|
||||||
|
s := &Server{}
|
||||||
|
go s.Serve(con)
|
||||||
|
return addr, func() {
|
||||||
|
if err := con.Close(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestRequest(addr *net.UDPAddr, m *Message) Request {
|
||||||
|
return Request{
|
||||||
|
Message: m,
|
||||||
|
Target: addr.String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientServer(t *testing.T) {
|
||||||
|
serverAddr, closer := newServer(t)
|
||||||
|
defer closer()
|
||||||
|
m := AcquireFields(Message{
|
||||||
|
TransactionID: NewTransactionID(),
|
||||||
|
Type: MessageType{
|
||||||
|
Method: MethodBinding,
|
||||||
|
Class: ClassRequest,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
m.AddSoftware("cydev/stun alpha")
|
||||||
|
m.WriteHeader()
|
||||||
|
r := newTestRequest(serverAddr, m)
|
||||||
|
defer ReleaseMessage(m)
|
||||||
|
if err := DefaultClient.Do(r, func(res Response) error {
|
||||||
|
if res.Message.GetSoftware() != "cydev/stun" {
|
||||||
|
t.Error("bad software attribute")
|
||||||
|
}
|
||||||
|
ip, _, err := res.Message.GetXORMappedAddress()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if !ip.Equal(net.ParseIP("127.0.0.1")) {
|
||||||
|
t.Error("bad ip", ip)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
@@ -72,6 +72,9 @@ func (s *Server) getName() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) serveConn(c net.PacketConn) error {
|
func (s *Server) serveConn(c net.PacketConn) error {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
m := AcquireMessage()
|
m := AcquireMessage()
|
||||||
buf := make([]byte, 2048)
|
buf := make([]byte, 2048)
|
||||||
n, addr, err := c.ReadFrom(buf)
|
n, addr, err := c.ReadFrom(buf)
|
||||||
@@ -127,7 +130,8 @@ func (s *Server) getConcurrency() int {
|
|||||||
func (s *Server) Serve(c net.PacketConn) error {
|
func (s *Server) Serve(c net.PacketConn) error {
|
||||||
for {
|
for {
|
||||||
if err := s.serveConn(c); err != nil {
|
if err := s.serveConn(c); err != nil {
|
||||||
s.Logger.Printf("serve: %v", err)
|
s.logger().Printf("serve: %v", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -27,7 +27,8 @@ func wrapLogrus(f func(c *cli.Context) error) func(c *cli.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func discover(c *cli.Context) error {
|
func discover(c *cli.Context) error {
|
||||||
conn, err := net.Dial("udp", stun.Normalize(c.String("server")))
|
normalized := stun.Normalize(c.String("server"))
|
||||||
|
conn, err := net.Dial("udp", normalized)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -35,12 +36,12 @@ func discover(c *cli.Context) error {
|
|||||||
TransactionID: stun.NewTransactionID(),
|
TransactionID: stun.NewTransactionID(),
|
||||||
Type: stun.MessageType{
|
Type: stun.MessageType{
|
||||||
Method: stun.MethodBinding,
|
Method: stun.MethodBinding,
|
||||||
Class: stun.ClassRequest,
|
Class: stun.ClassRequest,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
m.AddSoftware("cydev/stun alpha")
|
m.AddSoftware("cydev/stun alpha")
|
||||||
m.WriteHeader()
|
m.WriteHeader()
|
||||||
timeout := 100 * time.Millisecond
|
timeout := 1000 * time.Millisecond
|
||||||
for i := 0; i < 9; i++ {
|
for i := 0; i < 9; i++ {
|
||||||
_, err := m.WriteTo(conn)
|
_, err := m.WriteTo(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -86,9 +87,9 @@ func main() {
|
|||||||
app.Usage = "command line client for STUN"
|
app.Usage = "command line client for STUN"
|
||||||
app.Flags = []cli.Flag{
|
app.Flags = []cli.Flag{
|
||||||
cli.StringFlag{
|
cli.StringFlag{
|
||||||
Name: "server",
|
Name: "server",
|
||||||
Value: "ci.cydev.ru",
|
Value: "ci.cydev.ru",
|
||||||
Usage: "STUN server address",
|
Usage: "STUN server address",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
app.Action = wrapLogrus(discover)
|
app.Action = wrapLogrus(discover)
|
||||||
|
Reference in New Issue
Block a user