mirror of
https://github.com/pion/stun.git
synced 2025-10-05 07:47:00 +08:00
all: update client and server
This commit is contained in:
25
attribute_errorcode_test.go
Normal file
25
attribute_errorcode_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package stun
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestErrorCode_Reason(t *testing.T) {
|
||||
codes := [...]ErrorCode{
|
||||
CodeTryAlternate,
|
||||
CodeBadRequest,
|
||||
CodeUnauthorised,
|
||||
CodeUnknownAttribute,
|
||||
CodeStaleNonce,
|
||||
CodeServerError,
|
||||
}
|
||||
for _, code := range codes {
|
||||
if code.Reason() == "Unknown Error" {
|
||||
t.Error(code, "should not be unknown")
|
||||
}
|
||||
if len(code.Reason()) == 0 {
|
||||
t.Error(code, "should not be blank")
|
||||
}
|
||||
}
|
||||
if ErrorCode(999).Reason() != "Unknown Error" {
|
||||
t.Error("999 error should be Unknown")
|
||||
}
|
||||
}
|
@@ -222,3 +222,33 @@ func TestMessage_AddErrorCode(t *testing.T) {
|
||||
t.Error("bad reason", string(reason))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_AddErrorCodeDefault(t *testing.T) {
|
||||
m := AcquireMessage()
|
||||
defer ReleaseMessage(m)
|
||||
transactionID, err := base64.StdEncoding.DecodeString("jxhBARZwX+rsC6er")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
copy(m.TransactionID[:], transactionID)
|
||||
expectedCode := 500
|
||||
expectedReason := "Server Error"
|
||||
m.AddErrorCodeDefault(expectedCode)
|
||||
m.WriteHeader()
|
||||
|
||||
mRes := AcquireMessage()
|
||||
defer ReleaseMessage(mRes)
|
||||
if _, err = mRes.ReadFrom(m.reader()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
code, reason, err := mRes.GetErrorCode()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if code != expectedCode {
|
||||
t.Error("bad code", code)
|
||||
}
|
||||
if string(reason) != expectedReason {
|
||||
t.Error("bad reason", string(reason))
|
||||
}
|
||||
}
|
||||
|
144
client.go
144
client.go
@@ -1,5 +1,149 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SameTransaction returns true of a and b have same Transaction ID.
|
||||
func SameTransaction(a *Message, b *Message) bool {
|
||||
return a.TransactionID == b.TransactionID
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultClientRetries = 9
|
||||
DefaultMaxTimeout = 2 * time.Second
|
||||
DefaultInitialTimeout = 1 * time.Millisecond
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultClient is Client with defaults that are close
|
||||
// to RFC recommendations.
|
||||
DefaultClient = Client{}
|
||||
)
|
||||
|
||||
// Client implements STUN client.
|
||||
type Client struct {
|
||||
Retries int
|
||||
MaxTimeout time.Duration
|
||||
InitialTimeout time.Duration
|
||||
|
||||
addr *net.UDPAddr
|
||||
}
|
||||
|
||||
func (c Client) getRetries() int {
|
||||
if c.Retries == 0 {
|
||||
return DefaultClientRetries
|
||||
}
|
||||
return c.Retries
|
||||
}
|
||||
|
||||
func (c Client) getMaxTimeout() time.Duration {
|
||||
if c.MaxTimeout == 0 {
|
||||
return DefaultMaxTimeout
|
||||
}
|
||||
return c.MaxTimeout
|
||||
}
|
||||
|
||||
func (c Client) getInitialTimeout() time.Duration {
|
||||
if c.InitialTimeout == 0 {
|
||||
return DefaultInitialTimeout
|
||||
}
|
||||
return c.InitialTimeout
|
||||
}
|
||||
|
||||
func (c *Client) getAddr() (*net.UDPAddr, error) {
|
||||
var (
|
||||
err error
|
||||
addr *net.UDPAddr
|
||||
)
|
||||
if c.addr != nil {
|
||||
return c.addr, nil
|
||||
}
|
||||
addr, err = net.ResolveUDPAddr("udp", "0.0.0.0:0")
|
||||
if err == nil {
|
||||
c.addr = addr
|
||||
}
|
||||
return c.addr, err
|
||||
}
|
||||
|
||||
// Request is wrapper on message and target server address.
|
||||
type Request struct {
|
||||
Message *Message
|
||||
Target string
|
||||
}
|
||||
|
||||
// Response is message returned from STUN server.
|
||||
type Response struct {
|
||||
Message *Message
|
||||
}
|
||||
|
||||
// ResponseHandler is handler executed if response is received.
|
||||
type ResponseHandler func(r Response) error
|
||||
|
||||
const timeoutGrowthRate = 2
|
||||
|
||||
// Do performs request and passing response to handler. If error occurs
|
||||
// during request, Do returns it, not calling the handler.
|
||||
// Do returns any error that is returned by handler.
|
||||
// Response is only valid during handler execution.
|
||||
//
|
||||
// Never store Response, Message pointer or any values obtained from
|
||||
// Message getters, copy message and use it if needed.
|
||||
func (c Client) Do(request Request, h ResponseHandler) error {
|
||||
var (
|
||||
targetAddr *net.UDPAddr
|
||||
clientAddr *net.UDPAddr
|
||||
conn *net.UDPConn
|
||||
err error
|
||||
)
|
||||
if targetAddr, err = net.ResolveUDPAddr("udp", request.Target); err != nil {
|
||||
return errors.Wrap(err, "failed to resolve")
|
||||
}
|
||||
if clientAddr, err = c.getAddr(); err != nil {
|
||||
return errors.Wrap(err, "failed to get local addr")
|
||||
}
|
||||
if conn, err = net.DialUDP("udp", clientAddr, targetAddr); err != nil {
|
||||
return errors.Wrap(err, "failed to dial")
|
||||
}
|
||||
|
||||
var (
|
||||
timeout = c.getInitialTimeout()
|
||||
maxTimeout = c.getMaxTimeout()
|
||||
maxRetries = c.getRetries()
|
||||
deadline = time.Now()
|
||||
message = AcquireMessage()
|
||||
)
|
||||
defer ReleaseMessage(message)
|
||||
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
if _, err = request.Message.WriteTo(conn); err != nil {
|
||||
return errors.Wrap(err, "failed to write")
|
||||
}
|
||||
|
||||
deadline = time.Now().Add(timeout)
|
||||
if err = conn.SetReadDeadline(deadline); err != nil {
|
||||
return errors.Wrap(err, "failed to set deadline")
|
||||
}
|
||||
|
||||
// updating timeout
|
||||
if timeout < maxTimeout {
|
||||
timeout *= timeoutGrowthRate
|
||||
}
|
||||
|
||||
message.Reset()
|
||||
if _, err = message.ReadFrom(conn); err != nil {
|
||||
if _, ok := err.(net.Error); ok {
|
||||
continue
|
||||
}
|
||||
return errors.Wrap(err, "network failed")
|
||||
}
|
||||
if SameTransaction(message, request.Message) {
|
||||
return h(Response{
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
}
|
||||
return errors.Wrap(err, "max retries reached")
|
||||
}
|
||||
|
@@ -158,3 +158,30 @@ func TestClientSend(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Do(t *testing.T) {
|
||||
skipIfNotFlagged(t, envExternalBlackbox)
|
||||
client := Client{}
|
||||
m := AcquireMessage()
|
||||
m.Type = MessageType{Method: MethodBinding, Class: ClassRequest}
|
||||
m.TransactionID = NewTransactionID()
|
||||
m.AddSoftware("cydev/stun alpha")
|
||||
m.WriteHeader()
|
||||
request := Request{
|
||||
Target: "stun.l.google.com:19302",
|
||||
Message: m,
|
||||
}
|
||||
if err := client.Do(request, func(r Response) error {
|
||||
if r.Message.TransactionID != m.TransactionID {
|
||||
t.Error("transaction id messmatch")
|
||||
}
|
||||
ip, port, err := r.Message.GetXORMappedAddress()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
log.Println("got", ip, port)
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
25
errors_test.go
Normal file
25
errors_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package stun
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDecodeErr(t *testing.T) {
|
||||
err := newDecodeErr("parent", "children", "message")
|
||||
if !err.IsPlace(DecodeErrPlace{Parent: "parent", Children: "children"}) {
|
||||
t.Error("isPlace test failed")
|
||||
}
|
||||
if !err.IsPlaceParent("parent") {
|
||||
t.Error("parent test failed")
|
||||
}
|
||||
if !err.IsPlaceChildren("children") {
|
||||
t.Error("children test failed")
|
||||
}
|
||||
if err.Error() != "BadFormat for parent/children: message" {
|
||||
t.Error("bad Error string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestError_Error(t *testing.T) {
|
||||
if Error("error").Error() != "error" {
|
||||
t.Error("bad Error string")
|
||||
}
|
||||
}
|
66
integration_test.go
Normal file
66
integration_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newServer(t *testing.T) (*net.UDPAddr, func()) {
|
||||
laddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
con, err := net.ListenUDP("udp", laddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addr, ok := con.LocalAddr().(*net.UDPAddr)
|
||||
if !ok {
|
||||
t.Error("not UDP addr")
|
||||
}
|
||||
s := &Server{}
|
||||
go s.Serve(con)
|
||||
return addr, func() {
|
||||
if err := con.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRequest(addr *net.UDPAddr, m *Message) Request {
|
||||
return Request{
|
||||
Message: m,
|
||||
Target: addr.String(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientServer(t *testing.T) {
|
||||
serverAddr, closer := newServer(t)
|
||||
defer closer()
|
||||
m := AcquireFields(Message{
|
||||
TransactionID: NewTransactionID(),
|
||||
Type: MessageType{
|
||||
Method: MethodBinding,
|
||||
Class: ClassRequest,
|
||||
},
|
||||
})
|
||||
m.AddSoftware("cydev/stun alpha")
|
||||
m.WriteHeader()
|
||||
r := newTestRequest(serverAddr, m)
|
||||
defer ReleaseMessage(m)
|
||||
if err := DefaultClient.Do(r, func(res Response) error {
|
||||
if res.Message.GetSoftware() != "cydev/stun" {
|
||||
t.Error("bad software attribute")
|
||||
}
|
||||
ip, _, err := res.Message.GetXORMappedAddress()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !ip.Equal(net.ParseIP("127.0.0.1")) {
|
||||
t.Error("bad ip", ip)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
@@ -72,6 +72,9 @@ func (s *Server) getName() string {
|
||||
}
|
||||
|
||||
func (s *Server) serveConn(c net.PacketConn) error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
m := AcquireMessage()
|
||||
buf := make([]byte, 2048)
|
||||
n, addr, err := c.ReadFrom(buf)
|
||||
@@ -127,7 +130,8 @@ func (s *Server) getConcurrency() int {
|
||||
func (s *Server) Serve(c net.PacketConn) error {
|
||||
for {
|
||||
if err := s.serveConn(c); err != nil {
|
||||
s.Logger.Printf("serve: %v", err)
|
||||
s.logger().Printf("serve: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -27,7 +27,8 @@ func wrapLogrus(f func(c *cli.Context) error) func(c *cli.Context) error {
|
||||
}
|
||||
|
||||
func discover(c *cli.Context) error {
|
||||
conn, err := net.Dial("udp", stun.Normalize(c.String("server")))
|
||||
normalized := stun.Normalize(c.String("server"))
|
||||
conn, err := net.Dial("udp", normalized)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -35,12 +36,12 @@ func discover(c *cli.Context) error {
|
||||
TransactionID: stun.NewTransactionID(),
|
||||
Type: stun.MessageType{
|
||||
Method: stun.MethodBinding,
|
||||
Class: stun.ClassRequest,
|
||||
Class: stun.ClassRequest,
|
||||
},
|
||||
})
|
||||
m.AddSoftware("cydev/stun alpha")
|
||||
m.WriteHeader()
|
||||
timeout := 100 * time.Millisecond
|
||||
timeout := 1000 * time.Millisecond
|
||||
for i := 0; i < 9; i++ {
|
||||
_, err := m.WriteTo(conn)
|
||||
if err != nil {
|
||||
@@ -86,9 +87,9 @@ func main() {
|
||||
app.Usage = "command line client for STUN"
|
||||
app.Flags = []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "server",
|
||||
Value: "ci.cydev.ru",
|
||||
Usage: "STUN server address",
|
||||
Name: "server",
|
||||
Value: "ci.cydev.ru",
|
||||
Usage: "STUN server address",
|
||||
},
|
||||
}
|
||||
app.Action = wrapLogrus(discover)
|
||||
|
Reference in New Issue
Block a user