add:添加upstage供应商

This commit is contained in:
lidongjie
2024-10-14 18:02:43 +08:00
parent 4d903986a5
commit 5819f4f1ac
9 changed files with 478 additions and 0 deletions

View File

@@ -16,6 +16,7 @@ import (
"github.com/eolinker/apinto/drivers/ai-provider/perfxcloud"
"github.com/eolinker/apinto/drivers/ai-provider/stepfun"
"github.com/eolinker/apinto/drivers/ai-provider/tongyi"
"github.com/eolinker/apinto/drivers/ai-provider/upstage"
"github.com/eolinker/apinto/drivers/ai-provider/wenxin"
"github.com/eolinker/apinto/drivers/ai-provider/yi"
"github.com/eolinker/apinto/drivers/ai-provider/zhipuai"
@@ -123,4 +124,5 @@ func driverRegister(extenderRegister eosc.IExtenderDriverRegister) {
deepseek.Register(extenderRegister)
openrouter.Register(extenderRegister)
groq.Register(extenderRegister)
upstage.Register(extenderRegister)
}

View File

@@ -402,6 +402,12 @@ func ApintoProfession() []*eosc.ProfessionConfig {
Label: "GroqCloud",
Desc: "GroqCloud",
},
{
Id: "eolinker.com:apinto:upstage", // 插件ID
Name: "upstage", // 驱动名称应和定义文件的provider字段一致
Label: "upstage",
Desc: "upstage",
},
},
Mod: eosc.ProfessionConfig_Worker,
},

View File

@@ -0,0 +1,22 @@
package upstage
import (
"fmt"
"github.com/eolinker/eosc"
)
type Config struct {
APIKey string `json:"upstage_api_key"`
Organization string `json:"organization"`
}
func checkConfig(v interface{}) (*Config, error) {
conf, ok := v.(*Config)
if !ok {
return nil, eosc.ErrorConfigType
}
if conf.APIKey == "" {
return nil, fmt.Errorf("api_key is required")
}
return conf, nil
}

View File

