openAI对接提交

This commit is contained in:
Liujian
2024-09-30 17:43:24 +08:00
parent f6931a05a8
commit ffe6e6d8b9
22 changed files with 682 additions and 226 deletions

View File

@@ -1,6 +1,7 @@
package main
import (
"github.com/eolinker/apinto/drivers/ai-provider/openAI"
"github.com/eolinker/apinto/drivers/certs"
"github.com/eolinker/apinto/drivers/discovery/consul"
"github.com/eolinker/apinto/drivers/discovery/eureka"
@@ -82,4 +83,7 @@ func driverRegister(extenderRegister eosc.IExtenderDriverRegister) {
// 证书
certs.Register(extenderRegister)
// AI供应商
openAI.Register(extenderRegister)
}

View File

@@ -3,6 +3,8 @@ package main
import (
access_relational "github.com/eolinker/apinto/drivers/plugins/access-relational"
"github.com/eolinker/apinto/drivers/plugins/acl"
ai_formatter "github.com/eolinker/apinto/drivers/plugins/ai-formatter"
ai_prompt "github.com/eolinker/apinto/drivers/plugins/ai-prompt"
"github.com/eolinker/apinto/drivers/plugins/app"
auto_redirect "github.com/eolinker/apinto/drivers/plugins/auto-redirect"
"github.com/eolinker/apinto/drivers/plugins/cors"
@@ -112,4 +114,8 @@ func pluginRegister(extenderRegister eosc.IExtenderDriverRegister) {
// 鉴权插件
oauth2.Register(extenderRegister)
// ai相关插件
ai_prompt.Register(extenderRegister)
ai_formatter.Register(extenderRegister)
}

View File

@@ -18,7 +18,7 @@ func ApintoProfession() []*eosc.ProfessionConfig {
Name: "router",
Label: "路由",
Desc: "路由",
Dependencies: []string{"service", "template", "transcode"},
Dependencies: []string{"service", "template", "transcode", "ai-provider"},
AppendLabels: []string{"host", "service", "listen", "disable"},
Drivers: []*eosc.DriverConfig{
{
@@ -287,5 +287,21 @@ func ApintoProfession() []*eosc.ProfessionConfig {
},
Mod: eosc.ProfessionConfig_Worker,
},
{
Name: "ai-provider",
Label: "AI服务提供者",
Desc: "AI服务提供者",
Dependencies: nil,
AppendLabels: nil,
Drivers: []*eosc.DriverConfig{
{
Id: "eolinker.com:apinto:openai",
Name: "openAI",
Label: "openAI",
Desc: "openAI",
},
},
Mod: eosc.ProfessionConfig_Worker,
},
}
}

18
convert/convert.go Normal file
View File

@@ -0,0 +1,18 @@
package convert
import "github.com/eolinker/eosc/eocontext"
type IConverterDriver interface {
GetModel(model string) (FGenerateConfig, bool)
GetConverter(model string) (IConverter, bool)
}
type IConverter interface {
RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error
ResponseConvert(ctx eocontext.EoContext) error
}
type FGenerateConfig func(cfg string) (map[string]interface{}, error)
func CheckSkill(skill string) bool {
return skill == "github.com/eolinker/apinto/convert.convert.IConverterDriver"
}

View File

@@ -1,14 +1,5 @@
package ai_provider
import (
"github.com/eolinker/eosc/eocontext"
)
type IConverter interface {
RequestConvert(ctx eocontext.EoContext) error
ResponseConvert(ctx eocontext.EoContext) error
}
type ClientRequest struct {
Messages []*Message `json:"messages"`
}

View File

@@ -0,0 +1,98 @@
package ai_provider
import (
"fmt"
"time"
"github.com/eolinker/eosc/eocontext"
)
var _ eocontext.INode = (*_BaseNode)(nil)
func NewBaseNode(id string, ip string, port int) *_BaseNode {
return &_BaseNode{id: id, ip: ip, port: port}
}
type _BaseNode struct {
id string
ip string
port int
status eocontext.NodeStatus
}
func (n *_BaseNode) GetAttrs() eocontext.Attrs {
return map[string]string{}
}
func (n *_BaseNode) GetAttrByName(name string) (string, bool) {
return "", false
}
func (n *_BaseNode) ID() string {
return n.id
}
func (n *_BaseNode) IP() string {
return n.ip
}
func (n *_BaseNode) Port() int {
return n.port
}
func (n *_BaseNode) Status() eocontext.NodeStatus {
return n.status
}
// Addr 返回节点地址
func (n *_BaseNode) Addr() string {
if n.port == 0 {
return n.ip
}
return fmt.Sprintf("%s:%d", n.ip, n.port)
}
// Up 将节点状态置为运行中
func (n *_BaseNode) Up() {
n.status = eocontext.Running
}
// Down 将节点状态置为不可用
func (n *_BaseNode) Down() {
n.status = eocontext.Down
}
// Leave 将节点状态置为离开
func (n *_BaseNode) Leave() {
n.status = eocontext.Leave
}
func NewBalanceHandler(scheme string, timeout time.Duration, nodes []eocontext.INode) eocontext.BalanceHandler {
return &_BalanceHandler{scheme: scheme, timeout: timeout, nodes: nodes}
}
type _BalanceHandler struct {
scheme string
timeout time.Duration
nodes []eocontext.INode
}
func (b *_BalanceHandler) Select(ctx eocontext.EoContext) (eocontext.INode, int, error) {
if len(b.nodes) == 0 {
return nil, 0, nil
}
return b.nodes[0], 0, nil
}
func (b *_BalanceHandler) Scheme() string {
return b.scheme
}
func (b *_BalanceHandler) TimeOut() time.Duration {
return b.timeout
}
func (b *_BalanceHandler) Nodes() []eocontext.INode {
return b.nodes
}

View File

@@ -1,7 +1,34 @@
package openAI
import (
"fmt"
"net/url"
"github.com/eolinker/eosc"
)
type Config struct {
APIKey string `json:"api_key"`
Organization string `json:"organization"`
Base string `json:"base"`
}
func checkConfig(v interface{}) (*Config, error) {
conf, ok := v.(*Config)
if !ok {
return nil, eosc.ErrorConfigType
}
if conf.APIKey == "" {
return nil, fmt.Errorf("api_key is required")
}
if conf.Base != "" {
u, err := url.Parse(conf.Base)
if err != nil {
return nil, fmt.Errorf("base url is invalid")
}
if u.Scheme == "" || u.Host == "" {
return nil, fmt.Errorf("base url is invalid")
}
}
return conf, nil
}

View File

@@ -1,35 +1,82 @@
package openAI
import (
"time"
"embed"
"fmt"
"net/url"
"strconv"
"strings"
ai_provider "github.com/eolinker/apinto/drivers/ai-provider"
"github.com/eolinker/apinto/convert"
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/eosc"
"github.com/eolinker/eosc/eocontext"
)
var (
//go:embed openai.yaml
providerContent []byte
//go:embed *
providerDir embed.FS
modelConvert = make(map[string]convert.IConverter)
_ convert.IConverterDriver = (*executor)(nil)
)
func init() {
models, err := ai_provider.LoadModels(providerContent, providerDir)
if err != nil {
panic(err)
}
for key, value := range models {
if value.ModelProperties != nil {
if v, ok := modelModes[value.ModelProperties.Mode]; ok {
modelConvert[key] = v
}
}
}
}
type executor struct {
drivers.WorkerBase
apikey string
eocontext.BalanceHandler
}
func (e *executor) Select(ctx eocontext.EoContext) (eocontext.INode, int, error) {
//TODO implement me
panic("implement me")
type Converter struct {
balanceHandler eocontext.BalanceHandler
converter convert.IConverter
}
func (e *executor) Scheme() string {
//TODO implement me
panic("implement me")
func (c *Converter) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error {
if c.balanceHandler != nil {
ctx.SetBalance(c.balanceHandler)
}
return c.converter.RequestConvert(ctx, extender)
}
func (e *executor) TimeOut() time.Duration {
//TODO implement me
panic("implement me")
func (c *Converter) ResponseConvert(ctx eocontext.EoContext) error {
return c.converter.ResponseConvert(ctx)
}
func (e *executor) Nodes() []eocontext.INode {
//TODO implement me
panic("implement me")
func (e *executor) GetConverter(model string) (convert.IConverter, bool) {
converter, ok := modelConvert[model]
if !ok {
return nil, false
}
return &Converter{balanceHandler: e.BalanceHandler, converter: converter}, true
}
func (e *executor) GetModel(model string) (convert.FGenerateConfig, bool) {
if _, ok := modelConvert[model]; !ok {
return nil, false
}
return func(cfg string) (map[string]interface{}, error) {
return nil, nil
}, true
}
func (e *executor) Start() error {
@@ -37,16 +84,43 @@ func (e *executor) Start() error {
}
func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
//TODO implement me
panic("implement me")
cfg, ok := conf.(*Config)
if !ok {
return fmt.Errorf("invalid config")
}
return e.reset(cfg, workers)
}
func (e *executor) reset(conf *Config, workers map[eosc.RequireId]eosc.IWorker) error {
if conf.Base != "" {
u, err := url.Parse(conf.Base)
if err != nil {
return err
}
hosts := strings.Split(u.Host, ":")
ip := hosts[0]
port := 80
if u.Scheme == "https" {
port = 443
}
if len(hosts) > 1 {
port, _ = strconv.Atoi(hosts[1])
}
e.BalanceHandler = ai_provider.NewBalanceHandler(u.Scheme, 0, []eocontext.INode{ai_provider.NewBaseNode(e.Id(), ip, port)})
} else {
e.BalanceHandler = nil
}
e.apikey = conf.APIKey
return nil
}
func (e *executor) Stop() error {
//TODO implement me
panic("implement me")
e.BalanceHandler = nil
return nil
}
func (e *executor) CheckSkill(skill string) bool {
//TODO implement me
panic("implement me")
return convert.CheckSkill(skill)
}

View File

@@ -1 +1,31 @@
package openAI
import (
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/eosc"
)
var name = "openai"
// Register 注册驱动
func Register(register eosc.IExtenderDriverRegister) {
register.RegisterExtenderDriver(name, NewFactory())
}
// NewFactory 创建service_http驱动工厂
func NewFactory() eosc.IExtenderDriverFactory {
return drivers.NewFactory[Config](Create)
}
// Create 创建驱动实例
func Create(id, name string, v *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
_, err := checkConfig(v)
if err != nil {
return nil, err
}
w := &executor{
WorkerBase: drivers.Worker(id, name),
}
w.reset(v, workers)
return w, nil
}

View File

@@ -0,0 +1,21 @@
package openAI
import (
_ "embed"
"testing"
ai_provider "github.com/eolinker/apinto/drivers/ai-provider"
)
func TestLoad(t *testing.T) {
models, err := ai_provider.LoadModels(providerContent, providerDir)
if err != nil {
t.Fatal(err)
}
for key, model := range models {
t.Logf("key:%s,type:%+v", key, model.ModelType)
if model.ModelProperties != nil {
t.Logf("mode:%s,context_size:%d", model.ModelProperties.Mode, model.ModelProperties.ContextSize)
}
}
}

View File

@@ -0,0 +1,38 @@
package openAI
type ClientRequest struct {
Messages []*Message `json:"messages"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Response struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []ResponseChoice `json:"choices"`
Usage Usage `json:"usage"`
}
type ResponseChoice struct {
Index int `json:"index"`
Message Message `json:"message"`
Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"`
}
type CompletionTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
}

View File

@@ -0,0 +1,107 @@
package openAI
import (
"encoding/json"
"github.com/eolinker/eosc"
"github.com/eolinker/apinto/convert"
ai_provider "github.com/eolinker/apinto/drivers/ai-provider"
"github.com/eolinker/eosc/eocontext"
http_context "github.com/eolinker/eosc/eocontext/http-context"
)
var (
modelModes = map[string]IModelMode{
ai_provider.ModeChat.String(): NewChat(),
}
)
type IModelMode interface {
Endpoint() string
convert.IConverter
}
type Chat struct {
endPoint string
}
func NewChat() *Chat {
return &Chat{
endPoint: "/v1/chat/completions",
}
}
func (c *Chat) Endpoint() string {
return c.endPoint
}
func (c *Chat) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error {
httpContext, err := http_context.Assert(ctx)
if err != nil {
return err
}
body, err := httpContext.Proxy().Body().RawBody()
if err != nil {
return err
}
// 设置转发地址
httpContext.Proxy().URI().SetPath(c.endPoint)
baseCfg := eosc.NewBase[ai_provider.ClientRequest]()
err = json.Unmarshal(body, baseCfg)
if err != nil {
return err
}
messages := make([]Message, 0, len(baseCfg.Config.Messages)+1)
for _, m := range baseCfg.Config.Messages {
messages = append(messages, Message{
Role: m.Role,
Content: m.Content,
})
}
baseCfg.SetAppend("messages", messages)
for k, v := range extender {
baseCfg.SetAppend(k, v)
}
body, err = json.Marshal(baseCfg)
if err != nil {
return err
}
httpContext.Proxy().Body().SetRaw("application/json", body)
return nil
}
func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error {
httpContext, err := http_context.Assert(ctx)
if err != nil {
return err
}
if httpContext.Response().StatusCode() != 200 {
return nil
}
body := httpContext.Response().GetBody()
data := eosc.NewBase[Response]()
err = json.Unmarshal(body, data)
if err != nil {
return err
}
responseBody := &ai_provider.ClientResponse{}
if len(data.Config.Choices) > 0 {
msg := data.Config.Choices[0]
responseBody.Message = ai_provider.Message{
Role: msg.Message.Role,
Content: msg.Message.Content,
}
responseBody.FinishReason = msg.FinishReason
} else {
responseBody.Code = -1
responseBody.Error = "no response"
}
body, err = json.Marshal(responseBody)
if err != nil {
return err
}
httpContext.Response().SetBody(body)
return nil
}

View File

@@ -4,86 +4,9 @@ label:
description:
en_US: Models provided by OpenAI, such as GPT-3.5-Turbo and GPT-4.
zh_Hans: OpenAI 提供的模型,例如 GPT-3.5-Turbo 和 GPT-4。
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#E5E7EB"
help:
title:
en_US: Get your API Key from OpenAI
zh_Hans: 从 OpenAI 获取 API Key
url:
en_US: https://platform.openai.com/account/api-keys
supported_model_types:
- llm
- text-embedding
- speech2text
- moderation
- tts
configurate_methods:
- predefined-model
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: openai_api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: openai_organization
label:
zh_Hans: 组织 ID
en_US: Organization
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的组织 ID
en_US: Enter your Organization ID
- variable: openai_api_base
label:
zh_Hans: API Base
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Base
en_US: Enter your API Base
provider_credential_schema:
credential_form_schemas:
- variable: openai_api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: openai_organization
label:
zh_Hans: 组织 ID
en_US: Organization
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的组织 ID
en_US: Enter your Organization ID
- variable: openai_api_base
label:
zh_Hans: API Base
en_US: API Base
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Base, 如https://api.openai.com
en_US: Enter your API Base, e.g. https://api.openai.com

View File

@@ -0,0 +1,78 @@
package ai_provider
import (
"embed"
"strings"
yaml "gopkg.in/yaml.v3"
)
type ModelType string
const (
ModelTypeLLM ModelType = "llm"
ModelTypeTextEmbedding ModelType = "text-embedding"
ModelTypeSpeech2Text ModelType = "speech2text"
ModelTypeModeration ModelType = "moderation"
ModelTypeTTS ModelType = "tts"
)
const (
ModeChat Mode = "chat"
ModeComplete Mode = "complete"
)
type Mode string
func (m Mode) String() string {
return string(m)
}
type Provider struct {
Provider string `json:"provider" yaml:"provider"`
SupportedModelTypes []string `json:"supported_model_types" yaml:"supported_model_types"`
}
type Model struct {
Model string `json:"model" yaml:"model"`
ModelType ModelType `json:"model_type" yaml:"model_type"`
ModelProperties *ModelMode `json:"model_properties" yaml:"model_properties"`
}
type ModelMode struct {
Mode string `json:"mode" yaml:"mode"`
ContextSize int `json:"context_size" yaml:"context_size"`
}
func LoadModels(providerContent []byte, dirFs embed.FS) (map[string]*Model, error) {
var provider Provider
err := yaml.Unmarshal(providerContent, &provider)
if err != nil {
return nil, err
}
models := make(map[string]*Model)
for _, modelType := range provider.SupportedModelTypes {
dirFiles, err := dirFs.ReadDir(modelType)
if err != nil {
// 未找到模型目录
continue
}
for _, dirFile := range dirFiles {
if dirFile.IsDir() || !strings.HasSuffix(dirFile.Name(), ".yaml") {
continue
}
modelContent, err := dirFs.ReadFile(modelType + "/" + dirFile.Name())
if err != nil {
return nil, err
}
var m Model
err = yaml.Unmarshal(modelContent, &m)
if err != nil {
return nil, err
}
models[m.Model] = &m
}
}
return models, nil
}

View File

@@ -1,35 +0,0 @@
package ai_service
import (
"encoding/json"
"strings"
"github.com/eolinker/eosc"
)
// Config service_http驱动配置
type Config struct {
Title string `json:"title" label:"标题"`
Timeout int64 `json:"timeout" label:"请求超时时间" default:"2000" minimum:"1" title:"单位ms最小值1"`
Retry int `json:"retry" label:"失败重试次数"`
Scheme string `json:"scheme" label:"请求协议" enum:"HTTP,HTTPS"`
Provider eosc.RequireId `json:"provider" required:"false" empty_label:"使用匿名上游" label:"服务发现" skill:"github.com/eolinker/apinto/discovery.discovery.IDiscovery"`
}
func (c *Config) String() string {
data, _ := json.Marshal(c)
return string(data)
}
func (c *Config) rebuild() {
if c.Retry < 0 {
c.Retry = 0
}
if c.Timeout < 0 {
c.Timeout = 0
}
c.Scheme = strings.ToLower(c.Scheme)
if c.Scheme != "http" && c.Scheme != "https" {
c.Scheme = "http"
}
}

View File

@@ -1,42 +0,0 @@
package ai_service
import (
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/apinto/service"
"github.com/eolinker/eosc"
"github.com/eolinker/eosc/eocontext"
)
var _ service.IService = &executor{}
type executor struct {
drivers.WorkerBase
title string
eocontext.BalanceHandler
}
func (e *executor) PassHost() (eocontext.PassHostMod, string) {
return eocontext.NodeHost, ""
}
func (e *executor) Title() string {
return e.title
}
func (e *executor) Start() error {
return nil
}
func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
//TODO implement me
panic("implement me")
}
func (e *executor) Stop() error {
return nil
}
func (e *executor) CheckSkill(skill string) bool {
return service.CheckSkill(skill)
}

View File

@@ -1,28 +0,0 @@
package ai_service
import (
"github.com/eolinker/apinto/drivers"
iphash "github.com/eolinker/apinto/upstream/ip-hash"
roundrobin "github.com/eolinker/apinto/upstream/round-robin"
"github.com/eolinker/eosc"
"github.com/eolinker/eosc/log"
)
var DriverName = "service_ai"
// Register 注册service_http驱动工厂
func Register(register eosc.IExtenderDriverRegister) {
err := register.RegisterExtenderDriver(DriverName, NewFactory())
if err != nil {
log.Errorf("register %s %s", DriverName, err)
return
}
}
// NewFactory 创建service_http驱动工厂
func NewFactory() eosc.IExtenderDriverFactory {
roundrobin.Register()
iphash.Register()
return drivers.NewFactory[Config](Create)
}

View File

@@ -0,0 +1,19 @@
package ai_formatter
import (
"github.com/eolinker/eosc"
)
type Config struct {
Provider eosc.RequireId `json:"provider"`
Model string `json:"model"`
Config string `json:"config"`
}
func checkConfig(v interface{}) (*Config, error) {
conf, ok := v.(*Config)
if !ok {
return nil, eosc.ErrorConfigType
}
return conf, nil
}

View File

@@ -1,22 +1,21 @@
package ai_service
package ai_formatter
import (
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/eosc"
)
// Create 创建实例
func Create(id, name string, v *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
w := &executor{
WorkerBase: drivers.Worker(id, name),
title: v.Title,
}
err := w.Reset(v, workers)
_, err := checkConfig(v)
if err != nil {
return nil, err
}
return w, nil
w := &executor{
WorkerBase: drivers.Worker(id, name),
}
err = w.reset(v, workers)
if err != nil {
return nil, err
}
return w, err
}

View File

@@ -0,0 +1,83 @@
package ai_formatter
import (
"errors"
"github.com/eolinker/apinto/convert"
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/eosc"
"github.com/eolinker/eosc/eocontext"
http_context "github.com/eolinker/eosc/eocontext/http-context"
)
type executor struct {
drivers.WorkerBase
model string
extender map[string]interface{}
converter convert.IConverter
}
func (e *executor) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) {
return http_context.DoHttpFilter(e, ctx, next)
}
func (e *executor) DoHttpFilter(ctx http_context.IHttpContext, next eocontext.IChain) error {
err := e.converter.RequestConvert(ctx, e.extender)
if err != nil {
return err
}
if next != nil {
err = next.DoChain(ctx)
if err != nil {
return err
}
}
return e.converter.ResponseConvert(ctx)
}
func (e *executor) Destroy() {
}
func (e *executor) Start() error {
return nil
}
func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
return nil
}
func (e *executor) reset(cfg *Config, workers map[eosc.RequireId]eosc.IWorker) error {
w, ok := workers[cfg.Provider]
if !ok {
return errors.New("invalid provider")
}
if v, ok := w.(convert.IConverterDriver); ok {
converter, has := v.GetConverter(cfg.Model)
if !has {
return errors.New("invalid model")
}
f, has := v.GetModel(cfg.Model)
if !has {
return errors.New("invalid model")
}
extender, err := f(cfg.Config)
if err != nil {
return err
}
e.converter = converter
e.model = cfg.Model
e.extender = extender
return nil
}
return errors.New("provider not implement IConverterDriver")
}
func (e *executor) Stop() error {
return nil
}
func (e *executor) CheckSkill(skill string) bool {
return http_context.FilterSkillName == skill
}

View File

@@ -0,0 +1,28 @@
package ai_formatter
import (
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/eosc"
)
const (
Name = "ai_formatter"
)
func Register(register eosc.IExtenderDriverRegister) {
register.RegisterExtenderDriver(Name, NewFactory())
}
type Factory struct {
eosc.IExtenderDriverFactory
}
func NewFactory() *Factory {
return &Factory{
IExtenderDriverFactory: drivers.NewFactory[Config](Create),
}
}
func (f *Factory) Create(profession string, name string, label string, desc string, params map[string]interface{}) (eosc.IExtenderDriver, error) {
return f.IExtenderDriverFactory.Create(profession, name, label, desc, params)
}

View File

@@ -88,10 +88,11 @@ func (e *executor) Start() error {
}
func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
cfg, ok := conf.(*Config)
if !ok {
return errors.New("invalid config")
}
return nil
}
func (e *executor) reset(cfg *Config, workers map[eosc.RequireId]eosc.IWorker) error {
variables := make(map[string]bool)
required := false