Files
wgstun/client.go
2022-04-02 11:29:34 +08:00

276 lines
5.4 KiB
Go

package main
import (
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"net"
"os"
"strconv"
"strings"
"time"
)
// command to start once query for peer IP then update wireguard peers
func client_once(raddr *net.UDPAddr, conn *net.UDPConn, ifname string, peers []wgtypes.Key, force bool, once_print bool) {
wgClient, err := wgctrl.New()
defer wgClient.Close()
if err != nil {
printf("err create wireguard client instance: %v\n", err)
return
}
device, err := wgClient.Device(ifname)
if err != nil {
printf("err open wireguard interface: %v\n", err)
return
}
nPeers := make([]wgtypes.Key, 0, len(peers))
for _, peer := range peers {
found := false
for _, p := range device.Peers {
if keyEquals(peer, p.PublicKey) {
found = true
break
}
}
if found {
nPeers = append(nPeers, peer)
} else {
if once_print {
printf("%s ignored (not interface peer)\n", peer.String())
}
return
}
}
if len(nPeers) == 0 {
if once_print {
printf("nothing to query\n")
}
return
}
time_start_resolve := time.Now()
var results []ResolveResult
if conn != nil {
results, err = StunResolveConn(conn, nPeers)
} else {
results, err = StunResolve(raddr, nPeers)
}
if err != nil {
printf("err resolve: %v\n", err)
return
}
time_end_resolve := time.Now()
if !once_print {
printf("resolved %d peers in %v\n", len(nPeers), time_end_resolve.Sub(time_start_resolve))
}
device, err = wgClient.Device(ifname)
if err != nil {
printf("err open wireguard interface: %v\n", err)
return
}
cfg := wgtypes.Config{}
for _, result := range results {
if !result.Valid {
if once_print {
printf("%s not_resolved\n", result.Key.String())
} else {
printf("peer '%s' not_resolved\n", result.Key.String())
}
continue
}
var peer wgtypes.Peer
var peerFound = false
for _, _peer :=range device.Peers {
if keyEquals(result.Key, _peer.PublicKey) {
peer = _peer
peerFound = true
break
}
}
if !peerFound {
if once_print {
printf("%s not_found\n", result.Key.String())
}
continue
}
if !force {
if peer.Endpoint != nil &&
result.Endpoint.IP.Equal(peer.Endpoint.IP) &&
result.Endpoint.Port == peer.Endpoint.Port {
if once_print {
printf("%s no_change\n", result.Key.String())
}
continue
}
if peer.LastHandshakeTime.Sub(result.HandshakeTime) > 0 {
if once_print {
printf("%s newer\n", result.Key.String())
}
continue
}
}
if once_print {
printf("%s %s\n", result.Key.String(), result.Endpoint.String())
} else {
printf("peer '%s' ip changed\n", result.Key.String())
}
cfg.Peers = append(cfg.Peers, wgtypes.PeerConfig{
PublicKey: result.Key,
UpdateOnly: true,
Endpoint: result.Endpoint,
})
}
if len(cfg.Peers) == 0 {
if once_print {
printf("up-to-date\n")
}
return
}
err = wgClient.ConfigureDevice(ifname, cfg)
if err != nil {
printf("err submit wireguard config: %v\n", err)
return
}
printf("processed config wireguard\n")
}
// command to continuously update wireguard peers
func client_pull(raddr *net.UDPAddr, ifname string, peers []wgtypes.Key, interval int) {
conn, err := net.DialUDP("udp", nil, raddr)
if err != nil {
printf("err DialUDP: %v\n", err)
os.Exit(1)
return
}
printf("create UDP connection at '%s'\n", conn.LocalAddr().String())
ticker := time.Tick(time.Second * time.Duration(interval))
for _ = range ticker {
client_once(nil, conn, ifname, peers, false, false)
}
}
// client command entry
func client_main(args []string, once bool) {
var server_str string = "10.77.1.1:55550"
var ifname_str string
var interval_str string = "10"
var peers_str []string
var config string
var force bool
var verbose bool
_ = verbose
var ps *string
for _, arg := range args {
if ps != nil {
*ps = arg
ps = nil
continue
}
switch arg {
case "-s", "--server":
ps = &server_str
break
case "-i", "--interface":
ps = &ifname_str
break
case "-t", "--interval":
ps = &interval_str
break
case "-c", "--config":
ps = &config
break
case "-1", "--once":
once = true
break
case "-f", "--force":
force = true
break
case "-v":
verbose = true
break
default:
peers_str = append(peers_str, arg)
break
}
}
if ps != nil {
printf("parameter value not set\n")
os.Exit(1)
return
}
if server_str == "" || ifname_str == "" || len(peers_str) == 0 {
if config != "" {
printf("some parameter not set in config\n")
return
}
usage()
return
}
if !strings.Contains(server_str, ":") {
server_str += ":55550"
}
raddr, err := net.ResolveUDPAddr("udp", server_str)
if err != nil {
printf("error resolve address '%s': %v\n", server_str, err)
os.Exit(1)
return
}
peers := make([]wgtypes.Key, len(peers_str))
for i := range peers_str {
peers[i], err = wgtypes.ParseKey(peers_str[i])
if err != nil {
if len(server_str) > 9 {
server_str = server_str[:6]+"..."
}
printf("error parse key '%s': %v\n", server_str, err)
os.Exit(1)
return
}
}
if force && !once {
printf("--force can only be used in once mode\n")
os.Exit(1)
return
}
if once {
client_once(raddr, nil, ifname_str, peers, force, true)
return
}
interval, err := strconv.Atoi(interval_str)
if err != nil {
printf("error parse interval '%s': %v\n", server_str, err)
os.Exit(1)
return
}
client_pull(raddr, ifname_str, peers, interval)
}