mirror of
https://github.com/op0xA5/wgstun.git
synced 2025-09-26 19:41:16 +08:00
588 lines
13 KiB
Go
588 lines
13 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"golang.zx2c4.com/wireguard/wgctrl"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
// minimum count to build cache lookup map
|
|
// use O(n) search when peer count less than minCacheDictLen
|
|
// set minCacheDictLen larger can reduce map rebuild actions
|
|
const minCacheDictLen = 16
|
|
|
|
type peerCacheItem struct {
|
|
PublicKey wgtypes.Key
|
|
InternalIP net.IP
|
|
IP net.IP
|
|
Port int
|
|
LastHandshakeTime time.Time
|
|
}
|
|
var peerCache = struct {
|
|
// items store all peerCacheItems
|
|
items []peerCacheItem
|
|
// keyDict or ipDict use for speed up retrieve item when items count larger than minCacheDictLen
|
|
keyDict map[wgtypes.Key]*peerCacheItem
|
|
ipDict map[[net.IPv6len]byte]*peerCacheItem
|
|
expires time.Time
|
|
err error
|
|
mu sync.RWMutex
|
|
}{}
|
|
|
|
var errExpired = errors.New("cache expired")
|
|
// find peer info by peer's public key
|
|
func findPeerCache(key wgtypes.Key) (*peerCacheItem, error) {
|
|
peerCache.mu.RLock()
|
|
defer peerCache.mu.RUnlock()
|
|
|
|
if time.Now().After(peerCache.expires) {
|
|
return nil, errExpired
|
|
}
|
|
|
|
if peerCache.err != nil {
|
|
return nil, peerCache.err
|
|
}
|
|
|
|
if peerCache.keyDict != nil {
|
|
res, _ := peerCache.keyDict[key]
|
|
return res, nil
|
|
}
|
|
|
|
for i := range peerCache.items {
|
|
if keyEquals(key, peerCache.items[i].PublicKey) {
|
|
return &peerCache.items[i], nil
|
|
}
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
// find peer info by peer's allowed ip, only support /32 CIDR
|
|
func findInternalIPCache(ip net.IP) (*peerCacheItem, error) {
|
|
if ip.Equal(net.IPv4zero) {
|
|
return nil, nil
|
|
}
|
|
|
|
peerCache.mu.RLock()
|
|
defer peerCache.mu.RUnlock()
|
|
|
|
if time.Now().After(peerCache.expires) {
|
|
return nil, errExpired
|
|
}
|
|
|
|
if peerCache.err != nil {
|
|
return nil, peerCache.err
|
|
}
|
|
|
|
if peerCache.keyDict != nil {
|
|
ip16 := [net.IPv6len]byte{}
|
|
copy(ip16[:], ip.To16())
|
|
res, _ := peerCache.ipDict[ip16]
|
|
return res, nil
|
|
}
|
|
|
|
for i := range peerCache.items {
|
|
if ip.Equal(peerCache.items[i].InternalIP) {
|
|
return &peerCache.items[i], nil
|
|
}
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func updateCache() error {
|
|
peerCache.mu.Lock()
|
|
defer peerCache.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
if !now.After(peerCache.expires) {
|
|
return nil
|
|
}
|
|
|
|
peerCache.expires = now.Add(server_config.expire)
|
|
|
|
wgClient, err := wgctrl.New()
|
|
defer wgClient.Close()
|
|
if err != nil {
|
|
peerCache.err = err
|
|
return err
|
|
}
|
|
|
|
device, err := wgClient.Device(server_config.ifname)
|
|
if err != nil {
|
|
peerCache.err = err
|
|
return err
|
|
}
|
|
|
|
if len(device.Peers) == 0 {
|
|
peerCache.items = peerCache.items[:0]
|
|
peerCache.keyDict = nil
|
|
peerCache.ipDict = nil
|
|
return nil
|
|
}
|
|
|
|
if cap(peerCache.items) < len(device.Peers) {
|
|
// extend cache array space
|
|
_cap := ((len(device.Peers)-1)/16 + 1) * 16
|
|
peerCache.items = make([]peerCacheItem, len(device.Peers), _cap)
|
|
} else {
|
|
// reuse cache array
|
|
peerCache.items = peerCache.items[:len(device.Peers)]
|
|
}
|
|
|
|
for i, peer := range device.Peers {
|
|
peerCache.items[i] = peerCacheItem{
|
|
PublicKey: peer.PublicKey,
|
|
InternalIP: net.IPv4zero,
|
|
IP: net.IPv4zero,
|
|
Port: 0,
|
|
LastHandshakeTime: peer.LastHandshakeTime,
|
|
}
|
|
|
|
if peer.Endpoint != nil {
|
|
peerCache.items[i].IP = peer.Endpoint.IP
|
|
peerCache.items[i].Port = peer.Endpoint.Port
|
|
}
|
|
|
|
for _, ipnet := range peer.AllowedIPs {
|
|
ones, bits := ipnet.Mask.Size()
|
|
if ones == bits {
|
|
peerCache.items[i].InternalIP = ipnet.IP
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(peerCache.items) > minCacheDictLen {
|
|
peerCache.keyDict = make(map[wgtypes.Key]*peerCacheItem, len(peerCache.items))
|
|
for i := range peerCache.items {
|
|
peerCache.keyDict[peerCache.items[i].PublicKey] = &peerCache.items[i]
|
|
}
|
|
peerCache.ipDict = make(map[[net.IPv6len]byte]*peerCacheItem)
|
|
for i := range peerCache.items {
|
|
ip16 := [net.IPv6len]byte{}
|
|
copy(ip16[:], peerCache.items[i].InternalIP.To16())
|
|
peerCache.ipDict[ip16] = &peerCache.items[i]
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// serve udp connection
|
|
func listenFunc(conn *net.UDPConn) {
|
|
// buffer for read udp packet
|
|
bufferPool := sync.Pool{
|
|
New: func() interface{} { return make([]byte, 2048) },
|
|
}
|
|
// buffer for hold response data in handler function
|
|
packetBufferPool := sync.Pool{
|
|
New: func() interface{} { return new(bytes.Buffer) },
|
|
}
|
|
|
|
for {
|
|
buffer := bufferPool.Get().([]byte)
|
|
n, addr, err := conn.ReadFromUDP(buffer)
|
|
if err != nil {
|
|
bufferPool.Put(buffer)
|
|
|
|
// a hack check an error cause by network closed by user
|
|
str := err.Error()
|
|
if strings.Contains(str, "use of closed network connection") {
|
|
return
|
|
}
|
|
|
|
printf("ReadFromUDP err: %v\n", err)
|
|
conn.Close()
|
|
printf("listen goroutine exited\n")
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
packetBuffer := packetBufferPool.Get().(*bytes.Buffer)
|
|
packetBuffer.Reset()
|
|
handleFunc(buffer[:n], addr, conn, packetBuffer)
|
|
packetBufferPool.Put(packetBuffer)
|
|
bufferPool.Put(buffer)
|
|
}()
|
|
}
|
|
}
|
|
|
|
// handle once query
|
|
func handleFunc(b []byte, addr *net.UDPAddr, conn *net.UDPConn, buffer *bytes.Buffer) {
|
|
var req *PacketBody
|
|
var res Packet
|
|
|
|
// response should send separately if individualPacket is true, like Pong packet
|
|
individualPacket := false
|
|
|
|
for len(b) > 4 {
|
|
packet, _b, err := ReadPacket(b)
|
|
if err != nil {
|
|
printf("error read packet: %v\n", err)
|
|
continue
|
|
}
|
|
b = _b
|
|
|
|
switch packet.PacketType() {
|
|
case PacketTypePing:
|
|
req = packet.(*PacketBody)
|
|
|
|
pong := NewPacketBody()
|
|
pong.Type = PacketTypePong
|
|
copy(pong.Identity[:], req.Identity[:])
|
|
res = pong
|
|
individualPacket = true
|
|
break
|
|
case PacketTypeFindPeer:
|
|
req = packet.(*PacketBody)
|
|
|
|
item, err := findPeerCache(req.GetKey())
|
|
if err == errExpired {
|
|
err = updateCache()
|
|
if err != nil {
|
|
printf("error communicate wireguard device: %v\n", err)
|
|
}
|
|
item, err = findPeerCache(req.GetKey())
|
|
}
|
|
|
|
if item == nil {
|
|
nack := NewPacketBody()
|
|
nack.Type = PacketTypeNack
|
|
nack.SetKey(req.GetKey())
|
|
res = nack
|
|
} else {
|
|
if v4 := item.IP.To4(); v4 != nil {
|
|
_res := NewPacketIPv4Body()
|
|
_res.SetKey(item.PublicKey)
|
|
_res.SetIP(item.IP)
|
|
_res.Port = uint16(item.Port)
|
|
_res.SetHandshakeTime(item.LastHandshakeTime)
|
|
res = _res
|
|
} else {
|
|
_res := NewPacketIPv6Body()
|
|
_res.SetKey(item.PublicKey)
|
|
_res.SetIP(item.IP)
|
|
_res.Port = uint16(item.Port)
|
|
_res.SetHandshakeTime(item.LastHandshakeTime)
|
|
res = _res
|
|
}
|
|
}
|
|
|
|
break
|
|
case PacketTypeGetMyIP:
|
|
req = packet.(*PacketBody)
|
|
|
|
item, err := findInternalIPCache(addr.IP)
|
|
if err == errExpired {
|
|
err = updateCache()
|
|
if err != nil {
|
|
printf("error communicate wireguard device: %v\n", err)
|
|
}
|
|
item, err = findPeerCache(req.GetKey())
|
|
}
|
|
|
|
if item == nil {
|
|
nack := NewPacketBody()
|
|
nack.Type = PacketTypeNack
|
|
res = nack
|
|
} else {
|
|
if v4 := item.IP.To4(); v4 != nil {
|
|
_res := NewPacketIPv4Body()
|
|
_res.SetIP(item.IP)
|
|
_res.Port = uint16(item.Port)
|
|
_res.SetHandshakeTime(item.LastHandshakeTime)
|
|
res = _res
|
|
} else {
|
|
_res := NewPacketIPv6Body()
|
|
_res.SetIP(item.IP)
|
|
_res.Port = uint16(item.Port)
|
|
_res.SetHandshakeTime(item.LastHandshakeTime)
|
|
res = _res
|
|
}
|
|
}
|
|
|
|
break
|
|
}
|
|
|
|
if individualPacket {
|
|
if buffer.Len() > 0 {
|
|
_, err = conn.WriteToUDP(buffer.Bytes(), addr)
|
|
if err != nil {
|
|
printf("error write packet: %v\n", err)
|
|
}
|
|
}
|
|
|
|
buffer.Reset()
|
|
|
|
if buffer.Cap() == 0 {
|
|
buffer.Grow(MaxPacketSize)
|
|
}
|
|
WritePacket(buffer, res)
|
|
_, err = conn.WriteToUDP(buffer.Bytes(), addr)
|
|
if err != nil {
|
|
printf("error write packet: %v\n", err)
|
|
}
|
|
|
|
buffer.Reset()
|
|
} else {
|
|
bufferLen := buffer.Len()
|
|
if bufferLen > 0 && bufferLen + res.PacketSize() > MaxPacketSize {
|
|
_, err = conn.WriteToUDP(buffer.Bytes(), addr)
|
|
if err != nil {
|
|
printf("error write packet: %v\n", err)
|
|
}
|
|
|
|
buffer.Reset()
|
|
}
|
|
|
|
if buffer.Cap() == 0 {
|
|
buffer.Grow(MaxPacketSize)
|
|
}
|
|
WritePacket(buffer, res)
|
|
}
|
|
}
|
|
|
|
if buffer.Len() > 0 {
|
|
_, err := conn.WriteToUDP(buffer.Bytes(), addr)
|
|
if err != nil {
|
|
printf("error write packet: %s\n", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
var httpServer *http.Server
|
|
func httpListenFunc(l *net.TCPListener, root string) {
|
|
http.Handle("/peer", http.HandlerFunc(httpHandler))
|
|
http.Handle("/", http.FileServer(http.Dir(root)))
|
|
|
|
httpServer = &http.Server{
|
|
Addr: l.Addr().String(),
|
|
ReadTimeout: 60 * time.Second,
|
|
ReadHeaderTimeout: 60 * time.Second,
|
|
WriteTimeout: 300 * time.Second,
|
|
IdleTimeout: 300 * time.Second,
|
|
}
|
|
|
|
err := httpServer.Serve(l)
|
|
if err != nil {
|
|
printf("error serve http: %s\n", err)
|
|
}
|
|
}
|
|
|
|
func httpHandler(w http.ResponseWriter, r *http.Request) {
|
|
err := r.ParseForm()
|
|
if err != nil {
|
|
http.Error(w, "400 Bad Request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
key := r.Form.Get("public_key")
|
|
if key == "" {
|
|
key = r.Form.Get("key")
|
|
}
|
|
if key == "" {
|
|
http.Error(w, "404 Not Found", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
typ := r.Form.Get("type")
|
|
|
|
wgkey, err := wgtypes.ParseKey(strings.ReplaceAll(key, " ", "+"))
|
|
if err != nil {
|
|
http.Error(w, "404 Not Found", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
item, err := findPeerCache(wgkey)
|
|
if err == errExpired {
|
|
err = updateCache()
|
|
if err != nil {
|
|
printf("error communicate wireguard device: %v\n", err)
|
|
http.Error(w, "500 Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
item, err = findPeerCache(wgkey)
|
|
}
|
|
|
|
if item == nil {
|
|
http.Error(w, "404 Not Found", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
if typ == "json" {
|
|
body, err := json.Marshal(struct {
|
|
Addr string `json:"addr"`
|
|
IP string `json:"ip"`
|
|
LastHandshake string `json:"last_handshake"`
|
|
}{
|
|
item.IP.String() + ":" + strconv.Itoa(item.Port),
|
|
item.IP.String(),
|
|
item.LastHandshakeTime.Format("2006-01-02 15:04:05"),
|
|
})
|
|
if err != nil {
|
|
http.Error(w, "500 Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
w.Header().Set("Expires", time.Now().Add(-1 * time.Second).Format(http.TimeFormat))
|
|
w.Header().Set("Last-Modified", item.LastHandshakeTime.Format(http.TimeFormat))
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("Content-Type", "text/json; charset=utf-8")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write(body)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Expires", time.Now().Add(-1 * time.Second).Format(http.TimeFormat))
|
|
w.Header().Set("Last-Modified", item.LastHandshakeTime.Format(http.TimeFormat))
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(item.IP.String() + ":" + strconv.Itoa(item.Port)))
|
|
return
|
|
}
|
|
|
|
var server_config = struct {
|
|
ifname string
|
|
expire time.Duration
|
|
}{}
|
|
|
|
func server_main(args []string) {
|
|
var listen_str string = ":55550"
|
|
var ifname_str string
|
|
var expire_str string = "5"
|
|
var verbose bool
|
|
var listen_http_str string = ""
|
|
var http_root string = "./wwwroot"
|
|
_ = verbose
|
|
|
|
var ps *string
|
|
for _, arg := range args {
|
|
if ps != nil {
|
|
*ps = arg
|
|
ps = nil
|
|
continue
|
|
}
|
|
|
|
switch arg {
|
|
case "-l", "--listen":
|
|
ps = &listen_str
|
|
break
|
|
case "-e", "--expire":
|
|
ps = &expire_str
|
|
break
|
|
case "-i", "--interface":
|
|
ps = &ifname_str
|
|
break
|
|
case "-v":
|
|
verbose = true
|
|
break
|
|
case "-w", "--web":
|
|
ps = &listen_http_str
|
|
break
|
|
case "-r", "--root":
|
|
ps = &http_root
|
|
break
|
|
default:
|
|
printf("unknown parameter '%s'\n", arg)
|
|
os.Exit(1)
|
|
break
|
|
}
|
|
}
|
|
|
|
if ps != nil {
|
|
printf("parameter value not set\n")
|
|
os.Exit(1)
|
|
return
|
|
}
|
|
|
|
if ifname_str == ""{
|
|
usage()
|
|
return
|
|
}
|
|
|
|
listenAddr, err := net.ResolveUDPAddr("udp", listen_str)
|
|
if err != nil {
|
|
printf("error resolve address '%s': %v\n", listen_str, err)
|
|
os.Exit(1)
|
|
return
|
|
}
|
|
|
|
var httpListenAddr *net.TCPAddr
|
|
if listen_http_str != "" {
|
|
httpListenAddr, err = net.ResolveTCPAddr("tcp", listen_http_str)
|
|
if err != nil {
|
|
printf("error resolve address '%s': %v\n", listen_http_str, err)
|
|
os.Exit(1)
|
|
return
|
|
}
|
|
}
|
|
|
|
expireSec, err := strconv.Atoi(expire_str)
|
|
if err != nil {
|
|
printf("error parse expire '%s': %v\n", expire_str, err)
|
|
os.Exit(1)
|
|
return
|
|
}
|
|
|
|
if expireSec < 1 {
|
|
expireSec = 1
|
|
}
|
|
server_config.expire = time.Second * time.Duration(expireSec)
|
|
server_config.ifname = ifname_str
|
|
|
|
listener, err := net.ListenUDP("udp", listenAddr)
|
|
if err != nil {
|
|
printf("error create socket '%s': %v\n", listenAddr.String(), err)
|
|
os.Exit(1)
|
|
return
|
|
}
|
|
|
|
printf("server started on '%s'\n", listenAddr.String())
|
|
|
|
var httpListener *net.TCPListener
|
|
if httpListenAddr != nil {
|
|
httpListener, err = net.ListenTCP("tcp", httpListenAddr)
|
|
if err != nil {
|
|
printf("error create socket '%s': %v\n", httpListenAddr.String(), err)
|
|
os.Exit(1)
|
|
return
|
|
}
|
|
|
|
printf("http server started on '%s'\n", httpListenAddr.String())
|
|
}
|
|
|
|
go listenFunc(listener)
|
|
if httpListener != nil {
|
|
go httpListenFunc(httpListener, http_root)
|
|
}
|
|
|
|
signalC := make(chan os.Signal)
|
|
signal.Notify(signalC, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM)
|
|
s := <- signalC
|
|
printf("signal %s received\n", s.String())
|
|
|
|
_ = listener.Close()
|
|
if httpServer != nil {
|
|
httpServer.Close()
|
|
}
|
|
if httpListener != nil {
|
|
_ = httpListener.Close()
|
|
}
|
|
|
|
printf("server exited\n")
|
|
os.Exit(0)
|
|
}
|