feat: 支持聊天记录存入以及查询的能力 (#158)

This commit is contained in:
二丫讲梵
2023-04-02 20:19:21 +08:00
committed by GitHub
parent 2a5ef877fa
commit b14d6eabcc
15 changed files with 338 additions and 20 deletions

3
.gitignore vendored
View File

@@ -18,5 +18,8 @@ chatgpt-dingtalk
# Dependency directories (remove the comment below to include it)
# vendor/
config.yml
dingtalkbot.sqlite
tmp
test/
images/
data/

View File

@@ -80,6 +80,7 @@
- 🔗 自定义api域名通过配置指定解决国内服务器无法直接访问openai的问题
- 🪜 添加代理:通过配置指定,通过给应用注入代理解决国内服务器无法访问的问题
- 👐 默认模式:支持自定义默认的聊天模式,通过配置化指定
- 📝 查询对话:通过发送`#查对话 username:xxx`查询xxx的对话历史可在线预览可下载到本地。
## 使用前提
@@ -143,7 +144,7 @@
```
第一种:基于环境变量运行
# 运行项目
$ docker run -itd --name chatgpt -p 8090:8090 --add-host="host.docker.internal:host-gateway" -e APIKEY=换成你的key -e BASE_URL="" -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="http://host.docker.internal:15732" -e DEFAULT_MODE="单聊" -e MAX_REQUEST=0 -e PORT=8090 -e SERVICE_URL="你当前服务外网可访问的URL" -e CHAT_TYPE="0" --restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest
$ docker run -itd --name chatgpt -p 8090:8090 -v ./data:/app/data --add-host="host.docker.internal:host-gateway" -e APIKEY=换成你的key -e BASE_URL="" -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="http://host.docker.internal:15732" -e DEFAULT_MODE="单聊" -e MAX_REQUEST=0 -e PORT=8090 -e SERVICE_URL="你当前服务外网可访问的URL" -e CHAT_TYPE="0" --restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest
```
`📢 注意:`如果使用docker部署那么PORT参数不需要进行任何调整。

View File

@@ -17,7 +17,7 @@ services:
SERVICE_URL: "" # 指定服务的地址,就是当前服务可供外网访问的地址(或者直接理解为你配置在钉钉回调那里的地址),用于生成图片时给钉钉做渲染
CHAT_TYPE: "0" # 限定对话类型 0不限 1只能单聊 2只能群聊
volumes:
- ./data/images:/app/images
- ./data:/app/data
ports:
- "8090:8090"
extra_hosts:

14
go.mod
View File

@@ -4,27 +4,39 @@ go 1.18
require (
github.com/charmbracelet/log v0.2.1
github.com/glebarez/sqlite v1.7.0
github.com/go-resty/resty/v2 v2.7.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/sashabaranov/go-openai v1.5.7
github.com/solywsh/chatgpt v0.0.14
github.com/xgfone/ship/v5 v5.3.1
gopkg.in/yaml.v2 v2.4.0
gorm.io/gorm v1.24.6
)
require (
github.com/avast/retry-go v2.7.0+incompatible // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/charmbracelet/lipgloss v0.7.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/glebarez/go-sqlite v1.20.3 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.18 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.15.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/sashabaranov/go-openai v1.5.7 // indirect
golang.org/x/net v0.0.0-20211029224645-99673261e6eb // indirect
golang.org/x/sys v0.6.0 // indirect
modernc.org/libc v1.22.2 // indirect
modernc.org/mathutil v1.5.0 // indirect
modernc.org/memory v1.5.0 // indirect
modernc.org/sqlite v1.20.3 // indirect
)
replace github.com/solywsh/chatgpt => ./pkg/chatgpt

26
go.sum
View File

@@ -7,10 +7,23 @@ github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNW
github.com/charmbracelet/log v0.2.1 h1:1z7jpkk4yKyjwlmKmKMM5qnEDSpV32E7XtWhuv0mTZE=
github.com/charmbracelet/log v0.2.1/go.mod h1:GwFfjewhcVDWLrpAbY5A0Hin9YOlEn40eWT4PNaxFT4=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/glebarez/go-sqlite v1.20.3 h1:89BkqGOXR9oRmG58ZrzgoY/Fhy5x0M+/WV48U5zVrZ4=
github.com/glebarez/go-sqlite v1.20.3/go.mod h1:u3N6D/wftiAzIOJtZl6BmedqxmmkDfH3q+ihjqxC9u0=
github.com/glebarez/sqlite v1.7.0 h1:A7Xj/KN2Lvie4Z4rrgQHY8MsbebX3NyWsL3n2i82MVI=
github.com/glebarez/sqlite v1.7.0/go.mod h1:PkeevrRlF/1BhQBCnzcMWzgrIk7IOop+qS2jUYLfHhk=
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY=
github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
@@ -25,6 +38,9 @@ github.com/muesli/termenv v0.15.1/go.mod h1:HeAQPTzpfs016yGtA4g00CsdYnVLJvxsS4AN
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 h1:VstopitMQi3hZP0fzvnsLmzXZdQGc4bEcgu24cp+d4M=
github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
@@ -47,3 +63,13 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gorm.io/gorm v1.24.6 h1:wy98aq9oFEetsc4CAbKD2SoBCdMzsbSIvSUUFJuHi5s=
gorm.io/gorm v1.24.6/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
modernc.org/libc v1.22.2 h1:4U7v51GyhlWqQmwCHj28Rdq2Yzwk55ovjFrdPjs8Hb0=
modernc.org/libc v1.22.2/go.mod h1:uvQavJ1pZ0hIoC/jfqNoMLURIMhKzINIWypNM17puug=
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
modernc.org/sqlite v1.20.3 h1:SqGJMMxjj1PHusLxdYxeQSodg7Jxn9WWkaAQjKrntZs=
modernc.org/sqlite v1.20.3/go.mod h1:zKcGyrICaxNTMEHSr1HQ2GUraP0j+845GYw37+EyT6A=

20
main.go
View File

@@ -39,7 +39,7 @@ func Start() {
// 去除问题的前后空格
msgObj.Text.Content = strings.TrimSpace(msgObj.Text.Content)
// 打印钉钉回调过来的请求明细
logger.Info(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj))
// logger.Info(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj))
// TODO: 校验请求
if public.Config.ChatType != "0" && msgObj.ConversationType != public.Config.ChatType {
_, err = msgObj.ReplyToDingtalk(string(dingbot.TEXT), "抱歉,管理员禁用了这种聊天方式,请选择其他聊天方式与机器人对话!")
@@ -57,10 +57,13 @@ func Start() {
return ship.ErrBadRequest.New(fmt.Errorf("send message error: %v", err))
}
} else {
logger.Info(fmt.Sprintf("🙋 %s发起的问题: %#v", msgObj.SenderNick, msgObj.Text.Content))
// 除去帮助之外的逻辑分流在这里处理
switch {
case strings.HasPrefix(msgObj.Text.Content, "#图片"):
return process.ImageGenerate(&msgObj)
case strings.HasPrefix(msgObj.Text.Content, "#查对话"):
return process.SelectHistory(&msgObj)
default:
msgObj.Text.Content, err = process.GeneratePrompt(msgObj.Text.Content)
// err不为空提示词之后没有文本 -> 直接返回提示词所代表的内容
@@ -72,7 +75,6 @@ func Start() {
}
return nil
}
logger.Info(fmt.Sprintf("after generate prompt: %#v", msgObj.Text.Content))
return process.ProcessRequest(&msgObj)
}
}
@@ -81,9 +83,21 @@ func Start() {
// 解析生成后的图片
app.Route("/images/:filename").GET(func(c *ship.Context) error {
filename := c.Param("filename")
root := "./images/"
root := "./data/images/"
return c.File(filepath.Join(root, filename))
})
// 解析生成后的历史聊天
app.Route("/history/:filename").GET(func(c *ship.Context) error {
filename := c.Param("filename")
root := "./data/chatHistory/"
return c.File(filepath.Join(root, filename))
})
// 直接下载文件
app.Route("/download/:filename").GET(func(c *ship.Context) error {
filename := c.Param("filename")
root := "./data/chatHistory/"
return c.Attachment(filepath.Join(root, filename), "")
})
port := ":" + public.Config.Port
srv := &http.Server{

View File

@@ -19,6 +19,10 @@ type UserServiceInterface interface {
// 用户请求次数
SetUseRequestCount(userId string, current int)
GetUseRequestCount(uerId string) int
// 用户对话ID
SetAnswerID(userId, chattype string, current uint)
GetAnswerID(uerId, chattype string) uint
ClearAnswerID(userId, chattitle string)
}
var _ UserServiceInterface = (*UserService)(nil)

22
pkg/cache/user_chatid.go vendored Normal file
View File

@@ -0,0 +1,22 @@
package cache
import "time"
// SetAnswerID 设置用户获得答案的ID
func (s *UserService) SetAnswerID(userId, chattitle string, current uint) {
s.cache.Set(userId+"_"+chattitle, current, time.Hour*24)
}
// GetAnswerID 获取当前用户获得答案的ID
func (s *UserService) GetAnswerID(userId, chattitle string) uint {
sessionContext, ok := s.cache.Get(userId + "_" + chattitle)
if !ok {
return 0
}
return sessionContext.(uint)
}
// ClearUserSessionContext 清空GTP上下文接收文本中包含 SessionClearToken
func (s *UserService) ClearAnswerID(userId, chattitle string) {
s.cache.Delete(userId + "_" + chattitle)
}

View File

@@ -250,7 +250,7 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
if err != nil {
return "", err
}
file, err := os.Create("images/" + imageName)
file, err := os.Create("data/images/" + imageName)
if err != nil {
return "", err
}

60
pkg/db/chat.go Normal file
View File

@@ -0,0 +1,60 @@
package db
import (
"fmt"
"strings"
"gorm.io/gorm"
)
type ChatType uint
const Q ChatType = 1
const A ChatType = 2
type Chat struct {
gorm.Model
Username string `gorm:"type:varchar(50);not null;comment:'用户名'" json:"username"` // 用户名
Source string `gorm:"type:varchar(50);comment:'用户来源:群聊名字,私聊'" json:"source"` // 对话来源
ChatType ChatType `gorm:"type:tinyint(1);default:1;comment:'类型:1问, 2答'" json:"chat_type"` // 状态
ParentContent uint `gorm:"default:0;comment:'父消息编号(编号为0时表示为首条)'" json:"parent_content"`
Content string `gorm:"type:varchar(128);comment:'内容'" json:"content"` // 问题或回答的内容
}
// 需要考虑下如何处理一个完整对话的问题
// 如果是单聊,那么就记录上下两句就好了
// 如果是串聊,则需要知道哪条是第一条,并依次往下记录
// Add 添加资源
func (c Chat) Add() (uint, error) {
err := DB.Create(&c).Error
return c.ID, err
}
// Find 获取单个资源
func (c Chat) Find(filter map[string]interface{}, data *Chat) error {
return DB.Where(filter).First(&data).Error
}
type ChatListReq struct {
Username string `json:"username" form:"username"`
Source string `json:"source" form:"source"`
}
// List 获取数据列表
func (c Chat) List(req ChatListReq) ([]*Chat, error) {
var list []*Chat
db := DB.Model(&Chat{}).Order("created_at ASC")
userName := strings.TrimSpace(req.Username)
if userName != "" {
db = db.Where("username LIKE ?", fmt.Sprintf("%%%s%%", userName))
}
source := strings.TrimSpace(req.Source)
if source != "" {
db = db.Where("source LIKE ?", fmt.Sprintf("%%%s%%", source))
}
err := db.Find(&list).Error
return list, err
}

41
pkg/db/sqlite.go Normal file
View File

@@ -0,0 +1,41 @@
package db
import (
"github.com/eryajf/chatgpt-dingtalk/pkg/logger"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
)
// 全局数据库对象
var DB *gorm.DB
// 初始化数据库
func InitDB() {
DB = ConnSqlite()
dbAutoMigrate()
}
// 自动迁移表结构
func dbAutoMigrate() {
_ = DB.AutoMigrate(
Chat{},
)
}
func ConnSqlite() *gorm.DB {
db, err := gorm.Open(sqlite.Open("data/dingtalkbot.sqlite"), &gorm.Config{
// 禁用外键(指定外键时不会在mysql创建真实的外键约束)
DisableForeignKeyConstraintWhenMigrating: true,
})
if err != nil {
logger.Fatal("failed to connect sqlite3: %v", err)
}
dbObj, err := db.DB()
if err != nil {
logger.Fatal("failed to get sqlite3 obj: %v", err)
}
// 参见: https://github.com/glebarez/sqlite/issues/52
dbObj.SetMaxOpenConns(1)
return db
}

View File

@@ -68,13 +68,22 @@ type At struct {
IsAtAll bool `json:"isAtAll"`
}
// 获取用户标识,兼容当 SenderStaffId 字段为空的场景
func (r ReceiveMsg) GetSenderIdentifier() string {
if r.SenderStaffId != "" {
return r.SenderStaffId
} else {
return r.SenderNick
// 获取用户标识,兼容当 SenderStaffId 字段为空的场景,此处提供给发送消息是艾特使用
func (r ReceiveMsg) GetSenderIdentifier() (uid string) {
uid = r.SenderStaffId
if uid == "" {
uid = r.SenderNick
}
return
}
// GetChatTitle 获取聊天的群名字,如果是私聊,则命名为 昵称_私聊
func (r ReceiveMsg) GetChatTitle() (chatType string) {
chatType = r.ConversationTitle
if chatType == "" {
chatType = r.SenderNick + "_私聊"
}
return
}
// 发消息给钉钉

View File

@@ -5,6 +5,7 @@ import (
"strings"
"time"
"github.com/eryajf/chatgpt-dingtalk/pkg/db"
"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
"github.com/eryajf/chatgpt-dingtalk/pkg/logger"
"github.com/eryajf/chatgpt-dingtalk/public"
@@ -29,8 +30,12 @@ func ProcessRequest(rmsg *dingbot.ReceiveMsg) error {
logger.Warning(fmt.Errorf("send message error: %v", err))
}
case "重置":
// 重置用户对话模式
public.UserService.ClearUserMode(rmsg.GetSenderIdentifier())
// 清空用户对话上下文
public.UserService.ClearUserSessionContext(rmsg.GetSenderIdentifier())
// 清空用户对话的答案ID
public.UserService.ClearAnswerID(rmsg.SenderNick, rmsg.GetChatTitle())
_, err := rmsg.ReplyToDingtalk(string(dingbot.TEXT), fmt.Sprintf("=====已重置与👉%s👈的对话模式可以开始新的对话=====", rmsg.SenderNick))
if err != nil {
logger.Warning(fmt.Errorf("send message error: %v", err))
@@ -83,6 +88,17 @@ func Do(mode string, rmsg *dingbot.ReceiveMsg) error {
public.UserService.SetUserMode(rmsg.GetSenderIdentifier(), mode)
switch mode {
case "单聊":
qObj := db.Chat{
Username: rmsg.SenderNick,
Source: rmsg.GetChatTitle(),
ChatType: db.Q,
ParentContent: 0,
Content: rmsg.Text.Content,
}
qid, err := qObj.Add()
if err != nil {
logger.Error("往MySQL新增数据失败,错误信息:", err)
}
reply, err := chatgpt.SingleQa(rmsg.Text.Content, rmsg.GetSenderIdentifier())
if err != nil {
logger.Info(fmt.Errorf("gpt request error: %v", err))
@@ -107,6 +123,18 @@ func Do(mode string, rmsg *dingbot.ReceiveMsg) error {
} else {
reply = strings.TrimSpace(reply)
reply = strings.Trim(reply, "\n")
aObj := db.Chat{
Username: rmsg.SenderNick,
Source: rmsg.GetChatTitle(),
ChatType: db.A,
ParentContent: qid,
Content: reply,
}
_, err := aObj.Add()
if err != nil {
logger.Error("往MySQL新增数据失败,错误信息:", err)
}
logger.Info(fmt.Sprintf("🤖 %s得到的答案: %#v", rmsg.SenderNick, reply))
// 回复@我的用户
_, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), reply)
if err != nil {
@@ -115,6 +143,18 @@ func Do(mode string, rmsg *dingbot.ReceiveMsg) error {
}
}
case "串聊":
lastAid := public.UserService.GetAnswerID(rmsg.SenderNick, rmsg.GetChatTitle())
qObj := db.Chat{
Username: rmsg.SenderNick,
Source: rmsg.GetChatTitle(),
ChatType: db.Q,
ParentContent: lastAid,
Content: rmsg.Text.Content,
}
qid, err := qObj.Add()
if err != nil {
logger.Error("往MySQL新增数据失败,错误信息:", err)
}
cli, reply, err := chatgpt.ContextQa(rmsg.Text.Content, rmsg.GetSenderIdentifier())
if err != nil {
logger.Info(fmt.Sprintf("gpt request error: %v", err))
@@ -139,6 +179,20 @@ func Do(mode string, rmsg *dingbot.ReceiveMsg) error {
} else {
reply = strings.TrimSpace(reply)
reply = strings.Trim(reply, "\n")
aObj := db.Chat{
Username: rmsg.SenderNick,
Source: rmsg.GetChatTitle(),
ChatType: db.A,
ParentContent: qid,
Content: reply,
}
aid, err := aObj.Add()
if err != nil {
logger.Error("往MySQL新增数据失败,错误信息:", err)
}
// 将当前回答的ID放入缓存
public.UserService.SetAnswerID(rmsg.SenderNick, rmsg.GetChatTitle(), aid)
logger.Info(fmt.Sprintf("🤖 %s得到的答案: %#v", rmsg.SenderNick, reply))
// 回复@我的用户
_, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), reply)
if err != nil {
@@ -154,12 +208,23 @@ func Do(mode string, rmsg *dingbot.ReceiveMsg) error {
}
func ImageGenerate(rmsg *dingbot.ReceiveMsg) error {
qObj := db.Chat{
Username: rmsg.SenderNick,
Source: rmsg.GetChatTitle(),
ChatType: db.Q,
ParentContent: 0,
Content: rmsg.Text.Content,
}
qid, err := qObj.Add()
if err != nil {
logger.Error("往MySQL新增数据失败,错误信息:", err)
}
reply, err := chatgpt.ImageQa(rmsg.Text.Content, rmsg.GetSenderIdentifier())
if err != nil {
logger.Info(fmt.Errorf("gpt request error: %v", err))
_, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), fmt.Sprintf("请求openai失败了错误信息%v", err))
if err != nil {
logger.Warning(fmt.Errorf("send message error: %v", err))
logger.Error(fmt.Errorf("send message error: %v", err))
return err
}
}
@@ -169,12 +234,68 @@ func ImageGenerate(rmsg *dingbot.ReceiveMsg) error {
} else {
reply = strings.TrimSpace(reply)
reply = strings.Trim(reply, "\n")
// 回复@我的用户
_, err = rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), fmt.Sprintf(">点击图片可旋转或放大。\n![](%s)", reply))
reply = fmt.Sprintf(">点击图片可旋转或放大。\n![](%s)", reply)
aObj := db.Chat{
Username: rmsg.SenderNick,
Source: rmsg.GetChatTitle(),
ChatType: db.A,
ParentContent: qid,
Content: reply,
}
_, err := aObj.Add()
if err != nil {
logger.Warning(fmt.Errorf("send message error: %v", err))
logger.Error("往MySQL新增数据失败,错误信息:", err)
}
logger.Info(fmt.Sprintf("🤖 %s得到的答案: %#v", rmsg.SenderNick, reply))
// 回复@我的用户
_, err = rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), reply)
if err != nil {
logger.Error(fmt.Errorf("send message error: %v", err))
return err
}
}
return nil
}
func SelectHistory(rmsg *dingbot.ReceiveMsg) error {
name := strings.TrimSpace(strings.Split(rmsg.Text.Content, ":")[1])
if !rmsg.IsAdmin || name != rmsg.SenderNick {
_, err := rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), "**🤷 抱歉,您没有权限查询其他人的对话记录!**")
if err != nil {
logger.Error(fmt.Errorf("send message error: %v", err))
return err
}
return nil
}
// 获取数据列表
var chat db.Chat
chats, err := chat.List(db.ChatListReq{
Username: name,
})
if err != nil {
return err
}
var rst string
for _, chatTmp := range chats {
ctime := chatTmp.CreatedAt.Format("2006-01-02 15:04:05")
if chatTmp.ChatType == 1 {
rst += fmt.Sprintf("## 🙋 %s 问\n\n**时间:** %v\n\n**问题为:** %s\n\n", chatTmp.Username, ctime, chatTmp.Content)
} else {
rst += fmt.Sprintf("## 🤖 机器人答\n\n**时间:** %v\n\n**回答如下:** \n\n%s\n\n", ctime, chatTmp.Content)
}
// TODO: 答案应该严格放在问题之后目前只根据ID排序进行的陈列当一个用户同时提出多个问题时最终展示的可能会有点问题
}
fileName := time.Now().Format("20060102-150405") + ".md"
// 写入文件
if err = public.WriteToFile("./data/chatHistory/"+fileName, []byte(rst)); err != nil {
return err
}
// 回复@我的用户
reply := fmt.Sprintf("- 在线查看: [点我](%s)\n- 下载文件: [点我](%s)\n- 在线预览请安装插件:[Markdown Preview Plus](https://chrome.google.com/webstore/detail/markdown-preview-plus/febilkbfcbhebfnokafefeacimjdckgl)", public.Config.ServiceURL+"/history/"+fileName, public.Config.ServiceURL+"/download/"+fileName)
logger.Info(fmt.Sprintf("🤖 %s得到的答案: %#v", rmsg.SenderNick, reply))
_, err = rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), reply)
if err != nil {
logger.Error(fmt.Errorf("send message error: %v", err))
return err
}
return nil
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/eryajf/chatgpt-dingtalk/config"
"github.com/eryajf/chatgpt-dingtalk/pkg/cache"
"github.com/eryajf/chatgpt-dingtalk/pkg/db"
"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
"github.com/eryajf/chatgpt-dingtalk/pkg/logger"
)
@@ -15,9 +16,14 @@ var Config *config.Configuration
var Prompt *[]config.Prompt
func InitSvc() {
// 加载配置
Config = config.LoadConfig()
// 加载prompt
Prompt = config.LoadPrompt()
// 初始化缓存
UserService = cache.NewUserService()
// 初始化数据库
db.InitDB()
// 暂时不在初始化时获取余额
// if Config.Model == openai.GPT3Dot5Turbo0301 || Config.Model == openai.GPT3Dot5Turbo {
// _, _ = GetBalance()

View File

@@ -1,7 +1,6 @@
package chatgpt
package public
import (
"fmt"
"io/ioutil"
"os"
"strings"
@@ -13,7 +12,7 @@ func WriteToFile(path string, data []byte) error {
if len(tmp) > 0 {
tmp = tmp[:len(tmp)-1]
}
fmt.Println(tmp)
err := os.MkdirAll(strings.Join(tmp, "/"), os.ModePerm)
if err != nil {
return err