add: rs-capi

This commit is contained in:
zeke-chin
2024-11-26 11:39:29 +08:00
parent 34ef211bdf
commit d97b157dc2
26 changed files with 4320 additions and 611 deletions

5
.gitignore vendored
View File

@@ -1,4 +1,7 @@
.env
tests/
node_modules/
node_modules/
.DS_Store
.idea/
__pycache__/

49
api.go
View File

@@ -1,49 +0,0 @@
package main
import (
"bytes"
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
func formatMessages(messages []Message) string {
var formatted []string
for _, msg := range messages {
formatted = append(formatted, fmt.Sprintf("%s:%s", msg.Role, msg.Content))
}
return strings.Join(formatted, "\n")
}
func sendToCursorAPI(c *gin.Context, hexData []byte) (*http.Response, error) {
req, err := http.NewRequest("POST", "https://api2.cursor.sh/aiserver.v1.AiService/StreamChat", bytes.NewReader(hexData))
if err != nil {
return nil, err
}
// 获取认证token
authToken := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
if strings.Contains(authToken, "%3A%3A") {
authToken = strings.Split(authToken, "%3A%3A")[1]
}
// 设置请求头
req.Header.Set("Content-Type", "application/connect+proto")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken))
req.Header.Set("Connect-Accept-Encoding", "gzip,br")
req.Header.Set("Connect-Protocol-Version", "1")
req.Header.Set("User-Agent", "connect-es/1.4.0")
req.Header.Set("X-Amzn-Trace-Id", fmt.Sprintf("Root=%s", uuid.New().String()))
req.Header.Set("X-Cursor-Checksum", "zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef")
req.Header.Set("X-Cursor-Client-Version", "0.42.3")
req.Header.Set("X-Cursor-Timezone", "Asia/Shanghai")
req.Header.Set("X-Ghost-Mode", "false")
req.Header.Set("X-Request-Id", uuid.New().String())
req.Header.Set("Host", "api2.cursor.sh")
client := &http.Client{}
return client.Do(req)
}

View File

@@ -1,25 +0,0 @@
package main
import (
"log"
"os"
"github.com/joho/godotenv"
"cursor-api-proxy/internal/api"
)
func main() {
if err := godotenv.Load(); err != nil {
log.Println("Warning: Error loading .env file")
}
server := api.NewServer()
port := os.Getenv("PORT")
if port == "" {
port = "3000"
}
log.Printf("服务器运行在端口 %s\n", port)
server.Run(":" + port)
}

41
go.mod
View File

@@ -1,5 +1,40 @@
module cursor-api-proxy
module go-capi
go 1.21
go 1.22.0
// ... 其他依赖 ...
require (
github.com/gin-contrib/cors v1.7.2
github.com/gin-gonic/gin v1.10.0
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
)
require (
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/text v0.15.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

95
go.sum
View File

@@ -1,29 +1,33 @@
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -34,57 +38,66 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

265
handlers/chat.go Normal file
View File

@@ -0,0 +1,265 @@
package handlers
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"time"
"bufio"
"encoding/json"
"unicode"
"regexp"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go-capi/models"
"go-capi/utils"
)
func ChatCompletions(c *gin.Context) {
var chatRequest models.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 验证o1模型不支持流式输出
if strings.HasPrefix(chatRequest.Model, "o1-") && chatRequest.Stream {
c.JSON(http.StatusBadRequest, gin.H{"error": "Model not supported stream"})
return
}
// 获取并处理认证令牌
authHeader := c.GetHeader("Authorization")
if !strings.HasPrefix(authHeader, "Bearer ") {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid authorization header"})
return
}
authToken := strings.TrimPrefix(authHeader, "Bearer ")
if authToken == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing authorization token"})
return
}
// 处理多个密钥
keys := strings.Split(authToken, ",")
if len(keys) > 0 {
authToken = strings.TrimSpace(keys[0])
}
if strings.Contains(authToken, "%3A%3A") {
parts := strings.Split(authToken, "%3A%3A")
authToken = parts[1]
}
// 格式化消息
var messages []string
for _, msg := range chatRequest.Messages {
messages = append(messages, fmt.Sprintf("%s:%s", msg.Role, msg.Content))
}
formattedMessages := strings.Join(messages, "\n")
// 生成请求数据
hexData, err := utils.StringToHex(formattedMessages, chatRequest.Model)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 准备请求
client := &http.Client{Timeout: 300 * time.Second}
req, err := http.NewRequest("POST", "https://api2.cursor.sh/aiserver.v1.AiService/StreamChat", bytes.NewReader(hexData))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 设置请求头
req.Header.Set("Content-Type", "application/connect+proto")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("Connect-Accept-Encoding", "gzip,br")
req.Header.Set("Connect-Protocol-Version", "1")
req.Header.Set("User-Agent", "connect-es/1.4.0")
req.Header.Set("X-Amzn-Trace-Id", "Root="+uuid.New().String())
req.Header.Set("X-Cursor-Checksum", "zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef")
req.Header.Set("X-Cursor-Client-Version", "0.42.3")
req.Header.Set("X-Cursor-Timezone", "Asia/Shanghai")
req.Header.Set("X-Ghost-Mode", "false")
req.Header.Set("X-Request-Id", uuid.New().String())
req.Header.Set("Host", "api2.cursor.sh")
// ... 设置其他请求头
// 打印 请求头和请求体
fmt.Printf("\nRequest Headers: %v\n", req.Header)
fmt.Printf("\nRequest Body: %x\n", hexData)
resp, err := client.Do(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer resp.Body.Close()
if chatRequest.Stream {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
chunks := make([][]byte, 0)
reader := bufio.NewReader(resp.Body)
for {
chunk, err := reader.ReadBytes('\n')
if err == io.EOF {
break
}
if err != nil {
c.SSEvent("error", gin.H{"error": err.Error()})
return
}
chunks = append(chunks, chunk)
}
responseID := "chatcmpl-" + uuid.New().String()
c.Stream(func(w io.Writer) bool {
for _, chunk := range chunks {
text := chunkToUTF8String(chunk)
if text == "" {
continue
}
// 清理文本
text = strings.TrimSpace(text)
if strings.Contains(text, "<|END_USER|>") {
parts := strings.Split(text, "<|END_USER|>")
text = strings.TrimSpace(parts[len(parts)-1])
}
if len(text) > 0 && unicode.IsLetter(rune(text[0])) {
text = strings.TrimSpace(text[1:])
}
text = cleanControlChars(text)
if text != "" {
dataBody := map[string]interface{}{
"id": responseID,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"choices": []map[string]interface{}{
{
"index": 0,
"delta": map[string]string{
"content": text,
},
},
},
}
jsonData, _ := json.Marshal(dataBody)
c.SSEvent("", string(jsonData))
w.(http.Flusher).Flush()
}
}
c.SSEvent("", "[DONE]")
return false
})
} else {
// 非流式响应处理
reader := bufio.NewReader(resp.Body)
var allText string
for {
chunk, err := reader.ReadBytes('\n')
if err == io.EOF {
break
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
text := utils.ChunkToUTF8String(chunk)
if text != "" {
allText += text
}
}
// 清理响应文本
allText = cleanResponseText(allText)
response := models.ChatResponse{
ID: "chatcmpl-" + uuid.New().String(),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: chatRequest.Model,
Choices: []models.Choice{
{
Index: 0,
Message: &models.Message{
Role: "assistant",
Content: allText,
},
FinishReason: "stop",
},
},
Usage: &models.Usage{
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
},
}
c.JSON(http.StatusOK, response)
}
}
// 辅助函数
func chunkToUTF8String(chunk []byte) string {
// 实现从二进制chunk转换到UTF8字符串的逻辑
return string(chunk)
}
func cleanControlChars(text string) string {
return regexp.MustCompile(`[\x00-\x1F\x7F]`).ReplaceAllString(text, "")
}
func cleanResponseText(text string) string {
// 移除END_USER之前的所有内容
re := regexp.MustCompile(`(?s)^.*<\|END_USER\|>`)
text = re.ReplaceAllString(text, "")
// 移除开头的换行和单个字母
text = regexp.MustCompile(`^\n[a-zA-Z]?`).ReplaceAllString(text, "")
text = strings.TrimSpace(text)
// 清理控制字符
text = cleanControlChars(text)
return text
}
func GetModels(c *gin.Context) {
response := models.ModelsResponse{
Object: "list",
Data: []models.ModelData{
{ID: "claude-3-5-sonnet-20241022", Object: "model", Created: 1713744000, OwnedBy: "anthropic"},
{ID: "claude-3-opus", Object: "model", Created: 1709251200, OwnedBy: "anthropic"},
{ID: "claude-3.5-haiku", Object: "model", Created: 1711929600, OwnedBy: "anthropic"},
{ID: "claude-3.5-sonnet", Object: "model", Created: 1711929600, OwnedBy: "anthropic"},
{ID: "cursor-small", Object: "model", Created: 1712534400, OwnedBy: "cursor"},
{ID: "gpt-3.5-turbo", Object: "model", Created: 1677649200, OwnedBy: "openai"},
{ID: "gpt-4", Object: "model", Created: 1687392000, OwnedBy: "openai"},
{ID: "gpt-4-turbo-2024-04-09", Object: "model", Created: 1712620800, OwnedBy: "openai"},
{ID: "gpt-4o", Object: "model", Created: 1712620800, OwnedBy: "openai"},
{ID: "gpt-4o-mini", Object: "model", Created: 1712620800, OwnedBy: "openai"},
{ID: "o1-mini", Object: "model", Created: 1712620800, OwnedBy: "openai"},
{ID: "o1-preview", Object: "model", Created: 1712620800, OwnedBy: "openai"},
},
}
c.JSON(http.StatusOK, response)
}

View File

@@ -1,33 +0,0 @@
package api
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"cursor-api-proxy/internal/models"
"cursor-api-proxy/internal/service"
)
func handleChat(c *gin.Context) {
var req models.ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 检查是否为 o1 开头的模型且请求流式输出
if strings.HasPrefix(req.Model, "o1-") && req.Stream {
c.JSON(http.StatusBadRequest, gin.H{"error": "Model not supported stream"})
return
}
// ... 其他处理逻辑 ...
if req.Stream {
service.HandleStreamResponse(c, req)
return
}
service.HandleNormalResponse(c, req)
}

