use wireguard to set dns and iproute not using netsh or powershell

This commit is contained in:
naison
2021-10-23 10:59:41 +08:00
parent a38e77f067
commit 9194d154ac
7 changed files with 84 additions and 181 deletions

View File

@@ -4,50 +4,48 @@
package dns
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"net"
"os"
"os/exec"
"strconv"
)
func SetupDNS(ip string, namespace string) error {
tunName := os.Getenv("tunName")
log.Info("tun name: " + tunName)
_ = cleanDnsServer(tunName)
cmd := exec.Command("netsh", []string{
"interface",
"ipv4",
"add",
"dnsservers",
fmt.Sprintf("name=\"%s\"", tunName),
fmt.Sprintf("address=%s", ip),
"index=1",
}...)
output, err := cmd.CombinedOutput()
getenv := os.Getenv("luid")
parseUint, err := strconv.ParseUint(getenv, 10, 64)
if err != nil {
log.Warnf("error while set dns server, error: %v, output: %s, command: %v", err, string(output), cmd.Args)
log.Warningln(err)
return err
}
_ = addNicSuffixSearchList(namespace)
_ = updateNicMetric(tunName)
luid := winipcfg.LUID(parseUint)
err = luid.SetDNS(windows.AF_INET, []net.IP{net.ParseIP(ip)}, []string{
namespace + ".svc.cluster.local",
"svc.cluster.local",
"cluster.local",
})
_ = exec.CommandContext(context.Background(), "ipconfig", "/flushdns").Run()
if err != nil {
log.Warningln(err)
return err
}
//_ = updateNicMetric(tunName)
return nil
}
func CancelDNS() {
}
// @see https://docs.microsoft.com/en-us/powershell/module/dnsclient/set-dnsclientglobalsetting?view=windowsserver2019-ps#example-1--set-the-dns-suffix-search-list
func addNicSuffixSearchList(namespace string) error {
cmd := exec.Command("PowerShell", []string{
"Set-DnsClientGlobalSetting",
"-SuffixSearchList",
fmt.Sprintf("@(\"%s.svc.cluster.local\", \"svc.cluster.local\")", namespace),
}...)
output, err := cmd.CombinedOutput()
log.Info(cmd.Args)
getenv := os.Getenv("luid")
parseUint, err := strconv.ParseUint(getenv, 10, 64)
if err != nil {
log.Warnf("error while set dns suffix search list, err: %v, output: %s, command: %v", err, string(output), cmd.Args)
log.Warningln(err)
return
}
return err
luid := winipcfg.LUID(parseUint)
_ = luid.FlushDNS(windows.AF_INET)
}
func updateNicMetric(name string) error {
@@ -64,19 +62,3 @@ func updateNicMetric(name string) error {
}
return err
}
func cleanDnsServer(name string) error {
cmd := exec.Command("netsh", []string{
"interface",
"ipv4",
"delete",
"dnsservers",
fmt.Sprintf("\"%s\"", name),
"all",
}...)
out, err := cmd.CombinedOutput()
if err != nil {
log.Warnf("clean dnsservers failed, error: %v, output: %s, command: %v", err, string(out), cmd.Args)
}
return err
}

View File

