diff --git a/dns/dns_windows.go b/dns/dns_windows.go index 4a9da8aa..6e39b42e 100644 --- a/dns/dns_windows.go +++ b/dns/dns_windows.go @@ -4,50 +4,48 @@ package dns import ( + "context" "fmt" log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "net" "os" "os/exec" + "strconv" ) func SetupDNS(ip string, namespace string) error { - tunName := os.Getenv("tunName") - log.Info("tun name: " + tunName) - _ = cleanDnsServer(tunName) - cmd := exec.Command("netsh", []string{ - "interface", - "ipv4", - "add", - "dnsservers", - fmt.Sprintf("name=\"%s\"", tunName), - fmt.Sprintf("address=%s", ip), - "index=1", - }...) - output, err := cmd.CombinedOutput() + getenv := os.Getenv("luid") + parseUint, err := strconv.ParseUint(getenv, 10, 64) if err != nil { - log.Warnf("error while set dns server, error: %v, output: %s, command: %v", err, string(output), cmd.Args) + log.Warningln(err) + return err } - _ = addNicSuffixSearchList(namespace) - _ = updateNicMetric(tunName) + luid := winipcfg.LUID(parseUint) + err = luid.SetDNS(windows.AF_INET, []net.IP{net.ParseIP(ip)}, []string{ + namespace + ".svc.cluster.local", + "svc.cluster.local", + "cluster.local", + }) + _ = exec.CommandContext(context.Background(), "ipconfig", "/flushdns").Run() + if err != nil { + log.Warningln(err) + return err + } + //_ = updateNicMetric(tunName) return nil } func CancelDNS() { -} - -// @see https://docs.microsoft.com/en-us/powershell/module/dnsclient/set-dnsclientglobalsetting?view=windowsserver2019-ps#example-1--set-the-dns-suffix-search-list -func addNicSuffixSearchList(namespace string) error { - cmd := exec.Command("PowerShell", []string{ - "Set-DnsClientGlobalSetting", - "-SuffixSearchList", - fmt.Sprintf("@(\"%s.svc.cluster.local\", \"svc.cluster.local\")", namespace), - }...) - output, err := cmd.CombinedOutput() - log.Info(cmd.Args) + getenv := os.Getenv("luid") + parseUint, err := strconv.ParseUint(getenv, 10, 64) if err != nil { - log.Warnf("error while set dns suffix search list, err: %v, output: %s, command: %v", err, string(output), cmd.Args) + log.Warningln(err) + return } - return err + luid := winipcfg.LUID(parseUint) + _ = luid.FlushDNS(windows.AF_INET) } func updateNicMetric(name string) error { @@ -64,19 +62,3 @@ func updateNicMetric(name string) error { } return err } - -func cleanDnsServer(name string) error { - cmd := exec.Command("netsh", []string{ - "interface", - "ipv4", - "delete", - "dnsservers", - fmt.Sprintf("\"%s\"", name), - "all", - }...) - out, err := cmd.CombinedOutput() - if err != nil { - log.Warnf("clean dnsservers failed, error: %v, output: %s, command: %v", err, string(out), cmd.Args) - } - return err -} diff --git a/pkg/route.go b/pkg/route.go index 6c1b0e71..56240aad 100644 --- a/pkg/route.go +++ b/pkg/route.go @@ -70,10 +70,9 @@ func (r *Route) GenRouters() ([]router, error) { case "tcp": ln, err = core.TCPListener(node.Addr) case "tun": - cfg := tun.TunConfig{ + cfg := tun.Config{ Name: node.Get("name"), Addr: node.Get("net"), - Peer: node.Get("peer"), MTU: node.GetInt("mtu"), Routes: tunRoutes, Gateway: node.Get("gw"), diff --git a/tun/tun.go b/tun/tun.go index 444ee56b..2b3ef168 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -9,11 +9,10 @@ import ( "time" ) -// TunConfig is the config for TUN device. -type TunConfig struct { +// Config is the config for TUN device. +type Config struct { Name string Addr string - Peer string // peer addr of point-to-point on MacOS MTU int Routes []IPRoute Gateway string @@ -23,11 +22,11 @@ type tunListener struct { addr net.Addr conns chan net.Conn closed chan struct{} - config TunConfig + config Config } // TunListener creates a listener for tun tunnel. -func TunListener(cfg TunConfig) (Listener, error) { +func TunListener(cfg Config) (Listener, error) { threads := 1 ln := &tunListener{ conns: make(chan net.Conn, threads), diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index 9dfde459..280ad36b 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -11,7 +11,7 @@ import ( "github.com/songgao/water" ) -func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) { +func createTun(cfg Config) (conn net.Conn, itf *net.Interface, err error) { ip, _, err := net.ParseCIDR(cfg.Addr) if err != nil { return diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 7b25f0e4..393feb41 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -13,7 +13,7 @@ import ( "github.com/songgao/water" ) -func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) { +func createTun(cfg Config) (conn net.Conn, itf *net.Interface, err error) { ip, ipNet, err := net.ParseCIDR(cfg.Addr) if err != nil { return diff --git a/tun/tun_unix.go b/tun/tun_unix.go index 51bfce0d..9763d696 100644 --- a/tun/tun_unix.go +++ b/tun/tun_unix.go @@ -14,7 +14,7 @@ import ( "github.com/songgao/water" ) -func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) { +func createTun(cfg Config) (conn net.Conn, itf *net.Interface, err error) { ip, _, err := net.ParseCIDR(cfg.Addr) if err != nil { return diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 77c51f18..1319d099 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -1,114 +1,60 @@ package tun import ( - "context" "fmt" "github.com/pkg/errors" "golang.org/x/sys/windows" wireguardtun "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "k8s.io/client-go/util/retry" "net" - "os/exec" - "strings" + "os" "time" - - log "github.com/sirupsen/logrus" ) -func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) { +func createTun(cfg Config) (net.Conn, *net.Interface, error) { ip, ipNet, err := net.ParseCIDR(cfg.Addr) if err != nil { - return + return nil, nil, err } - ifce, itf, err := openTun(context.Background()) - if err != nil { - return - } - name, err := ifce.Name() - - cmd := fmt.Sprintf("netsh interface ip set address name=\"%s\" "+ - "source=static addr=%s mask=%s gateway=none", - name, ip.String(), ipMask(ipNet.Mask)) - log.Debug("[tun]", cmd) - - args := strings.Split(cmd, " ") - err = retry.OnError(retry.DefaultRetry, func(err error) bool { - return err != nil - }, func() error { - if er := exec.Command(args[0], args[1:]...).Run(); er != nil { - return fmt.Errorf("%s: %v", cmd, er) - } - return nil - }) - if err != nil { - return - } - - if err = addTunRoutes(name, cfg.Gateway, cfg.Routes...); err != nil { - return - } - - itf, err = net.InterfaceByName(name) - if err != nil { - return - } - - conn = &winTunConn{ - ifce: ifce, - addr: &net.IPAddr{IP: ip}, - } - return -} - -func openTun(ctx context.Context) (td wireguardtun.Device, p *net.Interface, err error) { interfaceName := "wg1" - if td, err = wireguardtun.CreateTUN(interfaceName, 0); err != nil { + if len(cfg.Name) != 0 { + interfaceName = cfg.Name + } + tunDevice, err := wireguardtun.CreateTUN(interfaceName, cfg.MTU) + if err != nil { return nil, nil, fmt.Errorf("failed to create TUN device: %w", err) } - if _, err = td.Name(); err != nil { - return nil, nil, fmt.Errorf("failed to get real name of TUN device: %w", err) + _ = os.Setenv("luid", fmt.Sprintf("%d", tunDevice.(*wireguardtun.NativeTun).LUID())) + + luid := winipcfg.LUID(tunDevice.(*wireguardtun.NativeTun).LUID()) + if err = luid.AddIPAddress(net.IPNet{IP: ip, Mask: ipNet.Mask}); err != nil { + return nil, nil, err } - if i, err := winipcfg.LUID(td.(*wireguardtun.NativeTun).LUID()).Interface(); err != nil { - return nil, nil, fmt.Errorf("failed to get interface for TUN device: %w", err) - } else { - if p, err = net.InterfaceByIndex(int(i.InterfaceIndex)); err != nil { - return nil, nil, fmt.Errorf("failed to get interface for TUN device: %w", err) + + if err = addTunRoutes(luid, cfg.Gateway, cfg.Routes...); err != nil { + return nil, nil, err + } + + row2, _ := luid.Interface() + iface, _ := net.InterfaceByIndex(int(row2.InterfaceIndex)) + return &winTunConn{ifce: tunDevice, addr: &net.IPAddr{IP: ip}}, iface, nil +} + +func addTunRoutes(ifName winipcfg.LUID, gw string, routes ...IPRoute) error { + _ = ifName.FlushRoutes(windows.AF_INET) + for _, route := range routes { + if route.Dest == nil { + continue + } + if gw != "" { + route.Gateway = net.ParseIP(gw) + } else { + route.Gateway = net.IPv4(0, 0, 0, 0) + } + if err := ifName.AddRoute(*route.Dest, route.Gateway, 0); err != nil { + return err } } - return td, p, nil -} - -func (t *winTunConn) Close() error { - return t.ifce.Close() -} - -func (t *winTunConn) getLUID() winipcfg.LUID { - return winipcfg.LUID(t.ifce.(*wireguardtun.NativeTun).LUID()) -} - -func (t *winTunConn) addSubnet(_ context.Context, subnet *net.IPNet) error { - return t.getLUID().AddIPAddress(*subnet) -} - -func (t *winTunConn) removeSubnet(_ context.Context, subnet *net.IPNet) error { - return t.getLUID().DeleteIPAddress(*subnet) -} - -func (t *winTunConn) setDNS(ctx context.Context, server net.IP, domains []string) (err error) { - ipFamily := func(ip net.IP) winipcfg.AddressFamily { - f := winipcfg.AddressFamily(windows.AF_INET6) - if ip4 := ip.To4(); ip4 != nil { - f = windows.AF_INET - } - return f - } - family := ipFamily(server) - luid := t.getLUID() - if err = luid.SetDNS(family, []net.IP{server}, domains); err != nil { - return err - } - _ = exec.CommandContext(ctx, "ipconfig", "/flushdns").Run() return nil } @@ -117,6 +63,16 @@ type winTunConn struct { addr net.Addr } +func (c *winTunConn) Close() error { + err := c.ifce.Close() + if name, err := c.ifce.Name(); err == nil { + if wt, err := wireguardtun.WintunPool.OpenAdapter(name); err == nil { + _, err = wt.Delete(true) + } + } + return err +} + func (c *winTunConn) Read(b []byte) (n int, err error) { return c.ifce.Read(b, 0) } @@ -133,47 +89,14 @@ func (c *winTunConn) RemoteAddr() net.Addr { return &net.IPAddr{} } -func (c *winTunConn) SetDeadline(t time.Time) error { +func (c *winTunConn) SetDeadline(time.Time) error { return &net.OpError{Op: "set", Net: "tun", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -func (c *winTunConn) SetReadDeadline(t time.Time) error { +func (c *winTunConn) SetReadDeadline(time.Time) error { return &net.OpError{Op: "set", Net: "tun", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -func (c *winTunConn) SetWriteDeadline(t time.Time) error { +func (c *winTunConn) SetWriteDeadline(time.Time) error { return &net.OpError{Op: "set", Net: "tun", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } - -func addTunRoutes(ifName string, gw string, routes ...IPRoute) error { - for _, route := range routes { - if route.Dest == nil { - continue - } - - deleteRoute(ifName, route.Dest.String()) - - cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=\"%s\" store=active", - route.Dest.String(), ifName) - if gw != "" { - cmd += " nexthop=" + gw - } - log.Debugf("[tun] %s", cmd) - args := strings.Split(cmd, " ") - if er := exec.Command(args[0], args[1:]...).Run(); er != nil { - return fmt.Errorf("%s: %v", cmd, er) - } - } - return nil -} - -func deleteRoute(ifName string, route string) error { - cmd := fmt.Sprintf("netsh interface ip delete route prefix=%s interface=\"%s\" store=active", - route, ifName) - args := strings.Split(cmd, " ") - return exec.Command(args[0], args[1:]...).Run() -} - -func ipMask(mask net.IPMask) string { - return fmt.Sprintf("%d.%d.%d.%d", mask[0], mask[1], mask[2], mask[3]) -}