This commit is contained in:
Shiming Zhang
2022-02-26 20:22:53 +08:00
parent b09240f34d
commit 933f75497c
18 changed files with 522 additions and 49 deletions

View File

@@ -35,7 +35,9 @@ func init() {
func main() {
logger := log.New(os.Stderr, "[tun] ", log.LstdFlags)
controls := []control.ControlFunc{}
if mark != 0 {
controls = append(controls, control.ControlSocketMark(mark))
}

View File

@@ -1,6 +1,8 @@
package device
import "fmt"
import (
"fmt"
)
var registry = map[string]func(name string, mtu uint32) (Device, error){}

13
netstack/handler.go Normal file
View File

@@ -0,0 +1,13 @@
package netstack
import (
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// Handler is a TCP/UDP connection handler that implements
// HandleTCPConn and HandleUDPConn methods.
type Handler interface {
HandleTCPConn(stack.TransportEndpointID, *gonet.TCPConn)
HandleUDPConn(stack.TransportEndpointID, *gonet.UDPConn)
}

24
netstack/icmp.go Normal file
View File

@@ -0,0 +1,24 @@
package netstack
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
func withICMPHandler() Option {
return func(s *Stack) error {
// Add default route table for IPv4 and IPv6.
// This will handle all incoming ICMP packets.
s.SetRouteTable([]tcpip.Route{
{
Destination: header.IPv4EmptySubnet,
NIC: s.nicID,
},
{
Destination: header.IPv6EmptySubnet,
NIC: s.nicID,
},
})
return nil
}
}

59
netstack/nic.go Normal file
View File

@@ -0,0 +1,59 @@
package netstack
import (
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
// defaultNICID is the ID of default NIC used by DefaultStack.
defaultNICID tcpip.NICID = 0x01
// nicPromiscuousModeEnabled is the value used by stack to enable
// or disable NIC's promiscuous mode.
nicPromiscuousModeEnabled = true
// nicSpoofingEnabled is the value used by stack to enable or disable
// NIC's spoofing.
nicSpoofingEnabled = true
)
// withCreatingNIC creates NIC for stack.
func withCreatingNIC(ep stack.LinkEndpoint) Option {
return func(s *Stack) error {
if err := s.CreateNICWithOptions(s.nicID, ep,
stack.NICOptions{
Disabled: false,
// If no queueing discipline was specified
// provide a stub implementation that just
// delegates to the lower link endpoint.
QDisc: nil,
}); err != nil {
return fmt.Errorf("create NIC: %s", err)
}
return nil
}
}
// withPromiscuousMode sets promiscuous mode in the given NIC.
func withPromiscuousMode(v bool) Option {
return func(s *Stack) error {
if err := s.SetPromiscuousMode(s.nicID, v); err != nil {
return fmt.Errorf("set promiscuous mode: %s", err)
}
return nil
}
}
// withSpoofing sets address spoofing in the given NIC, allowing
// endpoints to bind to any address in the NIC.
func withSpoofing(v bool) Option {
return func(s *Stack) error {
if err := s.SetSpoofing(s.nicID, v); err != nil {
return fmt.Errorf("set spoofing: %s", err)
}
return nil
}
}

223
netstack/opts.go Normal file
View File

@@ -0,0 +1,223 @@
package netstack
import (
"fmt"
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
)
const (
// defaultTimeToLive specifies the default TTL used by stack.
defaultTimeToLive uint8 = 64
// ipForwardingEnabled is the value used by stack to enable packet
// forwarding between NICs.
ipForwardingEnabled = true
// icmpBurst is the default number of ICMP messages that can be sent in
// a single burst.
icmpBurst = 50
// icmpLimit is the default maximum number of ICMP messages permitted
// by this rate limiter.
icmpLimit rate.Limit = 1000
// tcpCongestionControl is the congestion control algorithm used by
// stack. ccReno is the default option in gVisor stack.
tcpCongestionControlAlgorithm = "reno" // "reno" or "cubic"
// tcpDelayEnabled is the value used by stack to enable or disable
// tcp delay option. Disable Nagle's algorithm here by default.
tcpDelayEnabled = false
// tcpModerateReceiveBufferEnabled is the value used by stack to
// enable or disable tcp receive buffer auto-tuning option.
tcpModerateReceiveBufferEnabled = true
// tcpSACKEnabled is the value used by stack to enable or disable
// tcp selective ACK.
tcpSACKEnabled = true
// tcpRecovery is the loss detection algorithm used by TCP.
tcpRecovery = tcpip.TCPRACKLossDetection
// tcpMinBufferSize is the smallest size of a send/recv buffer.
tcpMinBufferSize = tcp.MinBufferSize // 4 KiB
// tcpMaxBufferSize is the maximum permitted size of a send/recv buffer.
tcpMaxBufferSize = tcp.MaxBufferSize // 4 MiB
// tcpDefaultBufferSize is the default size of the send/recv buffer for
// a transport endpoint.
tcpDefaultBufferSize = 212 << 10 // 212 KiB
)
type Option func(*Stack) error
// WithDefault sets all default values for stack.
func WithDefault() Option {
return func(s *Stack) error {
opts := []Option{
WithDefaultTTL(defaultTimeToLive),
WithForwarding(ipForwardingEnabled),
// Config default stack ICMP settings.
WithICMPBurst(icmpBurst), WithICMPLimit(icmpLimit),
// We expect no packet loss, therefore we can bump buffers.
// Too large buffers thrash cache, so there is little point
// in too large buffers.
//
// Ref: https://github.com/cloudflare/slirpnetstack/blob/master/stack.go
WithTCPBufferSizeRange(tcpMinBufferSize, tcpDefaultBufferSize, tcpMaxBufferSize),
WithTCPCongestionControl(tcpCongestionControlAlgorithm),
WithTCPDelay(tcpDelayEnabled),
// Receive Buffer Auto-Tuning Option, see:
// https://github.com/google/gvisor/issues/1666
WithTCPModerateReceiveBuffer(tcpModerateReceiveBufferEnabled),
// TCP selective ACK Option, see:
// https://tools.ietf.org/html/rfc2018
WithTCPSACKEnabled(tcpSACKEnabled),
// TCPRACKLossDetection: indicates RACK is used for loss detection and
// recovery.
//
// TCPRACKStaticReoWnd: indicates the reordering window should not be
// adjusted when DSACK is received.
//
// TCPRACKNoDupTh: indicates RACK should not consider the classic three
// duplicate acknowledgements rule to mark the segments as lost. This
// is used when reordering is not detected.
WithTCPRecovery(tcpRecovery),
}
for _, opt := range opts {
if err := opt(s); err != nil {
return err
}
}
return nil
}
}
// WithDefaultTTL sets the default TTL used by stack.
func WithDefaultTTL(ttl uint8) Option {
return func(s *Stack) error {
opt := tcpip.DefaultTTLOption(ttl)
if err := s.SetNetworkProtocolOption(ipv4.ProtocolNumber, &opt); err != nil {
return fmt.Errorf("set ipv4 default TTL: %s", err)
}
if err := s.SetNetworkProtocolOption(ipv6.ProtocolNumber, &opt); err != nil {
return fmt.Errorf("set ipv6 default TTL: %s", err)
}
return nil
}
}
// WithForwarding sets packet forwarding between NICs for IPv4 & IPv6.
func WithForwarding(v bool) Option {
return func(s *Stack) error {
if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, v); err != nil {
return fmt.Errorf("set ipv4 forwarding: %s", err)
}
if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, v); err != nil {
return fmt.Errorf("set ipv6 forwarding: %s", err)
}
return nil
}
}
// WithICMPBurst sets the number of ICMP messages that can be sent
// in a single burst.
func WithICMPBurst(burst int) Option {
return func(s *Stack) error {
s.SetICMPBurst(burst)
return nil
}
}
// WithICMPLimit sets the maximum number of ICMP messages permitted
// by rate limiter.
func WithICMPLimit(limit rate.Limit) Option {
return func(s *Stack) error {
s.SetICMPLimit(limit)
return nil
}
}
// WithTCPBufferSizeRange sets the receive and send buffer size range for TCP.
func WithTCPBufferSizeRange(a, b, c int) Option {
return func(s *Stack) error {
rcvOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: a, Default: b, Max: c}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvOpt); err != nil {
return fmt.Errorf("set TCP receive buffer size range: %s", err)
}
sndOpt := tcpip.TCPSendBufferSizeRangeOption{Min: a, Default: b, Max: c}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sndOpt); err != nil {
return fmt.Errorf("set TCP send buffer size range: %s", err)
}
return nil
}
}
// WithTCPCongestionControl sets the current congestion control algorithm.
func WithTCPCongestionControl(cc string) Option {
return func(s *Stack) error {
opt := tcpip.CongestionControlOption(cc)
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
return fmt.Errorf("set TCP congestion control algorithm: %s", err)
}
return nil
}
}
// WithTCPDelay enables or disables Nagle's algorithm in TCP.
func WithTCPDelay(v bool) Option {
return func(s *Stack) error {
opt := tcpip.TCPDelayEnabled(v)
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
return fmt.Errorf("set TCP delay: %s", err)
}
return nil
}
}
// WithTCPModerateReceiveBuffer sets receive buffer moderation for TCP.
func WithTCPModerateReceiveBuffer(v bool) Option {
return func(s *Stack) error {
opt := tcpip.TCPModerateReceiveBufferOption(v)
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
return fmt.Errorf("set TCP moderate receive buffer: %s", err)
}
return nil
}
}
// WithTCPSACKEnabled sets the SACK option for TCP.
func WithTCPSACKEnabled(v bool) Option {
return func(s *Stack) error {
opt := tcpip.TCPSACKEnabled(v)
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
return fmt.Errorf("set TCP SACK: %s", err)
}
return nil
}
}
// WithTCPRecovery sets the recovery option for TCP.
func WithTCPRecovery(v tcpip.TCPRecovery) Option {
return func(s *Stack) error {
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
return fmt.Errorf("set TCP Recovery: %s", err)
}
return nil
}
}

