mirror of
https://github.com/eryajf/chatgpt-dingtalk.git
synced 2025-12-24 12:57:50 +08:00
* 将ai交互切换为go-openai * add stream * ✨ feat(stream): 优化流式响应机制,实现实时卡片更新 - 将固定1.5秒更新改为基于300ms最小间隔的实时更新策略 - 新增内容缓冲区机制,避免过于频繁的API调用 - 改进流式中断处理,保护已接收的内容不丢失 🔧 chore(llm): 优化HTTP客户端配置 - 增加连接池设置(MaxIdleConns: 100, MaxIdleConnsPerHost: 10) - 设置空闲连接超时时间为90秒 - 添加HTTP/2禁用选项注释,用于解决流式错误问题 📝 docs(stream): 更新流式更新策略文档 - 详细说明实时流式更新机制和缓冲策略 - 新增HTTP/2流式错误的故障排除指南 - 更新配置参数说明和建议范围 🐛 fix(stream): 修复流式中断时的内容丢失问题 - 在流式接收中断时,确保已接收的内容不会丢失 - 改进错误处理逻辑,区分有内容和无内容的情况 * modify ai
405 lines
12 KiB
Go
405 lines
12 KiB
Go
package process
|
||
|
||
import (
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
|
||
"github.com/eryajf/chatgpt-dingtalk/pkg/db"
|
||
"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
|
||
"github.com/eryajf/chatgpt-dingtalk/pkg/llm"
|
||
"github.com/eryajf/chatgpt-dingtalk/pkg/logger"
|
||
"github.com/eryajf/chatgpt-dingtalk/public"
|
||
)
|
||
|
||
// DoStream 使用流式输出执行处理请求
|
||
func DoStream(mode string, rmsg *dingbot.ReceiveMsg) error {
|
||
// 先把模式注入
|
||
public.UserService.SetUserMode(rmsg.GetSenderIdentifier(), mode)
|
||
|
||
switch mode {
|
||
case "单聊":
|
||
return doSingleChatStream(rmsg)
|
||
case "串聊":
|
||
return doContextChatStream(rmsg)
|
||
default:
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// doSingleChatStream 单聊流式处理
|
||
func doSingleChatStream(rmsg *dingbot.ReceiveMsg) error {
|
||
// 保存问题到数据库
|
||
qObj := db.Chat{
|
||
Username: rmsg.SenderNick,
|
||
Source: rmsg.GetChatTitle(),
|
||
ChatType: db.Q,
|
||
ParentContent: 0,
|
||
Content: rmsg.Text.Content,
|
||
}
|
||
qid, err := qObj.Add()
|
||
if err != nil {
|
||
logger.Error("往MySQL新增数据失败,错误信息:", err)
|
||
}
|
||
|
||
// 获取流式内容
|
||
contentCh, cleanup, err := llm.SingleQaStream(rmsg.Text.Content, rmsg.GetSenderIdentifier())
|
||
if err != nil {
|
||
logger.Info(fmt.Errorf("gpt request error: %v", err))
|
||
if strings.Contains(fmt.Sprintf("%v", err), "maximum question length exceeded") {
|
||
public.UserService.ClearUserSessionContext(rmsg.GetSenderIdentifier())
|
||
_, err = rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), fmt.Sprintf("[Wrong] 请求 OpenAI 失败了\n\n> 错误信息:%v\n\n> 已超过最大文本限制,请缩短提问文字的字数。", err))
|
||
if err != nil {
|
||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||
}
|
||
} else {
|
||
_, err = rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), fmt.Sprintf("[Wrong] 请求 OpenAI 失败了\n\n> 错误信息:%v", err))
|
||
if err != nil {
|
||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||
}
|
||
}
|
||
return err
|
||
}
|
||
defer cleanup()
|
||
|
||
// 使用简化版本:直接累积内容后一次性回复
|
||
fullContent := ""
|
||
for content := range contentCh {
|
||
fullContent += content
|
||
}
|
||
|
||
if fullContent == "" {
|
||
logger.Warning("get gpt result failed: empty response")
|
||
return nil
|
||
}
|
||
|
||
// 格式化和处理答案
|
||
fullContent = strings.TrimSpace(fullContent)
|
||
fullContent = strings.Trim(fullContent, "\n")
|
||
|
||
// 保存答案到数据库
|
||
aObj := db.Chat{
|
||
Username: rmsg.SenderNick,
|
||
Source: rmsg.GetChatTitle(),
|
||
ChatType: db.A,
|
||
ParentContent: qid,
|
||
Content: fullContent,
|
||
}
|
||
_, err = aObj.Add()
|
||
if err != nil {
|
||
logger.Error("往MySQL新增数据失败,错误信息:", err)
|
||
}
|
||
|
||
logger.Info(fmt.Sprintf("🤖 %s得到的答案: %#v", rmsg.SenderNick, fullContent))
|
||
|
||
// 敏感词过滤
|
||
if public.JudgeSensitiveWord(fullContent) {
|
||
fullContent = public.SolveSensitiveWord(fullContent)
|
||
}
|
||
|
||
// 回复用户
|
||
_, err = rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), FormatMarkdown(fullContent))
|
||
if err != nil {
|
||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// doContextChatStream 串聊流式处理
|
||
func doContextChatStream(rmsg *dingbot.ReceiveMsg) error {
|
||
// 保存问题到数据库
|
||
lastAid := public.UserService.GetAnswerID(rmsg.SenderNick, rmsg.GetChatTitle())
|
||
qObj := db.Chat{
|
||
Username: rmsg.SenderNick,
|
||
Source: rmsg.GetChatTitle(),
|
||
ChatType: db.Q,
|
||
ParentContent: lastAid,
|
||
Content: rmsg.Text.Content,
|
||
}
|
||
qid, err := qObj.Add()
|
||
if err != nil {
|
||
logger.Error("往MySQL新增数据失败,错误信息:", err)
|
||
}
|
||
|
||
// 获取流式内容
|
||
cli, contentCh, err := llm.ContextQaStream(rmsg.Text.Content, rmsg.GetSenderIdentifier())
|
||
if err != nil {
|
||
logger.Info(fmt.Sprintf("gpt request error: %v", err))
|
||
if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") {
|
||
public.UserService.ClearUserSessionContext(rmsg.GetSenderIdentifier())
|
||
_, err = rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), fmt.Sprintf("[Wrong] 请求 OpenAI 失败了\n\n> 错误信息:%v\n\n> 串聊已超过最大文本限制,对话已重置,请重新发起。", err))
|
||
if err != nil {
|
||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||
}
|
||
} else {
|
||
_, err = rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), fmt.Sprintf("[Wrong] 请求 OpenAI 失败了\n\n> 错误信息:%v", err))
|
||
if err != nil {
|
||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||
}
|
||
}
|
||
return err
|
||
}
|
||
defer cli.Close()
|
||
|
||
// 使用简化版本:直接累积内容后一次性回复
|
||
fullContent := ""
|
||
for content := range contentCh {
|
||
fullContent += content
|
||
}
|
||
|
||
if fullContent == "" {
|
||
logger.Warning("get gpt result failed: empty response")
|
||
return nil
|
||
}
|
||
|
||
// 格式化和处理答案
|
||
fullContent = strings.TrimSpace(fullContent)
|
||
fullContent = strings.Trim(fullContent, "\n")
|
||
|
||
// 保存答案到数据库
|
||
aObj := db.Chat{
|
||
Username: rmsg.SenderNick,
|
||
Source: rmsg.GetChatTitle(),
|
||
ChatType: db.A,
|
||
ParentContent: qid,
|
||
Content: fullContent,
|
||
}
|
||
aid, err := aObj.Add()
|
||
if err != nil {
|
||
logger.Error("往MySQL新增数据失败,错误信息:", err)
|
||
}
|
||
|
||
// 将当前回答的ID放入缓存
|
||
public.UserService.SetAnswerID(rmsg.SenderNick, rmsg.GetChatTitle(), aid)
|
||
|
||
logger.Info(fmt.Sprintf("🤖 %s得到的答案: %#v", rmsg.SenderNick, fullContent))
|
||
|
||
// 敏感词过滤
|
||
if public.JudgeSensitiveWord(fullContent) {
|
||
fullContent = public.SolveSensitiveWord(fullContent)
|
||
}
|
||
|
||
// 回复用户
|
||
_, err = rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), FormatMarkdown(fullContent))
|
||
if err != nil {
|
||
logger.Warning(fmt.Errorf("send message error: %v", err))
|
||
return err
|
||
}
|
||
|
||
// 保存对话上下文
|
||
_ = cli.ChatContext.SaveConversation(rmsg.GetSenderIdentifier())
|
||
|
||
return nil
|
||
}
|
||
|
||
// DoStreamWithCard 使用流式卡片输出执行处理请求 (需要配置卡片模板)
|
||
func DoStreamWithCard(mode string, rmsg *dingbot.ReceiveMsg, cardTemplateID string) error {
|
||
// 先把模式注入
|
||
public.UserService.SetUserMode(rmsg.GetSenderIdentifier(), mode)
|
||
|
||
// 检查是否有 RobotCode,如果没有则降级为简化流式模式
|
||
clientId := rmsg.RobotCode
|
||
if clientId == "" {
|
||
logger.Warning("RobotCode is empty, fallback to simple stream mode")
|
||
return DoStream(mode, rmsg)
|
||
}
|
||
|
||
// 获取钉钉客户端
|
||
dingClient := public.DingTalkClientManager.GetClientByOAuthClientID(clientId)
|
||
if dingClient == nil {
|
||
logger.Warning(fmt.Errorf("dingtalk client not found for robot code: %s, fallback to simple stream mode", clientId))
|
||
return DoStream(mode, rmsg)
|
||
}
|
||
|
||
client, ok := dingClient.(*dingbot.DingTalkClient)
|
||
if !ok {
|
||
logger.Warning("invalid dingtalk client type, fallback to simple stream mode")
|
||
return DoStream(mode, rmsg)
|
||
}
|
||
|
||
// 生成唯一追踪ID
|
||
trackID := uuid.New().String()
|
||
|
||
// 创建并投放卡片
|
||
accessToken, err := client.GetAccessToken()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to get access token: %w", err)
|
||
}
|
||
|
||
cardClient, err := dingbot.NewStreamCardClient()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create stream card client: %w", err)
|
||
}
|
||
|
||
// 构建OpenSpaceID
|
||
var openSpaceID string
|
||
if rmsg.ConversationType == "2" { // 群聊
|
||
openSpaceID = fmt.Sprintf("dtv1.card//IM_GROUP.%s", rmsg.ConversationID)
|
||
logger.Info(fmt.Sprintf("🎴 群聊模式 - OpenSpaceID: %s, RobotCode: %s", openSpaceID, rmsg.RobotCode))
|
||
} else { // 单聊
|
||
openSpaceID = fmt.Sprintf("dtv1.card//IM_ROBOT.%s", rmsg.SenderStaffId)
|
||
logger.Info(fmt.Sprintf("🎴 私聊模式 - OpenSpaceID: %s, ConversationType: %s", openSpaceID, rmsg.ConversationType))
|
||
}
|
||
|
||
createReq := &dingbot.CreateAndDeliverCardRequest{
|
||
CardTemplateID: cardTemplateID,
|
||
OutTrackID: trackID,
|
||
ConversationID: rmsg.ConversationID,
|
||
SenderStaffID: rmsg.SenderStaffId,
|
||
RobotCode: rmsg.RobotCode,
|
||
OpenSpaceID: openSpaceID,
|
||
ConversationType: rmsg.ConversationType,
|
||
CardData: map[string]string{
|
||
"content": "",
|
||
},
|
||
}
|
||
|
||
if err := cardClient.CreateAndDeliverCard(accessToken, createReq); err != nil {
|
||
logger.Warning(fmt.Errorf("failed to create card: %v", err))
|
||
// 卡片创建失败,降级为普通消息
|
||
return DoStream(mode, rmsg)
|
||
}
|
||
|
||
// 发送初始状态
|
||
initialContent := fmt.Sprintf("**%s**\n\n%s", rmsg.Text.Content, "稍等,让我想一想……")
|
||
if err := client.UpdateAIStreamCard(trackID, initialContent, false); err != nil {
|
||
logger.Warning(fmt.Errorf("failed to update initial card: %v", err))
|
||
}
|
||
|
||
// 获取流式内容
|
||
var contentCh <-chan string
|
||
var cli *llm.Client
|
||
if mode == "单聊" {
|
||
var cleanup func()
|
||
contentCh, cleanup, err = llm.SingleQaStream(rmsg.Text.Content, rmsg.GetSenderIdentifier())
|
||
defer cleanup()
|
||
} else {
|
||
cli, contentCh, err = llm.ContextQaStream(rmsg.Text.Content, rmsg.GetSenderIdentifier())
|
||
defer cli.Close()
|
||
}
|
||
|
||
if err != nil {
|
||
errorMsg := fmt.Sprintf("**%s**\n\n出错了: %v", rmsg.Text.Content, err)
|
||
if err := client.UpdateAIStreamCard(trackID, errorMsg, true); err != nil {
|
||
logger.Warning(fmt.Errorf("failed to update error card: %v", err))
|
||
}
|
||
return err
|
||
}
|
||
|
||
// 实时流式更新卡片内容
|
||
questionHeader := fmt.Sprintf("**%s**\n\n", rmsg.Text.Content)
|
||
fullContent := questionHeader
|
||
|
||
// 使用缓冲机制避免更新过于频繁
|
||
updateBuffer := ""
|
||
lastUpdateTime := time.Now()
|
||
minUpdateInterval := 300 * time.Millisecond // 最小更新间隔300ms
|
||
|
||
for {
|
||
content, ok := <-contentCh
|
||
if !ok {
|
||
// 流结束,发送最后的更新(如果有未发送的缓冲内容)
|
||
if updateBuffer != "" {
|
||
fullContent += updateBuffer
|
||
if err := client.UpdateAIStreamCard(trackID, fullContent, true); err != nil {
|
||
logger.Error(fmt.Errorf("failed to finalize card: %v", err))
|
||
}
|
||
} else {
|
||
// 标记为完成
|
||
if err := client.UpdateAIStreamCard(trackID, fullContent, true); err != nil {
|
||
logger.Error(fmt.Errorf("failed to finalize card: %v", err))
|
||
}
|
||
}
|
||
|
||
// 保存到数据库并处理后续逻辑
|
||
saveStreamResult(mode, rmsg, fullContent[len(questionHeader):], cli)
|
||
return nil
|
||
}
|
||
|
||
// 累积接收到的内容到缓冲区
|
||
updateBuffer += content
|
||
|
||
// 检查是否应该更新(距离上次更新超过最小间隔)
|
||
if time.Since(lastUpdateTime) >= minUpdateInterval {
|
||
fullContent += updateBuffer
|
||
updateBuffer = ""
|
||
|
||
// 立即更新卡片
|
||
if err := client.UpdateAIStreamCard(trackID, fullContent, false); err != nil {
|
||
logger.Warning(fmt.Errorf("failed to update card: %v", err))
|
||
}
|
||
|
||
lastUpdateTime = time.Now()
|
||
}
|
||
}
|
||
}
|
||
|
||
// saveStreamResult 保存流式结果到数据库
|
||
func saveStreamResult(mode string, rmsg *dingbot.ReceiveMsg, answer string, cli *llm.Client) {
|
||
answer = strings.TrimSpace(answer)
|
||
answer = strings.Trim(answer, "\n")
|
||
|
||
if mode == "单聊" {
|
||
qObj := db.Chat{
|
||
Username: rmsg.SenderNick,
|
||
Source: rmsg.GetChatTitle(),
|
||
ChatType: db.Q,
|
||
ParentContent: 0,
|
||
Content: rmsg.Text.Content,
|
||
}
|
||
qid, err := qObj.Add()
|
||
if err != nil {
|
||
logger.Error("往MySQL新增数据失败,错误信息:", err)
|
||
}
|
||
|
||
aObj := db.Chat{
|
||
Username: rmsg.SenderNick,
|
||
Source: rmsg.GetChatTitle(),
|
||
ChatType: db.A,
|
||
ParentContent: qid,
|
||
Content: answer,
|
||
}
|
||
_, err = aObj.Add()
|
||
if err != nil {
|
||
logger.Error("往MySQL新增数据失败,错误信息:", err)
|
||
}
|
||
} else { // 串聊
|
||
lastAid := public.UserService.GetAnswerID(rmsg.SenderNick, rmsg.GetChatTitle())
|
||
qObj := db.Chat{
|
||
Username: rmsg.SenderNick,
|
||
Source: rmsg.GetChatTitle(),
|
||
ChatType: db.Q,
|
||
ParentContent: lastAid,
|
||
Content: rmsg.Text.Content,
|
||
}
|
||
qid, err := qObj.Add()
|
||
if err != nil {
|
||
logger.Error("往MySQL新增数据失败,错误信息:", err)
|
||
}
|
||
|
||
aObj := db.Chat{
|
||
Username: rmsg.SenderNick,
|
||
Source: rmsg.GetChatTitle(),
|
||
ChatType: db.A,
|
||
ParentContent: qid,
|
||
Content: answer,
|
||
}
|
||
aid, err := aObj.Add()
|
||
if err != nil {
|
||
logger.Error("往MySQL新增数据失败,错误信息:", err)
|
||
}
|
||
|
||
public.UserService.SetAnswerID(rmsg.SenderNick, rmsg.GetChatTitle(), aid)
|
||
|
||
if cli != nil {
|
||
_ = cli.ChatContext.SaveConversation(rmsg.GetSenderIdentifier())
|
||
}
|
||
}
|
||
|
||
logger.Info(fmt.Sprintf("🤖 %s得到的答案: %#v", rmsg.SenderNick, answer))
|
||
}
|