diff --git a/dial.go b/dial.go index b970aff..bbfd0b9 100644 --- a/dial.go +++ b/dial.go @@ -388,13 +388,7 @@ func (obj *Dialer) verifyUDPSocks5(ctx *Response, conn net.Conn, proxyAddr Addre if err != nil { return } - var cnl context.CancelFunc - udpCtx, cnl := context.WithCancel(context.TODO()) - wrapConn = NewUDPConn(udpCtx, wrapConn, &net.UDPAddr{IP: proxyAddress.IP, Port: proxyAddress.Port}, remoteAddr) - go func() { - tools.Copy(io.Discard, conn) - cnl() - }() + wrapConn = NewUDPConn(conn, wrapConn, &net.UDPAddr{IP: proxyAddress.IP, Port: proxyAddress.Port}, remoteAddr) return } func (obj *Dialer) writeCmd(conn net.Conn, network string) (err error) { diff --git a/roundTripper.go b/roundTripper.go index 232d839..ca0d575 100644 --- a/roundTripper.go +++ b/roundTripper.go @@ -148,13 +148,13 @@ func (obj *roundTripper) ghttp3Dial(ctx *Response, remoteAddress Address, proxyA } conn = obj.newConnecotr() - conn.Conn, err = http3.NewClient(netConn, func() { + conn.Conn, err = http3.NewClient(netConn, udpConn, func() { conn.forceCnl(errors.New("http3 client close")) }) if ct, ok := udpConn.(interface { - Context() context.Context + SetTcpCloseFunc(f func(error)) }); ok { - context.AfterFunc(ct.Context(), func() { + ct.SetTcpCloseFunc(func(err error) { conn.forceCnl(errors.New("http3 client close with udp")) }) } @@ -193,14 +193,14 @@ func (obj *roundTripper) uhttp3Dial(ctx *Response, remoteAddress Address, proxyA return nil, err } conn = obj.newConnecotr() - conn.Conn, err = http3.NewClient(netConn, func() { + conn.Conn, err = http3.NewClient(netConn, udpConn, func() { conn.forceCnl(errors.New("http3 client close")) }) if ct, ok := udpConn.(interface { - Context() context.Context + SetTcpCloseFunc(f func(error)) }); ok { - context.AfterFunc(ct.Context(), func() { - conn.forceCnl(errors.New("http3 client close with udp")) + ct.SetTcpCloseFunc(func(err error) { + conn.forceCnl(errors.New("uhttp3 client close with udp")) }) } return @@ -459,6 +459,7 @@ func (obj *roundTripper) RoundTrip(ctx *Response) (err error) { return err } err = obj.poolRoundTrip(task) + // log.Print(err) if err == nil || !task.suppertRetry() { break } diff --git a/socks5.go b/socks5.go index 809e1e0..bf2da76 100644 --- a/socks5.go +++ b/socks5.go @@ -2,13 +2,14 @@ package requests import ( "bytes" - "context" "encoding/binary" "errors" "io" "math" "net" "strconv" + + "github.com/gospider007/tools" ) const MaxUdpPacket int = math.MaxUint16 - 28 @@ -93,23 +94,31 @@ func (a Address) IsZero() bool { } type UDPConn struct { - ctx context.Context net.PacketConn + tcpConn net.Conn prefix []byte bufRead [MaxUdpPacket]byte bufWrite [MaxUdpPacket]byte proxyAddress net.Addr remoteAddress Address + tcpCloseFunc func(error) } -func NewUDPConn(ctx context.Context, packConn net.PacketConn, proxyAddress net.Addr, remoteAddress Address) *UDPConn { - return &UDPConn{ - ctx: ctx, +func NewUDPConn(tcpConn net.Conn, packConn net.PacketConn, proxyAddress net.Addr, remoteAddress Address) *UDPConn { + ucon := &UDPConn{ + tcpConn: tcpConn, remoteAddress: remoteAddress, PacketConn: packConn, proxyAddress: proxyAddress, prefix: []byte{0, 0, 0}, } + go func() { + _, err := tools.Copy(io.Discard, tcpConn) + if ucon.tcpCloseFunc != nil { + ucon.tcpCloseFunc(err) + } + }() + return ucon } func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { @@ -153,6 +162,10 @@ func (c *UDPConn) SetReadBuffer(i int) error { func (c *UDPConn) SetWriteBuffer(i int) error { return c.PacketConn.(*net.UDPConn).SetWriteBuffer(i) } -func (c *UDPConn) Context() context.Context { - return c.ctx +func (c *UDPConn) SetTcpCloseFunc(f func(error)) { + c.tcpCloseFunc = f +} +func (c *UDPConn) Close() error { + c.tcpConn.Close() + return c.PacketConn.Close() } diff --git a/test/proxy/http3_proxy_test.go b/test/proxy/http3_proxy_test.go index 607e36f..bb8a39e 100644 --- a/test/proxy/http3_proxy_test.go +++ b/test/proxy/http3_proxy_test.go @@ -139,6 +139,7 @@ func TestHttp3Proxy2(t *testing.T) { } fmt.Println(resp.StatusCode()) fmt.Println(resp.Proto()) + resp.CloseConn() time.Sleep(time.Second) } }