View File

@@ -1,15 +0,0 @@
package api
import (
"github.com/gin-gonic/gin"
)
func NewServer() *gin.Engine {
r := gin.Default()
setupRoutes(r)
return r
}
func setupRoutes(r *gin.Engine) {
r.POST("/v1/chat/completions", handleChat)
}

View File

@@ -1,14 +0,0 @@
package models
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
}
// ... 其他类型定义 ...

View File

@@ -1,15 +0,0 @@
package service
import (
"github.com/gin-gonic/gin"
"cursor-api-proxy/internal/models"
"cursor-api-proxy/internal/utils"
)
func HandleStreamResponse(c *gin.Context, req models.ChatRequest) {
// 从 handlers.go 移动流式响应处理逻辑到这里
}
func HandleNormalResponse(c *gin.Context, req models.ChatRequest) {
// 从 handlers.go 移动普通响应处理逻辑到这里
}

View File

@@ -1,5 +0,0 @@
package utils
func StringToHex(str, modelName string) []byte {
// 从 utils.go 移动转换逻辑到这里
}

View File

@@ -1,5 +0,0 @@
package utils
func ProcessChunk(chunk []byte) string {
// 从 process.go 移动处理逻辑到这里
}

106
main.go
View File

@@ -1,97 +1,35 @@
package main
import (
"log"
"net/http"
"os"
"strings"
"go-capi/handlers"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/joho/godotenv"
)
// 在 main() 函数之前添加以下结构体定义
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
func main() {
if err := godotenv.Load(); err != nil {
log.Println("Warning: Error loading .env file")
}
// 加载环境变量
godotenv.Load()
r := gin.Default()
r.POST("/v1/chat/completions", handleChat)
port := os.Getenv("PORT")
if port == "" {
port = "3000"
}
log.Printf("服务器运行在端口 %s\n", port)
// 配置CORS
r.Use(cors.New(cors.Config{
AllowOrigins: []string{"*"},
AllowMethods: []string{"GET", "POST",},
AllowHeaders: []string{"*"},
AllowCredentials: true,
}))
// 注册路由
r.POST("/v1/chat/completions", handlers.ChatCompletions)
r.GET("/models", handlers.GetModels)
// 获取端口号
// port := os.Getenv("PORT")
port := "3001"
r.Run(":" + port)
}
func handleChat(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 检查是否为 o1 开头的模型且请求流式输出
if strings.HasPrefix(req.Model, "o1-") && req.Stream {
c.JSON(http.StatusBadRequest, gin.H{"error": "Model not supported stream"})
return
}
// 获取并处理认证token
authHeader := c.GetHeader("Authorization")
authToken := strings.TrimPrefix(authHeader, "Bearer ")
if authToken == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Authorization is required",
})
return
}
// 处理消息
if len(req.Messages) == 0 {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Messages should be a non-empty array",
})
return
}
// 处理流式请求
if req.Stream {
handleStreamResponse(c, req)
return
}
// 处理非流式请求
handleNormalResponse(c, req)
}
// 在文件末尾添加这两个新函数
func handleStreamResponse(c *gin.Context, req ChatRequest) {
// TODO: 实现流式响应的逻辑
c.JSON(http.StatusNotImplemented, gin.H{
"error": "Stream response not implemented yet",
})
}
func handleNormalResponse(c *gin.Context, req ChatRequest) {
// TODO: 实现普通响应的逻辑
c.JSON(http.StatusNotImplemented, gin.H{
"error": "Normal response not implemented yet",
})
}
}

411
main.py Normal file
View File

