diff --git a/body.go b/body.go index 74de514..61e5bd3 100644 --- a/body.go +++ b/body.go @@ -189,92 +189,47 @@ func (obj *OrderData) formWriteMain(writer *multipart.Writer) (err error) { return writer.Close() } -func paramsWrite(buf *bytes.Buffer, key string, val any) error { - if buf.Len() > 0 { - buf.WriteByte('&') +func paramsWrite(content []byte, key string, val any) []byte { + if len(content) > 0 { + content = append(content, '&') } - buf.WriteString(url.QueryEscape(key)) - buf.WriteByte('=') - var err error + content = append(content, []byte(url.QueryEscape(key))...) + content = append(content, '=') switch value := val.(type) { case []byte: - _, err = buf.Write(value) + content = append(content, value...) case string: - _, err = buf.WriteString(value) + content = append(content, []byte(value)...) default: - v, err2 := gson.Encode(val) - if err2 != nil { - return err2 - } - _, err = buf.Write(v) + v, _ := gson.Encode(val) + content = append(content, v...) } - return err + return content } func (obj *OrderData) MarshalJSON() ([]byte, error) { - buf := bytes.NewBuffer(nil) - err := buf.WriteByte('{') - if err != nil { - return nil, err - } + content := []byte{} + content = append(content, '{') for i, value := range obj.data { if i > 0 { - if err = buf.WriteByte(','); err != nil { - return nil, err - } - } - if _, err = buf.WriteString(`"` + value.key + `":`); err != nil { - return nil, err + content = append(content, ',') } + content = append(content, []byte(`"`+value.key+`":`)...) val, err := gson.Encode(value.val) if err != nil { return nil, err } - if _, err = buf.Write(val); err != nil { - return nil, err - } + content = append(content, val...) } - if err = buf.WriteByte('}'); err != nil { - return nil, err - } - return buf.Bytes(), nil + content = append(content, '}') + return content, nil } -func (obj *RequestOption) newBody(val any) (io.Reader, *OrderData, error) { - switch value := val.(type) { - case *OrderData: - return nil, value, nil - case io.Reader: - obj.readOne = true - return value, nil, nil - case string: - return bytes.NewReader(tools.StringToBytes(value)), nil, nil - case []byte: - return bytes.NewReader(value), nil, nil - case map[string]any: - orderMap := NewOrderData() - for key, val := range value { - orderMap.Add(key, val) - } - return nil, orderMap, nil - default: - jsonData, err := gson.Decode(val) - if err != nil { - return nil, nil, errors.New("invalid body type") - } - orderMap := NewOrderData() - for kk, vv := range jsonData.Map() { - orderMap.Add(kk, vv.Value()) - } - return nil, orderMap, nil - } -} - -func (obj *OrderData) parseParams() *bytes.Buffer { - buf := bytes.NewBuffer(nil) +func (obj *OrderData) parseParams() []byte { + content := []byte{} for _, value := range obj.data { - paramsWrite(buf, value.key, value.val) + content = paramsWrite(content, value.key, value.val) } - return buf + return content } func (obj *OrderData) parseForm(ctx context.Context, boundary string) (io.Reader, bool, error) { if len(obj.data) == 0 { @@ -301,27 +256,6 @@ func (obj *OrderData) parseForm(ctx context.Context, boundary string) (io.Reader } return bytes.NewReader(body.Bytes()), false, err } -func (obj *OrderData) parseData() io.Reader { - val := obj.parseParams().Bytes() - if val == nil { - return nil - } - return bytes.NewReader(val) -} -func (obj *OrderData) parseJson() (io.Reader, error) { - con, err := obj.MarshalJSON() - if err != nil { - return nil, err - } - return bytes.NewReader(con), nil -} -func (obj *OrderData) parseText() (io.Reader, error) { - con, err := obj.MarshalJSON() - if err != nil { - return nil, err - } - return bytes.NewReader(con), nil -} // Upload files with form-data, type File struct { @@ -342,94 +276,178 @@ func randomBoundary() (string, string) { func (obj *RequestOption) initBody(ctx context.Context) (io.Reader, error) { if obj.Body != nil { - body, orderData, err := obj.newBody(obj.Body) - if err != nil { - return nil, err + switch value := obj.Body.(type) { + case io.Reader: + obj.readOne = true + return value, nil + case string: + return bytes.NewReader(tools.StringToBytes(value)), nil + case []byte: + return bytes.NewReader(value), nil + default: + content, err := gson.Encode(value) + if err != nil { + return nil, errors.New("invalid body type") + } + return bytes.NewReader(content), nil } - if body != nil { - return body, nil - } - con, err := orderData.MarshalJSON() - if err != nil { - return nil, err - } - return bytes.NewReader(con), nil } else if obj.Form != nil { var boundary string if obj.ContentType == "" { obj.ContentType, boundary = randomBoundary() } - body, orderData, err := obj.newBody(obj.Form) - if err != nil { - return nil, err - } - if body != nil { + switch value := obj.Form.(type) { + case *OrderData: + body, once, err := value.parseForm(ctx, boundary) + if err != nil { + return nil, err + } + obj.readOne = once + return body, nil + case io.Reader: + obj.readOne = true + return value, nil + case string: + return bytes.NewReader(tools.StringToBytes(value)), nil + case []byte: + return bytes.NewReader(value), nil + case map[string]any: + orderMap := NewOrderData() + for key, val := range value { + orderMap.Add(key, val) + } + body, once, err := orderMap.parseForm(ctx, boundary) + if err != nil { + return nil, err + } + obj.readOne = once + return body, nil + default: + jsonData, err := gson.Decode(value) + if err != nil { + return nil, errors.New("invalid body type") + } + orderMap := NewOrderData() + for kk, vv := range jsonData.Map() { + orderMap.Add(kk, vv.Value()) + } + body, once, err := orderMap.parseForm(ctx, boundary) + if err != nil { + return nil, err + } + obj.readOne = once return body, nil } - body, once, err := orderData.parseForm(ctx, boundary) - if err != nil { - return nil, err - } - obj.readOne = once - return body, err } else if obj.Data != nil { if obj.ContentType == "" { obj.ContentType = "application/x-www-form-urlencoded" } - body, orderData, err := obj.newBody(obj.Data) - if err != nil { - return nil, err + switch value := obj.Data.(type) { + case *OrderData: + return bytes.NewReader(value.parseParams()), nil + case io.Reader: + obj.readOne = true + return value, nil + case string: + return bytes.NewReader(tools.StringToBytes(value)), nil + case []byte: + return bytes.NewReader(value), nil + case map[string]any: + orderMap := NewOrderData() + for key, val := range value { + orderMap.Add(key, val) + } + return bytes.NewReader(orderMap.parseParams()), nil + default: + jsonData, err := gson.Decode(value) + if err != nil { + return nil, errors.New("invalid body type") + } + orderMap := NewOrderData() + for kk, vv := range jsonData.Map() { + orderMap.Add(kk, vv.Value()) + } + return bytes.NewReader(orderMap.parseParams()), nil } - if body != nil { - return body, nil - } - return orderData.parseData(), nil } else if obj.Json != nil { if obj.ContentType == "" { obj.ContentType = "application/json" } - body, orderData, err := obj.newBody(obj.Json) - if err != nil { - return nil, err + switch value := obj.Json.(type) { + case io.Reader: + obj.readOne = true + return value, nil + case string: + return bytes.NewReader(tools.StringToBytes(value)), nil + case []byte: + return bytes.NewReader(value), nil + default: + content, err := gson.Encode(value) + if err != nil { + return nil, errors.New("invalid body type") + } + return bytes.NewReader(content), nil } - if body != nil { - return body, nil - } - return orderData.parseJson() } else if obj.Text != nil { if obj.ContentType == "" { obj.ContentType = "text/plain" } - body, orderData, err := obj.newBody(obj.Text) - if err != nil { - return nil, err + switch value := obj.Text.(type) { + case io.Reader: + obj.readOne = true + return value, nil + case string: + return bytes.NewReader(tools.StringToBytes(value)), nil + case []byte: + return bytes.NewReader(value), nil + default: + content, err := gson.Encode(value) + if err != nil { + return nil, errors.New("invalid body type") + } + return bytes.NewReader(content), nil } - if body != nil { - return body, nil - } - return orderData.parseText() } else { return nil, nil } } + func (obj *RequestOption) initParams() (*url.URL, error) { baseUrl := cloneUrl(obj.Url) if obj.Params == nil { return baseUrl, nil } - body, dataData, err := obj.newBody(obj.Params) - if err != nil { - return nil, err - } var query string - if body != nil { - paramsBytes, err := io.ReadAll(body) + switch value := obj.Params.(type) { + case *OrderData: + query = tools.BytesToString(value.parseParams()) + case io.Reader: + obj.readOne = true + con, err := io.ReadAll(value) if err != nil { return nil, err } - query = tools.BytesToString(paramsBytes) - } else { - query = dataData.parseParams().String() + query = tools.BytesToString(con) + case string: + query = value + case []byte: + query = tools.BytesToString(value) + case map[string]any: + orderMap := NewOrderData() + for key, val := range value { + orderMap.Add(key, val) + } + query = tools.BytesToString(orderMap.parseParams()) + default: + jsonData, err := gson.Decode(value) + if err != nil { + return nil, errors.New("invalid body type") + } + orderMap := NewOrderData() + for kk, vv := range jsonData.Map() { + orderMap.Add(kk, vv.Value()) + } + query = tools.BytesToString(orderMap.parseParams()) } if query == "" { return baseUrl, nil diff --git a/roundTripper.go b/roundTripper.go index d5586dd..0796464 100644 --- a/roundTripper.go +++ b/roundTripper.go @@ -306,10 +306,6 @@ func (obj *roundTripper) RoundTrip(ctx *Response) (err error) { select { case <-ctx.Context().Done(): return context.Cause(ctx.Context()) - default: - } - ctx.response = nil - select { case conn = <-obj.getConnPool(ctx.connKey): ctx.isNewConn = false default: @@ -317,8 +313,7 @@ func (obj *roundTripper) RoundTrip(ctx *Response) (err error) { conn, err = obj.newConn(ctx) } if err == nil { - err = ctx.doRequest(conn) - if err == nil { + if err = ctx.doRequest(conn); err == nil { break } } diff --git a/test/request/stream_test.go b/test/request/stream_test.go index 8aaabfa..16393bd 100644 --- a/test/request/stream_test.go +++ b/test/request/stream_test.go @@ -86,7 +86,7 @@ func TestStreamWithConn3(t *testing.T) { t.Fatal(err) } if resp.StatusCode() != 200 { - log.Print("状态吗为:",resp.StatusCode()) + log.Print("状态吗为:", resp.StatusCode()) t.Fatal("resp.StatusCode()!= 200") } // body := resp.Body()