From ffe6e6d8b9b3be9ebad24967bcb2881118ae7068 Mon Sep 17 00:00:00 2001 From: Liujian <824010343@qq.com> Date: Mon, 30 Sep 2024 17:43:24 +0800 Subject: [PATCH] =?UTF-8?q?openAI=E5=AF=B9=E6=8E=A5=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/apinto/driver.go | 4 + app/apinto/plugin.go | 6 + app/apinto/profession.go | 18 ++- convert/convert.go | 18 +++ drivers/ai-provider/converter.go | 9 -- drivers/ai-provider/node.go | 98 +++++++++++++++ drivers/ai-provider/openAI/config.go | 27 +++++ drivers/ai-provider/openAI/executor.go | 112 +++++++++++++++--- drivers/ai-provider/openAI/factory.go | 30 +++++ drivers/ai-provider/openAI/load_test.go | 21 ++++ drivers/ai-provider/openAI/message.go | 38 ++++++ drivers/ai-provider/openAI/mode.go | 107 +++++++++++++++++ drivers/ai-provider/openAI/openai.yaml | 77 ------------ drivers/ai-provider/provider.go | 78 ++++++++++++ drivers/ai-service/config.go | 35 ------ drivers/ai-service/executor.go | 42 ------- drivers/ai-service/factory.go | 28 ----- drivers/plugins/ai-formatter/config.go | 19 +++ .../ai-formatter}/driver.go | 21 ++-- drivers/plugins/ai-formatter/executor.go | 83 +++++++++++++ drivers/plugins/ai-formatter/factory.go | 28 +++++ drivers/plugins/ai-prompt/executor.go | 9 +- 22 files changed, 682 insertions(+), 226 deletions(-) create mode 100644 convert/convert.go create mode 100644 drivers/ai-provider/node.go create mode 100644 drivers/ai-provider/openAI/load_test.go create mode 100644 drivers/ai-provider/openAI/message.go create mode 100644 drivers/ai-provider/openAI/mode.go create mode 100644 drivers/ai-provider/provider.go delete mode 100644 drivers/ai-service/config.go delete mode 100644 drivers/ai-service/executor.go delete mode 100644 drivers/ai-service/factory.go create mode 100644 drivers/plugins/ai-formatter/config.go rename drivers/{ai-service => plugins/ai-formatter}/driver.go (69%) create mode 100644 drivers/plugins/ai-formatter/executor.go create mode 100644 drivers/plugins/ai-formatter/factory.go diff --git a/app/apinto/driver.go b/app/apinto/driver.go index de236992..a1d38894 100644 --- a/app/apinto/driver.go +++ b/app/apinto/driver.go @@ -1,6 +1,7 @@ package main import ( + "github.com/eolinker/apinto/drivers/ai-provider/openAI" "github.com/eolinker/apinto/drivers/certs" "github.com/eolinker/apinto/drivers/discovery/consul" "github.com/eolinker/apinto/drivers/discovery/eureka" @@ -82,4 +83,7 @@ func driverRegister(extenderRegister eosc.IExtenderDriverRegister) { // 证书 certs.Register(extenderRegister) + + // AI供应商 + openAI.Register(extenderRegister) } diff --git a/app/apinto/plugin.go b/app/apinto/plugin.go index 3f923ff6..2a8d9da2 100644 --- a/app/apinto/plugin.go +++ b/app/apinto/plugin.go @@ -3,6 +3,8 @@ package main import ( access_relational "github.com/eolinker/apinto/drivers/plugins/access-relational" "github.com/eolinker/apinto/drivers/plugins/acl" + ai_formatter "github.com/eolinker/apinto/drivers/plugins/ai-formatter" + ai_prompt "github.com/eolinker/apinto/drivers/plugins/ai-prompt" "github.com/eolinker/apinto/drivers/plugins/app" auto_redirect "github.com/eolinker/apinto/drivers/plugins/auto-redirect" "github.com/eolinker/apinto/drivers/plugins/cors" @@ -112,4 +114,8 @@ func pluginRegister(extenderRegister eosc.IExtenderDriverRegister) { // 鉴权插件 oauth2.Register(extenderRegister) + + // ai相关插件 + ai_prompt.Register(extenderRegister) + ai_formatter.Register(extenderRegister) } diff --git a/app/apinto/profession.go b/app/apinto/profession.go index c3858e84..6ad19c30 100644 --- a/app/apinto/profession.go +++ b/app/apinto/profession.go @@ -18,7 +18,7 @@ func ApintoProfession() []*eosc.ProfessionConfig { Name: "router", Label: "路由", Desc: "路由", - Dependencies: []string{"service", "template", "transcode"}, + Dependencies: []string{"service", "template", "transcode", "ai-provider"}, AppendLabels: []string{"host", "service", "listen", "disable"}, Drivers: []*eosc.DriverConfig{ { @@ -287,5 +287,21 @@ func ApintoProfession() []*eosc.ProfessionConfig { }, Mod: eosc.ProfessionConfig_Worker, }, + { + Name: "ai-provider", + Label: "AI服务提供者", + Desc: "AI服务提供者", + Dependencies: nil, + AppendLabels: nil, + Drivers: []*eosc.DriverConfig{ + { + Id: "eolinker.com:apinto:openai", + Name: "openAI", + Label: "openAI", + Desc: "openAI", + }, + }, + Mod: eosc.ProfessionConfig_Worker, + }, } } diff --git a/convert/convert.go b/convert/convert.go new file mode 100644 index 00000000..632e2e36 --- /dev/null +++ b/convert/convert.go @@ -0,0 +1,18 @@ +package convert + +import "github.com/eolinker/eosc/eocontext" + +type IConverterDriver interface { + GetModel(model string) (FGenerateConfig, bool) + GetConverter(model string) (IConverter, bool) +} + +type IConverter interface { + RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error + ResponseConvert(ctx eocontext.EoContext) error +} +type FGenerateConfig func(cfg string) (map[string]interface{}, error) + +func CheckSkill(skill string) bool { + return skill == "github.com/eolinker/apinto/convert.convert.IConverterDriver" +} diff --git a/drivers/ai-provider/converter.go b/drivers/ai-provider/converter.go index 31692311..6d763ebc 100644 --- a/drivers/ai-provider/converter.go +++ b/drivers/ai-provider/converter.go @@ -1,14 +1,5 @@ package ai_provider -import ( - "github.com/eolinker/eosc/eocontext" -) - -type IConverter interface { - RequestConvert(ctx eocontext.EoContext) error - ResponseConvert(ctx eocontext.EoContext) error -} - type ClientRequest struct { Messages []*Message `json:"messages"` } diff --git a/drivers/ai-provider/node.go b/drivers/ai-provider/node.go new file mode 100644 index 00000000..1fbf70a2 --- /dev/null +++ b/drivers/ai-provider/node.go @@ -0,0 +1,98 @@ +package ai_provider + +import ( + "fmt" + "time" + + "github.com/eolinker/eosc/eocontext" +) + +var _ eocontext.INode = (*_BaseNode)(nil) + +func NewBaseNode(id string, ip string, port int) *_BaseNode { + return &_BaseNode{id: id, ip: ip, port: port} +} + +type _BaseNode struct { + id string + ip string + port int + status eocontext.NodeStatus +} + +func (n *_BaseNode) GetAttrs() eocontext.Attrs { + return map[string]string{} +} + +func (n *_BaseNode) GetAttrByName(name string) (string, bool) { + return "", false +} + +func (n *_BaseNode) ID() string { + return n.id +} + +func (n *_BaseNode) IP() string { + return n.ip +} + +func (n *_BaseNode) Port() int { + return n.port +} + +func (n *_BaseNode) Status() eocontext.NodeStatus { + + return n.status +} + +// Addr 返回节点地址 +func (n *_BaseNode) Addr() string { + if n.port == 0 { + return n.ip + } + return fmt.Sprintf("%s:%d", n.ip, n.port) +} + +// Up 将节点状态置为运行中 +func (n *_BaseNode) Up() { + n.status = eocontext.Running +} + +// Down 将节点状态置为不可用 +func (n *_BaseNode) Down() { + n.status = eocontext.Down +} + +// Leave 将节点状态置为离开 +func (n *_BaseNode) Leave() { + n.status = eocontext.Leave +} + +func NewBalanceHandler(scheme string, timeout time.Duration, nodes []eocontext.INode) eocontext.BalanceHandler { + return &_BalanceHandler{scheme: scheme, timeout: timeout, nodes: nodes} +} + +type _BalanceHandler struct { + scheme string + timeout time.Duration + nodes []eocontext.INode +} + +func (b *_BalanceHandler) Select(ctx eocontext.EoContext) (eocontext.INode, int, error) { + if len(b.nodes) == 0 { + return nil, 0, nil + } + return b.nodes[0], 0, nil +} + +func (b *_BalanceHandler) Scheme() string { + return b.scheme +} + +func (b *_BalanceHandler) TimeOut() time.Duration { + return b.timeout +} + +func (b *_BalanceHandler) Nodes() []eocontext.INode { + return b.nodes +} diff --git a/drivers/ai-provider/openAI/config.go b/drivers/ai-provider/openAI/config.go index 4bcc1fb0..7a3f286e 100644 --- a/drivers/ai-provider/openAI/config.go +++ b/drivers/ai-provider/openAI/config.go @@ -1,7 +1,34 @@ package openAI +import ( + "fmt" + "net/url" + + "github.com/eolinker/eosc" +) + type Config struct { APIKey string `json:"api_key"` Organization string `json:"organization"` Base string `json:"base"` } + +func checkConfig(v interface{}) (*Config, error) { + conf, ok := v.(*Config) + if !ok { + return nil, eosc.ErrorConfigType + } + if conf.APIKey == "" { + return nil, fmt.Errorf("api_key is required") + } + if conf.Base != "" { + u, err := url.Parse(conf.Base) + if err != nil { + return nil, fmt.Errorf("base url is invalid") + } + if u.Scheme == "" || u.Host == "" { + return nil, fmt.Errorf("base url is invalid") + } + } + return conf, nil +} diff --git a/drivers/ai-provider/openAI/executor.go b/drivers/ai-provider/openAI/executor.go index 674ae4f7..9ff0bb8a 100644 --- a/drivers/ai-provider/openAI/executor.go +++ b/drivers/ai-provider/openAI/executor.go @@ -1,35 +1,82 @@ package openAI import ( - "time" + "embed" + "fmt" + "net/url" + "strconv" + "strings" + ai_provider "github.com/eolinker/apinto/drivers/ai-provider" + + "github.com/eolinker/apinto/convert" "github.com/eolinker/apinto/drivers" "github.com/eolinker/eosc" "github.com/eolinker/eosc/eocontext" ) +var ( + //go:embed openai.yaml + providerContent []byte + //go:embed * + providerDir embed.FS + modelConvert = make(map[string]convert.IConverter) + + _ convert.IConverterDriver = (*executor)(nil) +) + +func init() { + models, err := ai_provider.LoadModels(providerContent, providerDir) + if err != nil { + panic(err) + } + for key, value := range models { + if value.ModelProperties != nil { + if v, ok := modelModes[value.ModelProperties.Mode]; ok { + modelConvert[key] = v + } + } + } +} + type executor struct { drivers.WorkerBase + apikey string + eocontext.BalanceHandler } -func (e *executor) Select(ctx eocontext.EoContext) (eocontext.INode, int, error) { - //TODO implement me - panic("implement me") +type Converter struct { + balanceHandler eocontext.BalanceHandler + converter convert.IConverter } -func (e *executor) Scheme() string { - //TODO implement me - panic("implement me") +func (c *Converter) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error { + if c.balanceHandler != nil { + ctx.SetBalance(c.balanceHandler) + } + return c.converter.RequestConvert(ctx, extender) } -func (e *executor) TimeOut() time.Duration { - //TODO implement me - panic("implement me") +func (c *Converter) ResponseConvert(ctx eocontext.EoContext) error { + return c.converter.ResponseConvert(ctx) } -func (e *executor) Nodes() []eocontext.INode { - //TODO implement me - panic("implement me") +func (e *executor) GetConverter(model string) (convert.IConverter, bool) { + converter, ok := modelConvert[model] + if !ok { + return nil, false + } + + return &Converter{balanceHandler: e.BalanceHandler, converter: converter}, true +} + +func (e *executor) GetModel(model string) (convert.FGenerateConfig, bool) { + if _, ok := modelConvert[model]; !ok { + return nil, false + } + return func(cfg string) (map[string]interface{}, error) { + return nil, nil + }, true } func (e *executor) Start() error { @@ -37,16 +84,43 @@ func (e *executor) Start() error { } func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error { - //TODO implement me - panic("implement me") + cfg, ok := conf.(*Config) + if !ok { + return fmt.Errorf("invalid config") + } + + return e.reset(cfg, workers) +} + +func (e *executor) reset(conf *Config, workers map[eosc.RequireId]eosc.IWorker) error { + if conf.Base != "" { + u, err := url.Parse(conf.Base) + if err != nil { + return err + } + hosts := strings.Split(u.Host, ":") + ip := hosts[0] + port := 80 + if u.Scheme == "https" { + port = 443 + } + if len(hosts) > 1 { + port, _ = strconv.Atoi(hosts[1]) + } + e.BalanceHandler = ai_provider.NewBalanceHandler(u.Scheme, 0, []eocontext.INode{ai_provider.NewBaseNode(e.Id(), ip, port)}) + } else { + e.BalanceHandler = nil + } + + e.apikey = conf.APIKey + return nil } func (e *executor) Stop() error { - //TODO implement me - panic("implement me") + e.BalanceHandler = nil + return nil } func (e *executor) CheckSkill(skill string) bool { - //TODO implement me - panic("implement me") + return convert.CheckSkill(skill) } diff --git a/drivers/ai-provider/openAI/factory.go b/drivers/ai-provider/openAI/factory.go index 8cfddee3..815d8107 100644 --- a/drivers/ai-provider/openAI/factory.go +++ b/drivers/ai-provider/openAI/factory.go @@ -1 +1,31 @@ package openAI + +import ( + "github.com/eolinker/apinto/drivers" + "github.com/eolinker/eosc" +) + +var name = "openai" + +// Register 注册驱动 +func Register(register eosc.IExtenderDriverRegister) { + register.RegisterExtenderDriver(name, NewFactory()) +} + +// NewFactory 创建service_http驱动工厂 +func NewFactory() eosc.IExtenderDriverFactory { + return drivers.NewFactory[Config](Create) +} + +// Create 创建驱动实例 +func Create(id, name string, v *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) { + _, err := checkConfig(v) + if err != nil { + return nil, err + } + w := &executor{ + WorkerBase: drivers.Worker(id, name), + } + w.reset(v, workers) + return w, nil +} diff --git a/drivers/ai-provider/openAI/load_test.go b/drivers/ai-provider/openAI/load_test.go new file mode 100644 index 00000000..7fcf97bb --- /dev/null +++ b/drivers/ai-provider/openAI/load_test.go @@ -0,0 +1,21 @@ +package openAI + +import ( + _ "embed" + "testing" + + ai_provider "github.com/eolinker/apinto/drivers/ai-provider" +) + +func TestLoad(t *testing.T) { + models, err := ai_provider.LoadModels(providerContent, providerDir) + if err != nil { + t.Fatal(err) + } + for key, model := range models { + t.Logf("key:%s,type:%+v", key, model.ModelType) + if model.ModelProperties != nil { + t.Logf("mode:%s,context_size:%d", model.ModelProperties.Mode, model.ModelProperties.ContextSize) + } + } +} diff --git a/drivers/ai-provider/openAI/message.go b/drivers/ai-provider/openAI/message.go new file mode 100644 index 00000000..bc62a02f --- /dev/null +++ b/drivers/ai-provider/openAI/message.go @@ -0,0 +1,38 @@ +package openAI + +type ClientRequest struct { + Messages []*Message `json:"messages"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type Response struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []ResponseChoice `json:"choices"` + Usage Usage `json:"usage"` +} + +type ResponseChoice struct { + Index int `json:"index"` + Message Message `json:"message"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"` +} + +type CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` +} diff --git a/drivers/ai-provider/openAI/mode.go b/drivers/ai-provider/openAI/mode.go new file mode 100644 index 00000000..6df0fe4c --- /dev/null +++ b/drivers/ai-provider/openAI/mode.go @@ -0,0 +1,107 @@ +package openAI + +import ( + "encoding/json" + + "github.com/eolinker/eosc" + + "github.com/eolinker/apinto/convert" + ai_provider "github.com/eolinker/apinto/drivers/ai-provider" + "github.com/eolinker/eosc/eocontext" + http_context "github.com/eolinker/eosc/eocontext/http-context" +) + +var ( + modelModes = map[string]IModelMode{ + ai_provider.ModeChat.String(): NewChat(), + } +) + +type IModelMode interface { + Endpoint() string + convert.IConverter +} + +type Chat struct { + endPoint string +} + +func NewChat() *Chat { + return &Chat{ + endPoint: "/v1/chat/completions", + } +} + +func (c *Chat) Endpoint() string { + return c.endPoint +} + +func (c *Chat) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error { + httpContext, err := http_context.Assert(ctx) + if err != nil { + return err + } + body, err := httpContext.Proxy().Body().RawBody() + if err != nil { + return err + } + // 设置转发地址 + httpContext.Proxy().URI().SetPath(c.endPoint) + baseCfg := eosc.NewBase[ai_provider.ClientRequest]() + err = json.Unmarshal(body, baseCfg) + if err != nil { + return err + } + messages := make([]Message, 0, len(baseCfg.Config.Messages)+1) + for _, m := range baseCfg.Config.Messages { + messages = append(messages, Message{ + Role: m.Role, + Content: m.Content, + }) + } + baseCfg.SetAppend("messages", messages) + for k, v := range extender { + baseCfg.SetAppend(k, v) + } + body, err = json.Marshal(baseCfg) + if err != nil { + return err + } + httpContext.Proxy().Body().SetRaw("application/json", body) + + return nil +} + +func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error { + httpContext, err := http_context.Assert(ctx) + if err != nil { + return err + } + if httpContext.Response().StatusCode() != 200 { + return nil + } + body := httpContext.Response().GetBody() + data := eosc.NewBase[Response]() + err = json.Unmarshal(body, data) + if err != nil { + return err + } + responseBody := &ai_provider.ClientResponse{} + if len(data.Config.Choices) > 0 { + msg := data.Config.Choices[0] + responseBody.Message = ai_provider.Message{ + Role: msg.Message.Role, + Content: msg.Message.Content, + } + responseBody.FinishReason = msg.FinishReason + } else { + responseBody.Code = -1 + responseBody.Error = "no response" + } + body, err = json.Marshal(responseBody) + if err != nil { + return err + } + httpContext.Response().SetBody(body) + return nil +} diff --git a/drivers/ai-provider/openAI/openai.yaml b/drivers/ai-provider/openAI/openai.yaml index b4dc8fd4..1dd132e3 100644 --- a/drivers/ai-provider/openAI/openai.yaml +++ b/drivers/ai-provider/openAI/openai.yaml @@ -4,86 +4,9 @@ label: description: en_US: Models provided by OpenAI, such as GPT-3.5-Turbo and GPT-4. zh_Hans: OpenAI 提供的模型,例如 GPT-3.5-Turbo 和 GPT-4。 -icon_small: - en_US: icon_s_en.svg -icon_large: - en_US: icon_l_en.svg -background: "#E5E7EB" -help: - title: - en_US: Get your API Key from OpenAI - zh_Hans: 从 OpenAI 获取 API Key - url: - en_US: https://platform.openai.com/account/api-keys supported_model_types: - llm - text-embedding - speech2text - moderation - tts -configurate_methods: - - predefined-model - - customizable-model -model_credential_schema: - model: - label: - en_US: Model Name - zh_Hans: 模型名称 - placeholder: - en_US: Enter your model name - zh_Hans: 输入模型名称 - credential_form_schemas: - - variable: openai_api_key - label: - en_US: API Key - type: secret-input - required: true - placeholder: - zh_Hans: 在此输入您的 API Key - en_US: Enter your API Key - - variable: openai_organization - label: - zh_Hans: 组织 ID - en_US: Organization - type: text-input - required: false - placeholder: - zh_Hans: 在此输入您的组织 ID - en_US: Enter your Organization ID - - variable: openai_api_base - label: - zh_Hans: API Base - en_US: API Base - type: text-input - required: false - placeholder: - zh_Hans: 在此输入您的 API Base - en_US: Enter your API Base -provider_credential_schema: - credential_form_schemas: - - variable: openai_api_key - label: - en_US: API Key - type: secret-input - required: true - placeholder: - zh_Hans: 在此输入您的 API Key - en_US: Enter your API Key - - variable: openai_organization - label: - zh_Hans: 组织 ID - en_US: Organization - type: text-input - required: false - placeholder: - zh_Hans: 在此输入您的组织 ID - en_US: Enter your Organization ID - - variable: openai_api_base - label: - zh_Hans: API Base - en_US: API Base - type: text-input - required: false - placeholder: - zh_Hans: 在此输入您的 API Base, 如:https://api.openai.com - en_US: Enter your API Base, e.g. https://api.openai.com diff --git a/drivers/ai-provider/provider.go b/drivers/ai-provider/provider.go new file mode 100644 index 00000000..1ac253c0 --- /dev/null +++ b/drivers/ai-provider/provider.go @@ -0,0 +1,78 @@ +package ai_provider + +import ( + "embed" + "strings" + + yaml "gopkg.in/yaml.v3" +) + +type ModelType string + +const ( + ModelTypeLLM ModelType = "llm" + ModelTypeTextEmbedding ModelType = "text-embedding" + ModelTypeSpeech2Text ModelType = "speech2text" + ModelTypeModeration ModelType = "moderation" + ModelTypeTTS ModelType = "tts" +) + +const ( + ModeChat Mode = "chat" + ModeComplete Mode = "complete" +) + +type Mode string + +func (m Mode) String() string { + return string(m) +} + +type Provider struct { + Provider string `json:"provider" yaml:"provider"` + SupportedModelTypes []string `json:"supported_model_types" yaml:"supported_model_types"` +} + +type Model struct { + Model string `json:"model" yaml:"model"` + ModelType ModelType `json:"model_type" yaml:"model_type"` + ModelProperties *ModelMode `json:"model_properties" yaml:"model_properties"` +} + +type ModelMode struct { + Mode string `json:"mode" yaml:"mode"` + ContextSize int `json:"context_size" yaml:"context_size"` +} + +func LoadModels(providerContent []byte, dirFs embed.FS) (map[string]*Model, error) { + var provider Provider + err := yaml.Unmarshal(providerContent, &provider) + if err != nil { + return nil, err + } + models := make(map[string]*Model) + for _, modelType := range provider.SupportedModelTypes { + dirFiles, err := dirFs.ReadDir(modelType) + if err != nil { + // 未找到模型目录 + continue + } + for _, dirFile := range dirFiles { + if dirFile.IsDir() || !strings.HasSuffix(dirFile.Name(), ".yaml") { + continue + } + modelContent, err := dirFs.ReadFile(modelType + "/" + dirFile.Name()) + if err != nil { + return nil, err + } + var m Model + err = yaml.Unmarshal(modelContent, &m) + if err != nil { + return nil, err + } + models[m.Model] = &m + } + + } + return models, nil +} diff --git a/drivers/ai-service/config.go b/drivers/ai-service/config.go deleted file mode 100644 index 10ba1894..00000000 --- a/drivers/ai-service/config.go +++ /dev/null @@ -1,35 +0,0 @@ -package ai_service - -import ( - "encoding/json" - "strings" - - "github.com/eolinker/eosc" -) - -// Config service_http驱动配置 -type Config struct { - Title string `json:"title" label:"标题"` - Timeout int64 `json:"timeout" label:"请求超时时间" default:"2000" minimum:"1" title:"单位:ms,最小值:1"` - Retry int `json:"retry" label:"失败重试次数"` - Scheme string `json:"scheme" label:"请求协议" enum:"HTTP,HTTPS"` - Provider eosc.RequireId `json:"provider" required:"false" empty_label:"使用匿名上游" label:"服务发现" skill:"github.com/eolinker/apinto/discovery.discovery.IDiscovery"` -} - -func (c *Config) String() string { - data, _ := json.Marshal(c) - return string(data) -} -func (c *Config) rebuild() { - if c.Retry < 0 { - c.Retry = 0 - } - if c.Timeout < 0 { - c.Timeout = 0 - } - c.Scheme = strings.ToLower(c.Scheme) - if c.Scheme != "http" && c.Scheme != "https" { - c.Scheme = "http" - } - -} diff --git a/drivers/ai-service/executor.go b/drivers/ai-service/executor.go deleted file mode 100644 index 606c30ed..00000000 --- a/drivers/ai-service/executor.go +++ /dev/null @@ -1,42 +0,0 @@ -package ai_service - -import ( - "github.com/eolinker/apinto/drivers" - "github.com/eolinker/apinto/service" - - "github.com/eolinker/eosc" - "github.com/eolinker/eosc/eocontext" -) - -var _ service.IService = &executor{} - -type executor struct { - drivers.WorkerBase - title string - eocontext.BalanceHandler -} - -func (e *executor) PassHost() (eocontext.PassHostMod, string) { - return eocontext.NodeHost, "" -} - -func (e *executor) Title() string { - return e.title -} - -func (e *executor) Start() error { - return nil -} - -func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error { - //TODO implement me - panic("implement me") -} - -func (e *executor) Stop() error { - return nil -} - -func (e *executor) CheckSkill(skill string) bool { - return service.CheckSkill(skill) -} diff --git a/drivers/ai-service/factory.go b/drivers/ai-service/factory.go deleted file mode 100644 index 27d058be..00000000 --- a/drivers/ai-service/factory.go +++ /dev/null @@ -1,28 +0,0 @@ -package ai_service - -import ( - "github.com/eolinker/apinto/drivers" - iphash "github.com/eolinker/apinto/upstream/ip-hash" - roundrobin "github.com/eolinker/apinto/upstream/round-robin" - "github.com/eolinker/eosc" - "github.com/eolinker/eosc/log" -) - -var DriverName = "service_ai" - -// Register 注册service_http驱动工厂 -func Register(register eosc.IExtenderDriverRegister) { - err := register.RegisterExtenderDriver(DriverName, NewFactory()) - if err != nil { - log.Errorf("register %s %s", DriverName, err) - return - - } -} - -// NewFactory 创建service_http驱动工厂 -func NewFactory() eosc.IExtenderDriverFactory { - roundrobin.Register() - iphash.Register() - return drivers.NewFactory[Config](Create) -} diff --git a/drivers/plugins/ai-formatter/config.go b/drivers/plugins/ai-formatter/config.go new file mode 100644 index 00000000..cdd004fb --- /dev/null +++ b/drivers/plugins/ai-formatter/config.go @@ -0,0 +1,19 @@ +package ai_formatter + +import ( + "github.com/eolinker/eosc" +) + +type Config struct { + Provider eosc.RequireId `json:"provider"` + Model string `json:"model"` + Config string `json:"config"` +} + +func checkConfig(v interface{}) (*Config, error) { + conf, ok := v.(*Config) + if !ok { + return nil, eosc.ErrorConfigType + } + return conf, nil +} diff --git a/drivers/ai-service/driver.go b/drivers/plugins/ai-formatter/driver.go similarity index 69% rename from drivers/ai-service/driver.go rename to drivers/plugins/ai-formatter/driver.go index 69ac132d..677a02d1 100644 --- a/drivers/ai-service/driver.go +++ b/drivers/plugins/ai-formatter/driver.go @@ -1,22 +1,21 @@ -package ai_service +package ai_formatter import ( "github.com/eolinker/apinto/drivers" "github.com/eolinker/eosc" ) -// Create 创建实例 func Create(id, name string, v *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) { - - w := &executor{ - WorkerBase: drivers.Worker(id, name), - title: v.Title, - } - - err := w.Reset(v, workers) + _, err := checkConfig(v) if err != nil { return nil, err } - - return w, nil + w := &executor{ + WorkerBase: drivers.Worker(id, name), + } + err = w.reset(v, workers) + if err != nil { + return nil, err + } + return w, err } diff --git a/drivers/plugins/ai-formatter/executor.go b/drivers/plugins/ai-formatter/executor.go new file mode 100644 index 00000000..4f6dc58e --- /dev/null +++ b/drivers/plugins/ai-formatter/executor.go @@ -0,0 +1,83 @@ +package ai_formatter + +import ( + "errors" + + "github.com/eolinker/apinto/convert" + + "github.com/eolinker/apinto/drivers" + "github.com/eolinker/eosc" + "github.com/eolinker/eosc/eocontext" + http_context "github.com/eolinker/eosc/eocontext/http-context" +) + +type executor struct { + drivers.WorkerBase + model string + extender map[string]interface{} + converter convert.IConverter +} + +func (e *executor) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) { + return http_context.DoHttpFilter(e, ctx, next) +} + +func (e *executor) DoHttpFilter(ctx http_context.IHttpContext, next eocontext.IChain) error { + err := e.converter.RequestConvert(ctx, e.extender) + if err != nil { + return err + } + if next != nil { + err = next.DoChain(ctx) + if err != nil { + return err + } + } + return e.converter.ResponseConvert(ctx) +} + +func (e *executor) Destroy() { +} + +func (e *executor) Start() error { + return nil +} + +func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error { + return nil +} + +func (e *executor) reset(cfg *Config, workers map[eosc.RequireId]eosc.IWorker) error { + w, ok := workers[cfg.Provider] + if !ok { + return errors.New("invalid provider") + } + if v, ok := w.(convert.IConverterDriver); ok { + converter, has := v.GetConverter(cfg.Model) + if !has { + return errors.New("invalid model") + } + f, has := v.GetModel(cfg.Model) + if !has { + return errors.New("invalid model") + } + + extender, err := f(cfg.Config) + if err != nil { + return err + } + e.converter = converter + e.model = cfg.Model + e.extender = extender + return nil + } + return errors.New("provider not implement IConverterDriver") +} + +func (e *executor) Stop() error { + return nil +} + +func (e *executor) CheckSkill(skill string) bool { + return http_context.FilterSkillName == skill +} diff --git a/drivers/plugins/ai-formatter/factory.go b/drivers/plugins/ai-formatter/factory.go new file mode 100644 index 00000000..bd09be52 --- /dev/null +++ b/drivers/plugins/ai-formatter/factory.go @@ -0,0 +1,28 @@ +package ai_formatter + +import ( + "github.com/eolinker/apinto/drivers" + "github.com/eolinker/eosc" +) + +const ( + Name = "ai_formatter" +) + +func Register(register eosc.IExtenderDriverRegister) { + register.RegisterExtenderDriver(Name, NewFactory()) +} + +type Factory struct { + eosc.IExtenderDriverFactory +} + +func NewFactory() *Factory { + return &Factory{ + IExtenderDriverFactory: drivers.NewFactory[Config](Create), + } +} + +func (f *Factory) Create(profession string, name string, label string, desc string, params map[string]interface{}) (eosc.IExtenderDriver, error) { + return f.IExtenderDriverFactory.Create(profession, name, label, desc, params) +} diff --git a/drivers/plugins/ai-prompt/executor.go b/drivers/plugins/ai-prompt/executor.go index b4ed9fe7..56e2f725 100644 --- a/drivers/plugins/ai-prompt/executor.go +++ b/drivers/plugins/ai-prompt/executor.go @@ -88,10 +88,11 @@ func (e *executor) Start() error { } func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error { - cfg, ok := conf.(*Config) - if !ok { - return errors.New("invalid config") - } + + return nil +} + +func (e *executor) reset(cfg *Config, workers map[eosc.RequireId]eosc.IWorker) error { variables := make(map[string]bool) required := false