From 79def2f2dc0a5d66c6cf4e37f7d2b6ee7c03a3bd Mon Sep 17 00:00:00 2001 From: Liujian <824010343@qq.com> Date: Wed, 12 Mar 2025 17:06:52 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0bedrock?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai-convert/convert.go | 5 - ai-convert/manager.go | 42 ++++ app/apinto/driver.go | 2 + app/apinto/profession.go | 6 + drivers/ai-model/config.go | 45 +++++ drivers/ai-model/executor.go | 68 +++++++ drivers/ai-model/factory.go | 37 ++++ drivers/ai-provider/bedrock/config.go | 234 ++++++++++++++++++++--- drivers/ai-provider/bedrock/converter.go | 134 ------------- drivers/ai-provider/bedrock/executor.go | 73 ------- drivers/ai-provider/bedrock/factory.go | 46 ----- drivers/ai-provider/bedrock/mode.go | 119 ------------ encoder/gzip.go | 2 +- 13 files changed, 412 insertions(+), 401 deletions(-) create mode 100644 drivers/ai-model/config.go create mode 100644 drivers/ai-model/executor.go create mode 100644 drivers/ai-model/factory.go delete mode 100644 drivers/ai-provider/bedrock/converter.go delete mode 100644 drivers/ai-provider/bedrock/executor.go delete mode 100644 drivers/ai-provider/bedrock/factory.go delete mode 100644 drivers/ai-provider/bedrock/mode.go diff --git a/ai-convert/convert.go b/ai-convert/convert.go index e52c2934..36c18817 100644 --- a/ai-convert/convert.go +++ b/ai-convert/convert.go @@ -8,11 +8,6 @@ type IConverterFactory interface { type IConverterCreateFunc func(cfg string) (IConverter, error) -//type IConverterDriver interface { -// GetModel(model string) (FGenerateConfig, bool) -// GetConverter(model string) (IConverter, bool) -//} - type IConverter interface { RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error ResponseConvert(ctx eocontext.EoContext) error diff --git a/ai-convert/manager.go b/ai-convert/manager.go index 83a711ed..94db7072 100644 --- a/ai-convert/manager.go +++ b/ai-convert/manager.go @@ -237,3 +237,45 @@ func SetProvider(id string, p IProvider) { func GetProvider(provider string) (IProvider, bool) { return balanceManager.Get(provider) } + +type IModelAccessConfigManager interface { + Get(id string) (IModelAccessConfig, bool) + Set(id string, config IModelAccessConfig) + Del(id string) +} + +type IModelAccessConfig interface { + Provider() string + Model() string + Config() map[string]string +} + +type modelAccessConfigManager struct { + configs eosc.Untyped[string, IModelAccessConfig] +} + +func (m *modelAccessConfigManager) Get(id string) (IModelAccessConfig, bool) { + return m.configs.Get(id) +} + +func (m *modelAccessConfigManager) Set(id string, config IModelAccessConfig) { + m.configs.Set(id, config) +} + +func (m *modelAccessConfigManager) Del(id string) { + m.configs.Del(id) +} + +func NewModelAccessConfigManager() *modelAccessConfigManager { + return &modelAccessConfigManager{ + configs: eosc.BuildUntyped[string, IModelAccessConfig](), + } +} + +var _ IModelAccessConfigManager = (*modelAccessConfigManager)(nil) + +var modelAcManager = NewModelAccessConfigManager() + +func init() { + bean.Injection(&modelAcManager) +} diff --git a/app/apinto/driver.go b/app/apinto/driver.go index 66e60175..2a987501 100644 --- a/app/apinto/driver.go +++ b/app/apinto/driver.go @@ -2,6 +2,7 @@ package main import ( ai_key "github.com/eolinker/apinto/drivers/ai-key" + ai_model "github.com/eolinker/apinto/drivers/ai-model" ai_provider "github.com/eolinker/apinto/drivers/ai-provider" "github.com/eolinker/apinto/drivers/certs" "github.com/eolinker/apinto/drivers/discovery/consul" @@ -125,4 +126,5 @@ func driverRegister(extenderRegister eosc.IExtenderDriverRegister) { ai_provider.Register(extenderRegister) ai_key.Register(extenderRegister) + ai_model.Register(extenderRegister) } diff --git a/app/apinto/profession.go b/app/apinto/profession.go index eb02e1fb..1cd324fd 100644 --- a/app/apinto/profession.go +++ b/app/apinto/profession.go @@ -291,6 +291,12 @@ func ApintoProfession() []*eosc.ProfessionConfig { Label: "ai-provider", Desc: "ai-provider", }, + { + Id: "eolinker.com:apinto:ai-model", + Name: "ai-model", + Label: "ai-model", + Desc: "ai-model", + }, }, }, //{ diff --git a/drivers/ai-model/config.go b/drivers/ai-model/config.go new file mode 100644 index 00000000..e10c4b26 --- /dev/null +++ b/drivers/ai-model/config.go @@ -0,0 +1,45 @@ +package ai_model + +import ( + "fmt" + + "github.com/eolinker/apinto/drivers" + + "github.com/eolinker/eosc" +) + +type Config struct { + Provider string `json:"provider"` + Model string `json:"model"` + AccessConfig string `json:"access_config"` +} + +// Create 创建驱动实例 +func Create(id, name string, v *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) { + cfg, err := checkConfig(v) + if err != nil { + return nil, err + } + w := &executor{ + WorkerBase: drivers.Worker(id, name), + } + err = w.reset(cfg) + return w, err +} + +func checkConfig(v interface{}) (*Config, error) { + conf, ok := v.(*Config) + if !ok { + return nil, eosc.ErrorConfigType + } + + if conf.Provider == "" { + return nil, fmt.Errorf("provider is required") + } + + if conf.Model == "" { + return nil, fmt.Errorf("model is required") + } + + return conf, nil +} diff --git a/drivers/ai-model/executor.go b/drivers/ai-model/executor.go new file mode 100644 index 00000000..6088e89f --- /dev/null +++ b/drivers/ai-model/executor.go @@ -0,0 +1,68 @@ +package ai_model + +import ( + "encoding/json" + + "github.com/eolinker/apinto/drivers" + "github.com/eolinker/eosc" +) + +var _ eosc.IWorker = (*executor)(nil) + +type executor struct { + drivers.WorkerBase +} + +func (e *executor) Start() error { + return nil +} + +func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error { + cfg, err := checkConfig(conf) + if err != nil { + return err + } + return e.reset(cfg) +} + +func (e *executor) reset(conf *Config) error { + tmp := make(map[string]string) + err := json.Unmarshal([]byte(conf.AccessConfig), &tmp) + if err != nil { + return err + } + accessConfigManager.Set(e.Name(), &modelAccessConfig{ + provider: conf.Provider, + model: conf.Model, + accessConfig: tmp, + }) + + return nil +} + +func (e *executor) Stop() error { + accessConfigManager.Del(e.Name()) + return nil +} + +func (e *executor) CheckSkill(skill string) bool { + return false +} + +type modelAccessConfig struct { + provider string + model string + accessConfig map[string]string +} + +func (m *modelAccessConfig) Provider() string { + return m.provider +} + +func (m *modelAccessConfig) Model() string { + return m.model +} + +func (m *modelAccessConfig) Config() map[string]string { + return m.accessConfig +} diff --git a/drivers/ai-model/factory.go b/drivers/ai-model/factory.go new file mode 100644 index 00000000..17d4ce85 --- /dev/null +++ b/drivers/ai-model/factory.go @@ -0,0 +1,37 @@ +package ai_model + +import ( + "sync" + + "github.com/eolinker/eosc/common/bean" + + ai_convert "github.com/eolinker/apinto/ai-convert" + "github.com/eolinker/apinto/drivers" + "github.com/eolinker/eosc" +) + +var name = "ai-model" + +var ( + once sync.Once + accessConfigManager ai_convert.IModelAccessConfigManager +) + +func init() { + once.Do(func() { + bean.Autowired(&accessConfigManager) + }) +} + +type Factory struct { +} + +// Register AI供应商Factory +func Register(register eosc.IExtenderDriverRegister) { + register.RegisterExtenderDriver(name, NewFactory()) +} + +// NewFactory 创建service_http驱动工厂 +func NewFactory() eosc.IExtenderDriverFactory { + return drivers.NewFactory[Config](Create) +} diff --git a/drivers/ai-provider/bedrock/config.go b/drivers/ai-provider/bedrock/config.go index ff8f7cfd..5cbee4b1 100644 --- a/drivers/ai-provider/bedrock/config.go +++ b/drivers/ai-provider/bedrock/config.go @@ -1,30 +1,38 @@ package bedrock import ( + "encoding/json" "fmt" + "net/http" + "strings" + "time" + + "github.com/aws/aws-sdk-go/aws/credentials" + + v4 "github.com/aws/aws-sdk-go/aws/signer/v4" + "github.com/eolinker/eosc/common/bean" + http_service "github.com/eolinker/eosc/eocontext/http-context" + + "github.com/eolinker/eosc/eocontext" + + ai_convert "github.com/eolinker/apinto/ai-convert" "github.com/eolinker/eosc" ) -type Config struct { - AccessKey string `json:"aws_access_key_id"` - SecretKey string `json:"aws_secret_access_key"` - Region string `json:"aws_region"` - ModelForValidation string `json:"model_for_validation"` +var ( + accessConfigManager ai_convert.IModelAccessConfigManager +) + +func init() { + bean.Autowired(&accessConfigManager) } -var ( - availableRegions = map[string]struct{}{ - "us-east-1": {}, - "us-west-2": {}, - "ap-southeast-1": {}, - "ap-northeast-1": {}, - "eu-central-1": {}, - "eu-west-2": {}, - "us-gov-west-1": {}, - "ap-southeast-2": {}, - } -) +type Config struct { + AccessKey string `json:"aws_access_key_id"` + SecretKey string `json:"aws_secret_access_key"` + Region string `json:"aws_region"` +} func checkConfig(v interface{}) (*Config, error) { conf, ok := v.(*Config) @@ -37,11 +45,191 @@ func checkConfig(v interface{}) (*Config, error) { if conf.SecretKey == "" { return nil, fmt.Errorf("aws_secret_access_key is required") } - //if conf.Region == "" { - // return nil, fmt.Errorf("aws_region is required") - //} - //if _, ok := availableRegions[conf.Region]; !ok { - // return nil, fmt.Errorf("aws_region is invalid") - //} return conf, nil } + +func Create(cfg string) (ai_convert.IConverter, error) { + var conf Config + err := json.Unmarshal([]byte(cfg), &conf) + if err != nil { + return nil, err + } + if conf.AccessKey == "" { + return nil, fmt.Errorf("aws_access_key_id is required") + } + if conf.SecretKey == "" { + return nil, fmt.Errorf("aws_secret_access_key is required") + } + return NewConvert(conf.AccessKey, conf.SecretKey, conf.Region), nil +} + +type Convert struct { + signer *v4.Signer + region string +} + +func NewConvert(ak string, sk string, region string) *Convert { + return &Convert{ + signer: v4.NewSigner(credentials.NewStaticCredentials(ak, sk, "")), + region: region, + } +} + +var ( + currentPath = "/model/%s/converse" + streamPath = "/model/%s/converse-stream" +) + +func (c *Convert) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error { + provider := ai_convert.GetAIProvider(ctx) + model := ai_convert.GetAIModel(ctx) + modelCfg, has := accessConfigManager.Get(fmt.Sprintf("%s#%s", provider, model)) + region := "" + if has { + model = modelCfg.Config()["model"] + region = modelCfg.Config()["region"] + } + if region == "" { + region = c.region + } + base := fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", region) + + balanceHandler, err := ai_convert.NewBalanceHandler("", base, 0) + if err != nil { + return err + } + ctx.SetBalance(balanceHandler) + httpContext, err := http_service.Assert(ctx) + if err != nil { + return err + } + body, err := httpContext.Proxy().Body().RawBody() + if err != nil { + return err + } + chatRequest := eosc.NewBase[ai_convert.Request](extender) + err = json.Unmarshal(body, chatRequest) + if err != nil { + return fmt.Errorf("unmarshal body error: %v, body: %s", err, string(body)) + } + messages := make([]Message, 0, len(chatRequest.Config.Messages)) + systemMessage := make([]*Content, 0) + for _, m := range chatRequest.Config.Messages { + if m.Role == "system" { + systemMessage = append(systemMessage, &Content{Text: m.Content}) + } else { + messages = append(messages, Message{ + Role: m.Role, + Content: []*Content{{Text: m.Content}}, + }) + } + } + chatRequest.SetAppend("messages", messages) + chatRequest.SetAppend("system", systemMessage) + path := fmt.Sprintf(currentPath, model) + if chatRequest.Config.Stream { + path = fmt.Sprintf(streamPath, model) + } + uri := fmt.Sprintf("%s%s", base, path) + httpContext.Proxy().URI().SetPath(path) + + body, _ = json.Marshal(chatRequest) + httpContext.Proxy().Body().SetRaw("application/json", body) + headers, err := signRequest(c.signer, region, uri, http.Header{}, string(body)) + if err != nil { + return err + } + for k, v := range headers { + httpContext.Proxy().Header().SetHeader(k, strings.Join(v, ";")) + } + httpContext.Proxy().Body().SetRaw("application/json", body) + return nil +} + +func (c *Convert) ResponseConvert(ctx eocontext.EoContext) error { + httpContext, err := http_service.Assert(ctx) + if err != nil { + return err + } + if httpContext.Response().StatusCode() != 200 { + return nil + } + body := httpContext.Response().GetBody() + data := eosc.NewBase[Response](nil) + err = json.Unmarshal(body, data) + if err != nil { + return err + } + responseBody := &ai_convert.Response{} + + body, err = json.Marshal(responseBody) + if err != nil { + return err + } + httpContext.Response().SetBody(body) + return nil +} + +func (c *Convert) streamHandler(ctx http_service.IHttpContext, p []byte) ([]byte, error) { + //// 对响应数据进行划分 + //inputToken := GetAIModelInputToken(ctx) + //outputToken := 0 + //totalToken := inputToken + //scanner := bufio.NewScanner(bytes.NewReader(p)) + //// Check the content encoding and convert to UTF-8 if necessary. + //encoding := ctx.Response().Headers().Get("content-encoding") + //for scanner.Scan() { + // line := scanner.Text() + // if encoding != "utf-8" && encoding != "" { + // tmp, err := encoder.ToUTF8(encoding, []byte(line)) + // if err != nil { + // log.Errorf("convert to utf-8 error: %v, line: %s", err, line) + // return p, nil + // } + // if ctx.Response().StatusCode() != 200 || (o.checkErr != nil && !o.checkErr(ctx, tmp)) { + // if o.errorCallback != nil { + // o.errorCallback(ctx, tmp) + // } + // return p, nil + // } + // line = string(tmp) + // } + // line = strings.TrimPrefix(line, "data:") + // if line == "" || strings.Trim(line, " ") == "[DONE]" { + // return p, nil + // } + // var resp openai.ChatCompletionResponse + // err := json.Unmarshal([]byte(line), &resp) + // if err != nil { + // return p, nil + // } + // if len(resp.Choices) > 0 { + // outputToken += getTokens(resp.Choices[0].Message.Content) + // totalToken += outputToken + // } + //} + //if err := scanner.Err(); err != nil { + // log.Errorf("scan error: %v", err) + // return p, nil + //} + // + //SetAIModelInputToken(ctx, inputToken) + //SetAIModelOutputToken(ctx, outputToken) + //SetAIModelTotalToken(ctx, totalToken) + return p, nil +} + +func signRequest(signer *v4.Signer, region string, uri string, headers http.Header, body string) (http.Header, error) { + request, err := http.NewRequest(http.MethodPost, uri, nil) + if err != nil { + return nil, err + } + request.Header = headers.Clone() + + _, err = signer.Sign(request, strings.NewReader(body), "bedrock", region, time.Now()) + if err != nil { + return nil, err + } + return request.Header, nil + +} diff --git a/drivers/ai-provider/bedrock/converter.go b/drivers/ai-provider/bedrock/converter.go deleted file mode 100644 index af5fd7e7..00000000 --- a/drivers/ai-provider/bedrock/converter.go +++ /dev/null @@ -1,134 +0,0 @@ -package bedrock - -import ( - "encoding/json" - "fmt" - "net/http" - "strings" - - ai_convert "github.com/eolinker/apinto/ai-convert" - - "github.com/aws/aws-sdk-go/aws/credentials" - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" - - "github.com/eolinker/eosc/eocontext" - http_context "github.com/eolinker/eosc/eocontext/http-context" - "github.com/eolinker/eosc/log" -) - -var _ ai_convert.IConverterFactory = &convertFactory{} - -type convertFactory struct { -} - -func (c *convertFactory) Create(cfg string) (ai_convert.IConverterDriver, error) { - var tmp Config - err := json.Unmarshal([]byte(cfg), &tmp) - if err != nil { - return nil, err - } - return newConverterDriver(&tmp) -} - -var _ ai_convert.IConverterDriver = &converterDriver{} - -type basicConfig struct { - signer *v4.Signer - region string - eocontext.BalanceHandler -} - -type converterDriver struct { - cfg *basicConfig - eocontext.BalanceHandler -} - -func newConverterDriver(cfg *Config) (ai_convert.IConverterDriver, error) { - base := fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", cfg.Region) - balanceHandler, err := ai_convert.NewBalanceHandler("", base, 0) - if err != nil { - return nil, err - } - return &converterDriver{ - cfg: &basicConfig{ - signer: v4.NewSigner(credentials.NewStaticCredentials(cfg.AccessKey, cfg.SecretKey, "")), - region: cfg.Region, - BalanceHandler: balanceHandler, - }, - BalanceHandler: balanceHandler, - }, nil - -} - -func (c *converterDriver) GetModel(model string) (ai_convert.FGenerateConfig, bool) { - if _, ok := modelConvert[model]; !ok { - return nil, false - } - return func(cfg string) (map[string]interface{}, error) { - result := map[string]interface{}{} - if cfg != "" { - tmp := make(map[string]interface{}) - if err := json.Unmarshal([]byte(cfg), &tmp); err != nil { - log.Errorf("unmarshal config error: %v, cfg: %s", err, cfg) - return result, nil - } - modelCfg := ai_convert.MapToStruct[ModelConfig](tmp) - if modelCfg.MaxTokens >= 1 { - result["maxTokens"] = modelCfg.MaxTokens - } - result["temperature"] = modelCfg.Temperature - result["topP"] = modelCfg.TopP - } - return result, nil - }, true -} - -func (c *converterDriver) GetConverter(model string) (ai_convert.IConverter, bool) { - converter, ok := modelConvert[model] - if !ok { - return nil, false - } - - return &Converter{ - converter: converter, - model: model, - basicConfig: c.cfg, - }, true -} - -type Converter struct { - converter ai_convert.IConverter - model string - *basicConfig -} - -func (c *Converter) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error { - if c.BalanceHandler != nil { - ctx.SetBalance(c.BalanceHandler) - } - httpContext, err := http_context.Assert(ctx) - if err != nil { - return err - } - - err = c.converter.RequestConvert(httpContext, extender) - if err != nil { - return err - } - body, _ := httpContext.Proxy().Body().RawBody() - headers, err := signRequest(c.signer, c.region, c.model, http.Header{}, string(body)) - if err != nil { - return err - } - for k, v := range headers { - - httpContext.Proxy().Header().SetHeader(k, strings.Join(v, ";")) - } - //httpContext.Proxy().Header().SetHeader("Authorization", authorization) - //httpContext.Proxy().Header().SetHeader("X-Amz-Date", date) - return nil -} - -func (c *Converter) ResponseConvert(ctx eocontext.EoContext) error { - return c.converter.ResponseConvert(ctx) -} diff --git a/drivers/ai-provider/bedrock/executor.go b/drivers/ai-provider/bedrock/executor.go deleted file mode 100644 index b7a2fc99..00000000 --- a/drivers/ai-provider/bedrock/executor.go +++ /dev/null @@ -1,73 +0,0 @@ -package bedrock - -import ( - "fmt" - "net/http" - "strings" - "time" - - ai_convert "github.com/eolinker/apinto/ai-convert" - - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" - - "github.com/eolinker/apinto/drivers" - - "github.com/eolinker/eosc" -) - -type executor struct { - drivers.WorkerBase -} - -func (e *executor) Start() error { - return nil -} - -func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error { - cfg, ok := conf.(*Config) - if !ok { - return fmt.Errorf("invalid config") - } - - return e.reset(cfg, workers) -} - -func (e *executor) reset(conf *Config, workers map[eosc.RequireId]eosc.IWorker) error { - d, err := newConverterDriver(conf) - if err != nil { - return err - } - e.IConverterDriver = d - - return nil -} - -func (e *executor) Stop() error { - e.IConverterDriver = nil - return nil -} - -func (e *executor) CheckSkill(skill string) bool { - return ai_convert.CheckKeySourceSkill(skill) -} - -type ModelConfig struct { - MaxTokens int `json:"max_tokens"` - Temperature float64 `json:"temperature"` - TopP float64 `json:"top_p"` -} - -func signRequest(signer *v4.Signer, region string, model string, headers http.Header, body string) (http.Header, error) { - request, err := http.NewRequest(http.MethodPost, fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", region, model), nil) - if err != nil { - return nil, err - } - request.Header = headers.Clone() - - _, err = signer.Sign(request, strings.NewReader(body), "bedrock", region, time.Now()) - if err != nil { - return nil, err - } - return request.Header, nil - -} diff --git a/drivers/ai-provider/bedrock/factory.go b/drivers/ai-provider/bedrock/factory.go deleted file mode 100644 index 3ab5e06b..00000000 --- a/drivers/ai-provider/bedrock/factory.go +++ /dev/null @@ -1,46 +0,0 @@ -package bedrock - -import ( - "sync" - - ai_convert "github.com/eolinker/apinto/ai-convert" - - "github.com/eolinker/eosc/common/bean" - - "github.com/eolinker/apinto/drivers" - "github.com/eolinker/eosc" -) - -var name = "bedrock" -var ( - converterManager ai_convert.IManager - once sync.Once -) - -// Register 注册驱动 -func Register(register eosc.IExtenderDriverRegister) { - register.RegisterExtenderDriver(name, NewFactory()) -} - -// NewFactory 创建service_http驱动工厂 -func NewFactory() eosc.IExtenderDriverFactory { - once.Do(func() { - bean.Autowired(&converterManager) - converterManager.Set(name, &convertFactory{}) - }) - return drivers.NewFactory[Config](Create) -} - -// Create 创建驱动实例 -func Create(id, name string, v *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) { - - _, err := checkConfig(v) - if err != nil { - return nil, err - } - w := &executor{ - WorkerBase: drivers.Worker(id, name), - } - w.reset(v, workers) - return w, nil -} diff --git a/drivers/ai-provider/bedrock/mode.go b/drivers/ai-provider/bedrock/mode.go deleted file mode 100644 index 687ea36c..00000000 --- a/drivers/ai-provider/bedrock/mode.go +++ /dev/null @@ -1,119 +0,0 @@ -package bedrock - -import ( - "encoding/json" - "fmt" - - ai_convert "github.com/eolinker/apinto/ai-convert" - - "github.com/eolinker/eosc" - - "github.com/eolinker/eosc/eocontext" - http_context "github.com/eolinker/eosc/eocontext/http-context" -) - -type FNewModelMode func(string) IModelMode - -var ( - modelModes = map[string]FNewModelMode{ - ai_convert.ModeChat.String(): NewChat, - ai_convert.ModeCompletion.String(): NewChat, - } -) - -type IModelMode interface { - Endpoint() string - ai_convert.IConverter -} - -type Chat struct { - endPoint string -} - -func NewChat(model string) IModelMode { - return &Chat{ - endPoint: fmt.Sprintf("/model/%s/converse", model), - } -} - -func (c *Chat) Endpoint() string { - return c.endPoint -} - -func (c *Chat) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error { - httpContext, err := http_context.Assert(ctx) - if err != nil { - return err - } - body, err := httpContext.Proxy().Body().RawBody() - if err != nil { - return err - } - // 设置转发地址 - httpContext.Proxy().URI().SetPath(c.endPoint) - baseCfg := eosc.NewBase[ai_convert.ClientRequest]() - err = json.Unmarshal(body, baseCfg) - if err != nil { - return err - } - messages := make([]Message, 0, len(baseCfg.Config.Messages)) - systemMessage := make([]Content, 0) - - for _, m := range baseCfg.Config.Messages { - if m.Role == "system" { - systemMessage = append(systemMessage, Content{Text: m.Content}) - } else { - messages = append(messages, Message{ - Role: m.Role, - Content: []*Content{{Text: m.Content}}, - }) - } - } - baseCfg.SetAppend("messages", messages) - baseCfg.SetAppend("system", systemMessage) - - for k, v := range extender { - baseCfg.SetAppend(k, v) - } - body, err = json.Marshal(baseCfg) - if err != nil { - return err - } - httpContext.Proxy().Body().SetRaw("application/json", body) - - return nil -} - -func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error { - httpContext, err := http_context.Assert(ctx) - if err != nil { - return err - } - if httpContext.Response().StatusCode() != 200 { - return nil - } - body := httpContext.Response().GetBody() - data := eosc.NewBase[Response]() - err = json.Unmarshal(body, data) - if err != nil { - return err - } - responseBody := &ai_convert.ClientResponse{} - if data.Config.Output.Message != nil && len(data.Config.Output.Message.Content) > 0 { - msg := data.Config.Output.Message - responseBody.Message = &ai_convert.Message{ - Role: msg.Role, - Content: msg.Content[0].Text, - } - responseBody.FinishReason = data.Config.StopReason - } else { - responseBody.Code = -1 - responseBody.Error = "no response" - } - body, err = json.Marshal(responseBody) - if err != nil { - return err - } - httpContext.Response().SetBody(body) - return nil -} diff --git a/encoder/gzip.go b/encoder/gzip.go index b2a30522..525b1585 100644 --- a/encoder/gzip.go +++ b/encoder/gzip.go @@ -7,7 +7,7 @@ import ( ) func init() { - encoderManger.Set("gzip", &Gzip{}) + //encoderManger.Set("gzip", &Gzip{}) } type Gzip struct {