diff --git a/engine/engine.go b/engine/engine.go index a68b0aa..1c9c5a2 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -2,20 +2,15 @@ package engine import ( "errors" - "fmt" "net" - "os" "github.com/xjasonlyu/tun2socks/v2/component/dialer" "github.com/xjasonlyu/tun2socks/v2/core/device" "github.com/xjasonlyu/tun2socks/v2/core/stack" - "github.com/xjasonlyu/tun2socks/v2/internal/version" "github.com/xjasonlyu/tun2socks/v2/log" "github.com/xjasonlyu/tun2socks/v2/proxy" "github.com/xjasonlyu/tun2socks/v2/stats" "github.com/xjasonlyu/tun2socks/v2/tunnel" - - "gopkg.in/yaml.v3" ) var _engine = &engine{} @@ -45,8 +40,6 @@ type Key struct { Device string `yaml:"device"` LogLevel string `yaml:"loglevel"` Interface string `yaml:"interface"` - Config string `yaml:"-"` - Version bool `yaml:"-"` } type engine struct { @@ -62,22 +55,15 @@ func (e *engine) start() error { return errors.New("empty key") } - if e.Version { - fmt.Println(version.String()) - fmt.Println(version.BuildString()) - os.Exit(0) - } - for _, f := range []func() error{ - e.setConfig, - e.setLogLevel, - e.setMark, - e.setInterface, - e.setStats, - e.setUDPTimeout, - e.setProxy, - e.setDevice, - e.setStack, + e.applyLogLevel, + e.applyMark, + e.applyInterface, + e.applyStats, + e.applyUDPTimeout, + e.applyProxy, + e.applyDevice, + e.applyStack, } { if err := f(); err != nil { return err @@ -97,19 +83,7 @@ func (e *engine) insert(k *Key) { e.Key = k } -func (e *engine) setConfig() error { - if e.Config == "" { - return nil - } - - data, err := os.ReadFile(e.Config) - if err != nil { - return err - } - return yaml.Unmarshal(data, e.Key) -} - -func (e *engine) setLogLevel() error { +func (e *engine) applyLogLevel() error { level, err := log.ParseLevel(e.LogLevel) if err != nil { return err @@ -118,7 +92,7 @@ func (e *engine) setLogLevel() error { return nil } -func (e *engine) setMark() error { +func (e *engine) applyMark() error { if e.Mark != 0 { dialer.SetMark(e.Mark) log.Infof("[DIALER] set fwmark: %#x", e.Mark) @@ -126,7 +100,7 @@ func (e *engine) setMark() error { return nil } -func (e *engine) setInterface() error { +func (e *engine) applyInterface() error { if e.Interface != "" { if err := dialer.BindToInterface(e.Interface); err != nil { return err @@ -136,7 +110,7 @@ func (e *engine) setInterface() error { return nil } -func (e *engine) setStats() error { +func (e *engine) applyStats() error { if e.Stats != "" { addr, err := net.ResolveTCPAddr("tcp", e.Stats) if err != nil { @@ -153,14 +127,14 @@ func (e *engine) setStats() error { return nil } -func (e *engine) setUDPTimeout() error { +func (e *engine) applyUDPTimeout() error { if e.UDPTimeout > 0 { tunnel.SetUDPTimeout(e.UDPTimeout) } return nil } -func (e *engine) setProxy() (err error) { +func (e *engine) applyProxy() (err error) { if e.Proxy == "" { return errors.New("empty proxy") } @@ -170,7 +144,7 @@ func (e *engine) setProxy() (err error) { return } -func (e *engine) setDevice() (err error) { +func (e *engine) applyDevice() (err error) { if e.Device == "" { return errors.New("empty device") } @@ -179,7 +153,7 @@ func (e *engine) setDevice() (err error) { return } -func (e *engine) setStack() (err error) { +func (e *engine) applyStack() (err error) { defer func() { if err == nil { log.Infof( diff --git a/main.go b/main.go index 0c80001..4861533 100644 --- a/main.go +++ b/main.go @@ -2,24 +2,32 @@ package main import ( "flag" + "fmt" "os" "os/signal" "syscall" "github.com/xjasonlyu/tun2socks/v2/engine" + "github.com/xjasonlyu/tun2socks/v2/internal/version" "github.com/xjasonlyu/tun2socks/v2/log" "go.uber.org/automaxprocs/maxprocs" + "gopkg.in/yaml.v3" ) -var key = new(engine.Key) +var ( + key = new(engine.Key) + + configFile string + versionFlag bool +) func init() { + flag.BoolVar(&versionFlag, "version", false, "Show version and then quit") flag.IntVar(&key.Mark, "fwmark", 0, "Set firewall MARK (Linux only)") flag.IntVar(&key.MTU, "mtu", 0, "Set device maximum transmission unit (MTU)") flag.IntVar(&key.UDPTimeout, "udp-timeout", 0, "Set timeout for each UDP session") - flag.BoolVar(&key.Version, "version", false, "Show version information and quit") - flag.StringVar(&key.Config, "config", "", "YAML format configuration file") + flag.StringVar(&configFile, "config", "", "YAML format configuration file") flag.StringVar(&key.Device, "device", "", "Use this device [driver://]name") flag.StringVar(&key.Interface, "interface", "", "Use network INTERFACE (Linux/MacOS only)") flag.StringVar(&key.LogLevel, "loglevel", "info", "Log level [debug|info|warning|error|silent]") @@ -32,16 +40,32 @@ func init() { func main() { maxprocs.Set(maxprocs.Logger(func(string, ...any) {})) + if versionFlag { + fmt.Println(version.String()) + fmt.Println(version.BuildString()) + os.Exit(0) + } + + if configFile != "" { + data, err := os.ReadFile(configFile) + if err != nil { + log.Fatalf("Failed to read config %s: %v", configFile, err) + } + if err = yaml.Unmarshal(data, key); err != nil { + log.Fatalf("Failed to unmarshal config %s: %v", configFile, err) + } + } + engine.Insert(key) - checkErr := func(msg string, f func() error) { - if err := f(); err != nil { + assert := func(msg string, err error) { + if err != nil { log.Fatalf("Failed to %s: %v", msg, err) } } - checkErr("start engine", engine.Start) - defer checkErr("stop engine", engine.Stop) + assert("start engine", engine.Start()) + defer assert("stop engine", engine.Stop()) sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)