mirror of
https://github.com/eolinker/apinto
synced 2025-12-24 13:28:15 +08:00
更新bedrock
This commit is contained in:
@@ -8,11 +8,6 @@ type IConverterFactory interface {
|
||||
|
||||
type IConverterCreateFunc func(cfg string) (IConverter, error)
|
||||
|
||||
//type IConverterDriver interface {
|
||||
// GetModel(model string) (FGenerateConfig, bool)
|
||||
// GetConverter(model string) (IConverter, bool)
|
||||
//}
|
||||
|
||||
type IConverter interface {
|
||||
RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error
|
||||
ResponseConvert(ctx eocontext.EoContext) error
|
||||
|
||||
@@ -237,3 +237,45 @@ func SetProvider(id string, p IProvider) {
|
||||
func GetProvider(provider string) (IProvider, bool) {
|
||||
return balanceManager.Get(provider)
|
||||
}
|
||||
|
||||
type IModelAccessConfigManager interface {
|
||||
Get(id string) (IModelAccessConfig, bool)
|
||||
Set(id string, config IModelAccessConfig)
|
||||
Del(id string)
|
||||
}
|
||||
|
||||
type IModelAccessConfig interface {
|
||||
Provider() string
|
||||
Model() string
|
||||
Config() map[string]string
|
||||
}
|
||||
|
||||
type modelAccessConfigManager struct {
|
||||
configs eosc.Untyped[string, IModelAccessConfig]
|
||||
}
|
||||
|
||||
func (m *modelAccessConfigManager) Get(id string) (IModelAccessConfig, bool) {
|
||||
return m.configs.Get(id)
|
||||
}
|
||||
|
||||
func (m *modelAccessConfigManager) Set(id string, config IModelAccessConfig) {
|
||||
m.configs.Set(id, config)
|
||||
}
|
||||
|
||||
func (m *modelAccessConfigManager) Del(id string) {
|
||||
m.configs.Del(id)
|
||||
}
|
||||
|
||||
func NewModelAccessConfigManager() *modelAccessConfigManager {
|
||||
return &modelAccessConfigManager{
|
||||
configs: eosc.BuildUntyped[string, IModelAccessConfig](),
|
||||
}
|
||||
}
|
||||
|
||||
var _ IModelAccessConfigManager = (*modelAccessConfigManager)(nil)
|
||||
|
||||
var modelAcManager = NewModelAccessConfigManager()
|
||||
|
||||
func init() {
|
||||
bean.Injection(&modelAcManager)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
ai_key "github.com/eolinker/apinto/drivers/ai-key"
|
||||
ai_model "github.com/eolinker/apinto/drivers/ai-model"
|
||||
ai_provider "github.com/eolinker/apinto/drivers/ai-provider"
|
||||
"github.com/eolinker/apinto/drivers/certs"
|
||||
"github.com/eolinker/apinto/drivers/discovery/consul"
|
||||
@@ -125,4 +126,5 @@ func driverRegister(extenderRegister eosc.IExtenderDriverRegister) {
|
||||
|
||||
ai_provider.Register(extenderRegister)
|
||||
ai_key.Register(extenderRegister)
|
||||
ai_model.Register(extenderRegister)
|
||||
}
|
||||
|
||||
@@ -291,6 +291,12 @@ func ApintoProfession() []*eosc.ProfessionConfig {
|
||||
Label: "ai-provider",
|
||||
Desc: "ai-provider",
|
||||
},
|
||||
{
|
||||
Id: "eolinker.com:apinto:ai-model",
|
||||
Name: "ai-model",
|
||||
Label: "ai-model",
|
||||
Desc: "ai-model",
|
||||
},
|
||||
},
|
||||
},
|
||||
//{
|
||||
|
||||
45
drivers/ai-model/config.go
Normal file
45
drivers/ai-model/config.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package ai_model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
AccessConfig string `json:"access_config"`
|
||||
}
|
||||
|
||||
// Create 创建驱动实例
|
||||
func Create(id, name string, v *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
|
||||
cfg, err := checkConfig(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w := &executor{
|
||||
WorkerBase: drivers.Worker(id, name),
|
||||
}
|
||||
err = w.reset(cfg)
|
||||
return w, err
|
||||
}
|
||||
|
||||
func checkConfig(v interface{}) (*Config, error) {
|
||||
conf, ok := v.(*Config)
|
||||
if !ok {
|
||||
return nil, eosc.ErrorConfigType
|
||||
}
|
||||
|
||||
if conf.Provider == "" {
|
||||
return nil, fmt.Errorf("provider is required")
|
||||
}
|
||||
|
||||
if conf.Model == "" {
|
||||
return nil, fmt.Errorf("model is required")
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
68
drivers/ai-model/executor.go
Normal file
68
drivers/ai-model/executor.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package ai_model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
var _ eosc.IWorker = (*executor)(nil)
|
||||
|
||||
type executor struct {
|
||||
drivers.WorkerBase
|
||||
}
|
||||
|
||||
func (e *executor) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
|
||||
cfg, err := checkConfig(conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return e.reset(cfg)
|
||||
}
|
||||
|
||||
func (e *executor) reset(conf *Config) error {
|
||||
tmp := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(conf.AccessConfig), &tmp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
accessConfigManager.Set(e.Name(), &modelAccessConfig{
|
||||
provider: conf.Provider,
|
||||
model: conf.Model,
|
||||
accessConfig: tmp,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *executor) Stop() error {
|
||||
accessConfigManager.Del(e.Name())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *executor) CheckSkill(skill string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type modelAccessConfig struct {
|
||||
provider string
|
||||
model string
|
||||
accessConfig map[string]string
|
||||
}
|
||||
|
||||
func (m *modelAccessConfig) Provider() string {
|
||||
return m.provider
|
||||
}
|
||||
|
||||
func (m *modelAccessConfig) Model() string {
|
||||
return m.model
|
||||
}
|
||||
|
||||
func (m *modelAccessConfig) Config() map[string]string {
|
||||
return m.accessConfig
|
||||
}
|
||||
37
drivers/ai-model/factory.go
Normal file
37
drivers/ai-model/factory.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package ai_model
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/eolinker/eosc/common/bean"
|
||||
|
||||
ai_convert "github.com/eolinker/apinto/ai-convert"
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
var name = "ai-model"
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
accessConfigManager ai_convert.IModelAccessConfigManager
|
||||
)
|
||||
|
||||
func init() {
|
||||
once.Do(func() {
|
||||
bean.Autowired(&accessConfigManager)
|
||||
})
|
||||
}
|
||||
|
||||
type Factory struct {
|
||||
}
|
||||
|
||||
// Register AI供应商Factory
|
||||
func Register(register eosc.IExtenderDriverRegister) {
|
||||
register.RegisterExtenderDriver(name, NewFactory())
|
||||
}
|
||||
|
||||
// NewFactory 创建service_http驱动工厂
|
||||
func NewFactory() eosc.IExtenderDriverFactory {
|
||||
return drivers.NewFactory[Config](Create)
|
||||
}
|
||||
@@ -1,30 +1,38 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
|
||||
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
"github.com/eolinker/eosc/common/bean"
|
||||
http_service "github.com/eolinker/eosc/eocontext/http-context"
|
||||
|
||||
"github.com/eolinker/eosc/eocontext"
|
||||
|
||||
ai_convert "github.com/eolinker/apinto/ai-convert"
|
||||
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
AccessKey string `json:"aws_access_key_id"`
|
||||
SecretKey string `json:"aws_secret_access_key"`
|
||||
Region string `json:"aws_region"`
|
||||
ModelForValidation string `json:"model_for_validation"`
|
||||
var (
|
||||
accessConfigManager ai_convert.IModelAccessConfigManager
|
||||
)
|
||||
|
||||
func init() {
|
||||
bean.Autowired(&accessConfigManager)
|
||||
}
|
||||
|
||||
var (
|
||||
availableRegions = map[string]struct{}{
|
||||
"us-east-1": {},
|
||||
"us-west-2": {},
|
||||
"ap-southeast-1": {},
|
||||
"ap-northeast-1": {},
|
||||
"eu-central-1": {},
|
||||
"eu-west-2": {},
|
||||
"us-gov-west-1": {},
|
||||
"ap-southeast-2": {},
|
||||
}
|
||||
)
|
||||
type Config struct {
|
||||
AccessKey string `json:"aws_access_key_id"`
|
||||
SecretKey string `json:"aws_secret_access_key"`
|
||||
Region string `json:"aws_region"`
|
||||
}
|
||||
|
||||
func checkConfig(v interface{}) (*Config, error) {
|
||||
conf, ok := v.(*Config)
|
||||
@@ -37,11 +45,191 @@ func checkConfig(v interface{}) (*Config, error) {
|
||||
if conf.SecretKey == "" {
|
||||
return nil, fmt.Errorf("aws_secret_access_key is required")
|
||||
}
|
||||
//if conf.Region == "" {
|
||||
// return nil, fmt.Errorf("aws_region is required")
|
||||
//}
|
||||
//if _, ok := availableRegions[conf.Region]; !ok {
|
||||
// return nil, fmt.Errorf("aws_region is invalid")
|
||||
//}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func Create(cfg string) (ai_convert.IConverter, error) {
|
||||
var conf Config
|
||||
err := json.Unmarshal([]byte(cfg), &conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conf.AccessKey == "" {
|
||||
return nil, fmt.Errorf("aws_access_key_id is required")
|
||||
}
|
||||
if conf.SecretKey == "" {
|
||||
return nil, fmt.Errorf("aws_secret_access_key is required")
|
||||
}
|
||||
return NewConvert(conf.AccessKey, conf.SecretKey, conf.Region), nil
|
||||
}
|
||||
|
||||
type Convert struct {
|
||||
signer *v4.Signer
|
||||
region string
|
||||
}
|
||||
|
||||
func NewConvert(ak string, sk string, region string) *Convert {
|
||||
return &Convert{
|
||||
signer: v4.NewSigner(credentials.NewStaticCredentials(ak, sk, "")),
|
||||
region: region,
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
currentPath = "/model/%s/converse"
|
||||
streamPath = "/model/%s/converse-stream"
|
||||
)
|
||||
|
||||
func (c *Convert) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error {
|
||||
provider := ai_convert.GetAIProvider(ctx)
|
||||
model := ai_convert.GetAIModel(ctx)
|
||||
modelCfg, has := accessConfigManager.Get(fmt.Sprintf("%s#%s", provider, model))
|
||||
region := ""
|
||||
if has {
|
||||
model = modelCfg.Config()["model"]
|
||||
region = modelCfg.Config()["region"]
|
||||
}
|
||||
if region == "" {
|
||||
region = c.region
|
||||
}
|
||||
base := fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", region)
|
||||
|
||||
balanceHandler, err := ai_convert.NewBalanceHandler("", base, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.SetBalance(balanceHandler)
|
||||
httpContext, err := http_service.Assert(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body, err := httpContext.Proxy().Body().RawBody()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chatRequest := eosc.NewBase[ai_convert.Request](extender)
|
||||
err = json.Unmarshal(body, chatRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal body error: %v, body: %s", err, string(body))
|
||||
}
|
||||
messages := make([]Message, 0, len(chatRequest.Config.Messages))
|
||||
systemMessage := make([]*Content, 0)
|
||||
for _, m := range chatRequest.Config.Messages {
|
||||
if m.Role == "system" {
|
||||
systemMessage = append(systemMessage, &Content{Text: m.Content})
|
||||
} else {
|
||||
messages = append(messages, Message{
|
||||
Role: m.Role,
|
||||
Content: []*Content{{Text: m.Content}},
|
||||
})
|
||||
}
|
||||
}
|
||||
chatRequest.SetAppend("messages", messages)
|
||||
chatRequest.SetAppend("system", systemMessage)
|
||||
path := fmt.Sprintf(currentPath, model)
|
||||
if chatRequest.Config.Stream {
|
||||
path = fmt.Sprintf(streamPath, model)
|
||||
}
|
||||
uri := fmt.Sprintf("%s%s", base, path)
|
||||
httpContext.Proxy().URI().SetPath(path)
|
||||
|
||||
body, _ = json.Marshal(chatRequest)
|
||||
httpContext.Proxy().Body().SetRaw("application/json", body)
|
||||
headers, err := signRequest(c.signer, region, uri, http.Header{}, string(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range headers {
|
||||
httpContext.Proxy().Header().SetHeader(k, strings.Join(v, ";"))
|
||||
}
|
||||
httpContext.Proxy().Body().SetRaw("application/json", body)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Convert) ResponseConvert(ctx eocontext.EoContext) error {
|
||||
httpContext, err := http_service.Assert(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if httpContext.Response().StatusCode() != 200 {
|
||||
return nil
|
||||
}
|
||||
body := httpContext.Response().GetBody()
|
||||
data := eosc.NewBase[Response](nil)
|
||||
err = json.Unmarshal(body, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
responseBody := &ai_convert.Response{}
|
||||
|
||||
body, err = json.Marshal(responseBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpContext.Response().SetBody(body)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Convert) streamHandler(ctx http_service.IHttpContext, p []byte) ([]byte, error) {
|
||||
//// 对响应数据进行划分
|
||||
//inputToken := GetAIModelInputToken(ctx)
|
||||
//outputToken := 0
|
||||
//totalToken := inputToken
|
||||
//scanner := bufio.NewScanner(bytes.NewReader(p))
|
||||
//// Check the content encoding and convert to UTF-8 if necessary.
|
||||
//encoding := ctx.Response().Headers().Get("content-encoding")
|
||||
//for scanner.Scan() {
|
||||
// line := scanner.Text()
|
||||
// if encoding != "utf-8" && encoding != "" {
|
||||
// tmp, err := encoder.ToUTF8(encoding, []byte(line))
|
||||
// if err != nil {
|
||||
// log.Errorf("convert to utf-8 error: %v, line: %s", err, line)
|
||||
// return p, nil
|
||||
// }
|
||||
// if ctx.Response().StatusCode() != 200 || (o.checkErr != nil && !o.checkErr(ctx, tmp)) {
|
||||
// if o.errorCallback != nil {
|
||||
// o.errorCallback(ctx, tmp)
|
||||
// }
|
||||
// return p, nil
|
||||
// }
|
||||
// line = string(tmp)
|
||||
// }
|
||||
// line = strings.TrimPrefix(line, "data:")
|
||||
// if line == "" || strings.Trim(line, " ") == "[DONE]" {
|
||||
// return p, nil
|
||||
// }
|
||||
// var resp openai.ChatCompletionResponse
|
||||
// err := json.Unmarshal([]byte(line), &resp)
|
||||
// if err != nil {
|
||||
// return p, nil
|
||||
// }
|
||||
// if len(resp.Choices) > 0 {
|
||||
// outputToken += getTokens(resp.Choices[0].Message.Content)
|
||||
// totalToken += outputToken
|
||||
// }
|
||||
//}
|
||||
//if err := scanner.Err(); err != nil {
|
||||
// log.Errorf("scan error: %v", err)
|
||||
// return p, nil
|
||||
//}
|
||||
//
|
||||
//SetAIModelInputToken(ctx, inputToken)
|
||||
//SetAIModelOutputToken(ctx, outputToken)
|
||||
//SetAIModelTotalToken(ctx, totalToken)
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func signRequest(signer *v4.Signer, region string, uri string, headers http.Header, body string) (http.Header, error) {
|
||||
request, err := http.NewRequest(http.MethodPost, uri, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header = headers.Clone()
|
||||
|
||||
_, err = signer.Sign(request, strings.NewReader(body), "bedrock", region, time.Now())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return request.Header, nil
|
||||
|
||||
}
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
ai_convert "github.com/eolinker/apinto/ai-convert"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
|
||||
"github.com/eolinker/eosc/eocontext"
|
||||
http_context "github.com/eolinker/eosc/eocontext/http-context"
|
||||
"github.com/eolinker/eosc/log"
|
||||
)
|
||||
|
||||
var _ ai_convert.IConverterFactory = &convertFactory{}
|
||||
|
||||
type convertFactory struct {
|
||||
}
|
||||
|
||||
func (c *convertFactory) Create(cfg string) (ai_convert.IConverterDriver, error) {
|
||||
var tmp Config
|
||||
err := json.Unmarshal([]byte(cfg), &tmp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newConverterDriver(&tmp)
|
||||
}
|
||||
|
||||
var _ ai_convert.IConverterDriver = &converterDriver{}
|
||||
|
||||
type basicConfig struct {
|
||||
signer *v4.Signer
|
||||
region string
|
||||
eocontext.BalanceHandler
|
||||
}
|
||||
|
||||
type converterDriver struct {
|
||||
cfg *basicConfig
|
||||
eocontext.BalanceHandler
|
||||
}
|
||||
|
||||
func newConverterDriver(cfg *Config) (ai_convert.IConverterDriver, error) {
|
||||
base := fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", cfg.Region)
|
||||
balanceHandler, err := ai_convert.NewBalanceHandler("", base, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &converterDriver{
|
||||
cfg: &basicConfig{
|
||||
signer: v4.NewSigner(credentials.NewStaticCredentials(cfg.AccessKey, cfg.SecretKey, "")),
|
||||
region: cfg.Region,
|
||||
BalanceHandler: balanceHandler,
|
||||
},
|
||||
BalanceHandler: balanceHandler,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (c *converterDriver) GetModel(model string) (ai_convert.FGenerateConfig, bool) {
|
||||
if _, ok := modelConvert[model]; !ok {
|
||||
return nil, false
|
||||
}
|
||||
return func(cfg string) (map[string]interface{}, error) {
|
||||
result := map[string]interface{}{}
|
||||
if cfg != "" {
|
||||
tmp := make(map[string]interface{})
|
||||
if err := json.Unmarshal([]byte(cfg), &tmp); err != nil {
|
||||
log.Errorf("unmarshal config error: %v, cfg: %s", err, cfg)
|
||||
return result, nil
|
||||
}
|
||||
modelCfg := ai_convert.MapToStruct[ModelConfig](tmp)
|
||||
if modelCfg.MaxTokens >= 1 {
|
||||
result["maxTokens"] = modelCfg.MaxTokens
|
||||
}
|
||||
result["temperature"] = modelCfg.Temperature
|
||||
result["topP"] = modelCfg.TopP
|
||||
}
|
||||
return result, nil
|
||||
}, true
|
||||
}
|
||||
|
||||
func (c *converterDriver) GetConverter(model string) (ai_convert.IConverter, bool) {
|
||||
converter, ok := modelConvert[model]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return &Converter{
|
||||
converter: converter,
|
||||
model: model,
|
||||
basicConfig: c.cfg,
|
||||
}, true
|
||||
}
|
||||
|
||||
type Converter struct {
|
||||
converter ai_convert.IConverter
|
||||
model string
|
||||
*basicConfig
|
||||
}
|
||||
|
||||
func (c *Converter) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error {
|
||||
if c.BalanceHandler != nil {
|
||||
ctx.SetBalance(c.BalanceHandler)
|
||||
}
|
||||
httpContext, err := http_context.Assert(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.converter.RequestConvert(httpContext, extender)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body, _ := httpContext.Proxy().Body().RawBody()
|
||||
headers, err := signRequest(c.signer, c.region, c.model, http.Header{}, string(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range headers {
|
||||
|
||||
httpContext.Proxy().Header().SetHeader(k, strings.Join(v, ";"))
|
||||
}
|
||||
//httpContext.Proxy().Header().SetHeader("Authorization", authorization)
|
||||
//httpContext.Proxy().Header().SetHeader("X-Amz-Date", date)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Converter) ResponseConvert(ctx eocontext.EoContext) error {
|
||||
return c.converter.ResponseConvert(ctx)
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
ai_convert "github.com/eolinker/apinto/ai-convert"
|
||||
|
||||
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
type executor struct {
|
||||
drivers.WorkerBase
|
||||
}
|
||||
|
||||
func (e *executor) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
|
||||
cfg, ok := conf.(*Config)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid config")
|
||||
}
|
||||
|
||||
return e.reset(cfg, workers)
|
||||
}
|
||||
|
||||
func (e *executor) reset(conf *Config, workers map[eosc.RequireId]eosc.IWorker) error {
|
||||
d, err := newConverterDriver(conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.IConverterDriver = d
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *executor) Stop() error {
|
||||
e.IConverterDriver = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *executor) CheckSkill(skill string) bool {
|
||||
return ai_convert.CheckKeySourceSkill(skill)
|
||||
}
|
||||
|
||||
type ModelConfig struct {
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
TopP float64 `json:"top_p"`
|
||||
}
|
||||
|
||||
func signRequest(signer *v4.Signer, region string, model string, headers http.Header, body string) (http.Header, error) {
|
||||
request, err := http.NewRequest(http.MethodPost, fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", region, model), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header = headers.Clone()
|
||||
|
||||
_, err = signer.Sign(request, strings.NewReader(body), "bedrock", region, time.Now())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return request.Header, nil
|
||||
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
ai_convert "github.com/eolinker/apinto/ai-convert"
|
||||
|
||||
"github.com/eolinker/eosc/common/bean"
|
||||
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
"github.com/eolinker/eosc"
|
||||
)
|
||||
|
||||
var name = "bedrock"
|
||||
var (
|
||||
converterManager ai_convert.IManager
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// Register 注册驱动
|
||||
func Register(register eosc.IExtenderDriverRegister) {
|
||||
register.RegisterExtenderDriver(name, NewFactory())
|
||||
}
|
||||
|
||||
// NewFactory 创建service_http驱动工厂
|
||||
func NewFactory() eosc.IExtenderDriverFactory {
|
||||
once.Do(func() {
|
||||
bean.Autowired(&converterManager)
|
||||
converterManager.Set(name, &convertFactory{})
|
||||
})
|
||||
return drivers.NewFactory[Config](Create)
|
||||
}
|
||||
|
||||
// Create 创建驱动实例
|
||||
func Create(id, name string, v *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
|
||||
|
||||
_, err := checkConfig(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w := &executor{
|
||||
WorkerBase: drivers.Worker(id, name),
|
||||
}
|
||||
w.reset(v, workers)
|
||||
return w, nil
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
ai_convert "github.com/eolinker/apinto/ai-convert"
|
||||
|
||||
"github.com/eolinker/eosc"
|
||||
|
||||
"github.com/eolinker/eosc/eocontext"
|
||||
http_context "github.com/eolinker/eosc/eocontext/http-context"
|
||||
)
|
||||
|
||||
type FNewModelMode func(string) IModelMode
|
||||
|
||||
var (
|
||||
modelModes = map[string]FNewModelMode{
|
||||
ai_convert.ModeChat.String(): NewChat,
|
||||
ai_convert.ModeCompletion.String(): NewChat,
|
||||
}
|
||||
)
|
||||
|
||||
type IModelMode interface {
|
||||
Endpoint() string
|
||||
ai_convert.IConverter
|
||||
}
|
||||
|
||||
type Chat struct {
|
||||
endPoint string
|
||||
}
|
||||
|
||||
func NewChat(model string) IModelMode {
|
||||
return &Chat{
|
||||
endPoint: fmt.Sprintf("/model/%s/converse", model),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Chat) Endpoint() string {
|
||||
return c.endPoint
|
||||
}
|
||||
|
||||
func (c *Chat) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error {
|
||||
httpContext, err := http_context.Assert(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body, err := httpContext.Proxy().Body().RawBody()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置转发地址
|
||||
httpContext.Proxy().URI().SetPath(c.endPoint)
|
||||
baseCfg := eosc.NewBase[ai_convert.ClientRequest]()
|
||||
err = json.Unmarshal(body, baseCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
messages := make([]Message, 0, len(baseCfg.Config.Messages))
|
||||
systemMessage := make([]Content, 0)
|
||||
|
||||
for _, m := range baseCfg.Config.Messages {
|
||||
if m.Role == "system" {
|
||||
systemMessage = append(systemMessage, Content{Text: m.Content})
|
||||
} else {
|
||||
messages = append(messages, Message{
|
||||
Role: m.Role,
|
||||
Content: []*Content{{Text: m.Content}},
|
||||
})
|
||||
}
|
||||
}
|
||||
baseCfg.SetAppend("messages", messages)
|
||||
baseCfg.SetAppend("system", systemMessage)
|
||||
|
||||
for k, v := range extender {
|
||||
baseCfg.SetAppend(k, v)
|
||||
}
|
||||
body, err = json.Marshal(baseCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpContext.Proxy().Body().SetRaw("application/json", body)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error {
|
||||
httpContext, err := http_context.Assert(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if httpContext.Response().StatusCode() != 200 {
|
||||
return nil
|
||||
}
|
||||
body := httpContext.Response().GetBody()
|
||||
data := eosc.NewBase[Response]()
|
||||
err = json.Unmarshal(body, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
responseBody := &ai_convert.ClientResponse{}
|
||||
if data.Config.Output.Message != nil && len(data.Config.Output.Message.Content) > 0 {
|
||||
msg := data.Config.Output.Message
|
||||
responseBody.Message = &ai_convert.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content[0].Text,
|
||||
}
|
||||
responseBody.FinishReason = data.Config.StopReason
|
||||
} else {
|
||||
responseBody.Code = -1
|
||||
responseBody.Error = "no response"
|
||||
}
|
||||
body, err = json.Marshal(responseBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpContext.Response().SetBody(body)
|
||||
return nil
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
encoderManger.Set("gzip", &Gzip{})
|
||||
//encoderManger.Set("gzip", &Gzip{})
|
||||
}
|
||||
|
||||
type Gzip struct {
|
||||
|
||||
Reference in New Issue
Block a user