Files
wgstun/resolve.go
2022-04-02 11:29:34 +08:00

139 lines
3.3 KiB
Go

package main
import (
"bytes"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"net"
"time"
)
// ResolveResult saves peer info on wireguard server
type ResolveResult struct {
// result values is valid, should check this before use Endpoint or other values
Valid bool
Key wgtypes.Key
Endpoint *net.UDPAddr
// last handshake time
HandshakeTime time.Time
}
type ResolveError struct {
action string
Addr string
Err error
}
func (err ResolveError) Error() string {
if err.action == "dail" {
return "error dail '" + err.Addr + "': " + err.Err.Error()
}
return "error " + err.action + ": " + err.Err.Error()
}
// StunResolve query wgstun server for peer info
func StunResolve(raddr *net.UDPAddr, peers []wgtypes.Key) ([]ResolveResult, error) {
conn, err := net.DialUDP("udp", nil, raddr)
if err != nil {
return nil, ResolveError{ "dail", raddr.String(), err}
}
return StunResolveConn(conn, peers)
}
// StunResolveConn query wgstun server use specified udp connection
func StunResolveConn(conn *net.UDPConn, peers []wgtypes.Key) ([]ResolveResult, error) {
result := make([]ResolveResult, len(peers))
for i, peer :=range peers {
copy(result[i].Key[:], peer[:])
}
valid_count := 0
buffer := new(bytes.Buffer)
buffer.Grow(MaxPacketSize)
for _, peer := range peers {
req := NewPacketBody()
req.Type = PacketTypeFindPeer
req.SetKey(peer)
req.Reserved = 0
if buffer.Len() + req.PacketSize() > MaxPacketSize {
_ = conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
_, err := conn.Write(buffer.Bytes())
if err != nil {
return nil, ResolveError{ "send packet", conn.RemoteAddr().String(), err}
}
buffer.Reset()
}
WritePacket(buffer, req)
}
if buffer.Len() > 0 {
_ = conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
_, err := conn.Write(buffer.Bytes())
if err != nil {
return nil, ResolveError{ "send packet", conn.RemoteAddr().String(), err}
}
buffer.Reset()
}
_ = conn.SetReadDeadline(time.Now().Add(time.Second * 30))
var nack_keys []wgtypes.Key
for {
recv_buf := make([]byte, 2048)
n, rerr := conn.Read(recv_buf)
recv_buf = recv_buf[:n]
for len(recv_buf) > 4 {
p, _buf, err := ReadPacket(recv_buf)
if err != nil {
return nil, ResolveError{ "read packet", conn.RemoteAddr().String(), err}
}
var key wgtypes.Key
if p.PacketType() == PacketTypeNack {
key = p.(*PacketBody).GetKey()
new_nack_key := true
for i := range nack_keys {
if keyEquals(result[i].Key, key) {
new_nack_key = false
break
}
}
if new_nack_key {
nack_keys = append(nack_keys, key)
valid_count++
}
}
if res, ok := p.(ResponsePacket); ok {
key = res.GetKey()
for i := range result {
if (!result[i].Valid) && keyEquals(result[i].Key, key) {
result[i].Valid = true
result[i].Endpoint = res.GetUDPAddr()
result[i].HandshakeTime = res.GetHandshakeTime()
valid_count++
}
}
}
recv_buf = _buf
}
if valid_count == len(result) {
return result, nil
}
if rerr != nil {
if netErr, ok := rerr.(net.Error); ok {
if netErr.Timeout() {
return nil, ResolveError{ "timeout", conn.RemoteAddr().String(), rerr}
}
}
return nil, ResolveError{ "read packet", conn.RemoteAddr().String(), rerr}
}
}
}