update socks tcp.go

This commit is contained in:
Jason
2019-08-11 20:41:52 +08:00
parent 7bc08271e8
commit 6c6d216aae

View File

@@ -1,7 +1,6 @@
package socks package socks
import ( import (
"context"
"io" "io"
"net" "net"
"strconv" "strconv"
@@ -34,47 +33,6 @@ func NewTCPHandler(proxyHost string, proxyPort uint16, fakeDns dns.FakeDns, sess
} }
} }
func ctxCopy(ctx context.Context, dst, src net.Conn) (written int64, err error) {
buf := core.NewBytes(core.BufSize)
defer core.FreeBytes(buf)
for {
select {
case <-ctx.Done():
return written, err
default:
}
src.SetReadDeadline(time.Now().Add(30*time.Second))
nr, er := src.Read(buf)
if nr > 0 {
dst.SetWriteDeadline(time.Now().Add(30*time.Second))
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
if ew, ok := ew.(net.Error); !ok || !ew.Timeout() {
err = ew
break
}
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er, ok := er.(net.Error); !ok || !er.Timeout() {
if er != io.EOF {
err = er
}
break
}
}
}
return written, err
}
func (h *tcpHandler) relay(localConn, remoteConn net.Conn) { func (h *tcpHandler) relay(localConn, remoteConn net.Conn) {
var once sync.Once var once sync.Once
closeOnce := func() { closeOnce := func() {
@@ -88,28 +46,23 @@ func (h *tcpHandler) relay(localConn, remoteConn net.Conn) {
defer closeOnce() defer closeOnce()
up := make(chan struct{}) up := make(chan struct{})
down := make(chan struct{})
// UpLink // UpLink
go func() { go func() {
if _, err := io.Copy(remoteConn, localConn); err != nil { if _, err := io.Copy(remoteConn, localConn); err != nil {
closeOnce() closeOnce()
} }
tcpCloseRead(remoteConn)
up <- struct{}{} up <- struct{}{}
}() }()
// DownLink // DownLink
go func() { if _, err := io.Copy(localConn, remoteConn); err != nil {
if _, err := io.Copy(localConn, remoteConn); err != nil { closeOnce()
closeOnce()
}
down <- struct{}{}
}()
select {
case <-up: // Wait for Up Link done
case <-down: // Wait for Down Link done
} }
tcpCloseRead(localConn)
<-up
if h.sessionStater != nil { if h.sessionStater != nil {
h.sessionStater.RemoveSession(localConn) h.sessionStater.RemoveSession(localConn)
@@ -178,3 +131,9 @@ func tcpKeepAlive(conn net.Conn) {
tcp.SetKeepAlivePeriod(30 * time.Second) tcp.SetKeepAlivePeriod(30 * time.Second)
} }
} }
func tcpCloseRead(conn net.Conn) {
if c, ok := conn.(interface{ CloseRead() error }); ok {
c.CloseRead()
}
}