diff --git a/go.mod b/go.mod index dd28297..20782d7 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/eryajf/chatgpt-dingtalk go 1.17 require ( - github.com/eatmoreapple/openwechat v1.2.3 + github.com/go-resty/resty/v2 v2.7.0 github.com/patrickmn/go-cache v2.1.0+incompatible ) + +require golang.org/x/net v0.0.0-20211029224645-99673261e6eb // indirect diff --git a/go.sum b/go.sum index 7543d91..41cead4 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,11 @@ -github.com/eatmoreapple/openwechat v1.2.3 h1:8AO+nvXwHVTM/7Gk7y6IZ2/hjnILTLQztWmJnPhPB+k= -github.com/eatmoreapple/openwechat v1.2.3/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8= +github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY= +github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +golang.org/x/net v0.0.0-20211029224645-99673261e6eb h1:pirldcYWx7rx7kE5r+9WsOXPXK0+WH5+uZ7uPmJ44uM= +golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/gpt/gpt.go b/gpt/gpt.go index 5debe6b..6eb3ce6 100644 --- a/gpt/gpt.go +++ b/gpt/gpt.go @@ -1,19 +1,26 @@ package gpt import ( - "bytes" "encoding/json" "fmt" - "io/ioutil" - "net/http" + "time" "github.com/eryajf/chatgpt-dingtalk/config" "github.com/eryajf/chatgpt-dingtalk/public/logger" + "github.com/go-resty/resty/v2" ) const BASEURL = "https://api.openai.com/v1/" -// ChatGPTResponseBody 请求体 +// ChatGPTRequestBody 请求体 +type ChatGPTRequestBody struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + MaxTokens uint `json:"max_tokens"` + Temperature float64 `json:"temperature"` +} + +// ChatGPTResponseBody 响应体 type ChatGPTResponseBody struct { ID string `json:"id"` Object string `json:"object"` @@ -30,14 +37,6 @@ type ChoiceItem struct { FinishReason string `json:"finish_reason"` } -// ChatGPTRequestBody 响应体 -type ChatGPTRequestBody struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - MaxTokens uint `json:"max_tokens"` - Temperature float64 `json:"temperature"` -} - // Completions gtp文本模型回复 //curl https://api.openai.com/v1/completions //-H "Content-Type: application/json" @@ -51,41 +50,29 @@ func Completions(msg string) (string, error) { MaxTokens: cfg.MaxTokens, Temperature: cfg.Temperature, } - requestData, err := json.Marshal(requestBody) - if err != nil { - return "", err - } - logger.Info(fmt.Sprintf("request gtp json string : %v", string(requestData))) - req, err := http.NewRequest("POST", BASEURL+"completions", bytes.NewBuffer(requestData)) - if err != nil { - return "", err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+cfg.ApiKey) - client := &http.Client{Timeout: cfg.SessionTimeout} - response, err := client.Do(req) - if err != nil { - return "", err - } - defer response.Body.Close() + client := resty.New(). + SetRetryCount(2). + SetRetryWaitTime(1*time.Second). + SetTimeout(cfg.SessionTimeout). + SetHeader("Content-Type", "application/json"). + SetHeader("Authorization", "Bearer "+cfg.ApiKey) - body, err := ioutil.ReadAll(response.Body) + rsp, err := client.R().SetBody(requestBody).Post(BASEURL + "completions") if err != nil { - return "", err + return "", fmt.Errorf("request openai failed, err : %v", err) } - - if response.StatusCode != 200 { - return "", fmt.Errorf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details: %v ", response.StatusCode, string(body)) + if rsp.StatusCode() != 200 { + return "", fmt.Errorf("gtp api status code not equals 200, code is %d ,details: %v ", rsp.StatusCode(), string(rsp.Body())) + } else { + logger.Info(fmt.Sprintf("response gtp json string : %v", string(rsp.Body()))) } - logger.Info(fmt.Sprintf("response gtp json string : %v", string(body))) gptResponseBody := &ChatGPTResponseBody{} - err = json.Unmarshal(body, gptResponseBody) + err = json.Unmarshal(rsp.Body(), gptResponseBody) if err != nil { return "", err } - var reply string if len(gptResponseBody.Choices) > 0 { reply = gptResponseBody.Choices[0].Text diff --git a/main.go b/main.go index 042ca19..c4955a3 100644 --- a/main.go +++ b/main.go @@ -110,7 +110,6 @@ func getRequestText(rmsg public.ReceiveMsg) string { // 1.去除空格以及换行 requestText := strings.TrimSpace(rmsg.Text.Content) requestText = strings.Trim(rmsg.Text.Content, "\n") - // 2.替换掉当前用户名称 replaceText := "@" + rmsg.SenderNick requestText = strings.TrimSpace(strings.ReplaceAll(rmsg.Text.Content, replaceText, ""))