mirror of
https://github.com/eolinker/apinto
synced 2025-11-02 22:54:02 +08:00
更新流式传输操作
This commit is contained in:
@@ -73,9 +73,7 @@ func (c *Chat) RequestConvert(ctx eocontext.EoContext, extender map[string]inter
|
||||
for k, v := range extender {
|
||||
baseCfg.SetAppend(k, v)
|
||||
}
|
||||
//if paramModel := ctx.GetLabel(convert.ParamAIModel); paramModel != "" {
|
||||
// baseCfg.SetAppend("model", paramModel)
|
||||
//}
|
||||
|
||||
// Marshal the updated configuration back into JSON.
|
||||
body, err = json.Marshal(baseCfg)
|
||||
if err != nil {
|
||||
@@ -87,10 +85,59 @@ func (c *Chat) RequestConvert(ctx eocontext.EoContext, extender map[string]inter
|
||||
|
||||
// SetProvider the modified body in the HTTP context.
|
||||
httpContext.Proxy().Body().SetRaw("application/json", body)
|
||||
|
||||
httpContext.Response().AppendStreamFunc(c.streamFunc())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Chat) streamFunc() http_context.StreamFunc {
|
||||
return func(ctx http_context.IHttpContext, p []byte) ([]byte, error) {
|
||||
data := eosc.NewBase[Response]()
|
||||
err := json.Unmarshal(p, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status := ctx.Response().StatusCode()
|
||||
switch status {
|
||||
case 200:
|
||||
// Calculate the token consumption for a successful request.
|
||||
usage := data.Config
|
||||
if usage.Done {
|
||||
convert.SetAIStatusNormal(ctx)
|
||||
convert.SetAIModelInputToken(ctx, usage.PromptEvalCount)
|
||||
convert.SetAIModelOutputToken(ctx, usage.EvalCount)
|
||||
convert.SetAIModelTotalToken(ctx, usage.PromptEvalCount+usage.EvalCount)
|
||||
}
|
||||
case 404:
|
||||
convert.SetAIStatusInvalid(ctx)
|
||||
case 429:
|
||||
convert.SetAIStatusExceeded(ctx)
|
||||
}
|
||||
|
||||
// Prepare the response body for the client.
|
||||
responseBody := &convert.ClientResponse{}
|
||||
resp := data.Config
|
||||
if resp.Message != nil {
|
||||
responseBody.Message = &convert.Message{
|
||||
Role: resp.Message.Role,
|
||||
Content: resp.Message.Content,
|
||||
}
|
||||
if resp.Done {
|
||||
responseBody.FinishReason = convert.FinishStop
|
||||
}
|
||||
} else {
|
||||
responseBody.Code = -1
|
||||
responseBody.Error = "response message is nil"
|
||||
}
|
||||
|
||||
// Marshal the modified response body back into JSON.
|
||||
body, err := json.Marshal(responseBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ResponseConvert converts the response body for the Chat mode.
|
||||
// It processes the response to ensure it conforms to the expected format and encoding.
|
||||
func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error {
|
||||
@@ -99,57 +146,14 @@ func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status := httpContext.Response().StatusCode()
|
||||
switch status {
|
||||
case 200:
|
||||
convert.SetAIStatusNormal(ctx)
|
||||
}
|
||||
if httpContext.Response().IsBodyStream() {
|
||||
bodyStream := httpContext.Response().GetBodyStream()
|
||||
if bodyStream != nil {
|
||||
bodyStream.AppendWriterFunc(func(p []byte) ([]byte, error) {
|
||||
// Parse the response body into a base configuration.
|
||||
data := eosc.NewBase[Response]()
|
||||
err = json.Unmarshal(p, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch status {
|
||||
case 200:
|
||||
// Calculate the token consumption for a successful request.
|
||||
usage := data.Config
|
||||
if usage.Done {
|
||||
convert.SetAIStatusNormal(ctx)
|
||||
convert.SetAIModelInputToken(ctx, usage.PromptEvalCount)
|
||||
convert.SetAIModelOutputToken(ctx, usage.EvalCount)
|
||||
convert.SetAIModelTotalToken(ctx, usage.PromptEvalCount+usage.EvalCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare the response body for the client.
|
||||
responseBody := &convert.ClientResponse{}
|
||||
resp := data.Config
|
||||
if resp.Message != nil {
|
||||
responseBody.Message = &convert.Message{
|
||||
Role: resp.Message.Role,
|
||||
Content: resp.Message.Content,
|
||||
}
|
||||
if resp.Done {
|
||||
responseBody.FinishReason = convert.FinishStop
|
||||
}
|
||||
} else {
|
||||
responseBody.Code = -1
|
||||
responseBody.Error = "response message is nil"
|
||||
}
|
||||
|
||||
// Marshal the modified response body back into JSON.
|
||||
body, err := json.Marshal(responseBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Retrieve the response body.
|
||||
|
||||
@@ -24,22 +24,31 @@ func (l *accessLog) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (er
|
||||
return http_service.DoHttpFilter(l, ctx, next)
|
||||
}
|
||||
|
||||
func (l *accessLog) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) (err error) {
|
||||
err = next.DoChain(ctx)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
func (l *accessLog) bodyFinish(ctx http_service.IHttpContext) {
|
||||
outputs := l.proxy.List()
|
||||
entry := http_entry.NewEntry(ctx)
|
||||
for _, v := range outputs {
|
||||
|
||||
err = v.Output(entry)
|
||||
err := v.Output(entry)
|
||||
if err != nil {
|
||||
log.Error("access log http-entry error:", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *accessLog) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) (err error) {
|
||||
ctx.AppendBodyFinishFunc(l.bodyFinish)
|
||||
err = next.DoChain(ctx)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
if ctx.Response().IsBodyStream() {
|
||||
return nil
|
||||
}
|
||||
|
||||
l.bodyFinish(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -213,7 +213,9 @@ func (e *executor) processKeyPool(ctx http_context.IHttpContext, cloneProxy http
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.Response().IsBodyStream() {
|
||||
return nil
|
||||
}
|
||||
if err = converter.ResponseConvert(ctx); err != nil {
|
||||
convert.SetAIProviderStatuses(ctx, convert.AIProviderStatus{
|
||||
Provider: e.provider,
|
||||
@@ -235,9 +237,10 @@ func (e *executor) processKeyPool(ctx http_context.IHttpContext, cloneProxy http
|
||||
return nil
|
||||
default:
|
||||
continue
|
||||
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("")
|
||||
return fmt.Errorf("all key resources for provider %s is invalid", e.provider)
|
||||
}
|
||||
|
||||
// handleNoKeyResource handles the case when no key resources are available.
|
||||
|
||||
@@ -131,7 +131,7 @@ func (h *complete) Complete(org eocontext.EoContext) error {
|
||||
}
|
||||
response := fasthttp.AcquireResponse()
|
||||
|
||||
lastErr = fasthttp_client.ProxyTimeout(scheme, host, node, request, response, timeOut, false)
|
||||
lastErr = fasthttp_client.ProxyTimeout(scheme, host, node, request, response, timeOut)
|
||||
if lastErr == nil {
|
||||
return newGRPCResponse(ctx, response, methodDesc)
|
||||
}
|
||||
|
||||
@@ -38,7 +38,6 @@ func (h *HttpComplete) Complete(org eocontext.EoContext) error {
|
||||
defer func() {
|
||||
//设置原始响应状态码
|
||||
ctx.Response().SetProxyStatus(ctx.Response().StatusCode(), "")
|
||||
//ctx.WithValue("response_time", time.Now().Sub(proxyTime).Milliseconds())
|
||||
ctx.Response().SetResponseTime(time.Since(proxyTime))
|
||||
ctx.SetLabel("handler", "proxy")
|
||||
}()
|
||||
|
||||
4
go.mod
4
go.mod
@@ -11,7 +11,7 @@ require (
|
||||
github.com/clbanning/mxj v1.8.4
|
||||
github.com/coocood/freecache v1.2.2
|
||||
github.com/dubbogo/gost v1.13.1
|
||||
github.com/eolinker/eosc v0.20.1
|
||||
github.com/eolinker/eosc v0.20.2
|
||||
github.com/fasthttp/websocket v1.5.0
|
||||
github.com/fullstorydev/grpcurl v1.8.7
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
@@ -197,4 +197,4 @@ require (
|
||||
|
||||
replace github.com/soheilhy/cmux v0.1.5 => github.com/hmzzrcs/cmux v0.1.6
|
||||
|
||||
replace github.com/eolinker/eosc => ../eosc
|
||||
//replace github.com/eolinker/eosc => ../eosc
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -17,9 +16,9 @@ import (
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func ProxyTimeout(scheme string, host string, node eocontext.INode, req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration, stream bool) error {
|
||||
func ProxyTimeout(scheme string, host string, node eocontext.INode, req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration) error {
|
||||
addr := fmt.Sprintf("%s://%s", scheme, node.Addr())
|
||||
err := defaultClient.ProxyTimeout(addr, host, req, resp, timeout, stream)
|
||||
err := defaultClient.ProxyTimeout(addr, host, req, resp, timeout)
|
||||
if err != nil {
|
||||
node.Down()
|
||||
}
|
||||
@@ -81,7 +80,7 @@ func GenDialFunc(isTls bool) (fasthttp.DialFunc, error) {
|
||||
return Dial, nil
|
||||
}
|
||||
|
||||
func (c *Client) getHostClient(addr string, rewriteHost string, stream bool) (*fasthttp.HostClient, string, error) {
|
||||
func (c *Client) getHostClient(addr string, rewriteHost string) (*fasthttp.HostClient, string, error) {
|
||||
|
||||
scheme, nodeAddr := readAddress(addr)
|
||||
host := nodeAddr
|
||||
@@ -96,7 +95,7 @@ func (c *Client) getHostClient(addr string, rewriteHost string, stream bool) (*f
|
||||
if isTLS {
|
||||
m = c.ms
|
||||
}
|
||||
key := fmt.Sprintf("%s-%s", host, strconv.FormatBool(stream))
|
||||
key := host
|
||||
hc := m[key]
|
||||
c.mLock.RUnlock()
|
||||
if hc != nil {
|
||||
@@ -142,7 +141,7 @@ func (c *Client) getHostClient(addr string, rewriteHost string, stream bool) (*f
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
Dial: dial,
|
||||
StreamResponseBody: stream,
|
||||
StreamResponseBody: true,
|
||||
MaxConns: DefaultMaxConns,
|
||||
MaxConnWaitTimeout: DefaultMaxConnWaitTimeout,
|
||||
RetryIf: func(request *fasthttp.Request) bool {
|
||||
@@ -186,7 +185,7 @@ func (c *Client) getHostClient(addr string, rewriteHost string, stream bool) (*f
|
||||
// continue in the background and the response will be discarded.
|
||||
// If requests take too long and the connection pool gets filled up please
|
||||
// try setting a ReadTimeout.
|
||||
func (c *Client) ProxyTimeout(addr string, host string, req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration, stream bool) error {
|
||||
func (c *Client) ProxyTimeout(addr string, host string, req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration) error {
|
||||
request := req
|
||||
request.Header.ResetConnectionClose()
|
||||
request.Header.Set("Connection", "keep-alive")
|
||||
@@ -197,7 +196,7 @@ func (c *Client) ProxyTimeout(addr string, host string, req *fasthttp.Request, r
|
||||
}
|
||||
}()
|
||||
|
||||
client, scheme, err := c.getHostClient(addr, host, stream)
|
||||
client, scheme, err := c.getHostClient(addr, host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestMyselfProxyTimeout(t *testing.T) {
|
||||
req.Header.SetContentType("application/json")
|
||||
t.Log(string(req.URI().RequestURI()), req.URI().String(), string(req.URI().Host()), string(req.URI().Scheme()))
|
||||
req.SetBody([]byte(`{"cpCode":"YTO","province":"广东省","city":"广州市"}`))
|
||||
err := defaultClient.ProxyTimeout(addr, "", req, resp, 0, false)
|
||||
err := defaultClient.ProxyTimeout(addr, "", req, resp, 0)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
@@ -37,9 +37,24 @@ type cloneContext struct {
|
||||
upstreamHostHandler eoscContext.UpstreamHostHandler
|
||||
labels map[string]string
|
||||
entry eosc.IEntry
|
||||
bodyFinishes []http_service.BodyFinishFunc
|
||||
responseError error
|
||||
}
|
||||
|
||||
func (ctx *cloneContext) BodyFinish() {
|
||||
for _, finishFunc := range ctx.bodyFinishes {
|
||||
finishFunc(ctx)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ctx *cloneContext) AppendBodyFinishFunc(finishFunc http_service.BodyFinishFunc) {
|
||||
if ctx.bodyFinishes == nil {
|
||||
ctx.bodyFinishes = make([]http_service.BodyFinishFunc, 0, 10)
|
||||
}
|
||||
ctx.bodyFinishes = append(ctx.bodyFinishes, finishFunc)
|
||||
}
|
||||
|
||||
func (ctx *cloneContext) ProxyClone() http_service.IRequest {
|
||||
// 创建一个新的 ProxyRequest 实例
|
||||
req := fasthttp.AcquireRequest()
|
||||
@@ -167,7 +182,7 @@ func (ctx *cloneContext) SendTo(scheme string, node eoscContext.INode, timeout t
|
||||
request.URI().SetHost(targetHost)
|
||||
}
|
||||
beginTime := time.Now()
|
||||
ctx.responseError = fasthttp_client.ProxyTimeout(scheme, rewriteHost, node, request, ctx.response.Response, timeout, false)
|
||||
ctx.responseError = fasthttp_client.ProxyTimeout(scheme, rewriteHost, node, request, ctx.response.Response, timeout)
|
||||
var responseHeader fasthttp.ResponseHeader
|
||||
if ctx.response.Response != nil {
|
||||
responseHeader = ctx.response.Response.Header
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package http_context
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -9,12 +10,11 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/eolinker/apinto/entries/ctx_key"
|
||||
http_entry "github.com/eolinker/apinto/entries/http-entry"
|
||||
|
||||
"github.com/eolinker/eosc"
|
||||
|
||||
"github.com/eolinker/apinto/entries/ctx_key"
|
||||
|
||||
"github.com/eolinker/eosc/log"
|
||||
"github.com/eolinker/eosc/utils/config"
|
||||
|
||||
@@ -45,6 +45,20 @@ type HttpContext struct {
|
||||
labels map[string]string
|
||||
port int
|
||||
entry eosc.IEntry
|
||||
bodyFinishes []http_service.BodyFinishFunc
|
||||
}
|
||||
|
||||
func (ctx *HttpContext) BodyFinish() {
|
||||
for _, finishFunc := range ctx.bodyFinishes {
|
||||
finishFunc(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *HttpContext) AppendBodyFinishFunc(finishFunc http_service.BodyFinishFunc) {
|
||||
if ctx.bodyFinishes == nil {
|
||||
ctx.bodyFinishes = make([]http_service.BodyFinishFunc, 0, 10)
|
||||
}
|
||||
ctx.bodyFinishes = append(ctx.bodyFinishes, finishFunc)
|
||||
}
|
||||
|
||||
func (ctx *HttpContext) ProxyClone() http_service.IRequest {
|
||||
@@ -182,7 +196,7 @@ func (ctx *HttpContext) SendTo(scheme string, node eoscContext.INode, timeout ti
|
||||
|
||||
beginTime := time.Now()
|
||||
response := fasthttp.AcquireResponse()
|
||||
ctx.response.responseError = fasthttp_client.ProxyTimeout(scheme, rewriteHost, node, request, response, timeout, ctx.GetLabel("stream") == "true")
|
||||
ctx.response.responseError = fasthttp_client.ProxyTimeout(scheme, rewriteHost, node, request, response, timeout)
|
||||
|
||||
agent := newRequestAgent(&ctx.proxyRequest, host, scheme, response.Header, beginTime, time.Now())
|
||||
|
||||
@@ -202,23 +216,49 @@ func (ctx *HttpContext) SendTo(scheme string, node eoscContext.INode, timeout ti
|
||||
response.Header.CopyTo(&ctx.response.Response.Header)
|
||||
ctx.response.ResponseHeader.refresh()
|
||||
if response.IsBodyStream() && response.Header.ContentLength() < 0 {
|
||||
reader := response.BodyStream()
|
||||
// 流式传输
|
||||
ctx.response.Response.SetStatusCode(response.StatusCode())
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
ctx.SetLabel("stream_running", "true")
|
||||
ctx.response.Response.SetBodyStreamWriter(func(w *bufio.Writer) {
|
||||
defer func() {
|
||||
ctx.SetLabel("stream_running", "false")
|
||||
ctx.FastFinish()
|
||||
}()
|
||||
reader := response.BodyStream()
|
||||
buffer := make([]byte, 4096) // 4KB 缓冲区
|
||||
for {
|
||||
n, err := reader.Read(buffer)
|
||||
if n > 0 {
|
||||
chunk := buffer[:n]
|
||||
for _, streamFunc := range ctx.Response().StreamFunc() {
|
||||
chunk, err = streamFunc(ctx, chunk)
|
||||
if err != nil {
|
||||
log.Errorf("exec stream func error: %v", err)
|
||||
break
|
||||
}
|
||||
chunk = append(chunk, []byte("\r\n")...)
|
||||
}
|
||||
|
||||
n, err = w.Write(chunk)
|
||||
if err != nil {
|
||||
log.Errorf("stream write error: %v", err)
|
||||
break
|
||||
}
|
||||
ctx.Response().SetBody(chunk)
|
||||
|
||||
w.Flush() // 实时发送数据
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
log.Errorf("stream read error: %v", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
agent.responseBody.Write(response.Body()[:n])
|
||||
_, err = ctx.response.Response.BodyWriter().Write(buf[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
ctx.BodyFinish()
|
||||
})
|
||||
|
||||
agent.setResponseLength(-1)
|
||||
ctx.proxyRequests = append(ctx.proxyRequests, agent)
|
||||
return nil
|
||||
@@ -280,19 +320,27 @@ func (ctx *HttpContext) Clone() (eoscContext.EoContext, error) {
|
||||
|
||||
copyContext.proxyRequest.reset(req, ctx.requestReader.remoteAddr)
|
||||
copyContext.response.reset(resp)
|
||||
|
||||
resp.Header.CopyTo(copyContext.response.header)
|
||||
copyContext.response.refresh()
|
||||
copyContext.completeHandler = ctx.completeHandler
|
||||
copyContext.finishHandler = ctx.finishHandler
|
||||
|
||||
copyContext.response.Response.SetStatusCode(ctx.response.Response.StatusCode())
|
||||
cloneLabels := make(map[string]string, len(ctx.labels))
|
||||
for k, v := range ctx.labels {
|
||||
cloneLabels[k] = v
|
||||
}
|
||||
copyContext.labels = cloneLabels
|
||||
for _, finishFunc := range ctx.bodyFinishes {
|
||||
copyContext.AppendBodyFinishFunc(finishFunc)
|
||||
}
|
||||
for _, streamFunc := range ctx.response.streamFuncArray {
|
||||
copyContext.Response().AppendStreamFunc(streamFunc)
|
||||
}
|
||||
|
||||
//记录请求时间
|
||||
copyContext.ctx = context.WithValue(ctx.Context(), http_service.KeyCloneCtx, true)
|
||||
copyContext.WithValue(ctx_key.CtxKeyRetry, 0)
|
||||
|
||||
return copyContext, nil
|
||||
}
|
||||
|
||||
@@ -334,8 +382,13 @@ func (ctx *HttpContext) RequestId() string {
|
||||
return ctx.requestID
|
||||
}
|
||||
|
||||
// Finish finish
|
||||
// FastFinish finish
|
||||
func (ctx *HttpContext) FastFinish() {
|
||||
streamRunning := ctx.GetLabel("stream_running")
|
||||
if streamRunning == "true" {
|
||||
// 暂时不释放
|
||||
return
|
||||
}
|
||||
if ctx.response.responseError != nil {
|
||||
ctx.fastHttpRequestCtx.SetStatusCode(504)
|
||||
ctx.fastHttpRequestCtx.SetBodyString(ctx.response.responseError.Error())
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
package http_context
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/eolinker/eosc/log"
|
||||
)
|
||||
|
||||
type Reader struct {
|
||||
reader io.Reader
|
||||
|
||||
agent *requestAgent
|
||||
record strings.Builder
|
||||
requestId string
|
||||
resp *Response
|
||||
}
|
||||
|
||||
func (r *Reader) Read(p []byte) (int, error) {
|
||||
n, err := r.reader.Read(p)
|
||||
if err != nil {
|
||||
log.Debug("read error:", err)
|
||||
log.DebugF("request id %s ,read body: %s", r.requestId, r.record.String())
|
||||
return 0, err
|
||||
}
|
||||
r.record.Write(p[:n])
|
||||
if r.agent != nil {
|
||||
r.agent.responseBody.Write(p[:n])
|
||||
}
|
||||
if r.resp != nil {
|
||||
r.resp.AppendBody(p[:n])
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
@@ -2,14 +2,11 @@ package http_context
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
"go.uber.org/zap/buffer"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
@@ -19,75 +16,39 @@ var _ http_service.IResponse = (*Response)(nil)
|
||||
type Response struct {
|
||||
ResponseHeader
|
||||
*fasthttp.Response
|
||||
statusCode int
|
||||
length int
|
||||
responseTime time.Duration
|
||||
proxyStatusCode int
|
||||
responseError error
|
||||
remoteIP string
|
||||
remotePort int
|
||||
buf *buffer.Buffer
|
||||
streamBody *bytes.Buffer
|
||||
streamFuncArray []http_service.StreamFunc
|
||||
}
|
||||
|
||||
type BodyStream struct {
|
||||
reader io.Reader
|
||||
streamReadHandler []func(p []byte) error
|
||||
streamWriteHandler []func(p []byte) ([]byte, error)
|
||||
func (r *Response) StreamFunc() []http_service.StreamFunc {
|
||||
return r.streamFuncArray
|
||||
}
|
||||
|
||||
func NewBodyStream(reader io.Reader) *BodyStream {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.Bytes()
|
||||
return &BodyStream{reader: reader}
|
||||
}
|
||||
|
||||
func (b *BodyStream) AppendReaderFunc(f func(p []byte) error) {
|
||||
if b.streamReadHandler == nil {
|
||||
b.streamReadHandler = make([]func(p []byte) error, 0)
|
||||
func (r *Response) AppendStreamFunc(streamFunc http_service.StreamFunc) {
|
||||
if r.streamFuncArray == nil {
|
||||
r.streamFuncArray = make([]http_service.StreamFunc, 0, 10)
|
||||
}
|
||||
b.streamReadHandler = append(b.streamReadHandler, f)
|
||||
r.streamFuncArray = append(r.streamFuncArray, streamFunc)
|
||||
}
|
||||
|
||||
func (b *BodyStream) AppendWriterFunc(f func(p []byte) ([]byte, error)) {
|
||||
if b.streamWriteHandler == nil {
|
||||
b.streamWriteHandler = make([]func(p []byte) ([]byte, error), 0)
|
||||
}
|
||||
b.streamWriteHandler = append(b.streamWriteHandler, f)
|
||||
}
|
||||
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 32*1024) // 默认 32KB 缓冲区
|
||||
},
|
||||
}
|
||||
|
||||
func (b *BodyStream) Read(p []byte) (n int, err error) {
|
||||
tmp := bufferPool.Get().([]byte)
|
||||
defer bufferPool.Put(tmp)
|
||||
n, err = b.reader.Read(tmp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
org := tmp[:n]
|
||||
for _, fn := range b.streamWriteHandler {
|
||||
result, err := fn(org)
|
||||
func (r *Response) StreamFuncHandle(ctx http_service.IHttpContext, org []byte) ([]byte, error) {
|
||||
result := make([]byte, len(org))
|
||||
copy(result, org)
|
||||
var err error
|
||||
for _, streamFunc := range r.streamFuncArray {
|
||||
result, err = streamFunc(ctx, result)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return nil, err
|
||||
}
|
||||
org = result
|
||||
}
|
||||
org = append(org, []byte("\n")...)
|
||||
copy(p, org)
|
||||
return len(org), nil
|
||||
}
|
||||
|
||||
func (b *BodyStream) Write(p []byte) (n int, err error) {
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *Response) GetBodyStream() http_service.IResponseStream {
|
||||
return r.bodyStream
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *Response) ContentLength() int {
|
||||
@@ -117,6 +78,7 @@ func (r *Response) Finish() error {
|
||||
r.Response = nil
|
||||
r.responseError = nil
|
||||
r.proxyStatusCode = 0
|
||||
r.streamBody = nil
|
||||
return nil
|
||||
}
|
||||
func (r *Response) reset(resp *fasthttp.Response) {
|
||||
@@ -124,6 +86,7 @@ func (r *Response) reset(resp *fasthttp.Response) {
|
||||
r.ResponseHeader.reset(&resp.Header)
|
||||
r.responseError = nil
|
||||
r.proxyStatusCode = 0
|
||||
r.streamBody = &bytes.Buffer{}
|
||||
}
|
||||
|
||||
func (r *Response) BodyLen() int {
|
||||
@@ -138,7 +101,7 @@ func (r *Response) GetBody() []byte {
|
||||
r.Response.SetBody(body)
|
||||
}
|
||||
if r.IsBodyStream() {
|
||||
return nil
|
||||
return r.streamBody.Bytes()
|
||||
}
|
||||
return r.Response.Body()
|
||||
}
|
||||
@@ -149,6 +112,8 @@ func (r *Response) IsBodyStream() bool {
|
||||
|
||||
func (r *Response) SetBody(bytes []byte) {
|
||||
if r.IsBodyStream() {
|
||||
r.streamBody.Write(bytes)
|
||||
// 不处理
|
||||
return
|
||||
}
|
||||
if strings.Contains(r.GetHeader("Content-Encoding"), "gzip") {
|
||||
@@ -164,15 +129,21 @@ func (r *Response) StatusCode() int {
|
||||
if r.responseError != nil {
|
||||
return 504
|
||||
}
|
||||
return r.Response.StatusCode()
|
||||
|
||||
return r.statusCode
|
||||
}
|
||||
|
||||
func (r *Response) Status() string {
|
||||
return strconv.Itoa(r.StatusCode())
|
||||
if r.statusCode == 0 {
|
||||
r.statusCode = r.Response.StatusCode()
|
||||
}
|
||||
|
||||
return strconv.Itoa(r.statusCode)
|
||||
}
|
||||
|
||||
func (r *Response) SetStatus(code int, status string) {
|
||||
r.Response.SetStatusCode(code)
|
||||
r.statusCode = code
|
||||
r.responseError = nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user