@@ -0,0 +1,411 @@
import json
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import uuid
import httpx
import os
from dotenv import load_dotenv
import time
import re
# 加载环境变量
load_dotenv()
app = FastAPI()
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 定义请求模型
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: bool = False
def string_to_hex(text: str, model_name: str) -> bytes:
"""将文本转换为特定格式的十六进制数据"""
# 将输入文本转换为UTF-8字节
text_bytes = text.encode('utf-8')
text_length = len(text_bytes)
# 固定常量
FIXED_HEADER = 2
SEPARATOR = 1
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
# 计算第一个长度字段
if text_length < 128:
text_length_field1 = format(text_length, '02x')
text_length_field_size1 = 1
else:
low_byte1 = (text_length & 0x7F) | 0x80
high_byte1 = (text_length >> 7) & 0xFF
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
text_length_field_size1 = 2
# 计算基础长度字段
base_length = text_length + 0x2A
if base_length < 128:
text_length_field = format(base_length, '02x')
text_length_field_size = 1
else:
low_byte = (base_length & 0x7F) | 0x80
high_byte = (base_length >> 7) & 0xFF
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
text_length_field_size = 2
# 计算总消息长度
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
# 构造十六进制字符串
model_name_bytes = model_name.encode('utf-8')
model_name_length_hex = format(len(model_name_bytes), '02X')
model_name_hex = model_name_bytes.hex().upper()
hex_string = (
f"{message_total_length:010x}"
"12"
f"{text_length_field}"
"0A"
f"{text_length_field1}"
f"{text_bytes.hex()}"
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
f"{model_name_length_hex}"
f"{model_name_hex}"
"22004A"
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
"800101B00100C00100E00100E80100"
).upper()
return bytes.fromhex(hex_string)
def chunk_to_utf8_string(chunk: bytes) -> str:
"""将二进制chunk转换为UTF-8字符串"""
if not chunk or len(chunk) < 2:
return ''
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
return ''
# 记录原始chunk的十六进制调试用
print(f"chunk length: {len(chunk)}")
# print(f"chunk hex: {chunk.hex()}")
try:
# 去掉0x0A之前的所有字节
try:
chunk = chunk[chunk.index(0x0A) + 1:]
except ValueError:
pass
filtered_chunk = bytearray()
i = 0
while i < len(chunk):
# 检查是否有连续的0x00
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
i += 4
while i < len(chunk) and chunk[i] <= 0x0F:
i += 1
continue
if chunk[i] == 0x0C:
i += 1
while i < len(chunk) and chunk[i] == 0x0A:
i += 1
else:
filtered_chunk.append(chunk[i])
i += 1
# 过滤掉特定字节
filtered_chunk = bytes(b for b in filtered_chunk
if b != 0x00 and b != 0x0C)
if not filtered_chunk:
return ''
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
# print(f"decoded result: {result}") # 调试输出
return result
except Exception as e:
print(f"Error in chunk_to_utf8_string: {str(e)}")
return ''
async def process_stream(chunks, ):
"""处理流式响应"""
response_id = f"chatcmpl-{str(uuid.uuid4())}"
# 先将所有chunks读取到列表中
# chunks = []
# async for chunk in response.aiter_raw():
# chunks.append(chunk)
# 然后处理保存的chunks
for chunk in chunks:
text = chunk_to_utf8_string(chunk)
if text:
# 清理文本
text = text.strip()
if "<|END_USER|>" in text:
text = text.split("<|END_USER|>")[-1].strip()
if text and text[0].isalpha():
text = text[1:].strip()
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
if text: # 确保清理后的文本不为空
data_body = {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"choices": [{
"index": 0,
"delta": {
"content": text
}
}]
}
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
# yield "data: {\n"
# yield f' "id": "{response_id}",\n'
# yield ' "object": "chat.completion.chunk",\n'
# yield f' "created": {int(time.time())},\n'
# yield ' "choices": [{\n'
# yield ' "index": 0,\n'
# yield ' "delta": {\n'
# yield f' "content": "{text}"\n'
# yield " }\n"
# yield " }]\n"
# yield "}\n\n"
yield "data: [DONE]\n\n"
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, chat_request: ChatRequest):
# 验证o1模型不支持流式输出
if chat_request.model.startswith('o1-') and chat_request.stream:
raise HTTPException(
status_code=400,
detail="Model not supported stream"
)
# 获取并处理认证令牌
auth_header = request.headers.get('authorization', '')
if not auth_header.startswith('Bearer '):
raise HTTPException(
status_code=401,
detail="Invalid authorization header"
)
auth_token = auth_header.replace('Bearer ', '')
if not auth_token:
raise HTTPException(
status_code=401,
detail="Missing authorization token"
)
# 处理多个密钥
keys = [key.strip() for key in auth_token.split(',')]
if keys:
auth_token = keys[0] # 使用第一个密钥
if '%3A%3A' in auth_token:
auth_token = auth_token.split('%3A%3A')[1]
# 格式化消息
formatted_messages = "\n".join(
f"{msg.role}:{msg.content}" for msg in chat_request.messages
)
# 生成请求数据
hex_data = string_to_hex(formatted_messages, chat_request.model)
# 准备请求头
headers = {
'Content-Type': 'application/connect+proto',
'Authorization': f'Bearer {auth_token}',
'Connect-Accept-Encoding': 'gzip,br',
'Connect-Protocol-Version': '1',
'User-Agent': 'connect-es/1.4.0',
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
'X-Cursor-Client-Version': '0.42.3',
'X-Cursor-Timezone': 'Asia/Shanghai',
'X-Ghost-Mode': 'false',
'X-Request-Id': str(uuid.uuid4()),
'Host': 'api2.cursor.sh'
}
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
try:
# 使用 stream=True 参数
# 打印 headers 和 二进制 data
print(f"headers: {headers}")
print(hex_data)
async with client.stream(
'POST',
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
headers=headers,
content=hex_data,
timeout=None
) as response:
if chat_request.stream:
chunks = []
async for chunk in response.aiter_raw():
chunks.append(chunk)
return StreamingResponse(
process_stream(chunks),
media_type="text/event-stream"
)
else:
# 非流式响应处理
text = ''
async for chunk in response.aiter_raw():
# print('chunk:', chunk.hex())
print('chunk length:', len(chunk))
res = chunk_to_utf8_string(chunk)
# print('res:', res)
if res:
text += res
# 清理响应文本
import re
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
return {
"id": f"chatcmpl-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": chat_request.model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
except Exception as e:
print(f"Error: {str(e)}")
raise HTTPException(
status_code=500,
detail="Internal server error"
)
@app.post("/models")
async def models():
return {
"object": "list",
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"object": "model",
"created": 1713744000,
"owned_by": "anthropic"
},
{
"id": "claude-3-opus",
"object": "model",
"created": 1709251200,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-haiku",
"object": "model",
"created": 1711929600,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-sonnet",
"object": "model",
"created": 1711929600,
"owned_by": "anthropic"
},
{
"id": "cursor-small",
"object": "model",
"created": 1712534400,
"owned_by": "cursor"
},
{
"id": "gpt-3.5-turbo",
"object": "model",
"created": 1677649200,
"owned_by": "openai"
},
{
"id": "gpt-4",
"object": "model",
"created": 1687392000,
"owned_by": "openai"
},
{
"id": "gpt-4-turbo-2024-04-09",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "gpt-4o",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "gpt-4o-mini",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "o1-mini",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "o1-preview",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
}
]
}
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", "3001"))
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)

50
models/models.go Normal file
View File

