Files
wgstun/server.go
2022-04-02 11:29:34 +08:00

588 lines
13 KiB
Go

package main
import (
"bytes"
"encoding/json"
"errors"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"net"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
)
// minimum count to build cache lookup map
// use O(n) search when peer count less than minCacheDictLen
// set minCacheDictLen larger can reduce map rebuild actions
const minCacheDictLen = 16
type peerCacheItem struct {
PublicKey wgtypes.Key
InternalIP net.IP
IP net.IP
Port int
LastHandshakeTime time.Time
}
var peerCache = struct {
// items store all peerCacheItems
items []peerCacheItem
// keyDict or ipDict use for speed up retrieve item when items count larger than minCacheDictLen
keyDict map[wgtypes.Key]*peerCacheItem
ipDict map[[net.IPv6len]byte]*peerCacheItem
expires time.Time
err error
mu sync.RWMutex
}{}
var errExpired = errors.New("cache expired")
// find peer info by peer's public key
func findPeerCache(key wgtypes.Key) (*peerCacheItem, error) {
peerCache.mu.RLock()
defer peerCache.mu.RUnlock()
if time.Now().After(peerCache.expires) {
return nil, errExpired
}
if peerCache.err != nil {
return nil, peerCache.err
}
if peerCache.keyDict != nil {
res, _ := peerCache.keyDict[key]
return res, nil
}
for i := range peerCache.items {
if keyEquals(key, peerCache.items[i].PublicKey) {
return &peerCache.items[i], nil
}
}
return nil, nil
}
// find peer info by peer's allowed ip, only support /32 CIDR
func findInternalIPCache(ip net.IP) (*peerCacheItem, error) {
if ip.Equal(net.IPv4zero) {
return nil, nil
}
peerCache.mu.RLock()
defer peerCache.mu.RUnlock()
if time.Now().After(peerCache.expires) {
return nil, errExpired
}
if peerCache.err != nil {
return nil, peerCache.err
}
if peerCache.keyDict != nil {
ip16 := [net.IPv6len]byte{}
copy(ip16[:], ip.To16())
res, _ := peerCache.ipDict[ip16]
return res, nil
}
for i := range peerCache.items {
if ip.Equal(peerCache.items[i].InternalIP) {
return &peerCache.items[i], nil
}
}
return nil, nil
}
func updateCache() error {
peerCache.mu.Lock()
defer peerCache.mu.Unlock()
now := time.Now()
if !now.After(peerCache.expires) {
return nil
}
peerCache.expires = now.Add(server_config.expire)
wgClient, err := wgctrl.New()
defer wgClient.Close()
if err != nil {
peerCache.err = err
return err
}
device, err := wgClient.Device(server_config.ifname)
if err != nil {
peerCache.err = err
return err
}
if len(device.Peers) == 0 {
peerCache.items = peerCache.items[:0]
peerCache.keyDict = nil
peerCache.ipDict = nil
return nil
}
if cap(peerCache.items) < len(device.Peers) {
// extend cache array space
_cap := ((len(device.Peers)-1)/16 + 1) * 16
peerCache.items = make([]peerCacheItem, len(device.Peers), _cap)
} else {
// reuse cache array
peerCache.items = peerCache.items[:len(device.Peers)]
}
for i, peer := range device.Peers {
peerCache.items[i] = peerCacheItem{
PublicKey: peer.PublicKey,
InternalIP: net.IPv4zero,
IP: net.IPv4zero,
Port: 0,
LastHandshakeTime: peer.LastHandshakeTime,
}
if peer.Endpoint != nil {
peerCache.items[i].IP = peer.Endpoint.IP
peerCache.items[i].Port = peer.Endpoint.Port
}
for _, ipnet := range peer.AllowedIPs {
ones, bits := ipnet.Mask.Size()
if ones == bits {
peerCache.items[i].InternalIP = ipnet.IP
break
}
}
}
if len(peerCache.items) > minCacheDictLen {
peerCache.keyDict = make(map[wgtypes.Key]*peerCacheItem, len(peerCache.items))
for i := range peerCache.items {
peerCache.keyDict[peerCache.items[i].PublicKey] = &peerCache.items[i]
}
peerCache.ipDict = make(map[[net.IPv6len]byte]*peerCacheItem)
for i := range peerCache.items {
ip16 := [net.IPv6len]byte{}
copy(ip16[:], peerCache.items[i].InternalIP.To16())
peerCache.ipDict[ip16] = &peerCache.items[i]
}
}
return nil
}
// serve udp connection
func listenFunc(conn *net.UDPConn) {
// buffer for read udp packet
bufferPool := sync.Pool{
New: func() interface{} { return make([]byte, 2048) },
}
// buffer for hold response data in handler function
packetBufferPool := sync.Pool{
New: func() interface{} { return new(bytes.Buffer) },
}
for {
buffer := bufferPool.Get().([]byte)
n, addr, err := conn.ReadFromUDP(buffer)
if err != nil {
bufferPool.Put(buffer)
// a hack check an error cause by network closed by user
str := err.Error()
if strings.Contains(str, "use of closed network connection") {
return
}
printf("ReadFromUDP err: %v\n", err)
conn.Close()
printf("listen goroutine exited\n")
return
}
go func() {
packetBuffer := packetBufferPool.Get().(*bytes.Buffer)
packetBuffer.Reset()
handleFunc(buffer[:n], addr, conn, packetBuffer)
packetBufferPool.Put(packetBuffer)
bufferPool.Put(buffer)
}()
}
}
// handle once query
func handleFunc(b []byte, addr *net.UDPAddr, conn *net.UDPConn, buffer *bytes.Buffer) {
var req *PacketBody
var res Packet
// response should send separately if individualPacket is true, like Pong packet
individualPacket := false
for len(b) > 4 {
packet, _b, err := ReadPacket(b)
if err != nil {
printf("error read packet: %v\n", err)
continue
}
b = _b
switch packet.PacketType() {
case PacketTypePing:
req = packet.(*PacketBody)
pong := NewPacketBody()
pong.Type = PacketTypePong
copy(pong.Identity[:], req.Identity[:])
res = pong
individualPacket = true
break
case PacketTypeFindPeer:
req = packet.(*PacketBody)
item, err := findPeerCache(req.GetKey())
if err == errExpired {
err = updateCache()
if err != nil {
printf("error communicate wireguard device: %v\n", err)
}
item, err = findPeerCache(req.GetKey())
}
if item == nil {
nack := NewPacketBody()
nack.Type = PacketTypeNack
nack.SetKey(req.GetKey())
res = nack
} else {
if v4 := item.IP.To4(); v4 != nil {
_res := NewPacketIPv4Body()
_res.SetKey(item.PublicKey)
_res.SetIP(item.IP)
_res.Port = uint16(item.Port)
_res.SetHandshakeTime(item.LastHandshakeTime)
res = _res
} else {
_res := NewPacketIPv6Body()
_res.SetKey(item.PublicKey)
_res.SetIP(item.IP)
_res.Port = uint16(item.Port)
_res.SetHandshakeTime(item.LastHandshakeTime)
res = _res
}
}
break
case PacketTypeGetMyIP:
req = packet.(*PacketBody)
item, err := findInternalIPCache(addr.IP)
if err == errExpired {
err = updateCache()
if err != nil {
printf("error communicate wireguard device: %v\n", err)
}
item, err = findPeerCache(req.GetKey())
}
if item == nil {
nack := NewPacketBody()
nack.Type = PacketTypeNack
res = nack
} else {
if v4 := item.IP.To4(); v4 != nil {
_res := NewPacketIPv4Body()
_res.SetIP(item.IP)
_res.Port = uint16(item.Port)
_res.SetHandshakeTime(item.LastHandshakeTime)
res = _res
} else {
_res := NewPacketIPv6Body()
_res.SetIP(item.IP)
_res.Port = uint16(item.Port)
_res.SetHandshakeTime(item.LastHandshakeTime)
res = _res
}
}
break
}
if individualPacket {
if buffer.Len() > 0 {
_, err = conn.WriteToUDP(buffer.Bytes(), addr)
if err != nil {
printf("error write packet: %v\n", err)
}
}
buffer.Reset()
if buffer.Cap() == 0 {
buffer.Grow(MaxPacketSize)
}
WritePacket(buffer, res)
_, err = conn.WriteToUDP(buffer.Bytes(), addr)
if err != nil {
printf("error write packet: %v\n", err)
}
buffer.Reset()
} else {
bufferLen := buffer.Len()
if bufferLen > 0 && bufferLen + res.PacketSize() > MaxPacketSize {
_, err = conn.WriteToUDP(buffer.Bytes(), addr)
if err != nil {
printf("error write packet: %v\n", err)
}
buffer.Reset()
}
if buffer.Cap() == 0 {
buffer.Grow(MaxPacketSize)
}
WritePacket(buffer, res)
}
}
if buffer.Len() > 0 {
_, err := conn.WriteToUDP(buffer.Bytes(), addr)
if err != nil {
printf("error write packet: %s\n", err)
}
}
}
var httpServer *http.Server
func httpListenFunc(l *net.TCPListener, root string) {
http.Handle("/peer", http.HandlerFunc(httpHandler))
http.Handle("/", http.FileServer(http.Dir(root)))
httpServer = &http.Server{
Addr: l.Addr().String(),
ReadTimeout: 60 * time.Second,
ReadHeaderTimeout: 60 * time.Second,
WriteTimeout: 300 * time.Second,
IdleTimeout: 300 * time.Second,
}
err := httpServer.Serve(l)
if err != nil {
printf("error serve http: %s\n", err)
}
}
func httpHandler(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
http.Error(w, "400 Bad Request", http.StatusBadRequest)
return
}
key := r.Form.Get("public_key")
if key == "" {
key = r.Form.Get("key")
}
if key == "" {
http.Error(w, "404 Not Found", http.StatusNotFound)
return
}
typ := r.Form.Get("type")
wgkey, err := wgtypes.ParseKey(strings.ReplaceAll(key, " ", "+"))
if err != nil {
http.Error(w, "404 Not Found", http.StatusNotFound)
return
}
item, err := findPeerCache(wgkey)
if err == errExpired {
err = updateCache()
if err != nil {
printf("error communicate wireguard device: %v\n", err)
http.Error(w, "500 Internal Server Error", http.StatusInternalServerError)
return
}
item, err = findPeerCache(wgkey)
}
if item == nil {
http.Error(w, "404 Not Found", http.StatusNotFound)
return
}
if typ == "json" {
body, err := json.Marshal(struct {
Addr string `json:"addr"`
IP string `json:"ip"`
LastHandshake string `json:"last_handshake"`
}{
item.IP.String() + ":" + strconv.Itoa(item.Port),
item.IP.String(),
item.LastHandshakeTime.Format("2006-01-02 15:04:05"),
})
if err != nil {
http.Error(w, "500 Internal Server Error", http.StatusInternalServerError)
return
}
w.Header().Set("Expires", time.Now().Add(-1 * time.Second).Format(http.TimeFormat))
w.Header().Set("Last-Modified", item.LastHandshakeTime.Format(http.TimeFormat))
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("Content-Type", "text/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(body)
return
}
w.Header().Set("Expires", time.Now().Add(-1 * time.Second).Format(http.TimeFormat))
w.Header().Set("Last-Modified", item.LastHandshakeTime.Format(http.TimeFormat))
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(item.IP.String() + ":" + strconv.Itoa(item.Port)))
return
}
var server_config = struct {
ifname string
expire time.Duration
}{}
func server_main(args []string) {
var listen_str string = ":55550"
var ifname_str string
var expire_str string = "5"
var verbose bool
var listen_http_str string = ""
var http_root string = "./wwwroot"
_ = verbose
var ps *string
for _, arg := range args {
if ps != nil {
*ps = arg
ps = nil
continue
}
switch arg {
case "-l", "--listen":
ps = &listen_str
break
case "-e", "--expire":
ps = &expire_str
break
case "-i", "--interface":
ps = &ifname_str
break
case "-v":
verbose = true
break
case "-w", "--web":
ps = &listen_http_str
break
case "-r", "--root":
ps = &http_root
break
default:
printf("unknown parameter '%s'\n", arg)
os.Exit(1)
break
}
}
if ps != nil {
printf("parameter value not set\n")
os.Exit(1)
return
}
if ifname_str == ""{
usage()
return
}
listenAddr, err := net.ResolveUDPAddr("udp", listen_str)
if err != nil {
printf("error resolve address '%s': %v\n", listen_str, err)
os.Exit(1)
return
}
var httpListenAddr *net.TCPAddr
if listen_http_str != "" {
httpListenAddr, err = net.ResolveTCPAddr("tcp", listen_http_str)
if err != nil {
printf("error resolve address '%s': %v\n", listen_http_str, err)
os.Exit(1)
return
}
}
expireSec, err := strconv.Atoi(expire_str)
if err != nil {
printf("error parse expire '%s': %v\n", expire_str, err)
os.Exit(1)
return
}
if expireSec < 1 {
expireSec = 1
}
server_config.expire = time.Second * time.Duration(expireSec)
server_config.ifname = ifname_str
listener, err := net.ListenUDP("udp", listenAddr)
if err != nil {
printf("error create socket '%s': %v\n", listenAddr.String(), err)
os.Exit(1)
return
}
printf("server started on '%s'\n", listenAddr.String())
var httpListener *net.TCPListener
if httpListenAddr != nil {
httpListener, err = net.ListenTCP("tcp", httpListenAddr)
if err != nil {
printf("error create socket '%s': %v\n", httpListenAddr.String(), err)
os.Exit(1)
return
}
printf("http server started on '%s'\n", httpListenAddr.String())
}
go listenFunc(listener)
if httpListener != nil {
go httpListenFunc(httpListener, http_root)
}
signalC := make(chan os.Signal)
signal.Notify(signalC, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM)
s := <- signalC
printf("signal %s received\n", s.String())
_ = listener.Close()
if httpServer != nil {
httpServer.Close()
}
if httpListener != nil {
_ = httpListener.Close()
}
printf("server exited\n")
os.Exit(0)
}