Feature: add fwmark support

This commit is contained in:
xjasonlyu
2021-02-10 15:46:55 +08:00
parent ea62deb0b6
commit 7f7f2913a8
11 changed files with 136 additions and 40 deletions

View File

@@ -1,8 +1,11 @@
package dialer
import "net"
import (
"net"
"sync"
)
var _boundInterface *net.Interface
var _bindOnce sync.Once
// BindToInterface binds dialer to specific interface.
func BindToInterface(name string) error {
@@ -10,6 +13,9 @@ func BindToInterface(name string) error {
if err != nil {
return err
}
_boundInterface = i
_bindOnce.Do(func() {
addControl(bindToInterface(i))
})
return nil
}

View File

@@ -7,7 +7,8 @@ import (
"golang.org/x/sys/unix"
)
func bindToInterface(network, address string, c syscall.RawConn) error {
func bindToInterface(i *net.Interface) controlFunc {
return func(network, address string, c syscall.RawConn) error {
ipStr, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(ipStr); ip != nil && !ip.IsGlobalUnicast() {
return nil
@@ -16,9 +17,10 @@ func bindToInterface(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
switch network {
case "tcp4", "udp4":
unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, _boundInterface.Index)
unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, i.Index)
case "tcp6", "udp6":
unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, _boundInterface.Index)
unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, i.Index)
}
})
}
}

View File

@@ -7,13 +7,15 @@ import (
"golang.org/x/sys/unix"
)
func bindToInterface(network, address string, c syscall.RawConn) error {
func bindToInterface(i *net.Interface) controlFunc {
return func(network, address string, c syscall.RawConn) error {
ipStr, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(ipStr); ip != nil && !ip.IsGlobalUnicast() {
return nil
}
return c.Control(func(fd uintptr) {
unix.BindToDevice(int(fd), _boundInterface.Name)
unix.BindToDevice(int(fd), i.Name)
})
}
}

View File

@@ -4,9 +4,12 @@ package dialer
import (
"errors"
"net"
"syscall"
)
func bindToInterface(network, address string, c syscall.RawConn) error {
func bindToInterface(_ *net.Interface) controlFunc {
return func(string, string, syscall.RawConn) error {
return errors.New("unsupported platform")
}
}

View File

@@ -0,0 +1,37 @@
package dialer
import (
"errors"
"net"
"syscall"
)
type controlFunc func(string, string, syscall.RawConn) error
var (
_controlPool = make([]controlFunc, 0, 2)
)
func addControl(f controlFunc) {
_controlPool = append(_controlPool, f)
}
func setControl(i interface{}) {
control := func(address, network string, c syscall.RawConn) error {
for _, f := range _controlPool {
if err := f(address, network, c); err != nil {
return err
}
}
return nil
}
switch v := i.(type) {
case *net.Dialer:
v.Control = control
case *net.ListenConfig:
v.Control = control
default:
panic(errors.New("wrong type"))
}
}

View File

@@ -11,20 +11,12 @@ func Dial(network, address string) (net.Conn, error) {
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
d := &net.Dialer{}
if _boundInterface != nil {
d.Control = bindToInterface
}
setControl(d)
return d.DialContext(ctx, network, address)
}
func ListenPacket(network, address string) (net.PacketConn, error) {
lc := &net.ListenConfig{}
if _boundInterface != nil {
lc.Control = bindToInterface
}
setControl(lc)
return lc.ListenPacket(context.Background(), network, address)
}

View File

@@ -0,0 +1,14 @@
package dialer
import (
"sync"
)
var _setOnce sync.Once
// SetMark sets the mark for each packet sent through this dialer(socket).
func SetMark(i int) {
_setOnce.Do(func() {
addControl(setMark(i))
})
}

View File

@@ -0,0 +1,15 @@
package dialer
import (
"syscall"
"golang.org/x/sys/unix"
)
func setMark(i int) controlFunc {
return func(_, _ string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, i)
})
}
}

View File

@@ -0,0 +1,14 @@
// +build !linux
package dialer
import (
"errors"
"syscall"
)
func setMark(_ int) controlFunc {
return func(string, string, syscall.RawConn) error {
return errors.New("fwmark: linux only")
}
}

View File

@@ -31,6 +31,7 @@ func Insert(k *Key) {
type Key struct {
MTU uint32
Mark int
Proxy string
Stats string
Token string
@@ -60,6 +61,7 @@ func (e *engine) start() error {
for _, f := range []func() error{
e.setLogLevel,
e.setMark,
e.setInterface,
e.setStats,
e.setProxy,
@@ -93,12 +95,20 @@ func (e *engine) setLogLevel() error {
return nil
}
func (e *engine) setMark() error {
if e.Mark != 0 {
dialer.SetMark(e.Mark)
log.Infof("[DIALER] set fwmark: %d", e.Mark)
}
return nil
}
func (e *engine) setInterface() error {
if e.Interface != "" {
if err := dialer.BindToInterface(e.Interface); err != nil {
return err
}
log.Infof("[BOUND] bind to interface: %s", e.Interface)
log.Infof("[DIALER] use interface: %s", e.Interface)
}
return nil
}
@@ -108,7 +118,7 @@ func (e *engine) setStats() error {
go func() {
_ = stats.Start(e.Stats, e.Token)
}()
log.Infof("[STATS] stats server listen at: http://%s", e.Stats)
log.Infof("[STATS] serve at: http://%s", e.Stats)
}
return nil
}

View File

@@ -15,12 +15,13 @@ var key = new(engine.Key)
func init() {
flag.StringVarP(&key.Device, "device", "d", "", "use this device [driver://]name")
flag.IntVar(&key.Mark, "fwmark", 0, "set firewall MARK (Linux only)")
flag.StringVarP(&key.Interface, "interface", "i", "", "use network INTERFACE (Linux/MacOS only)")
flag.StringVarP(&key.Proxy, "proxy", "p", "", "use this proxy [protocol://]host[:port]")
flag.StringVarP(&key.LogLevel, "loglevel", "l", "info", "log level [debug|info|warn|error|silent]")
flag.Uint32VarP(&key.MTU, "mtu", "m", 0, "set device maximum transmission unit (MTU)")
flag.StringVarP(&key.Proxy, "proxy", "p", "", "use this proxy [protocol://]host[:port]")
flag.StringVar(&key.Stats, "stats", "", "HTTP statistic server listen address")
flag.StringVar(&key.Token, "token", "", "HTTP statistic server auth token")
flag.Uint32VarP(&key.MTU, "mtu", "m", 0, "set device maximum transmission unit (MTU)")
flag.BoolVarP(&key.Version, "version", "v", false, "show version information and quit")
flag.Parse()
}