@@ -0,0 +1,50 @@
package models
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream,omitempty"`
}
type Delta struct {
Content string `json:"content"`
}
type Choice struct {
Index int `json:"index"`
Delta Delta `json:"delta,omitempty"`
Message *Message `json:"message,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}
type ChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model,omitempty"`
Choices []Choice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type ModelData struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type ModelsResponse struct {
Object string `json:"object"`
Data []ModelData `json:"data"`
}

View File

@@ -1,190 +0,0 @@
package main
import (
"bytes"
// "encoding/hex"
// "log"
"fmt"
)
// // 封装函数,用于将 chunk 转换为 UTF-8 字符串
// function chunkToUtf8String (chunk) {
// if (chunk[0] === 0x01 || chunk[0] === 0x02 || (chunk[0] === 0x60 && chunk[1] === 0x0C)) {
// return ''
// }
// console.log('chunk:', Buffer.from(chunk).toString('hex'))
// console.log('chunk string:', Buffer.from(chunk).toString('utf-8'))
// // 去掉 chunk 中 0x0A 以及之前的字符
// chunk = chunk.slice(chunk.indexOf(0x0A) + 1)
// let filteredChunk = []
// let i = 0
// while (i < chunk.length) {
// // 新的条件过滤如果遇到连续4个0x00则移除其之后所有的以 0 开头的字节0x00 到 0x0F
// if (chunk.slice(i, i + 4).every(byte => byte === 0x00)) {
// i += 4 // 跳过这4个0x00
// while (i < chunk.length && chunk[i] >= 0x00 && chunk[i] <= 0x0F) {
// i++ // 跳过所有以 0 开头的字节
// }
// continue
// }
// if (chunk[i] === 0x0C) {
// // 遇到 0x0C 时,跳过 0x0C 以及后续的所有连续的 0x0A
// i++ // 跳过 0x0C
// while (i < chunk.length && chunk[i] === 0x0A) {
// i++ // 跳过所有连续的 0x0A
// }
// } else if (
// i > 0 &&
// chunk[i] === 0x0A &&
// chunk[i - 1] >= 0x00 &&
// chunk[i - 1] <= 0x09
// ) {
// // 如果当前字节是 0x0A且前一个字节在 0x00 至 0x09 之间,跳过前一个字节和当前字节
// filteredChunk.pop() // 移除已添加的前一个字节
// i++ // 跳过当前的 0x0A
// } else {
// filteredChunk.push(chunk[i])
// i++
// }
// }
// // 第二步:去除所有的 0x00 和 0x0C
// filteredChunk = filteredChunk.filter((byte) => byte !== 0x00 && byte !== 0x0C)
// // 去除小于 0x0A 的字节
// filteredChunk = filteredChunk.filter((byte) => byte >= 0x0A)
// const hexString = Buffer.from(filteredChunk).toString('hex')
// console.log('hexString:', hexString)
// const utf8String = Buffer.from(filteredChunk).toString('utf-8')
// console.log('utf8String:', utf8String)
// return utf8String
// }
// func processChunk(chunk []byte) string {
// // 检查特殊字节开头的情况
// if len(chunk) > 0 && (chunk[0] == 0x01 || chunk[0] == 0x02 || (len(chunk) > 1 && chunk[0] == 0x60 && chunk[1] == 0x0C)) {
// return ""
// }
// // 打印调试信息
// fmt.Printf("chunk: %x\n", chunk)
// fmt.Printf("chunk string: %s\n", string(chunk))
// // 找到第一个 0x0A 并截取之后的内容
// index := bytes.IndexByte(chunk, 0x0A)
// if index != -1 {
// chunk = chunk[index+1:]
// }
// // 创建过滤后的切片
// filteredChunk := make([]byte, 0, len(chunk))
// for i := 0; i < len(chunk); {
// // 检查连续4个0x00的情况
// if i+4 <= len(chunk) {
// if chunk[i] == 0x00 && chunk[i+1] == 0x00 && chunk[i+2] == 0x00 && chunk[i+3] == 0x00 {
// i += 4
// // 跳过所有以0开头的字节
// for i < len(chunk) && chunk[i] <= 0x0F {
// i++
// }
// continue
// }
// }
// if chunk[i] == 0x0C {
// i++
// // 跳过所有连续的0x0A
// for i < len(chunk) && chunk[i] == 0x0A {
// i++
// }
// } else if i > 0 && chunk[i] == 0x0A && chunk[i-1] >= 0x00 && chunk[i-1] <= 0x09 {
// // 移除前一个字节并跳过当前的0x0A
// filteredChunk = filteredChunk[:len(filteredChunk)-1]
// i++
// } else {
// filteredChunk = append(filteredChunk, chunk[i])
// i++
// }
// }
// // 过滤掉0x00和0x0C
// tempChunk := make([]byte, 0, len(filteredChunk))
// for _, b := range filteredChunk {
// if b != 0x00 && b != 0x0C {
// tempChunk = append(tempChunk, b)
// }
// }
// filteredChunk = tempChunk
// // 过滤掉小于0x0A的字节
// tempChunk = make([]byte, 0, len(filteredChunk))
// for _, b := range filteredChunk {
// if b >= 0x0A {
// tempChunk = append(tempChunk, b)
// }
// }
// filteredChunk = tempChunk
// // 打印调试信息并返回结果
// fmt.Printf("hexString: %x\n", filteredChunk)
// result := string(filteredChunk)
// fmt.Printf("utf8String: %s\n", result)
// return result
// }
func processChunk(chunk []byte) string {
// 检查特殊字节开头的情况
if len(chunk) > 0 && (chunk[0] == 0x01 || chunk[0] == 0x02 || (len(chunk) > 1 && chunk[0] == 0x60 && chunk[1] == 0x0C)) {
return ""
}
// 打印调试信息
fmt.Printf("chunk: %x\n", chunk)
fmt.Printf("chunk string: %s\n", string(chunk))
// 找到第一个 0x0A 并截取之后的内容
index := bytes.IndexByte(chunk, 0x0A)
if index != -1 {
chunk = chunk[index+1:]
}
// 创建过滤后的切片
filteredChunk := make([]byte, 0, len(chunk))
for i := 0; i < len(chunk); {
// 检查连续4个0x00的情况
if i+4 <= len(chunk) {
allZeros := true
for j := 0; j < 4; j++ {
if chunk[i+j] != 0x00 {
allZeros = false
break
}
}
if allZeros {
i += 4
// 跳过所有以0开头的字节
for i < len(chunk) && chunk[i] <= 0x0F {
i++
}
continue
}
}
// 保留UTF-8字符
if chunk[i] >= 0xE0 || (chunk[i] >= 0x20 && chunk[i] <= 0x7F) {
filteredChunk = append(filteredChunk, chunk[i])
}
i++
}
// 打印调试信息并返回结果
fmt.Printf("hexString: %x\n", filteredChunk)
result := string(filteredChunk)
fmt.Printf("utf8String: %s\n", result)
return string(chunk)
}

View File

@@ -1,6 +1,3 @@
# cursor-api
将 Cursor 编辑器转换为 OpenAI 兼容的 API 接口服务。
## 项目简介

45
rs-capi/.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,45 @@
{
// 使用 IntelliSense 了解相关属性。
// 悬停以查看现有属性的描述。
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug executable 'rs-api'",
"cargo": {
"args": [
"build",
"--bin=rs-api",
"--package=rs-api"
],
"filter": {
"name": "rs-api",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in executable 'rs-api'",
"cargo": {
"args": [
"test",
"--no-run",
"--bin=rs-api",
"--package=rs-api"
],
"filter": {
"name": "rs-api",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
}
]
}

1965
rs-capi/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

23
rs-capi/Cargo.toml Normal file
View File

@@ -0,0 +1,23 @@
[package]
name = "rs-api"
version = "0.1.0"
edition = "2021"
[dependencies]
axum = { version = "0.7", features = ["json"] }
tokio = { version = "1.0", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
reqwest = { version = "0.11", features = ["json", "stream"] }
tower-http = { version = "0.5", features = ["cors", "trace"] }
uuid = { version = "1.0", features = ["v4"] }
dotenv = "0.15"
chrono = "0.4"
futures = "0.3"
bytes = "1.0"
regex = "1.5"
tracing = "0.1"
tracing-subscriber = "0.3"
hex = "0.4"
hyper = "1.5.1"
http = "1.1.0"

822
rs-capi/main.py Normal file
View File

@@ -0,0 +1,822 @@
import json
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import uuid
import httpx
import os
from dotenv import load_dotenv
import time
import re
# 加载环境变量
load_dotenv()
app = FastAPI()
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 定义请求模型
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: bool = False
def string_to_hex(text: str, model_name: str) -> bytes:
"""将文本转换为特定格式的十六进制数据"""
# 将输入文本转换为UTF-8字节
text_bytes = text.encode('utf-8')
text_length = len(text_bytes)
# 固定常量
FIXED_HEADER = 2
SEPARATOR = 1
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
# 计算第一个长度字段
if text_length < 128:
text_length_field1 = format(text_length, '02x')
text_length_field_size1 = 1
else:
low_byte1 = (text_length & 0x7F) | 0x80
high_byte1 = (text_length >> 7) & 0xFF
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
text_length_field_size1 = 2
# 计算基础长度字段
base_length = text_length + 0x2A
if base_length < 128:
text_length_field = format(base_length, '02x')
text_length_field_size = 1
else:
low_byte = (base_length & 0x7F) | 0x80
high_byte = (base_length >> 7) & 0xFF
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
text_length_field_size = 2
# 计算总消息长度
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
# 构造十六进制字符串
model_name_bytes = model_name.encode('utf-8')
model_name_length_hex = format(len(model_name_bytes), '02X')
model_name_hex = model_name_bytes.hex().upper()
hex_string = (
f"{message_total_length:010x}"
"12"
f"{text_length_field}"
"0A"
f"{text_length_field1}"
f"{text_bytes.hex()}"
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
f"{model_name_length_hex}"
f"{model_name_hex}"
"22004A"
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
"800101B00100C00100E00100E80100"
).upper()
return bytes.fromhex(hex_string)
def chunk_to_utf8_string(chunk: bytes) -> str:
"""将二进制chunk转换为UTF-8字符串"""
if not chunk or len(chunk) < 2:
return ''
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
return ''
# 记录原始chunk的十六进制调试用
print(f"chunk length: {len(chunk)}")
# print(f"chunk hex: {chunk.hex()}")
try:
# 去掉0x0A之前的所有字节
try:
chunk = chunk[chunk.index(0x0A) + 1:]
except ValueError:
pass
filtered_chunk = bytearray()
i = 0
while i < len(chunk):
# 检查是否有连续的0x00
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
i += 4
while i < len(chunk) and chunk[i] <= 0x0F:
i += 1
continue
if chunk[i] == 0x0C:
i += 1
while i < len(chunk) and chunk[i] == 0x0A:
i += 1
else:
filtered_chunk.append(chunk[i])
i += 1
# 过滤掉特定字节
filtered_chunk = bytes(b for b in filtered_chunk
if b != 0x00 and b != 0x0C)
if not filtered_chunk:
return ''
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
# print(f"decoded result: {result}") # 调试输出
return result
except Exception as e:
print(f"Error in chunk_to_utf8_string: {str(e)}")
return ''
async def process_stream(chunks, ):
"""处理流式响应"""
response_id = f"chatcmpl-{str(uuid.uuid4())}"
# 先将所有chunks读取到列表中
# chunks = []
# async for chunk in response.aiter_raw():
# chunks.append(chunk)
# 然后处理保存的chunks
for chunk in chunks:
text = chunk_to_utf8_string(chunk)
if text:
# 清理文本
text = text.strip()
if "<|END_USER|>" in text:
text = text.split("<|END_USER|>")[-1].strip()
if text and text[0].isalpha():
text = text[1:].strip()
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
if text: # 确保清理后的文本不为空
data_body = {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"choices": [{
"index": 0,
"delta": {
"content": text
}
}]
}
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
# yield "data: {\n"
# yield f' "id": "{response_id}",\n'
# yield ' "object": "chat.completion.chunk",\n'
# yield f' "created": {int(time.time())},\n'
# yield ' "choices": [{\n'
# yield ' "index": 0,\n'
# yield ' "delta": {\n'
# yield f' "content": "{text}"\n'
# yield " }\n"
# yield " }]\n"
# yield "}\n\n"
yield "data: [DONE]\n\n"
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, chat_request: ChatRequest):
# 验证o1模型不支持流式输出
if chat_request.model.startswith('o1-') and chat_request.stream:
raise HTTPException(
status_code=400,
detail="Model not supported stream"
)
# 获取并处理认证令牌
auth_header = request.headers.get('authorization', '')
if not auth_header.startswith('Bearer '):
raise HTTPException(
status_code=401,
detail="Invalid authorization header"
)
auth_token = auth_header.replace('Bearer ', '')
if not auth_token:
raise HTTPException(
status_code=401,
detail="Missing authorization token"
)
# 处理多个密钥
keys = [key.strip() for key in auth_token.split(',')]
if keys:
auth_token = keys[0] # 使用第一个密钥
if '%3A%3A' in auth_token:
auth_token = auth_token.split('%3A%3A')[1]
# 格式化消息
formatted_messages = "\n".join(
f"{msg.role}:{msg.content}" for msg in chat_request.messages
)
# 生成请求数据
hex_data = string_to_hex(formatted_messages, chat_request.model)
# 准备请求头
headers = {
'Content-Type': 'application/connect+proto',
'Authorization': f'Bearer {auth_token}',
'Connect-Accept-Encoding': 'gzip,br',
'Connect-Protocol-Version': '1',
'User-Agent': 'connect-es/1.4.0',
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
'X-Cursor-Client-Version': '0.42.3',
'X-Cursor-Timezone': 'Asia/Shanghai',
'X-Ghost-Mode': 'false',
'X-Request-Id': str(uuid.uuid4()),
'Host': 'api2.cursor.sh'
}
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
try:
# 使用 stream=True 参数
# 打印 headers 和 二进制 data
print(f"headers: {headers}")
print(hex_data)
async with client.stream(
'POST',
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
headers=headers,
content=hex_data,
timeout=None
) as response:
if chat_request.stream:
chunks = []
async for chunk in response.aiter_raw():
chunks.append(chunk)
return StreamingResponse(
process_stream(chunks),
media_type="text/event-stream"
)
else:
# 非流式响应处理
text = ''
async for chunk in response.aiter_raw():
# print('chunk:', chunk.hex())
print('chunk length:', len(chunk))
res = chunk_to_utf8_string(chunk)
# print('res:', res)
if res:
text += res
# 清理响应文本
import re
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
return {
"id": f"chatcmpl-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": chat_request.model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
except Exception as e:
print(f"Error: {str(e)}")
raise HTTPException(
status_code=500,
detail="Internal server error"
)
@app.post("/models")
async def models():
return {
"object": "list",
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"object": "model",
"created": 1713744000,
"owned_by": "anthropic"
},
{
"id": "claude-3-opus",
"object": "model",
"created": 1709251200,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-haiku",
"object": "model",
"created": 1711929600,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-sonnet",
"object": "model",
"created": 1711929600,
"owned_by": "anthropic"
},
{
"id": "cursor-small",
"object": "model",
"created": 1712534400,
"owned_by": "cursor"
},
{
"id": "gpt-3.5-turbo",
"object": "model",
"created": 1677649200,
"owned_by": "openai"
},
{
"id": "gpt-4",
"object": "model",
"created": 1687392000,
"owned_by": "openai"
},
{
"id": "gpt-4-turbo-2024-04-09",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "gpt-4o",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "gpt-4o-mini",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "o1-mini",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "o1-preview",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
}
]
}
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", "3001"))
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)
import json
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import uuid
import httpx
import os
from dotenv import load_dotenv
import time
import re
# 加载环境变量
load_dotenv()
app = FastAPI()
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 定义请求模型
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: bool = False
def string_to_hex(text: str, model_name: str) -> bytes:
"""将文本转换为特定格式的十六进制数据"""
# 将输入文本转换为UTF-8字节
text_bytes = text.encode('utf-8')
text_length = len(text_bytes)
# 固定常量
FIXED_HEADER = 2
SEPARATOR = 1
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
# 计算第一个长度字段
if text_length < 128:
text_length_field1 = format(text_length, '02x')
text_length_field_size1 = 1
else:
low_byte1 = (text_length & 0x7F) | 0x80
high_byte1 = (text_length >> 7) & 0xFF
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
text_length_field_size1 = 2
# 计算基础长度字段
base_length = text_length + 0x2A
if base_length < 128:
text_length_field = format(base_length, '02x')
text_length_field_size = 1
else:
low_byte = (base_length & 0x7F) | 0x80
high_byte = (base_length >> 7) & 0xFF
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
text_length_field_size = 2
# 计算总消息长度
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
# 构造十六进制字符串
model_name_bytes = model_name.encode('utf-8')
model_name_length_hex = format(len(model_name_bytes), '02X')
model_name_hex = model_name_bytes.hex().upper()
hex_string = (
f"{message_total_length:010x}"
"12"
f"{text_length_field}"
"0A"
f"{text_length_field1}"
f"{text_bytes.hex()}"
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
f"{model_name_length_hex}"
f"{model_name_hex}"
"22004A"
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
"800101B00100C00100E00100E80100"
).upper()
return bytes.fromhex(hex_string)
def chunk_to_utf8_string(chunk: bytes) -> str:
"""将二进制chunk转换为UTF-8字符串"""
if not chunk or len(chunk) < 2:
return ''
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
return ''
# 记录原始chunk的十六进制调试用
print(f"chunk length: {len(chunk)}")
# print(f"chunk hex: {chunk.hex()}")
try:
# 去掉0x0A之前的所有字节
try:
chunk = chunk[chunk.index(0x0A) + 1:]
except ValueError:
pass
filtered_chunk = bytearray()
i = 0
while i < len(chunk):
# 检查是否有连续的0x00
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
i += 4
while i < len(chunk) and chunk[i] <= 0x0F:
i += 1
continue
if chunk[i] == 0x0C:
i += 1
while i < len(chunk) and chunk[i] == 0x0A:
i += 1
else:
filtered_chunk.append(chunk[i])
i += 1
# 过滤掉特定字节
filtered_chunk = bytes(b for b in filtered_chunk
if b != 0x00 and b != 0x0C)
if not filtered_chunk:
return ''
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
# print(f"decoded result: {result}") # 调试输出
return result
except Exception as e:
print(f"Error in chunk_to_utf8_string: {str(e)}")
return ''
async def process_stream(chunks, ):
"""处理流式响应"""
response_id = f"chatcmpl-{str(uuid.uuid4())}"
# 先将所有chunks读取到列表中
# chunks = []
# async for chunk in response.aiter_raw():
# chunks.append(chunk)
# 然后处理保存的chunks
for chunk in chunks:
text = chunk_to_utf8_string(chunk)
if text:
# 清理文本
text = text.strip()
if "<|END_USER|>" in text:
text = text.split("<|END_USER|>")[-1].strip()
if text and text[0].isalpha():
text = text[1:].strip()
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
if text: # 确保清理后的文本不为空
data_body = {
"id": response_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"choices": [{
"index": 0,
"delta": {
"content": text
}
}]
}
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
# yield "data: {\n"
# yield f' "id": "{response_id}",\n'
# yield ' "object": "chat.completion.chunk",\n'
# yield f' "created": {int(time.time())},\n'
# yield ' "choices": [{\n'
# yield ' "index": 0,\n'
# yield ' "delta": {\n'
# yield f' "content": "{text}"\n'
# yield " }\n"
# yield " }]\n"
# yield "}\n\n"
yield "data: [DONE]\n\n"
@app.post("/v1/chat/completions")
async def chat_completions(request: Request, chat_request: ChatRequest):
# 验证o1模型不支持流式输出
if chat_request.model.startswith('o1-') and chat_request.stream:
raise HTTPException(
status_code=400,
detail="Model not supported stream"
)
# 获取并处理认证令牌
auth_header = request.headers.get('authorization', '')
if not auth_header.startswith('Bearer '):
raise HTTPException(
status_code=401,
detail="Invalid authorization header"
)
auth_token = auth_header.replace('Bearer ', '')
if not auth_token:
raise HTTPException(
status_code=401,
detail="Missing authorization token"
)
# 处理多个密钥
keys = [key.strip() for key in auth_token.split(',')]
if keys:
auth_token = keys[0] # 使用第一个密钥
if '%3A%3A' in auth_token:
auth_token = auth_token.split('%3A%3A')[1]
# 格式化消息
formatted_messages = "\n".join(
f"{msg.role}:{msg.content}" for msg in chat_request.messages
)
# 生成请求数据
hex_data = string_to_hex(formatted_messages, chat_request.model)
# 准备请求头
headers = {
'Content-Type': 'application/connect+proto',
'Authorization': f'Bearer {auth_token}',
'Connect-Accept-Encoding': 'gzip,br',
'Connect-Protocol-Version': '1',
'User-Agent': 'connect-es/1.4.0',
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
'X-Cursor-Client-Version': '0.42.3',
'X-Cursor-Timezone': 'Asia/Shanghai',
'X-Ghost-Mode': 'false',
'X-Request-Id': str(uuid.uuid4()),
'Host': 'api2.cursor.sh'
}
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
try:
# 使用 stream=True 参数
# 打印 headers 和 二进制 data
print(f"headers: {headers}")
print(hex_data)
async with client.stream(
'POST',
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
headers=headers,
content=hex_data,
timeout=None
) as response:
if chat_request.stream:
chunks = []
async for chunk in response.aiter_raw():
chunks.append(chunk)
return StreamingResponse(
process_stream(chunks),
media_type="text/event-stream"
)
else:
# 非流式响应处理
text = ''
async for chunk in response.aiter_raw():
# print('chunk:', chunk.hex())
print('chunk length:', len(chunk))
res = chunk_to_utf8_string(chunk)
# print('res:', res)
if res:
text += res
# 清理响应文本
import re
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
return {
"id": f"chatcmpl-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": chat_request.model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
except Exception as e:
print(f"Error: {str(e)}")
raise HTTPException(
status_code=500,
detail="Internal server error"
)
@app.post("/models")
async def models():
return {
"object": "list",
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"object": "model",
"created": 1713744000,
"owned_by": "anthropic"
},
{
"id": "claude-3-opus",
"object": "model",
"created": 1709251200,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-haiku",
"object": "model",
"created": 1711929600,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-sonnet",
"object": "model",
"created": 1711929600,
"owned_by": "anthropic"
},
{
"id": "cursor-small",
"object": "model",
"created": 1712534400,
"owned_by": "cursor"
},
{
"id": "gpt-3.5-turbo",
"object": "model",
"created": 1677649200,
"owned_by": "openai"
},
{
"id": "gpt-4",
"object": "model",
"created": 1687392000,
"owned_by": "openai"
},
{
"id": "gpt-4-turbo-2024-04-09",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "gpt-4o",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "gpt-4o-mini",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "o1-mini",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
},
{
"id": "o1-preview",
"object": "model",
"created": 1712620800,
"owned_by": "openai"
}
]
}
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", "3001"))
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)

