fix a bug

This commit is contained in:
Jason
2019-08-07 20:26:11 +08:00
parent 95ef383397
commit a25572ced6
2 changed files with 21 additions and 18 deletions

View File

@@ -20,6 +20,8 @@ import (
type tcpHandler struct { type tcpHandler struct {
sync.Mutex sync.Mutex
sessKey string
proxyHost string proxyHost string
proxyPort uint16 proxyPort uint16
@@ -110,8 +112,7 @@ func (h *tcpHandler) relay(localConn, remoteConn net.Conn, sess *stats.Session)
<-upCh // Wait for uplink done. <-upCh // Wait for uplink done.
if h.sessionStater != nil { if h.sessionStater != nil {
key := fmt.Sprintf("%s:%s", localConn.LocalAddr().Network(), localConn.LocalAddr().String()) h.sessionStater.RemoveSession(h.sessKey)
h.sessionStater.RemoveSession(key)
} }
} }
@@ -156,8 +157,8 @@ func (h *tcpHandler) Handle(localConn net.Conn, target *net.TCPAddr) error {
DownloadBytes: 0, DownloadBytes: 0,
SessionStart: time.Now(), SessionStart: time.Now(),
} }
key := fmt.Sprintf("%s:%s", localConn.LocalAddr().Network(), localConn.LocalAddr().String()) h.sessKey = fmt.Sprintf("%s:%s", localConn.LocalAddr().Network(), localConn.LocalAddr().String())
h.sessionStater.AddSession(key, sess) h.sessionStater.AddSession(h.sessKey, sess)
} }
// set keepalive // set keepalive

View File

@@ -19,7 +19,9 @@ import (
type udpHandler struct { type udpHandler struct {
sync.Mutex sync.Mutex
closed bool closed bool
sessKey string
proxyHost string proxyHost string
proxyPort uint16 proxyPort uint16
timeout time.Duration timeout time.Duration
@@ -40,8 +42,9 @@ func NewUDPHandler(proxyHost string, proxyPort uint16, timeout time.Duration, fa
tcpConns: make(map[core.UDPConn]net.Conn, 8), tcpConns: make(map[core.UDPConn]net.Conn, 8),
remoteAddrs: make(map[core.UDPConn]*net.UDPAddr, 8), remoteAddrs: make(map[core.UDPConn]*net.UDPAddr, 8),
fakeDns: fakeDns, fakeDns: fakeDns,
timeout: timeout,
sessionStater: sessionStater, sessionStater: sessionStater,
timeout: timeout,
closed: false,
} }
} }
@@ -86,7 +89,7 @@ func (h *udpHandler) fetchUDPInput(conn core.UDPConn, input net.PacketConn) {
} }
n, err = conn.WriteFrom(buf[int(3+len(addr)):n], resolvedAddr) n, err = conn.WriteFrom(buf[int(3+len(addr)):n], resolvedAddr)
if n > 0 && h.sessionStater != nil { if n > 0 && h.sessionStater != nil {
if sess := h.sessionStater.GetSession(conn); sess != nil { if sess := h.sessionStater.GetSession(h.sessKey); sess != nil {
sess.AddDownloadBytes(int64(n)) sess.AddDownloadBytes(int64(n))
} }
} }
@@ -194,8 +197,8 @@ func (h *udpHandler) connectInternal(conn core.UDPConn, targetAddr string) error
DownloadBytes: 0, DownloadBytes: 0,
SessionStart: time.Now(), SessionStart: time.Now(),
} }
key := fmt.Sprintf("%s:%s", conn.LocalAddr().Network(), conn.LocalAddr().String()) h.sessKey = fmt.Sprintf("%s:%s", conn.LocalAddr().Network(), conn.LocalAddr().String())
h.sessionStater.AddSession(key, sess) h.sessionStater.AddSession(h.sessKey, sess)
} }
log.Access(process, "proxy", "udp", conn.LocalAddr().String(), targetAddr) log.Access(process, "proxy", "udp", conn.LocalAddr().String(), targetAddr)
} }
@@ -204,7 +207,7 @@ func (h *udpHandler) connectInternal(conn core.UDPConn, targetAddr string) error
func (h *udpHandler) ReceiveTo(conn core.UDPConn, data []byte, addr *net.UDPAddr) error { func (h *udpHandler) ReceiveTo(conn core.UDPConn, data []byte, addr *net.UDPAddr) error {
h.Lock() h.Lock()
pc, ok1 := h.udpConns[conn] remoteUDPConn, ok1 := h.udpConns[conn]
remoteAddr, ok2 := h.remoteAddrs[conn] remoteAddr, ok2 := h.remoteAddrs[conn]
h.Unlock() h.Unlock()
@@ -220,9 +223,9 @@ func (h *udpHandler) ReceiveTo(conn core.UDPConn, data []byte, addr *net.UDPAddr
targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(addr.Port)) targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(addr.Port))
buf := append([]byte{0, 0, 0}, ParseAddr(targetAddr)...) buf := append([]byte{0, 0, 0}, ParseAddr(targetAddr)...)
buf = append(buf, data[:]...) buf = append(buf, data[:]...)
n, err := pc.WriteTo(buf, remoteAddr) n, err := remoteUDPConn.WriteTo(buf, remoteAddr)
if n > 0 && h.sessionStater != nil { if n > 0 && h.sessionStater != nil {
if sess := h.sessionStater.GetSession(conn); sess != nil { if sess := h.sessionStater.GetSession(h.sessKey); sess != nil {
sess.AddUploadBytes(int64(n)) sess.AddUploadBytes(int64(n))
} }
} }
@@ -245,12 +248,12 @@ func (h *udpHandler) Close(conn core.UDPConn) {
h.Lock() h.Lock()
defer h.Unlock() defer h.Unlock()
if c, ok := h.tcpConns[conn]; ok { if remoteConn, ok := h.tcpConns[conn]; ok {
c.Close() remoteConn.Close()
delete(h.tcpConns, conn) delete(h.tcpConns, conn)
} }
if pc, ok := h.udpConns[conn]; ok { if remoteUDPConn, ok := h.udpConns[conn]; ok {
pc.Close() remoteUDPConn.Close()
delete(h.udpConns, conn) delete(h.udpConns, conn)
} }
@@ -258,8 +261,7 @@ func (h *udpHandler) Close(conn core.UDPConn) {
delete(h.remoteAddrs, conn) delete(h.remoteAddrs, conn)
if h.sessionStater != nil { if h.sessionStater != nil {
key := fmt.Sprintf("%s:%s", conn.LocalAddr().Network(), conn.LocalAddr().String()) h.sessionStater.RemoveSession(h.sessKey)
h.sessionStater.RemoveSession(key)
} }
h.closed = true h.closed = true