更新流式传输操作

This commit is contained in:
Liujian
2025-02-26 11:13:10 +08:00
parent 8a59d592d0
commit cad39db9a4
12 changed files with 204 additions and 185 deletions

View File

@@ -73,9 +73,7 @@ func (c *Chat) RequestConvert(ctx eocontext.EoContext, extender map[string]inter
for k, v := range extender {
baseCfg.SetAppend(k, v)
}
//if paramModel := ctx.GetLabel(convert.ParamAIModel); paramModel != "" {
// baseCfg.SetAppend("model", paramModel)
//}
// Marshal the updated configuration back into JSON.
body, err = json.Marshal(baseCfg)
if err != nil {
@@ -87,10 +85,59 @@ func (c *Chat) RequestConvert(ctx eocontext.EoContext, extender map[string]inter
// SetProvider the modified body in the HTTP context.
httpContext.Proxy().Body().SetRaw("application/json", body)
httpContext.Response().AppendStreamFunc(c.streamFunc())
return nil
}
func (c *Chat) streamFunc() http_context.StreamFunc {
return func(ctx http_context.IHttpContext, p []byte) ([]byte, error) {
data := eosc.NewBase[Response]()
err := json.Unmarshal(p, data)
if err != nil {
return nil, err
}
status := ctx.Response().StatusCode()
switch status {
case 200:
// Calculate the token consumption for a successful request.
usage := data.Config
if usage.Done {
convert.SetAIStatusNormal(ctx)
convert.SetAIModelInputToken(ctx, usage.PromptEvalCount)
convert.SetAIModelOutputToken(ctx, usage.EvalCount)
convert.SetAIModelTotalToken(ctx, usage.PromptEvalCount+usage.EvalCount)
}
case 404:
convert.SetAIStatusInvalid(ctx)
case 429:
convert.SetAIStatusExceeded(ctx)
}
// Prepare the response body for the client.
responseBody := &convert.ClientResponse{}
resp := data.Config
if resp.Message != nil {
responseBody.Message = &convert.Message{
Role: resp.Message.Role,
Content: resp.Message.Content,
}
if resp.Done {
responseBody.FinishReason = convert.FinishStop
}
} else {
responseBody.Code = -1
responseBody.Error = "response message is nil"
}
// Marshal the modified response body back into JSON.
body, err := json.Marshal(responseBody)
if err != nil {
return nil, err
}
return body, nil
}
}
// ResponseConvert converts the response body for the Chat mode.
// It processes the response to ensure it conforms to the expected format and encoding.
func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error {
@@ -99,57 +146,14 @@ func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error {
if err != nil {
return err
}
status := httpContext.Response().StatusCode()
switch status {
case 200:
convert.SetAIStatusNormal(ctx)
}
if httpContext.Response().IsBodyStream() {
bodyStream := httpContext.Response().GetBodyStream()
if bodyStream != nil {
bodyStream.AppendWriterFunc(func(p []byte) ([]byte, error) {
// Parse the response body into a base configuration.
data := eosc.NewBase[Response]()
err = json.Unmarshal(p, data)
if err != nil {
return nil, err
}
switch status {
case 200:
// Calculate the token consumption for a successful request.
usage := data.Config
if usage.Done {
convert.SetAIStatusNormal(ctx)
convert.SetAIModelInputToken(ctx, usage.PromptEvalCount)
convert.SetAIModelOutputToken(ctx, usage.EvalCount)
convert.SetAIModelTotalToken(ctx, usage.PromptEvalCount+usage.EvalCount)
}
}
// Prepare the response body for the client.
responseBody := &convert.ClientResponse{}
resp := data.Config
if resp.Message != nil {
responseBody.Message = &convert.Message{
Role: resp.Message.Role,
Content: resp.Message.Content,
}
if resp.Done {
responseBody.FinishReason = convert.FinishStop
}
} else {
responseBody.Code = -1
responseBody.Error = "response message is nil"
}
// Marshal the modified response body back into JSON.
body, err := json.Marshal(responseBody)
if err != nil {
return nil, err
}
return body, nil
})
}
return nil
}
// Retrieve the response body.

View File

@@ -24,22 +24,31 @@ func (l *accessLog) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (er
return http_service.DoHttpFilter(l, ctx, next)
}
func (l *accessLog) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) (err error) {
err = next.DoChain(ctx)
if err != nil {
log.Error(err)
}
func (l *accessLog) bodyFinish(ctx http_service.IHttpContext) {
outputs := l.proxy.List()
entry := http_entry.NewEntry(ctx)
for _, v := range outputs {
err = v.Output(entry)
err := v.Output(entry)
if err != nil {
log.Error("access log http-entry error:", err)
continue
}
}
return
}
func (l *accessLog) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) (err error) {
ctx.AppendBodyFinishFunc(l.bodyFinish)
err = next.DoChain(ctx)
if err != nil {
log.Error(err)
}
if ctx.Response().IsBodyStream() {
return nil
}
l.bodyFinish(ctx)
return nil
}

View File

@@ -213,7 +213,9 @@ func (e *executor) processKeyPool(ctx http_context.IHttpContext, cloneProxy http
return err
}
}
if ctx.Response().IsBodyStream() {
return nil
}
if err = converter.ResponseConvert(ctx); err != nil {
convert.SetAIProviderStatuses(ctx, convert.AIProviderStatus{
Provider: e.provider,
@@ -235,9 +237,10 @@ func (e *executor) processKeyPool(ctx http_context.IHttpContext, cloneProxy http
return nil
default:
continue
}
}
return fmt.Errorf("")
return fmt.Errorf("all key resources for provider %s is invalid", e.provider)
}
// handleNoKeyResource handles the case when no key resources are available.

View File

@@ -131,7 +131,7 @@ func (h *complete) Complete(org eocontext.EoContext) error {
}
response := fasthttp.AcquireResponse()
lastErr = fasthttp_client.ProxyTimeout(scheme, host, node, request, response, timeOut, false)
lastErr = fasthttp_client.ProxyTimeout(scheme, host, node, request, response, timeOut)
if lastErr == nil {
return newGRPCResponse(ctx, response, methodDesc)
}

View File

@@ -38,7 +38,6 @@ func (h *HttpComplete) Complete(org eocontext.EoContext) error {
defer func() {
//设置原始响应状态码
ctx.Response().SetProxyStatus(ctx.Response().StatusCode(), "")
//ctx.WithValue("response_time", time.Now().Sub(proxyTime).Milliseconds())
ctx.Response().SetResponseTime(time.Since(proxyTime))
ctx.SetLabel("handler", "proxy")
}()

4
go.mod
View File

@@ -11,7 +11,7 @@ require (
github.com/clbanning/mxj v1.8.4
github.com/coocood/freecache v1.2.2
github.com/dubbogo/gost v1.13.1
github.com/eolinker/eosc v0.20.1
github.com/eolinker/eosc v0.20.2
github.com/fasthttp/websocket v1.5.0
github.com/fullstorydev/grpcurl v1.8.7
github.com/go-redis/redis/v8 v8.11.5
@@ -197,4 +197,4 @@ require (
replace github.com/soheilhy/cmux v0.1.5 => github.com/hmzzrcs/cmux v0.1.6
replace github.com/eolinker/eosc => ../eosc
//replace github.com/eolinker/eosc => ../eosc

View File

@@ -7,7 +7,6 @@ import (
"net"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
@@ -17,9 +16,9 @@ import (
"github.com/valyala/fasthttp"
)
func ProxyTimeout(scheme string, host string, node eocontext.INode, req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration, stream bool) error {
func ProxyTimeout(scheme string, host string, node eocontext.INode, req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration) error {
addr := fmt.Sprintf("%s://%s", scheme, node.Addr())
err := defaultClient.ProxyTimeout(addr, host, req, resp, timeout, stream)
err := defaultClient.ProxyTimeout(addr, host, req, resp, timeout)
if err != nil {
node.Down()
}
@@ -81,7 +80,7 @@ func GenDialFunc(isTls bool) (fasthttp.DialFunc, error) {
return Dial, nil
}
func (c *Client) getHostClient(addr string, rewriteHost string, stream bool) (*fasthttp.HostClient, string, error) {
func (c *Client) getHostClient(addr string, rewriteHost string) (*fasthttp.HostClient, string, error) {
scheme, nodeAddr := readAddress(addr)
host := nodeAddr
@@ -96,7 +95,7 @@ func (c *Client) getHostClient(addr string, rewriteHost string, stream bool) (*f
if isTLS {
m = c.ms
}
key := fmt.Sprintf("%s-%s", host, strconv.FormatBool(stream))
key := host
hc := m[key]
c.mLock.RUnlock()
if hc != nil {
@@ -142,7 +141,7 @@ func (c *Client) getHostClient(addr string, rewriteHost string, stream bool) (*f
InsecureSkipVerify: true,
},
Dial: dial,
StreamResponseBody: stream,
StreamResponseBody: true,
MaxConns: DefaultMaxConns,
MaxConnWaitTimeout: DefaultMaxConnWaitTimeout,
RetryIf: func(request *fasthttp.Request) bool {
@@ -186,7 +185,7 @@ func (c *Client) getHostClient(addr string, rewriteHost string, stream bool) (*f
// continue in the background and the response will be discarded.
// If requests take too long and the connection pool gets filled up please
// try setting a ReadTimeout.
func (c *Client) ProxyTimeout(addr string, host string, req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration, stream bool) error {
func (c *Client) ProxyTimeout(addr string, host string, req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration) error {
request := req
request.Header.ResetConnectionClose()
request.Header.Set("Connection", "keep-alive")
@@ -197,7 +196,7 @@ func (c *Client) ProxyTimeout(addr string, host string, req *fasthttp.Request, r
}
}()
client, scheme, err := c.getHostClient(addr, host, stream)
client, scheme, err := c.getHostClient(addr, host)
if err != nil {
return err
}

View File

@@ -18,7 +18,7 @@ func TestMyselfProxyTimeout(t *testing.T) {
req.Header.SetContentType("application/json")
t.Log(string(req.URI().RequestURI()), req.URI().String(), string(req.URI().Host()), string(req.URI().Scheme()))
req.SetBody([]byte(`{"cpCode":"YTO","province":"广东省","city":"广州市"}`))
err := defaultClient.ProxyTimeout(addr, "", req, resp, 0, false)
err := defaultClient.ProxyTimeout(addr, "", req, resp, 0)
if err != nil {
t.Error(err)
}

View File

@@ -37,9 +37,24 @@ type cloneContext struct {
upstreamHostHandler eoscContext.UpstreamHostHandler
labels map[string]string
entry eosc.IEntry
bodyFinishes []http_service.BodyFinishFunc
responseError error
}
func (ctx *cloneContext) BodyFinish() {
for _, finishFunc := range ctx.bodyFinishes {
finishFunc(ctx)
}
return
}
func (ctx *cloneContext) AppendBodyFinishFunc(finishFunc http_service.BodyFinishFunc) {
if ctx.bodyFinishes == nil {
ctx.bodyFinishes = make([]http_service.BodyFinishFunc, 0, 10)
}
ctx.bodyFinishes = append(ctx.bodyFinishes, finishFunc)
}
func (ctx *cloneContext) ProxyClone() http_service.IRequest {
// 创建一个新的 ProxyRequest 实例
req := fasthttp.AcquireRequest()
@@ -167,7 +182,7 @@ func (ctx *cloneContext) SendTo(scheme string, node eoscContext.INode, timeout t
request.URI().SetHost(targetHost)
}
beginTime := time.Now()
ctx.responseError = fasthttp_client.ProxyTimeout(scheme, rewriteHost, node, request, ctx.response.Response, timeout, false)
ctx.responseError = fasthttp_client.ProxyTimeout(scheme, rewriteHost, node, request, ctx.response.Response, timeout)
var responseHeader fasthttp.ResponseHeader
if ctx.response.Response != nil {
responseHeader = ctx.response.Response.Header

View File

@@ -1,6 +1,7 @@
package http_context
import (
"bufio"
"context"
"fmt"
"io"
@@ -9,12 +10,11 @@ import (
"strings"
"time"
"github.com/eolinker/apinto/entries/ctx_key"
http_entry "github.com/eolinker/apinto/entries/http-entry"
"github.com/eolinker/eosc"
"github.com/eolinker/apinto/entries/ctx_key"
"github.com/eolinker/eosc/log"
"github.com/eolinker/eosc/utils/config"
@@ -45,6 +45,20 @@ type HttpContext struct {
labels map[string]string
port int
entry eosc.IEntry
bodyFinishes []http_service.BodyFinishFunc
}
func (ctx *HttpContext) BodyFinish() {
for _, finishFunc := range ctx.bodyFinishes {
finishFunc(ctx)
}
}
func (ctx *HttpContext) AppendBodyFinishFunc(finishFunc http_service.BodyFinishFunc) {
if ctx.bodyFinishes == nil {
ctx.bodyFinishes = make([]http_service.BodyFinishFunc, 0, 10)
}
ctx.bodyFinishes = append(ctx.bodyFinishes, finishFunc)
}
func (ctx *HttpContext) ProxyClone() http_service.IRequest {
@@ -182,7 +196,7 @@ func (ctx *HttpContext) SendTo(scheme string, node eoscContext.INode, timeout ti
beginTime := time.Now()
response := fasthttp.AcquireResponse()
ctx.response.responseError = fasthttp_client.ProxyTimeout(scheme, rewriteHost, node, request, response, timeout, ctx.GetLabel("stream") == "true")
ctx.response.responseError = fasthttp_client.ProxyTimeout(scheme, rewriteHost, node, request, response, timeout)
agent := newRequestAgent(&ctx.proxyRequest, host, scheme, response.Header, beginTime, time.Now())
@@ -202,23 +216,49 @@ func (ctx *HttpContext) SendTo(scheme string, node eoscContext.INode, timeout ti
response.Header.CopyTo(&ctx.response.Response.Header)
ctx.response.ResponseHeader.refresh()
if response.IsBodyStream() && response.Header.ContentLength() < 0 {
reader := response.BodyStream()
// 流式传输
ctx.response.Response.SetStatusCode(response.StatusCode())
buf := make([]byte, 4096)
for {
n, err := reader.Read(buf)
if err == io.EOF {
return nil
ctx.SetLabel("stream_running", "true")
ctx.response.Response.SetBodyStreamWriter(func(w *bufio.Writer) {
defer func() {
ctx.SetLabel("stream_running", "false")
ctx.FastFinish()
}()
reader := response.BodyStream()
buffer := make([]byte, 4096) // 4KB 缓冲区
for {
n, err := reader.Read(buffer)
if n > 0 {
chunk := buffer[:n]
for _, streamFunc := range ctx.Response().StreamFunc() {
chunk, err = streamFunc(ctx, chunk)
if err != nil {
log.Errorf("exec stream func error: %v", err)
break
}
chunk = append(chunk, []byte("\r\n")...)
}
n, err = w.Write(chunk)
if err != nil {
log.Errorf("stream write error: %v", err)
break
}
ctx.Response().SetBody(chunk)
w.Flush() // 实时发送数据
}
if err != nil {
if err == io.EOF {
break
}
log.Errorf("stream read error: %v", err)
break
}
}
if err != nil {
return err
}
agent.responseBody.Write(response.Body()[:n])
_, err = ctx.response.Response.BodyWriter().Write(buf[:n])
if err != nil {
return err
}
}
ctx.BodyFinish()
})
agent.setResponseLength(-1)
ctx.proxyRequests = append(ctx.proxyRequests, agent)
return nil
@@ -280,19 +320,27 @@ func (ctx *HttpContext) Clone() (eoscContext.EoContext, error) {
copyContext.proxyRequest.reset(req, ctx.requestReader.remoteAddr)
copyContext.response.reset(resp)
resp.Header.CopyTo(copyContext.response.header)
copyContext.response.refresh()
copyContext.completeHandler = ctx.completeHandler
copyContext.finishHandler = ctx.finishHandler
copyContext.response.Response.SetStatusCode(ctx.response.Response.StatusCode())
cloneLabels := make(map[string]string, len(ctx.labels))
for k, v := range ctx.labels {
cloneLabels[k] = v
}
copyContext.labels = cloneLabels
for _, finishFunc := range ctx.bodyFinishes {
copyContext.AppendBodyFinishFunc(finishFunc)
}
for _, streamFunc := range ctx.response.streamFuncArray {
copyContext.Response().AppendStreamFunc(streamFunc)
}
//记录请求时间
copyContext.ctx = context.WithValue(ctx.Context(), http_service.KeyCloneCtx, true)
copyContext.WithValue(ctx_key.CtxKeyRetry, 0)
return copyContext, nil
}
@@ -334,8 +382,13 @@ func (ctx *HttpContext) RequestId() string {
return ctx.requestID
}
// Finish finish
// FastFinish finish
func (ctx *HttpContext) FastFinish() {
streamRunning := ctx.GetLabel("stream_running")
if streamRunning == "true" {
// 暂时不释放
return
}
if ctx.response.responseError != nil {
ctx.fastHttpRequestCtx.SetStatusCode(504)
ctx.fastHttpRequestCtx.SetBodyString(ctx.response.responseError.Error())

View File

@@ -1,34 +0,0 @@
package http_context
import (
"io"
"strings"
"github.com/eolinker/eosc/log"
)
type Reader struct {
reader io.Reader
agent *requestAgent
record strings.Builder
requestId string
resp *Response
}
func (r *Reader) Read(p []byte) (int, error) {
n, err := r.reader.Read(p)
if err != nil {
log.Debug("read error:", err)
log.DebugF("request id %s ,read body: %s", r.requestId, r.record.String())
return 0, err
}
r.record.Write(p[:n])
if r.agent != nil {
r.agent.responseBody.Write(p[:n])
}
if r.resp != nil {
r.resp.AppendBody(p[:n])
}
return n, nil
}

View File

@@ -2,14 +2,11 @@ package http_context
import (
"bytes"
"io"
"strconv"
"strings"
"sync"
"time"
http_service "github.com/eolinker/eosc/eocontext/http-context"
"go.uber.org/zap/buffer"
"github.com/valyala/fasthttp"
)
@@ -19,75 +16,39 @@ var _ http_service.IResponse = (*Response)(nil)
type Response struct {
ResponseHeader
*fasthttp.Response
statusCode int
length int
responseTime time.Duration
proxyStatusCode int
responseError error
remoteIP string
remotePort int
buf *buffer.Buffer
streamBody *bytes.Buffer
streamFuncArray []http_service.StreamFunc
}
type BodyStream struct {
reader io.Reader
streamReadHandler []func(p []byte) error
streamWriteHandler []func(p []byte) ([]byte, error)
func (r *Response) StreamFunc() []http_service.StreamFunc {
return r.streamFuncArray
}
func NewBodyStream(reader io.Reader) *BodyStream {
buf := &bytes.Buffer{}
buf.Bytes()
return &BodyStream{reader: reader}
}
func (b *BodyStream) AppendReaderFunc(f func(p []byte) error) {
if b.streamReadHandler == nil {
b.streamReadHandler = make([]func(p []byte) error, 0)
func (r *Response) AppendStreamFunc(streamFunc http_service.StreamFunc) {
if r.streamFuncArray == nil {
r.streamFuncArray = make([]http_service.StreamFunc, 0, 10)
}
b.streamReadHandler = append(b.streamReadHandler, f)
r.streamFuncArray = append(r.streamFuncArray, streamFunc)
}
func (b *BodyStream) AppendWriterFunc(f func(p []byte) ([]byte, error)) {
if b.streamWriteHandler == nil {
b.streamWriteHandler = make([]func(p []byte) ([]byte, error), 0)
}
b.streamWriteHandler = append(b.streamWriteHandler, f)
}
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024) // 默认 32KB 缓冲区
},
}
func (b *BodyStream) Read(p []byte) (n int, err error) {
tmp := bufferPool.Get().([]byte)
defer bufferPool.Put(tmp)
n, err = b.reader.Read(tmp)
if err != nil {
return 0, err
}
org := tmp[:n]
for _, fn := range b.streamWriteHandler {
result, err := fn(org)
func (r *Response) StreamFuncHandle(ctx http_service.IHttpContext, org []byte) ([]byte, error) {
result := make([]byte, len(org))
copy(result, org)
var err error
for _, streamFunc := range r.streamFuncArray {
result, err = streamFunc(ctx, result)
if err != nil {
return 0, err
return nil, err
}
org = result
}
org = append(org, []byte("\n")...)
copy(p, org)
return len(org), nil
}
func (b *BodyStream) Write(p []byte) (n int, err error) {
return 0, nil
}
func (r *Response) GetBodyStream() http_service.IResponseStream {
return r.bodyStream
return result, nil
}
func (r *Response) ContentLength() int {
@@ -117,6 +78,7 @@ func (r *Response) Finish() error {
r.Response = nil
r.responseError = nil
r.proxyStatusCode = 0
r.streamBody = nil
return nil
}
func (r *Response) reset(resp *fasthttp.Response) {
@@ -124,6 +86,7 @@ func (r *Response) reset(resp *fasthttp.Response) {
r.ResponseHeader.reset(&resp.Header)
r.responseError = nil
r.proxyStatusCode = 0
r.streamBody = &bytes.Buffer{}
}
func (r *Response) BodyLen() int {
@@ -138,7 +101,7 @@ func (r *Response) GetBody() []byte {
r.Response.SetBody(body)
}
if r.IsBodyStream() {
return nil
return r.streamBody.Bytes()
}
return r.Response.Body()
}
@@ -149,6 +112,8 @@ func (r *Response) IsBodyStream() bool {
func (r *Response) SetBody(bytes []byte) {
if r.IsBodyStream() {
r.streamBody.Write(bytes)
// 不处理
return
}
if strings.Contains(r.GetHeader("Content-Encoding"), "gzip") {
@@ -164,15 +129,21 @@ func (r *Response) StatusCode() int {
if r.responseError != nil {
return 504
}
return r.Response.StatusCode()
return r.statusCode
}
func (r *Response) Status() string {
return strconv.Itoa(r.StatusCode())
if r.statusCode == 0 {
r.statusCode = r.Response.StatusCode()
}
return strconv.Itoa(r.statusCode)
}
func (r *Response) SetStatus(code int, status string) {
r.Response.SetStatusCode(code)
r.statusCode = code
r.responseError = nil
}