diff --git a/component/dialer/bind.go b/component/dialer/bind.go index 366044f..fd3a098 100755 --- a/component/dialer/bind.go +++ b/component/dialer/bind.go @@ -1,8 +1,11 @@ package dialer -import "net" +import ( + "net" + "sync" +) -var _boundInterface *net.Interface +var _bindOnce sync.Once // BindToInterface binds dialer to specific interface. func BindToInterface(name string) error { @@ -10,6 +13,9 @@ func BindToInterface(name string) error { if err != nil { return err } - _boundInterface = i + + _bindOnce.Do(func() { + addControl(bindToInterface(i)) + }) return nil } diff --git a/component/dialer/bind_darwin.go b/component/dialer/bind_darwin.go index 27bb7b0..dfc693d 100755 --- a/component/dialer/bind_darwin.go +++ b/component/dialer/bind_darwin.go @@ -7,18 +7,20 @@ import ( "golang.org/x/sys/unix" ) -func bindToInterface(network, address string, c syscall.RawConn) error { - ipStr, _, _ := net.SplitHostPort(address) - if ip := net.ParseIP(ipStr); ip != nil && !ip.IsGlobalUnicast() { - return nil - } - - return c.Control(func(fd uintptr) { - switch network { - case "tcp4", "udp4": - unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, _boundInterface.Index) - case "tcp6", "udp6": - unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, _boundInterface.Index) +func bindToInterface(i *net.Interface) controlFunc { + return func(network, address string, c syscall.RawConn) error { + ipStr, _, _ := net.SplitHostPort(address) + if ip := net.ParseIP(ipStr); ip != nil && !ip.IsGlobalUnicast() { + return nil } - }) + + return c.Control(func(fd uintptr) { + switch network { + case "tcp4", "udp4": + unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, i.Index) + case "tcp6", "udp6": + unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, i.Index) + } + }) + } } diff --git a/component/dialer/bind_linux.go b/component/dialer/bind_linux.go index 4e4ea61..76decc3 100755 --- a/component/dialer/bind_linux.go +++ b/component/dialer/bind_linux.go @@ -7,13 +7,15 @@ import ( "golang.org/x/sys/unix" ) -func bindToInterface(network, address string, c syscall.RawConn) error { - ipStr, _, _ := net.SplitHostPort(address) - if ip := net.ParseIP(ipStr); ip != nil && !ip.IsGlobalUnicast() { - return nil - } +func bindToInterface(i *net.Interface) controlFunc { + return func(network, address string, c syscall.RawConn) error { + ipStr, _, _ := net.SplitHostPort(address) + if ip := net.ParseIP(ipStr); ip != nil && !ip.IsGlobalUnicast() { + return nil + } - return c.Control(func(fd uintptr) { - unix.BindToDevice(int(fd), _boundInterface.Name) - }) + return c.Control(func(fd uintptr) { + unix.BindToDevice(int(fd), i.Name) + }) + } } diff --git a/component/dialer/bind_others.go b/component/dialer/bind_others.go index d4aa13c..d946ece 100755 --- a/component/dialer/bind_others.go +++ b/component/dialer/bind_others.go @@ -4,9 +4,12 @@ package dialer import ( "errors" + "net" "syscall" ) -func bindToInterface(network, address string, c syscall.RawConn) error { - return errors.New("unsupported platform") +func bindToInterface(_ *net.Interface) controlFunc { + return func(string, string, syscall.RawConn) error { + return errors.New("unsupported platform") + } } diff --git a/component/dialer/control.go b/component/dialer/control.go new file mode 100644 index 0000000..fbdf427 --- /dev/null +++ b/component/dialer/control.go @@ -0,0 +1,37 @@ +package dialer + +import ( + "errors" + "net" + "syscall" +) + +type controlFunc func(string, string, syscall.RawConn) error + +var ( + _controlPool = make([]controlFunc, 0, 2) +) + +func addControl(f controlFunc) { + _controlPool = append(_controlPool, f) +} + +func setControl(i interface{}) { + control := func(address, network string, c syscall.RawConn) error { + for _, f := range _controlPool { + if err := f(address, network, c); err != nil { + return err + } + } + return nil + } + + switch v := i.(type) { + case *net.Dialer: + v.Control = control + case *net.ListenConfig: + v.Control = control + default: + panic(errors.New("wrong type")) + } +} diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 2d375ee..72e5aa6 100755 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -11,20 +11,12 @@ func Dial(network, address string) (net.Conn, error) { func DialContext(ctx context.Context, network, address string) (net.Conn, error) { d := &net.Dialer{} - - if _boundInterface != nil { - d.Control = bindToInterface - } - + setControl(d) return d.DialContext(ctx, network, address) } func ListenPacket(network, address string) (net.PacketConn, error) { lc := &net.ListenConfig{} - - if _boundInterface != nil { - lc.Control = bindToInterface - } - + setControl(lc) return lc.ListenPacket(context.Background(), network, address) } diff --git a/component/dialer/fwmark.go b/component/dialer/fwmark.go new file mode 100644 index 0000000..c7e882b --- /dev/null +++ b/component/dialer/fwmark.go @@ -0,0 +1,14 @@ +package dialer + +import ( + "sync" +) + +var _setOnce sync.Once + +// SetMark sets the mark for each packet sent through this dialer(socket). +func SetMark(i int) { + _setOnce.Do(func() { + addControl(setMark(i)) + }) +} diff --git a/component/dialer/fwmark_linux.go b/component/dialer/fwmark_linux.go new file mode 100644 index 0000000..2a55f06 --- /dev/null +++ b/component/dialer/fwmark_linux.go @@ -0,0 +1,15 @@ +package dialer + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func setMark(i int) controlFunc { + return func(_, _ string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, i) + }) + } +} diff --git a/component/dialer/fwmark_others.go b/component/dialer/fwmark_others.go new file mode 100644 index 0000000..e5cdd14 --- /dev/null +++ b/component/dialer/fwmark_others.go @@ -0,0 +1,14 @@ +// +build !linux + +package dialer + +import ( + "errors" + "syscall" +) + +func setMark(_ int) controlFunc { + return func(string, string, syscall.RawConn) error { + return errors.New("fwmark: linux only") + } +} diff --git a/engine/engine.go b/engine/engine.go index ca25281..28a167d 100755 --- a/engine/engine.go +++ b/engine/engine.go @@ -31,6 +31,7 @@ func Insert(k *Key) { type Key struct { MTU uint32 + Mark int Proxy string Stats string Token string @@ -60,6 +61,7 @@ func (e *engine) start() error { for _, f := range []func() error{ e.setLogLevel, + e.setMark, e.setInterface, e.setStats, e.setProxy, @@ -93,12 +95,20 @@ func (e *engine) setLogLevel() error { return nil } +func (e *engine) setMark() error { + if e.Mark != 0 { + dialer.SetMark(e.Mark) + log.Infof("[DIALER] set fwmark: %d", e.Mark) + } + return nil +} + func (e *engine) setInterface() error { if e.Interface != "" { if err := dialer.BindToInterface(e.Interface); err != nil { return err } - log.Infof("[BOUND] bind to interface: %s", e.Interface) + log.Infof("[DIALER] use interface: %s", e.Interface) } return nil } @@ -108,7 +118,7 @@ func (e *engine) setStats() error { go func() { _ = stats.Start(e.Stats, e.Token) }() - log.Infof("[STATS] stats server listen at: http://%s", e.Stats) + log.Infof("[STATS] serve at: http://%s", e.Stats) } return nil } diff --git a/main.go b/main.go index dcd328d..eb5323a 100755 --- a/main.go +++ b/main.go @@ -15,12 +15,13 @@ var key = new(engine.Key) func init() { flag.StringVarP(&key.Device, "device", "d", "", "use this device [driver://]name") + flag.IntVar(&key.Mark, "fwmark", 0, "set firewall MARK (Linux only)") flag.StringVarP(&key.Interface, "interface", "i", "", "use network INTERFACE (Linux/MacOS only)") - flag.StringVarP(&key.Proxy, "proxy", "p", "", "use this proxy [protocol://]host[:port]") flag.StringVarP(&key.LogLevel, "loglevel", "l", "info", "log level [debug|info|warn|error|silent]") + flag.Uint32VarP(&key.MTU, "mtu", "m", 0, "set device maximum transmission unit (MTU)") + flag.StringVarP(&key.Proxy, "proxy", "p", "", "use this proxy [protocol://]host[:port]") flag.StringVar(&key.Stats, "stats", "", "HTTP statistic server listen address") flag.StringVar(&key.Token, "token", "", "HTTP statistic server auth token") - flag.Uint32VarP(&key.MTU, "mtu", "m", 0, "set device maximum transmission unit (MTU)") flag.BoolVarP(&key.Version, "version", "v", false, "show version information and quit") flag.Parse() }