disproxy bug,optimize http.readResponse, optimize newRequest

This commit is contained in:
bxd
2023-12-06 17:54:40 +08:00
parent 9928687a3e
commit 2d959bdeb3
5 changed files with 125 additions and 48 deletions

12
conn.go
View File

@@ -7,7 +7,7 @@ import (
"io"
"log"
"net"
"net/http"
"net/textproto"
"sync"
"sync/atomic"
@@ -27,7 +27,7 @@ type connecotr struct {
rawConn net.Conn
h2RawConn *http2.ClientConn
proxy string
r *bufio.Reader
r *textproto.Reader
w *bufio.Writer
pr *pipCon
inPool bool
@@ -85,12 +85,8 @@ func (obj *connecotr) wrapBody(task *reqTask) {
task.res.Body = body
}
func (obj *connecotr) http1Req(task *reqTask) {
task.err = httpWrite(task.req, obj.w, task.orderHeaders)
// if task.err = task.req.Write(obj.w); task.err == nil {
// task.err = obj.w.Flush()
// }
if task.err == nil {
task.res, task.err = http.ReadResponse(obj.r, task.req)
if task.err = httpWrite(task.req, obj.w, task.orderHeaders); task.err == nil {
task.res, task.err = readResponse(obj.r, task.req)
if task.err != nil {
task.err = tools.WrapError(task.err, "http1 read error")
} else if task.res == nil {

View File

@@ -163,14 +163,13 @@ func (obj *RequestOption) initParams() (*url.URL, error) {
if query == "" {
return obj.Url, nil
}
pu := cloneUrl(obj.Url)
pquery := pu.Query().Encode()
pquery := obj.Url.Query().Encode()
if pquery == "" {
pu.RawQuery = query
obj.Url.RawQuery = query
} else {
pu.RawQuery = pquery + "&" + query
obj.Url.RawQuery = pquery + "&" + query
}
return pu, nil
return obj.Url, nil
}
func (obj *Client) newRequestOption(option RequestOption) RequestOption {
// start

View File

@@ -229,17 +229,19 @@ func (obj *Client) Request(ctx context.Context, method string, href string, opti
rawOption = options[0]
}
optionBak := obj.newRequestOption(rawOption)
if optionBak.Url == nil {
if optionBak.Url, err = url.Parse(href); err != nil {
if optionBak.Method == "" {
optionBak.Method = method
}
uhref := optionBak.Url
if uhref == nil {
if uhref, err = url.Parse(href); err != nil {
err = tools.WrapError(err, "url parse error")
return
}
}
if optionBak.Method == "" {
optionBak.Method = method
}
for maxRetries := 0; maxRetries <= optionBak.MaxRetries; maxRetries++ {
option := optionBak
option.Url = cloneUrl(uhref)
response, err = obj.request(ctx, &option)
if err == nil || errors.Is(err, errFatal) || option.once {
return
@@ -280,6 +282,22 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response
response.disDecode = option.DisDecode
response.stream = option.Stream
//init headers and orderheaders,befor init ctxData
headers, err := option.initHeaders()
if err != nil {
return response, tools.WrapError(err, errors.New("tempRequest init headers error"), err)
}
if headers == nil {
headers = http.Header{
"User-Agent": []string{UserAgent},
"Accept": []string{"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7"},
"Accept-Encoding": []string{"gzip, deflate, br"},
"Accept-Language": []string{AcceptLanguage},
"Sec-Ch-Ua": []string{SecChUa},
"Sec-Ch-Ua-Mobile": []string{"?0"},
"Sec-Ch-Ua-Platform": []string{`"Windows"`},
}
}
//init ctxData
ctxData, err := NewReqCtxData(ctx, option)
if err != nil {
@@ -303,30 +321,11 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response
return response, tools.WrapError(err, errors.New("tempRequest init body error"), err)
}
//create request
reqs, err := http.NewRequestWithContext(response.ctx, strings.ToUpper(option.Method), href.String(), body)
reqs, err := newRequestWithContext(response.ctx, option.Method, href, body)
if err != nil {
return response, tools.WrapError(errFatal, errors.New("tempRequest 构造request失败"), err)
}
//init headers
//init headers and orderheaders,befor init ctxData
headers, err := option.initHeaders()
if err != nil {
return response, tools.WrapError(err, errors.New("tempRequest init headers error"), err)
}
if headers == nil {
reqs.Header = http.Header{
"User-Agent": []string{UserAgent},
"Accept": []string{"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7"},
"Accept-Encoding": []string{"gzip, deflate, br"},
"Accept-Language": []string{AcceptLanguage},
"Sec-Ch-Ua": []string{SecChUa},
"Sec-Ch-Ua-Mobile": []string{"?0"},
"Sec-Ch-Ua-Platform": []string{`"Windows"`},
}
} else {
reqs.Header = headers
}
reqs.Header = headers
//add Referer
if reqs.Header.Get("Referer") == "" {
if option.Referer != "" {

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls"
"errors"
"net"
"net/textproto"
"net/url"
"strings"
"time"
@@ -141,14 +142,16 @@ func (obj *roundTripper) newConnecotr(netConn net.Conn) *connecotr {
return conne
}
func (obj *roundTripper) dial(ctxData *reqCtxData, req *http.Request) (conn *connecotr, err error) {
proxy := cloneUrl(ctxData.proxy)
if proxy == nil && !ctxData.disProxy && obj.getProxy != nil {
proxyStr, err := obj.getProxy(req.Context(), proxy)
if err != nil {
return conn, err
}
if proxy, err = gtls.VerifyProxy(proxyStr); err != nil {
return conn, err
var proxy *url.URL
if !ctxData.disProxy {
if proxy = cloneUrl(ctxData.proxy); proxy == nil && obj.getProxy != nil {
proxyStr, err := obj.getProxy(req.Context(), proxy)
if err != nil {
return conn, err
}
if proxy, err = gtls.VerifyProxy(proxyStr); err != nil {
return conn, err
}
}
}
netConn, err := obj.dialer.DialContextWithProxy(req.Context(), ctxData, "tcp", req.URL.Scheme, getAddr(req.URL), getHost(req), proxy, obj.tlsConfigClone())
@@ -190,7 +193,7 @@ func (obj *roundTripper) dial(ctxData *reqCtxData, req *http.Request) (conn *con
return conne, err
}
} else {
conne.r, conne.w = bufio.NewReader(conne), bufio.NewWriter(conne)
conne.r, conne.w = textproto.NewReader(bufio.NewReader(conne)), bufio.NewWriter(conne)
}
return conne, err
}
@@ -261,6 +264,13 @@ func (obj *roundTripper) RoundTrip(req *http.Request) (response *http.Response,
ctxData := GetReqCtxData(req.Context())
if ctxData.requestCallBack != nil {
if err = ctxData.requestCallBack(req.Context(), req, nil); err != nil {
if err == http.ErrUseLastResponse {
if req.Response == nil {
return nil, errors.New("errUseLastResponse response is nil")
} else {
return req.Response, nil
}
}
return nil, err
}
}

View File

@@ -3,11 +3,15 @@ package requests
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/textproto"
"net/url"
"strconv"
"strings"
"sync"
_ "unsafe"
@@ -85,6 +89,12 @@ func removeZone(host string) string
//go:linkname shouldSendContentLength net/http.(*transferWriter).shouldSendContentLength
func shouldSendContentLength(t *http.Request) bool
//go:linkname removeEmptyPort net/http.removeEmptyPort
func removeEmptyPort(host string) string
//go:linkname readTransfer net/http.readTransfer
func readTransfer(msg any, r *bufio.Reader) (err error)
func httpWrite(r *http.Request, w *bufio.Writer, orderHeaders []string) (err error) {
host := r.Host
if host == "" {
@@ -182,3 +192,66 @@ func init() {
return strings.Builder{}
}
}
func newRequestWithContext(ctx context.Context, method string, u *url.URL, body io.Reader) (*http.Request, error) {
req := (&http.Request{}).WithContext(ctx)
if method == "" {
req.Method = http.MethodGet
} else {
req.Method = strings.ToUpper(method)
}
req.URL = u
req.Proto = "HTTP/1.1"
req.ProtoMajor = 1
req.ProtoMinor = 1
req.Host = u.Host
u.Host = removeEmptyPort(u.Host)
if body != nil {
if v, ok := body.(interface{ Len() int }); ok {
req.ContentLength = int64(v.Len())
}
rc, ok := body.(io.ReadCloser)
if !ok {
rc = io.NopCloser(body)
}
req.Body = rc
}
return req, nil
}
func readResponse(tp *textproto.Reader, req *http.Request) (*http.Response, error) {
resp := &http.Response{
Request: req,
}
// Parse the first line of the response.
line, err := tp.ReadLine()
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return nil, err
}
proto, status, ok := strings.Cut(line, " ")
if !ok {
return nil, errors.New("malformed HTTP response")
}
resp.Proto = proto
resp.Status = strings.TrimLeft(status, " ")
statusCode, _, _ := strings.Cut(resp.Status, " ")
if resp.StatusCode, err = strconv.Atoi(statusCode); err != nil {
return nil, errors.New("malformed HTTP status code")
}
if resp.ProtoMajor, resp.ProtoMinor, ok = http.ParseHTTPVersion(resp.Proto); !ok {
return nil, errors.New("malformed HTTP version")
}
// Parse the response headers.
mimeHeader, err := tp.ReadMIMEHeader()
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return nil, err
}
resp.Header = http.Header(mimeHeader)
return resp, readTransfer(resp, tp.R)
}