111
rs-capi/src/hex_utils.rs Normal file
View File

@@ -0,0 +1,111 @@
pub fn string_to_hex(text: &str, model_name: &str) -> Vec<u8> {
let text_bytes = text.as_bytes();
let text_length = text_bytes.len();
// 固定常量
const FIXED_HEADER: usize = 2;
const SEPARATOR: usize = 1;
let model_name_bytes = model_name.as_bytes();
let fixed_suffix_length = 0xA3 + model_name_bytes.len();
// 计算第一个长度字段
let (text_length_field1, text_length_field_size1) = if text_length < 128 {
(format!("{:02x}", text_length), 1)
} else {
let low_byte1 = (text_length & 0x7F) | 0x80;
let high_byte1 = (text_length >> 7) & 0xFF;
(format!("{:02x}{:02x}", low_byte1, high_byte1), 2)
};
// 计算基础长度字段
let base_length = text_length + 0x2A;
let (text_length_field, text_length_field_size) = if base_length < 128 {
(format!("{:02x}", base_length), 1)
} else {
let low_byte = (base_length & 0x7F) | 0x80;
let high_byte = (base_length >> 7) & 0xFF;
(format!("{:02x}{:02x}", low_byte, high_byte), 2)
};
// 计算总消息长度
let message_total_length = FIXED_HEADER + text_length_field_size + SEPARATOR +
text_length_field_size1 + text_length + fixed_suffix_length;
// 构造十六进制字符串
let model_name_length_hex = format!("{:02X}", model_name_bytes.len());
let model_name_hex = hex::encode_upper(model_name_bytes);
let hex_string = format!(
"{:010x}\
12{}\
0A{}\
{}\
10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612\
2002A132F643A2F6964656150726F2F656475626F73733A1E0A\
{}{}\
22004A\
2461383761396133342D323164642D343863372D623434662D616636633365636536663765\
680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061\
800101B00100C00100E00100E80100",
message_total_length,
text_length_field,
text_length_field1,
hex::encode_upper(text_bytes),
model_name_length_hex,
model_name_hex
).to_uppercase();
// 将十六进制字符串转换为字节数组
hex::decode(hex_string).unwrap_or_default()
}
pub fn chunk_to_utf8_string(chunk: &[u8]) -> String {
if chunk.len() < 2 {
return String::new();
}
if chunk[0] == 0x01 || chunk[0] == 0x02 || (chunk[0] == 0x60 && chunk[1] == 0x0C) {
return String::new();
}
// 尝试找到0x0A并从其后开始处理
let chunk = match chunk.iter().position(|&x| x == 0x0A) {
Some(pos) => &chunk[pos + 1..],
None => chunk
};
let mut filtered_chunk = Vec::new();
let mut i = 0;
while i < chunk.len() {
// 检查是否有连续的0x00
if i + 4 <= chunk.len() && chunk[i..i+4].iter().all(|&x| x == 0x00) {
i += 4;
while i < chunk.len() && chunk[i] <= 0x0F {
i += 1;
}
continue;
}
if chunk[i] == 0x0C {
i += 1;
while i < chunk.len() && chunk[i] == 0x0A {
i += 1;
}
} else {
filtered_chunk.push(chunk[i]);
i += 1;
}
}
// 过滤掉特定字节
filtered_chunk.retain(|&b| b != 0x00 && b != 0x0C);
if filtered_chunk.is_empty() {
return String::new();
}
// 转换为UTF-8字符串
String::from_utf8_lossy(&filtered_chunk).trim().to_string()
}