80
netstack/stack.go Normal file
View File

@@ -0,0 +1,80 @@
// Package stack provides a thin wrapper around a gVisor's stack.
package netstack
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
type Stack struct {
*stack.Stack
handler Handler
nicID tcpip.NICID
}
// NewStack allocates a new *Stack with given options.
func NewStack(ep stack.LinkEndpoint, handler Handler, opts ...Option) (*Stack, error) {
s := &Stack{
Stack: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
}),
handler: handler,
nicID: defaultNICID,
}
opts = append(opts,
// Important: We must initiate transport protocol handlers
// before creating NIC, otherwise NIC would dispatch packets
// to stack and cause race condition.
withICMPHandler(), withTCPHandler(), withUDPHandler(),
// Create stack NIC and then bind link endpoint.
withCreatingNIC(ep),
// In the past we did s.AddAddressRange to assign 0.0.0.0/0
// onto the interface. We need that to be able to terminate
// all the incoming connections - to any ip. AddressRange API
// has been removed and the suggested workaround is to use
// Promiscuous mode. https://github.com/google/gvisor/issues/3876
//
// Ref: https://github.com/cloudflare/slirpnetstack/blob/master/stack.go
withPromiscuousMode(nicPromiscuousModeEnabled),
// Enable spoofing if a stack may send packets from unowned addresses.
// This change required changes to some netgophers since previously,
// promiscuous mode was enough to let the netstack respond to all
// incoming packets regardless of the packet's destination address. Now
// that a stack.Route is not held for each incoming packet, finding a route
// may fail with local addresses we don't own but accepted packets for
// while in promiscuous mode. Since we also want to be able to send from
// any address (in response the received promiscuous mode packets), we need
// to enable spoofing.
//
// Ref: https://github.com/google/gvisor/commit/8c0701462a84ff77e602f1626aec49479c308127
withSpoofing(nicSpoofingEnabled),
)
for _, opt := range opts {
if err := opt(s); err != nil {
return nil, err
}
}
return s, nil
}

