diff --git a/cmd/main.go b/cmd/main.go index 72cacef..2f51cea 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -54,6 +54,7 @@ type CmdArgs struct { FakeIPRange *string FakeDNSAddr *string FakeDNSHosts *string + HijackDNS *string // Session Stats EnableStats *bool @@ -142,7 +143,7 @@ func main() { proxy.RegisterMonitor(monitor) // Register TCP and UDP handlers to handle accepted connections. core.RegisterTCPConnHandler(proxy.NewTCPHandler(proxyHost, proxyPort)) - core.RegisterUDPConnHandler(proxy.NewUDPHandler(proxyHost, proxyPort, *args.UdpTimeout)) + core.RegisterUDPConnHandler(proxy.NewUDPHandler(proxyHost, proxyPort, *args.UdpTimeout, *args.HijackDNS)) // Register an output callback to write packets output from lwip stack to tun // device, output function should be set before input any packets. diff --git a/cmd/main_fakedns.go b/cmd/main_fakedns.go index 49fbac3..9dc5ab7 100644 --- a/cmd/main_fakedns.go +++ b/cmd/main_fakedns.go @@ -14,6 +14,7 @@ func init() { args.FakeDNSAddr = flag.String("fakeDNSAddr", ":53", "Listen address of fake DNS") args.FakeIPRange = flag.String("fakeIPRange", "198.18.0.0/15", "Fake IP CIDR range for DNS") args.FakeDNSHosts = flag.String("fakeDNSHosts", "", "DNS hosts mapping, e.g. 'example.com=1.1.1.1,example.net=2.2.2.2'") + args.HijackDNS = flag.String("hijackDNS", "", "Hijack the DNS query to get a fake ip, e.g. '*:53' or '8.8.8.8:53,8.8.4.4:53'") registerInitFn(func() { if *args.EnableFakeDNS { diff --git a/component/fakedns/fakedns.go b/component/fakedns/fakedns.go index 2282e26..a0f484f 100644 --- a/component/fakedns/fakedns.go +++ b/component/fakedns/fakedns.go @@ -8,7 +8,7 @@ type FakeDNS interface { Start() error Stop() error - // Generate a fake dns response for the specify request. + // Resolve a fake dns response for the specify request. Resolve([]byte) ([]byte, error) // IPToHost returns the corresponding domain for the given IP. diff --git a/proxy/tcp.go b/proxy/tcp.go index a2ddc09..5542a0c 100644 --- a/proxy/tcp.go +++ b/proxy/tcp.go @@ -81,7 +81,7 @@ func (h *tcpHandler) Handle(conn net.Conn, target *net.TCPAddr) error { // Lookup fakeDNS host record targetHost, err := lookupHost(target) if err != nil { - log.Warnf("lookup target host error: %v", err) + log.Warnf("lookup target host: %v", err) return err } diff --git a/proxy/udp.go b/proxy/udp.go index f2f2be8..c7cdd34 100644 --- a/proxy/udp.go +++ b/proxy/udp.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "strconv" + "strings" "sync" "time" @@ -18,20 +19,37 @@ import ( type udpHandler struct { proxyHost string proxyPort int + timeout time.Duration + hijackDNS []string remoteAddrMap sync.Map remoteConnMap sync.Map } -func NewUDPHandler(proxyHost string, proxyPort int, timeout time.Duration) core.UDPConnHandler { +func NewUDPHandler(proxyHost string, proxyPort int, timeout time.Duration, hijackDNS string) core.UDPConnHandler { return &udpHandler{ proxyHost: proxyHost, proxyPort: proxyPort, timeout: timeout, + hijackDNS: strings.Split(hijackDNS, ","), } } +func (h *udpHandler) isHijacked(target *net.UDPAddr) bool { + for _, addr := range h.hijackDNS { + host, port, err := net.SplitHostPort(addr) + if err != nil { + continue + } + portInt, _ := strconv.Atoi(port) + if (host == "*" && portInt == target.Port) || addr == target.String() { + return true + } + } + return false +} + func (h *udpHandler) fetchUDPInput(conn core.UDPConn, input net.PacketConn, addr *net.UDPAddr) { buf := pool.BufPool.Get().([]byte) @@ -58,14 +76,15 @@ func (h *udpHandler) fetchUDPInput(conn core.UDPConn, input net.PacketConn, addr } func (h *udpHandler) Connect(conn core.UDPConn, target *net.UDPAddr) error { - if target.Port == 53 { + // Check hijackDNS + if h.isHijacked(target) { return nil } // Lookup fakeDNS host record targetHost, err := lookupHost(target) if err != nil { - log.Warnf("lookup target host error: %v", err) + log.Warnf("lookup target host: %v", err) return err } @@ -105,10 +124,18 @@ func (h *udpHandler) Connect(conn core.UDPConn, target *net.UDPAddr) error { } func (h *udpHandler) ReceiveTo(conn core.UDPConn, data []byte, addr *net.UDPAddr) (err error) { - if addr.Port == 53 { + // Close if return error + defer func() { + if err != nil { + h.Close(conn) + } + }() + + // Check hijackDNS + if h.isHijacked(addr) { resp, err := fakeDNS.Resolve(data) if err != nil { - log.Warnf("hijack DNS: %v", err) + return fmt.Errorf("hijack DNS request error: %v", err) } else { if _, err = conn.WriteFrom(resp, addr); err != nil { return fmt.Errorf("write dns answer failed: %v", err) @@ -118,13 +145,6 @@ func (h *udpHandler) ReceiveTo(conn core.UDPConn, data []byte, addr *net.UDPAddr } } - // Close if return error - defer func() { - if err != nil { - h.Close(conn) - } - }() - var remoteAddr net.Addr var remoteConn net.PacketConn