From a25572ced6f19f1d39bb746764d2d25c5b333657 Mon Sep 17 00:00:00 2001 From: Jason Date: Wed, 7 Aug 2019 20:26:11 +0800 Subject: [PATCH] fix a bug --- proxy/socks/tcp.go | 9 +++++---- proxy/socks/udp.go | 30 ++++++++++++++++-------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/proxy/socks/tcp.go b/proxy/socks/tcp.go index c3bc07e..b2552ae 100644 --- a/proxy/socks/tcp.go +++ b/proxy/socks/tcp.go @@ -20,6 +20,8 @@ import ( type tcpHandler struct { sync.Mutex + sessKey string + proxyHost string proxyPort uint16 @@ -110,8 +112,7 @@ func (h *tcpHandler) relay(localConn, remoteConn net.Conn, sess *stats.Session) <-upCh // Wait for uplink done. if h.sessionStater != nil { - key := fmt.Sprintf("%s:%s", localConn.LocalAddr().Network(), localConn.LocalAddr().String()) - h.sessionStater.RemoveSession(key) + h.sessionStater.RemoveSession(h.sessKey) } } @@ -156,8 +157,8 @@ func (h *tcpHandler) Handle(localConn net.Conn, target *net.TCPAddr) error { DownloadBytes: 0, SessionStart: time.Now(), } - key := fmt.Sprintf("%s:%s", localConn.LocalAddr().Network(), localConn.LocalAddr().String()) - h.sessionStater.AddSession(key, sess) + h.sessKey = fmt.Sprintf("%s:%s", localConn.LocalAddr().Network(), localConn.LocalAddr().String()) + h.sessionStater.AddSession(h.sessKey, sess) } // set keepalive diff --git a/proxy/socks/udp.go b/proxy/socks/udp.go index 398c6f1..1fe6fc8 100644 --- a/proxy/socks/udp.go +++ b/proxy/socks/udp.go @@ -19,7 +19,9 @@ import ( type udpHandler struct { sync.Mutex - closed bool + closed bool + sessKey string + proxyHost string proxyPort uint16 timeout time.Duration @@ -40,8 +42,9 @@ func NewUDPHandler(proxyHost string, proxyPort uint16, timeout time.Duration, fa tcpConns: make(map[core.UDPConn]net.Conn, 8), remoteAddrs: make(map[core.UDPConn]*net.UDPAddr, 8), fakeDns: fakeDns, - timeout: timeout, sessionStater: sessionStater, + timeout: timeout, + closed: false, } } @@ -86,7 +89,7 @@ func (h *udpHandler) fetchUDPInput(conn core.UDPConn, input net.PacketConn) { } n, err = conn.WriteFrom(buf[int(3+len(addr)):n], resolvedAddr) if n > 0 && h.sessionStater != nil { - if sess := h.sessionStater.GetSession(conn); sess != nil { + if sess := h.sessionStater.GetSession(h.sessKey); sess != nil { sess.AddDownloadBytes(int64(n)) } } @@ -194,8 +197,8 @@ func (h *udpHandler) connectInternal(conn core.UDPConn, targetAddr string) error DownloadBytes: 0, SessionStart: time.Now(), } - key := fmt.Sprintf("%s:%s", conn.LocalAddr().Network(), conn.LocalAddr().String()) - h.sessionStater.AddSession(key, sess) + h.sessKey = fmt.Sprintf("%s:%s", conn.LocalAddr().Network(), conn.LocalAddr().String()) + h.sessionStater.AddSession(h.sessKey, sess) } log.Access(process, "proxy", "udp", conn.LocalAddr().String(), targetAddr) } @@ -204,7 +207,7 @@ func (h *udpHandler) connectInternal(conn core.UDPConn, targetAddr string) error func (h *udpHandler) ReceiveTo(conn core.UDPConn, data []byte, addr *net.UDPAddr) error { h.Lock() - pc, ok1 := h.udpConns[conn] + remoteUDPConn, ok1 := h.udpConns[conn] remoteAddr, ok2 := h.remoteAddrs[conn] h.Unlock() @@ -220,9 +223,9 @@ func (h *udpHandler) ReceiveTo(conn core.UDPConn, data []byte, addr *net.UDPAddr targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(addr.Port)) buf := append([]byte{0, 0, 0}, ParseAddr(targetAddr)...) buf = append(buf, data[:]...) - n, err := pc.WriteTo(buf, remoteAddr) + n, err := remoteUDPConn.WriteTo(buf, remoteAddr) if n > 0 && h.sessionStater != nil { - if sess := h.sessionStater.GetSession(conn); sess != nil { + if sess := h.sessionStater.GetSession(h.sessKey); sess != nil { sess.AddUploadBytes(int64(n)) } } @@ -245,12 +248,12 @@ func (h *udpHandler) Close(conn core.UDPConn) { h.Lock() defer h.Unlock() - if c, ok := h.tcpConns[conn]; ok { - c.Close() + if remoteConn, ok := h.tcpConns[conn]; ok { + remoteConn.Close() delete(h.tcpConns, conn) } - if pc, ok := h.udpConns[conn]; ok { - pc.Close() + if remoteUDPConn, ok := h.udpConns[conn]; ok { + remoteUDPConn.Close() delete(h.udpConns, conn) } @@ -258,8 +261,7 @@ func (h *udpHandler) Close(conn core.UDPConn) { delete(h.remoteAddrs, conn) if h.sessionStater != nil { - key := fmt.Sprintf("%s:%s", conn.LocalAddr().Network(), conn.LocalAddr().String()) - h.sessionStater.RemoveSession(key) + h.sessionStater.RemoveSession(h.sessKey) } h.closed = true