@@ -0,0 +1,215 @@
package upstage
import (
"embed"
"encoding/json"
"fmt"
"github.com/eolinker/eosc/log"
"reflect"
"strconv"
"github.com/eolinker/apinto/drivers"
http_context "github.com/eolinker/eosc/eocontext/http-context"
ai_provider "github.com/eolinker/apinto/drivers/ai-provider"
"github.com/eolinker/apinto/convert"
"github.com/eolinker/eosc"
"github.com/eolinker/eosc/eocontext"
)
var (
//go:embed upstage.yaml
providerContent []byte
//go:embed *
providerDir embed.FS
modelConvert = make(map[string]convert.IConverter)
_ convert.IConverterDriver = (*executor)(nil)
)
func init() {
models, err := ai_provider.LoadModels(providerContent, providerDir)
if err != nil {
panic(err)
}
for key, value := range models {
if value.ModelProperties != nil {
if v, ok := modelModes[value.ModelProperties.Mode]; ok {
modelConvert[key] = v
}
}
}
}
type Converter struct {
apikey string
balanceHandler eocontext.BalanceHandler
converter convert.IConverter
}
func (c *Converter) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error {
if c.balanceHandler != nil {
ctx.SetBalance(c.balanceHandler)
}
httpContext, err := http_context.Assert(ctx)
if err != nil {
return err
}
httpContext.Proxy().Header().SetHeader("Authorization", "Bearer "+c.apikey)
return c.converter.RequestConvert(httpContext, extender)
}
func (c *Converter) ResponseConvert(ctx eocontext.EoContext) error {
return c.converter.ResponseConvert(ctx)
}
type executor struct {
drivers.WorkerBase
apikey string
eocontext.BalanceHandler
}
func (e *executor) GetConverter(model string) (convert.IConverter, bool) {
converter, ok := modelConvert[model]
if !ok {
return nil, false
}
return &Converter{balanceHandler: e.BalanceHandler, converter: converter, apikey: e.apikey}, true
}
func (e *executor) GetModel(model string) (convert.FGenerateConfig, bool) {
if _, ok := modelConvert[model]; !ok {
return nil, false
}
return func(cfg string) (map[string]interface{}, error) {
result := map[string]interface{}{
"model": model,
}
if cfg != "" {
tmp := make(map[string]interface{})
if err := json.Unmarshal([]byte(cfg), &tmp); err != nil {
log.Errorf("unmarshal config error: %v, cfg: %s", err, cfg)
return result, nil
}
modelCfg := mapToStruct[ModelConfig](tmp)
result["frequency_penalty"] = modelCfg.FrequencyPenalty
if modelCfg.MaxTokens >= 1 {
result["max_tokens"] = modelCfg.MaxTokens
}
result["presence_penalty"] = modelCfg.PresencePenalty
result["temperature"] = modelCfg.Temperature
result["top_p"] = modelCfg.TopP
if modelCfg.ResponseFormat == "" {
modelCfg.ResponseFormat = "text"
}
result["response_format"] = map[string]interface{}{
"type": modelCfg.ResponseFormat,
}
}
return result, nil
}, true
}
func (e *executor) Start() error {
return nil
}
func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
cfg, ok := conf.(*Config)
if !ok {
return fmt.Errorf("invalid config")
}
return e.reset(cfg, workers)
}
func (e *executor) reset(conf *Config, workers map[eosc.RequireId]eosc.IWorker) error {
e.BalanceHandler = nil
e.apikey = conf.APIKey
convert.Set(e.Id(), e)
return nil
}
func (e *executor) Stop() error {
e.BalanceHandler = nil
convert.Del(e.Id())
return nil
}
func (e *executor) CheckSkill(skill string) bool {
return convert.CheckSkill(skill)
}
type ModelConfig struct {
FrequencyPenalty float64 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
PresencePenalty float64 `json:"presence_penalty"`
ResponseFormat string `json:"response_format"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
}
func mapToStruct[T any](tmp map[string]interface{}) *T {
// 创建目标结构体的实例
var result T
val := reflect.ValueOf(&result).Elem()
// 获取结构体的类型
t := val.Type()
// 遍历 map 中的键值对
for k, v := range tmp {
// 查找结构体中与键名匹配的字段
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
jsonTag := field.Tag.Get("json")
if jsonTag == k {
// 获取字段的值
fieldVal := val.Field(i)
// 如果字段不可设置,跳过
if !fieldVal.CanSet() {
continue
}
// 根据字段的类型,进行类型转换
switch fieldVal.Kind() {
case reflect.Float64:
if strVal, ok := v.(string); ok && strVal != "" {
// 如果是 string 类型且非空,转换为 float64
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
fieldVal.SetFloat(floatVal)
}
} else if floatVal, ok := v.(float64); ok {
fieldVal.SetFloat(floatVal)
}
case reflect.Int:
if intVal, ok := v.(int); ok {
fieldVal.SetInt(int64(intVal))
} else if strVal, ok := v.(string); ok && strVal != "" {
if intVal, err := strconv.Atoi(strVal); err == nil {
fieldVal.SetInt(int64(intVal))
}
}
case reflect.String:
if strVal, ok := v.(string); ok {
fieldVal.SetString(strVal)
}
default:
// 其他类型不进行转换
}
}
}
}
return &result
}

View File

@@ -0,0 +1,31 @@
package upstage
import (
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/eosc"
)
var name = "upstage"
// Register 注册驱动
func Register(register eosc.IExtenderDriverRegister) {
register.RegisterExtenderDriver(name, NewFactory())
}
// NewFactory 创建service_http驱动工厂
func NewFactory() eosc.IExtenderDriverFactory {
return drivers.NewFactory[Config](Create)
}
// Create 创建驱动实例
func Create(id, name string, v *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
_, err := checkConfig(v)
if err != nil {
return nil, err
}
w := &executor{
WorkerBase: drivers.Worker(id, name),
}
w.reset(v, workers)
return w, nil
}

View File

@@ -0,0 +1,43 @@
model: solar-1-mini-chat
label:
zh_Hans: solar-1-mini-chat
en_US: solar-1-mini-chat
ko_KR: solar-1-mini-chat
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 32768
- name: seed
label:
zh_Hans: 种子
en_US: Seed
type: int
help:
zh_Hans:
如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint
响应参数来监视变化。
en_US:
If specified, model will make a best effort to sample deterministically,
such that repeated requests with the same seed and parameters should return
the same result. Determinism is not guaranteed, and you should refer to the
system_fingerprint response parameter to monitor changes in the backend.
required: false
pricing:
input: "0.5"
output: "0.5"
unit: "0.000001"
currency: USD

View File

@@ -0,0 +1,35 @@
package upstage
type ClientRequest struct {
Messages []*Message `json:"messages"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Response struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []ResponseChoice `json:"choices"`
Usage Usage `json:"usage"`
}
type ResponseChoice struct {
Index int `json:"index"`
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type CompletionTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
}

