update hijackDNS

This commit is contained in:
Jason
2019-08-16 20:47:42 +08:00
parent 78dfd36759
commit b44b3e13f6
3 changed files with 48 additions and 41 deletions

View File

@@ -139,11 +139,11 @@ func main() {
lwipWriter = filter.NewICMPFilter(lwipWriter).(io.Writer) lwipWriter = filter.NewICMPFilter(lwipWriter).(io.Writer)
// Register modules to proxy // Register modules to proxy
proxy.RegisterFakeDNS(fakeDNS)
proxy.RegisterMonitor(monitor) proxy.RegisterMonitor(monitor)
proxy.RegisterFakeDNS(fakeDNS, *args.HijackDNS)
// Register TCP and UDP handlers to handle accepted connections. // Register TCP and UDP handlers to handle accepted connections.
core.RegisterTCPConnHandler(proxy.NewTCPHandler(proxyHost, proxyPort)) core.RegisterTCPConnHandler(proxy.NewTCPHandler(proxyHost, proxyPort))
core.RegisterUDPConnHandler(proxy.NewUDPHandler(proxyHost, proxyPort, *args.UdpTimeout, *args.HijackDNS)) core.RegisterUDPConnHandler(proxy.NewUDPHandler(proxyHost, proxyPort, *args.UdpTimeout))
// Register an output callback to write packets output from lwip stack to tun // Register an output callback to write packets output from lwip stack to tun
// device, output function should be set before input any packets. // device, output function should be set before input any packets.

View File

@@ -3,24 +3,62 @@ package proxy
import ( import (
"errors" "errors"
"net" "net"
"strconv"
"strings"
D "github.com/xjasonlyu/tun2socks/component/fakedns" D "github.com/xjasonlyu/tun2socks/component/fakedns"
S "github.com/xjasonlyu/tun2socks/component/session" S "github.com/xjasonlyu/tun2socks/component/session"
) )
var ( var (
fakeDNS D.FakeDNS
monitor S.Monitor monitor S.Monitor
fakeDNS D.FakeDNS
hijackDNS []string
) )
func RegisterFakeDNS(d D.FakeDNS) { // Register Monitor
fakeDNS = d
}
func RegisterMonitor(m S.Monitor) { func RegisterMonitor(m S.Monitor) {
monitor = m monitor = m
} }
// Session Operation
func addSession(key interface{}, session *S.Session) {
if monitor != nil {
monitor.AddSession(key, session)
}
}
func removeSession(key interface{}) {
if monitor != nil {
monitor.RemoveSession(key)
}
}
// Register FakeDNS
func RegisterFakeDNS(d D.FakeDNS, h string) {
fakeDNS = d
hijackDNS = strings.Split(h, ",")
}
// Check target if is hijacked address.
func isHijacked(target *net.UDPAddr) bool {
if hijackDNS == nil {
return false
}
for _, addr := range hijackDNS {
host, port, err := net.SplitHostPort(addr)
if err != nil {
continue
}
portInt, _ := strconv.Atoi(port)
if (host == "*" && portInt == target.Port) || addr == target.String() {
return true
}
}
return false
}
// DNS lookup // DNS lookup
func lookupHost(target net.Addr) (targetHost string, err error) { func lookupHost(target net.Addr) (targetHost string, err error) {
var targetIP net.IP var targetIP net.IP
@@ -43,16 +81,3 @@ func lookupHost(target net.Addr) (targetHost string, err error) {
} }
return return
} }
// Session Operation
func addSession(key interface{}, session *S.Session) {
if monitor != nil {
monitor.AddSession(key, session)
}
}
func removeSession(key interface{}) {
if monitor != nil {
monitor.RemoveSession(key)
}
}

View File

@@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@@ -19,37 +18,20 @@ import (
type udpHandler struct { type udpHandler struct {
proxyHost string proxyHost string
proxyPort int proxyPort int
timeout time.Duration timeout time.Duration
hijackDNS []string
remoteAddrMap sync.Map remoteAddrMap sync.Map
remoteConnMap sync.Map remoteConnMap sync.Map
} }
func NewUDPHandler(proxyHost string, proxyPort int, timeout time.Duration, hijackDNS string) core.UDPConnHandler { func NewUDPHandler(proxyHost string, proxyPort int, timeout time.Duration) core.UDPConnHandler {
return &udpHandler{ return &udpHandler{
proxyHost: proxyHost, proxyHost: proxyHost,
proxyPort: proxyPort, proxyPort: proxyPort,
timeout: timeout, timeout: timeout,
hijackDNS: strings.Split(hijackDNS, ","),
} }
} }
func (h *udpHandler) isHijacked(target *net.UDPAddr) bool {
for _, addr := range h.hijackDNS {
host, port, err := net.SplitHostPort(addr)
if err != nil {
continue
}
portInt, _ := strconv.Atoi(port)
if (host == "*" && portInt == target.Port) || addr == target.String() {
return true
}
}
return false
}
func (h *udpHandler) fetchUDPInput(conn core.UDPConn, input net.PacketConn, addr *net.UDPAddr) { func (h *udpHandler) fetchUDPInput(conn core.UDPConn, input net.PacketConn, addr *net.UDPAddr) {
buf := pool.BufPool.Get().([]byte) buf := pool.BufPool.Get().([]byte)
@@ -77,7 +59,7 @@ func (h *udpHandler) fetchUDPInput(conn core.UDPConn, input net.PacketConn, addr
func (h *udpHandler) Connect(conn core.UDPConn, target *net.UDPAddr) error { func (h *udpHandler) Connect(conn core.UDPConn, target *net.UDPAddr) error {
// Check hijackDNS // Check hijackDNS
if h.isHijacked(target) { if isHijacked(target) {
return nil return nil
} }
@@ -132,7 +114,7 @@ func (h *udpHandler) ReceiveTo(conn core.UDPConn, data []byte, addr *net.UDPAddr
}() }()
// Check hijackDNS // Check hijackDNS
if h.isHijacked(addr) { if isHijacked(addr) {
resp, err := fakeDNS.Resolve(data) resp, err := fakeDNS.Resolve(data)
if err != nil { if err != nil {
return fmt.Errorf("hijack DNS request error: %v", err) return fmt.Errorf("hijack DNS request error: %v", err)