Files
stun/client.go
2016-05-11 16:53:35 +03:00

150 lines
3.3 KiB
Go

package stun
import (
"github.com/pkg/errors"
"net"
"time"
)
// SameTransaction returns true of a and b have same Transaction ID.
func SameTransaction(a *Message, b *Message) bool {
return a.TransactionID == b.TransactionID
}
const (
DefaultClientRetries = 9
DefaultMaxTimeout = 2 * time.Second
DefaultInitialTimeout = 1 * time.Millisecond
)
var (
// DefaultClient is Client with defaults that are close
// to RFC recommendations.
DefaultClient = Client{}
)
// Client implements STUN client.
type Client struct {
Retries int
MaxTimeout time.Duration
InitialTimeout time.Duration
addr *net.UDPAddr
}
func (c Client) getRetries() int {
if c.Retries == 0 {
return DefaultClientRetries
}
return c.Retries
}
func (c Client) getMaxTimeout() time.Duration {
if c.MaxTimeout == 0 {
return DefaultMaxTimeout
}
return c.MaxTimeout
}
func (c Client) getInitialTimeout() time.Duration {
if c.InitialTimeout == 0 {
return DefaultInitialTimeout
}
return c.InitialTimeout
}
func (c *Client) getAddr() (*net.UDPAddr, error) {
var (
err error
addr *net.UDPAddr
)
if c.addr != nil {
return c.addr, nil
}
addr, err = net.ResolveUDPAddr("udp", "0.0.0.0:0")
if err == nil {
c.addr = addr
}
return c.addr, err
}
// Request is wrapper on message and target server address.
type Request struct {
Message *Message
Target string
}
// Response is message returned from STUN server.
type Response struct {
Message *Message
}
// ResponseHandler is handler executed if response is received.
type ResponseHandler func(r Response) error
const timeoutGrowthRate = 2
// Do performs request and passing response to handler. If error occurs
// during request, Do returns it, not calling the handler.
// Do returns any error that is returned by handler.
// Response is only valid during handler execution.
//
// Never store Response, Message pointer or any values obtained from
// Message getters, copy message and use it if needed.
func (c Client) Do(request Request, h ResponseHandler) error {
var (
targetAddr *net.UDPAddr
clientAddr *net.UDPAddr
conn *net.UDPConn
err error
)
if targetAddr, err = net.ResolveUDPAddr("udp", request.Target); err != nil {
return errors.Wrap(err, "failed to resolve")
}
if clientAddr, err = c.getAddr(); err != nil {
return errors.Wrap(err, "failed to get local addr")
}
if conn, err = net.DialUDP("udp", clientAddr, targetAddr); err != nil {
return errors.Wrap(err, "failed to dial")
}
var (
timeout = c.getInitialTimeout()
maxTimeout = c.getMaxTimeout()
maxRetries = c.getRetries()
deadline = time.Now()
message = AcquireMessage()
)
defer ReleaseMessage(message)
for i := 0; i < maxRetries; i++ {
if _, err = request.Message.WriteTo(conn); err != nil {
return errors.Wrap(err, "failed to write")
}
deadline = time.Now().Add(timeout)
if err = conn.SetReadDeadline(deadline); err != nil {
return errors.Wrap(err, "failed to set deadline")
}
// updating timeout
if timeout < maxTimeout {
timeout *= timeoutGrowthRate
}
message.Reset()
if _, err = message.ReadFrom(conn); err != nil {
if _, ok := err.(net.Error); ok {
continue
}
return errors.Wrap(err, "network failed")
}
if SameTransaction(message, request.Message) {
return h(Response{
Message: message,
})
}
}
return errors.Wrap(err, "max retries reached")
}