View File

@@ -0,0 +1,104 @@
package upstage
import (
"encoding/json"
"github.com/eolinker/apinto/convert"
ai_provider "github.com/eolinker/apinto/drivers/ai-provider"
"github.com/eolinker/eosc"
"github.com/eolinker/eosc/eocontext"
http_context "github.com/eolinker/eosc/eocontext/http-context"
)
var (
modelModes = map[string]IModelMode{
ai_provider.ModeChat.String(): NewChat(),
}
)
type IModelMode interface {
Endpoint() string
convert.IConverter
}
type Chat struct {
endPoint string
}
func NewChat() *Chat {
return &Chat{
endPoint: "/v1/solar/chat/completions",
}
}
func (c *Chat) Endpoint() string {
return c.endPoint
}
func (c *Chat) RequestConvert(ctx eocontext.EoContext, extender map[string]interface{}) error {
httpContext, err := http_context.Assert(ctx)
if err != nil {
return err
}
body, err := httpContext.Proxy().Body().RawBody()
if err != nil {
return err
}
// 设置转发地址
httpContext.Proxy().URI().SetPath(c.endPoint)
baseCfg := eosc.NewBase[ai_provider.ClientRequest]()
err = json.Unmarshal(body, baseCfg)
if err != nil {
return err
}
messages := make([]Message, 0, len(baseCfg.Config.Messages)+1)
for _, m := range baseCfg.Config.Messages {
messages = append(messages, Message{
Role: m.Role,
Content: m.Content,
})
}
baseCfg.SetAppend("messages", messages)
for k, v := range extender {
baseCfg.SetAppend(k, v)
}
body, err = json.Marshal(baseCfg)
if err != nil {
return err
}
httpContext.Proxy().Body().SetRaw("application/json", body)
return nil
}
func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error {
httpContext, err := http_context.Assert(ctx)
if err != nil {
return err
}
if httpContext.Response().StatusCode() != 200 {
return nil
}
body := httpContext.Response().GetBody()
data := eosc.NewBase[Response]()
err = json.Unmarshal(body, data)
if err != nil {
return err
}
responseBody := &ai_provider.ClientResponse{}
if len(data.Config.Choices) > 0 {
msg := data.Config.Choices[0]
responseBody.Message = ai_provider.Message{
Role: msg.Message.Role,
Content: msg.Message.Content,
}
responseBody.FinishReason = msg.FinishReason
} else {
responseBody.Code = -1
responseBody.Error = "no response"
}
body, err = json.Marshal(responseBody)
if err != nil {
return err
}
httpContext.Response().SetBody(body)
return nil
}

View File

@@ -0,0 +1,20 @@
provider: upstage
label:
en_US: Upstage
description:
en_US: Models provided by Upstage, such as Solar-1-mini-chat.
zh_Hans: Upstage 提供的模型,例如 Solar-1-mini-chat.
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#FFFFF"
help:
title:
en_US: Get your API Key from Upstage
zh_Hans: 从 Upstage 获取 API Key
url:
en_US: https://console.upstage.ai/api-keys
supported_model_types:
- llm
- text-embedding