352
rs-capi/src/main.rs Normal file
View File

@@ -0,0 +1,352 @@
use axum::{
http::{HeaderMap, StatusCode},
response::{
sse::{Event, Sse},
IntoResponse, Response,
},
routing::post,
Json, Router,
};
use tower_http::trace::TraceLayer;
use bytes::Bytes;
use futures::{
channel::mpsc,
stream::{Stream, StreamExt},
SinkExt,
};
// use http::HeaderName as HttpHeaderName;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::{convert::Infallible, time::Duration};
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
mod hex_utils;
use hex_utils::{chunk_to_utf8_string, string_to_hex};
// 定义请求模型
#[derive(Debug, Deserialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
#[serde(default)]
stream: bool,
}
// 定义响应模型
#[derive(Debug, Serialize)]
struct ChatResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<Choice>,
usage: Usage,
}
#[derive(Debug, Serialize)]
struct Choice {
index: i32,
message: ResponseMessage,
finish_reason: String,
}
#[derive(Debug, Serialize)]
struct ResponseMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize)]
struct Usage {
prompt_tokens: i32,
completion_tokens: i32,
total_tokens: i32,
}
#[derive(Debug, Serialize)]
struct StreamResponse {
id: String,
object: String,
created: i64,
choices: Vec<StreamChoice>,
}
#[derive(Debug, Serialize)]
struct StreamChoice {
index: i32,
delta: Delta,
}
#[derive(Debug, Serialize)]
struct Delta {
content: String,
}
async fn process_stream(
chunks: Vec<Bytes>,
) -> impl Stream<Item = Result<Event, Infallible>> + Send {
let (mut tx, rx) = mpsc::channel(100);
let response_id = format!("chatcmpl-{}", Uuid::new_v4());
tokio::spawn(async move {
for chunk in chunks {
let text = chunk_to_utf8_string(&chunk);
if !text.is_empty() {
let text = text.trim();
let text = if let Some(idx) = text.find("<|END_USER|>") {
text[idx + "<|END_USER|>".len()..].trim()
} else {
text
};
let text = if !text.is_empty() && text.chars().next().unwrap().is_alphabetic() {
text[1..].trim()
} else {
text
};
let re = Regex::new(r"[\x00-\x1F\x7F]").unwrap();
let text = re.replace_all(text, "");
if !text.is_empty() {
let response = StreamResponse {
id: response_id.clone(),
object: "chat.completion.chunk".to_string(),
created: chrono::Utc::now().timestamp(),
choices: vec![StreamChoice {
index: 0,
delta: Delta {
content: text.to_string(),
},
}],
};
let json_data = serde_json::to_string(&response).unwrap();
if !json_data.is_empty() {
let _ = tx.send(Ok(Event::default().data(json_data))).await;
}
}
}
}
let _ = tx.send(Ok(Event::default().data("[DONE]"))).await;
});
rx
}
#[tokio::main]
async fn main() {
// 初始化日志
tracing_subscriber::fmt::init();
// 创建CORS中间件
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
// 创建路由
let app = Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/models", post(models))
.layer(cors)
.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &axum::http::Request<_>| {
tracing::info_span!(
"http_request",
method = %request.method(),
uri = %request.uri(),
)
})
// .on_request(|_request: &axum::http::Request<_>, _span: &tracing::Span| { info!("started processing request"); })
.on_response(
|response: &axum::http::Response<_>,
latency: std::time::Duration,
_span: &tracing::Span| {
tracing::info!(
status = %response.status(),
latency = ?latency,
);
},
),
);
// 启动服务器
let addr = "0.0.0.0:3002";
println!("Server running on {}", addr);
// 修改服务器启动代码
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
// 处理聊天完成请求
async fn chat_completions(
headers: HeaderMap,
Json(chat_request): Json<ChatRequest>,
) -> Result<Response, StatusCode> {
// 验证认证
let auth_header = headers
.get("authorization")
.and_then(|h| h.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
if !auth_header.starts_with("Bearer ") {
return Err(StatusCode::UNAUTHORIZED);
}
let mut auth_token = auth_header.replace("Bearer ", "");
// 验证o1模型不支持流式输出
if chat_request.model.starts_with("o1-") && chat_request.stream {
return Err(StatusCode::BAD_REQUEST);
}
tracing::info!("chat_request: {:?}", chat_request);
// 处理多个密钥
if auth_token.contains(',') {
auth_token = auth_token.split(',').next().unwrap().trim().to_string();
}
if auth_token.contains("%3A%3A") {
auth_token = auth_token
.split("%3A%3A")
.nth(1)
.unwrap_or(&auth_token)
.to_string();
}
// 格式化消息
let formatted_messages = chat_request
.messages
.iter()
.map(|msg| format!("{}:{}", msg.role, msg.content))
.collect::<Vec<_>>()
.join("\n");
// 生成请求数据
let hex_data = string_to_hex(&formatted_messages, &chat_request.model);
// 准备请求头
let request_id = Uuid::new_v4();
let headers = reqwest::header::HeaderMap::from_iter([
(reqwest::header::CONTENT_TYPE, "application/connect+proto"),
(reqwest::header::AUTHORIZATION, &format!("Bearer {}", auth_token)),
// 对于标准 HTTP 头部,使用预定义的常量
(reqwest::header::HeaderName::from_str("Connect-Accept-Encoding").unwrap(), "gzip,br"),
(reqwest::header::HeaderName::from_str("Connect-Protocol-Version").unwrap(), "1"),
(reqwest::header::HeaderName::from_str("User-Agent").unwrap(), "connect-es/1.4.0"),
(reqwest::header::HeaderName::from_str("X-Amzn-Trace-Id").unwrap(), &format!("Root={}", Uuid::new_v4())),
(reqwest::header::HeaderName::from_str("X-Cursor-Checksum").unwrap(), "zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef"),
(reqwest::header::HeaderName::from_str("X-Cursor-Client-Version").unwrap(), "0.42.3"),
(reqwest::header::HeaderName::from_str("X-Cursor-Timezone").unwrap(), "Asia/Shanghai"),
(reqwest::header::HeaderName::from_str("X-Ghost-Mode").unwrap(), "false"),
(reqwest::header::HeaderName::from_str("X-Request-Id").unwrap(), &request_id.to_string()),
(reqwest::header::HeaderName::from_str("Host").unwrap(), "api2.cursor.sh"),
].iter().map(|(k, v)| (
k.clone(),
reqwest::header::HeaderValue::from_str(v).unwrap()
)));
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(300))
.build()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let response = client
.post("https://api2.cursor.sh/aiserver.v1.AiService/StreamChat")
.headers(headers)
.body(hex_data)
.send()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if chat_request.stream {
let mut chunks = Vec::new();
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(chunk) => chunks.push(chunk),
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
let stream = process_stream(chunks).await;
return Ok(Sse::new(stream).into_response());
}
// 非流式响应
let mut text = String::new();
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(chunk) => {
let res = chunk_to_utf8_string(&chunk);
if !res.is_empty() {
text.push_str(&res);
}
}
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
// 清理响应文本
let re = Regex::new(r"^.*<\|END_USER\|>").unwrap();
text = re.replace(&text, "").to_string();
let re = Regex::new(r"^\n[a-zA-Z]?").unwrap();
text = re.replace(&text, "").trim().to_string();
let re = Regex::new(r"[\x00-\x1F\x7F]").unwrap();
text = re.replace_all(&text, "").to_string();
let response = ChatResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp(),
model: chat_request.model,
choices: vec![Choice {
index: 0,
message: ResponseMessage {
role: "assistant".to_string(),
content: text,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
};
Ok(Json(response).into_response())
}
// 处理模型列表请求
async fn models() -> Json<serde_json::Value> {
Json(serde_json::json!({
"object": "list",
"data": [
{
"id": "claude-3-5-sonnet-20241022",
"object": "model",
"created": 1713744000,
"owned_by": "anthropic"
},
// ... 其他模型
]
}))
}

View File

@@ -1,45 +0,0 @@
package main
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
}
type ChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
type StreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
}

