diff --git a/proxy/direct/udp.go b/proxy/direct/udp.go index 37e6cf0..ed72f30 100644 --- a/proxy/direct/udp.go +++ b/proxy/direct/udp.go @@ -40,7 +40,7 @@ func (h *udpHandler) fetchUDPInput(conn core.UDPConn, pc *net.UDPConn) { return } - if _, err = conn.WriteFrom(buf[:n], addr); err != nil { + if _, err := conn.WriteFrom(buf[:n], addr); err != nil { log.Warnf("failed to write UDP data to TUN: %v", err) return } @@ -54,8 +54,11 @@ func (h *udpHandler) Connect(conn core.UDPConn, target *net.UDPAddr) error { log.Errorf("failed to bind udp address") return err } + h.remoteUDPConnMap.Store(conn, pc) + go h.fetchUDPInput(conn, pc) + log.Infof("new proxy connection for target: %s:%s", target.Network(), target.String()) return nil } @@ -80,9 +83,10 @@ func (h *udpHandler) ReceiveTo(conn core.UDPConn, data []byte, addr *net.UDPAddr } func (h *udpHandler) Close(conn core.UDPConn) { + conn.Close() + if pc, ok := h.remoteUDPConnMap.Load(conn); ok { pc.(*net.UDPConn).Close() h.remoteUDPConnMap.Delete(conn) } - conn.Close() } diff --git a/proxy/redirect/udp.go b/proxy/redirect/udp.go index aa913f5..0c2a57d 100644 --- a/proxy/redirect/udp.go +++ b/proxy/redirect/udp.go @@ -45,8 +45,7 @@ func (h *udpHandler) fetchUDPInput(conn core.UDPConn, pc *net.UDPConn) { return } - _, err = conn.WriteFrom(buf[:n], addr) - if err != nil { + if _, err := conn.WriteFrom(buf[:n], addr); err != nil { log.Warnf("failed to write UDP data to TUN") return } @@ -61,9 +60,12 @@ func (h *udpHandler) Connect(conn core.UDPConn, target *net.UDPAddr) error { return err } targetAddr, _ := net.ResolveUDPAddr("udp", h.target) + h.remoteAddrMap.Store(conn, targetAddr) h.remoteUDPConnMap.Store(conn, pc) + go h.fetchUDPInput(conn, pc) + log.Infof("new proxy connection for target: %s:%s", target.Network(), target.String()) return nil } diff --git a/proxy/socks/udp.go b/proxy/socks/udp.go index ebe2784..c2ec078 100644 --- a/proxy/socks/udp.go +++ b/proxy/socks/udp.go @@ -233,6 +233,8 @@ func (h *udpHandler) ReceiveTo(conn core.UDPConn, data []byte, addr *net.UDPAddr } func (h *udpHandler) Close(conn core.UDPConn) { + conn.Close() + if remoteConn, ok := h.remoteConnMap.Load(conn); ok { remoteConn.(net.Conn).Close() h.remoteConnMap.Delete(conn) @@ -243,7 +245,6 @@ func (h *udpHandler) Close(conn core.UDPConn) { h.remotePacketConnMap.Delete(conn) } - conn.Close() h.remoteAddrMap.Delete(conn) if h.sessionStater != nil {