diff --git a/cmd/main.go b/cmd/main.go index 2f51cea..ae126cf 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -139,11 +139,11 @@ func main() { lwipWriter = filter.NewICMPFilter(lwipWriter).(io.Writer) // Register modules to proxy - proxy.RegisterFakeDNS(fakeDNS) proxy.RegisterMonitor(monitor) + proxy.RegisterFakeDNS(fakeDNS, *args.HijackDNS) // Register TCP and UDP handlers to handle accepted connections. core.RegisterTCPConnHandler(proxy.NewTCPHandler(proxyHost, proxyPort)) - core.RegisterUDPConnHandler(proxy.NewUDPHandler(proxyHost, proxyPort, *args.UdpTimeout, *args.HijackDNS)) + core.RegisterUDPConnHandler(proxy.NewUDPHandler(proxyHost, proxyPort, *args.UdpTimeout)) // Register an output callback to write packets output from lwip stack to tun // device, output function should be set before input any packets. diff --git a/proxy/proxy.go b/proxy/proxy.go index cdc997d..ca07f83 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -3,24 +3,62 @@ package proxy import ( "errors" "net" + "strconv" + "strings" D "github.com/xjasonlyu/tun2socks/component/fakedns" S "github.com/xjasonlyu/tun2socks/component/session" ) var ( - fakeDNS D.FakeDNS monitor S.Monitor + + fakeDNS D.FakeDNS + hijackDNS []string ) -func RegisterFakeDNS(d D.FakeDNS) { - fakeDNS = d -} - +// Register Monitor func RegisterMonitor(m S.Monitor) { monitor = m } +// Session Operation +func addSession(key interface{}, session *S.Session) { + if monitor != nil { + monitor.AddSession(key, session) + } +} + +func removeSession(key interface{}) { + if monitor != nil { + monitor.RemoveSession(key) + } +} + +// Register FakeDNS +func RegisterFakeDNS(d D.FakeDNS, h string) { + fakeDNS = d + hijackDNS = strings.Split(h, ",") +} + +// Check target if is hijacked address. +func isHijacked(target *net.UDPAddr) bool { + if hijackDNS == nil { + return false + } + for _, addr := range hijackDNS { + host, port, err := net.SplitHostPort(addr) + if err != nil { + continue + } + portInt, _ := strconv.Atoi(port) + if (host == "*" && portInt == target.Port) || addr == target.String() { + return true + } + } + return false +} + // DNS lookup func lookupHost(target net.Addr) (targetHost string, err error) { var targetIP net.IP @@ -43,16 +81,3 @@ func lookupHost(target net.Addr) (targetHost string, err error) { } return } - -// Session Operation -func addSession(key interface{}, session *S.Session) { - if monitor != nil { - monitor.AddSession(key, session) - } -} - -func removeSession(key interface{}) { - if monitor != nil { - monitor.RemoveSession(key) - } -} diff --git a/proxy/udp.go b/proxy/udp.go index c7cdd34..543962d 100644 --- a/proxy/udp.go +++ b/proxy/udp.go @@ -4,7 +4,6 @@ import ( "fmt" "net" "strconv" - "strings" "sync" "time" @@ -19,37 +18,20 @@ import ( type udpHandler struct { proxyHost string proxyPort int - timeout time.Duration - hijackDNS []string remoteAddrMap sync.Map remoteConnMap sync.Map } -func NewUDPHandler(proxyHost string, proxyPort int, timeout time.Duration, hijackDNS string) core.UDPConnHandler { +func NewUDPHandler(proxyHost string, proxyPort int, timeout time.Duration) core.UDPConnHandler { return &udpHandler{ proxyHost: proxyHost, proxyPort: proxyPort, timeout: timeout, - hijackDNS: strings.Split(hijackDNS, ","), } } -func (h *udpHandler) isHijacked(target *net.UDPAddr) bool { - for _, addr := range h.hijackDNS { - host, port, err := net.SplitHostPort(addr) - if err != nil { - continue - } - portInt, _ := strconv.Atoi(port) - if (host == "*" && portInt == target.Port) || addr == target.String() { - return true - } - } - return false -} - func (h *udpHandler) fetchUDPInput(conn core.UDPConn, input net.PacketConn, addr *net.UDPAddr) { buf := pool.BufPool.Get().([]byte) @@ -77,7 +59,7 @@ func (h *udpHandler) fetchUDPInput(conn core.UDPConn, input net.PacketConn, addr func (h *udpHandler) Connect(conn core.UDPConn, target *net.UDPAddr) error { // Check hijackDNS - if h.isHijacked(target) { + if isHijacked(target) { return nil } @@ -132,7 +114,7 @@ func (h *udpHandler) ReceiveTo(conn core.UDPConn, data []byte, addr *net.UDPAddr }() // Check hijackDNS - if h.isHijacked(addr) { + if isHijacked(addr) { resp, err := fakeDNS.Resolve(data) if err != nil { return fmt.Errorf("hijack DNS request error: %v", err)