mirror of
https://github.com/xjasonlyu/tun2socks.git
synced 2025-10-05 16:56:54 +08:00
Feature: add fwmark support
This commit is contained in:
@@ -1,8 +1,11 @@
|
|||||||
package dialer
|
package dialer
|
||||||
|
|
||||||
import "net"
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
var _boundInterface *net.Interface
|
var _bindOnce sync.Once
|
||||||
|
|
||||||
// BindToInterface binds dialer to specific interface.
|
// BindToInterface binds dialer to specific interface.
|
||||||
func BindToInterface(name string) error {
|
func BindToInterface(name string) error {
|
||||||
@@ -10,6 +13,9 @@ func BindToInterface(name string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_boundInterface = i
|
|
||||||
|
_bindOnce.Do(func() {
|
||||||
|
addControl(bindToInterface(i))
|
||||||
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -7,7 +7,8 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
func bindToInterface(network, address string, c syscall.RawConn) error {
|
func bindToInterface(i *net.Interface) controlFunc {
|
||||||
|
return func(network, address string, c syscall.RawConn) error {
|
||||||
ipStr, _, _ := net.SplitHostPort(address)
|
ipStr, _, _ := net.SplitHostPort(address)
|
||||||
if ip := net.ParseIP(ipStr); ip != nil && !ip.IsGlobalUnicast() {
|
if ip := net.ParseIP(ipStr); ip != nil && !ip.IsGlobalUnicast() {
|
||||||
return nil
|
return nil
|
||||||
@@ -16,9 +17,10 @@ func bindToInterface(network, address string, c syscall.RawConn) error {
|
|||||||
return c.Control(func(fd uintptr) {
|
return c.Control(func(fd uintptr) {
|
||||||
switch network {
|
switch network {
|
||||||
case "tcp4", "udp4":
|
case "tcp4", "udp4":
|
||||||
unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, _boundInterface.Index)
|
unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, i.Index)
|
||||||
case "tcp6", "udp6":
|
case "tcp6", "udp6":
|
||||||
unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, _boundInterface.Index)
|
unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, i.Index)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@@ -7,13 +7,15 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
func bindToInterface(network, address string, c syscall.RawConn) error {
|
func bindToInterface(i *net.Interface) controlFunc {
|
||||||
|
return func(network, address string, c syscall.RawConn) error {
|
||||||
ipStr, _, _ := net.SplitHostPort(address)
|
ipStr, _, _ := net.SplitHostPort(address)
|
||||||
if ip := net.ParseIP(ipStr); ip != nil && !ip.IsGlobalUnicast() {
|
if ip := net.ParseIP(ipStr); ip != nil && !ip.IsGlobalUnicast() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.Control(func(fd uintptr) {
|
return c.Control(func(fd uintptr) {
|
||||||
unix.BindToDevice(int(fd), _boundInterface.Name)
|
unix.BindToDevice(int(fd), i.Name)
|
||||||
})
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@@ -4,9 +4,12 @@ package dialer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
func bindToInterface(network, address string, c syscall.RawConn) error {
|
func bindToInterface(_ *net.Interface) controlFunc {
|
||||||
|
return func(string, string, syscall.RawConn) error {
|
||||||
return errors.New("unsupported platform")
|
return errors.New("unsupported platform")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
37
component/dialer/control.go
Normal file
37
component/dialer/control.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package dialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
type controlFunc func(string, string, syscall.RawConn) error
|
||||||
|
|
||||||
|
var (
|
||||||
|
_controlPool = make([]controlFunc, 0, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
func addControl(f controlFunc) {
|
||||||
|
_controlPool = append(_controlPool, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setControl(i interface{}) {
|
||||||
|
control := func(address, network string, c syscall.RawConn) error {
|
||||||
|
for _, f := range _controlPool {
|
||||||
|
if err := f(address, network, c); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := i.(type) {
|
||||||
|
case *net.Dialer:
|
||||||
|
v.Control = control
|
||||||
|
case *net.ListenConfig:
|
||||||
|
v.Control = control
|
||||||
|
default:
|
||||||
|
panic(errors.New("wrong type"))
|
||||||
|
}
|
||||||
|
}
|
@@ -11,20 +11,12 @@ func Dial(network, address string) (net.Conn, error) {
|
|||||||
|
|
||||||
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
d := &net.Dialer{}
|
d := &net.Dialer{}
|
||||||
|
setControl(d)
|
||||||
if _boundInterface != nil {
|
|
||||||
d.Control = bindToInterface
|
|
||||||
}
|
|
||||||
|
|
||||||
return d.DialContext(ctx, network, address)
|
return d.DialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenPacket(network, address string) (net.PacketConn, error) {
|
func ListenPacket(network, address string) (net.PacketConn, error) {
|
||||||
lc := &net.ListenConfig{}
|
lc := &net.ListenConfig{}
|
||||||
|
setControl(lc)
|
||||||
if _boundInterface != nil {
|
|
||||||
lc.Control = bindToInterface
|
|
||||||
}
|
|
||||||
|
|
||||||
return lc.ListenPacket(context.Background(), network, address)
|
return lc.ListenPacket(context.Background(), network, address)
|
||||||
}
|
}
|
||||||
|
14
component/dialer/fwmark.go
Normal file
14
component/dialer/fwmark.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package dialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _setOnce sync.Once
|
||||||
|
|
||||||
|
// SetMark sets the mark for each packet sent through this dialer(socket).
|
||||||
|
func SetMark(i int) {
|
||||||
|
_setOnce.Do(func() {
|
||||||
|
addControl(setMark(i))
|
||||||
|
})
|
||||||
|
}
|
15
component/dialer/fwmark_linux.go
Normal file
15
component/dialer/fwmark_linux.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package dialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setMark(i int) controlFunc {
|
||||||
|
return func(_, _ string, c syscall.RawConn) error {
|
||||||
|
return c.Control(func(fd uintptr) {
|
||||||
|
unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, i)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
14
component/dialer/fwmark_others.go
Normal file
14
component/dialer/fwmark_others.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
// +build !linux
|
||||||
|
|
||||||
|
package dialer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setMark(_ int) controlFunc {
|
||||||
|
return func(string, string, syscall.RawConn) error {
|
||||||
|
return errors.New("fwmark: linux only")
|
||||||
|
}
|
||||||
|
}
|
@@ -31,6 +31,7 @@ func Insert(k *Key) {
|
|||||||
|
|
||||||
type Key struct {
|
type Key struct {
|
||||||
MTU uint32
|
MTU uint32
|
||||||
|
Mark int
|
||||||
Proxy string
|
Proxy string
|
||||||
Stats string
|
Stats string
|
||||||
Token string
|
Token string
|
||||||
@@ -60,6 +61,7 @@ func (e *engine) start() error {
|
|||||||
|
|
||||||
for _, f := range []func() error{
|
for _, f := range []func() error{
|
||||||
e.setLogLevel,
|
e.setLogLevel,
|
||||||
|
e.setMark,
|
||||||
e.setInterface,
|
e.setInterface,
|
||||||
e.setStats,
|
e.setStats,
|
||||||
e.setProxy,
|
e.setProxy,
|
||||||
@@ -93,12 +95,20 @@ func (e *engine) setLogLevel() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *engine) setMark() error {
|
||||||
|
if e.Mark != 0 {
|
||||||
|
dialer.SetMark(e.Mark)
|
||||||
|
log.Infof("[DIALER] set fwmark: %d", e.Mark)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (e *engine) setInterface() error {
|
func (e *engine) setInterface() error {
|
||||||
if e.Interface != "" {
|
if e.Interface != "" {
|
||||||
if err := dialer.BindToInterface(e.Interface); err != nil {
|
if err := dialer.BindToInterface(e.Interface); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Infof("[BOUND] bind to interface: %s", e.Interface)
|
log.Infof("[DIALER] use interface: %s", e.Interface)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -108,7 +118,7 @@ func (e *engine) setStats() error {
|
|||||||
go func() {
|
go func() {
|
||||||
_ = stats.Start(e.Stats, e.Token)
|
_ = stats.Start(e.Stats, e.Token)
|
||||||
}()
|
}()
|
||||||
log.Infof("[STATS] stats server listen at: http://%s", e.Stats)
|
log.Infof("[STATS] serve at: http://%s", e.Stats)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
5
main.go
5
main.go
@@ -15,12 +15,13 @@ var key = new(engine.Key)
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
flag.StringVarP(&key.Device, "device", "d", "", "use this device [driver://]name")
|
flag.StringVarP(&key.Device, "device", "d", "", "use this device [driver://]name")
|
||||||
|
flag.IntVar(&key.Mark, "fwmark", 0, "set firewall MARK (Linux only)")
|
||||||
flag.StringVarP(&key.Interface, "interface", "i", "", "use network INTERFACE (Linux/MacOS only)")
|
flag.StringVarP(&key.Interface, "interface", "i", "", "use network INTERFACE (Linux/MacOS only)")
|
||||||
flag.StringVarP(&key.Proxy, "proxy", "p", "", "use this proxy [protocol://]host[:port]")
|
|
||||||
flag.StringVarP(&key.LogLevel, "loglevel", "l", "info", "log level [debug|info|warn|error|silent]")
|
flag.StringVarP(&key.LogLevel, "loglevel", "l", "info", "log level [debug|info|warn|error|silent]")
|
||||||
|
flag.Uint32VarP(&key.MTU, "mtu", "m", 0, "set device maximum transmission unit (MTU)")
|
||||||
|
flag.StringVarP(&key.Proxy, "proxy", "p", "", "use this proxy [protocol://]host[:port]")
|
||||||
flag.StringVar(&key.Stats, "stats", "", "HTTP statistic server listen address")
|
flag.StringVar(&key.Stats, "stats", "", "HTTP statistic server listen address")
|
||||||
flag.StringVar(&key.Token, "token", "", "HTTP statistic server auth token")
|
flag.StringVar(&key.Token, "token", "", "HTTP statistic server auth token")
|
||||||
flag.Uint32VarP(&key.MTU, "mtu", "m", 0, "set device maximum transmission unit (MTU)")
|
|
||||||
flag.BoolVarP(&key.Version, "version", "v", false, "show version information and quit")
|
flag.BoolVarP(&key.Version, "version", "v", false, "show version information and quit")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user