mirror of
https://github.com/zeke-chin/cursor-api.git
synced 2025-09-26 19:51:11 +08:00
add: rs-capi
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,4 +1,7 @@
|
||||
.env
|
||||
tests/
|
||||
|
||||
node_modules/
|
||||
node_modules/
|
||||
.DS_Store
|
||||
.idea/
|
||||
__pycache__/
|
49
api.go
49
api.go
@@ -1,49 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func formatMessages(messages []Message) string {
|
||||
var formatted []string
|
||||
for _, msg := range messages {
|
||||
formatted = append(formatted, fmt.Sprintf("%s:%s", msg.Role, msg.Content))
|
||||
}
|
||||
return strings.Join(formatted, "\n")
|
||||
}
|
||||
|
||||
func sendToCursorAPI(c *gin.Context, hexData []byte) (*http.Response, error) {
|
||||
req, err := http.NewRequest("POST", "https://api2.cursor.sh/aiserver.v1.AiService/StreamChat", bytes.NewReader(hexData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取认证token
|
||||
authToken := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
|
||||
if strings.Contains(authToken, "%3A%3A") {
|
||||
authToken = strings.Split(authToken, "%3A%3A")[1]
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/connect+proto")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken))
|
||||
req.Header.Set("Connect-Accept-Encoding", "gzip,br")
|
||||
req.Header.Set("Connect-Protocol-Version", "1")
|
||||
req.Header.Set("User-Agent", "connect-es/1.4.0")
|
||||
req.Header.Set("X-Amzn-Trace-Id", fmt.Sprintf("Root=%s", uuid.New().String()))
|
||||
req.Header.Set("X-Cursor-Checksum", "zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef")
|
||||
req.Header.Set("X-Cursor-Client-Version", "0.42.3")
|
||||
req.Header.Set("X-Cursor-Timezone", "Asia/Shanghai")
|
||||
req.Header.Set("X-Ghost-Mode", "false")
|
||||
req.Header.Set("X-Request-Id", uuid.New().String())
|
||||
req.Header.Set("Host", "api2.cursor.sh")
|
||||
|
||||
client := &http.Client{}
|
||||
return client.Do(req)
|
||||
}
|
@@ -1,25 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"cursor-api-proxy/internal/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := godotenv.Load(); err != nil {
|
||||
log.Println("Warning: Error loading .env file")
|
||||
}
|
||||
|
||||
server := api.NewServer()
|
||||
|
||||
port := os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = "3000"
|
||||
}
|
||||
|
||||
log.Printf("服务器运行在端口 %s\n", port)
|
||||
server.Run(":" + port)
|
||||
}
|
41
go.mod
41
go.mod
@@ -1,5 +1,40 @@
|
||||
module cursor-api-proxy
|
||||
module go-capi
|
||||
|
||||
go 1.21
|
||||
go 1.22.0
|
||||
|
||||
// ... 其他依赖 ...
|
||||
require (
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.23.0 // indirect
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sys v0.20.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
95
go.sum
95
go.sum
@@ -1,29 +1,33 @@
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
|
||||
github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
|
||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
@@ -34,57 +38,66 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
|
||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
|
||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
|
||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
265
handlers/chat.go
Normal file
265
handlers/chat.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"unicode"
|
||||
"regexp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go-capi/models"
|
||||
"go-capi/utils"
|
||||
)
|
||||
|
||||
func ChatCompletions(c *gin.Context) {
|
||||
var chatRequest models.ChatRequest
|
||||
if err := c.ShouldBindJSON(&chatRequest); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证o1模型不支持流式输出
|
||||
if strings.HasPrefix(chatRequest.Model, "o1-") && chatRequest.Stream {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Model not supported stream"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取并处理认证令牌
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid authorization header"})
|
||||
return
|
||||
}
|
||||
|
||||
authToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if authToken == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing authorization token"})
|
||||
return
|
||||
}
|
||||
|
||||
// 处理多个密钥
|
||||
keys := strings.Split(authToken, ",")
|
||||
if len(keys) > 0 {
|
||||
authToken = strings.TrimSpace(keys[0])
|
||||
}
|
||||
|
||||
if strings.Contains(authToken, "%3A%3A") {
|
||||
parts := strings.Split(authToken, "%3A%3A")
|
||||
authToken = parts[1]
|
||||
}
|
||||
|
||||
// 格式化消息
|
||||
var messages []string
|
||||
for _, msg := range chatRequest.Messages {
|
||||
messages = append(messages, fmt.Sprintf("%s:%s", msg.Role, msg.Content))
|
||||
}
|
||||
formattedMessages := strings.Join(messages, "\n")
|
||||
|
||||
// 生成请求数据
|
||||
hexData, err := utils.StringToHex(formattedMessages, chatRequest.Model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 准备请求
|
||||
client := &http.Client{Timeout: 300 * time.Second}
|
||||
req, err := http.NewRequest("POST", "https://api2.cursor.sh/aiserver.v1.AiService/StreamChat", bytes.NewReader(hexData))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/connect+proto")
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
req.Header.Set("Connect-Accept-Encoding", "gzip,br")
|
||||
req.Header.Set("Connect-Protocol-Version", "1")
|
||||
req.Header.Set("User-Agent", "connect-es/1.4.0")
|
||||
req.Header.Set("X-Amzn-Trace-Id", "Root="+uuid.New().String())
|
||||
req.Header.Set("X-Cursor-Checksum", "zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef")
|
||||
req.Header.Set("X-Cursor-Client-Version", "0.42.3")
|
||||
req.Header.Set("X-Cursor-Timezone", "Asia/Shanghai")
|
||||
req.Header.Set("X-Ghost-Mode", "false")
|
||||
req.Header.Set("X-Request-Id", uuid.New().String())
|
||||
req.Header.Set("Host", "api2.cursor.sh")
|
||||
// ... 设置其他请求头
|
||||
|
||||
|
||||
// 打印 请求头和请求体
|
||||
fmt.Printf("\nRequest Headers: %v\n", req.Header)
|
||||
fmt.Printf("\nRequest Body: %x\n", hexData)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if chatRequest.Stream {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
chunks := make([][]byte, 0)
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
|
||||
for {
|
||||
chunk, err := reader.ReadBytes('\n')
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
c.SSEvent("error", gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
responseID := "chatcmpl-" + uuid.New().String()
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
for _, chunk := range chunks {
|
||||
text := chunkToUTF8String(chunk)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 清理文本
|
||||
text = strings.TrimSpace(text)
|
||||
if strings.Contains(text, "<|END_USER|>") {
|
||||
parts := strings.Split(text, "<|END_USER|>")
|
||||
text = strings.TrimSpace(parts[len(parts)-1])
|
||||
}
|
||||
if len(text) > 0 && unicode.IsLetter(rune(text[0])) {
|
||||
text = strings.TrimSpace(text[1:])
|
||||
}
|
||||
text = cleanControlChars(text)
|
||||
|
||||
if text != "" {
|
||||
dataBody := map[string]interface{}{
|
||||
"id": responseID,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"index": 0,
|
||||
"delta": map[string]string{
|
||||
"content": text,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, _ := json.Marshal(dataBody)
|
||||
c.SSEvent("", string(jsonData))
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
}
|
||||
|
||||
c.SSEvent("", "[DONE]")
|
||||
return false
|
||||
})
|
||||
} else {
|
||||
// 非流式响应处理
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
var allText string
|
||||
|
||||
for {
|
||||
chunk, err := reader.ReadBytes('\n')
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
text := utils.ChunkToUTF8String(chunk)
|
||||
if text != "" {
|
||||
allText += text
|
||||
}
|
||||
}
|
||||
|
||||
// 清理响应文本
|
||||
allText = cleanResponseText(allText)
|
||||
|
||||
|
||||
response := models.ChatResponse{
|
||||
ID: "chatcmpl-" + uuid.New().String(),
|
||||
Object: "chat.completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: chatRequest.Model,
|
||||
Choices: []models.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: &models.Message{
|
||||
Role: "assistant",
|
||||
Content: allText,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: 0,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: 0,
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func chunkToUTF8String(chunk []byte) string {
|
||||
// 实现从二进制chunk转换到UTF8字符串的逻辑
|
||||
return string(chunk)
|
||||
}
|
||||
|
||||
func cleanControlChars(text string) string {
|
||||
return regexp.MustCompile(`[\x00-\x1F\x7F]`).ReplaceAllString(text, "")
|
||||
}
|
||||
|
||||
func cleanResponseText(text string) string {
|
||||
// 移除END_USER之前的所有内容
|
||||
re := regexp.MustCompile(`(?s)^.*<\|END_USER\|>`)
|
||||
text = re.ReplaceAllString(text, "")
|
||||
|
||||
// 移除开头的换行和单个字母
|
||||
text = regexp.MustCompile(`^\n[a-zA-Z]?`).ReplaceAllString(text, "")
|
||||
text = strings.TrimSpace(text)
|
||||
|
||||
// 清理控制字符
|
||||
text = cleanControlChars(text)
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
func GetModels(c *gin.Context) {
|
||||
response := models.ModelsResponse{
|
||||
Object: "list",
|
||||
Data: []models.ModelData{
|
||||
{ID: "claude-3-5-sonnet-20241022", Object: "model", Created: 1713744000, OwnedBy: "anthropic"},
|
||||
{ID: "claude-3-opus", Object: "model", Created: 1709251200, OwnedBy: "anthropic"},
|
||||
{ID: "claude-3.5-haiku", Object: "model", Created: 1711929600, OwnedBy: "anthropic"},
|
||||
{ID: "claude-3.5-sonnet", Object: "model", Created: 1711929600, OwnedBy: "anthropic"},
|
||||
{ID: "cursor-small", Object: "model", Created: 1712534400, OwnedBy: "cursor"},
|
||||
{ID: "gpt-3.5-turbo", Object: "model", Created: 1677649200, OwnedBy: "openai"},
|
||||
{ID: "gpt-4", Object: "model", Created: 1687392000, OwnedBy: "openai"},
|
||||
{ID: "gpt-4-turbo-2024-04-09", Object: "model", Created: 1712620800, OwnedBy: "openai"},
|
||||
{ID: "gpt-4o", Object: "model", Created: 1712620800, OwnedBy: "openai"},
|
||||
{ID: "gpt-4o-mini", Object: "model", Created: 1712620800, OwnedBy: "openai"},
|
||||
{ID: "o1-mini", Object: "model", Created: 1712620800, OwnedBy: "openai"},
|
||||
{ID: "o1-preview", Object: "model", Created: 1712620800, OwnedBy: "openai"},
|
||||
},
|
||||
}
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
@@ -1,33 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"cursor-api-proxy/internal/models"
|
||||
"cursor-api-proxy/internal/service"
|
||||
)
|
||||
|
||||
func handleChat(c *gin.Context) {
|
||||
var req models.ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为 o1 开头的模型且请求流式输出
|
||||
if strings.HasPrefix(req.Model, "o1-") && req.Stream {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Model not supported stream"})
|
||||
return
|
||||
}
|
||||
|
||||
// ... 其他处理逻辑 ...
|
||||
|
||||
if req.Stream {
|
||||
service.HandleStreamResponse(c, req)
|
||||
return
|
||||
}
|
||||
|
||||
service.HandleNormalResponse(c, req)
|
||||
}
|
@@ -1,15 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func NewServer() *gin.Engine {
|
||||
r := gin.Default()
|
||||
setupRoutes(r)
|
||||
return r
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine) {
|
||||
r.POST("/v1/chat/completions", handleChat)
|
||||
}
|
@@ -1,14 +0,0 @@
|
||||
package models
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// ... 其他类型定义 ...
|
@@ -1,15 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"cursor-api-proxy/internal/models"
|
||||
"cursor-api-proxy/internal/utils"
|
||||
)
|
||||
|
||||
func HandleStreamResponse(c *gin.Context, req models.ChatRequest) {
|
||||
// 从 handlers.go 移动流式响应处理逻辑到这里
|
||||
}
|
||||
|
||||
func HandleNormalResponse(c *gin.Context, req models.ChatRequest) {
|
||||
// 从 handlers.go 移动普通响应处理逻辑到这里
|
||||
}
|
@@ -1,5 +0,0 @@
|
||||
package utils
|
||||
|
||||
func StringToHex(str, modelName string) []byte {
|
||||
// 从 utils.go 移动转换逻辑到这里
|
||||
}
|
@@ -1,5 +0,0 @@
|
||||
package utils
|
||||
|
||||
func ProcessChunk(chunk []byte) string {
|
||||
// 从 process.go 移动处理逻辑到这里
|
||||
}
|
106
main.go
106
main.go
@@ -1,97 +1,35 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"go-capi/handlers"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
// 在 main() 函数之前添加以下结构体定义
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := godotenv.Load(); err != nil {
|
||||
log.Println("Warning: Error loading .env file")
|
||||
}
|
||||
// 加载环境变量
|
||||
godotenv.Load()
|
||||
|
||||
r := gin.Default()
|
||||
r.POST("/v1/chat/completions", handleChat)
|
||||
|
||||
port := os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = "3000"
|
||||
}
|
||||
|
||||
log.Printf("服务器运行在端口 %s\n", port)
|
||||
// 配置CORS
|
||||
r.Use(cors.New(cors.Config{
|
||||
AllowOrigins: []string{"*"},
|
||||
AllowMethods: []string{"GET", "POST",},
|
||||
AllowHeaders: []string{"*"},
|
||||
AllowCredentials: true,
|
||||
}))
|
||||
|
||||
// 注册路由
|
||||
r.POST("/v1/chat/completions", handlers.ChatCompletions)
|
||||
r.GET("/models", handlers.GetModels)
|
||||
|
||||
// 获取端口号
|
||||
// port := os.Getenv("PORT")
|
||||
port := "3001"
|
||||
|
||||
r.Run(":" + port)
|
||||
}
|
||||
|
||||
func handleChat(c *gin.Context) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为 o1 开头的模型且请求流式输出
|
||||
if strings.HasPrefix(req.Model, "o1-") && req.Stream {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Model not supported stream"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取并处理认证token
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
authToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
||||
if authToken == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Authorization is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 处理消息
|
||||
if len(req.Messages) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "Messages should be a non-empty array",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 处理流式请求
|
||||
if req.Stream {
|
||||
handleStreamResponse(c, req)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理非流式请求
|
||||
handleNormalResponse(c, req)
|
||||
}
|
||||
|
||||
// 在文件末尾添加这两个新函数
|
||||
func handleStreamResponse(c *gin.Context, req ChatRequest) {
|
||||
// TODO: 实现流式响应的逻辑
|
||||
c.JSON(http.StatusNotImplemented, gin.H{
|
||||
"error": "Stream response not implemented yet",
|
||||
})
|
||||
}
|
||||
|
||||
func handleNormalResponse(c *gin.Context, req ChatRequest) {
|
||||
// TODO: 实现普通响应的逻辑
|
||||
c.JSON(http.StatusNotImplemented, gin.H{
|
||||
"error": "Normal response not implemented yet",
|
||||
})
|
||||
}
|
||||
}
|
411
main.py
Normal file
411
main.py
Normal file
@@ -0,0 +1,411 @@
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, Request, Response, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
import httpx
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import time
|
||||
import re
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# 定义请求模型
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
stream: bool = False
|
||||
|
||||
|
||||
def string_to_hex(text: str, model_name: str) -> bytes:
|
||||
"""将文本转换为特定格式的十六进制数据"""
|
||||
# 将输入文本转换为UTF-8字节
|
||||
text_bytes = text.encode('utf-8')
|
||||
text_length = len(text_bytes)
|
||||
|
||||
# 固定常量
|
||||
FIXED_HEADER = 2
|
||||
SEPARATOR = 1
|
||||
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
|
||||
|
||||
# 计算第一个长度字段
|
||||
if text_length < 128:
|
||||
text_length_field1 = format(text_length, '02x')
|
||||
text_length_field_size1 = 1
|
||||
else:
|
||||
low_byte1 = (text_length & 0x7F) | 0x80
|
||||
high_byte1 = (text_length >> 7) & 0xFF
|
||||
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
|
||||
text_length_field_size1 = 2
|
||||
|
||||
# 计算基础长度字段
|
||||
base_length = text_length + 0x2A
|
||||
if base_length < 128:
|
||||
text_length_field = format(base_length, '02x')
|
||||
text_length_field_size = 1
|
||||
else:
|
||||
low_byte = (base_length & 0x7F) | 0x80
|
||||
high_byte = (base_length >> 7) & 0xFF
|
||||
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
|
||||
text_length_field_size = 2
|
||||
|
||||
# 计算总消息长度
|
||||
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
|
||||
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
|
||||
|
||||
# 构造十六进制字符串
|
||||
model_name_bytes = model_name.encode('utf-8')
|
||||
model_name_length_hex = format(len(model_name_bytes), '02X')
|
||||
model_name_hex = model_name_bytes.hex().upper()
|
||||
|
||||
hex_string = (
|
||||
f"{message_total_length:010x}"
|
||||
"12"
|
||||
f"{text_length_field}"
|
||||
"0A"
|
||||
f"{text_length_field1}"
|
||||
f"{text_bytes.hex()}"
|
||||
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
|
||||
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
|
||||
f"{model_name_length_hex}"
|
||||
f"{model_name_hex}"
|
||||
"22004A"
|
||||
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
|
||||
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
|
||||
"800101B00100C00100E00100E80100"
|
||||
).upper()
|
||||
|
||||
return bytes.fromhex(hex_string)
|
||||
|
||||
|
||||
def chunk_to_utf8_string(chunk: bytes) -> str:
|
||||
"""将二进制chunk转换为UTF-8字符串"""
|
||||
if not chunk or len(chunk) < 2:
|
||||
return ''
|
||||
|
||||
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
|
||||
return ''
|
||||
|
||||
# 记录原始chunk的十六进制(调试用)
|
||||
print(f"chunk length: {len(chunk)}")
|
||||
# print(f"chunk hex: {chunk.hex()}")
|
||||
|
||||
try:
|
||||
# 去掉0x0A之前的所有字节
|
||||
try:
|
||||
chunk = chunk[chunk.index(0x0A) + 1:]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
filtered_chunk = bytearray()
|
||||
i = 0
|
||||
while i < len(chunk):
|
||||
# 检查是否有连续的0x00
|
||||
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
|
||||
i += 4
|
||||
while i < len(chunk) and chunk[i] <= 0x0F:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if chunk[i] == 0x0C:
|
||||
i += 1
|
||||
while i < len(chunk) and chunk[i] == 0x0A:
|
||||
i += 1
|
||||
else:
|
||||
filtered_chunk.append(chunk[i])
|
||||
i += 1
|
||||
|
||||
# 过滤掉特定字节
|
||||
filtered_chunk = bytes(b for b in filtered_chunk
|
||||
if b != 0x00 and b != 0x0C)
|
||||
|
||||
if not filtered_chunk:
|
||||
return ''
|
||||
|
||||
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
|
||||
# print(f"decoded result: {result}") # 调试输出
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in chunk_to_utf8_string: {str(e)}")
|
||||
return ''
|
||||
|
||||
|
||||
async def process_stream(chunks, ):
|
||||
"""处理流式响应"""
|
||||
response_id = f"chatcmpl-{str(uuid.uuid4())}"
|
||||
|
||||
# 先将所有chunks读取到列表中
|
||||
# chunks = []
|
||||
# async for chunk in response.aiter_raw():
|
||||
# chunks.append(chunk)
|
||||
|
||||
# 然后处理保存的chunks
|
||||
for chunk in chunks:
|
||||
text = chunk_to_utf8_string(chunk)
|
||||
if text:
|
||||
# 清理文本
|
||||
text = text.strip()
|
||||
if "<|END_USER|>" in text:
|
||||
text = text.split("<|END_USER|>")[-1].strip()
|
||||
if text and text[0].isalpha():
|
||||
text = text[1:].strip()
|
||||
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
|
||||
|
||||
if text: # 确保清理后的文本不为空
|
||||
data_body = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": text
|
||||
}
|
||||
}]
|
||||
}
|
||||
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
|
||||
# yield "data: {\n"
|
||||
# yield f' "id": "{response_id}",\n'
|
||||
# yield ' "object": "chat.completion.chunk",\n'
|
||||
# yield f' "created": {int(time.time())},\n'
|
||||
# yield ' "choices": [{\n'
|
||||
# yield ' "index": 0,\n'
|
||||
# yield ' "delta": {\n'
|
||||
# yield f' "content": "{text}"\n'
|
||||
# yield " }\n"
|
||||
# yield " }]\n"
|
||||
# yield "}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request, chat_request: ChatRequest):
|
||||
# 验证o1模型不支持流式输出
|
||||
if chat_request.model.startswith('o1-') and chat_request.stream:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model not supported stream"
|
||||
)
|
||||
|
||||
# 获取并处理认证令牌
|
||||
auth_header = request.headers.get('authorization', '')
|
||||
if not auth_header.startswith('Bearer '):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid authorization header"
|
||||
)
|
||||
|
||||
auth_token = auth_header.replace('Bearer ', '')
|
||||
if not auth_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing authorization token"
|
||||
)
|
||||
|
||||
# 处理多个密钥
|
||||
keys = [key.strip() for key in auth_token.split(',')]
|
||||
if keys:
|
||||
auth_token = keys[0] # 使用第一个密钥
|
||||
|
||||
if '%3A%3A' in auth_token:
|
||||
auth_token = auth_token.split('%3A%3A')[1]
|
||||
|
||||
# 格式化消息
|
||||
formatted_messages = "\n".join(
|
||||
f"{msg.role}:{msg.content}" for msg in chat_request.messages
|
||||
)
|
||||
|
||||
# 生成请求数据
|
||||
hex_data = string_to_hex(formatted_messages, chat_request.model)
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
'Content-Type': 'application/connect+proto',
|
||||
'Authorization': f'Bearer {auth_token}',
|
||||
'Connect-Accept-Encoding': 'gzip,br',
|
||||
'Connect-Protocol-Version': '1',
|
||||
'User-Agent': 'connect-es/1.4.0',
|
||||
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
|
||||
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
|
||||
'X-Cursor-Client-Version': '0.42.3',
|
||||
'X-Cursor-Timezone': 'Asia/Shanghai',
|
||||
'X-Ghost-Mode': 'false',
|
||||
'X-Request-Id': str(uuid.uuid4()),
|
||||
'Host': 'api2.cursor.sh'
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
||||
try:
|
||||
# 使用 stream=True 参数
|
||||
# 打印 headers 和 二进制 data
|
||||
print(f"headers: {headers}")
|
||||
print(hex_data)
|
||||
async with client.stream(
|
||||
'POST',
|
||||
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
|
||||
headers=headers,
|
||||
content=hex_data,
|
||||
timeout=None
|
||||
) as response:
|
||||
if chat_request.stream:
|
||||
chunks = []
|
||||
async for chunk in response.aiter_raw():
|
||||
chunks.append(chunk)
|
||||
return StreamingResponse(
|
||||
process_stream(chunks),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
# 非流式响应处理
|
||||
text = ''
|
||||
async for chunk in response.aiter_raw():
|
||||
# print('chunk:', chunk.hex())
|
||||
print('chunk length:', len(chunk))
|
||||
|
||||
res = chunk_to_utf8_string(chunk)
|
||||
# print('res:', res)
|
||||
if res:
|
||||
text += res
|
||||
|
||||
# 清理响应文本
|
||||
import re
|
||||
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
|
||||
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
|
||||
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": chat_request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/models")
|
||||
async def models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "claude-3-5-sonnet-20241022",
|
||||
"object": "model",
|
||||
"created": 1713744000,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3-opus",
|
||||
"object": "model",
|
||||
"created": 1709251200,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-haiku",
|
||||
"object": "model",
|
||||
"created": 1711929600,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-sonnet",
|
||||
"object": "model",
|
||||
"created": 1711929600,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "cursor-small",
|
||||
"object": "model",
|
||||
"created": 1712534400,
|
||||
"owned_by": "cursor"
|
||||
},
|
||||
{
|
||||
"id": "gpt-3.5-turbo",
|
||||
"object": "model",
|
||||
"created": 1677649200,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"object": "model",
|
||||
"created": 1687392000,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4-turbo-2024-04-09",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4o-mini",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "o1-mini",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "o1-preview",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
port = int(os.getenv("PORT", "3001"))
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)
|
50
models/models.go
Normal file
50
models/models.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package models
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type Delta struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Delta Delta `json:"delta,omitempty"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type ModelData struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
type ModelsResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []ModelData `json:"data"`
|
||||
}
|
190
process.go
190
process.go
@@ -1,190 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
// "encoding/hex"
|
||||
// "log"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
|
||||
// // 封装函数,用于将 chunk 转换为 UTF-8 字符串
|
||||
// function chunkToUtf8String (chunk) {
|
||||
// if (chunk[0] === 0x01 || chunk[0] === 0x02 || (chunk[0] === 0x60 && chunk[1] === 0x0C)) {
|
||||
// return ''
|
||||
// }
|
||||
|
||||
// console.log('chunk:', Buffer.from(chunk).toString('hex'))
|
||||
// console.log('chunk string:', Buffer.from(chunk).toString('utf-8'))
|
||||
|
||||
// // 去掉 chunk 中 0x0A 以及之前的字符
|
||||
// chunk = chunk.slice(chunk.indexOf(0x0A) + 1)
|
||||
|
||||
// let filteredChunk = []
|
||||
// let i = 0
|
||||
// while (i < chunk.length) {
|
||||
// // 新的条件过滤:如果遇到连续4个0x00,则移除其之后所有的以 0 开头的字节(0x00 到 0x0F)
|
||||
// if (chunk.slice(i, i + 4).every(byte => byte === 0x00)) {
|
||||
// i += 4 // 跳过这4个0x00
|
||||
// while (i < chunk.length && chunk[i] >= 0x00 && chunk[i] <= 0x0F) {
|
||||
// i++ // 跳过所有以 0 开头的字节
|
||||
// }
|
||||
// continue
|
||||
// }
|
||||
|
||||
// if (chunk[i] === 0x0C) {
|
||||
// // 遇到 0x0C 时,跳过 0x0C 以及后续的所有连续的 0x0A
|
||||
// i++ // 跳过 0x0C
|
||||
// while (i < chunk.length && chunk[i] === 0x0A) {
|
||||
// i++ // 跳过所有连续的 0x0A
|
||||
// }
|
||||
// } else if (
|
||||
// i > 0 &&
|
||||
// chunk[i] === 0x0A &&
|
||||
// chunk[i - 1] >= 0x00 &&
|
||||
// chunk[i - 1] <= 0x09
|
||||
// ) {
|
||||
// // 如果当前字节是 0x0A,且前一个字节在 0x00 至 0x09 之间,跳过前一个字节和当前字节
|
||||
// filteredChunk.pop() // 移除已添加的前一个字节
|
||||
// i++ // 跳过当前的 0x0A
|
||||
// } else {
|
||||
// filteredChunk.push(chunk[i])
|
||||
// i++
|
||||
// }
|
||||
// }
|
||||
|
||||
// // 第二步:去除所有的 0x00 和 0x0C
|
||||
// filteredChunk = filteredChunk.filter((byte) => byte !== 0x00 && byte !== 0x0C)
|
||||
|
||||
// // 去除小于 0x0A 的字节
|
||||
// filteredChunk = filteredChunk.filter((byte) => byte >= 0x0A)
|
||||
|
||||
// const hexString = Buffer.from(filteredChunk).toString('hex')
|
||||
// console.log('hexString:', hexString)
|
||||
// const utf8String = Buffer.from(filteredChunk).toString('utf-8')
|
||||
// console.log('utf8String:', utf8String)
|
||||
// return utf8String
|
||||
// }
|
||||
// func processChunk(chunk []byte) string {
|
||||
// // 检查特殊字节开头的情况
|
||||
// if len(chunk) > 0 && (chunk[0] == 0x01 || chunk[0] == 0x02 || (len(chunk) > 1 && chunk[0] == 0x60 && chunk[1] == 0x0C)) {
|
||||
// return ""
|
||||
// }
|
||||
|
||||
// // 打印调试信息
|
||||
// fmt.Printf("chunk: %x\n", chunk)
|
||||
// fmt.Printf("chunk string: %s\n", string(chunk))
|
||||
|
||||
// // 找到第一个 0x0A 并截取之后的内容
|
||||
// index := bytes.IndexByte(chunk, 0x0A)
|
||||
// if index != -1 {
|
||||
// chunk = chunk[index+1:]
|
||||
// }
|
||||
|
||||
// // 创建过滤后的切片
|
||||
// filteredChunk := make([]byte, 0, len(chunk))
|
||||
// for i := 0; i < len(chunk); {
|
||||
// // 检查连续4个0x00的情况
|
||||
// if i+4 <= len(chunk) {
|
||||
// if chunk[i] == 0x00 && chunk[i+1] == 0x00 && chunk[i+2] == 0x00 && chunk[i+3] == 0x00 {
|
||||
// i += 4
|
||||
// // 跳过所有以0开头的字节
|
||||
// for i < len(chunk) && chunk[i] <= 0x0F {
|
||||
// i++
|
||||
// }
|
||||
// continue
|
||||
// }
|
||||
// }
|
||||
|
||||
// if chunk[i] == 0x0C {
|
||||
// i++
|
||||
// // 跳过所有连续的0x0A
|
||||
// for i < len(chunk) && chunk[i] == 0x0A {
|
||||
// i++
|
||||
// }
|
||||
// } else if i > 0 && chunk[i] == 0x0A && chunk[i-1] >= 0x00 && chunk[i-1] <= 0x09 {
|
||||
// // 移除前一个字节并跳过当前的0x0A
|
||||
// filteredChunk = filteredChunk[:len(filteredChunk)-1]
|
||||
// i++
|
||||
// } else {
|
||||
// filteredChunk = append(filteredChunk, chunk[i])
|
||||
// i++
|
||||
// }
|
||||
// }
|
||||
|
||||
// // 过滤掉0x00和0x0C
|
||||
// tempChunk := make([]byte, 0, len(filteredChunk))
|
||||
// for _, b := range filteredChunk {
|
||||
// if b != 0x00 && b != 0x0C {
|
||||
// tempChunk = append(tempChunk, b)
|
||||
// }
|
||||
// }
|
||||
// filteredChunk = tempChunk
|
||||
|
||||
// // 过滤掉小于0x0A的字节
|
||||
// tempChunk = make([]byte, 0, len(filteredChunk))
|
||||
// for _, b := range filteredChunk {
|
||||
// if b >= 0x0A {
|
||||
// tempChunk = append(tempChunk, b)
|
||||
// }
|
||||
// }
|
||||
// filteredChunk = tempChunk
|
||||
|
||||
// // 打印调试信息并返回结果
|
||||
// fmt.Printf("hexString: %x\n", filteredChunk)
|
||||
// result := string(filteredChunk)
|
||||
// fmt.Printf("utf8String: %s\n", result)
|
||||
// return result
|
||||
// }
|
||||
|
||||
func processChunk(chunk []byte) string {
|
||||
// 检查特殊字节开头的情况
|
||||
if len(chunk) > 0 && (chunk[0] == 0x01 || chunk[0] == 0x02 || (len(chunk) > 1 && chunk[0] == 0x60 && chunk[1] == 0x0C)) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 打印调试信息
|
||||
fmt.Printf("chunk: %x\n", chunk)
|
||||
fmt.Printf("chunk string: %s\n", string(chunk))
|
||||
|
||||
// 找到第一个 0x0A 并截取之后的内容
|
||||
index := bytes.IndexByte(chunk, 0x0A)
|
||||
if index != -1 {
|
||||
chunk = chunk[index+1:]
|
||||
}
|
||||
|
||||
// 创建过滤后的切片
|
||||
filteredChunk := make([]byte, 0, len(chunk))
|
||||
for i := 0; i < len(chunk); {
|
||||
// 检查连续4个0x00的情况
|
||||
if i+4 <= len(chunk) {
|
||||
allZeros := true
|
||||
for j := 0; j < 4; j++ {
|
||||
if chunk[i+j] != 0x00 {
|
||||
allZeros = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allZeros {
|
||||
i += 4
|
||||
// 跳过所有以0开头的字节
|
||||
for i < len(chunk) && chunk[i] <= 0x0F {
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 保留UTF-8字符
|
||||
if chunk[i] >= 0xE0 || (chunk[i] >= 0x20 && chunk[i] <= 0x7F) {
|
||||
filteredChunk = append(filteredChunk, chunk[i])
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
// 打印调试信息并返回结果
|
||||
fmt.Printf("hexString: %x\n", filteredChunk)
|
||||
result := string(filteredChunk)
|
||||
fmt.Printf("utf8String: %s\n", result)
|
||||
return string(chunk)
|
||||
}
|
45
rs-capi/.vscode/launch.json
vendored
Normal file
45
rs-capi/.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
{
|
||||
// 使用 IntelliSense 了解相关属性。
|
||||
// 悬停以查看现有属性的描述。
|
||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug executable 'rs-api'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"build",
|
||||
"--bin=rs-api",
|
||||
"--package=rs-api"
|
||||
],
|
||||
"filter": {
|
||||
"name": "rs-api",
|
||||
"kind": "bin"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug unit tests in executable 'rs-api'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--bin=rs-api",
|
||||
"--package=rs-api"
|
||||
],
|
||||
"filter": {
|
||||
"name": "rs-api",
|
||||
"kind": "bin"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
}
|
||||
]
|
||||
}
|
1965
rs-capi/Cargo.lock
generated
Normal file
1965
rs-capi/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
23
rs-capi/Cargo.toml
Normal file
23
rs-capi/Cargo.toml
Normal file
@@ -0,0 +1,23 @@
|
||||
[package]
|
||||
name = "rs-api"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
reqwest = { version = "0.11", features = ["json", "stream"] }
|
||||
tower-http = { version = "0.5", features = ["cors", "trace"] }
|
||||
uuid = { version = "1.0", features = ["v4"] }
|
||||
dotenv = "0.15"
|
||||
chrono = "0.4"
|
||||
futures = "0.3"
|
||||
bytes = "1.0"
|
||||
regex = "1.5"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
hex = "0.4"
|
||||
hyper = "1.5.1"
|
||||
http = "1.1.0"
|
822
rs-capi/main.py
Normal file
822
rs-capi/main.py
Normal file
@@ -0,0 +1,822 @@
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, Request, Response, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
import httpx
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import time
|
||||
import re
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# 定义请求模型
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
stream: bool = False
|
||||
|
||||
|
||||
def string_to_hex(text: str, model_name: str) -> bytes:
|
||||
"""将文本转换为特定格式的十六进制数据"""
|
||||
# 将输入文本转换为UTF-8字节
|
||||
text_bytes = text.encode('utf-8')
|
||||
text_length = len(text_bytes)
|
||||
|
||||
# 固定常量
|
||||
FIXED_HEADER = 2
|
||||
SEPARATOR = 1
|
||||
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
|
||||
|
||||
# 计算第一个长度字段
|
||||
if text_length < 128:
|
||||
text_length_field1 = format(text_length, '02x')
|
||||
text_length_field_size1 = 1
|
||||
else:
|
||||
low_byte1 = (text_length & 0x7F) | 0x80
|
||||
high_byte1 = (text_length >> 7) & 0xFF
|
||||
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
|
||||
text_length_field_size1 = 2
|
||||
|
||||
# 计算基础长度字段
|
||||
base_length = text_length + 0x2A
|
||||
if base_length < 128:
|
||||
text_length_field = format(base_length, '02x')
|
||||
text_length_field_size = 1
|
||||
else:
|
||||
low_byte = (base_length & 0x7F) | 0x80
|
||||
high_byte = (base_length >> 7) & 0xFF
|
||||
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
|
||||
text_length_field_size = 2
|
||||
|
||||
# 计算总消息长度
|
||||
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
|
||||
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
|
||||
|
||||
# 构造十六进制字符串
|
||||
model_name_bytes = model_name.encode('utf-8')
|
||||
model_name_length_hex = format(len(model_name_bytes), '02X')
|
||||
model_name_hex = model_name_bytes.hex().upper()
|
||||
|
||||
hex_string = (
|
||||
f"{message_total_length:010x}"
|
||||
"12"
|
||||
f"{text_length_field}"
|
||||
"0A"
|
||||
f"{text_length_field1}"
|
||||
f"{text_bytes.hex()}"
|
||||
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
|
||||
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
|
||||
f"{model_name_length_hex}"
|
||||
f"{model_name_hex}"
|
||||
"22004A"
|
||||
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
|
||||
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
|
||||
"800101B00100C00100E00100E80100"
|
||||
).upper()
|
||||
|
||||
return bytes.fromhex(hex_string)
|
||||
|
||||
|
||||
def chunk_to_utf8_string(chunk: bytes) -> str:
|
||||
"""将二进制chunk转换为UTF-8字符串"""
|
||||
if not chunk or len(chunk) < 2:
|
||||
return ''
|
||||
|
||||
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
|
||||
return ''
|
||||
|
||||
# 记录原始chunk的十六进制(调试用)
|
||||
print(f"chunk length: {len(chunk)}")
|
||||
# print(f"chunk hex: {chunk.hex()}")
|
||||
|
||||
try:
|
||||
# 去掉0x0A之前的所有字节
|
||||
try:
|
||||
chunk = chunk[chunk.index(0x0A) + 1:]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
filtered_chunk = bytearray()
|
||||
i = 0
|
||||
while i < len(chunk):
|
||||
# 检查是否有连续的0x00
|
||||
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
|
||||
i += 4
|
||||
while i < len(chunk) and chunk[i] <= 0x0F:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if chunk[i] == 0x0C:
|
||||
i += 1
|
||||
while i < len(chunk) and chunk[i] == 0x0A:
|
||||
i += 1
|
||||
else:
|
||||
filtered_chunk.append(chunk[i])
|
||||
i += 1
|
||||
|
||||
# 过滤掉特定字节
|
||||
filtered_chunk = bytes(b for b in filtered_chunk
|
||||
if b != 0x00 and b != 0x0C)
|
||||
|
||||
if not filtered_chunk:
|
||||
return ''
|
||||
|
||||
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
|
||||
# print(f"decoded result: {result}") # 调试输出
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in chunk_to_utf8_string: {str(e)}")
|
||||
return ''
|
||||
|
||||
|
||||
async def process_stream(chunks, ):
|
||||
"""处理流式响应"""
|
||||
response_id = f"chatcmpl-{str(uuid.uuid4())}"
|
||||
|
||||
# 先将所有chunks读取到列表中
|
||||
# chunks = []
|
||||
# async for chunk in response.aiter_raw():
|
||||
# chunks.append(chunk)
|
||||
|
||||
# 然后处理保存的chunks
|
||||
for chunk in chunks:
|
||||
text = chunk_to_utf8_string(chunk)
|
||||
if text:
|
||||
# 清理文本
|
||||
text = text.strip()
|
||||
if "<|END_USER|>" in text:
|
||||
text = text.split("<|END_USER|>")[-1].strip()
|
||||
if text and text[0].isalpha():
|
||||
text = text[1:].strip()
|
||||
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
|
||||
|
||||
if text: # 确保清理后的文本不为空
|
||||
data_body = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": text
|
||||
}
|
||||
}]
|
||||
}
|
||||
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
|
||||
# yield "data: {\n"
|
||||
# yield f' "id": "{response_id}",\n'
|
||||
# yield ' "object": "chat.completion.chunk",\n'
|
||||
# yield f' "created": {int(time.time())},\n'
|
||||
# yield ' "choices": [{\n'
|
||||
# yield ' "index": 0,\n'
|
||||
# yield ' "delta": {\n'
|
||||
# yield f' "content": "{text}"\n'
|
||||
# yield " }\n"
|
||||
# yield " }]\n"
|
||||
# yield "}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request, chat_request: ChatRequest):
|
||||
# 验证o1模型不支持流式输出
|
||||
if chat_request.model.startswith('o1-') and chat_request.stream:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model not supported stream"
|
||||
)
|
||||
|
||||
# 获取并处理认证令牌
|
||||
auth_header = request.headers.get('authorization', '')
|
||||
if not auth_header.startswith('Bearer '):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid authorization header"
|
||||
)
|
||||
|
||||
auth_token = auth_header.replace('Bearer ', '')
|
||||
if not auth_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing authorization token"
|
||||
)
|
||||
|
||||
# 处理多个密钥
|
||||
keys = [key.strip() for key in auth_token.split(',')]
|
||||
if keys:
|
||||
auth_token = keys[0] # 使用第一个密钥
|
||||
|
||||
if '%3A%3A' in auth_token:
|
||||
auth_token = auth_token.split('%3A%3A')[1]
|
||||
|
||||
# 格式化消息
|
||||
formatted_messages = "\n".join(
|
||||
f"{msg.role}:{msg.content}" for msg in chat_request.messages
|
||||
)
|
||||
|
||||
# 生成请求数据
|
||||
hex_data = string_to_hex(formatted_messages, chat_request.model)
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
'Content-Type': 'application/connect+proto',
|
||||
'Authorization': f'Bearer {auth_token}',
|
||||
'Connect-Accept-Encoding': 'gzip,br',
|
||||
'Connect-Protocol-Version': '1',
|
||||
'User-Agent': 'connect-es/1.4.0',
|
||||
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
|
||||
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
|
||||
'X-Cursor-Client-Version': '0.42.3',
|
||||
'X-Cursor-Timezone': 'Asia/Shanghai',
|
||||
'X-Ghost-Mode': 'false',
|
||||
'X-Request-Id': str(uuid.uuid4()),
|
||||
'Host': 'api2.cursor.sh'
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
||||
try:
|
||||
# 使用 stream=True 参数
|
||||
# 打印 headers 和 二进制 data
|
||||
print(f"headers: {headers}")
|
||||
print(hex_data)
|
||||
async with client.stream(
|
||||
'POST',
|
||||
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
|
||||
headers=headers,
|
||||
content=hex_data,
|
||||
timeout=None
|
||||
) as response:
|
||||
if chat_request.stream:
|
||||
chunks = []
|
||||
async for chunk in response.aiter_raw():
|
||||
chunks.append(chunk)
|
||||
return StreamingResponse(
|
||||
process_stream(chunks),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
# 非流式响应处理
|
||||
text = ''
|
||||
async for chunk in response.aiter_raw():
|
||||
# print('chunk:', chunk.hex())
|
||||
print('chunk length:', len(chunk))
|
||||
|
||||
res = chunk_to_utf8_string(chunk)
|
||||
# print('res:', res)
|
||||
if res:
|
||||
text += res
|
||||
|
||||
# 清理响应文本
|
||||
import re
|
||||
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
|
||||
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
|
||||
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": chat_request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/models")
|
||||
async def models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "claude-3-5-sonnet-20241022",
|
||||
"object": "model",
|
||||
"created": 1713744000,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3-opus",
|
||||
"object": "model",
|
||||
"created": 1709251200,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-haiku",
|
||||
"object": "model",
|
||||
"created": 1711929600,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-sonnet",
|
||||
"object": "model",
|
||||
"created": 1711929600,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "cursor-small",
|
||||
"object": "model",
|
||||
"created": 1712534400,
|
||||
"owned_by": "cursor"
|
||||
},
|
||||
{
|
||||
"id": "gpt-3.5-turbo",
|
||||
"object": "model",
|
||||
"created": 1677649200,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"object": "model",
|
||||
"created": 1687392000,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4-turbo-2024-04-09",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4o-mini",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "o1-mini",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "o1-preview",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
port = int(os.getenv("PORT", "3001"))
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, Request, Response, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
import httpx
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import time
|
||||
import re
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# 定义请求模型
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
stream: bool = False
|
||||
|
||||
|
||||
def string_to_hex(text: str, model_name: str) -> bytes:
|
||||
"""将文本转换为特定格式的十六进制数据"""
|
||||
# 将输入文本转换为UTF-8字节
|
||||
text_bytes = text.encode('utf-8')
|
||||
text_length = len(text_bytes)
|
||||
|
||||
# 固定常量
|
||||
FIXED_HEADER = 2
|
||||
SEPARATOR = 1
|
||||
FIXED_SUFFIX_LENGTH = 0xA3 + len(model_name)
|
||||
|
||||
# 计算第一个长度字段
|
||||
if text_length < 128:
|
||||
text_length_field1 = format(text_length, '02x')
|
||||
text_length_field_size1 = 1
|
||||
else:
|
||||
low_byte1 = (text_length & 0x7F) | 0x80
|
||||
high_byte1 = (text_length >> 7) & 0xFF
|
||||
text_length_field1 = format(low_byte1, '02x') + format(high_byte1, '02x')
|
||||
text_length_field_size1 = 2
|
||||
|
||||
# 计算基础长度字段
|
||||
base_length = text_length + 0x2A
|
||||
if base_length < 128:
|
||||
text_length_field = format(base_length, '02x')
|
||||
text_length_field_size = 1
|
||||
else:
|
||||
low_byte = (base_length & 0x7F) | 0x80
|
||||
high_byte = (base_length >> 7) & 0xFF
|
||||
text_length_field = format(low_byte, '02x') + format(high_byte, '02x')
|
||||
text_length_field_size = 2
|
||||
|
||||
# 计算总消息长度
|
||||
message_total_length = (FIXED_HEADER + text_length_field_size + SEPARATOR +
|
||||
text_length_field_size1 + text_length + FIXED_SUFFIX_LENGTH)
|
||||
|
||||
# 构造十六进制字符串
|
||||
model_name_bytes = model_name.encode('utf-8')
|
||||
model_name_length_hex = format(len(model_name_bytes), '02X')
|
||||
model_name_hex = model_name_bytes.hex().upper()
|
||||
|
||||
hex_string = (
|
||||
f"{message_total_length:010x}"
|
||||
"12"
|
||||
f"{text_length_field}"
|
||||
"0A"
|
||||
f"{text_length_field1}"
|
||||
f"{text_bytes.hex()}"
|
||||
"10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612"
|
||||
"2002A132F643A2F6964656150726F2F656475626F73733A1E0A"
|
||||
f"{model_name_length_hex}"
|
||||
f"{model_name_hex}"
|
||||
"22004A"
|
||||
"2461383761396133342D323164642D343863372D623434662D616636633365636536663765"
|
||||
"680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061"
|
||||
"800101B00100C00100E00100E80100"
|
||||
).upper()
|
||||
|
||||
return bytes.fromhex(hex_string)
|
||||
|
||||
|
||||
def chunk_to_utf8_string(chunk: bytes) -> str:
|
||||
"""将二进制chunk转换为UTF-8字符串"""
|
||||
if not chunk or len(chunk) < 2:
|
||||
return ''
|
||||
|
||||
if chunk[0] in [0x01, 0x02] or (chunk[0] == 0x60 and chunk[1] == 0x0C):
|
||||
return ''
|
||||
|
||||
# 记录原始chunk的十六进制(调试用)
|
||||
print(f"chunk length: {len(chunk)}")
|
||||
# print(f"chunk hex: {chunk.hex()}")
|
||||
|
||||
try:
|
||||
# 去掉0x0A之前的所有字节
|
||||
try:
|
||||
chunk = chunk[chunk.index(0x0A) + 1:]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
filtered_chunk = bytearray()
|
||||
i = 0
|
||||
while i < len(chunk):
|
||||
# 检查是否有连续的0x00
|
||||
if i + 4 <= len(chunk) and all(chunk[j] == 0x00 for j in range(i, i + 4)):
|
||||
i += 4
|
||||
while i < len(chunk) and chunk[i] <= 0x0F:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if chunk[i] == 0x0C:
|
||||
i += 1
|
||||
while i < len(chunk) and chunk[i] == 0x0A:
|
||||
i += 1
|
||||
else:
|
||||
filtered_chunk.append(chunk[i])
|
||||
i += 1
|
||||
|
||||
# 过滤掉特定字节
|
||||
filtered_chunk = bytes(b for b in filtered_chunk
|
||||
if b != 0x00 and b != 0x0C)
|
||||
|
||||
if not filtered_chunk:
|
||||
return ''
|
||||
|
||||
result = filtered_chunk.decode('utf-8', errors='ignore').strip()
|
||||
# print(f"decoded result: {result}") # 调试输出
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in chunk_to_utf8_string: {str(e)}")
|
||||
return ''
|
||||
|
||||
|
||||
async def process_stream(chunks, ):
|
||||
"""处理流式响应"""
|
||||
response_id = f"chatcmpl-{str(uuid.uuid4())}"
|
||||
|
||||
# 先将所有chunks读取到列表中
|
||||
# chunks = []
|
||||
# async for chunk in response.aiter_raw():
|
||||
# chunks.append(chunk)
|
||||
|
||||
# 然后处理保存的chunks
|
||||
for chunk in chunks:
|
||||
text = chunk_to_utf8_string(chunk)
|
||||
if text:
|
||||
# 清理文本
|
||||
text = text.strip()
|
||||
if "<|END_USER|>" in text:
|
||||
text = text.split("<|END_USER|>")[-1].strip()
|
||||
if text and text[0].isalpha():
|
||||
text = text[1:].strip()
|
||||
text = re.sub(r"[\x00-\x1F\x7F]", "", text)
|
||||
|
||||
if text: # 确保清理后的文本不为空
|
||||
data_body = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": text
|
||||
}
|
||||
}]
|
||||
}
|
||||
yield f"data: {json.dumps(data_body, ensure_ascii=False)}\n\n"
|
||||
# yield "data: {\n"
|
||||
# yield f' "id": "{response_id}",\n'
|
||||
# yield ' "object": "chat.completion.chunk",\n'
|
||||
# yield f' "created": {int(time.time())},\n'
|
||||
# yield ' "choices": [{\n'
|
||||
# yield ' "index": 0,\n'
|
||||
# yield ' "delta": {\n'
|
||||
# yield f' "content": "{text}"\n'
|
||||
# yield " }\n"
|
||||
# yield " }]\n"
|
||||
# yield "}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request, chat_request: ChatRequest):
|
||||
# 验证o1模型不支持流式输出
|
||||
if chat_request.model.startswith('o1-') and chat_request.stream:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model not supported stream"
|
||||
)
|
||||
|
||||
# 获取并处理认证令牌
|
||||
auth_header = request.headers.get('authorization', '')
|
||||
if not auth_header.startswith('Bearer '):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid authorization header"
|
||||
)
|
||||
|
||||
auth_token = auth_header.replace('Bearer ', '')
|
||||
if not auth_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing authorization token"
|
||||
)
|
||||
|
||||
# 处理多个密钥
|
||||
keys = [key.strip() for key in auth_token.split(',')]
|
||||
if keys:
|
||||
auth_token = keys[0] # 使用第一个密钥
|
||||
|
||||
if '%3A%3A' in auth_token:
|
||||
auth_token = auth_token.split('%3A%3A')[1]
|
||||
|
||||
# 格式化消息
|
||||
formatted_messages = "\n".join(
|
||||
f"{msg.role}:{msg.content}" for msg in chat_request.messages
|
||||
)
|
||||
|
||||
# 生成请求数据
|
||||
hex_data = string_to_hex(formatted_messages, chat_request.model)
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
'Content-Type': 'application/connect+proto',
|
||||
'Authorization': f'Bearer {auth_token}',
|
||||
'Connect-Accept-Encoding': 'gzip,br',
|
||||
'Connect-Protocol-Version': '1',
|
||||
'User-Agent': 'connect-es/1.4.0',
|
||||
'X-Amzn-Trace-Id': f'Root={str(uuid.uuid4())}',
|
||||
'X-Cursor-Checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
|
||||
'X-Cursor-Client-Version': '0.42.3',
|
||||
'X-Cursor-Timezone': 'Asia/Shanghai',
|
||||
'X-Ghost-Mode': 'false',
|
||||
'X-Request-Id': str(uuid.uuid4()),
|
||||
'Host': 'api2.cursor.sh'
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
||||
try:
|
||||
# 使用 stream=True 参数
|
||||
# 打印 headers 和 二进制 data
|
||||
print(f"headers: {headers}")
|
||||
print(hex_data)
|
||||
async with client.stream(
|
||||
'POST',
|
||||
'https://api2.cursor.sh/aiserver.v1.AiService/StreamChat',
|
||||
headers=headers,
|
||||
content=hex_data,
|
||||
timeout=None
|
||||
) as response:
|
||||
if chat_request.stream:
|
||||
chunks = []
|
||||
async for chunk in response.aiter_raw():
|
||||
chunks.append(chunk)
|
||||
return StreamingResponse(
|
||||
process_stream(chunks),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
# 非流式响应处理
|
||||
text = ''
|
||||
async for chunk in response.aiter_raw():
|
||||
# print('chunk:', chunk.hex())
|
||||
print('chunk length:', len(chunk))
|
||||
|
||||
res = chunk_to_utf8_string(chunk)
|
||||
# print('res:', res)
|
||||
if res:
|
||||
text += res
|
||||
|
||||
# 清理响应文本
|
||||
import re
|
||||
text = re.sub(r'^.*<\|END_USER\|>', '', text, flags=re.DOTALL)
|
||||
text = re.sub(r'^\n[a-zA-Z]?', '', text).strip()
|
||||
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": chat_request.model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/models")
|
||||
async def models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "claude-3-5-sonnet-20241022",
|
||||
"object": "model",
|
||||
"created": 1713744000,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3-opus",
|
||||
"object": "model",
|
||||
"created": 1709251200,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-haiku",
|
||||
"object": "model",
|
||||
"created": 1711929600,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "claude-3.5-sonnet",
|
||||
"object": "model",
|
||||
"created": 1711929600,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
{
|
||||
"id": "cursor-small",
|
||||
"object": "model",
|
||||
"created": 1712534400,
|
||||
"owned_by": "cursor"
|
||||
},
|
||||
{
|
||||
"id": "gpt-3.5-turbo",
|
||||
"object": "model",
|
||||
"created": 1677649200,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"object": "model",
|
||||
"created": 1687392000,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4-turbo-2024-04-09",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "gpt-4o-mini",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "o1-mini",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
},
|
||||
{
|
||||
"id": "o1-preview",
|
||||
"object": "model",
|
||||
"created": 1712620800,
|
||||
"owned_by": "openai"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
port = int(os.getenv("PORT", "3001"))
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True, timeout_keep_alive=30)
|
111
rs-capi/src/hex_utils.rs
Normal file
111
rs-capi/src/hex_utils.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
pub fn string_to_hex(text: &str, model_name: &str) -> Vec<u8> {
|
||||
let text_bytes = text.as_bytes();
|
||||
let text_length = text_bytes.len();
|
||||
|
||||
// 固定常量
|
||||
const FIXED_HEADER: usize = 2;
|
||||
const SEPARATOR: usize = 1;
|
||||
|
||||
let model_name_bytes = model_name.as_bytes();
|
||||
let fixed_suffix_length = 0xA3 + model_name_bytes.len();
|
||||
|
||||
// 计算第一个长度字段
|
||||
let (text_length_field1, text_length_field_size1) = if text_length < 128 {
|
||||
(format!("{:02x}", text_length), 1)
|
||||
} else {
|
||||
let low_byte1 = (text_length & 0x7F) | 0x80;
|
||||
let high_byte1 = (text_length >> 7) & 0xFF;
|
||||
(format!("{:02x}{:02x}", low_byte1, high_byte1), 2)
|
||||
};
|
||||
|
||||
// 计算基础长度字段
|
||||
let base_length = text_length + 0x2A;
|
||||
let (text_length_field, text_length_field_size) = if base_length < 128 {
|
||||
(format!("{:02x}", base_length), 1)
|
||||
} else {
|
||||
let low_byte = (base_length & 0x7F) | 0x80;
|
||||
let high_byte = (base_length >> 7) & 0xFF;
|
||||
(format!("{:02x}{:02x}", low_byte, high_byte), 2)
|
||||
};
|
||||
|
||||
// 计算总消息长度
|
||||
let message_total_length = FIXED_HEADER + text_length_field_size + SEPARATOR +
|
||||
text_length_field_size1 + text_length + fixed_suffix_length;
|
||||
|
||||
// 构造十六进制字符串
|
||||
let model_name_length_hex = format!("{:02X}", model_name_bytes.len());
|
||||
let model_name_hex = hex::encode_upper(model_name_bytes);
|
||||
|
||||
let hex_string = format!(
|
||||
"{:010x}\
|
||||
12{}\
|
||||
0A{}\
|
||||
{}\
|
||||
10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612\
|
||||
2002A132F643A2F6964656150726F2F656475626F73733A1E0A\
|
||||
{}{}\
|
||||
22004A\
|
||||
2461383761396133342D323164642D343863372D623434662D616636633365636536663765\
|
||||
680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061\
|
||||
800101B00100C00100E00100E80100",
|
||||
message_total_length,
|
||||
text_length_field,
|
||||
text_length_field1,
|
||||
hex::encode_upper(text_bytes),
|
||||
model_name_length_hex,
|
||||
model_name_hex
|
||||
).to_uppercase();
|
||||
|
||||
// 将十六进制字符串转换为字节数组
|
||||
hex::decode(hex_string).unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn chunk_to_utf8_string(chunk: &[u8]) -> String {
|
||||
if chunk.len() < 2 {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
if chunk[0] == 0x01 || chunk[0] == 0x02 || (chunk[0] == 0x60 && chunk[1] == 0x0C) {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
// 尝试找到0x0A并从其后开始处理
|
||||
let chunk = match chunk.iter().position(|&x| x == 0x0A) {
|
||||
Some(pos) => &chunk[pos + 1..],
|
||||
None => chunk
|
||||
};
|
||||
|
||||
let mut filtered_chunk = Vec::new();
|
||||
let mut i = 0;
|
||||
|
||||
while i < chunk.len() {
|
||||
// 检查是否有连续的0x00
|
||||
if i + 4 <= chunk.len() && chunk[i..i+4].iter().all(|&x| x == 0x00) {
|
||||
i += 4;
|
||||
while i < chunk.len() && chunk[i] <= 0x0F {
|
||||
i += 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if chunk[i] == 0x0C {
|
||||
i += 1;
|
||||
while i < chunk.len() && chunk[i] == 0x0A {
|
||||
i += 1;
|
||||
}
|
||||
} else {
|
||||
filtered_chunk.push(chunk[i]);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// 过滤掉特定字节
|
||||
filtered_chunk.retain(|&b| b != 0x00 && b != 0x0C);
|
||||
|
||||
if filtered_chunk.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
// 转换为UTF-8字符串
|
||||
String::from_utf8_lossy(&filtered_chunk).trim().to_string()
|
||||
}
|
352
rs-capi/src/main.rs
Normal file
352
rs-capi/src/main.rs
Normal file
@@ -0,0 +1,352 @@
|
||||
use axum::{
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{
|
||||
sse::{Event, Sse},
|
||||
IntoResponse, Response,
|
||||
},
|
||||
routing::post,
|
||||
Json, Router,
|
||||
};
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
stream::{Stream, StreamExt},
|
||||
SinkExt,
|
||||
};
|
||||
// use http::HeaderName as HttpHeaderName;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::str::FromStr;
|
||||
use std::{convert::Infallible, time::Duration};
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
mod hex_utils;
|
||||
use hex_utils::{chunk_to_utf8_string, string_to_hex};
|
||||
|
||||
// 定义请求模型
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
#[serde(default)]
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
// 定义响应模型
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: i64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Choice {
|
||||
index: i32,
|
||||
message: ResponseMessage,
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ResponseMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Usage {
|
||||
prompt_tokens: i32,
|
||||
completion_tokens: i32,
|
||||
total_tokens: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct StreamResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: i64,
|
||||
choices: Vec<StreamChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct StreamChoice {
|
||||
index: i32,
|
||||
delta: Delta,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Delta {
|
||||
content: String,
|
||||
}
|
||||
|
||||
async fn process_stream(
|
||||
chunks: Vec<Bytes>,
|
||||
) -> impl Stream<Item = Result<Event, Infallible>> + Send {
|
||||
let (mut tx, rx) = mpsc::channel(100);
|
||||
let response_id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||
|
||||
tokio::spawn(async move {
|
||||
for chunk in chunks {
|
||||
let text = chunk_to_utf8_string(&chunk);
|
||||
if !text.is_empty() {
|
||||
let text = text.trim();
|
||||
let text = if let Some(idx) = text.find("<|END_USER|>") {
|
||||
text[idx + "<|END_USER|>".len()..].trim()
|
||||
} else {
|
||||
text
|
||||
};
|
||||
|
||||
let text = if !text.is_empty() && text.chars().next().unwrap().is_alphabetic() {
|
||||
text[1..].trim()
|
||||
} else {
|
||||
text
|
||||
};
|
||||
|
||||
let re = Regex::new(r"[\x00-\x1F\x7F]").unwrap();
|
||||
let text = re.replace_all(text, "");
|
||||
|
||||
if !text.is_empty() {
|
||||
let response = StreamResponse {
|
||||
id: response_id.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta: Delta {
|
||||
content: text.to_string(),
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
let json_data = serde_json::to_string(&response).unwrap();
|
||||
if !json_data.is_empty() {
|
||||
let _ = tx.send(Ok(Event::default().data(json_data))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = tx.send(Ok(Event::default().data("[DONE]"))).await;
|
||||
});
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// 初始化日志
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
// 创建CORS中间件
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
// 创建路由
|
||||
let app = Router::new()
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/models", post(models))
|
||||
.layer(cors)
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(|request: &axum::http::Request<_>| {
|
||||
tracing::info_span!(
|
||||
"http_request",
|
||||
method = %request.method(),
|
||||
uri = %request.uri(),
|
||||
)
|
||||
})
|
||||
// .on_request(|_request: &axum::http::Request<_>, _span: &tracing::Span| { info!("started processing request"); })
|
||||
.on_response(
|
||||
|response: &axum::http::Response<_>,
|
||||
latency: std::time::Duration,
|
||||
_span: &tracing::Span| {
|
||||
tracing::info!(
|
||||
status = %response.status(),
|
||||
latency = ?latency,
|
||||
);
|
||||
},
|
||||
),
|
||||
);
|
||||
|
||||
// 启动服务器
|
||||
let addr = "0.0.0.0:3002";
|
||||
println!("Server running on {}", addr);
|
||||
|
||||
// 修改服务器启动代码
|
||||
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
||||
// 处理聊天完成请求
|
||||
async fn chat_completions(
|
||||
headers: HeaderMap,
|
||||
Json(chat_request): Json<ChatRequest>,
|
||||
) -> Result<Response, StatusCode> {
|
||||
// 验证认证
|
||||
let auth_header = headers
|
||||
.get("authorization")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if !auth_header.starts_with("Bearer ") {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let mut auth_token = auth_header.replace("Bearer ", "");
|
||||
|
||||
// 验证o1模型不支持流式输出
|
||||
if chat_request.model.starts_with("o1-") && chat_request.stream {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
tracing::info!("chat_request: {:?}", chat_request);
|
||||
|
||||
// 处理多个密钥
|
||||
if auth_token.contains(',') {
|
||||
auth_token = auth_token.split(',').next().unwrap().trim().to_string();
|
||||
}
|
||||
|
||||
if auth_token.contains("%3A%3A") {
|
||||
auth_token = auth_token
|
||||
.split("%3A%3A")
|
||||
.nth(1)
|
||||
.unwrap_or(&auth_token)
|
||||
.to_string();
|
||||
}
|
||||
|
||||
// 格式化消息
|
||||
let formatted_messages = chat_request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|msg| format!("{}:{}", msg.role, msg.content))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
// 生成请求数据
|
||||
let hex_data = string_to_hex(&formatted_messages, &chat_request.model);
|
||||
// 准备请求头
|
||||
let request_id = Uuid::new_v4();
|
||||
let headers = reqwest::header::HeaderMap::from_iter([
|
||||
(reqwest::header::CONTENT_TYPE, "application/connect+proto"),
|
||||
(reqwest::header::AUTHORIZATION, &format!("Bearer {}", auth_token)),
|
||||
// 对于标准 HTTP 头部,使用预定义的常量
|
||||
(reqwest::header::HeaderName::from_str("Connect-Accept-Encoding").unwrap(), "gzip,br"),
|
||||
(reqwest::header::HeaderName::from_str("Connect-Protocol-Version").unwrap(), "1"),
|
||||
(reqwest::header::HeaderName::from_str("User-Agent").unwrap(), "connect-es/1.4.0"),
|
||||
(reqwest::header::HeaderName::from_str("X-Amzn-Trace-Id").unwrap(), &format!("Root={}", Uuid::new_v4())),
|
||||
(reqwest::header::HeaderName::from_str("X-Cursor-Checksum").unwrap(), "zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef"),
|
||||
(reqwest::header::HeaderName::from_str("X-Cursor-Client-Version").unwrap(), "0.42.3"),
|
||||
(reqwest::header::HeaderName::from_str("X-Cursor-Timezone").unwrap(), "Asia/Shanghai"),
|
||||
(reqwest::header::HeaderName::from_str("X-Ghost-Mode").unwrap(), "false"),
|
||||
(reqwest::header::HeaderName::from_str("X-Request-Id").unwrap(), &request_id.to_string()),
|
||||
(reqwest::header::HeaderName::from_str("Host").unwrap(), "api2.cursor.sh"),
|
||||
].iter().map(|(k, v)| (
|
||||
k.clone(),
|
||||
reqwest::header::HeaderValue::from_str(v).unwrap()
|
||||
)));
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(300))
|
||||
.build()
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let response = client
|
||||
.post("https://api2.cursor.sh/aiserver.v1.AiService/StreamChat")
|
||||
.headers(headers)
|
||||
.body(hex_data)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if chat_request.stream {
|
||||
let mut chunks = Vec::new();
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(chunk) => chunks.push(chunk),
|
||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
}
|
||||
}
|
||||
|
||||
let stream = process_stream(chunks).await;
|
||||
return Ok(Sse::new(stream).into_response());
|
||||
}
|
||||
|
||||
// 非流式响应
|
||||
let mut text = String::new();
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(chunk) => {
|
||||
let res = chunk_to_utf8_string(&chunk);
|
||||
if !res.is_empty() {
|
||||
text.push_str(&res);
|
||||
}
|
||||
}
|
||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
}
|
||||
}
|
||||
|
||||
// 清理响应文本
|
||||
let re = Regex::new(r"^.*<\|END_USER\|>").unwrap();
|
||||
text = re.replace(&text, "").to_string();
|
||||
|
||||
let re = Regex::new(r"^\n[a-zA-Z]?").unwrap();
|
||||
text = re.replace(&text, "").trim().to_string();
|
||||
|
||||
let re = Regex::new(r"[\x00-\x1F\x7F]").unwrap();
|
||||
text = re.replace_all(&text, "").to_string();
|
||||
|
||||
let response = ChatResponse {
|
||||
id: format!("chatcmpl-{}", Uuid::new_v4()),
|
||||
object: "chat.completion".to_string(),
|
||||
created: chrono::Utc::now().timestamp(),
|
||||
model: chat_request.model,
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: ResponseMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: text,
|
||||
},
|
||||
finish_reason: "stop".to_string(),
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Json(response).into_response())
|
||||
}
|
||||
|
||||
// 处理模型列表请求
|
||||
async fn models() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "claude-3-5-sonnet-20241022",
|
||||
"object": "model",
|
||||
"created": 1713744000,
|
||||
"owned_by": "anthropic"
|
||||
},
|
||||
// ... 其他模型
|
||||
]
|
||||
}))
|
||||
}
|
45
types.go
45
types.go
@@ -1,45 +0,0 @@
|
||||
package main
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
type StreamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
} `json:"choices"`
|
||||
}
|
83
utils.go
83
utils.go
@@ -1,83 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func stringToHex(str, modelName string) []byte {
|
||||
inputBytes := []byte(str)
|
||||
byteLength := len(inputBytes)
|
||||
|
||||
const (
|
||||
FIXED_HEADER = 2
|
||||
SEPARATOR = 1
|
||||
)
|
||||
|
||||
FIXED_SUFFIX_LENGTH := 0xA3 + len(modelName)
|
||||
|
||||
// 计算文本长度字段
|
||||
var textLengthField1, textLengthFieldSize1 int
|
||||
if byteLength < 128 {
|
||||
textLengthField1 = byteLength
|
||||
textLengthFieldSize1 = 1
|
||||
} else {
|
||||
lowByte1 := (byteLength & 0x7F) | 0x80
|
||||
highByte1 := (byteLength >> 7) & 0xFF
|
||||
textLengthField1 = (highByte1 << 8) | lowByte1
|
||||
textLengthFieldSize1 = 2
|
||||
}
|
||||
|
||||
// 计算基础长度
|
||||
baseLength := byteLength + 0x2A
|
||||
var textLengthField, textLengthFieldSize int
|
||||
if baseLength < 128 {
|
||||
textLengthField = baseLength
|
||||
textLengthFieldSize = 1
|
||||
} else {
|
||||
lowByte := (baseLength & 0x7F) | 0x80
|
||||
highByte := (baseLength >> 7) & 0xFF
|
||||
textLengthField = (highByte << 8) | lowByte
|
||||
textLengthFieldSize = 2
|
||||
}
|
||||
|
||||
// 计算总消息长度
|
||||
messageTotalLength := FIXED_HEADER + textLengthFieldSize + SEPARATOR +
|
||||
textLengthFieldSize1 + byteLength + FIXED_SUFFIX_LENGTH
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
// 写入消息长度
|
||||
fmt.Fprintf(&buf, "%010x", messageTotalLength)
|
||||
|
||||
// 写入固定头部
|
||||
buf.WriteString("12")
|
||||
|
||||
// 写入长度字段
|
||||
fmt.Fprintf(&buf, "%02x", textLengthField)
|
||||
|
||||
buf.WriteString("0A")
|
||||
fmt.Fprintf(&buf, "%02x", textLengthField1)
|
||||
|
||||
// 写入消息内容
|
||||
buf.WriteString(hex.EncodeToString(inputBytes))
|
||||
|
||||
// 写入固定后缀
|
||||
buf.WriteString("10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612")
|
||||
buf.WriteString("2002A132F643A2F6964656150726F2F656475626F73733A1E0A")
|
||||
|
||||
// 写入模型名称长度和内容
|
||||
fmt.Fprintf(&buf, "%02X", len(modelName))
|
||||
buf.WriteString(strings.ToUpper(hex.EncodeToString([]byte(modelName))))
|
||||
|
||||
// 写入剩余固定内容
|
||||
buf.WriteString("22004A")
|
||||
buf.WriteString("2461383761396133342D323164642D343863372D623434662D616636633365636536663765")
|
||||
buf.WriteString("680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061")
|
||||
buf.WriteString("800101B00100C00100E00100E80100")
|
||||
|
||||
hexBytes, _ := hex.DecodeString(strings.ToUpper(buf.String()))
|
||||
return hexBytes
|
||||
}
|
158
utils/hex.go
Normal file
158
utils/hex.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func StringToHex(text string, modelName string) ([]byte, error) {
|
||||
textBytes := []byte(text)
|
||||
textLength := len(textBytes)
|
||||
|
||||
const (
|
||||
FIXED_HEADER = 2
|
||||
SEPARATOR = 1
|
||||
)
|
||||
|
||||
modelNameBytes := []byte(modelName)
|
||||
FIXED_SUFFIX_LENGTH := 0xA3 + len(modelNameBytes)
|
||||
|
||||
// 计算第一个长度字段
|
||||
var textLengthField1 string
|
||||
var textLengthFieldSize1 int
|
||||
if textLength < 128 {
|
||||
textLengthField1 = fmt.Sprintf("%02x", textLength)
|
||||
textLengthFieldSize1 = 1
|
||||
} else {
|
||||
lowByte1 := (textLength & 0x7F) | 0x80
|
||||
highByte1 := (textLength >> 7) & 0xFF
|
||||
textLengthField1 = fmt.Sprintf("%02x%02x", lowByte1, highByte1)
|
||||
textLengthFieldSize1 = 2
|
||||
}
|
||||
|
||||
// 计算基础长度字段
|
||||
baseLength := textLength + 0x2A
|
||||
var textLengthField string
|
||||
var textLengthFieldSize int
|
||||
if baseLength < 128 {
|
||||
textLengthField = fmt.Sprintf("%02x", baseLength)
|
||||
textLengthFieldSize = 1
|
||||
} else {
|
||||
lowByte := (baseLength & 0x7F) | 0x80
|
||||
highByte := (baseLength >> 7) & 0xFF
|
||||
textLengthField = fmt.Sprintf("%02x%02x", lowByte, highByte)
|
||||
textLengthFieldSize = 2
|
||||
}
|
||||
|
||||
// 计算总消息长度
|
||||
messageTotalLength := FIXED_HEADER + textLengthFieldSize + SEPARATOR + textLengthFieldSize1 + textLength + FIXED_SUFFIX_LENGTH
|
||||
|
||||
modelNameHex := strings.ToUpper(hex.EncodeToString(modelNameBytes))
|
||||
modelNameLengthHex := fmt.Sprintf("%02X", len(modelNameBytes))
|
||||
|
||||
hexString := fmt.Sprintf(
|
||||
"%010x"+
|
||||
"12"+
|
||||
"%s"+
|
||||
"0a"+
|
||||
"%s"+
|
||||
"%x"+
|
||||
"10016a2432343163636435662d393162612d343131382d393239612d3936626330313631626432612"+
|
||||
"2002a132f643a2f6964656150726f2f656475626f73733a1e0a"+
|
||||
"%s"+
|
||||
"%s"+
|
||||
"22004a"+
|
||||
"2461383761396133342d323164642d343863372d623434662d616636633365636536663765"+
|
||||
"680070007a2436393737376535612d386332642d343835342d623564392d653062623232336163303061"+
|
||||
"800101b00100c00100e00100e80100",
|
||||
messageTotalLength,
|
||||
textLengthField,
|
||||
textLengthField1,
|
||||
textBytes,
|
||||
modelNameLengthHex,
|
||||
modelNameHex,
|
||||
)
|
||||
|
||||
hexString = strings.ToLower(hexString)
|
||||
|
||||
return hex.DecodeString(hexString)
|
||||
}
|
||||
|
||||
func ChunkToUTF8String(chunk []byte) string {
|
||||
// 基础检查
|
||||
if len(chunk) < 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if chunk[0] == 0x01 || chunk[0] == 0x02 || (chunk[0] == 0x60 && chunk[1] == 0x0C) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 修改调试输出格式
|
||||
// fmt.Printf("chunk length: %d hex: %x\n", len(chunk), chunk)
|
||||
fmt.Printf("chunk length: %d\n", len(chunk), )
|
||||
|
||||
// 去掉0x0A之前的所有字节
|
||||
if idx := bytes.IndexByte(chunk, 0x0A); idx != -1 {
|
||||
chunk = chunk[idx+1:]
|
||||
}
|
||||
|
||||
// 修改过滤逻辑,将过滤步骤分开
|
||||
filteredChunk := make([]byte, 0, len(chunk))
|
||||
i := 0
|
||||
for i < len(chunk) {
|
||||
// 检查连续的0x00
|
||||
if i+4 <= len(chunk) && allZeros(chunk[i:i+4]) {
|
||||
i += 4
|
||||
for i < len(chunk) && chunk[i] <= 0x0F {
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if chunk[i] == 0x0C {
|
||||
i++
|
||||
for i < len(chunk) && chunk[i] == 0x0A {
|
||||
i++
|
||||
}
|
||||
} else {
|
||||
filteredChunk = append(filteredChunk, chunk[i])
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// 最后统一过滤特定字节
|
||||
finalFiltered := make([]byte, 0, len(filteredChunk))
|
||||
for _, b := range filteredChunk {
|
||||
if b != 0x00 && b != 0x0C {
|
||||
finalFiltered = append(finalFiltered, b)
|
||||
}
|
||||
}
|
||||
|
||||
if len(finalFiltered) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 添加错误处理
|
||||
result := strings.TrimSpace(string(finalFiltered))
|
||||
if !utf8.Valid(finalFiltered) {
|
||||
fmt.Printf("Error: Invalid UTF-8 sequence\n")
|
||||
return ""
|
||||
}
|
||||
|
||||
fmt.Printf("decoded result: %s\n", result)
|
||||
return result
|
||||
}
|
||||
|
||||
// 辅助函数检查连续的零字节
|
||||
func allZeros(data []byte) bool {
|
||||
for _, b := range data {
|
||||
if b != 0x00 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
Reference in New Issue
Block a user