mirror of
https://github.com/eryajf/chatgpt-dingtalk.git
synced 2025-10-05 16:16:56 +08:00
feat: add user request limit config (#98)
This commit is contained in:
@@ -82,7 +82,7 @@
|
|||||||
|
|
||||||
```sh
|
```sh
|
||||||
# 运行项目
|
# 运行项目
|
||||||
$ docker run -itd --name chatgpt -p 8090:8090 --add-host="host.docker.internal:host-gateway" -e APIKEY=换成你的key -e BASE_URL="" -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="http://host.docker.internal:15732" -e DEFAULT_MODE="单聊" --restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest
|
$ docker run -itd --name chatgpt -p 8090:8090 --add-host="host.docker.internal:host-gateway" -e APIKEY=换成你的key -e BASE_URL="" -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="http://host.docker.internal:15732" -e DEFAULT_MODE="单聊" -e MAX_REQUEST=0 --restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest
|
||||||
```
|
```
|
||||||
|
|
||||||
`📢 注意:`如果使用docker部署,那么proxy地址可以直接使用如上方式部署,`host.docker.internal`会指向容器所在宿主机的IP,只需要更改端口为你的代理端口即可。参见:[Docker容器如何优雅地访问宿主机网络](https://wiki.eryajf.net/pages/674f53/)
|
`📢 注意:`如果使用docker部署,那么proxy地址可以直接使用如上方式部署,`host.docker.internal`会指向容器所在宿主机的IP,只需要更改端口为你的代理端口即可。参见:[Docker容器如何优雅地访问宿主机网络](https://wiki.eryajf.net/pages/674f53/)
|
||||||
@@ -243,7 +243,8 @@ $ go run main.go
|
|||||||
"model": "gpt-3.5-turbo", // 指定模型,默认为 gpt-3.5-turbo , 可选参数有: "gpt-4-32k-0314", "gpt-4-32k", "gpt-4-0314", "gpt-4", "gpt-3.5-turbo-0301", "gpt-3.5-turbo", "text-davinci-003", "text-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-001", "davinci-instruct-beta", "davinci", "curie-instruct-beta", "curie", "ada", "babbage"
|
"model": "gpt-3.5-turbo", // 指定模型,默认为 gpt-3.5-turbo , 可选参数有: "gpt-4-32k-0314", "gpt-4-32k", "gpt-4-0314", "gpt-4", "gpt-3.5-turbo-0301", "gpt-3.5-turbo", "text-davinci-003", "text-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-001", "davinci-instruct-beta", "davinci", "curie-instruct-beta", "curie", "ada", "babbage"
|
||||||
"session_timeout": 600, // 会话超时时间,默认600秒,在会话时间内所有发送给机器人的信息会作为上下文
|
"session_timeout": 600, // 会话超时时间,默认600秒,在会话时间内所有发送给机器人的信息会作为上下文
|
||||||
"http_proxy": "", // 指定请求时使用的代理,如果为空,则不使用代理
|
"http_proxy": "", // 指定请求时使用的代理,如果为空,则不使用代理
|
||||||
"default_mode": "单聊" // 默认对话模式,可根据实际场景自定义,如果不设置,默认为单聊
|
"default_mode": "单聊", // 默认对话模式,可根据实际场景自定义,如果不设置,默认为单聊
|
||||||
|
"max_request": 0 // 单人单日请求次数限制,默认为0,即不限制
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@@ -4,5 +4,6 @@
|
|||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
"session_timeout": 600,
|
"session_timeout": 600,
|
||||||
"http_proxy": "",
|
"http_proxy": "",
|
||||||
"default_mode": "单聊"
|
"default_mode": "单聊",
|
||||||
|
"max_request": 0
|
||||||
}
|
}
|
@@ -25,6 +25,8 @@ type Configuration struct {
|
|||||||
DefaultMode string `json:"default_mode"`
|
DefaultMode string `json:"default_mode"`
|
||||||
// 代理地址
|
// 代理地址
|
||||||
HttpProxy string `json:"http_proxy"`
|
HttpProxy string `json:"http_proxy"`
|
||||||
|
// 用户单日最大请求次数
|
||||||
|
MaxRequest int `json:"max_request"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var config *Configuration
|
var config *Configuration
|
||||||
@@ -54,6 +56,7 @@ func LoadConfig() *Configuration {
|
|||||||
sessionTimeout := os.Getenv("SESSION_TIMEOUT")
|
sessionTimeout := os.Getenv("SESSION_TIMEOUT")
|
||||||
defaultMode := os.Getenv("DEFAULT_MODE")
|
defaultMode := os.Getenv("DEFAULT_MODE")
|
||||||
httpProxy := os.Getenv("HTTP_PROXY")
|
httpProxy := os.Getenv("HTTP_PROXY")
|
||||||
|
maxRequest := os.Getenv("MAX_REQUEST")
|
||||||
if apiKey != "" {
|
if apiKey != "" {
|
||||||
config.ApiKey = apiKey
|
config.ApiKey = apiKey
|
||||||
}
|
}
|
||||||
@@ -79,6 +82,10 @@ func LoadConfig() *Configuration {
|
|||||||
if model != "" {
|
if model != "" {
|
||||||
config.Model = model
|
config.Model = model
|
||||||
}
|
}
|
||||||
|
if maxRequest != "" {
|
||||||
|
newMR, _ := strconv.Atoi(maxRequest)
|
||||||
|
config.MaxRequest = newMR
|
||||||
|
}
|
||||||
})
|
})
|
||||||
if config.Model == "" {
|
if config.Model == "" {
|
||||||
config.DefaultMode = "gpt-3.5-turbo"
|
config.DefaultMode = "gpt-3.5-turbo"
|
||||||
|
@@ -12,6 +12,7 @@ services:
|
|||||||
SESSION_TIMEOUT: 600 # 超时时间
|
SESSION_TIMEOUT: 600 # 超时时间
|
||||||
HTTP_PROXY: http://host.docker.internal:15777 # 配置代理,注意:host.docker.internal会解析到容器所在的宿主机IP,如果你的服务部署在宿主机,只需要更改端口即可
|
HTTP_PROXY: http://host.docker.internal:15777 # 配置代理,注意:host.docker.internal会解析到容器所在的宿主机IP,如果你的服务部署在宿主机,只需要更改端口即可
|
||||||
DEFAULT_MODE: "单聊" # 聊天模式
|
DEFAULT_MODE: "单聊" # 聊天模式
|
||||||
|
MAX_REQUEST: 0 # 单人单日请求次数限制,默认为0,即不限制
|
||||||
ports:
|
ports:
|
||||||
- "8090:8090"
|
- "8090:8090"
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
|
2
go.mod
2
go.mod
@@ -3,13 +3,13 @@ module github.com/eryajf/chatgpt-dingtalk
|
|||||||
go 1.18
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/avast/retry-go v2.7.0+incompatible
|
|
||||||
github.com/go-resty/resty/v2 v2.7.0
|
github.com/go-resty/resty/v2 v2.7.0
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
github.com/solywsh/chatgpt v0.0.14
|
github.com/solywsh/chatgpt v0.0.14
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/avast/retry-go v2.7.0+incompatible // indirect
|
||||||
github.com/sashabaranov/go-openai v1.5.1 // indirect
|
github.com/sashabaranov/go-openai v1.5.1 // indirect
|
||||||
github.com/stretchr/testify v1.8.2 // indirect
|
github.com/stretchr/testify v1.8.2 // indirect
|
||||||
golang.org/x/net v0.0.0-20211029224645-99673261e6eb // indirect
|
golang.org/x/net v0.0.0-20211029224645-99673261e6eb // indirect
|
||||||
|
178
main.go
178
main.go
@@ -6,12 +6,10 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/avast/retry-go"
|
"github.com/eryajf/chatgpt-dingtalk/pkg/process"
|
||||||
"github.com/eryajf/chatgpt-dingtalk/public"
|
"github.com/eryajf/chatgpt-dingtalk/public"
|
||||||
"github.com/eryajf/chatgpt-dingtalk/public/logger"
|
"github.com/eryajf/chatgpt-dingtalk/public/logger"
|
||||||
"github.com/solywsh/chatgpt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -66,7 +64,7 @@ func Start() {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logger.Info(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj))
|
logger.Info(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj))
|
||||||
err = ProcessRequest(*msgObj)
|
err = process.ProcessRequest(*msgObj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warning(fmt.Errorf("process request failed: %v", err))
|
logger.Warning(fmt.Errorf("process request failed: %v", err))
|
||||||
}
|
}
|
||||||
@@ -86,175 +84,3 @@ func Start() {
|
|||||||
logger.Danger(err)
|
logger.Danger(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProcessRequest(rmsg public.ReceiveMsg) error {
|
|
||||||
content := strings.TrimSpace(rmsg.Text.Content)
|
|
||||||
switch content {
|
|
||||||
case "单聊":
|
|
||||||
public.UserService.SetUserMode(rmsg.SenderStaffId, content)
|
|
||||||
_, err := rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈单聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
|
||||||
}
|
|
||||||
case "串聊":
|
|
||||||
public.UserService.SetUserMode(rmsg.SenderStaffId, content)
|
|
||||||
_, err := rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈串聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
|
||||||
}
|
|
||||||
case "重置":
|
|
||||||
public.UserService.ClearUserMode(rmsg.SenderStaffId)
|
|
||||||
public.UserService.ClearUserSessionContext(rmsg.SenderStaffId)
|
|
||||||
_, err := rmsg.ReplyText(fmt.Sprintf("=====已重置与👉%s👈的对话模式,可以开始新的对话=====", rmsg.SenderNick), rmsg.SenderStaffId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
|
||||||
}
|
|
||||||
case "余额":
|
|
||||||
cacheMsg := public.UserService.GetUserMode("system_balance")
|
|
||||||
if cacheMsg == "" {
|
|
||||||
rst, err := public.GetBalance()
|
|
||||||
if err != nil {
|
|
||||||
logger.Warning(fmt.Errorf("get balance error: %v", err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
t1 := time.Unix(int64(rst.Grants.Data[0].EffectiveAt), 0)
|
|
||||||
t2 := time.Unix(int64(rst.Grants.Data[0].ExpiresAt), 0)
|
|
||||||
cacheMsg = fmt.Sprintf("💵 已用: 💲%v\n💵 剩余: 💲%v\n⏳ 有效时间: 从 %v 到 %v\n", fmt.Sprintf("%.2f", rst.TotalUsed), fmt.Sprintf("%.2f", rst.TotalAvailable), t1.Format("2006-01-02 15:04:05"), t2.Format("2006-01-02 15:04:05"))
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := rmsg.ReplyText(cacheMsg, rmsg.SenderStaffId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if public.FirstCheck(rmsg) {
|
|
||||||
return Do("串聊", rmsg)
|
|
||||||
} else {
|
|
||||||
return Do("单聊", rmsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Do(mode string, rmsg public.ReceiveMsg) error {
|
|
||||||
// 先把模式注入
|
|
||||||
public.UserService.SetUserMode(rmsg.SenderStaffId, mode)
|
|
||||||
switch mode {
|
|
||||||
case "单聊":
|
|
||||||
reply, err := SingleQa(rmsg.Text.Content, rmsg.SenderStaffId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info(fmt.Errorf("gpt request error: %v", err))
|
|
||||||
if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") {
|
|
||||||
public.UserService.ClearUserSessionContext(rmsg.SenderStaffId)
|
|
||||||
_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if reply == "" {
|
|
||||||
logger.Warning(fmt.Errorf("get gpt result falied: %v", err))
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
reply = strings.TrimSpace(reply)
|
|
||||||
reply = strings.Trim(reply, "\n")
|
|
||||||
// 回复@我的用户
|
|
||||||
// fmt.Println("单聊结果是:", reply)
|
|
||||||
_, err = rmsg.ReplyText(reply, rmsg.SenderStaffId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "串聊":
|
|
||||||
cli, reply, err := ContextQa(rmsg.Text.Content, rmsg.SenderStaffId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info(fmt.Sprintf("gpt request error: %v", err))
|
|
||||||
if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") {
|
|
||||||
public.UserService.ClearUserSessionContext(rmsg.SenderStaffId)
|
|
||||||
_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if reply == "" {
|
|
||||||
logger.Warning(fmt.Errorf("get gpt result falied: %v", err))
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
reply = strings.TrimSpace(reply)
|
|
||||||
reply = strings.Trim(reply, "\n")
|
|
||||||
// 回复@我的用户
|
|
||||||
_, err = rmsg.ReplyText(reply, rmsg.SenderStaffId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_ = cli.ChatContext.SaveConversation(rmsg.SenderStaffId)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func SingleQa(question, userId string) (answer string, err error) {
|
|
||||||
chat := chatgpt.New(userId)
|
|
||||||
defer chat.Close()
|
|
||||||
// 定义一个重试策略
|
|
||||||
retryStrategy := []retry.Option{
|
|
||||||
retry.Delay(100 * time.Millisecond),
|
|
||||||
retry.Attempts(3),
|
|
||||||
retry.LastErrorOnly(true),
|
|
||||||
}
|
|
||||||
// 使用重试策略进行重试
|
|
||||||
err = retry.Do(
|
|
||||||
func() error {
|
|
||||||
answer, err = chat.ChatWithContext(question)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
retryStrategy...)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func ContextQa(question, userId string) (chat *chatgpt.ChatGPT, answer string, err error) {
|
|
||||||
chat = chatgpt.New(userId)
|
|
||||||
if public.UserService.GetUserSessionContext(userId) != "" {
|
|
||||||
err := chat.ChatContext.LoadConversation(userId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warning("load station failed: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
retryStrategy := []retry.Option{
|
|
||||||
retry.Delay(100 * time.Millisecond),
|
|
||||||
retry.Attempts(3),
|
|
||||||
retry.LastErrorOnly(true)}
|
|
||||||
// 使用重试策略进行重试
|
|
||||||
err = retry.Do(
|
|
||||||
func() error {
|
|
||||||
answer, err = chat.ChatWithContext(question)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
retryStrategy...)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
35
pkg/cache/user_base.go
vendored
Normal file
35
pkg/cache/user_base.go
vendored
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/patrickmn/go-cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserServiceInterface 用户业务接口
|
||||||
|
type UserServiceInterface interface {
|
||||||
|
// 用户聊天模式
|
||||||
|
GetUserMode(userId string) string
|
||||||
|
SetUserMode(userId, mode string)
|
||||||
|
ClearUserMode(userId string)
|
||||||
|
// 用户聊天上下文
|
||||||
|
GetUserSessionContext(userId string) string
|
||||||
|
SetUserSessionContext(userId, content string)
|
||||||
|
ClearUserSessionContext(userId string)
|
||||||
|
// 用户请求次数
|
||||||
|
SetUseRequestCount(userId string, current int)
|
||||||
|
GetUseRequestCount(uerId string) int
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ UserServiceInterface = (*UserService)(nil)
|
||||||
|
|
||||||
|
// UserService 用戶业务
|
||||||
|
type UserService struct {
|
||||||
|
// 缓存
|
||||||
|
cache *cache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserService 创建新的业务层
|
||||||
|
func NewUserService() UserServiceInterface {
|
||||||
|
return &UserService{cache: cache.New(time.Hour*2, time.Hour*5)}
|
||||||
|
}
|
22
pkg/cache/user_context.go
vendored
Normal file
22
pkg/cache/user_context.go
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import "github.com/patrickmn/go-cache"
|
||||||
|
|
||||||
|
// SetUserSessionContext 设置用户会话上下文文本,question用户提问内容,GTP回复内容
|
||||||
|
func (s *UserService) SetUserSessionContext(userId string, content string) {
|
||||||
|
s.cache.Set(userId+"_content", content, cache.DefaultExpiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserSessionContext 获取用户会话上下文文本
|
||||||
|
func (s *UserService) GetUserSessionContext(userId string) string {
|
||||||
|
sessionContext, ok := s.cache.Get(userId + "_content")
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return sessionContext.(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUserSessionContext 清空GTP上下文,接收文本中包含 SessionClearToken
|
||||||
|
func (s *UserService) ClearUserSessionContext(userId string) {
|
||||||
|
s.cache.Delete(userId + "_content")
|
||||||
|
}
|
22
pkg/cache/user_mode.go
vendored
Normal file
22
pkg/cache/user_mode.go
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import "github.com/patrickmn/go-cache"
|
||||||
|
|
||||||
|
// GetUserMode 获取当前对话模式
|
||||||
|
func (s *UserService) GetUserMode(userId string) string {
|
||||||
|
sessionContext, ok := s.cache.Get(userId + "_mode")
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return sessionContext.(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserMode 设置用户对话模式
|
||||||
|
func (s *UserService) SetUserMode(userId string, mode string) {
|
||||||
|
s.cache.Set(userId+"_mode", mode, cache.DefaultExpiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearUserMode 重置用户对话模式
|
||||||
|
func (s *UserService) ClearUserMode(userId string) {
|
||||||
|
s.cache.Delete(userId + "_mode")
|
||||||
|
}
|
19
pkg/cache/user_requese.go
vendored
Normal file
19
pkg/cache/user_requese.go
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetUseRequestCount 设置用户请求次数
|
||||||
|
func (s *UserService) SetUseRequestCount(userId string, current int) {
|
||||||
|
s.cache.Set(userId+"_request", current, time.Hour*24)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUseRequestCount 获取当前用户已请求次数
|
||||||
|
func (s *UserService) GetUseRequestCount(userId string) int {
|
||||||
|
sessionContext, ok := s.cache.Get(userId + "_request")
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return sessionContext.(int)
|
||||||
|
}
|
56
pkg/chatgpt/export.go
Normal file
56
pkg/chatgpt/export.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package chatgpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/avast/retry-go"
|
||||||
|
"github.com/eryajf/chatgpt-dingtalk/public"
|
||||||
|
"github.com/eryajf/chatgpt-dingtalk/public/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SingleQa(question, userId string) (answer string, err error) {
|
||||||
|
chat := New(userId)
|
||||||
|
defer chat.Close()
|
||||||
|
// 定义一个重试策略
|
||||||
|
retryStrategy := []retry.Option{
|
||||||
|
retry.Delay(100 * time.Millisecond),
|
||||||
|
retry.Attempts(3),
|
||||||
|
retry.LastErrorOnly(true),
|
||||||
|
}
|
||||||
|
// 使用重试策略进行重试
|
||||||
|
err = retry.Do(
|
||||||
|
func() error {
|
||||||
|
answer, err = chat.ChatWithContext(question)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
retryStrategy...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func ContextQa(question, userId string) (chat *ChatGPT, answer string, err error) {
|
||||||
|
chat = New(userId)
|
||||||
|
if public.UserService.GetUserSessionContext(userId) != "" {
|
||||||
|
err := chat.ChatContext.LoadConversation(userId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning("load station failed: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
retryStrategy := []retry.Option{
|
||||||
|
retry.Delay(100 * time.Millisecond),
|
||||||
|
retry.Attempts(3),
|
||||||
|
retry.LastErrorOnly(true)}
|
||||||
|
// 使用重试策略进行重试
|
||||||
|
err = retry.Do(
|
||||||
|
func() error {
|
||||||
|
answer, err = chat.ChatWithContext(question)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
retryStrategy...)
|
||||||
|
return
|
||||||
|
}
|
160
pkg/process/process_request.go
Normal file
160
pkg/process/process_request.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package process
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/eryajf/chatgpt-dingtalk/public"
|
||||||
|
"github.com/eryajf/chatgpt-dingtalk/public/logger"
|
||||||
|
"github.com/solywsh/chatgpt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProcessRequest 分析处理请求逻辑
|
||||||
|
func ProcessRequest(rmsg public.ReceiveMsg) error {
|
||||||
|
if CheckRequest(rmsg) {
|
||||||
|
content := strings.TrimSpace(rmsg.Text.Content)
|
||||||
|
switch content {
|
||||||
|
case "单聊":
|
||||||
|
public.UserService.SetUserMode(rmsg.SenderStaffId, content)
|
||||||
|
_, err := rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈单聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||||||
|
}
|
||||||
|
case "串聊":
|
||||||
|
public.UserService.SetUserMode(rmsg.SenderStaffId, content)
|
||||||
|
_, err := rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈串聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||||||
|
}
|
||||||
|
case "重置":
|
||||||
|
public.UserService.ClearUserMode(rmsg.SenderStaffId)
|
||||||
|
public.UserService.ClearUserSessionContext(rmsg.SenderStaffId)
|
||||||
|
_, err := rmsg.ReplyText(fmt.Sprintf("=====已重置与👉%s👈的对话模式,可以开始新的对话=====", rmsg.SenderNick), rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||||||
|
}
|
||||||
|
case "余额":
|
||||||
|
cacheMsg := public.UserService.GetUserMode("system_balance")
|
||||||
|
if cacheMsg == "" {
|
||||||
|
rst, err := public.GetBalance()
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning(fmt.Errorf("get balance error: %v", err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t1 := time.Unix(int64(rst.Grants.Data[0].EffectiveAt), 0)
|
||||||
|
t2 := time.Unix(int64(rst.Grants.Data[0].ExpiresAt), 0)
|
||||||
|
cacheMsg = fmt.Sprintf("💵 已用: 💲%v\n💵 剩余: 💲%v\n⏳ 有效时间: 从 %v 到 %v\n", fmt.Sprintf("%.2f", rst.TotalUsed), fmt.Sprintf("%.2f", rst.TotalAvailable), t1.Format("2006-01-02 15:04:05"), t2.Format("2006-01-02 15:04:05"))
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := rmsg.ReplyText(cacheMsg, rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if public.FirstCheck(rmsg) {
|
||||||
|
return Do("串聊", rmsg)
|
||||||
|
} else {
|
||||||
|
return Do("单聊", rmsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行处理请求
|
||||||
|
func Do(mode string, rmsg public.ReceiveMsg) error {
|
||||||
|
// 先把模式注入
|
||||||
|
public.UserService.SetUserMode(rmsg.SenderStaffId, mode)
|
||||||
|
switch mode {
|
||||||
|
case "单聊":
|
||||||
|
reply, err := chatgpt.SingleQa(rmsg.Text.Content, rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info(fmt.Errorf("gpt request error: %v", err))
|
||||||
|
if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") {
|
||||||
|
public.UserService.ClearUserSessionContext(rmsg.SenderStaffId)
|
||||||
|
_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if reply == "" {
|
||||||
|
logger.Warning(fmt.Errorf("get gpt result falied: %v", err))
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
reply = strings.TrimSpace(reply)
|
||||||
|
reply = strings.Trim(reply, "\n")
|
||||||
|
// 回复@我的用户
|
||||||
|
// fmt.Println("单聊结果是:", reply)
|
||||||
|
_, err = rmsg.ReplyText(reply, rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "串聊":
|
||||||
|
cli, reply, err := chatgpt.ContextQa(rmsg.Text.Content, rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info(fmt.Sprintf("gpt request error: %v", err))
|
||||||
|
if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") {
|
||||||
|
public.UserService.ClearUserSessionContext(rmsg.SenderStaffId)
|
||||||
|
_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if reply == "" {
|
||||||
|
logger.Warning(fmt.Errorf("get gpt result falied: %v", err))
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
reply = strings.TrimSpace(reply)
|
||||||
|
reply = strings.Trim(reply, "\n")
|
||||||
|
// 回复@我的用户
|
||||||
|
_, err = rmsg.ReplyText(reply, rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_ = cli.ChatContext.SaveConversation(rmsg.SenderStaffId)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessRequest 分析处理请求逻辑
|
||||||
|
func CheckRequest(rmsg public.ReceiveMsg) bool {
|
||||||
|
if public.Config.MaxRequest == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
count := public.UserService.GetUseRequestCount(rmsg.SenderStaffId)
|
||||||
|
// 判断访问次数是否超过限制
|
||||||
|
if count >= public.Config.MaxRequest {
|
||||||
|
logger.Info(fmt.Sprintf("亲爱的: %s,您今日请求次数已达上限,请明天再来,交互发问资源有限,请务必斟酌您的问题,给您带来不便,敬请谅解!", rmsg.SenderNick))
|
||||||
|
_, err := rmsg.ReplyText(fmt.Sprintf("一个好的问题,胜过十个好的答案!\n亲爱的: %s,您今日请求次数已达上限,请明天再来,交互发问资源有限,请务必斟酌您的问题,给您带来不便,敬请谅解!", rmsg.SenderNick), rmsg.SenderStaffId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// 访问次数未超过限制,将计数加1
|
||||||
|
public.UserService.SetUseRequestCount(rmsg.SenderStaffId, count+1)
|
||||||
|
return true
|
||||||
|
}
|
@@ -4,15 +4,15 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/eryajf/chatgpt-dingtalk/config"
|
"github.com/eryajf/chatgpt-dingtalk/config"
|
||||||
"github.com/eryajf/chatgpt-dingtalk/service"
|
"github.com/eryajf/chatgpt-dingtalk/pkg/cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
var UserService service.UserServiceInterface
|
var UserService cache.UserServiceInterface
|
||||||
var Config *config.Configuration
|
var Config *config.Configuration
|
||||||
|
|
||||||
func InitSvc() {
|
func InitSvc() {
|
||||||
Config = config.LoadConfig()
|
Config = config.LoadConfig()
|
||||||
UserService = service.NewUserService()
|
UserService = cache.NewUserService()
|
||||||
_, _ = GetBalance()
|
_, _ = GetBalance()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,68 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/patrickmn/go-cache"
|
|
||||||
)
|
|
||||||
|
|
||||||
// UserServiceInterface 用户业务接口
|
|
||||||
type UserServiceInterface interface {
|
|
||||||
GetUserMode(userId string) string
|
|
||||||
SetUserMode(userId, mode string)
|
|
||||||
ClearUserMode(userId string)
|
|
||||||
GetUserSessionContext(userId string) string
|
|
||||||
SetUserSessionContext(userId, content string)
|
|
||||||
ClearUserSessionContext(userId string)
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ UserServiceInterface = (*UserService)(nil)
|
|
||||||
|
|
||||||
// UserService 用戶业务
|
|
||||||
type UserService struct {
|
|
||||||
// 缓存
|
|
||||||
cache *cache.Cache
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewUserService 创建新的业务层
|
|
||||||
func NewUserService() UserServiceInterface {
|
|
||||||
return &UserService{cache: cache.New(time.Hour*2, time.Hour*5)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserMode 获取当前对话模式
|
|
||||||
func (s *UserService) GetUserMode(userId string) string {
|
|
||||||
sessionContext, ok := s.cache.Get(userId + "_mode")
|
|
||||||
if !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return sessionContext.(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUserMode 设置用户对话模式
|
|
||||||
func (s *UserService) SetUserMode(userId string, mode string) {
|
|
||||||
s.cache.Set(userId+"_mode", mode, cache.DefaultExpiration)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearUserMode 重置用户对话模式
|
|
||||||
func (s *UserService) ClearUserMode(userId string) {
|
|
||||||
s.cache.Delete(userId + "_mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUserSessionContext 设置用户会话上下文文本,question用户提问内容,GTP回复内容
|
|
||||||
func (s *UserService) SetUserSessionContext(userId string, content string) {
|
|
||||||
s.cache.Set(userId+"_content", content, cache.DefaultExpiration)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserSessionContext 获取用户会话上下文文本
|
|
||||||
func (s *UserService) GetUserSessionContext(userId string) string {
|
|
||||||
sessionContext, ok := s.cache.Get(userId + "_content")
|
|
||||||
if !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return sessionContext.(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearUserSessionContext 清空GTP上下文,接收文本中包含 SessionClearToken
|
|
||||||
func (s *UserService) ClearUserSessionContext(userId string) {
|
|
||||||
s.cache.Delete(userId + "_content")
|
|
||||||
}
|
|
Reference in New Issue
Block a user