gemini完成

This commit is contained in:
Liujian
2024-10-10 18:58:42 +08:00
parent e8f0fa09ff
commit de17b202f5
4 changed files with 85 additions and 70 deletions

View File

@@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
@@ -96,8 +95,13 @@ func (e *executor) GetModel(model string) (convert.FGenerateConfig, bool) {
log.Errorf("unmarshal config error: %v, cfg: %s", err, cfg)
return result, nil
}
modelCfg := mapToStruct[ModelConfig](tmp)
result["generationConfig"] = modelCfg
modelCfg := ai_provider.MapToStruct[ModelConfig](tmp)
generationConfig := make(map[string]interface{})
generationConfig["maxOutputTokens"] = modelCfg.MaxOutputTokens
generationConfig["temperature"] = modelCfg.Temperature
generationConfig["topP"] = modelCfg.TopP
generationConfig["topK"] = modelCfg.TopK
result["generationConfig"] = generationConfig
}
return result, nil
}, true
@@ -156,61 +160,3 @@ type ModelConfig struct {
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
}
func mapToStruct[T any](tmp map[string]interface{}) *T {
// 创建目标结构体的实例
var result T
val := reflect.ValueOf(&result).Elem()
// 获取结构体的类型
t := val.Type()
// 遍历 map 中的键值对
for k, v := range tmp {
// 查找结构体中与键名匹配的字段
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
jsonTag := field.Tag.Get("json")
if jsonTag == k {
// 获取字段的值
fieldVal := val.Field(i)
// 如果字段不可设置,跳过
if !fieldVal.CanSet() {
continue
}
// 根据字段的类型,进行类型转换
switch fieldVal.Kind() {
case reflect.Float64:
if strVal, ok := v.(string); ok && strVal != "" {
// 如果是 string 类型且非空,转换为 float64
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
fieldVal.SetFloat(floatVal)
}
} else if floatVal, ok := v.(float64); ok {
fieldVal.SetFloat(floatVal)
}
case reflect.Int:
if intVal, ok := v.(int); ok {
fieldVal.SetInt(int64(intVal))
} else if strVal, ok := v.(string); ok && strVal != "" {
if intVal, err := strconv.Atoi(strVal); err == nil {
fieldVal.SetInt(int64(intVal))
}
}
case reflect.String:
if strVal, ok := v.(string); ok {
fieldVal.SetString(strVal)
}
default:
// 其他类型不进行转换
}
}
}
}
return &result
}

View File

@@ -5,8 +5,8 @@ type ClientRequest struct {
}
type Content struct {
Part map[string]interface{} `json:"part"`
Role string `json:"role"`
Parts []map[string]interface{} `json:"parts"`
Role string `json:"role"`
}
type Response struct {

View File

@@ -61,14 +61,18 @@ func (c *Chat) RequestConvert(ctx eocontext.EoContext, extender map[string]inter
messages := make([]Content, 0, len(baseCfg.Config.Messages)+1)
for _, m := range baseCfg.Config.Messages {
role := "user"
if m.Role == "system" {
if m.Role == "system" && len(baseCfg.Config.Messages) > 1 {
role = "model"
}
messages = append(messages, Content{
Role: role,
Part: map[string]interface{}{
parts := make([]map[string]interface{}, 0, 1)
if m.Content != "" {
parts = append(parts, map[string]interface{}{
"text": m.Content,
},
})
}
messages = append(messages, Content{
Role: role,
Parts: parts,
})
}
baseCfg.SetAppend("contents", messages)
@@ -106,9 +110,12 @@ func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error {
role = "assistant"
}
text := ""
if v, ok := msg.Content.Part["text"]; ok {
text = v.(string)
if len(msg.Content.Parts) > 0 {
if v, ok := msg.Content.Parts[0]["text"]; ok {
text = v.(string)
}
}
responseBody.Message = ai_provider.Message{
Role: role,
Content: text,

View File

@@ -2,6 +2,8 @@ package ai_provider
import (
"embed"
"reflect"
"strconv"
"strings"
yaml "gopkg.in/yaml.v3"
@@ -76,3 +78,63 @@ func LoadModels(providerContent []byte, dirFs embed.FS) (map[string]*Model, erro
}
return models, nil
}
func MapToStruct[T any](tmp map[string]interface{}) *T {
// 创建目标结构体的实例
var result T
val := reflect.ValueOf(&result).Elem()
// 获取结构体的类型
t := val.Type()
// 遍历 map 中的键值对
for k, v := range tmp {
// 查找结构体中与键名匹配的字段
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
jsonTag := field.Tag.Get("json")
if jsonTag == k {
// 获取字段的值
fieldVal := val.Field(i)
// 如果字段不可设置,跳过
if !fieldVal.CanSet() {
continue
}
// 根据字段的类型,进行类型转换
switch fieldVal.Kind() {
case reflect.Float64:
if strVal, ok := v.(string); ok && strVal != "" {
// 如果是 string 类型且非空,转换为 float64
if floatVal, err := strconv.ParseFloat(strVal, 64); err == nil {
fieldVal.SetFloat(floatVal)
}
} else if floatVal, ok := v.(float64); ok {
fieldVal.SetFloat(floatVal)
}
case reflect.Int:
if intVal, ok := v.(int); ok {
fieldVal.SetInt(int64(intVal))
} else if strVal, ok := v.(string); ok && strVal != "" {
if intVal, err := strconv.Atoi(strVal); err == nil {
fieldVal.SetInt(int64(intVal))
}
} else if floatVal, ok := v.(float64); ok {
fieldVal.SetInt(int64(floatVal))
}
case reflect.String:
if strVal, ok := v.(string); ok {
fieldVal.SetString(strVal)
}
default:
// 其他类型不进行转换
}
}
}
}
return &result
}