更新流式传输操作

This commit is contained in:
Liujian
2025-02-26 11:13:10 +08:00
parent 8a59d592d0
commit cad39db9a4
12 changed files with 204 additions and 185 deletions

View File

@@ -7,7 +7,6 @@ import (
"net"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
@@ -17,9 +16,9 @@ import (
"github.com/valyala/fasthttp"
)
func ProxyTimeout(scheme string, host string, node eocontext.INode, req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration, stream bool) error {
func ProxyTimeout(scheme string, host string, node eocontext.INode, req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration) error {
addr := fmt.Sprintf("%s://%s", scheme, node.Addr())
err := defaultClient.ProxyTimeout(addr, host, req, resp, timeout, stream)
err := defaultClient.ProxyTimeout(addr, host, req, resp, timeout)
if err != nil {
node.Down()
}
@@ -81,7 +80,7 @@ func GenDialFunc(isTls bool) (fasthttp.DialFunc, error) {
return Dial, nil
}
func (c *Client) getHostClient(addr string, rewriteHost string, stream bool) (*fasthttp.HostClient, string, error) {
func (c *Client) getHostClient(addr string, rewriteHost string) (*fasthttp.HostClient, string, error) {
scheme, nodeAddr := readAddress(addr)
host := nodeAddr
@@ -96,7 +95,7 @@ func (c *Client) getHostClient(addr string, rewriteHost string, stream bool) (*f
if isTLS {
m = c.ms
}
key := fmt.Sprintf("%s-%s", host, strconv.FormatBool(stream))
key := host
hc := m[key]
c.mLock.RUnlock()
if hc != nil {
@@ -142,7 +141,7 @@ func (c *Client) getHostClient(addr string, rewriteHost string, stream bool) (*f
InsecureSkipVerify: true,
},
Dial: dial,
StreamResponseBody: stream,
StreamResponseBody: true,
MaxConns: DefaultMaxConns,
MaxConnWaitTimeout: DefaultMaxConnWaitTimeout,
RetryIf: func(request *fasthttp.Request) bool {
@@ -186,7 +185,7 @@ func (c *Client) getHostClient(addr string, rewriteHost string, stream bool) (*f
// continue in the background and the response will be discarded.
// If requests take too long and the connection pool gets filled up please
// try setting a ReadTimeout.
func (c *Client) ProxyTimeout(addr string, host string, req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration, stream bool) error {
func (c *Client) ProxyTimeout(addr string, host string, req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration) error {
request := req
request.Header.ResetConnectionClose()
request.Header.Set("Connection", "keep-alive")
@@ -197,7 +196,7 @@ func (c *Client) ProxyTimeout(addr string, host string, req *fasthttp.Request, r
}
}()
client, scheme, err := c.getHostClient(addr, host, stream)
client, scheme, err := c.getHostClient(addr, host)
if err != nil {
return err
}

View File

@@ -18,7 +18,7 @@ func TestMyselfProxyTimeout(t *testing.T) {
req.Header.SetContentType("application/json")
t.Log(string(req.URI().RequestURI()), req.URI().String(), string(req.URI().Host()), string(req.URI().Scheme()))
req.SetBody([]byte(`{"cpCode":"YTO","province":"广东省","city":"广州市"}`))
err := defaultClient.ProxyTimeout(addr, "", req, resp, 0, false)
err := defaultClient.ProxyTimeout(addr, "", req, resp, 0)
if err != nil {
t.Error(err)
}

View File

@@ -37,9 +37,24 @@ type cloneContext struct {
upstreamHostHandler eoscContext.UpstreamHostHandler
labels map[string]string
entry eosc.IEntry
bodyFinishes []http_service.BodyFinishFunc
responseError error
}
func (ctx *cloneContext) BodyFinish() {
for _, finishFunc := range ctx.bodyFinishes {
finishFunc(ctx)
}
return
}
func (ctx *cloneContext) AppendBodyFinishFunc(finishFunc http_service.BodyFinishFunc) {
if ctx.bodyFinishes == nil {
ctx.bodyFinishes = make([]http_service.BodyFinishFunc, 0, 10)
}
ctx.bodyFinishes = append(ctx.bodyFinishes, finishFunc)
}
func (ctx *cloneContext) ProxyClone() http_service.IRequest {
// 创建一个新的 ProxyRequest 实例
req := fasthttp.AcquireRequest()
@@ -167,7 +182,7 @@ func (ctx *cloneContext) SendTo(scheme string, node eoscContext.INode, timeout t
request.URI().SetHost(targetHost)
}
beginTime := time.Now()
ctx.responseError = fasthttp_client.ProxyTimeout(scheme, rewriteHost, node, request, ctx.response.Response, timeout, false)
ctx.responseError = fasthttp_client.ProxyTimeout(scheme, rewriteHost, node, request, ctx.response.Response, timeout)
var responseHeader fasthttp.ResponseHeader
if ctx.response.Response != nil {
responseHeader = ctx.response.Response.Header

View File

@@ -1,6 +1,7 @@
package http_context
import (
"bufio"
"context"
"fmt"
"io"
@@ -9,12 +10,11 @@ import (
"strings"
"time"
"github.com/eolinker/apinto/entries/ctx_key"
http_entry "github.com/eolinker/apinto/entries/http-entry"
"github.com/eolinker/eosc"
"github.com/eolinker/apinto/entries/ctx_key"
"github.com/eolinker/eosc/log"
"github.com/eolinker/eosc/utils/config"
@@ -45,6 +45,20 @@ type HttpContext struct {
labels map[string]string
port int
entry eosc.IEntry
bodyFinishes []http_service.BodyFinishFunc
}
func (ctx *HttpContext) BodyFinish() {
for _, finishFunc := range ctx.bodyFinishes {
finishFunc(ctx)
}
}
func (ctx *HttpContext) AppendBodyFinishFunc(finishFunc http_service.BodyFinishFunc) {
if ctx.bodyFinishes == nil {
ctx.bodyFinishes = make([]http_service.BodyFinishFunc, 0, 10)
}
ctx.bodyFinishes = append(ctx.bodyFinishes, finishFunc)
}
func (ctx *HttpContext) ProxyClone() http_service.IRequest {
@@ -182,7 +196,7 @@ func (ctx *HttpContext) SendTo(scheme string, node eoscContext.INode, timeout ti
beginTime := time.Now()
response := fasthttp.AcquireResponse()
ctx.response.responseError = fasthttp_client.ProxyTimeout(scheme, rewriteHost, node, request, response, timeout, ctx.GetLabel("stream") == "true")
ctx.response.responseError = fasthttp_client.ProxyTimeout(scheme, rewriteHost, node, request, response, timeout)
agent := newRequestAgent(&ctx.proxyRequest, host, scheme, response.Header, beginTime, time.Now())
@@ -202,23 +216,49 @@ func (ctx *HttpContext) SendTo(scheme string, node eoscContext.INode, timeout ti
response.Header.CopyTo(&ctx.response.Response.Header)
ctx.response.ResponseHeader.refresh()
if response.IsBodyStream() && response.Header.ContentLength() < 0 {
reader := response.BodyStream()
// 流式传输
ctx.response.Response.SetStatusCode(response.StatusCode())
buf := make([]byte, 4096)
for {
n, err := reader.Read(buf)
if err == io.EOF {
return nil
ctx.SetLabel("stream_running", "true")
ctx.response.Response.SetBodyStreamWriter(func(w *bufio.Writer) {
defer func() {
ctx.SetLabel("stream_running", "false")
ctx.FastFinish()
}()
reader := response.BodyStream()
buffer := make([]byte, 4096) // 4KB 缓冲区
for {
n, err := reader.Read(buffer)
if n > 0 {
chunk := buffer[:n]
for _, streamFunc := range ctx.Response().StreamFunc() {
chunk, err = streamFunc(ctx, chunk)
if err != nil {
log.Errorf("exec stream func error: %v", err)
break
}
chunk = append(chunk, []byte("\r\n")...)
}
n, err = w.Write(chunk)
if err != nil {
log.Errorf("stream write error: %v", err)
break
}
ctx.Response().SetBody(chunk)
w.Flush() // 实时发送数据
}
if err != nil {
if err == io.EOF {
break
}
log.Errorf("stream read error: %v", err)
break
}
}
if err != nil {
return err
}
agent.responseBody.Write(response.Body()[:n])
_, err = ctx.response.Response.BodyWriter().Write(buf[:n])
if err != nil {
return err
}
}
ctx.BodyFinish()
})
agent.setResponseLength(-1)
ctx.proxyRequests = append(ctx.proxyRequests, agent)
return nil
@@ -280,19 +320,27 @@ func (ctx *HttpContext) Clone() (eoscContext.EoContext, error) {
copyContext.proxyRequest.reset(req, ctx.requestReader.remoteAddr)
copyContext.response.reset(resp)
resp.Header.CopyTo(copyContext.response.header)
copyContext.response.refresh()
copyContext.completeHandler = ctx.completeHandler
copyContext.finishHandler = ctx.finishHandler
copyContext.response.Response.SetStatusCode(ctx.response.Response.StatusCode())
cloneLabels := make(map[string]string, len(ctx.labels))
for k, v := range ctx.labels {
cloneLabels[k] = v
}
copyContext.labels = cloneLabels
for _, finishFunc := range ctx.bodyFinishes {
copyContext.AppendBodyFinishFunc(finishFunc)
}
for _, streamFunc := range ctx.response.streamFuncArray {
copyContext.Response().AppendStreamFunc(streamFunc)
}
//记录请求时间
copyContext.ctx = context.WithValue(ctx.Context(), http_service.KeyCloneCtx, true)
copyContext.WithValue(ctx_key.CtxKeyRetry, 0)
return copyContext, nil
}
@@ -334,8 +382,13 @@ func (ctx *HttpContext) RequestId() string {
return ctx.requestID
}
// Finish finish
// FastFinish finish
func (ctx *HttpContext) FastFinish() {
streamRunning := ctx.GetLabel("stream_running")
if streamRunning == "true" {
// 暂时不释放
return
}
if ctx.response.responseError != nil {
ctx.fastHttpRequestCtx.SetStatusCode(504)
ctx.fastHttpRequestCtx.SetBodyString(ctx.response.responseError.Error())

View File

@@ -1,34 +0,0 @@
package http_context
import (
"io"
"strings"
"github.com/eolinker/eosc/log"
)
type Reader struct {
reader io.Reader
agent *requestAgent
record strings.Builder
requestId string
resp *Response
}
func (r *Reader) Read(p []byte) (int, error) {
n, err := r.reader.Read(p)
if err != nil {
log.Debug("read error:", err)
log.DebugF("request id %s ,read body: %s", r.requestId, r.record.String())
return 0, err
}
r.record.Write(p[:n])
if r.agent != nil {
r.agent.responseBody.Write(p[:n])
}
if r.resp != nil {
r.resp.AppendBody(p[:n])
}
return n, nil
}

View File

@@ -2,14 +2,11 @@ package http_context
import (
"bytes"
"io"
"strconv"
"strings"
"sync"
"time"
http_service "github.com/eolinker/eosc/eocontext/http-context"
"go.uber.org/zap/buffer"
"github.com/valyala/fasthttp"
)
@@ -19,75 +16,39 @@ var _ http_service.IResponse = (*Response)(nil)
type Response struct {
ResponseHeader
*fasthttp.Response
statusCode int
length int
responseTime time.Duration
proxyStatusCode int
responseError error
remoteIP string
remotePort int
buf *buffer.Buffer
streamBody *bytes.Buffer
streamFuncArray []http_service.StreamFunc
}
type BodyStream struct {
reader io.Reader
streamReadHandler []func(p []byte) error
streamWriteHandler []func(p []byte) ([]byte, error)
func (r *Response) StreamFunc() []http_service.StreamFunc {
return r.streamFuncArray
}
func NewBodyStream(reader io.Reader) *BodyStream {
buf := &bytes.Buffer{}
buf.Bytes()
return &BodyStream{reader: reader}
}
func (b *BodyStream) AppendReaderFunc(f func(p []byte) error) {
if b.streamReadHandler == nil {
b.streamReadHandler = make([]func(p []byte) error, 0)
func (r *Response) AppendStreamFunc(streamFunc http_service.StreamFunc) {
if r.streamFuncArray == nil {
r.streamFuncArray = make([]http_service.StreamFunc, 0, 10)
}
b.streamReadHandler = append(b.streamReadHandler, f)
r.streamFuncArray = append(r.streamFuncArray, streamFunc)
}
func (b *BodyStream) AppendWriterFunc(f func(p []byte) ([]byte, error)) {
if b.streamWriteHandler == nil {
b.streamWriteHandler = make([]func(p []byte) ([]byte, error), 0)
}
b.streamWriteHandler = append(b.streamWriteHandler, f)
}
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024) // 默认 32KB 缓冲区
},
}
func (b *BodyStream) Read(p []byte) (n int, err error) {
tmp := bufferPool.Get().([]byte)
defer bufferPool.Put(tmp)
n, err = b.reader.Read(tmp)
if err != nil {
return 0, err
}
org := tmp[:n]
for _, fn := range b.streamWriteHandler {
result, err := fn(org)
func (r *Response) StreamFuncHandle(ctx http_service.IHttpContext, org []byte) ([]byte, error) {
result := make([]byte, len(org))
copy(result, org)
var err error
for _, streamFunc := range r.streamFuncArray {
result, err = streamFunc(ctx, result)
if err != nil {
return 0, err
return nil, err
}
org = result
}
org = append(org, []byte("\n")...)
copy(p, org)
return len(org), nil
}
func (b *BodyStream) Write(p []byte) (n int, err error) {
return 0, nil
}
func (r *Response) GetBodyStream() http_service.IResponseStream {
return r.bodyStream
return result, nil
}
func (r *Response) ContentLength() int {
@@ -117,6 +78,7 @@ func (r *Response) Finish() error {
r.Response = nil
r.responseError = nil
r.proxyStatusCode = 0
r.streamBody = nil
return nil
}
func (r *Response) reset(resp *fasthttp.Response) {
@@ -124,6 +86,7 @@ func (r *Response) reset(resp *fasthttp.Response) {
r.ResponseHeader.reset(&resp.Header)
r.responseError = nil
r.proxyStatusCode = 0
r.streamBody = &bytes.Buffer{}
}
func (r *Response) BodyLen() int {
@@ -138,7 +101,7 @@ func (r *Response) GetBody() []byte {
r.Response.SetBody(body)
}
if r.IsBodyStream() {
return nil
return r.streamBody.Bytes()
}
return r.Response.Body()
}
@@ -149,6 +112,8 @@ func (r *Response) IsBodyStream() bool {
func (r *Response) SetBody(bytes []byte) {
if r.IsBodyStream() {
r.streamBody.Write(bytes)
// 不处理
return
}
if strings.Contains(r.GetHeader("Content-Encoding"), "gzip") {
@@ -164,15 +129,21 @@ func (r *Response) StatusCode() int {
if r.responseError != nil {
return 504
}
return r.Response.StatusCode()
return r.statusCode
}
func (r *Response) Status() string {
return strconv.Itoa(r.StatusCode())
if r.statusCode == 0 {
r.statusCode = r.Response.StatusCode()
}
return strconv.Itoa(r.statusCode)
}
func (r *Response) SetStatus(code int, status string) {
r.Response.SetStatusCode(code)
r.statusCode = code
r.responseError = nil
}