diff --git a/proxy/socks/tcp.go b/proxy/socks/tcp.go index 2cef1bf..2ef35b7 100644 --- a/proxy/socks/tcp.go +++ b/proxy/socks/tcp.go @@ -1,7 +1,6 @@ package socks import ( - "context" "io" "net" "strconv" @@ -34,47 +33,6 @@ func NewTCPHandler(proxyHost string, proxyPort uint16, fakeDns dns.FakeDns, sess } } -func ctxCopy(ctx context.Context, dst, src net.Conn) (written int64, err error) { - buf := core.NewBytes(core.BufSize) - defer core.FreeBytes(buf) - - for { - select { - case <-ctx.Done(): - return written, err - default: - } - src.SetReadDeadline(time.Now().Add(30*time.Second)) - nr, er := src.Read(buf) - if nr > 0 { - dst.SetWriteDeadline(time.Now().Add(30*time.Second)) - nw, ew := dst.Write(buf[0:nr]) - if nw > 0 { - written += int64(nw) - } - if ew != nil { - if ew, ok := ew.(net.Error); !ok || !ew.Timeout() { - err = ew - break - } - } - if nr != nw { - err = io.ErrShortWrite - break - } - } - if er != nil { - if er, ok := er.(net.Error); !ok || !er.Timeout() { - if er != io.EOF { - err = er - } - break - } - } - } - return written, err -} - func (h *tcpHandler) relay(localConn, remoteConn net.Conn) { var once sync.Once closeOnce := func() { @@ -88,28 +46,23 @@ func (h *tcpHandler) relay(localConn, remoteConn net.Conn) { defer closeOnce() up := make(chan struct{}) - down := make(chan struct{}) // UpLink go func() { if _, err := io.Copy(remoteConn, localConn); err != nil { closeOnce() } + tcpCloseRead(remoteConn) up <- struct{}{} }() // DownLink - go func() { - if _, err := io.Copy(localConn, remoteConn); err != nil { - closeOnce() - } - down <- struct{}{} - }() - - select { - case <-up: // Wait for Up Link done - case <-down: // Wait for Down Link done + if _, err := io.Copy(localConn, remoteConn); err != nil { + closeOnce() } + tcpCloseRead(localConn) + + <-up if h.sessionStater != nil { h.sessionStater.RemoveSession(localConn) @@ -178,3 +131,9 @@ func tcpKeepAlive(conn net.Conn) { tcp.SetKeepAlivePeriod(30 * time.Second) } } + +func tcpCloseRead(conn net.Conn) { + if c, ok := conn.(interface{ CloseRead() error }); ok { + c.CloseRead() + } +}