From 9eccfd803b2994d37a1eb45a9a7f31cb1a73a499 Mon Sep 17 00:00:00 2001 From: bhpike65 Date: Thu, 13 Jul 2017 11:56:58 +0800 Subject: [PATCH] fix & add hairpinning test support --- .gitignore | 1 + README.md | 11 ++-- client.go | 10 +-- nat/nat.go | 165 ++++++++++++++++++++++++++----------------------- server.go | 38 ++++++------ stun/stun.go | 170 ++++++++++++++++++--------------------------------- 6 files changed, 179 insertions(+), 216 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..485dee6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/README.md b/README.md index c7556c2..581f8b4 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ Go implementation of STUN fmt.Println("Failed to build STUN PP request:", err) os.Exit(1) } + fmt.Println("mapping address: ", resp.Addr.String()) } ``` @@ -69,13 +70,13 @@ import ( ) func main() { - test, _ := nat.NewNATDiscovery("192.168.1.1:0"", "stun.l.google.com:19302", "") - if err := test.Discovery(); err != nil { + res, err := nat.Discovery("192.168.1.1:0", "stun.l.google.com:19302", "") + if err != nil { fmt.Println("nat discovery error: ", err.Error()) os.Exit(-1) } - fmt.Printf("nat discovery result:\n%s", test) + fmt.Printf("nat discovery result:\n%s", res) return } ``` @@ -85,6 +86,7 @@ it will output: localAddress:192.168.1.3:56010, mappingAddress:1.1.1.1:15168 NAT mapping type: Endpoint-Independent Mapping NAT NAT filtering type: Endpoint-Independent Filtering NAT +NAT Hairpinning Support: YES ``` # Example Usage @@ -107,7 +109,7 @@ if you don't have two public IP address in one machine, instead, you can use two go run server.go -slave -slaveserver 1.1.1.1:12345 -primary-addr 2.2.2.2 -primary-port 3478 -alt-port 3479 ``` then it will start a tcp server listen on 1.1.1.1:12345, and waits request from master server. -you should add iptables rules to filter the packet which isn't come from master server. +you should add iptables rules to filter the packet which doesn't come from master server. 2. start master server @@ -126,6 +128,7 @@ and get the NAT behaviour test result: localAddress:192.168.1.3:49191, mappingAddress:3.3.3.3:37408 Address and Port-Dependent Mapping NAT Endpoint-Independent Filtering NAT +NAT Hairpinning Support: YES ``` # Spec diff --git a/client.go b/client.go index 9d2b23f..5aec48d 100644 --- a/client.go +++ b/client.go @@ -3,9 +3,9 @@ package main import ( "flag" "fmt" - "os" "github.com/bhpike65/go-stun/nat" "net" + "os" ) var server = flag.String("server", "stun.l.google.com:19302", "STUN server to query") @@ -30,16 +30,12 @@ func main() { } } - test, err := nat.NewNATDiscovery(*local, *server, *altServer) + res, err := nat.Discovery(*local, *server, *altServer) if err != nil { - fmt.Println("create nat test error: ", err.Error()) - os.Exit(-1) - } - if err = test.Discovery(); err != nil { fmt.Println("nat discovery error: ", err.Error()) os.Exit(-1) } - fmt.Printf("nat discovery result:\n%s", test) + fmt.Printf("nat discovery result:\n%s", res) return } diff --git a/nat/nat.go b/nat/nat.go index 9a9c53a..cf0f434 100644 --- a/nat/nat.go +++ b/nat/nat.go @@ -1,81 +1,80 @@ package nat import ( - "net" - "github.com/bhpike65/go-stun/stun" "errors" "fmt" + "github.com/bhpike65/go-stun/stun" + "net" ) + const ( NAT_TEST_FAILED = -1 - NAT_TYPE_NONAT = iota - NAT_TYPE_EIM //Endpoint-Independent Mapping NAT - NAT_TYPE_ADM //Address-Dependent Mapping NAT - NAT_TYPE_APDM //Address and Port-Dependent Mapping NAT - NAT_TYPE_EIF //Endpoint-Independent Filtering NAT - NAT_TYPE_ADF //Address-Dependent Filtering NAT - NAT_TYPE_APDF //Address and Port-Dependent Filtering NAT + NAT_TYPE_NONAT = iota + NAT_TYPE_EIM //Endpoint-Independent Mapping NAT + NAT_TYPE_ADM //Address-Dependent Mapping NAT + NAT_TYPE_APDM //Address and Port-Dependent Mapping NAT + NAT_TYPE_EIF //Endpoint-Independent Filtering NAT + NAT_TYPE_ADF //Address-Dependent Filtering NAT + NAT_TYPE_APDF //Address and Port-Dependent Filtering NAT ) type NATBehaviorDiscovery struct { - Conn *net.UDPConn - Local *net.UDPAddr - Server *net.UDPAddr - AltServer *net.UDPAddr - LocalAddr string - MappingAddr string - MappingType int + Local *net.UDPAddr + Server *net.UDPAddr + AltServer *net.UDPAddr + LocalAddr string + MappingAddr string + MappingType int FilteringType int + Hairpinning bool } -func NewNATDiscovery(localAddr, serverAddr, altServerAddr string) (*NATBehaviorDiscovery, error) { - ret := new(NATBehaviorDiscovery) + +func Discovery(local, server, altServer string) (*NATBehaviorDiscovery, error) { + var res NATBehaviorDiscovery var err error - ret.Server, err = net.ResolveUDPAddr("udp", serverAddr) + res.Server, err = net.ResolveUDPAddr("udp", server) if err != nil { return nil, err } - ret.Local, err = net.ResolveUDPAddr("udp", localAddr) + res.Local, err = net.ResolveUDPAddr("udp", local) if err != nil { return nil, err } - ret.Conn, err = net.ListenUDP("udp", ret.Local) - if err != nil { - return nil, err - } - ret.LocalAddr = ret.Conn.LocalAddr().String() - if altServerAddr != "" { - ret.AltServer, err = net.ResolveUDPAddr("udp", altServerAddr) + if altServer != "" { + res.AltServer, err = net.ResolveUDPAddr("udp", altServer) if err != nil { return nil, err } } - ret.MappingType = NAT_TEST_FAILED - ret.FilteringType = NAT_TEST_FAILED - return ret, nil -} -func (d *NATBehaviorDiscovery) Discovery() error { + conn, err := net.ListenUDP("udp", res.Local) + if err != nil { + return nil, err + } + defer conn.Close() + res.LocalAddr = conn.LocalAddr().String() + // testI: NO-NAT? req := stun.NewBindRequest(nil) - resp, localAddr, err := req.RequestTo(d.Conn, d.Server) + resp, localAddr, err := req.RequestTo(conn, res.Server) if err != nil { - return errors.New(fmt.Sprintf("Failed to build STUN PP request: %s", err.Error())) + return &res, errors.New(fmt.Sprintf("Failed to build STUN PP request: %s", err.Error())) } - if localAddr.String() == resp.Addr.String() { - d.MappingType = NAT_TYPE_NONAT - return nil - } - primaryPort := d.Server.Port mappingPP := resp.Addr.String() - d.MappingAddr = mappingPP + res.MappingAddr = mappingPP + if localAddr.String() == mappingPP { + res.MappingType = NAT_TYPE_NONAT + return &res, nil + } + primaryPort := res.Server.Port alternative := resp.OtherAddr other := resp.OtherAddr - if other == nil && d.AltServer != nil { - other = d.AltServer + if other == nil && res.AltServer != nil { + other = res.AltServer } if other != nil { altIp := other.IP @@ -84,66 +83,73 @@ func (d *NATBehaviorDiscovery) Discovery() error { req = stun.NewBindRequest(nil) remoteAP, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", altIp.String(), primaryPort)) if err != nil { - return errors.New(fmt.Sprintf("resolve AP address failed:%s", err.Error())) + return &res, errors.New(fmt.Sprintf("resolve AP address failed:%s", err.Error())) } - resp, localAddr, err = req.RequestTo(d.Conn, remoteAP) + resp, localAddr, err = req.RequestTo(conn, remoteAP) if err != nil { - return errors.New(fmt.Sprintf("Failed to build STUN AP request:%s", err.Error())) + return &res, errors.New(fmt.Sprintf("Failed to build STUN AP request:%s", err.Error())) } mappingAP := resp.Addr.String() if mappingPP == mappingAP { - d.MappingType = NAT_TYPE_EIM - } else{ + res.MappingType = NAT_TYPE_EIM + } else { //testIII, send to alternativeIp:alternativePort req = stun.NewBindRequest(nil) remoteAA, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", altIp.String(), altPort)) if err != nil { - return errors.New(fmt.Sprintf("resolve AA address failed:%s", err.Error())) + return &res, errors.New(fmt.Sprintf("resolve AA address failed:%s", err.Error())) } - resp, localAddr, err = req.RequestTo(d.Conn, remoteAA) + resp, localAddr, err = req.RequestTo(conn, remoteAA) if err != nil { - return errors.New(fmt.Sprintf("Failed to build STUN AA request:%s", err.Error())) + return &res, errors.New(fmt.Sprintf("Failed to build STUN AA request:%s", err.Error())) } mappingAA := resp.Addr.String() if mappingAP == mappingAA { - d.MappingType = NAT_TYPE_ADM + res.MappingType = NAT_TYPE_ADM } else { - d.MappingType = NAT_TYPE_APDM + res.MappingType = NAT_TYPE_APDM } } } else { - d.MappingType = NAT_TEST_FAILED + res.MappingType = NAT_TEST_FAILED } - if alternative == nil { - d.FilteringType = NAT_TEST_FAILED - return nil - } - //start NAT filter behavior test - //test II - req = stun.NewBindRequest(nil) - req.SetChangeIP(true) - req.SetChangePort(true) - _, _, err = req.RequestTo(d.Conn, d.Server) - if err == nil { - d.FilteringType = NAT_TYPE_EIF - } else { - //test III + if alternative != nil { + //start NAT filter behavior test + //test II req = stun.NewBindRequest(nil) - req.SetChangeIP(false) + req.SetChangeIP(true) req.SetChangePort(true) - req.ValidateSource(fmt.Sprintf("%s:%d", d.Server.IP.String(), alternative.Port)) - resp, _, err = req.RequestTo(d.Conn, d.Server) + _, _, err = req.RequestTo(conn, res.Server) if err == nil { - d.FilteringType = NAT_TYPE_ADF - } else if resp != nil { - d.FilteringType = NAT_TEST_FAILED + res.FilteringType = NAT_TYPE_EIF } else { - d.FilteringType = NAT_TYPE_APDF + //test III + req = stun.NewBindRequest(nil) + req.SetChangeIP(false) + req.SetChangePort(true) + req.ValidateSource(fmt.Sprintf("%s:%d", res.Server.IP.String(), alternative.Port)) + resp, _, err = req.RequestTo(conn, res.Server) + if err == nil { + res.FilteringType = NAT_TYPE_ADF + } else if resp != nil { + res.FilteringType = NAT_TEST_FAILED + } else { + res.FilteringType = NAT_TYPE_APDF + } } + } else { + res.FilteringType = NAT_TEST_FAILED } - d.Conn.Close() - return nil + + //hairpinning support test + req = stun.NewBindRequest(nil) + resp, _, err = req.Request(res.Local.IP.String(), mappingPP) + if err != nil { + res.Hairpinning = true + } + + return &res, nil } func (d *NATBehaviorDiscovery) String() string { @@ -173,5 +179,12 @@ func (d *NATBehaviorDiscovery) String() string { ret += "NAT filtering type: test failed\n" } } + + if d.Hairpinning { + ret += "NAT Hairpinning Support: YES\n" + } else { + ret += "NAT Hairpinning Support:: NO\n" + } + return ret } diff --git a/server.go b/server.go index 938cb03..5e7722f 100644 --- a/server.go +++ b/server.go @@ -1,29 +1,29 @@ package main import ( + "bufio" + "encoding/hex" "flag" + "fmt" + "github.com/bhpike65/go-stun/stun" + "io" + "log" "net" "os" - "bufio" - "fmt" - "go-stun/stun" - "log" - "io" "strings" - "encoding/hex" ) const ( - typePP = iota // primaryAddr:primaryPort - typePA // primaryAddr:alterAddr - typeAP // alterAddr:primaryPort - typeAA // alterAddr:alterAddr + typePP = iota // primaryAddr:primaryPort + typePA // primaryAddr:alterAddr + typeAP // alterAddr:primaryPort + typeAA // alterAddr:alterAddr typeMax ) var roleSet [typeMax]*net.UDPConn -var logger *log.Logger +var logger *log.Logger // ./stunserver --primaryAddr 1.1.1.1 --alternativeAddr 2.2.2.2 --primaryPort 3478 --alternativePort 3479 // ./stunserver --slaveserver 2.2.2.2:12345 --primaryAddr 1.1.1.1 --primaryPort 3478 --alternativePort 3479 @@ -40,7 +40,6 @@ var public = flag.Bool("public", true, "primaryAddr and alternativeAddr must be var slaveChan chan *string - var lanNets = []*net.IPNet{ {net.IPv4(10, 0, 0, 0), net.CIDRMask(8, 32)}, {net.IPv4(172, 16, 0, 0), net.CIDRMask(12, 32)}, @@ -51,12 +50,12 @@ var lanNets = []*net.IPNet{ func main() { flag.Parse() - logFile, err := os.OpenFile("./slave.log", os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666) + logFile, err := os.OpenFile("./slave.log", os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666) if err != nil { fmt.Println("failed to create slave.log: ", err.Error()) os.Exit(-1) } - logger = log.New(logFile,"",log.Llongfile | log.LstdFlags) + logger = log.New(logFile, "", log.Llongfile|log.LstdFlags) if *primaryAddr == "" || *alterAddr == "" { addrs, err := net.InterfaceAddrs() @@ -70,7 +69,7 @@ func main() { if ipnet.IP.To4() != nil && !lan.Contains(ipnet.IP) { if *primaryAddr == "" { *primaryAddr = ipnet.IP.String() - } else if *alterAddr == "" && *primaryAddr != ipnet.IP.String() { + } else if *alterAddr == "" && *primaryAddr != ipnet.IP.String() { *alterAddr = ipnet.IP.String() } else { break @@ -92,11 +91,11 @@ func main() { } } - roleSet[typePP], err = net.ListenUDP("udp", &net.UDPAddr{IP:net.ParseIP(*primaryAddr), Port:*primaryPort}) + roleSet[typePP], err = net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP(*primaryAddr), Port: *primaryPort}) if err != nil { logger.Fatal("listen on PP failed") } - roleSet[typePA], err = net.ListenUDP("udp", &net.UDPAddr{IP:net.ParseIP(*primaryAddr), Port:*alterPort}) + roleSet[typePA], err = net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP(*primaryAddr), Port: *alterPort}) if err != nil { logger.Fatal("listen on PA failed") } @@ -114,7 +113,7 @@ func main() { } slaveChan = make(chan *string, 128) go slaveClientWorker(slaveAddr) - aaAddr = &net.UDPAddr{IP:slaveAddr.IP, Port:*alterPort} + aaAddr = &net.UDPAddr{IP: slaveAddr.IP, Port: *alterPort} } } else if *slaveServer != "" { slaveAddr, err := net.ResolveTCPAddr("tcp", *slaveServer) @@ -128,7 +127,7 @@ func main() { if err != nil { logger.Fatalf("alterAddr %s:%d resolve failed", *alterAddr, alterPort) } - roleSet[typeAP], err = net.ListenUDP("udp", &net.UDPAddr{IP:net.ParseIP(*alterAddr), Port:*primaryPort}) + roleSet[typeAP], err = net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP(*alterAddr), Port: *primaryPort}) if err != nil { logger.Fatal("listen on PP failed") } @@ -264,4 +263,3 @@ func slaveProcessRequest(conn net.Conn) { req.RespondTo(roleSet[typePP], remote, nil) } } - diff --git a/stun/stun.go b/stun/stun.go index 298e8f3..13608e9 100644 --- a/stun/stun.go +++ b/stun/stun.go @@ -1,64 +1,64 @@ package stun import ( - "crypto/rand" - "fmt" - "net" "bytes" - "time" + "crypto/rand" "encoding/binary" - "io" "errors" + "fmt" "golang.org/x/net/ipv4" + "io" + "net" + "time" ) /* - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - |0 0| STUN Message Type | Message Length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Magic Cookie | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | | - | Transaction ID (96 bits) | - | | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |0 0| STUN Message Type | Message Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Magic Cookie | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | Transaction ID (96 bits) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - 0 1 - 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + 0 1 + 2 3 4 5 6 7 8 9 0 1 2 3 4 5 - +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ - |M |M |M|M|M|C|M|M|M|C|M|M|M|M| - |11|10|9|8|7|1|6|5|4|0|3|2|1|0| - +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + |M |M |M|M|M|C|M|M|M|C|M|M|M|M| + |11|10|9|8|7|1|6|5|4|0|3|2|1|0| + +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ - Figure 3: Format of STUN Message Type Field + Figure 3: Format of STUN Message Type Field */ type header struct { - Type uint16 - Length uint16 - Magic uint32 + Type uint16 + Length uint16 + Magic uint32 TransacrtonId [12]byte } type StunMessageReq struct { header - ChangeIp bool - ChangePort bool - RespSource string + ChangeIp bool + ChangePort bool + RespSource string //Candidate interface{} } type StunMessageResp struct { header - Addr *net.UDPAddr - OtherAddr *net.UDPAddr - ErrorCode uint16 - ErrorMsg string + Addr *net.UDPAddr + OtherAddr *net.UDPAddr + ErrorCode uint16 + ErrorMsg string } type attrHeader struct { @@ -78,14 +78,14 @@ const ( attrNonce = 0x15 attrXorAddress = 0x20 attrUseCandidate = 0x25 - attrPadding = 0x26 + attrPadding = 0x26 attrResponsePort = 0x27 // Comprehension optional - attrSoftware = 0x8022 + attrSoftware = 0x8022 //attrAlternate = 0x8023 - attrFingerprint = 0x8028 - attrOtherAddress= 0x802c + attrFingerprint = 0x8028 + attrOtherAddress = 0x802c ) const ( @@ -107,12 +107,12 @@ const ( const ( attrAddressFieldIpv4 = 1 attrAddressFieldIpv6 = 2 - attrAddressSizeIpv4 = 8 + attrAddressSizeIpv4 = 8 attrAddressSizeIpv6 = 20 ) const ( - magic = 0x2112a442 + magic = 0x2112a442 headerLen = 20 ) @@ -150,8 +150,8 @@ func (req *StunMessageReq) Unmarshal(data []byte) error { } if !typeIsRequest(req.Type) || methodFromMsgType(req.Type) != methodBinding || - req.Magic != magic || - int(req.Length+20) != len(data) { + req.Magic != magic || + int(req.Length+20) != len(data) { return errors.New("stun binding get an error format reply") } @@ -184,7 +184,7 @@ func (resp *StunMessageResp) Marshal() []byte { binary.Write(&buf, binary.BigEndian, resp.header) if resp.Addr.IP.To4() != nil { - binary.Write(&buf, binary.BigEndian, []interface{} { + binary.Write(&buf, binary.BigEndian, []interface{}{ uint16(attrAddress), uint16(attrAddressSizeIpv4), uint8(0), @@ -193,18 +193,18 @@ func (resp *StunMessageResp) Marshal() []byte { resp.Addr.IP.To4(), }) - binary.Write(&buf, binary.BigEndian, []interface{} { + binary.Write(&buf, binary.BigEndian, []interface{}{ uint16(attrXorAddress), uint16(attrAddressSizeIpv4), uint8(0), uint8(attrAddressFieldIpv4), uint16(resp.Addr.Port ^ magic>>16), }) - for i, field :=range resp.Addr.IP.To4() { + for i, field := range resp.Addr.IP.To4() { binary.Write(&buf, binary.BigEndian, uint8(field^magicBytes[i])) } } else { - binary.Write(&buf, binary.BigEndian, []interface{} { + binary.Write(&buf, binary.BigEndian, []interface{}{ uint16(attrAddress), uint16(attrAddressSizeIpv6), uint8(0), @@ -212,14 +212,14 @@ func (resp *StunMessageResp) Marshal() []byte { uint16(resp.Addr.Port), resp.Addr.IP.To16(), }) - binary.Write(&buf, binary.BigEndian, []interface{} { + binary.Write(&buf, binary.BigEndian, []interface{}{ uint16(attrXorAddress), uint16(attrAddressSizeIpv6), uint8(0), uint8(attrAddressFieldIpv6), uint16(resp.Addr.Port ^ magic>>16), }) - for i, field :=range resp.Addr.IP.To16() { + for i, field := range resp.Addr.IP.To16() { if i < 4 { binary.Write(&buf, binary.BigEndian, uint8(field^magicBytes[i])) } else { @@ -292,7 +292,7 @@ func (resp *StunMessageResp) Unmarshal(data []byte) error { if err != nil { return err } - resp.Addr = &net.UDPAddr{IP:ip, Port:port, Zone:""} + resp.Addr = &net.UDPAddr{IP: ip, Port: port, Zone: ""} } case attrXorAddress: ip, port, err := parseAddress(value) @@ -303,17 +303,17 @@ func (resp *StunMessageResp) Unmarshal(data []byte) error { ip[i] ^= data[4+i] } port ^= int(binary.BigEndian.Uint16(data[4:])) - resp.Addr = &net.UDPAddr{IP:ip, Port:port, Zone:""} + resp.Addr = &net.UDPAddr{IP: ip, Port: port, Zone: ""} haveXor = true case attrErrCode: resp.ErrorCode = uint16(value[2])*100 + uint16(value[3]) - resp.ErrorMsg = string(value[4:]) + resp.ErrorMsg = string(value[4:]) case attrOtherAddress: ip, port, err := parseAddress(value) if err != nil { return err } - resp.OtherAddr = &net.UDPAddr{IP:ip, Port:port, Zone:""} + resp.OtherAddr = &net.UDPAddr{IP: ip, Port: port, Zone: ""} default: } } @@ -337,29 +337,28 @@ func parseAddress(raw []byte) (net.IP, int, error) { ip := make([]byte, len(raw[4:])) copy(ip, raw[4:]) if len(ip) != family { - return nil, 0, errors.New("address parse error") + return nil, 0, errors.New("address parse error") } return net.IP(ip), int(port), nil } func getMsgType(class uint8, method uint16) uint16 { - return (method & 0x0f80) << 2 | (method & 0x0070) << 1 | (method & 0x0f) |(uint16(class) & 0x02) << 7 | (uint16(class) & 0x01) << 4 + return (method&0x0f80)<<2 | (method&0x0070)<<1 | (method & 0x0f) | (uint16(class)&0x02)<<7 | (uint16(class)&0x01)<<4 } func typeIsRequest(t uint16) bool { - return (t& 0x0110) == 0x0000 + return (t & 0x0110) == 0x0000 } func typeIsSuccessResp(t uint16) bool { - return (t& 0x0110) == 0x0100 + return (t & 0x0110) == 0x0100 } func typeIsErrorResp(t uint16) bool { - return (t& 0x0110) == 0x0110 + return (t & 0x0110) == 0x0110 } - func methodFromMsgType(t uint16) uint16 { - return (t & 0x000f) | ((t & 0x00e0) >>1) | ((t & 0x3E00) >>2) + return (t & 0x000f) | ((t & 0x00e0) >> 1) | ((t & 0x3E00) >> 2) } func NewBindRequest(tid []byte) *StunMessageReq { @@ -403,7 +402,7 @@ func (req *StunMessageReq) RequestTo(conn *net.UDPConn, to *net.UDPAddr) (*StunM loc, _ := net.ResolveUDPAddr("udp", conn.LocalAddr().String()) buf := make([]byte, 1500) - for retry:=0; retry < 3; retry++ { + for retry := 0; retry < 3; retry++ { _, err := pkConn.WriteTo(req.Marshal(), nil, to) if err != nil { return nil, nil, err @@ -453,55 +452,8 @@ func (req *StunMessageReq) Request(localAddr, remoteAddr string) (*StunMessageRe if err != nil { return nil, nil, err } - - pkConn := ipv4.NewPacketConn(sock) - pkConn.SetControlMessage(ipv4.FlagDst, true) - - defer pkConn.Close() - - if err := pkConn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { - fmt.Println("Couldn't set the socket timeout:", err) - } - - loc, _ := net.ResolveUDPAddr("udp", sock.LocalAddr().String()) - - buf := make([]byte, 1500) - - - for retry:=0; retry < 3; retry++ { - _, err = pkConn.WriteTo(req.Marshal(), nil, remote) - if err != nil { - return nil, nil, err - } - - n, cm, src, err := pkConn.ReadFrom(buf) - if err != nil { - if err, ok := err.(net.Error); ok && err.Timeout() { - - } - return nil, nil, err - } - loc.IP = cm.Dst - - var resp StunMessageResp - if err = resp.Unmarshal(buf[:n]); err != nil { - return nil, loc, err - } - if req.RespSource != "" && src.String() != req.RespSource { - return &resp, nil, errors.New("receive packet from unexpected source") - } - if resp.ErrorCode != 0 { - return &resp, loc, errors.New(resp.ErrorMsg) - } - if req.TransacrtonId != resp.TransacrtonId || - getMsgType(classResonseSuccess, methodBinding) != resp.Type || - resp.Addr == nil { - return &resp, loc, errors.New("receive error response") - } - return &resp, loc, nil - } - - return nil, nil, errors.New("request retry exceeds max times") + defer sock.Close() + return req.RequestTo(sock, remote) } func (req *StunMessageReq) RespondTo(conn *net.UDPConn, to *net.UDPAddr, other *net.UDPAddr) error { @@ -516,4 +468,4 @@ func (req *StunMessageReq) RespondTo(conn *net.UDPConn, to *net.UDPAddr, other * _, err := conn.WriteTo(resp.Marshal(), to) return err -} \ No newline at end of file +}