[context] ctxCopy

This commit is contained in:
Jason
2019-08-11 20:35:49 +08:00
parent 513f6f2966
commit 7bc08271e8

View File

@@ -1,6 +1,7 @@
package socks
import (
"context"
"io"
"net"
"strconv"
@@ -33,6 +34,47 @@ func NewTCPHandler(proxyHost string, proxyPort uint16, fakeDns dns.FakeDns, sess
}
}
func ctxCopy(ctx context.Context, dst, src net.Conn) (written int64, err error) {
buf := core.NewBytes(core.BufSize)
defer core.FreeBytes(buf)
for {
select {
case <-ctx.Done():
return written, err
default:
}
src.SetReadDeadline(time.Now().Add(30*time.Second))
nr, er := src.Read(buf)
if nr > 0 {
dst.SetWriteDeadline(time.Now().Add(30*time.Second))
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
if ew, ok := ew.(net.Error); !ok || !ew.Timeout() {
err = ew
break
}
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er, ok := er.(net.Error); !ok || !er.Timeout() {
if er != io.EOF {
err = er
}
break
}
}
}
return written, err
}
func (h *tcpHandler) relay(localConn, remoteConn net.Conn) {
var once sync.Once
closeOnce := func() {
@@ -65,8 +107,8 @@ func (h *tcpHandler) relay(localConn, remoteConn net.Conn) {
}()
select {
case <-up: // Wait for UpLink done.
case <-down:
case <-up: // Wait for Up Link done
case <-down: // Wait for Down Link done
}
if h.sessionStater != nil {