77
netstack/tcp.go Normal file
View File

@@ -0,0 +1,77 @@
package netstack
import (
"fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
// defaultWndSize if set to zero, the default
// receive window buffer size is used instead.
defaultWndSize = 0
// maxConnAttempts specifies the maximum number
// of in-flight tcp connection attempts.
maxConnAttempts = 2 << 10
// tcpKeepaliveCount is the maximum number of
// TCP keep-alive probes to send before giving up
// and killing the connection if no response is
// obtained from the other end.
tcpKeepaliveCount = 9
// tcpKeepaliveIdle specifies the time a connection
// must remain idle before the first TCP keepalive
// packet is sent. Once this time is reached,
// tcpKeepaliveInterval option is used instead.
tcpKeepaliveIdle = 60 * time.Second
// tcpKeepaliveInterval specifies the interval
// time between sending TCP keepalive packets.
tcpKeepaliveInterval = 30 * time.Second
)
func withTCPHandler() Option {
return func(s *Stack) error {
tcpForwarder := tcp.NewForwarder(s.Stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
ep, err := r.CreateEndpoint(&wq)
if err != nil {
// RST: prevent potential half-open TCP connection leak.
r.Complete(true)
return
}
defer r.Complete(false)
setKeepalive(ep)
s.handler.HandleTCPConn(r.ID(), gonet.NewTCPConn(&wq, ep))
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
return nil
}
}
func setKeepalive(ep tcpip.Endpoint) error {
ep.SocketOptions().SetKeepAlive(true)
idle := tcpip.KeepaliveIdleOption(tcpKeepaliveIdle)
if err := ep.SetSockOpt(&idle); err != nil {
return fmt.Errorf("set keepalive idle: %s", err)
}
interval := tcpip.KeepaliveIntervalOption(tcpKeepaliveInterval)
if err := ep.SetSockOpt(&interval); err != nil {
return fmt.Errorf("set keepalive interval: %s", err)
}
if err := ep.SetSockOptInt(tcpip.KeepaliveCountOption, tcpKeepaliveCount); err != nil {
return fmt.Errorf("set keepalive count: %s", err)
}
return nil
}

23
netstack/udp.go Normal file
View File

@@ -0,0 +1,23 @@
package netstack
import (
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
func withUDPHandler() Option {
return func(s *Stack) error {
udpForwarder := udp.NewForwarder(s.Stack, func(r *udp.ForwarderRequest) {
var wq waiter.Queue
ep, err := r.CreateEndpoint(&wq)
if err != nil {
return
}
s.handler.HandleUDPConn(r.ID(), gonet.NewUDPConn(s.Stack, &wq, ep))
})
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
return nil
}
}

View File

@@ -9,7 +9,7 @@ import (
"strings"
)
func cmd(name string, arg ...string) error {
func command(name string, arg ...string) error {
c := exec.Command(name, arg...)
out, err := c.CombinedOutput()
if err != nil {

View File

@@ -7,12 +7,12 @@ import (
"syscall"
)
func cmd(name string, arg ...string) (string, error) {
func command(name string, arg ...string) error {
c := exec.Command(name, arg...)
c.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
out, err := c.CombinedOutput()
if err != nil {
return "", fmt.Errorf("%q: %w: %q", strings.Join(append([]string{name}, arg...), " "), err, out)
return fmt.Errorf("%q: %w: %q", strings.Join(append([]string{name}, arg...), " "), err, out)
}
return strings.TrimSpace(string(out)), nil
return nil
}

View File

@@ -11,7 +11,3 @@ func toIpAndMask(cidr string) (string, string, error) {
}
return ipNet.IP.String(), net.IP(ipNet.Mask).String(), nil
}
func GetDevice() string {
return "tun"
}

View File

@@ -7,12 +7,6 @@ import (
"strings"
)
// Check everything needed for tun setup
func Check() error {
// TODO: check whether ifconfig and route command exists
return nil
}
// SetRoute set specified ip range route to tun device
func SetRoute(name string, ipRange []string) error {
var err, lastErr error
@@ -20,7 +14,7 @@ func SetRoute(name string, ipRange []string) error {
tunIp := strings.Split(r, "/")[0]
if i == 0 {
// run command: ifconfig utun6 inet 172.20.0.0/16 172.20.0.0
err = cmd("ifconfig",
err = command("ifconfig",
name,
"inet",
r,
@@ -28,7 +22,7 @@ func SetRoute(name string, ipRange []string) error {
)
} else {
// run command: ifconfig utun6 add 172.20.0.0/16 172.20.0.1
err = cmd("ifconfig",
err = command("ifconfig",
name,
"add",
r,
@@ -40,7 +34,7 @@ func SetRoute(name string, ipRange []string) error {
continue
}
// run command: route add -net 172.20.0.0/16 -interface utun6
err = cmd("route",
err = command("route",
"add",
"-net",
r,

View File

@@ -1,15 +1,9 @@
package route
// Check everything needed for tun setup
func Check() error {
// TODO: check whether ip command exists
return nil
}
// SetRoute let specified ip range route to tun device
func SetRoute(name string, ipRange []string) error {
// run command: ip link set dev kt0 up
err := cmd("ip",
err := command("ip",
"link",
"set",
"dev",
@@ -23,7 +17,7 @@ func SetRoute(name string, ipRange []string) error {
var lastErr error
for _, r := range ipRange {
// run command: ip route add 10.96.0.0/16 dev kt0
err = cmd("ip",
err = command("ip",
"route",
"add",
r,

View File

@@ -1,23 +1,9 @@
package route
import (
"fmt"
"strings"
wintun "golang.zx2c4.com/wintun"
)
// Check everything needed for tun setup
func CheckContext() (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("failed to found tun driver: %v", r)
}
}()
wintun.RunningVersion()
return
}
// SetRoute let specified ip range route to tun device
func SetRoute(name string, ipRange []string) error {
var lastErr error
@@ -29,7 +15,7 @@ func SetRoute(name string, ipRange []string) error {
}
if i == 0 {
// run command: netsh interface ip set address KtConnectTunnel static 172.20.0.1 255.255.0.0
err = cmd("netsh",
err = command("netsh",
"interface",
"ip",
"set",
@@ -41,7 +27,7 @@ func SetRoute(name string, ipRange []string) error {
)
} else {
// run command: netsh interface ip add address KtConnectTunnel 172.21.0.1 255.255.0.0
err = cmd("netsh",
err = command("netsh",
"interface",
"ip",
"add",
@@ -56,7 +42,7 @@ func SetRoute(name string, ipRange []string) error {
continue
}
// run command: route add 172.20.0.0 mask 255.255.0.0 172.20.0.1
err = cmd("route",
err = command("route",
"add",
ip,
"mask",

8
tun.go
View File

@@ -2,8 +2,8 @@ package tun
import (
"github.com/wzshiming/tun/device"
"github.com/wzshiming/tun/netstack"
"github.com/wzshiming/tun/route"
"github.com/wzshiming/tun/stack"
)
type Config struct {
@@ -25,7 +25,7 @@ func NewTun(c Config) *Tun {
type Tun struct {
*Config
stack *stack.Stack
stack *netstack.Stack
device device.Device
}
@@ -34,7 +34,7 @@ func (t *Tun) Start() error {
t.Name = route.GetName()
}
if t.Device == "" {
t.Device = route.GetDevice()
t.Device = "tun"
}
d, err := device.NewDevice(t.Device, t.Name, uint32(t.MTU))
@@ -43,7 +43,7 @@ func (t *Tun) Start() error {
}
t.device = d
s, err := stack.NewStack(t.device, t, stack.WithDefault())
s, err := netstack.NewStack(t.device, t, netstack.WithDefault())
if err != nil {
return err
}

View File

@@ -62,7 +62,7 @@ func (t *Tun) handleUDPConn(id stack.TransportEndpointID, uc *gonet.UDPConn) {
Port: int(id.LocalPort),
}
pc, err := t.ListenPacket.ListenPacket(context.Background(), "udp", remote.String())
pc, err := t.ListenPacket.ListenPacket(context.Background(), "udp", ":0")
if err != nil {
if t.Logger != nil {
t.Logger.Println("UDP listen error:", err)