@@ -70,10 +70,9 @@ func (r *Route) GenRouters() ([]router, error) {
case "tcp":
ln, err = core.TCPListener(node.Addr)
case "tun":
cfg := tun.TunConfig{
cfg := tun.Config{
Name: node.Get("name"),
Addr: node.Get("net"),
Peer: node.Get("peer"),
MTU: node.GetInt("mtu"),
Routes: tunRoutes,
Gateway: node.Get("gw"),

View File

@@ -9,11 +9,10 @@ import (
"time"
)
// TunConfig is the config for TUN device.
type TunConfig struct {
// Config is the config for TUN device.
type Config struct {
Name string
Addr string
Peer string // peer addr of point-to-point on MacOS
MTU int
Routes []IPRoute
Gateway string
@@ -23,11 +22,11 @@ type tunListener struct {
addr net.Addr
conns chan net.Conn
closed chan struct{}
config TunConfig
config Config
}
// TunListener creates a listener for tun tunnel.
func TunListener(cfg TunConfig) (Listener, error) {
func TunListener(cfg Config) (Listener, error) {
threads := 1
ln := &tunListener{
conns: make(chan net.Conn, threads),

View File

@@ -11,7 +11,7 @@ import (
"github.com/songgao/water"
)
func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) {
func createTun(cfg Config) (conn net.Conn, itf *net.Interface, err error) {
ip, _, err := net.ParseCIDR(cfg.Addr)
if err != nil {
return

View File

@@ -13,7 +13,7 @@ import (
"github.com/songgao/water"
)
func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) {
func createTun(cfg Config) (conn net.Conn, itf *net.Interface, err error) {
ip, ipNet, err := net.ParseCIDR(cfg.Addr)
if err != nil {
return

View File

@@ -14,7 +14,7 @@ import (
"github.com/songgao/water"
)
func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) {
func createTun(cfg Config) (conn net.Conn, itf *net.Interface, err error) {
ip, _, err := net.ParseCIDR(cfg.Addr)
if err != nil {
return

View File

@@ -1,114 +1,60 @@
package tun
import (
"context"
"fmt"
"github.com/pkg/errors"
"golang.org/x/sys/windows"
wireguardtun "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"k8s.io/client-go/util/retry"
"net"
"os/exec"
"strings"
"os"
"time"
log "github.com/sirupsen/logrus"
)
func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) {
func createTun(cfg Config) (net.Conn, *net.Interface, error) {
ip, ipNet, err := net.ParseCIDR(cfg.Addr)
if err != nil {
return
return nil, nil, err
}
ifce, itf, err := openTun(context.Background())
if err != nil {
return
}
name, err := ifce.Name()
cmd := fmt.Sprintf("netsh interface ip set address name=\"%s\" "+
"source=static addr=%s mask=%s gateway=none",
name, ip.String(), ipMask(ipNet.Mask))
log.Debug("[tun]", cmd)
args := strings.Split(cmd, " ")
err = retry.OnError(retry.DefaultRetry, func(err error) bool {
return err != nil
}, func() error {
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
return fmt.Errorf("%s: %v", cmd, er)
}
return nil
})
if err != nil {
return
}
if err = addTunRoutes(name, cfg.Gateway, cfg.Routes...); err != nil {
return
}
itf, err = net.InterfaceByName(name)
if err != nil {
return
}
conn = &winTunConn{
ifce: ifce,
addr: &net.IPAddr{IP: ip},
}
return
}
func openTun(ctx context.Context) (td wireguardtun.Device, p *net.Interface, err error) {
interfaceName := "wg1"
if td, err = wireguardtun.CreateTUN(interfaceName, 0); err != nil {
if len(cfg.Name) != 0 {
interfaceName = cfg.Name
}
tunDevice, err := wireguardtun.CreateTUN(interfaceName, cfg.MTU)
if err != nil {
return nil, nil, fmt.Errorf("failed to create TUN device: %w", err)
}
if _, err = td.Name(); err != nil {
return nil, nil, fmt.Errorf("failed to get real name of TUN device: %w", err)
_ = os.Setenv("luid", fmt.Sprintf("%d", tunDevice.(*wireguardtun.NativeTun).LUID()))
luid := winipcfg.LUID(tunDevice.(*wireguardtun.NativeTun).LUID())
if err = luid.AddIPAddress(net.IPNet{IP: ip, Mask: ipNet.Mask}); err != nil {
return nil, nil, err
}
if i, err := winipcfg.LUID(td.(*wireguardtun.NativeTun).LUID()).Interface(); err != nil {
return nil, nil, fmt.Errorf("failed to get interface for TUN device: %w", err)
} else {
if p, err = net.InterfaceByIndex(int(i.InterfaceIndex)); err != nil {
return nil, nil, fmt.Errorf("failed to get interface for TUN device: %w", err)
if err = addTunRoutes(luid, cfg.Gateway, cfg.Routes...); err != nil {
return nil, nil, err
}
row2, _ := luid.Interface()
iface, _ := net.InterfaceByIndex(int(row2.InterfaceIndex))
return &winTunConn{ifce: tunDevice, addr: &net.IPAddr{IP: ip}}, iface, nil
}
func addTunRoutes(ifName winipcfg.LUID, gw string, routes ...IPRoute) error {
_ = ifName.FlushRoutes(windows.AF_INET)
for _, route := range routes {
if route.Dest == nil {
continue
}
if gw != "" {
route.Gateway = net.ParseIP(gw)
} else {
route.Gateway = net.IPv4(0, 0, 0, 0)
}
if err := ifName.AddRoute(*route.Dest, route.Gateway, 0); err != nil {
return err
}
}
return td, p, nil
}
func (t *winTunConn) Close() error {
return t.ifce.Close()
}
func (t *winTunConn) getLUID() winipcfg.LUID {
return winipcfg.LUID(t.ifce.(*wireguardtun.NativeTun).LUID())
}
func (t *winTunConn) addSubnet(_ context.Context, subnet *net.IPNet) error {
return t.getLUID().AddIPAddress(*subnet)
}
func (t *winTunConn) removeSubnet(_ context.Context, subnet *net.IPNet) error {
return t.getLUID().DeleteIPAddress(*subnet)
}
func (t *winTunConn) setDNS(ctx context.Context, server net.IP, domains []string) (err error) {
ipFamily := func(ip net.IP) winipcfg.AddressFamily {
f := winipcfg.AddressFamily(windows.AF_INET6)
if ip4 := ip.To4(); ip4 != nil {
f = windows.AF_INET
}
return f
}
family := ipFamily(server)
luid := t.getLUID()
if err = luid.SetDNS(family, []net.IP{server}, domains); err != nil {
return err
}
_ = exec.CommandContext(ctx, "ipconfig", "/flushdns").Run()
return nil
}
@@ -117,6 +63,16 @@ type winTunConn struct {
addr net.Addr
}
func (c *winTunConn) Close() error {
err := c.ifce.Close()
if name, err := c.ifce.Name(); err == nil {
if wt, err := wireguardtun.WintunPool.OpenAdapter(name); err == nil {
_, err = wt.Delete(true)
}
}
return err
}
func (c *winTunConn) Read(b []byte) (n int, err error) {
return c.ifce.Read(b, 0)
}
@@ -133,47 +89,14 @@ func (c *winTunConn) RemoteAddr() net.Addr {
return &net.IPAddr{}
}
func (c *winTunConn) SetDeadline(t time.Time) error {
func (c *winTunConn) SetDeadline(time.Time) error {
return &net.OpError{Op: "set", Net: "tun", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *winTunConn) SetReadDeadline(t time.Time) error {
func (c *winTunConn) SetReadDeadline(time.Time) error {
return &net.OpError{Op: "set", Net: "tun", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *winTunConn) SetWriteDeadline(t time.Time) error {
func (c *winTunConn) SetWriteDeadline(time.Time) error {
return &net.OpError{Op: "set", Net: "tun", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func addTunRoutes(ifName string, gw string, routes ...IPRoute) error {
for _, route := range routes {
if route.Dest == nil {
continue
}
deleteRoute(ifName, route.Dest.String())
cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=\"%s\" store=active",
route.Dest.String(), ifName)
if gw != "" {
cmd += " nexthop=" + gw
}
log.Debugf("[tun] %s", cmd)
args := strings.Split(cmd, " ")
if er := exec.Command(args[0], args[1:]...).Run(); er != nil {
return fmt.Errorf("%s: %v", cmd, er)
}
}
return nil
}
func deleteRoute(ifName string, route string) error {
cmd := fmt.Sprintf("netsh interface ip delete route prefix=%s interface=\"%s\" store=active",
route, ifName)
args := strings.Split(cmd, " ")
return exec.Command(args[0], args[1:]...).Run()
}
func ipMask(mask net.IPMask) string {
return fmt.Sprintf("%d.%d.%d.%d", mask[0], mask[1], mask[2], mask[3])
}