View File

@@ -1,83 +0,0 @@
package main
import (
"bytes"
"encoding/hex"
"fmt"
"strings"
)
func stringToHex(str, modelName string) []byte {
inputBytes := []byte(str)
byteLength := len(inputBytes)
const (
FIXED_HEADER = 2
SEPARATOR = 1
)
FIXED_SUFFIX_LENGTH := 0xA3 + len(modelName)
// 计算文本长度字段
var textLengthField1, textLengthFieldSize1 int
if byteLength < 128 {
textLengthField1 = byteLength
textLengthFieldSize1 = 1
} else {
lowByte1 := (byteLength & 0x7F) | 0x80
highByte1 := (byteLength >> 7) & 0xFF
textLengthField1 = (highByte1 << 8) | lowByte1
textLengthFieldSize1 = 2
}
// 计算基础长度
baseLength := byteLength + 0x2A
var textLengthField, textLengthFieldSize int
if baseLength < 128 {
textLengthField = baseLength
textLengthFieldSize = 1
} else {
lowByte := (baseLength & 0x7F) | 0x80
highByte := (baseLength >> 7) & 0xFF
textLengthField = (highByte << 8) | lowByte
textLengthFieldSize = 2
}
// 计算总消息长度
messageTotalLength := FIXED_HEADER + textLengthFieldSize + SEPARATOR +
textLengthFieldSize1 + byteLength + FIXED_SUFFIX_LENGTH
var buf bytes.Buffer
// 写入消息长度
fmt.Fprintf(&buf, "%010x", messageTotalLength)
// 写入固定头部
buf.WriteString("12")
// 写入长度字段
fmt.Fprintf(&buf, "%02x", textLengthField)
buf.WriteString("0A")
fmt.Fprintf(&buf, "%02x", textLengthField1)
// 写入消息内容
buf.WriteString(hex.EncodeToString(inputBytes))
// 写入固定后缀
buf.WriteString("10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612")
buf.WriteString("2002A132F643A2F6964656150726F2F656475626F73733A1E0A")
// 写入模型名称长度和内容
fmt.Fprintf(&buf, "%02X", len(modelName))
buf.WriteString(strings.ToUpper(hex.EncodeToString([]byte(modelName))))
// 写入剩余固定内容
buf.WriteString("22004A")
buf.WriteString("2461383761396133342D323164642D343863372D623434662D616636633365636536663765")
buf.WriteString("680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061")
buf.WriteString("800101B00100C00100E00100E80100")
hexBytes, _ := hex.DecodeString(strings.ToUpper(buf.String()))
return hexBytes
}

