diff --git a/.gitignore b/.gitignore index 28ceae6..99876ce 100644 --- a/.gitignore +++ b/.gitignore @@ -18,5 +18,8 @@ chatgpt-dingtalk # Dependency directories (remove the comment below to include it) # vendor/ config.yml +dingtalkbot.sqlite tmp test/ +images/ +data/ \ No newline at end of file diff --git a/README.md b/README.md index ddf16dc..b52cf2b 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ - 🔗 自定义api域名:通过配置指定,解决国内服务器无法直接访问openai的问题 - 🪜 添加代理:通过配置指定,通过给应用注入代理解决国内服务器无法访问的问题 - 👐 默认模式:支持自定义默认的聊天模式,通过配置化指定 +- 📝 查询对话:通过发送`#查对话 username:xxx`查询xxx的对话历史,可在线预览,可下载到本地。 ## 使用前提 @@ -143,7 +144,7 @@ ``` 第一种:基于环境变量运行 # 运行项目 -$ docker run -itd --name chatgpt -p 8090:8090 --add-host="host.docker.internal:host-gateway" -e APIKEY=换成你的key -e BASE_URL="" -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="http://host.docker.internal:15732" -e DEFAULT_MODE="单聊" -e MAX_REQUEST=0 -e PORT=8090 -e SERVICE_URL="你当前服务外网可访问的URL" -e CHAT_TYPE="0" --restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest +$ docker run -itd --name chatgpt -p 8090:8090 -v ./data:/app/data --add-host="host.docker.internal:host-gateway" -e APIKEY=换成你的key -e BASE_URL="" -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="http://host.docker.internal:15732" -e DEFAULT_MODE="单聊" -e MAX_REQUEST=0 -e PORT=8090 -e SERVICE_URL="你当前服务外网可访问的URL" -e CHAT_TYPE="0" --restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest ``` `📢 注意:`如果使用docker部署,那么PORT参数不需要进行任何调整。 diff --git a/docker-compose.yml b/docker-compose.yml index 801434f..e022a8e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -17,7 +17,7 @@ services: SERVICE_URL: "" # 指定服务的地址,就是当前服务可供外网访问的地址(或者直接理解为你配置在钉钉回调那里的地址),用于生成图片时给钉钉做渲染 CHAT_TYPE: "0" # 限定对话类型 0:不限 1:只能单聊 2:只能群聊 volumes: - - ./data/images:/app/images + - ./data:/app/data ports: - "8090:8090" extra_hosts: diff --git a/go.mod b/go.mod index 22aec30..2d820b1 100644 --- a/go.mod +++ b/go.mod @@ -4,27 +4,39 @@ go 1.18 require ( github.com/charmbracelet/log v0.2.1 + github.com/glebarez/sqlite v1.7.0 github.com/go-resty/resty/v2 v2.7.0 github.com/patrickmn/go-cache v2.1.0+incompatible - github.com/sashabaranov/go-openai v1.5.7 github.com/solywsh/chatgpt v0.0.14 github.com/xgfone/ship/v5 v5.3.1 gopkg.in/yaml.v2 v2.4.0 + gorm.io/gorm v1.24.6 ) require ( github.com/avast/retry-go v2.7.0+incompatible // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/charmbracelet/lipgloss v0.7.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/glebarez/go-sqlite v1.20.3 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-isatty v0.0.18 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.15.1 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/sashabaranov/go-openai v1.5.7 // indirect golang.org/x/net v0.0.0-20211029224645-99673261e6eb // indirect golang.org/x/sys v0.6.0 // indirect + modernc.org/libc v1.22.2 // indirect + modernc.org/mathutil v1.5.0 // indirect + modernc.org/memory v1.5.0 // indirect + modernc.org/sqlite v1.20.3 // indirect ) replace github.com/solywsh/chatgpt => ./pkg/chatgpt diff --git a/go.sum b/go.sum index 1713aa1..23da329 100644 --- a/go.sum +++ b/go.sum @@ -7,10 +7,23 @@ github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNW github.com/charmbracelet/log v0.2.1 h1:1z7jpkk4yKyjwlmKmKMM5qnEDSpV32E7XtWhuv0mTZE= github.com/charmbracelet/log v0.2.1/go.mod h1:GwFfjewhcVDWLrpAbY5A0Hin9YOlEn40eWT4PNaxFT4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/glebarez/go-sqlite v1.20.3 h1:89BkqGOXR9oRmG58ZrzgoY/Fhy5x0M+/WV48U5zVrZ4= +github.com/glebarez/go-sqlite v1.20.3/go.mod h1:u3N6D/wftiAzIOJtZl6BmedqxmmkDfH3q+ihjqxC9u0= +github.com/glebarez/sqlite v1.7.0 h1:A7Xj/KN2Lvie4Z4rrgQHY8MsbebX3NyWsL3n2i82MVI= +github.com/glebarez/sqlite v1.7.0/go.mod h1:PkeevrRlF/1BhQBCnzcMWzgrIk7IOop+qS2jUYLfHhk= github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY= github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98= @@ -25,6 +38,9 @@ github.com/muesli/termenv v0.15.1/go.mod h1:HeAQPTzpfs016yGtA4g00CsdYnVLJvxsS4AN github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 h1:VstopitMQi3hZP0fzvnsLmzXZdQGc4bEcgu24cp+d4M= +github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -47,3 +63,13 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gorm.io/gorm v1.24.6 h1:wy98aq9oFEetsc4CAbKD2SoBCdMzsbSIvSUUFJuHi5s= +gorm.io/gorm v1.24.6/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +modernc.org/libc v1.22.2 h1:4U7v51GyhlWqQmwCHj28Rdq2Yzwk55ovjFrdPjs8Hb0= +modernc.org/libc v1.22.2/go.mod h1:uvQavJ1pZ0hIoC/jfqNoMLURIMhKzINIWypNM17puug= +modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= +modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= +modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= +modernc.org/sqlite v1.20.3 h1:SqGJMMxjj1PHusLxdYxeQSodg7Jxn9WWkaAQjKrntZs= +modernc.org/sqlite v1.20.3/go.mod h1:zKcGyrICaxNTMEHSr1HQ2GUraP0j+845GYw37+EyT6A= diff --git a/main.go b/main.go index 0bf8b3a..5c118ba 100644 --- a/main.go +++ b/main.go @@ -39,7 +39,7 @@ func Start() { // 去除问题的前后空格 msgObj.Text.Content = strings.TrimSpace(msgObj.Text.Content) // 打印钉钉回调过来的请求明细 - logger.Info(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj)) + // logger.Info(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj)) // TODO: 校验请求 if public.Config.ChatType != "0" && msgObj.ConversationType != public.Config.ChatType { _, err = msgObj.ReplyToDingtalk(string(dingbot.TEXT), "抱歉,管理员禁用了这种聊天方式,请选择其他聊天方式与机器人对话!") @@ -57,10 +57,13 @@ func Start() { return ship.ErrBadRequest.New(fmt.Errorf("send message error: %v", err)) } } else { + logger.Info(fmt.Sprintf("🙋 %s发起的问题: %#v", msgObj.SenderNick, msgObj.Text.Content)) // 除去帮助之外的逻辑分流在这里处理 switch { case strings.HasPrefix(msgObj.Text.Content, "#图片"): return process.ImageGenerate(&msgObj) + case strings.HasPrefix(msgObj.Text.Content, "#查对话"): + return process.SelectHistory(&msgObj) default: msgObj.Text.Content, err = process.GeneratePrompt(msgObj.Text.Content) // err不为空:提示词之后没有文本 -> 直接返回提示词所代表的内容 @@ -72,7 +75,6 @@ func Start() { } return nil } - logger.Info(fmt.Sprintf("after generate prompt: %#v", msgObj.Text.Content)) return process.ProcessRequest(&msgObj) } } @@ -81,9 +83,21 @@ func Start() { // 解析生成后的图片 app.Route("/images/:filename").GET(func(c *ship.Context) error { filename := c.Param("filename") - root := "./images/" + root := "./data/images/" return c.File(filepath.Join(root, filename)) }) + // 解析生成后的历史聊天 + app.Route("/history/:filename").GET(func(c *ship.Context) error { + filename := c.Param("filename") + root := "./data/chatHistory/" + return c.File(filepath.Join(root, filename)) + }) + // 直接下载文件 + app.Route("/download/:filename").GET(func(c *ship.Context) error { + filename := c.Param("filename") + root := "./data/chatHistory/" + return c.Attachment(filepath.Join(root, filename), "") + }) port := ":" + public.Config.Port srv := &http.Server{ diff --git a/pkg/cache/user_base.go b/pkg/cache/user_base.go index ae7cf0c..04b7420 100644 --- a/pkg/cache/user_base.go +++ b/pkg/cache/user_base.go @@ -19,6 +19,10 @@ type UserServiceInterface interface { // 用户请求次数 SetUseRequestCount(userId string, current int) GetUseRequestCount(uerId string) int + // 用户对话ID + SetAnswerID(userId, chattype string, current uint) + GetAnswerID(uerId, chattype string) uint + ClearAnswerID(userId, chattitle string) } var _ UserServiceInterface = (*UserService)(nil) diff --git a/pkg/cache/user_chatid.go b/pkg/cache/user_chatid.go new file mode 100644 index 0000000..a7a3f81 --- /dev/null +++ b/pkg/cache/user_chatid.go @@ -0,0 +1,22 @@ +package cache + +import "time" + +// SetAnswerID 设置用户获得答案的ID +func (s *UserService) SetAnswerID(userId, chattitle string, current uint) { + s.cache.Set(userId+"_"+chattitle, current, time.Hour*24) +} + +// GetAnswerID 获取当前用户获得答案的ID +func (s *UserService) GetAnswerID(userId, chattitle string) uint { + sessionContext, ok := s.cache.Get(userId + "_" + chattitle) + if !ok { + return 0 + } + return sessionContext.(uint) +} + +// ClearUserSessionContext 清空GTP上下文,接收文本中包含 SessionClearToken +func (s *UserService) ClearAnswerID(userId, chattitle string) { + s.cache.Delete(userId + "_" + chattitle) +} diff --git a/pkg/chatgpt/context.go b/pkg/chatgpt/context.go index 9c9d06d..ae5bab4 100644 --- a/pkg/chatgpt/context.go +++ b/pkg/chatgpt/context.go @@ -250,7 +250,7 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) { if err != nil { return "", err } - file, err := os.Create("images/" + imageName) + file, err := os.Create("data/images/" + imageName) if err != nil { return "", err } diff --git a/pkg/db/chat.go b/pkg/db/chat.go new file mode 100644 index 0000000..01fcb77 --- /dev/null +++ b/pkg/db/chat.go @@ -0,0 +1,60 @@ +package db + +import ( + "fmt" + "strings" + + "gorm.io/gorm" +) + +type ChatType uint + +const Q ChatType = 1 +const A ChatType = 2 + +type Chat struct { + gorm.Model + Username string `gorm:"type:varchar(50);not null;comment:'用户名'" json:"username"` // 用户名 + Source string `gorm:"type:varchar(50);comment:'用户来源:群聊名字,私聊'" json:"source"` // 对话来源 + ChatType ChatType `gorm:"type:tinyint(1);default:1;comment:'类型:1问, 2答'" json:"chat_type"` // 状态 + ParentContent uint `gorm:"default:0;comment:'父消息编号(编号为0时表示为首条)'" json:"parent_content"` + Content string `gorm:"type:varchar(128);comment:'内容'" json:"content"` // 问题或回答的内容 +} + +// 需要考虑下如何处理一个完整对话的问题 +// 如果是单聊,那么就记录上下两句就好了 +// 如果是串聊,则需要知道哪条是第一条,并依次往下记录 + +// Add 添加资源 +func (c Chat) Add() (uint, error) { + err := DB.Create(&c).Error + return c.ID, err +} + +// Find 获取单个资源 +func (c Chat) Find(filter map[string]interface{}, data *Chat) error { + return DB.Where(filter).First(&data).Error +} + +type ChatListReq struct { + Username string `json:"username" form:"username"` + Source string `json:"source" form:"source"` +} + +// List 获取数据列表 +func (c Chat) List(req ChatListReq) ([]*Chat, error) { + var list []*Chat + db := DB.Model(&Chat{}).Order("created_at ASC") + + userName := strings.TrimSpace(req.Username) + if userName != "" { + db = db.Where("username LIKE ?", fmt.Sprintf("%%%s%%", userName)) + } + source := strings.TrimSpace(req.Source) + if source != "" { + db = db.Where("source LIKE ?", fmt.Sprintf("%%%s%%", source)) + } + + err := db.Find(&list).Error + return list, err +} diff --git a/pkg/db/sqlite.go b/pkg/db/sqlite.go new file mode 100644 index 0000000..87f0473 --- /dev/null +++ b/pkg/db/sqlite.go @@ -0,0 +1,41 @@ +package db + +import ( + "github.com/eryajf/chatgpt-dingtalk/pkg/logger" + "github.com/glebarez/sqlite" + "gorm.io/gorm" +) + +// 全局数据库对象 +var DB *gorm.DB + +// 初始化数据库 +func InitDB() { + DB = ConnSqlite() + + dbAutoMigrate() +} + +// 自动迁移表结构 +func dbAutoMigrate() { + _ = DB.AutoMigrate( + Chat{}, + ) +} + +func ConnSqlite() *gorm.DB { + db, err := gorm.Open(sqlite.Open("data/dingtalkbot.sqlite"), &gorm.Config{ + // 禁用外键(指定外键时不会在mysql创建真实的外键约束) + DisableForeignKeyConstraintWhenMigrating: true, + }) + if err != nil { + logger.Fatal("failed to connect sqlite3: %v", err) + } + dbObj, err := db.DB() + if err != nil { + logger.Fatal("failed to get sqlite3 obj: %v", err) + } + // 参见: https://github.com/glebarez/sqlite/issues/52 + dbObj.SetMaxOpenConns(1) + return db +} diff --git a/pkg/dingbot/dingbot.go b/pkg/dingbot/dingbot.go index ad61bec..7d06698 100644 --- a/pkg/dingbot/dingbot.go +++ b/pkg/dingbot/dingbot.go @@ -68,13 +68,22 @@ type At struct { IsAtAll bool `json:"isAtAll"` } -// 获取用户标识,兼容当 SenderStaffId 字段为空的场景 -func (r ReceiveMsg) GetSenderIdentifier() string { - if r.SenderStaffId != "" { - return r.SenderStaffId - } else { - return r.SenderNick +// 获取用户标识,兼容当 SenderStaffId 字段为空的场景,此处提供给发送消息是艾特使用 +func (r ReceiveMsg) GetSenderIdentifier() (uid string) { + uid = r.SenderStaffId + if uid == "" { + uid = r.SenderNick } + return +} + +// GetChatTitle 获取聊天的群名字,如果是私聊,则命名为 昵称_私聊 +func (r ReceiveMsg) GetChatTitle() (chatType string) { + chatType = r.ConversationTitle + if chatType == "" { + chatType = r.SenderNick + "_私聊" + } + return } // 发消息给钉钉 diff --git a/pkg/process/process_request.go b/pkg/process/process_request.go index ece02c4..97a6bf3 100644 --- a/pkg/process/process_request.go +++ b/pkg/process/process_request.go @@ -5,6 +5,7 @@ import ( "strings" "time" + "github.com/eryajf/chatgpt-dingtalk/pkg/db" "github.com/eryajf/chatgpt-dingtalk/pkg/dingbot" "github.com/eryajf/chatgpt-dingtalk/pkg/logger" "github.com/eryajf/chatgpt-dingtalk/public" @@ -29,8 +30,12 @@ func ProcessRequest(rmsg *dingbot.ReceiveMsg) error { logger.Warning(fmt.Errorf("send message error: %v", err)) } case "重置": + // 重置用户对话模式 public.UserService.ClearUserMode(rmsg.GetSenderIdentifier()) + // 清空用户对话上下文 public.UserService.ClearUserSessionContext(rmsg.GetSenderIdentifier()) + // 清空用户对话的答案ID + public.UserService.ClearAnswerID(rmsg.SenderNick, rmsg.GetChatTitle()) _, err := rmsg.ReplyToDingtalk(string(dingbot.TEXT), fmt.Sprintf("=====已重置与👉%s👈的对话模式,可以开始新的对话=====", rmsg.SenderNick)) if err != nil { logger.Warning(fmt.Errorf("send message error: %v", err)) @@ -83,6 +88,17 @@ func Do(mode string, rmsg *dingbot.ReceiveMsg) error { public.UserService.SetUserMode(rmsg.GetSenderIdentifier(), mode) switch mode { case "单聊": + qObj := db.Chat{ + Username: rmsg.SenderNick, + Source: rmsg.GetChatTitle(), + ChatType: db.Q, + ParentContent: 0, + Content: rmsg.Text.Content, + } + qid, err := qObj.Add() + if err != nil { + logger.Error("往MySQL新增数据失败,错误信息:", err) + } reply, err := chatgpt.SingleQa(rmsg.Text.Content, rmsg.GetSenderIdentifier()) if err != nil { logger.Info(fmt.Errorf("gpt request error: %v", err)) @@ -107,6 +123,18 @@ func Do(mode string, rmsg *dingbot.ReceiveMsg) error { } else { reply = strings.TrimSpace(reply) reply = strings.Trim(reply, "\n") + aObj := db.Chat{ + Username: rmsg.SenderNick, + Source: rmsg.GetChatTitle(), + ChatType: db.A, + ParentContent: qid, + Content: reply, + } + _, err := aObj.Add() + if err != nil { + logger.Error("往MySQL新增数据失败,错误信息:", err) + } + logger.Info(fmt.Sprintf("🤖 %s得到的答案: %#v", rmsg.SenderNick, reply)) // 回复@我的用户 _, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), reply) if err != nil { @@ -115,6 +143,18 @@ func Do(mode string, rmsg *dingbot.ReceiveMsg) error { } } case "串聊": + lastAid := public.UserService.GetAnswerID(rmsg.SenderNick, rmsg.GetChatTitle()) + qObj := db.Chat{ + Username: rmsg.SenderNick, + Source: rmsg.GetChatTitle(), + ChatType: db.Q, + ParentContent: lastAid, + Content: rmsg.Text.Content, + } + qid, err := qObj.Add() + if err != nil { + logger.Error("往MySQL新增数据失败,错误信息:", err) + } cli, reply, err := chatgpt.ContextQa(rmsg.Text.Content, rmsg.GetSenderIdentifier()) if err != nil { logger.Info(fmt.Sprintf("gpt request error: %v", err)) @@ -139,6 +179,20 @@ func Do(mode string, rmsg *dingbot.ReceiveMsg) error { } else { reply = strings.TrimSpace(reply) reply = strings.Trim(reply, "\n") + aObj := db.Chat{ + Username: rmsg.SenderNick, + Source: rmsg.GetChatTitle(), + ChatType: db.A, + ParentContent: qid, + Content: reply, + } + aid, err := aObj.Add() + if err != nil { + logger.Error("往MySQL新增数据失败,错误信息:", err) + } + // 将当前回答的ID放入缓存 + public.UserService.SetAnswerID(rmsg.SenderNick, rmsg.GetChatTitle(), aid) + logger.Info(fmt.Sprintf("🤖 %s得到的答案: %#v", rmsg.SenderNick, reply)) // 回复@我的用户 _, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), reply) if err != nil { @@ -154,12 +208,23 @@ func Do(mode string, rmsg *dingbot.ReceiveMsg) error { } func ImageGenerate(rmsg *dingbot.ReceiveMsg) error { + qObj := db.Chat{ + Username: rmsg.SenderNick, + Source: rmsg.GetChatTitle(), + ChatType: db.Q, + ParentContent: 0, + Content: rmsg.Text.Content, + } + qid, err := qObj.Add() + if err != nil { + logger.Error("往MySQL新增数据失败,错误信息:", err) + } reply, err := chatgpt.ImageQa(rmsg.Text.Content, rmsg.GetSenderIdentifier()) if err != nil { logger.Info(fmt.Errorf("gpt request error: %v", err)) _, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v", err)) if err != nil { - logger.Warning(fmt.Errorf("send message error: %v", err)) + logger.Error(fmt.Errorf("send message error: %v", err)) return err } } @@ -169,12 +234,68 @@ func ImageGenerate(rmsg *dingbot.ReceiveMsg) error { } else { reply = strings.TrimSpace(reply) reply = strings.Trim(reply, "\n") - // 回复@我的用户 - _, err = rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), fmt.Sprintf(">点击图片可旋转或放大。\n![](%s)", reply)) + reply = fmt.Sprintf(">点击图片可旋转或放大。\n![](%s)", reply) + aObj := db.Chat{ + Username: rmsg.SenderNick, + Source: rmsg.GetChatTitle(), + ChatType: db.A, + ParentContent: qid, + Content: reply, + } + _, err := aObj.Add() if err != nil { - logger.Warning(fmt.Errorf("send message error: %v", err)) + logger.Error("往MySQL新增数据失败,错误信息:", err) + } + logger.Info(fmt.Sprintf("🤖 %s得到的答案: %#v", rmsg.SenderNick, reply)) + // 回复@我的用户 + _, err = rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), reply) + if err != nil { + logger.Error(fmt.Errorf("send message error: %v", err)) return err } } return nil } +func SelectHistory(rmsg *dingbot.ReceiveMsg) error { + name := strings.TrimSpace(strings.Split(rmsg.Text.Content, ":")[1]) + if !rmsg.IsAdmin || name != rmsg.SenderNick { + _, err := rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), "**🤷 抱歉,您没有权限查询其他人的对话记录!**") + if err != nil { + logger.Error(fmt.Errorf("send message error: %v", err)) + return err + } + return nil + } + // 获取数据列表 + var chat db.Chat + chats, err := chat.List(db.ChatListReq{ + Username: name, + }) + if err != nil { + return err + } + var rst string + for _, chatTmp := range chats { + ctime := chatTmp.CreatedAt.Format("2006-01-02 15:04:05") + if chatTmp.ChatType == 1 { + rst += fmt.Sprintf("## 🙋 %s 问\n\n**时间:** %v\n\n**问题为:** %s\n\n", chatTmp.Username, ctime, chatTmp.Content) + } else { + rst += fmt.Sprintf("## 🤖 机器人答\n\n**时间:** %v\n\n**回答如下:** \n\n%s\n\n", ctime, chatTmp.Content) + } + // TODO: 答案应该严格放在问题之后,目前只根据ID排序进行的陈列,当一个用户同时提出多个问题时,最终展示的可能会有点问题 + } + fileName := time.Now().Format("20060102-150405") + ".md" + // 写入文件 + if err = public.WriteToFile("./data/chatHistory/"+fileName, []byte(rst)); err != nil { + return err + } + // 回复@我的用户 + reply := fmt.Sprintf("- 在线查看: [点我](%s)\n- 下载文件: [点我](%s)\n- 在线预览请安装插件:[Markdown Preview Plus](https://chrome.google.com/webstore/detail/markdown-preview-plus/febilkbfcbhebfnokafefeacimjdckgl)", public.Config.ServiceURL+"/history/"+fileName, public.Config.ServiceURL+"/download/"+fileName) + logger.Info(fmt.Sprintf("🤖 %s得到的答案: %#v", rmsg.SenderNick, reply)) + _, err = rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), reply) + if err != nil { + logger.Error(fmt.Errorf("send message error: %v", err)) + return err + } + return nil +} diff --git a/public/public.go b/public/public.go index 6c8de76..ab0e513 100644 --- a/public/public.go +++ b/public/public.go @@ -6,6 +6,7 @@ import ( "github.com/eryajf/chatgpt-dingtalk/config" "github.com/eryajf/chatgpt-dingtalk/pkg/cache" + "github.com/eryajf/chatgpt-dingtalk/pkg/db" "github.com/eryajf/chatgpt-dingtalk/pkg/dingbot" "github.com/eryajf/chatgpt-dingtalk/pkg/logger" ) @@ -15,9 +16,14 @@ var Config *config.Configuration var Prompt *[]config.Prompt func InitSvc() { + // 加载配置 Config = config.LoadConfig() + // 加载prompt Prompt = config.LoadPrompt() + // 初始化缓存 UserService = cache.NewUserService() + // 初始化数据库 + db.InitDB() // 暂时不在初始化时获取余额 // if Config.Model == openai.GPT3Dot5Turbo0301 || Config.Model == openai.GPT3Dot5Turbo { // _, _ = GetBalance() diff --git a/pkg/chatgpt/tools.go b/public/tools.go similarity index 91% rename from pkg/chatgpt/tools.go rename to public/tools.go index 6f557b2..fc29240 100644 --- a/pkg/chatgpt/tools.go +++ b/public/tools.go @@ -1,7 +1,6 @@ -package chatgpt +package public import ( - "fmt" "io/ioutil" "os" "strings" @@ -13,7 +12,7 @@ func WriteToFile(path string, data []byte) error { if len(tmp) > 0 { tmp = tmp[:len(tmp)-1] } - fmt.Println(tmp) + err := os.MkdirAll(strings.Join(tmp, "/"), os.ModePerm) if err != nil { return err