158
utils/hex.go Normal file
View File

@@ -0,0 +1,158 @@
package utils
import (
"bytes"
"encoding/hex"
"fmt"
"strings"
"unicode/utf8"
)
func StringToHex(text string, modelName string) ([]byte, error) {
textBytes := []byte(text)
textLength := len(textBytes)
const (
FIXED_HEADER = 2
SEPARATOR = 1
)
modelNameBytes := []byte(modelName)
FIXED_SUFFIX_LENGTH := 0xA3 + len(modelNameBytes)
// 计算第一个长度字段
var textLengthField1 string
var textLengthFieldSize1 int
if textLength < 128 {
textLengthField1 = fmt.Sprintf("%02x", textLength)
textLengthFieldSize1 = 1
} else {
lowByte1 := (textLength & 0x7F) | 0x80
highByte1 := (textLength >> 7) & 0xFF
textLengthField1 = fmt.Sprintf("%02x%02x", lowByte1, highByte1)
textLengthFieldSize1 = 2
}
// 计算基础长度字段
baseLength := textLength + 0x2A
var textLengthField string
var textLengthFieldSize int
if baseLength < 128 {
textLengthField = fmt.Sprintf("%02x", baseLength)
textLengthFieldSize = 1
} else {
lowByte := (baseLength & 0x7F) | 0x80
highByte := (baseLength >> 7) & 0xFF
textLengthField = fmt.Sprintf("%02x%02x", lowByte, highByte)
textLengthFieldSize = 2
}
// 计算总消息长度
messageTotalLength := FIXED_HEADER + textLengthFieldSize + SEPARATOR + textLengthFieldSize1 + textLength + FIXED_SUFFIX_LENGTH
modelNameHex := strings.ToUpper(hex.EncodeToString(modelNameBytes))
modelNameLengthHex := fmt.Sprintf("%02X", len(modelNameBytes))
hexString := fmt.Sprintf(
"%010x"+
"12"+
"%s"+
"0a"+
"%s"+
"%x"+
"10016a2432343163636435662d393162612d343131382d393239612d3936626330313631626432612"+
"2002a132f643a2f6964656150726f2f656475626f73733a1e0a"+
"%s"+
"%s"+
"22004a"+
"2461383761396133342d323164642d343863372d623434662d616636633365636536663765"+
"680070007a2436393737376535612d386332642d343835342d623564392d653062623232336163303061"+
"800101b00100c00100e00100e80100",
messageTotalLength,
textLengthField,
textLengthField1,
textBytes,
modelNameLengthHex,
modelNameHex,
)
hexString = strings.ToLower(hexString)
return hex.DecodeString(hexString)
}
func ChunkToUTF8String(chunk []byte) string {
// 基础检查
if len(chunk) < 2 {
return ""
}
if chunk[0] == 0x01 || chunk[0] == 0x02 || (chunk[0] == 0x60 && chunk[1] == 0x0C) {
return ""
}
// 修改调试输出格式
// fmt.Printf("chunk length: %d hex: %x\n", len(chunk), chunk)
fmt.Printf("chunk length: %d\n", len(chunk), )
// 去掉0x0A之前的所有字节
if idx := bytes.IndexByte(chunk, 0x0A); idx != -1 {
chunk = chunk[idx+1:]
}
// 修改过滤逻辑,将过滤步骤分开
filteredChunk := make([]byte, 0, len(chunk))
i := 0
for i < len(chunk) {
// 检查连续的0x00
if i+4 <= len(chunk) && allZeros(chunk[i:i+4]) {
i += 4
for i < len(chunk) && chunk[i] <= 0x0F {
i++
}
continue
}
if chunk[i] == 0x0C {
i++
for i < len(chunk) && chunk[i] == 0x0A {
i++
}
} else {
filteredChunk = append(filteredChunk, chunk[i])
i++
}
}
// 最后统一过滤特定字节
finalFiltered := make([]byte, 0, len(filteredChunk))
for _, b := range filteredChunk {
if b != 0x00 && b != 0x0C {
finalFiltered = append(finalFiltered, b)
}
}
if len(finalFiltered) == 0 {
return ""
}
// 添加错误处理
result := strings.TrimSpace(string(finalFiltered))
if !utf8.Valid(finalFiltered) {
fmt.Printf("Error: Invalid UTF-8 sequence\n")
return ""
}
fmt.Printf("decoded result: %s\n", result)
return result
}
// 辅助函数检查连续的零字节
func allZeros(data []byte) bool {
for _, b := range data {
if b != 0x00 {
return false
}
}
return true
}