From b63b8dae00a209c4ec14f787e1136f9bfd57db52 Mon Sep 17 00:00:00 2001 From: BlueSkyXN <63384277+BlueSkyXN@users.noreply.github.com> Date: Tue, 25 Mar 2025 10:26:22 +0800 Subject: [PATCH] 0.1.8 --- qwen2api-cf.js | 167 +- qwen2api.go | 4640 ++---------------------------------------------- 2 files changed, 256 insertions(+), 4551 deletions(-) diff --git a/qwen2api-cf.js b/qwen2api-cf.js index 4b9e8ce..3de59db 100644 --- a/qwen2api-cf.js +++ b/qwen2api-cf.js @@ -204,7 +204,7 @@ export default { } }, - // ---------------------- 流式响应处理(重构并去重) ---------------------- + // ---------------------- 流式响应处理(改进) ---------------------- async handleStreamResponse(fetchResponse, requestId, modelName) { const { readable, writable } = new TransformStream(); const writer = writable.getWriter(); @@ -215,8 +215,9 @@ export default { await writer.write(encoder.encode(`data: ${payload}\n\n`)); }; - // 用于去重:记录上一次完整接收到的 delta 内容 + // 用于去重和累积内容 let previousDelta = ""; + let cumulativeContent = ""; // 累积完整内容,解决断流问题 const processStream = async () => { try { @@ -227,67 +228,29 @@ export default { while (true) { const { done, value } = await reader.read(); - if (done) break; + if (done) { + // 确保最后一个缓冲区也被处理 + if (buffer.trim()) { + await processBuffer(buffer); + } + break; + } const chunkStr = decoder.decode(value, { stream: true }); buffer += chunkStr; - // SSE 消息通常以 "\n\n" 分隔 - const parts = buffer.split('\n\n'); - buffer = parts.pop() || ''; - - for (const part of parts) { - if (!part.trim()) continue; - - const lines = part.split('\n'); - for (const line of lines) { - if (!line.startsWith('data: ')) continue; - - const dataStr = line.slice('data: '.length).trim(); - if (dataStr === '[DONE]') { - await sendSSE('[DONE]'); - console.log('收到 [DONE],流结束'); - break; - } - - try { - const jsonData = JSON.parse(dataStr); - const delta = jsonData?.choices?.[0]?.delta; - if (!delta) continue; - - let currentDelta = delta.content || ""; - // 去除重复:如果当前内容以上次完整内容为前缀,则只保留新增部分 - let newContent = currentDelta; - if (previousDelta && currentDelta.startsWith(previousDelta)) { - newContent = currentDelta.substring(previousDelta.length); - } - previousDelta = currentDelta; - if (!newContent) continue; - - const openaiChunk = { - id: `chatcmpl-${requestId}`, - object: 'chat.completion.chunk', - created: Date.now(), - model: modelName, - choices: [ - { - index: 0, - delta: isFirstChunk - ? { role: 'assistant', content: newContent } - : { content: newContent }, - finish_reason: null - } - ] - }; - - if (isFirstChunk) isFirstChunk = false; - await sendSSE(JSON.stringify(openaiChunk)); - } catch (err) { - console.error('解析 SSE JSON 失败:', dataStr, err); - } - } + // 更可靠的处理方式:按照 SSE 规范处理双换行符分隔的消息 + await processBuffer(buffer); + + // 仅保留可能不完整的最后一部分 + const lastBoundaryIndex = buffer.lastIndexOf('\n\n'); + if (lastBoundaryIndex !== -1) { + buffer = buffer.substring(lastBoundaryIndex + 2); } } + + // 确保发送最终 DONE 信号 + console.log(`流处理完成,累积内容长度: ${cumulativeContent.length}`); await sendSSE('[DONE]'); } catch (err) { console.error('处理 SSE 流时出错:', err); @@ -305,14 +268,100 @@ export default { ] }; try { - await writer.write(encoder.encode(`data: ${JSON.stringify(errorChunk)}\n\n`)); - await writer.write(encoder.encode('data: [DONE]\n\n')); + await sendSSE(JSON.stringify(errorChunk)); + await sendSSE('[DONE]'); } catch (_) {} } finally { await writer.close(); } }; + // 处理缓冲区内的完整 SSE 消息 + const processBuffer = async (buffer) => { + // 按 data: 行分割 + const dataLineRegex = /^data: (.+)$/gm; + let match; + + while ((match = dataLineRegex.exec(buffer)) !== null) { + const dataStr = match[1].trim(); + + if (dataStr === '[DONE]') { + await sendSSE('[DONE]'); + console.log('收到 [DONE],流结束'); + continue; + } + + try { + const jsonData = JSON.parse(dataStr); + const delta = jsonData?.choices?.[0]?.delta; + if (!delta) continue; + + let currentDelta = delta.content || ""; + + // 改进的去重逻辑:如果有完整内容,检查是否为前缀 + if (currentDelta) { + let newContent = currentDelta; + let needsSending = true; + + if (previousDelta && currentDelta.startsWith(previousDelta)) { + // 只提取新增部分 + newContent = currentDelta.substring(previousDelta.length); + // 如果没有新增内容,跳过发送 + if (!newContent) needsSending = false; + } + + if (needsSending) { + // 创建并发送内容块 + const openaiChunk = { + id: `chatcmpl-${requestId}`, + object: 'chat.completion.chunk', + created: Date.now(), + model: modelName, + choices: [ + { + index: 0, + delta: isFirstChunk + ? { role: 'assistant', content: newContent } + : { content: newContent }, + finish_reason: null + } + ] + }; + + if (isFirstChunk) isFirstChunk = false; + await sendSSE(JSON.stringify(openaiChunk)); + + // 累积内容 + cumulativeContent += newContent; + } + + // 更新之前的内容为当前完整内容 + previousDelta = currentDelta; + } + + // 处理完成标志 + if (jsonData?.choices?.[0]?.finish_reason) { + const finishChunk = { + id: `chatcmpl-${requestId}`, + object: 'chat.completion.chunk', + created: Date.now(), + model: modelName, + choices: [ + { + index: 0, + delta: {}, + finish_reason: jsonData.choices[0].finish_reason + } + ] + }; + await sendSSE(JSON.stringify(finishChunk)); + } + } catch (err) { + console.error('解析 SSE JSON 失败:', dataStr, err); + } + } + }; + processStream(); return new Response(readable, { headers: { diff --git a/qwen2api.go b/qwen2api.go index a61b3e3..c8c38f3 100644 --- a/qwen2api.go +++ b/qwen2api.go @@ -324,30 +324,19 @@ func (pool *WorkerPool) Start() { logDebug("Worker %d 处理任务 reqID:%s", workerID, task.reqID) // 处理任务 - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - - switch task.path { - case "/v1/models": - handleModels(task.w, task.r) - case "/v1/chat/completions": - if task.isStream { - handleStreamingRequest(task.w, task.r, task.apiReq, task.reqID) - } else { - handleNonStreamingRequest(task.w, task.r, task.apiReq, task.reqID) - } - case "/v1/images/generations": - handleImageGenerations(task.w, task.r, task.apiReq, task.reqID) + switch task.path { + case "/v1/models": + handleModels(task.w, task.r) + case "/v1/chat/completions": + if task.isStream { + handleStreamingRequest(task.w, task.r, task.apiReq, task.reqID) + } else { + handleNonStreamingRequest(task.w, task.r, task.apiReq, task.reqID) } - - logInfo("Worker %d 任务 reqID:%s 处理完成", workerID, task.reqID) - }() + case "/v1/images/generations": + handleImageGenerations(task.w, task.r, task.apiReq, task.reqID) + } - // 等待任务完成后再发送通知 - wg.Wait() // 通知任务完成 close(task.done) @@ -605,23 +594,14 @@ func main() { logInfo("[reqID:%s] 任务已提交到队列", reqID) - // 增加健壮的超时处理 - timeoutCtx, cancel := context.WithTimeout(r.Context(), time.Duration(appConfig.Timeout)*time.Second) - defer cancel() - // 等待任务完成或超时 select { case <-task.done: // 任务已完成 - logInfo("[reqID:%s] 任务通道已关闭", reqID) - case <-timeoutCtx.Done(): - if r.Context().Err() != nil { - // 客户端取消 - logWarn("[reqID:%s] 请求被客户端取消", reqID) - } else { - // 服务器端超时 - logWarn("[reqID:%s] 请求处理超时", reqID) - } + logInfo("[reqID:%s] 任务已完成", reqID) + case <-r.Context().Done(): + // 请求被取消或超时 + logWarn("[reqID:%s] 请求被取消或超时", reqID) } // 请求处理完成,更新指标 @@ -1199,1322 +1179,117 @@ func handleStreamingRequest(w http.ResponseWriter, r *http.Request, apiReq APIRe previousContent := "" // 创建正则表达式来查找 data: 行 - dataRegex := regexp.MustCompile(`(?m)^data: (.+)package main + dataRegex := regexp.MustCompile(`(?m)^data: (.+)$`) -import ( - "bufio" - "bytes" - "context" - "crypto/rand" - "crypto/tls" - "encoding/base64" - "encoding/json" - "flag" - "fmt" - "io" - "log" - "mime/multipart" - "net/http" - "os" - "os/signal" - "regexp" - "strings" - "sync" - "sync/atomic" - "syscall" - "time" -) - -// 版本和API常量 -const ( - Version = "1.0.0" - TargetURL = "https://chat.qwen.ai/api/chat/completions" - ModelsURL = "https://chat.qwen.ai/api/models" - FilesURL = "https://chat.qwen.ai/api/v1/files/" - TasksURL = "https://chat.qwen.ai/api/v1/tasks/status/" -) - -// 默认模型列表(当获取接口失败时使用) -var DefaultModels = []string{ - "qwen-max-latest", - "qwen-plus-latest", - "qwen2.5-vl-72b-instruct", - "qwen2.5-14b-instruct-1m", - "qvq-72b-preview", - "qwq-32b-preview", - "qwen2.5-coder-32b-instruct", - "qwen-turbo-latest", - "qwen2.5-72b-instruct", +// 持续读取响应 +buffer := "" +pendingContent := "" // 用于累积内容,解决流处理断开问题 +for { +// 添加超时检测 +select { +case <-r.Context().Done(): +logWarn("[reqID:%s] 请求超时或被客户端取消", reqID) +return +default: +// 继续处理 } -// 扩展模型变种后缀 -var ModelSuffixes = []string{ - "", - "-thinking", - "-search", - "-thinking-search", - "-draw", +// 读取一块数据 +chunk := make([]byte, 4096) +n, err := reader.Read(chunk) +if err != nil { +if err != io.EOF { +logError("[reqID:%s] 读取响应出错: %v", reqID, err) +return +} +break } -// 日志级别常量 -const ( - LogLevelDebug = "debug" - LogLevelInfo = "info" - LogLevelWarn = "warn" - LogLevelError = "error" -) +// 添加到缓冲区 +buffer += string(chunk[:n]) -// WorkerPool 工作池结构体,用于管理goroutine -type WorkerPool struct { - taskQueue chan *Task - workerCount int - shutdownChannel chan struct{} - wg sync.WaitGroup +// 更稳健的处理方式:按行分割并只处理完整行 +lines := strings.Split(buffer, "\n") +// 保留最后可能不完整的行 +if len(lines) > 0 { +buffer = lines[len(lines)-1] } -// Task 任务结构体,包含请求处理所需数据 -type Task struct { - r *http.Request - w http.ResponseWriter - done chan struct{} - reqID string - isStream bool - apiReq APIRequest - path string +// 处理所有完整的行(除最后一行外) +for i := 0; i < len(lines)-1; i++ { +line := lines[i] +if !strings.HasPrefix(line, "data: ") { +continue } -// Semaphore 信号量实现,用于限制并发数量 -type Semaphore struct { - sem chan struct{} +// 提取数据部分 +dataStr := strings.TrimPrefix(line, "data: ") + +// 处理[DONE]消息 +if dataStr == "[DONE]" { +logDebug("[reqID:%s] 收到[DONE]消息", reqID) +w.Write([]byte("data: [DONE]\n\n")) +flusher.Flush() +continue } -// 配置结构体 -type Config struct { - Port string - Address string - LogLevel string - DevMode bool - MaxRetries int - Timeout int - VerifySSL bool - WorkerCount int - QueueSize int - MaxConcurrent int - APIPrefix string +// 解析JSON +var qwenResp QwenResponse +if err := json.Unmarshal([]byte(dataStr), &qwenResp); err != nil { +logWarn("[reqID:%s] 解析JSON失败: %v, data: %s", reqID, err, dataStr) +continue } -// APIRequest OpenAI兼容的请求结构体 -type APIRequest struct { - Model string `json:"model"` - Messages []APIMessage `json:"messages"` - Stream bool `json:"stream"` - Temperature float64 `json:"temperature,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` +// 处理块 +for _, choice := range qwenResp.Choices { +content := choice.Delta.Content + +// 改进去重逻辑 - 只处理重复前缀 +if previousContent != "" && strings.HasPrefix(content, previousContent) { +// 计算新增内容 +newContent := content[len(previousContent):] +if newContent != "" { +// 创建内容块 - 只发送新部分 +contentChunk := createContentChunk(respID, createdTime, modelName, newContent) +w.Write([]byte("data: " + string(contentChunk) + "\n\n")) +flusher.Flush() +pendingContent += newContent // 累积内容 +} +} else if content != "" { +// 直接发送完整内容 +contentChunk := createContentChunk(respID, createdTime, modelName, content) +w.Write([]byte("data: " + string(contentChunk) + "\n\n")) +flusher.Flush() +pendingContent += content // 累积内容 } -// APIMessage 消息结构体 -type APIMessage struct { - Role string `json:"role"` - Content interface{} `json:"content"` - FeatureConfig interface{} `json:"feature_config,omitempty"` - ChatType string `json:"chat_type,omitempty"` - Extra interface{} `json:"extra,omitempty"` +// 更新前一个内容为完整内容 +if content != "" { +previousContent = content } -// 内容项目结构体(处理图像等内容) -type ContentItem struct { - Type string `json:"type,omitempty"` - Text string `json:"text,omitempty"` - ImageURL *ImageURL `json:"image_url,omitempty"` - Image string `json:"image,omitempty"` +// 处理完成标志 +if choice.FinishReason != "" { +finishReason := choice.FinishReason +doneChunk := createDoneChunk(respID, createdTime, modelName, finishReason) +w.Write([]byte("data: " + string(doneChunk) + "\n\n")) +flusher.Flush() +} +} +} } -// ImageURL 图像URL结构体 -type ImageURL struct { - URL string `json:"url"` +// 检查是否有累积的内容需要作为最终响应 +if pendingContent != "" { +logInfo("[reqID:%s] 流处理完成,累积内容长度: %d", reqID, len(pendingContent)) } - -// QwenRequest 通义千问API请求结构体 -type QwenRequest struct { - Model string `json:"model"` - Messages []APIMessage `json:"messages"` - Stream bool `json:"stream"` - ChatType string `json:"chat_type,omitempty"` - ID string `json:"id,omitempty"` - IncrementalOutput bool `json:"incremental_output,omitempty"` - Size string `json:"size,omitempty"` -} - -// QwenResponse 通义千问API响应结构体 -type QwenResponse struct { - Messages []struct { - Role string `json:"role"` - Content string `json:"content"` - Extra struct { - Wanx struct { - TaskID string `json:"task_id"` - } `json:"wanx"` - } `json:"extra"` - } `json:"messages"` - Choices []struct { - Delta struct { - Content string `json:"content"` - } `json:"delta"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` -} - -// FileUploadResponse 文件上传响应 -type FileUploadResponse struct { - ID string `json:"id"` -} - -// TaskStatusResponse 任务状态响应 -type TaskStatusResponse struct { - Content string `json:"content"` -} - -// StreamChunk OpenAI兼容的流式响应块 -type StreamChunk struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - Delta struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` - } `json:"choices"` -} - -// CompletionResponse OpenAI兼容的完成响应 -type CompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` -} - -// ImagesResponse 图像生成响应 -type ImagesResponse struct { - Created int64 `json:"created"` - Data []ImageURL `json:"data"` -} - -// ImagesRequest 图像生成请求 -type ImagesRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - N int `json:"n"` - Size string `json:"size"` -} - -// ModelData 模型数据 -type ModelData struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - OwnedBy string `json:"owned_by"` -} - -// ModelsResponse 模型列表响应 -type ModelsResponse struct { - Object string `json:"object"` - Data []ModelData `json:"data"` -} - -// 全局变量 -var ( - appConfig *Config - logger *log.Logger - logLevel string - logMutex sync.Mutex - workerPool *WorkerPool - requestSem *Semaphore - requestCount uint64 = 0 - countMutex sync.Mutex - // 性能指标 - requestCounter int64 - successCounter int64 - errorCounter int64 - avgResponseTime int64 - queuedRequests int64 - rejectedRequests int64 -) - -// NewSemaphore 创建新的信号量 -func NewSemaphore(size int) *Semaphore { - return &Semaphore{ - sem: make(chan struct{}, size), - } -} - -// Acquire 获取信号量(阻塞) -func (s *Semaphore) Acquire() { - s.sem <- struct{}{} -} - -// Release 释放信号量 -func (s *Semaphore) Release() { - <-s.sem -} - -// TryAcquire 尝试获取信号量(非阻塞) -func (s *Semaphore) TryAcquire() bool { - select { - case s.sem <- struct{}{}: - return true - default: - return false - } -} - -// NewWorkerPool 创建并启动一个新的工作池 -func NewWorkerPool(workerCount int, queueSize int) *WorkerPool { - pool := &WorkerPool{ - taskQueue: make(chan *Task, queueSize), - workerCount: workerCount, - shutdownChannel: make(chan struct{}), - } - - pool.Start() - return pool -} - -// Start 启动工作池中的worker goroutines -func (pool *WorkerPool) Start() { - // 启动工作goroutine - for i := 0; i < pool.workerCount; i++ { - pool.wg.Add(1) - go func(workerID int) { - defer pool.wg.Done() - - logInfo("Worker %d 已启动", workerID) - - for { - select { - case task, ok := <-pool.taskQueue: - if !ok { - // 队列已关闭,退出worker - logInfo("Worker %d 收到队列关闭信号,准备退出", workerID) - return - } - - logDebug("Worker %d 处理任务 reqID:%s", workerID, task.reqID) - - // 处理任务 - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - - switch task.path { - case "/v1/models": - handleModels(task.w, task.r) - case "/v1/chat/completions": - if task.isStream { - handleStreamingRequest(task.w, task.r, task.apiReq, task.reqID) - } else { - handleNonStreamingRequest(task.w, task.r, task.apiReq, task.reqID) - } - case "/v1/images/generations": - handleImageGenerations(task.w, task.r, task.apiReq, task.reqID) - } - - logInfo("Worker %d 任务 reqID:%s 处理完成", workerID, task.reqID) - }() - - // 等待任务完成后再发送通知 - wg.Wait() - // 通知任务完成 - close(task.done) - - case <-pool.shutdownChannel: - // 收到关闭信号,退出worker - logInfo("Worker %d 收到关闭信号,准备退出", workerID) - return - } - } - }(i) - } -} - -// SubmitTask 提交任务到工作池,非阻塞 -func (pool *WorkerPool) SubmitTask(task *Task) (bool, error) { - select { - case pool.taskQueue <- task: - // 任务成功添加到队列 - return true, nil - default: - // 队列已满 - return false, fmt.Errorf("任务队列已满") - } -} - -// Shutdown 关闭工作池 -func (pool *WorkerPool) Shutdown() { - logInfo("正在关闭工作池...") - - // 发送关闭信号给所有worker - close(pool.shutdownChannel) - - // 等待所有worker退出 - pool.wg.Wait() - - // 关闭任务队列 - close(pool.taskQueue) - - logInfo("工作池已关闭") -} - -// 日志函数 -func initLogger(level string) { - logger = log.New(os.Stdout, "[QwenAPI] ", log.LstdFlags) - logLevel = level -} - -func logDebug(format string, v ...interface{}) { - if logLevel == LogLevelDebug { - logMutex.Lock() - logger.Printf("[DEBUG] "+format, v...) - logMutex.Unlock() - } -} - -func logInfo(format string, v ...interface{}) { - if logLevel == LogLevelDebug || logLevel == LogLevelInfo { - logMutex.Lock() - logger.Printf("[INFO] "+format, v...) - logMutex.Unlock() - } -} - -func logWarn(format string, v ...interface{}) { - if logLevel == LogLevelDebug || logLevel == LogLevelInfo || logLevel == LogLevelWarn { - logMutex.Lock() - logger.Printf("[WARN] "+format, v...) - logMutex.Unlock() - } -} - -func logError(format string, v ...interface{}) { - logMutex.Lock() - logger.Printf("[ERROR] "+format, v...) - logMutex.Unlock() - - // 错误计数 - atomic.AddInt64(&errorCounter, 1) -} - -// 解析命令行参数 -func parseFlags() *Config { - cfg := &Config{} - flag.StringVar(&cfg.Port, "port", "8080", "Port to listen on") - flag.StringVar(&cfg.Address, "address", "localhost", "Address to listen on") - flag.StringVar(&cfg.LogLevel, "log-level", LogLevelInfo, "Log level (debug, info, warn, error)") - flag.BoolVar(&cfg.DevMode, "dev", false, "Enable development mode with enhanced logging") - flag.IntVar(&cfg.MaxRetries, "max-retries", 3, "Maximum number of retries for failed requests") - flag.IntVar(&cfg.Timeout, "timeout", 300, "Request timeout in seconds") - flag.BoolVar(&cfg.VerifySSL, "verify-ssl", true, "Verify SSL certificates") - flag.IntVar(&cfg.WorkerCount, "workers", 50, "Number of worker goroutines in the pool") - flag.IntVar(&cfg.QueueSize, "queue-size", 500, "Size of the task queue") - flag.IntVar(&cfg.MaxConcurrent, "max-concurrent", 100, "Maximum number of concurrent requests") - flag.StringVar(&cfg.APIPrefix, "api-prefix", "", "API prefix for all endpoints") - flag.Parse() - - // 如果开发模式开启,自动设置日志级别为debug - if cfg.DevMode && cfg.LogLevel != LogLevelDebug { - cfg.LogLevel = LogLevelDebug - fmt.Println("开发模式已启用,日志级别设置为debug") - } - - return cfg -} - -// 从请求头中提取令牌 -func extractToken(r *http.Request) (string, error) { - // 获取 Authorization 头部 - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", fmt.Errorf("missing Authorization header") - } - - // 验证格式并提取令牌 - if !strings.HasPrefix(authHeader, "Bearer ") { - return "", fmt.Errorf("invalid Authorization header format, must start with 'Bearer '") - } - - // 提取令牌值 - token := strings.TrimPrefix(authHeader, "Bearer ") - if token == "" { - return "", fmt.Errorf("empty token in Authorization header") - } - - return token, nil -} - -// 设置CORS头 -func setCORSHeaders(w http.ResponseWriter) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") -} - -// 生成UUID -func generateUUID() string { - b := make([]byte, 16) - _, err := rand.Read(b) - if err != nil { - return fmt.Sprintf("%d", time.Now().UnixNano()) - } - - return fmt.Sprintf("%x-%x-%x-%x-%x", - b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) -} - -// 安全的HTTP客户端,支持禁用SSL验证 -func getHTTPClient() *http.Client { - tr := &http.Transport{ - MaxIdleConnsPerHost: 100, - IdleConnTimeout: 90 * time.Second, - TLSClientConfig: nil, // 默认配置 - } - - // 如果配置了禁用SSL验证 - if !appConfig.VerifySSL { - tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - - return &http.Client{ - Timeout: time.Duration(appConfig.Timeout) * time.Second, - Transport: tr, - } -} - -// 主入口函数 -func main() { - // 解析配置 - appConfig = parseFlags() - - // 初始化日志 - initLogger(appConfig.LogLevel) - - logInfo("启动服务: 地址=%s, 端口=%s, 版本=%s, 日志级别=%s", - appConfig.Address, appConfig.Port, Version, appConfig.LogLevel) - - // 创建工作池和信号量 - workerPool = NewWorkerPool(appConfig.WorkerCount, appConfig.QueueSize) - requestSem = NewSemaphore(appConfig.MaxConcurrent) - - logInfo("工作池已创建: %d个worker, 队列大小为%d", appConfig.WorkerCount, appConfig.QueueSize) - - // 配置更高的并发处理能力 - http.DefaultTransport.(*http.Transport).MaxIdleConnsPerHost = 100 - http.DefaultTransport.(*http.Transport).MaxIdleConns = 100 - http.DefaultTransport.(*http.Transport).IdleConnTimeout = 90 * time.Second - - // 创建自定义服务器,支持更高并发 - server := &http.Server{ - Addr: appConfig.Address + ":" + appConfig.Port, - ReadTimeout: time.Duration(appConfig.Timeout) * time.Second, - WriteTimeout: time.Duration(appConfig.Timeout) * time.Second, - IdleTimeout: 120 * time.Second, - Handler: nil, // 使用默认的ServeMux - } - - // API路径前缀 - apiPrefix := appConfig.APIPrefix - - // 创建处理器 - http.HandleFunc(apiPrefix+"/v1/models", func(w http.ResponseWriter, r *http.Request) { - setCORSHeaders(w) - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - // 计数 - countMutex.Lock() - requestCount++ - currentCount := requestCount - countMutex.Unlock() - - reqID := generateRequestID() - logInfo("[reqID:%s] 收到模型列表请求 #%d", reqID, currentCount) - - // 请求计数 - atomic.AddInt64(&requestCounter, 1) - - startTime := time.Now() - - // 创建任务 - task := &Task{ - r: r, - w: w, - done: make(chan struct{}), - reqID: reqID, - path: "/v1/models", - } - - // 尝试获取信号量 - if !requestSem.TryAcquire() { - // 请求数量超过限制 - atomic.AddInt64(&rejectedRequests, 1) - logWarn("[reqID:%s] 请求被拒绝: 当前并发请求数已达上限", reqID) - w.Header().Set("Retry-After", "30") - http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable) - return - } - - // 释放信号量(在函数返回时) - defer requestSem.Release() - - // 添加到任务队列 - atomic.AddInt64(&queuedRequests, 1) - submitted, err := workerPool.SubmitTask(task) - if !submitted { - atomic.AddInt64(&queuedRequests, -1) - atomic.AddInt64(&rejectedRequests, 1) - logError("[reqID:%s] 提交任务失败: %v", reqID, err) - w.Header().Set("Retry-After", "60") - http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable) - return - } - - logInfo("[reqID:%s] 任务已提交到队列", reqID) - - // 等待任务完成或超时 - select { - case <-task.done: - // 任务已完成 - logInfo("[reqID:%s] 任务已完成", reqID) - case <-r.Context().Done(): - // 请求被取消或超时 - logWarn("[reqID:%s] 请求被取消或超时", reqID) - } - - // 请求处理完成,更新指标 - atomic.AddInt64(&queuedRequests, -1) - elapsed := time.Since(startTime).Milliseconds() - - // 更新平均响应时间 - atomic.AddInt64(&avgResponseTime, elapsed) - - if r.Context().Err() == nil { - // 成功计数增加 - atomic.AddInt64(&successCounter, 1) - logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed) - } else { - logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed) - } - }) - - http.HandleFunc(apiPrefix+"/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - setCORSHeaders(w) - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - // 计数器增加 - countMutex.Lock() - requestCount++ - currentCount := requestCount - countMutex.Unlock() - - reqID := generateRequestID() - logInfo("[reqID:%s] 收到新请求 #%d", reqID, currentCount) - - // 请求计数 - atomic.AddInt64(&requestCounter, 1) - - startTime := time.Now() - - // 尝试获取信号量 - if !requestSem.TryAcquire() { - // 请求数量超过限制 - atomic.AddInt64(&rejectedRequests, 1) - logWarn("[reqID:%s] 请求 #%d 被拒绝: 当前并发请求数已达上限", reqID, currentCount) - w.Header().Set("Retry-After", "30") - http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable) - return - } - - // 释放信号量(在函数返回时) - defer requestSem.Release() - - // 解析请求体 - var apiReq APIRequest - if err := json.NewDecoder(r.Body).Decode(&apiReq); err != nil { - logError("[reqID:%s] 解析请求失败: %v", reqID, err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return - } - - // 创建任务 - task := &Task{ - r: r, - w: w, - done: make(chan struct{}), - reqID: reqID, - isStream: apiReq.Stream, - apiReq: apiReq, - path: "/v1/chat/completions", - } - - // 添加到任务队列 - atomic.AddInt64(&queuedRequests, 1) - submitted, err := workerPool.SubmitTask(task) - if !submitted { - atomic.AddInt64(&queuedRequests, -1) - atomic.AddInt64(&rejectedRequests, 1) - logError("[reqID:%s] 提交任务失败: %v", reqID, err) - w.Header().Set("Retry-After", "60") - http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable) - return - } - - logInfo("[reqID:%s] 任务已提交到队列", reqID) - - // 等待任务完成或超时 - select { - case <-task.done: - // 任务已完成 - logInfo("[reqID:%s] 任务已完成", reqID) - case <-r.Context().Done(): - // 请求被取消或超时 - logWarn("[reqID:%s] 请求被取消或超时", reqID) - } - - // 请求处理完成,更新指标 - atomic.AddInt64(&queuedRequests, -1) - elapsed := time.Since(startTime).Milliseconds() - - // 更新平均响应时间 - atomic.AddInt64(&avgResponseTime, elapsed) - - if r.Context().Err() == nil { - // 成功计数增加 - atomic.AddInt64(&successCounter, 1) - logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed) - } else { - logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed) - } - }) - - http.HandleFunc(apiPrefix+"/v1/images/generations", func(w http.ResponseWriter, r *http.Request) { - setCORSHeaders(w) - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - // 计数器增加 - countMutex.Lock() - requestCount++ - currentCount := requestCount - countMutex.Unlock() - - reqID := generateRequestID() - logInfo("[reqID:%s] 收到图像生成请求 #%d", reqID, currentCount) - - // 请求计数 - atomic.AddInt64(&requestCounter, 1) - - startTime := time.Now() - - // 尝试获取信号量 - if !requestSem.TryAcquire() { - // 请求数量超过限制 - atomic.AddInt64(&rejectedRequests, 1) - logWarn("[reqID:%s] 请求 #%d 被拒绝: 当前并发请求数已达上限", reqID, currentCount) - w.Header().Set("Retry-After", "30") - http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable) - return - } - - // 释放信号量(在函数返回时) - defer requestSem.Release() - - // 解析请求体 - var apiReq APIRequest - if err := json.NewDecoder(r.Body).Decode(&apiReq); err != nil { - logError("[reqID:%s] 解析请求失败: %v", reqID, err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return - } - - // 创建任务 - task := &Task{ - r: r, - w: w, - done: make(chan struct{}), - reqID: reqID, - apiReq: apiReq, - path: "/v1/images/generations", - } - - // 添加到任务队列 - atomic.AddInt64(&queuedRequests, 1) - submitted, err := workerPool.SubmitTask(task) - if !submitted { - atomic.AddInt64(&queuedRequests, -1) - atomic.AddInt64(&rejectedRequests, 1) - logError("[reqID:%s] 提交任务失败: %v", reqID, err) - w.Header().Set("Retry-After", "60") - http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable) - return - } - - logInfo("[reqID:%s] 任务已提交到队列", reqID) - - // 等待任务完成或超时 - select { - case <-task.done: - // 任务已完成 - logInfo("[reqID:%s] 任务已完成", reqID) - case <-r.Context().Done(): - // 请求被取消或超时 - logWarn("[reqID:%s] 请求被取消或超时", reqID) - } - - // 请求处理完成,更新指标 - atomic.AddInt64(&queuedRequests, -1) - elapsed := time.Since(startTime).Milliseconds() - - // 更新平均响应时间 - atomic.AddInt64(&avgResponseTime, elapsed) - - if r.Context().Err() == nil { - // 成功计数增加 - atomic.AddInt64(&successCounter, 1) - logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed) - } else { - logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed) - } - }) - - // 添加健康检查端点 - http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - setCORSHeaders(w) - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - // 获取各种计数器的值 - reqCount := atomic.LoadInt64(&requestCounter) - succCount := atomic.LoadInt64(&successCounter) - errCount := atomic.LoadInt64(&errorCounter) - queuedCount := atomic.LoadInt64(&queuedRequests) - rejectedCount := atomic.LoadInt64(&rejectedRequests) - - // 计算平均响应时间 - var avgTime int64 = 0 - if reqCount > 0 { - avgTime = atomic.LoadInt64(&avgResponseTime) / reqCount - } - - // 构建响应 - stats := map[string]interface{}{ - "status": "ok", - "version": Version, - "requests": reqCount, - "success": succCount, - "errors": errCount, - "queued": queuedCount, - "rejected": rejectedCount, - "avg_time_ms": avgTime, - "worker_count": workerPool.workerCount, - "queue_size": len(workerPool.taskQueue), - "queue_capacity": cap(workerPool.taskQueue), - "queue_percent": float64(len(workerPool.taskQueue)) / float64(cap(workerPool.taskQueue)) * 100, - "concurrent_limit": appConfig.MaxConcurrent, - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(stats) - }) - - // 创建停止通道 - stop := make(chan os.Signal, 1) - signal.Notify(stop, os.Interrupt, syscall.SIGTERM) - - // 在goroutine中启动服务器 - go func() { - logInfo("Starting proxy server on %s", server.Addr) - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logError("Failed to start server: %v", err) - os.Exit(1) - } - }() - - // 等待停止信号 - <-stop - - // 创建上下文用于优雅关闭 - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // 优雅关闭服务器 - logInfo("Server is shutting down...") - if err := server.Shutdown(ctx); err != nil { - logError("Server shutdown failed: %v", err) - } - - // 关闭工作池 - workerPool.Shutdown() - - logInfo("Server gracefully stopped") -} - -// 生成请求ID -func generateRequestID() string { - return fmt.Sprintf("%x", time.Now().UnixNano()) -} - -// 处理模型列表请求 -func handleModels(w http.ResponseWriter, r *http.Request) { - logInfo("处理模型列表请求") - - // 从请求中提取token - authToken, err := extractToken(r) - if err != nil { - logWarn("提取token失败: %v", err) - // 使用默认模型列表 - returnDefaultModels(w) - return - } - - // 请求通义千问API获取模型列表 - client := getHTTPClient() - req, err := http.NewRequest("GET", ModelsURL, nil) - if err != nil { - logError("创建请求失败: %v", err) - returnDefaultModels(w) - return - } - - // 设置请求头 - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Mozilla/5.0") - - // 发送请求 - resp, err := client.Do(req) - if err != nil { - logError("请求模型列表失败: %v", err) - returnDefaultModels(w) - return - } - defer resp.Body.Close() - - // 检查响应状态 - if resp.StatusCode != http.StatusOK { - logError("获取模型列表返回非200状态码: %d", resp.StatusCode) - returnDefaultModels(w) - return - } - - // 解析响应 - var qwenResp struct { - Data []struct { - ID string `json:"id"` - } `json:"data"` - } - if err := json.NewDecoder(resp.Body).Decode(&qwenResp); err != nil { - logError("解析模型列表响应失败: %v", err) - returnDefaultModels(w) - return - } - - // 提取模型ID - models := make([]string, 0, len(qwenResp.Data)) - for _, model := range qwenResp.Data { - models = append(models, model.ID) - } - - // 如果没有获取到模型,使用默认列表 - if len(models) == 0 { - logWarn("未获取到模型,使用默认列表") - returnDefaultModels(w) - return - } - - // 扩展模型列表,增加变种后缀 - expandedModels := make([]ModelData, 0, len(models)*len(ModelSuffixes)) - for _, model := range models { - for _, suffix := range ModelSuffixes { - expandedModels = append(expandedModels, ModelData{ - ID: model + suffix, - Object: "model", - Created: time.Now().Unix(), - OwnedBy: "qwen", - }) - } - } - - // 构建响应 - modelsResp := ModelsResponse{ - Object: "list", - Data: expandedModels, - } - - // 返回响应 - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(modelsResp) -} - -// 返回默认模型列表 -func returnDefaultModels(w http.ResponseWriter) { - // 扩展默认模型列表,增加变种后缀 - expandedModels := make([]ModelData, 0, len(DefaultModels)*len(ModelSuffixes)) - for _, model := range DefaultModels { - for _, suffix := range ModelSuffixes { - expandedModels = append(expandedModels, ModelData{ - ID: model + suffix, - Object: "model", - Created: time.Now().Unix(), - OwnedBy: "qwen", - }) - } - } - - // 构建响应 - modelsResp := ModelsResponse{ - Object: "list", - Data: expandedModels, - } - - // 返回响应 - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(modelsResp) -} - -// 处理聊天完成请求(流式) -func handleStreamingRequest(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string) { - logInfo("[reqID:%s] 处理流式请求", reqID) - - // 从请求中提取token - authToken, err := extractToken(r) - if err != nil { - logError("[reqID:%s] 提取token失败: %v", reqID, err) - http.Error(w, "无效的认证信息", http.StatusUnauthorized) - return - } - - // 检查消息 - if len(apiReq.Messages) == 0 { - logError("[reqID:%s] 消息为空", reqID) - http.Error(w, "消息为空", http.StatusBadRequest) - return - } - - // 准备模型名和聊天类型 - modelName := "qwen-turbo-latest" - if apiReq.Model != "" { - modelName = apiReq.Model - } - chatType := "t2t" - - // 处理特殊模型名后缀 - if strings.Contains(modelName, "-draw") { - handleDrawRequest(w, r, apiReq, reqID, authToken) - return - } - - // 处理思考模式 - if strings.Contains(modelName, "-thinking") { - modelName = strings.Replace(modelName, "-thinking", "", 1) - lastMsgIdx := len(apiReq.Messages) - 1 - if lastMsgIdx >= 0 { - apiReq.Messages[lastMsgIdx].FeatureConfig = map[string]interface{}{ - "thinking_enabled": true, - } - } - } - - // 处理搜索模式 - if strings.Contains(modelName, "-search") { - modelName = strings.Replace(modelName, "-search", "", 1) - chatType = "search" - lastMsgIdx := len(apiReq.Messages) - 1 - if lastMsgIdx >= 0 { - apiReq.Messages[lastMsgIdx].ChatType = "search" - } - } - - // 处理图片消息 - lastMsgIdx := len(apiReq.Messages) - 1 - if lastMsgIdx >= 0 { - lastMsg := apiReq.Messages[lastMsgIdx] - - // 检查内容是否为数组 - contentArray, ok := lastMsg.Content.([]interface{}) - if ok { - // 处理内容数组 - for i, item := range contentArray { - itemMap, isMap := item.(map[string]interface{}) - if !isMap { - continue - } - - // 检查是否包含图像URL - if imageURL, hasImageURL := itemMap["image_url"]; hasImageURL { - imageURLMap, isMap := imageURL.(map[string]interface{}) - if !isMap { - continue - } - - // 获取URL - url, hasURL := imageURLMap["url"].(string) - if !hasURL { - continue - } - - // 上传图像 - imageID, uploadErr := uploadImage(url, authToken) - if uploadErr != nil { - logError("[reqID:%s] 上传图像失败: %v", reqID, uploadErr) - continue - } - - // 替换内容 - contentArrayCopy := make([]interface{}, len(contentArray)) - copy(contentArrayCopy, contentArray) - contentArrayCopy[i] = map[string]interface{}{ - "type": "image", - "image": imageID, - } - apiReq.Messages[lastMsgIdx].Content = contentArrayCopy - break - } - } - } - } - - // 创建通义千问请求 - qwenReq := QwenRequest{ - Model: modelName, - Messages: apiReq.Messages, - Stream: true, - ChatType: chatType, - ID: generateUUID(), - } - - // 序列化请求 - reqData, err := json.Marshal(qwenReq) - if err != nil { - logError("[reqID:%s] 序列化请求失败: %v", reqID, err) - http.Error(w, "内部服务器错误", http.StatusInternalServerError) - return - } - - // 创建HTTP请求 - req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData)) - if err != nil { - logError("[reqID:%s] 创建请求失败: %v", reqID, err) - http.Error(w, "内部服务器错误", http.StatusInternalServerError) - return - } - - // 设置请求头 - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Mozilla/5.0") - - // 发送请求 - client := getHTTPClient() - resp, err := client.Do(req) - if err != nil { - logError("[reqID:%s] 发送请求失败: %v", reqID, err) - http.Error(w, "连接到API失败", http.StatusBadGateway) - return - } - defer resp.Body.Close() - - // 检查响应状态 - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes)) - http.Error(w, fmt.Sprintf("API错误,状态码: %d", resp.StatusCode), resp.StatusCode) - return - } - - // 设置响应头 - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - - // 创建响应ID和时间戳 - respID := fmt.Sprintf("chatcmpl-%s", generateUUID()) - createdTime := time.Now().Unix() - - // 创建读取器和Flusher - reader := bufio.NewReaderSize(resp.Body, 16384) - flusher, ok := w.(http.Flusher) - if !ok { - logError("[reqID:%s] 流式传输不支持", reqID) - http.Error(w, "流式传输不支持", http.StatusInternalServerError) - return - } - - // 发送角色块 - roleChunk := createRoleChunk(respID, createdTime, modelName) - w.Write([]byte("data: " + string(roleChunk) + "\n\n")) + // 发送结束信号(如果没有正常结束) + finishReason := "stop" + doneChunk := createDoneChunk(respID, createdTime, modelName, finishReason) + w.Write([]byte("data: " + string(doneChunk) + "\n\n")) + w.Write([]byte("data: [DONE]\n\n")) flusher.Flush() - - // 用于去重的前一个内容 - previousContent := "" - - ) - - // 流式传输状态跟踪 - var isCompleted bool - var totalChunks int - buffer := "" - - logInfo("[reqID:%s] 开始流式传输...", reqID) - - // 持续读取响应 - for { - // 添加超时检测 - select { - case <-r.Context().Done(): - logWarn("[reqID:%s] 请求超时或被客户端取消", reqID) - return - default: - // 继续处理 - } - - // 读取一块数据 - chunk := make([]byte, 4096) - n, err := reader.Read(chunk) - - // 处理读取结果 - if err != nil { - if err != io.EOF { - logError("[reqID:%s] 读取响应出错: %v", reqID, err) - return - } - - // EOF处理 - logInfo("[reqID:%s] 读取到文件末尾,完成流式传输", reqID) - - // 如果buffer中还有内容,尝试处理 - if len(buffer) > 0 { - logInfo("[reqID:%s] 处理最后的缓冲区 (长度: %d)", reqID, len(buffer)) - } - - if !isCompleted { - // 发送结束信号(如果没有正常结束) - finishReason := "stop" - doneChunk := createDoneChunk(respID, createdTime, modelName, finishReason) - w.Write([]byte("data: " + string(doneChunk) + "\n\n")) - w.Write([]byte("data: [DONE]\n\n")) - flusher.Flush() - logInfo("[reqID:%s] 发送结束信号,共传输 %d 个数据块", reqID, totalChunks) - } - break - } - - if n > 0 { - // 添加到缓冲区 - buffer += string(chunk[:n]) - - // 查找所有的data行 - matches := dataRegex.FindAllStringSubmatch(buffer, -1) - - // 处理匹配到的行 - for _, match := range matches { - // 获取数据部分 - dataStr := match[1] - - // 从缓冲区中移除已处理的行 - buffer = strings.Replace(buffer, "data: "+dataStr+"\n", "", 1) - - // 处理[DONE]消息 - if dataStr == "[DONE]" { - w.Write([]byte("data: [DONE]\n\n")) - flusher.Flush() - isCompleted = true - logInfo("[reqID:%s] 处理完成信号 [DONE]", reqID) - continue - } - - // 解析JSON - var qwenResp QwenResponse - if err := json.Unmarshal([]byte(dataStr), &qwenResp); err != nil { - logWarn("[reqID:%s] 解析JSON失败: %v, data: %s", reqID, err, dataStr) - continue - } - - // 处理块 - for _, choice := range qwenResp.Choices { - content := choice.Delta.Content - - // 去重 - if strings.HasPrefix(content, previousContent) { - content = content[len(previousContent):] - } - - if content != "" { - previousContent += content - - // 创建内容块 - contentChunk := createContentChunk(respID, createdTime, modelName, content) - w.Write([]byte("data: " + string(contentChunk) + "\n\n")) - flusher.Flush() - totalChunks++ - - if totalChunks % 10 == 0 { - logInfo("[reqID:%s] 已传输 %d 个数据块", reqID, totalChunks) - } - } - - // 处理完成标志 - if choice.FinishReason != "" { - finishReason := choice.FinishReason - doneChunk := createDoneChunk(respID, createdTime, modelName, finishReason) - w.Write([]byte("data: " + string(doneChunk) + "\n\n")) - flusher.Flush() - logInfo("[reqID:%s] 处理完成标志: %s", reqID, finishReason) - } - } - } - } - - // 如果已完成,退出循环 - if isCompleted && buffer == "" { - logInfo("[reqID:%s] 处理完成并且缓冲区为空,结束流式传输", reqID) - break - } - } - - logInfo("[reqID:%s] 流式传输完成,总共发送 %d 个数据块", reqID, totalChunks) } // 处理聊天完成请求(非流式) @@ -3118,3189 +1893,70 @@ func handleDrawRequest(w http.ResponseWriter, r *http.Request, apiReq APIRequest // 从流式响应中提取完整内容 func extractFullContentFromStream(body io.ReadCloser, reqID string) (string, error) { - var contentBuilder strings.Builder - - // 创建读取器 - reader := bufio.NewReaderSize(body, 16384) - - // 创建正则表达式来查找 data: 行 - dataRegex := regexp.MustCompile(`(?m)^data: (.+)package main +var contentBuilder strings.Builder -import ( - "bufio" - "bytes" - "context" - "crypto/rand" - "crypto/tls" - "encoding/base64" - "encoding/json" - "flag" - "fmt" - "io" - "log" - "mime/multipart" - "net/http" - "os" - "os/signal" - "regexp" - "strings" - "sync" - "sync/atomic" - "syscall" - "time" -) +// 创建读取器 +reader := bufio.NewReaderSize(body, 16384) -// 版本和API常量 -const ( - Version = "1.0.0" - TargetURL = "https://chat.qwen.ai/api/chat/completions" - ModelsURL = "https://chat.qwen.ai/api/models" - FilesURL = "https://chat.qwen.ai/api/v1/files/" - TasksURL = "https://chat.qwen.ai/api/v1/tasks/status/" -) - -// 默认模型列表(当获取接口失败时使用) -var DefaultModels = []string{ - "qwen-max-latest", - "qwen-plus-latest", - "qwen2.5-vl-72b-instruct", - "qwen2.5-14b-instruct-1m", - "qvq-72b-preview", - "qwq-32b-preview", - "qwen2.5-coder-32b-instruct", - "qwen-turbo-latest", - "qwen2.5-72b-instruct", -} - -// 扩展模型变种后缀 -var ModelSuffixes = []string{ - "", - "-thinking", - "-search", - "-thinking-search", - "-draw", -} - -// 日志级别常量 -const ( - LogLevelDebug = "debug" - LogLevelInfo = "info" - LogLevelWarn = "warn" - LogLevelError = "error" -) - -// WorkerPool 工作池结构体,用于管理goroutine -type WorkerPool struct { - taskQueue chan *Task - workerCount int - shutdownChannel chan struct{} - wg sync.WaitGroup -} - -// Task 任务结构体,包含请求处理所需数据 -type Task struct { - r *http.Request - w http.ResponseWriter - done chan struct{} - reqID string - isStream bool - apiReq APIRequest - path string -} - -// Semaphore 信号量实现,用于限制并发数量 -type Semaphore struct { - sem chan struct{} -} - -// 配置结构体 -type Config struct { - Port string - Address string - LogLevel string - DevMode bool - MaxRetries int - Timeout int - VerifySSL bool - WorkerCount int - QueueSize int - MaxConcurrent int - APIPrefix string -} - -// APIRequest OpenAI兼容的请求结构体 -type APIRequest struct { - Model string `json:"model"` - Messages []APIMessage `json:"messages"` - Stream bool `json:"stream"` - Temperature float64 `json:"temperature,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` -} - -// APIMessage 消息结构体 -type APIMessage struct { - Role string `json:"role"` - Content interface{} `json:"content"` - FeatureConfig interface{} `json:"feature_config,omitempty"` - ChatType string `json:"chat_type,omitempty"` - Extra interface{} `json:"extra,omitempty"` -} - -// 内容项目结构体(处理图像等内容) -type ContentItem struct { - Type string `json:"type,omitempty"` - Text string `json:"text,omitempty"` - ImageURL *ImageURL `json:"image_url,omitempty"` - Image string `json:"image,omitempty"` -} - -// ImageURL 图像URL结构体 -type ImageURL struct { - URL string `json:"url"` -} - -// QwenRequest 通义千问API请求结构体 -type QwenRequest struct { - Model string `json:"model"` - Messages []APIMessage `json:"messages"` - Stream bool `json:"stream"` - ChatType string `json:"chat_type,omitempty"` - ID string `json:"id,omitempty"` - IncrementalOutput bool `json:"incremental_output,omitempty"` - Size string `json:"size,omitempty"` -} - -// QwenResponse 通义千问API响应结构体 -type QwenResponse struct { - Messages []struct { - Role string `json:"role"` - Content string `json:"content"` - Extra struct { - Wanx struct { - TaskID string `json:"task_id"` - } `json:"wanx"` - } `json:"extra"` - } `json:"messages"` - Choices []struct { - Delta struct { - Content string `json:"content"` - } `json:"delta"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` -} - -// FileUploadResponse 文件上传响应 -type FileUploadResponse struct { - ID string `json:"id"` -} - -// TaskStatusResponse 任务状态响应 -type TaskStatusResponse struct { - Content string `json:"content"` -} - -// StreamChunk OpenAI兼容的流式响应块 -type StreamChunk struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - Delta struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` - } `json:"choices"` -} - -// CompletionResponse OpenAI兼容的完成响应 -type CompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` -} - -// ImagesResponse 图像生成响应 -type ImagesResponse struct { - Created int64 `json:"created"` - Data []ImageURL `json:"data"` -} - -// ImagesRequest 图像生成请求 -type ImagesRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - N int `json:"n"` - Size string `json:"size"` -} - -// ModelData 模型数据 -type ModelData struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - OwnedBy string `json:"owned_by"` -} - -// ModelsResponse 模型列表响应 -type ModelsResponse struct { - Object string `json:"object"` - Data []ModelData `json:"data"` -} - -// 全局变量 -var ( - appConfig *Config - logger *log.Logger - logLevel string - logMutex sync.Mutex - workerPool *WorkerPool - requestSem *Semaphore - requestCount uint64 = 0 - countMutex sync.Mutex - - // 性能指标 - requestCounter int64 - successCounter int64 - errorCounter int64 - avgResponseTime int64 - queuedRequests int64 - rejectedRequests int64 -) - -// NewSemaphore 创建新的信号量 -func NewSemaphore(size int) *Semaphore { - return &Semaphore{ - sem: make(chan struct{}, size), - } -} - -// Acquire 获取信号量(阻塞) -func (s *Semaphore) Acquire() { - s.sem <- struct{}{} -} - -// Release 释放信号量 -func (s *Semaphore) Release() { - <-s.sem -} - -// TryAcquire 尝试获取信号量(非阻塞) -func (s *Semaphore) TryAcquire() bool { - select { - case s.sem <- struct{}{}: - return true - default: - return false - } -} - -// NewWorkerPool 创建并启动一个新的工作池 -func NewWorkerPool(workerCount int, queueSize int) *WorkerPool { - pool := &WorkerPool{ - taskQueue: make(chan *Task, queueSize), - workerCount: workerCount, - shutdownChannel: make(chan struct{}), - } - - pool.Start() - return pool -} - -// Start 启动工作池中的worker goroutines -func (pool *WorkerPool) Start() { - // 启动工作goroutine - for i := 0; i < pool.workerCount; i++ { - pool.wg.Add(1) - go func(workerID int) { - defer pool.wg.Done() - - logInfo("Worker %d 已启动", workerID) - - for { - select { - case task, ok := <-pool.taskQueue: - if !ok { - // 队列已关闭,退出worker - logInfo("Worker %d 收到队列关闭信号,准备退出", workerID) - return - } - - logDebug("Worker %d 处理任务 reqID:%s", workerID, task.reqID) - - // 处理任务 - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - - switch task.path { - case "/v1/models": - handleModels(task.w, task.r) - case "/v1/chat/completions": - if task.isStream { - handleStreamingRequest(task.w, task.r, task.apiReq, task.reqID) - } else { - handleNonStreamingRequest(task.w, task.r, task.apiReq, task.reqID) - } - case "/v1/images/generations": - handleImageGenerations(task.w, task.r, task.apiReq, task.reqID) - } - - logInfo("Worker %d 任务 reqID:%s 处理完成", workerID, task.reqID) - }() - - // 等待任务完成后再发送通知 - wg.Wait() - // 通知任务完成 - close(task.done) - - case <-pool.shutdownChannel: - // 收到关闭信号,退出worker - logInfo("Worker %d 收到关闭信号,准备退出", workerID) - return - } - } - }(i) - } -} - -// SubmitTask 提交任务到工作池,非阻塞 -func (pool *WorkerPool) SubmitTask(task *Task) (bool, error) { - select { - case pool.taskQueue <- task: - // 任务成功添加到队列 - return true, nil - default: - // 队列已满 - return false, fmt.Errorf("任务队列已满") - } -} - -// Shutdown 关闭工作池 -func (pool *WorkerPool) Shutdown() { - logInfo("正在关闭工作池...") - - // 发送关闭信号给所有worker - close(pool.shutdownChannel) - - // 等待所有worker退出 - pool.wg.Wait() - - // 关闭任务队列 - close(pool.taskQueue) - - logInfo("工作池已关闭") -} - -// 日志函数 -func initLogger(level string) { - logger = log.New(os.Stdout, "[QwenAPI] ", log.LstdFlags) - logLevel = level -} - -func logDebug(format string, v ...interface{}) { - if logLevel == LogLevelDebug { - logMutex.Lock() - logger.Printf("[DEBUG] "+format, v...) - logMutex.Unlock() - } -} - -func logInfo(format string, v ...interface{}) { - if logLevel == LogLevelDebug || logLevel == LogLevelInfo { - logMutex.Lock() - logger.Printf("[INFO] "+format, v...) - logMutex.Unlock() - } -} - -func logWarn(format string, v ...interface{}) { - if logLevel == LogLevelDebug || logLevel == LogLevelInfo || logLevel == LogLevelWarn { - logMutex.Lock() - logger.Printf("[WARN] "+format, v...) - logMutex.Unlock() - } -} - -func logError(format string, v ...interface{}) { - logMutex.Lock() - logger.Printf("[ERROR] "+format, v...) - logMutex.Unlock() - - // 错误计数 - atomic.AddInt64(&errorCounter, 1) -} - -// 解析命令行参数 -func parseFlags() *Config { - cfg := &Config{} - flag.StringVar(&cfg.Port, "port", "8080", "Port to listen on") - flag.StringVar(&cfg.Address, "address", "localhost", "Address to listen on") - flag.StringVar(&cfg.LogLevel, "log-level", LogLevelInfo, "Log level (debug, info, warn, error)") - flag.BoolVar(&cfg.DevMode, "dev", false, "Enable development mode with enhanced logging") - flag.IntVar(&cfg.MaxRetries, "max-retries", 3, "Maximum number of retries for failed requests") - flag.IntVar(&cfg.Timeout, "timeout", 300, "Request timeout in seconds") - flag.BoolVar(&cfg.VerifySSL, "verify-ssl", true, "Verify SSL certificates") - flag.IntVar(&cfg.WorkerCount, "workers", 50, "Number of worker goroutines in the pool") - flag.IntVar(&cfg.QueueSize, "queue-size", 500, "Size of the task queue") - flag.IntVar(&cfg.MaxConcurrent, "max-concurrent", 100, "Maximum number of concurrent requests") - flag.StringVar(&cfg.APIPrefix, "api-prefix", "", "API prefix for all endpoints") - flag.Parse() - - // 如果开发模式开启,自动设置日志级别为debug - if cfg.DevMode && cfg.LogLevel != LogLevelDebug { - cfg.LogLevel = LogLevelDebug - fmt.Println("开发模式已启用,日志级别设置为debug") - } - - return cfg -} - -// 从请求头中提取令牌 -func extractToken(r *http.Request) (string, error) { - // 获取 Authorization 头部 - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", fmt.Errorf("missing Authorization header") - } - - // 验证格式并提取令牌 - if !strings.HasPrefix(authHeader, "Bearer ") { - return "", fmt.Errorf("invalid Authorization header format, must start with 'Bearer '") - } - - // 提取令牌值 - token := strings.TrimPrefix(authHeader, "Bearer ") - if token == "" { - return "", fmt.Errorf("empty token in Authorization header") - } - - return token, nil -} - -// 设置CORS头 -func setCORSHeaders(w http.ResponseWriter) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") -} - -// 生成UUID -func generateUUID() string { - b := make([]byte, 16) - _, err := rand.Read(b) - if err != nil { - return fmt.Sprintf("%d", time.Now().UnixNano()) - } - - return fmt.Sprintf("%x-%x-%x-%x-%x", - b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) -} - -// 安全的HTTP客户端,支持禁用SSL验证 -func getHTTPClient() *http.Client { - tr := &http.Transport{ - MaxIdleConnsPerHost: 100, - IdleConnTimeout: 90 * time.Second, - TLSClientConfig: nil, // 默认配置 - } - - // 如果配置了禁用SSL验证 - if !appConfig.VerifySSL { - tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - - return &http.Client{ - Timeout: time.Duration(appConfig.Timeout) * time.Second, - Transport: tr, - } +// 持续读取响应 +buffer := "" +for { +// 读取一块数据 +chunk := make([]byte, 4096) +n, err := reader.Read(chunk) +if err != nil { +if err != io.EOF { +return contentBuilder.String(), err } - -// 主入口函数 -func main() { - // 解析配置 - appConfig = parseFlags() - - // 初始化日志 - initLogger(appConfig.LogLevel) - - logInfo("启动服务: 地址=%s, 端口=%s, 版本=%s, 日志级别=%s", - appConfig.Address, appConfig.Port, Version, appConfig.LogLevel) - - // 创建工作池和信号量 - workerPool = NewWorkerPool(appConfig.WorkerCount, appConfig.QueueSize) - requestSem = NewSemaphore(appConfig.MaxConcurrent) - - logInfo("工作池已创建: %d个worker, 队列大小为%d", appConfig.WorkerCount, appConfig.QueueSize) - - // 配置更高的并发处理能力 - http.DefaultTransport.(*http.Transport).MaxIdleConnsPerHost = 100 - http.DefaultTransport.(*http.Transport).MaxIdleConns = 100 - http.DefaultTransport.(*http.Transport).IdleConnTimeout = 90 * time.Second - - // 创建自定义服务器,支持更高并发 - server := &http.Server{ - Addr: appConfig.Address + ":" + appConfig.Port, - ReadTimeout: time.Duration(appConfig.Timeout) * time.Second, - WriteTimeout: time.Duration(appConfig.Timeout) * time.Second, - IdleTimeout: 120 * time.Second, - Handler: nil, // 使用默认的ServeMux - } - - // API路径前缀 - apiPrefix := appConfig.APIPrefix - - // 创建处理器 - http.HandleFunc(apiPrefix+"/v1/models", func(w http.ResponseWriter, r *http.Request) { - setCORSHeaders(w) - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - // 计数 - countMutex.Lock() - requestCount++ - currentCount := requestCount - countMutex.Unlock() - - reqID := generateRequestID() - logInfo("[reqID:%s] 收到模型列表请求 #%d", reqID, currentCount) - - // 请求计数 - atomic.AddInt64(&requestCounter, 1) - - startTime := time.Now() - - // 创建任务 - task := &Task{ - r: r, - w: w, - done: make(chan struct{}), - reqID: reqID, - path: "/v1/models", - } - - // 尝试获取信号量 - if !requestSem.TryAcquire() { - // 请求数量超过限制 - atomic.AddInt64(&rejectedRequests, 1) - logWarn("[reqID:%s] 请求被拒绝: 当前并发请求数已达上限", reqID) - w.Header().Set("Retry-After", "30") - http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable) - return - } - - // 释放信号量(在函数返回时) - defer requestSem.Release() - - // 添加到任务队列 - atomic.AddInt64(&queuedRequests, 1) - submitted, err := workerPool.SubmitTask(task) - if !submitted { - atomic.AddInt64(&queuedRequests, -1) - atomic.AddInt64(&rejectedRequests, 1) - logError("[reqID:%s] 提交任务失败: %v", reqID, err) - w.Header().Set("Retry-After", "60") - http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable) - return - } - - logInfo("[reqID:%s] 任务已提交到队列", reqID) - - // 等待任务完成或超时 - select { - case <-task.done: - // 任务已完成 - logInfo("[reqID:%s] 任务已完成", reqID) - case <-r.Context().Done(): - // 请求被取消或超时 - logWarn("[reqID:%s] 请求被取消或超时", reqID) - } - - // 请求处理完成,更新指标 - atomic.AddInt64(&queuedRequests, -1) - elapsed := time.Since(startTime).Milliseconds() - - // 更新平均响应时间 - atomic.AddInt64(&avgResponseTime, elapsed) - - if r.Context().Err() == nil { - // 成功计数增加 - atomic.AddInt64(&successCounter, 1) - logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed) - } else { - logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed) - } - }) - - http.HandleFunc(apiPrefix+"/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - setCORSHeaders(w) - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - // 计数器增加 - countMutex.Lock() - requestCount++ - currentCount := requestCount - countMutex.Unlock() - - reqID := generateRequestID() - logInfo("[reqID:%s] 收到新请求 #%d", reqID, currentCount) - - // 请求计数 - atomic.AddInt64(&requestCounter, 1) - - startTime := time.Now() - - // 尝试获取信号量 - if !requestSem.TryAcquire() { - // 请求数量超过限制 - atomic.AddInt64(&rejectedRequests, 1) - logWarn("[reqID:%s] 请求 #%d 被拒绝: 当前并发请求数已达上限", reqID, currentCount) - w.Header().Set("Retry-After", "30") - http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable) - return - } - - // 释放信号量(在函数返回时) - defer requestSem.Release() - - // 解析请求体 - var apiReq APIRequest - if err := json.NewDecoder(r.Body).Decode(&apiReq); err != nil { - logError("[reqID:%s] 解析请求失败: %v", reqID, err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return - } - - // 创建任务 - task := &Task{ - r: r, - w: w, - done: make(chan struct{}), - reqID: reqID, - isStream: apiReq.Stream, - apiReq: apiReq, - path: "/v1/chat/completions", - } - - // 添加到任务队列 - atomic.AddInt64(&queuedRequests, 1) - submitted, err := workerPool.SubmitTask(task) - if !submitted { - atomic.AddInt64(&queuedRequests, -1) - atomic.AddInt64(&rejectedRequests, 1) - logError("[reqID:%s] 提交任务失败: %v", reqID, err) - w.Header().Set("Retry-After", "60") - http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable) - return - } - - logInfo("[reqID:%s] 任务已提交到队列", reqID) - - // 等待任务完成或超时 - select { - case <-task.done: - // 任务已完成 - logInfo("[reqID:%s] 任务已完成", reqID) - case <-r.Context().Done(): - // 请求被取消或超时 - logWarn("[reqID:%s] 请求被取消或超时", reqID) - } - - // 请求处理完成,更新指标 - atomic.AddInt64(&queuedRequests, -1) - elapsed := time.Since(startTime).Milliseconds() - - // 更新平均响应时间 - atomic.AddInt64(&avgResponseTime, elapsed) - - if r.Context().Err() == nil { - // 成功计数增加 - atomic.AddInt64(&successCounter, 1) - logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed) - } else { - logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed) - } - }) - - http.HandleFunc(apiPrefix+"/v1/images/generations", func(w http.ResponseWriter, r *http.Request) { - setCORSHeaders(w) - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - // 计数器增加 - countMutex.Lock() - requestCount++ - currentCount := requestCount - countMutex.Unlock() - - reqID := generateRequestID() - logInfo("[reqID:%s] 收到图像生成请求 #%d", reqID, currentCount) - - // 请求计数 - atomic.AddInt64(&requestCounter, 1) - - startTime := time.Now() - - // 尝试获取信号量 - if !requestSem.TryAcquire() { - // 请求数量超过限制 - atomic.AddInt64(&rejectedRequests, 1) - logWarn("[reqID:%s] 请求 #%d 被拒绝: 当前并发请求数已达上限", reqID, currentCount) - w.Header().Set("Retry-After", "30") - http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable) - return - } - - // 释放信号量(在函数返回时) - defer requestSem.Release() - - // 解析请求体 - var apiReq APIRequest - if err := json.NewDecoder(r.Body).Decode(&apiReq); err != nil { - logError("[reqID:%s] 解析请求失败: %v", reqID, err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return - } - - // 创建任务 - task := &Task{ - r: r, - w: w, - done: make(chan struct{}), - reqID: reqID, - apiReq: apiReq, - path: "/v1/images/generations", - } - - // 添加到任务队列 - atomic.AddInt64(&queuedRequests, 1) - submitted, err := workerPool.SubmitTask(task) - if !submitted { - atomic.AddInt64(&queuedRequests, -1) - atomic.AddInt64(&rejectedRequests, 1) - logError("[reqID:%s] 提交任务失败: %v", reqID, err) - w.Header().Set("Retry-After", "60") - http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable) - return - } - - logInfo("[reqID:%s] 任务已提交到队列", reqID) - - // 等待任务完成或超时 - select { - case <-task.done: - // 任务已完成 - logInfo("[reqID:%s] 任务已完成", reqID) - case <-r.Context().Done(): - // 请求被取消或超时 - logWarn("[reqID:%s] 请求被取消或超时", reqID) - } - - // 请求处理完成,更新指标 - atomic.AddInt64(&queuedRequests, -1) - elapsed := time.Since(startTime).Milliseconds() - - // 更新平均响应时间 - atomic.AddInt64(&avgResponseTime, elapsed) - - if r.Context().Err() == nil { - // 成功计数增加 - atomic.AddInt64(&successCounter, 1) - logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed) - } else { - logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed) - } - }) - - // 添加健康检查端点 - http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - setCORSHeaders(w) - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - // 获取各种计数器的值 - reqCount := atomic.LoadInt64(&requestCounter) - succCount := atomic.LoadInt64(&successCounter) - errCount := atomic.LoadInt64(&errorCounter) - queuedCount := atomic.LoadInt64(&queuedRequests) - rejectedCount := atomic.LoadInt64(&rejectedRequests) - - // 计算平均响应时间 - var avgTime int64 = 0 - if reqCount > 0 { - avgTime = atomic.LoadInt64(&avgResponseTime) / reqCount - } - - // 构建响应 - stats := map[string]interface{}{ - "status": "ok", - "version": Version, - "requests": reqCount, - "success": succCount, - "errors": errCount, - "queued": queuedCount, - "rejected": rejectedCount, - "avg_time_ms": avgTime, - "worker_count": workerPool.workerCount, - "queue_size": len(workerPool.taskQueue), - "queue_capacity": cap(workerPool.taskQueue), - "queue_percent": float64(len(workerPool.taskQueue)) / float64(cap(workerPool.taskQueue)) * 100, - "concurrent_limit": appConfig.MaxConcurrent, - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(stats) - }) - - // 创建停止通道 - stop := make(chan os.Signal, 1) - signal.Notify(stop, os.Interrupt, syscall.SIGTERM) - - // 在goroutine中启动服务器 - go func() { - logInfo("Starting proxy server on %s", server.Addr) - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logError("Failed to start server: %v", err) - os.Exit(1) - } - }() - - // 等待停止信号 - <-stop - - // 创建上下文用于优雅关闭 - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // 优雅关闭服务器 - logInfo("Server is shutting down...") - if err := server.Shutdown(ctx); err != nil { - logError("Server shutdown failed: %v", err) - } - - // 关闭工作池 - workerPool.Shutdown() - - logInfo("Server gracefully stopped") -} - -// 生成请求ID -func generateRequestID() string { - return fmt.Sprintf("%x", time.Now().UnixNano()) -} - -// 处理模型列表请求 -func handleModels(w http.ResponseWriter, r *http.Request) { - logInfo("处理模型列表请求") - - // 从请求中提取token - authToken, err := extractToken(r) - if err != nil { - logWarn("提取token失败: %v", err) - // 使用默认模型列表 - returnDefaultModels(w) - return - } - - // 请求通义千问API获取模型列表 - client := getHTTPClient() - req, err := http.NewRequest("GET", ModelsURL, nil) - if err != nil { - logError("创建请求失败: %v", err) - returnDefaultModels(w) - return - } - - // 设置请求头 - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Mozilla/5.0") - - // 发送请求 - resp, err := client.Do(req) - if err != nil { - logError("请求模型列表失败: %v", err) - returnDefaultModels(w) - return - } - defer resp.Body.Close() - - // 检查响应状态 - if resp.StatusCode != http.StatusOK { - logError("获取模型列表返回非200状态码: %d", resp.StatusCode) - returnDefaultModels(w) - return - } - - // 解析响应 - var qwenResp struct { - Data []struct { - ID string `json:"id"` - } `json:"data"` - } - if err := json.NewDecoder(resp.Body).Decode(&qwenResp); err != nil { - logError("解析模型列表响应失败: %v", err) - returnDefaultModels(w) - return - } - - // 提取模型ID - models := make([]string, 0, len(qwenResp.Data)) - for _, model := range qwenResp.Data { - models = append(models, model.ID) - } - - // 如果没有获取到模型,使用默认列表 - if len(models) == 0 { - logWarn("未获取到模型,使用默认列表") - returnDefaultModels(w) - return - } - - // 扩展模型列表,增加变种后缀 - expandedModels := make([]ModelData, 0, len(models)*len(ModelSuffixes)) - for _, model := range models { - for _, suffix := range ModelSuffixes { - expandedModels = append(expandedModels, ModelData{ - ID: model + suffix, - Object: "model", - Created: time.Now().Unix(), - OwnedBy: "qwen", - }) - } - } - - // 构建响应 - modelsResp := ModelsResponse{ - Object: "list", - Data: expandedModels, - } - - // 返回响应 - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(modelsResp) -} - -// 返回默认模型列表 -func returnDefaultModels(w http.ResponseWriter) { - // 扩展默认模型列表,增加变种后缀 - expandedModels := make([]ModelData, 0, len(DefaultModels)*len(ModelSuffixes)) - for _, model := range DefaultModels { - for _, suffix := range ModelSuffixes { - expandedModels = append(expandedModels, ModelData{ - ID: model + suffix, - Object: "model", - Created: time.Now().Unix(), - OwnedBy: "qwen", - }) - } - } - - // 构建响应 - modelsResp := ModelsResponse{ - Object: "list", - Data: expandedModels, - } - - // 返回响应 - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(modelsResp) -} - -// 处理聊天完成请求(流式) -func handleStreamingRequest(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string) { - logInfo("[reqID:%s] 处理流式请求", reqID) - - // 从请求中提取token - authToken, err := extractToken(r) - if err != nil { - logError("[reqID:%s] 提取token失败: %v", reqID, err) - http.Error(w, "无效的认证信息", http.StatusUnauthorized) - return - } - - // 检查消息 - if len(apiReq.Messages) == 0 { - logError("[reqID:%s] 消息为空", reqID) - http.Error(w, "消息为空", http.StatusBadRequest) - return - } - - // 准备模型名和聊天类型 - modelName := "qwen-turbo-latest" - if apiReq.Model != "" { - modelName = apiReq.Model - } - chatType := "t2t" - - // 处理特殊模型名后缀 - if strings.Contains(modelName, "-draw") { - handleDrawRequest(w, r, apiReq, reqID, authToken) - return - } - - // 处理思考模式 - if strings.Contains(modelName, "-thinking") { - modelName = strings.Replace(modelName, "-thinking", "", 1) - lastMsgIdx := len(apiReq.Messages) - 1 - if lastMsgIdx >= 0 { - apiReq.Messages[lastMsgIdx].FeatureConfig = map[string]interface{}{ - "thinking_enabled": true, - } - } - } - - // 处理搜索模式 - if strings.Contains(modelName, "-search") { - modelName = strings.Replace(modelName, "-search", "", 1) - chatType = "search" - lastMsgIdx := len(apiReq.Messages) - 1 - if lastMsgIdx >= 0 { - apiReq.Messages[lastMsgIdx].ChatType = "search" - } - } - - // 处理图片消息 - lastMsgIdx := len(apiReq.Messages) - 1 - if lastMsgIdx >= 0 { - lastMsg := apiReq.Messages[lastMsgIdx] - - // 检查内容是否为数组 - contentArray, ok := lastMsg.Content.([]interface{}) - if ok { - // 处理内容数组 - for i, item := range contentArray { - itemMap, isMap := item.(map[string]interface{}) - if !isMap { - continue - } - - // 检查是否包含图像URL - if imageURL, hasImageURL := itemMap["image_url"]; hasImageURL { - imageURLMap, isMap := imageURL.(map[string]interface{}) - if !isMap { - continue - } - - // 获取URL - url, hasURL := imageURLMap["url"].(string) - if !hasURL { - continue - } - - // 上传图像 - imageID, uploadErr := uploadImage(url, authToken) - if uploadErr != nil { - logError("[reqID:%s] 上传图像失败: %v", reqID, uploadErr) - continue - } - - // 替换内容 - contentArrayCopy := make([]interface{}, len(contentArray)) - copy(contentArrayCopy, contentArray) - contentArrayCopy[i] = map[string]interface{}{ - "type": "image", - "image": imageID, - } - apiReq.Messages[lastMsgIdx].Content = contentArrayCopy - break - } - } - } - } - - // 创建通义千问请求 - qwenReq := QwenRequest{ - Model: modelName, - Messages: apiReq.Messages, - Stream: true, - ChatType: chatType, - ID: generateUUID(), - } - - // 序列化请求 - reqData, err := json.Marshal(qwenReq) - if err != nil { - logError("[reqID:%s] 序列化请求失败: %v", reqID, err) - http.Error(w, "内部服务器错误", http.StatusInternalServerError) - return - } - - // 创建HTTP请求 - req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData)) - if err != nil { - logError("[reqID:%s] 创建请求失败: %v", reqID, err) - http.Error(w, "内部服务器错误", http.StatusInternalServerError) - return - } - - // 设置请求头 - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Mozilla/5.0") - - // 发送请求 - client := getHTTPClient() - resp, err := client.Do(req) - if err != nil { - logError("[reqID:%s] 发送请求失败: %v", reqID, err) - http.Error(w, "连接到API失败", http.StatusBadGateway) - return - } - defer resp.Body.Close() - - // 检查响应状态 - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes)) - http.Error(w, fmt.Sprintf("API错误,状态码: %d", resp.StatusCode), resp.StatusCode) - return - } - - // 设置响应头 - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - - // 创建响应ID和时间戳 - respID := fmt.Sprintf("chatcmpl-%s", generateUUID()) - createdTime := time.Now().Unix() - - // 创建读取器和Flusher - reader := bufio.NewReaderSize(resp.Body, 16384) - flusher, ok := w.(http.Flusher) - if !ok { - logError("[reqID:%s] 流式传输不支持", reqID) - http.Error(w, "流式传输不支持", http.StatusInternalServerError) - return - } - - // 发送角色块 - roleChunk := createRoleChunk(respID, createdTime, modelName) - w.Write([]byte("data: " + string(roleChunk) + "\n\n")) - flusher.Flush() - - // 用于去重的前一个内容 - previousContent := "" - - // 创建正则表达式来查找 data: 行 - dataRegex := regexp.MustCompile(`(?m)^data: (.+)package main - -import ( - "bufio" - "bytes" - "context" - "crypto/rand" - "crypto/tls" - "encoding/base64" - "encoding/json" - "flag" - "fmt" - "io" - "log" - "mime/multipart" - "net/http" - "os" - "os/signal" - "regexp" - "strings" - "sync" - "sync/atomic" - "syscall" - "time" -) - -// 版本和API常量 -const ( - Version = "1.0.0" - TargetURL = "https://chat.qwen.ai/api/chat/completions" - ModelsURL = "https://chat.qwen.ai/api/models" - FilesURL = "https://chat.qwen.ai/api/v1/files/" - TasksURL = "https://chat.qwen.ai/api/v1/tasks/status/" -) - -// 默认模型列表(当获取接口失败时使用) -var DefaultModels = []string{ - "qwen-max-latest", - "qwen-plus-latest", - "qwen2.5-vl-72b-instruct", - "qwen2.5-14b-instruct-1m", - "qvq-72b-preview", - "qwq-32b-preview", - "qwen2.5-coder-32b-instruct", - "qwen-turbo-latest", - "qwen2.5-72b-instruct", -} - -// 扩展模型变种后缀 -var ModelSuffixes = []string{ - "", - "-thinking", - "-search", - "-thinking-search", - "-draw", -} - -// 日志级别常量 -const ( - LogLevelDebug = "debug" - LogLevelInfo = "info" - LogLevelWarn = "warn" - LogLevelError = "error" -) - -// WorkerPool 工作池结构体,用于管理goroutine -type WorkerPool struct { - taskQueue chan *Task - workerCount int - shutdownChannel chan struct{} - wg sync.WaitGroup +break } -// Task 任务结构体,包含请求处理所需数据 -type Task struct { - r *http.Request - w http.ResponseWriter - done chan struct{} - reqID string - isStream bool - apiReq APIRequest - path string -} +// 添加到缓冲区 +buffer += string(chunk[:n]) -// Semaphore 信号量实现,用于限制并发数量 -type Semaphore struct { - sem chan struct{} +// 更稳健的处理方式:按行分割并只处理完整行 +lines := strings.Split(buffer, "\n") +// 保留最后可能不完整的行 +if len(lines) > 0 { +buffer = lines[len(lines)-1] } -// 配置结构体 -type Config struct { - Port string - Address string - LogLevel string - DevMode bool - MaxRetries int - Timeout int - VerifySSL bool - WorkerCount int - QueueSize int - MaxConcurrent int - APIPrefix string +// 处理所有完整的行(除最后一行外) +for i := 0; i < len(lines)-1; i++ { +line := lines[i] +if !strings.HasPrefix(line, "data: ") { +continue } -// APIRequest OpenAI兼容的请求结构体 -type APIRequest struct { - Model string `json:"model"` - Messages []APIMessage `json:"messages"` - Stream bool `json:"stream"` - Temperature float64 `json:"temperature,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` -} +// 提取数据部分 +dataStr := strings.TrimPrefix(line, "data: ") -// APIMessage 消息结构体 -type APIMessage struct { - Role string `json:"role"` - Content interface{} `json:"content"` - FeatureConfig interface{} `json:"feature_config,omitempty"` - ChatType string `json:"chat_type,omitempty"` - Extra interface{} `json:"extra,omitempty"` +// 处理[DONE]消息 +if dataStr == "[DONE]" { +logDebug("[reqID:%s] 非流式模式收到[DONE]消息", reqID) +continue } -// 内容项目结构体(处理图像等内容) -type ContentItem struct { - Type string `json:"type,omitempty"` - Text string `json:"text,omitempty"` - ImageURL *ImageURL `json:"image_url,omitempty"` - Image string `json:"image,omitempty"` +// 解析JSON +var qwenResp QwenResponse +if err := json.Unmarshal([]byte(dataStr), &qwenResp); err != nil { +logWarn("[reqID:%s] 解析JSON失败: %v, data: %s", reqID, err, dataStr) +continue } -// ImageURL 图像URL结构体 -type ImageURL struct { - URL string `json:"url"` +// 提取内容 - 累积所有delta内容片段 +for _, choice := range qwenResp.Choices { +if choice.Delta.Content != "" { +contentBuilder.WriteString(choice.Delta.Content) } - -// QwenRequest 通义千问API请求结构体 -type QwenRequest struct { - Model string `json:"model"` - Messages []APIMessage `json:"messages"` - Stream bool `json:"stream"` - ChatType string `json:"chat_type,omitempty"` - ID string `json:"id,omitempty"` - IncrementalOutput bool `json:"incremental_output,omitempty"` - Size string `json:"size,omitempty"` } - -// QwenResponse 通义千问API响应结构体 -type QwenResponse struct { - Messages []struct { - Role string `json:"role"` - Content string `json:"content"` - Extra struct { - Wanx struct { - TaskID string `json:"task_id"` - } `json:"wanx"` - } `json:"extra"` - } `json:"messages"` - Choices []struct { - Delta struct { - Content string `json:"content"` - } `json:"delta"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` -} - -// FileUploadResponse 文件上传响应 -type FileUploadResponse struct { - ID string `json:"id"` -} - -// TaskStatusResponse 任务状态响应 -type TaskStatusResponse struct { - Content string `json:"content"` -} - -// StreamChunk OpenAI兼容的流式响应块 -type StreamChunk struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - Delta struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` - } `json:"choices"` -} - -// CompletionResponse OpenAI兼容的完成响应 -type CompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []struct { - Index int `json:"index"` - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` -} - -// ImagesResponse 图像生成响应 -type ImagesResponse struct { - Created int64 `json:"created"` - Data []ImageURL `json:"data"` -} - -// ImagesRequest 图像生成请求 -type ImagesRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - N int `json:"n"` - Size string `json:"size"` -} - -// ModelData 模型数据 -type ModelData struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - OwnedBy string `json:"owned_by"` -} - -// ModelsResponse 模型列表响应 -type ModelsResponse struct { - Object string `json:"object"` - Data []ModelData `json:"data"` -} - -// 全局变量 -var ( - appConfig *Config - logger *log.Logger - logLevel string - logMutex sync.Mutex - workerPool *WorkerPool - requestSem *Semaphore - requestCount uint64 = 0 - countMutex sync.Mutex - - // 性能指标 - requestCounter int64 - successCounter int64 - errorCounter int64 - avgResponseTime int64 - queuedRequests int64 - rejectedRequests int64 -) - -// NewSemaphore 创建新的信号量 -func NewSemaphore(size int) *Semaphore { - return &Semaphore{ - sem: make(chan struct{}, size), - } } - -// Acquire 获取信号量(阻塞) -func (s *Semaphore) Acquire() { - s.sem <- struct{}{} -} - -// Release 释放信号量 -func (s *Semaphore) Release() { - <-s.sem -} - -// TryAcquire 尝试获取信号量(非阻塞) -func (s *Semaphore) TryAcquire() bool { - select { - case s.sem <- struct{}{}: - return true - default: - return false - } -} - -// NewWorkerPool 创建并启动一个新的工作池 -func NewWorkerPool(workerCount int, queueSize int) *WorkerPool { - pool := &WorkerPool{ - taskQueue: make(chan *Task, queueSize), - workerCount: workerCount, - shutdownChannel: make(chan struct{}), - } - - pool.Start() - return pool -} - -// Start 启动工作池中的worker goroutines -func (pool *WorkerPool) Start() { - // 启动工作goroutine - for i := 0; i < pool.workerCount; i++ { - pool.wg.Add(1) - go func(workerID int) { - defer pool.wg.Done() - - logInfo("Worker %d 已启动", workerID) - - for { - select { - case task, ok := <-pool.taskQueue: - if !ok { - // 队列已关闭,退出worker - logInfo("Worker %d 收到队列关闭信号,准备退出", workerID) - return - } - - logDebug("Worker %d 处理任务 reqID:%s", workerID, task.reqID) - - // 处理任务 - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - - switch task.path { - case "/v1/models": - handleModels(task.w, task.r) - case "/v1/chat/completions": - if task.isStream { - handleStreamingRequest(task.w, task.r, task.apiReq, task.reqID) - } else { - handleNonStreamingRequest(task.w, task.r, task.apiReq, task.reqID) - } - case "/v1/images/generations": - handleImageGenerations(task.w, task.r, task.apiReq, task.reqID) - } - - logInfo("Worker %d 任务 reqID:%s 处理完成", workerID, task.reqID) - }() - - // 等待任务完成后再发送通知 - wg.Wait() - // 通知任务完成 - close(task.done) - - case <-pool.shutdownChannel: - // 收到关闭信号,退出worker - logInfo("Worker %d 收到关闭信号,准备退出", workerID) - return - } - } - }(i) - } -} - -// SubmitTask 提交任务到工作池,非阻塞 -func (pool *WorkerPool) SubmitTask(task *Task) (bool, error) { - select { - case pool.taskQueue <- task: - // 任务成功添加到队列 - return true, nil - default: - // 队列已满 - return false, fmt.Errorf("任务队列已满") - } -} - -// Shutdown 关闭工作池 -func (pool *WorkerPool) Shutdown() { - logInfo("正在关闭工作池...") - - // 发送关闭信号给所有worker - close(pool.shutdownChannel) - - // 等待所有worker退出 - pool.wg.Wait() - - // 关闭任务队列 - close(pool.taskQueue) - - logInfo("工作池已关闭") -} - -// 日志函数 -func initLogger(level string) { - logger = log.New(os.Stdout, "[QwenAPI] ", log.LstdFlags) - logLevel = level -} - -func logDebug(format string, v ...interface{}) { - if logLevel == LogLevelDebug { - logMutex.Lock() - logger.Printf("[DEBUG] "+format, v...) - logMutex.Unlock() - } -} - -func logInfo(format string, v ...interface{}) { - if logLevel == LogLevelDebug || logLevel == LogLevelInfo { - logMutex.Lock() - logger.Printf("[INFO] "+format, v...) - logMutex.Unlock() - } -} - -func logWarn(format string, v ...interface{}) { - if logLevel == LogLevelDebug || logLevel == LogLevelInfo || logLevel == LogLevelWarn { - logMutex.Lock() - logger.Printf("[WARN] "+format, v...) - logMutex.Unlock() - } -} - -func logError(format string, v ...interface{}) { - logMutex.Lock() - logger.Printf("[ERROR] "+format, v...) - logMutex.Unlock() - - // 错误计数 - atomic.AddInt64(&errorCounter, 1) -} - -// 解析命令行参数 -func parseFlags() *Config { - cfg := &Config{} - flag.StringVar(&cfg.Port, "port", "8080", "Port to listen on") - flag.StringVar(&cfg.Address, "address", "localhost", "Address to listen on") - flag.StringVar(&cfg.LogLevel, "log-level", LogLevelInfo, "Log level (debug, info, warn, error)") - flag.BoolVar(&cfg.DevMode, "dev", false, "Enable development mode with enhanced logging") - flag.IntVar(&cfg.MaxRetries, "max-retries", 3, "Maximum number of retries for failed requests") - flag.IntVar(&cfg.Timeout, "timeout", 300, "Request timeout in seconds") - flag.BoolVar(&cfg.VerifySSL, "verify-ssl", true, "Verify SSL certificates") - flag.IntVar(&cfg.WorkerCount, "workers", 50, "Number of worker goroutines in the pool") - flag.IntVar(&cfg.QueueSize, "queue-size", 500, "Size of the task queue") - flag.IntVar(&cfg.MaxConcurrent, "max-concurrent", 100, "Maximum number of concurrent requests") - flag.StringVar(&cfg.APIPrefix, "api-prefix", "", "API prefix for all endpoints") - flag.Parse() - - // 如果开发模式开启,自动设置日志级别为debug - if cfg.DevMode && cfg.LogLevel != LogLevelDebug { - cfg.LogLevel = LogLevelDebug - fmt.Println("开发模式已启用,日志级别设置为debug") - } - - return cfg -} - -// 从请求头中提取令牌 -func extractToken(r *http.Request) (string, error) { - // 获取 Authorization 头部 - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", fmt.Errorf("missing Authorization header") - } - - // 验证格式并提取令牌 - if !strings.HasPrefix(authHeader, "Bearer ") { - return "", fmt.Errorf("invalid Authorization header format, must start with 'Bearer '") - } - - // 提取令牌值 - token := strings.TrimPrefix(authHeader, "Bearer ") - if token == "" { - return "", fmt.Errorf("empty token in Authorization header") - } - - return token, nil -} - -// 设置CORS头 -func setCORSHeaders(w http.ResponseWriter) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") -} - -// 生成UUID -func generateUUID() string { - b := make([]byte, 16) - _, err := rand.Read(b) - if err != nil { - return fmt.Sprintf("%d", time.Now().UnixNano()) - } - - return fmt.Sprintf("%x-%x-%x-%x-%x", - b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) -} - -// 安全的HTTP客户端,支持禁用SSL验证 -func getHTTPClient() *http.Client { - tr := &http.Transport{ - MaxIdleConnsPerHost: 100, - IdleConnTimeout: 90 * time.Second, - TLSClientConfig: nil, // 默认配置 - } - - // 如果配置了禁用SSL验证 - if !appConfig.VerifySSL { - tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - - return &http.Client{ - Timeout: time.Duration(appConfig.Timeout) * time.Second, - Transport: tr, - } -} - -// 主入口函数 -func main() { - // 解析配置 - appConfig = parseFlags() - - // 初始化日志 - initLogger(appConfig.LogLevel) - - logInfo("启动服务: 地址=%s, 端口=%s, 版本=%s, 日志级别=%s", - appConfig.Address, appConfig.Port, Version, appConfig.LogLevel) - - // 创建工作池和信号量 - workerPool = NewWorkerPool(appConfig.WorkerCount, appConfig.QueueSize) - requestSem = NewSemaphore(appConfig.MaxConcurrent) - - logInfo("工作池已创建: %d个worker, 队列大小为%d", appConfig.WorkerCount, appConfig.QueueSize) - - // 配置更高的并发处理能力 - http.DefaultTransport.(*http.Transport).MaxIdleConnsPerHost = 100 - http.DefaultTransport.(*http.Transport).MaxIdleConns = 100 - http.DefaultTransport.(*http.Transport).IdleConnTimeout = 90 * time.Second - - // 创建自定义服务器,支持更高并发 - server := &http.Server{ - Addr: appConfig.Address + ":" + appConfig.Port, - ReadTimeout: time.Duration(appConfig.Timeout) * time.Second, - WriteTimeout: time.Duration(appConfig.Timeout) * time.Second, - IdleTimeout: 120 * time.Second, - Handler: nil, // 使用默认的ServeMux - } - - // API路径前缀 - apiPrefix := appConfig.APIPrefix - - // 创建处理器 - http.HandleFunc(apiPrefix+"/v1/models", func(w http.ResponseWriter, r *http.Request) { - setCORSHeaders(w) - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - // 计数 - countMutex.Lock() - requestCount++ - currentCount := requestCount - countMutex.Unlock() - - reqID := generateRequestID() - logInfo("[reqID:%s] 收到模型列表请求 #%d", reqID, currentCount) - - // 请求计数 - atomic.AddInt64(&requestCounter, 1) - - startTime := time.Now() - - // 创建任务 - task := &Task{ - r: r, - w: w, - done: make(chan struct{}), - reqID: reqID, - path: "/v1/models", - } - - // 尝试获取信号量 - if !requestSem.TryAcquire() { - // 请求数量超过限制 - atomic.AddInt64(&rejectedRequests, 1) - logWarn("[reqID:%s] 请求被拒绝: 当前并发请求数已达上限", reqID) - w.Header().Set("Retry-After", "30") - http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable) - return - } - - // 释放信号量(在函数返回时) - defer requestSem.Release() - - // 添加到任务队列 - atomic.AddInt64(&queuedRequests, 1) - submitted, err := workerPool.SubmitTask(task) - if !submitted { - atomic.AddInt64(&queuedRequests, -1) - atomic.AddInt64(&rejectedRequests, 1) - logError("[reqID:%s] 提交任务失败: %v", reqID, err) - w.Header().Set("Retry-After", "60") - http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable) - return - } - - logInfo("[reqID:%s] 任务已提交到队列", reqID) - - // 等待任务完成或超时 - select { - case <-task.done: - // 任务已完成 - logInfo("[reqID:%s] 任务已完成", reqID) - case <-r.Context().Done(): - // 请求被取消或超时 - logWarn("[reqID:%s] 请求被取消或超时", reqID) - } - - // 请求处理完成,更新指标 - atomic.AddInt64(&queuedRequests, -1) - elapsed := time.Since(startTime).Milliseconds() - - // 更新平均响应时间 - atomic.AddInt64(&avgResponseTime, elapsed) - - if r.Context().Err() == nil { - // 成功计数增加 - atomic.AddInt64(&successCounter, 1) - logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed) - } else { - logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed) - } - }) - - http.HandleFunc(apiPrefix+"/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { - setCORSHeaders(w) - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - // 计数器增加 - countMutex.Lock() - requestCount++ - currentCount := requestCount - countMutex.Unlock() - - reqID := generateRequestID() - logInfo("[reqID:%s] 收到新请求 #%d", reqID, currentCount) - - // 请求计数 - atomic.AddInt64(&requestCounter, 1) - - startTime := time.Now() - - // 尝试获取信号量 - if !requestSem.TryAcquire() { - // 请求数量超过限制 - atomic.AddInt64(&rejectedRequests, 1) - logWarn("[reqID:%s] 请求 #%d 被拒绝: 当前并发请求数已达上限", reqID, currentCount) - w.Header().Set("Retry-After", "30") - http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable) - return - } - - // 释放信号量(在函数返回时) - defer requestSem.Release() - - // 解析请求体 - var apiReq APIRequest - if err := json.NewDecoder(r.Body).Decode(&apiReq); err != nil { - logError("[reqID:%s] 解析请求失败: %v", reqID, err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return - } - - // 创建任务 - task := &Task{ - r: r, - w: w, - done: make(chan struct{}), - reqID: reqID, - isStream: apiReq.Stream, - apiReq: apiReq, - path: "/v1/chat/completions", - } - - // 添加到任务队列 - atomic.AddInt64(&queuedRequests, 1) - submitted, err := workerPool.SubmitTask(task) - if !submitted { - atomic.AddInt64(&queuedRequests, -1) - atomic.AddInt64(&rejectedRequests, 1) - logError("[reqID:%s] 提交任务失败: %v", reqID, err) - w.Header().Set("Retry-After", "60") - http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable) - return - } - - logInfo("[reqID:%s] 任务已提交到队列", reqID) - - // 等待任务完成或超时 - select { - case <-task.done: - // 任务已完成 - logInfo("[reqID:%s] 任务已完成", reqID) - case <-r.Context().Done(): - // 请求被取消或超时 - logWarn("[reqID:%s] 请求被取消或超时", reqID) - } - - // 请求处理完成,更新指标 - atomic.AddInt64(&queuedRequests, -1) - elapsed := time.Since(startTime).Milliseconds() - - // 更新平均响应时间 - atomic.AddInt64(&avgResponseTime, elapsed) - - if r.Context().Err() == nil { - // 成功计数增加 - atomic.AddInt64(&successCounter, 1) - logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed) - } else { - logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed) - } - }) - - http.HandleFunc(apiPrefix+"/v1/images/generations", func(w http.ResponseWriter, r *http.Request) { - setCORSHeaders(w) - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - // 计数器增加 - countMutex.Lock() - requestCount++ - currentCount := requestCount - countMutex.Unlock() - - reqID := generateRequestID() - logInfo("[reqID:%s] 收到图像生成请求 #%d", reqID, currentCount) - - // 请求计数 - atomic.AddInt64(&requestCounter, 1) - - startTime := time.Now() - - // 尝试获取信号量 - if !requestSem.TryAcquire() { - // 请求数量超过限制 - atomic.AddInt64(&rejectedRequests, 1) - logWarn("[reqID:%s] 请求 #%d 被拒绝: 当前并发请求数已达上限", reqID, currentCount) - w.Header().Set("Retry-After", "30") - http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable) - return - } - - // 释放信号量(在函数返回时) - defer requestSem.Release() - - // 解析请求体 - var apiReq APIRequest - if err := json.NewDecoder(r.Body).Decode(&apiReq); err != nil { - logError("[reqID:%s] 解析请求失败: %v", reqID, err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return - } - - // 创建任务 - task := &Task{ - r: r, - w: w, - done: make(chan struct{}), - reqID: reqID, - apiReq: apiReq, - path: "/v1/images/generations", - } - - // 添加到任务队列 - atomic.AddInt64(&queuedRequests, 1) - submitted, err := workerPool.SubmitTask(task) - if !submitted { - atomic.AddInt64(&queuedRequests, -1) - atomic.AddInt64(&rejectedRequests, 1) - logError("[reqID:%s] 提交任务失败: %v", reqID, err) - w.Header().Set("Retry-After", "60") - http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable) - return - } - - logInfo("[reqID:%s] 任务已提交到队列", reqID) - - // 等待任务完成或超时 - select { - case <-task.done: - // 任务已完成 - logInfo("[reqID:%s] 任务已完成", reqID) - case <-r.Context().Done(): - // 请求被取消或超时 - logWarn("[reqID:%s] 请求被取消或超时", reqID) - } - - // 请求处理完成,更新指标 - atomic.AddInt64(&queuedRequests, -1) - elapsed := time.Since(startTime).Milliseconds() - - // 更新平均响应时间 - atomic.AddInt64(&avgResponseTime, elapsed) - - if r.Context().Err() == nil { - // 成功计数增加 - atomic.AddInt64(&successCounter, 1) - logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed) - } else { - logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed) - } - }) - - // 添加健康检查端点 - http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - setCORSHeaders(w) - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - // 获取各种计数器的值 - reqCount := atomic.LoadInt64(&requestCounter) - succCount := atomic.LoadInt64(&successCounter) - errCount := atomic.LoadInt64(&errorCounter) - queuedCount := atomic.LoadInt64(&queuedRequests) - rejectedCount := atomic.LoadInt64(&rejectedRequests) - - // 计算平均响应时间 - var avgTime int64 = 0 - if reqCount > 0 { - avgTime = atomic.LoadInt64(&avgResponseTime) / reqCount - } - - // 构建响应 - stats := map[string]interface{}{ - "status": "ok", - "version": Version, - "requests": reqCount, - "success": succCount, - "errors": errCount, - "queued": queuedCount, - "rejected": rejectedCount, - "avg_time_ms": avgTime, - "worker_count": workerPool.workerCount, - "queue_size": len(workerPool.taskQueue), - "queue_capacity": cap(workerPool.taskQueue), - "queue_percent": float64(len(workerPool.taskQueue)) / float64(cap(workerPool.taskQueue)) * 100, - "concurrent_limit": appConfig.MaxConcurrent, - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(stats) - }) - - // 创建停止通道 - stop := make(chan os.Signal, 1) - signal.Notify(stop, os.Interrupt, syscall.SIGTERM) - - // 在goroutine中启动服务器 - go func() { - logInfo("Starting proxy server on %s", server.Addr) - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logError("Failed to start server: %v", err) - os.Exit(1) - } - }() - - // 等待停止信号 - <-stop - - // 创建上下文用于优雅关闭 - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // 优雅关闭服务器 - logInfo("Server is shutting down...") - if err := server.Shutdown(ctx); err != nil { - logError("Server shutdown failed: %v", err) - } - - // 关闭工作池 - workerPool.Shutdown() - - logInfo("Server gracefully stopped") -} - -// 生成请求ID -func generateRequestID() string { - return fmt.Sprintf("%x", time.Now().UnixNano()) -} - -// 处理模型列表请求 -func handleModels(w http.ResponseWriter, r *http.Request) { - logInfo("处理模型列表请求") - - // 从请求中提取token - authToken, err := extractToken(r) - if err != nil { - logWarn("提取token失败: %v", err) - // 使用默认模型列表 - returnDefaultModels(w) - return - } - - // 请求通义千问API获取模型列表 - client := getHTTPClient() - req, err := http.NewRequest("GET", ModelsURL, nil) - if err != nil { - logError("创建请求失败: %v", err) - returnDefaultModels(w) - return - } - - // 设置请求头 - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Mozilla/5.0") - - // 发送请求 - resp, err := client.Do(req) - if err != nil { - logError("请求模型列表失败: %v", err) - returnDefaultModels(w) - return - } - defer resp.Body.Close() - - // 检查响应状态 - if resp.StatusCode != http.StatusOK { - logError("获取模型列表返回非200状态码: %d", resp.StatusCode) - returnDefaultModels(w) - return - } - - // 解析响应 - var qwenResp struct { - Data []struct { - ID string `json:"id"` - } `json:"data"` - } - if err := json.NewDecoder(resp.Body).Decode(&qwenResp); err != nil { - logError("解析模型列表响应失败: %v", err) - returnDefaultModels(w) - return - } - - // 提取模型ID - models := make([]string, 0, len(qwenResp.Data)) - for _, model := range qwenResp.Data { - models = append(models, model.ID) - } - - // 如果没有获取到模型,使用默认列表 - if len(models) == 0 { - logWarn("未获取到模型,使用默认列表") - returnDefaultModels(w) - return - } - - // 扩展模型列表,增加变种后缀 - expandedModels := make([]ModelData, 0, len(models)*len(ModelSuffixes)) - for _, model := range models { - for _, suffix := range ModelSuffixes { - expandedModels = append(expandedModels, ModelData{ - ID: model + suffix, - Object: "model", - Created: time.Now().Unix(), - OwnedBy: "qwen", - }) - } - } - - // 构建响应 - modelsResp := ModelsResponse{ - Object: "list", - Data: expandedModels, - } - - // 返回响应 - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(modelsResp) -} - -// 返回默认模型列表 -func returnDefaultModels(w http.ResponseWriter) { - // 扩展默认模型列表,增加变种后缀 - expandedModels := make([]ModelData, 0, len(DefaultModels)*len(ModelSuffixes)) - for _, model := range DefaultModels { - for _, suffix := range ModelSuffixes { - expandedModels = append(expandedModels, ModelData{ - ID: model + suffix, - Object: "model", - Created: time.Now().Unix(), - OwnedBy: "qwen", - }) - } - } - - // 构建响应 - modelsResp := ModelsResponse{ - Object: "list", - Data: expandedModels, - } - - // 返回响应 - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(modelsResp) -} - -// 处理聊天完成请求(流式) -func handleStreamingRequest(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string) { - logInfo("[reqID:%s] 处理流式请求", reqID) - - // 从请求中提取token - authToken, err := extractToken(r) - if err != nil { - logError("[reqID:%s] 提取token失败: %v", reqID, err) - http.Error(w, "无效的认证信息", http.StatusUnauthorized) - return - } - - // 检查消息 - if len(apiReq.Messages) == 0 { - logError("[reqID:%s] 消息为空", reqID) - http.Error(w, "消息为空", http.StatusBadRequest) - return - } - - // 准备模型名和聊天类型 - modelName := "qwen-turbo-latest" - if apiReq.Model != "" { - modelName = apiReq.Model - } - chatType := "t2t" - - // 处理特殊模型名后缀 - if strings.Contains(modelName, "-draw") { - handleDrawRequest(w, r, apiReq, reqID, authToken) - return - } - - // 处理思考模式 - if strings.Contains(modelName, "-thinking") { - modelName = strings.Replace(modelName, "-thinking", "", 1) - lastMsgIdx := len(apiReq.Messages) - 1 - if lastMsgIdx >= 0 { - apiReq.Messages[lastMsgIdx].FeatureConfig = map[string]interface{}{ - "thinking_enabled": true, - } - } - } - - // 处理搜索模式 - if strings.Contains(modelName, "-search") { - modelName = strings.Replace(modelName, "-search", "", 1) - chatType = "search" - lastMsgIdx := len(apiReq.Messages) - 1 - if lastMsgIdx >= 0 { - apiReq.Messages[lastMsgIdx].ChatType = "search" - } - } - - // 处理图片消息 - lastMsgIdx := len(apiReq.Messages) - 1 - if lastMsgIdx >= 0 { - lastMsg := apiReq.Messages[lastMsgIdx] - - // 检查内容是否为数组 - contentArray, ok := lastMsg.Content.([]interface{}) - if ok { - // 处理内容数组 - for i, item := range contentArray { - itemMap, isMap := item.(map[string]interface{}) - if !isMap { - continue - } - - // 检查是否包含图像URL - if imageURL, hasImageURL := itemMap["image_url"]; hasImageURL { - imageURLMap, isMap := imageURL.(map[string]interface{}) - if !isMap { - continue - } - - // 获取URL - url, hasURL := imageURLMap["url"].(string) - if !hasURL { - continue - } - - // 上传图像 - imageID, uploadErr := uploadImage(url, authToken) - if uploadErr != nil { - logError("[reqID:%s] 上传图像失败: %v", reqID, uploadErr) - continue - } - - // 替换内容 - contentArrayCopy := make([]interface{}, len(contentArray)) - copy(contentArrayCopy, contentArray) - contentArrayCopy[i] = map[string]interface{}{ - "type": "image", - "image": imageID, - } - apiReq.Messages[lastMsgIdx].Content = contentArrayCopy - break - } - } - } - } - - // 创建通义千问请求 - qwenReq := QwenRequest{ - Model: modelName, - Messages: apiReq.Messages, - Stream: true, - ChatType: chatType, - ID: generateUUID(), - } - - // 序列化请求 - reqData, err := json.Marshal(qwenReq) - if err != nil { - logError("[reqID:%s] 序列化请求失败: %v", reqID, err) - http.Error(w, "内部服务器错误", http.StatusInternalServerError) - return - } - - // 创建HTTP请求 - req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData)) - if err != nil { - logError("[reqID:%s] 创建请求失败: %v", reqID, err) - http.Error(w, "内部服务器错误", http.StatusInternalServerError) - return - } - - // 设置请求头 - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Mozilla/5.0") - - // 发送请求 - client := getHTTPClient() - resp, err := client.Do(req) - if err != nil { - logError("[reqID:%s] 发送请求失败: %v", reqID, err) - http.Error(w, "连接到API失败", http.StatusBadGateway) - return - } - defer resp.Body.Close() - - // 检查响应状态 - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes)) - http.Error(w, fmt.Sprintf("API错误,状态码: %d", resp.StatusCode), resp.StatusCode) - return - } - - // 设置响应头 - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - - // 创建响应ID和时间戳 - respID := fmt.Sprintf("chatcmpl-%s", generateUUID()) - createdTime := time.Now().Unix() - - // 创建读取器和Flusher - reader := bufio.NewReaderSize(resp.Body, 16384) - flusher, ok := w.(http.Flusher) - if !ok { - logError("[reqID:%s] 流式传输不支持", reqID) - http.Error(w, "流式传输不支持", http.StatusInternalServerError) - return - } - - // 发送角色块 - roleChunk := createRoleChunk(respID, createdTime, modelName) - w.Write([]byte("data: " + string(roleChunk) + "\n\n")) - flusher.Flush() - - // 用于去重的前一个内容 - previousContent := "" - - ) - - // 流式传输状态跟踪 - var isCompleted bool - var totalChunks int - buffer := "" - - logInfo("[reqID:%s] 开始流式传输...", reqID) - - // 持续读取响应 - for { - // 添加超时检测 - select { - case <-r.Context().Done(): - logWarn("[reqID:%s] 请求超时或被客户端取消", reqID) - return - default: - // 继续处理 - } - - // 读取一块数据 - chunk := make([]byte, 4096) - n, err := reader.Read(chunk) - - // 处理读取结果 - if err != nil { - if err != io.EOF { - logError("[reqID:%s] 读取响应出错: %v", reqID, err) - return - } - - // EOF处理 - logInfo("[reqID:%s] 读取到文件末尾,完成流式传输", reqID) - - // 如果buffer中还有内容,尝试处理 - if len(buffer) > 0 { - logInfo("[reqID:%s] 处理最后的缓冲区 (长度: %d)", reqID, len(buffer)) - } - - if !isCompleted { - // 发送结束信号(如果没有正常结束) - finishReason := "stop" - doneChunk := createDoneChunk(respID, createdTime, modelName, finishReason) - w.Write([]byte("data: " + string(doneChunk) + "\n\n")) - w.Write([]byte("data: [DONE]\n\n")) - flusher.Flush() - logInfo("[reqID:%s] 发送结束信号,共传输 %d 个数据块", reqID, totalChunks) - } - break - } - - if n > 0 { - // 添加到缓冲区 - buffer += string(chunk[:n]) - - // 查找所有的data行 - matches := dataRegex.FindAllStringSubmatch(buffer, -1) - - // 处理匹配到的行 - for _, match := range matches { - // 获取数据部分 - dataStr := match[1] - - // 从缓冲区中移除已处理的行 - buffer = strings.Replace(buffer, "data: "+dataStr+"\n", "", 1) - - // 处理[DONE]消息 - if dataStr == "[DONE]" { - w.Write([]byte("data: [DONE]\n\n")) - flusher.Flush() - isCompleted = true - logInfo("[reqID:%s] 处理完成信号 [DONE]", reqID) - continue - } - - // 解析JSON - var qwenResp QwenResponse - if err := json.Unmarshal([]byte(dataStr), &qwenResp); err != nil { - logWarn("[reqID:%s] 解析JSON失败: %v, data: %s", reqID, err, dataStr) - continue - } - - // 处理块 - for _, choice := range qwenResp.Choices { - content := choice.Delta.Content - - // 去重 - if strings.HasPrefix(content, previousContent) { - content = content[len(previousContent):] - } - - if content != "" { - previousContent += content - - // 创建内容块 - contentChunk := createContentChunk(respID, createdTime, modelName, content) - w.Write([]byte("data: " + string(contentChunk) + "\n\n")) - flusher.Flush() - totalChunks++ - - if totalChunks % 10 == 0 { - logInfo("[reqID:%s] 已传输 %d 个数据块", reqID, totalChunks) - } - } - - // 处理完成标志 - if choice.FinishReason != "" { - finishReason := choice.FinishReason - doneChunk := createDoneChunk(respID, createdTime, modelName, finishReason) - w.Write([]byte("data: " + string(doneChunk) + "\n\n")) - flusher.Flush() - logInfo("[reqID:%s] 处理完成标志: %s", reqID, finishReason) - } - } - } - } - - // 如果已完成,退出循环 - if isCompleted && buffer == "" { - logInfo("[reqID:%s] 处理完成并且缓冲区为空,结束流式传输", reqID) - break - } - } - - logInfo("[reqID:%s] 流式传输完成,总共发送 %d 个数据块", reqID, totalChunks) -} - -// 处理聊天完成请求(非流式) -func handleNonStreamingRequest(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string) { - logInfo("[reqID:%s] 处理非流式请求", reqID) - - // 从请求中提取token - authToken, err := extractToken(r) - if err != nil { - logError("[reqID:%s] 提取token失败: %v", reqID, err) - http.Error(w, "无效的认证信息", http.StatusUnauthorized) - return - } - - // 检查消息 - if len(apiReq.Messages) == 0 { - logError("[reqID:%s] 消息为空", reqID) - http.Error(w, "消息为空", http.StatusBadRequest) - return - } - - // 准备模型名和聊天类型 - modelName := "qwen-turbo-latest" - if apiReq.Model != "" { - modelName = apiReq.Model - } - chatType := "t2t" - - // 处理特殊模型名后缀 - if strings.Contains(modelName, "-draw") { - handleDrawRequest(w, r, apiReq, reqID, authToken) - return - } - - // 处理思考模式 - if strings.Contains(modelName, "-thinking") { - modelName = strings.Replace(modelName, "-thinking", "", 1) - lastMsgIdx := len(apiReq.Messages) - 1 - if lastMsgIdx >= 0 { - apiReq.Messages[lastMsgIdx].FeatureConfig = map[string]interface{}{ - "thinking_enabled": true, - } - } - } - - // 处理搜索模式 - if strings.Contains(modelName, "-search") { - modelName = strings.Replace(modelName, "-search", "", 1) - chatType = "search" - lastMsgIdx := len(apiReq.Messages) - 1 - if lastMsgIdx >= 0 { - apiReq.Messages[lastMsgIdx].ChatType = "search" - } - } - - // 处理图片消息 - lastMsgIdx := len(apiReq.Messages) - 1 - if lastMsgIdx >= 0 { - lastMsg := apiReq.Messages[lastMsgIdx] - - // 检查内容是否为数组 - contentArray, ok := lastMsg.Content.([]interface{}) - if ok { - // 处理内容数组 - for i, item := range contentArray { - itemMap, isMap := item.(map[string]interface{}) - if !isMap { - continue - } - - // 检查是否包含图像URL - if imageURL, hasImageURL := itemMap["image_url"]; hasImageURL { - imageURLMap, isMap := imageURL.(map[string]interface{}) - if !isMap { - continue - } - - // 获取URL - url, hasURL := imageURLMap["url"].(string) - if !hasURL { - continue - } - - // 上传图像 - imageID, uploadErr := uploadImage(url, authToken) - if uploadErr != nil { - logError("[reqID:%s] 上传图像失败: %v", reqID, uploadErr) - continue - } - - // 替换内容 - contentArrayCopy := make([]interface{}, len(contentArray)) - copy(contentArrayCopy, contentArray) - contentArrayCopy[i] = map[string]interface{}{ - "type": "image", - "image": imageID, - } - apiReq.Messages[lastMsgIdx].Content = contentArrayCopy - break - } - } - } - } - - // 创建通义千问请求 - 通过流式请求来获取非流式响应 - qwenReq := QwenRequest{ - Model: modelName, - Messages: apiReq.Messages, - Stream: true, // 使用流式API - ChatType: chatType, - ID: generateUUID(), - } - - // 序列化请求 - reqData, err := json.Marshal(qwenReq) - if err != nil { - logError("[reqID:%s] 序列化请求失败: %v", reqID, err) - http.Error(w, "内部服务器错误", http.StatusInternalServerError) - return - } - - // 创建HTTP请求 - req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData)) - if err != nil { - logError("[reqID:%s] 创建请求失败: %v", reqID, err) - http.Error(w, "内部服务器错误", http.StatusInternalServerError) - return - } - - // 设置请求头 - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Mozilla/5.0") - - // 发送请求 - client := getHTTPClient() - resp, err := client.Do(req) - if err != nil { - logError("[reqID:%s] 发送请求失败: %v", reqID, err) - http.Error(w, "连接到API失败", http.StatusBadGateway) - return - } - defer resp.Body.Close() - - // 检查响应状态 - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes)) - http.Error(w, fmt.Sprintf("API错误,状态码: %d", resp.StatusCode), resp.StatusCode) - return - } - - // 从流式响应中提取完整内容 - fullContent, err := extractFullContentFromStream(resp.Body, reqID) - if err != nil { - logError("[reqID:%s] 提取内容失败: %v", reqID, err) - http.Error(w, "解析响应失败", http.StatusInternalServerError) - return - } - - // 创建非流式响应 - completionResponse := CompletionResponse{ - ID: fmt.Sprintf("chatcmpl-%s", generateUUID()), - Object: "chat.completion", - Created: time.Now().Unix(), - Model: modelName, - Choices: []struct { - Index int `json:"index"` - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - }{ - { - Index: 0, - Message: struct { - Role string `json:"role"` - Content string `json:"content"` - }{ - Role: "assistant", - Content: fullContent, - }, - FinishReason: "stop", - }, - }, - Usage: struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - }{ - PromptTokens: estimateTokens(apiReq.Messages), - CompletionTokens: len(fullContent) / 4, - TotalTokens: estimateTokens(apiReq.Messages) + len(fullContent)/4, - }, - } - - // 返回响应 - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(completionResponse) -} - -// 处理图像生成请求 -func handleImageGenerations(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string) { - logInfo("[reqID:%s] 处理图像生成请求", reqID) - - // 从请求中提取token - authToken, err := extractToken(r) - if err != nil { - logError("[reqID:%s] 提取token失败: %v", reqID, err) - http.Error(w, "无效的认证信息", http.StatusUnauthorized) - return - } - - // 解析图像生成请求 - var imgReq ImagesRequest - if err := json.NewDecoder(r.Body).Decode(&imgReq); err != nil { - logError("[reqID:%s] 解析图像请求失败: %v", reqID, err) - http.Error(w, "无效的请求体", http.StatusBadRequest) - return - } - - // 默认值设置 - if imgReq.Model == "" { - imgReq.Model = "qwen-max-latest-draw" - } - if imgReq.Size == "" { - imgReq.Size = "1024*1024" - } - if imgReq.N <= 0 { - imgReq.N = 1 - } - - // 获取纯模型名(去除-draw后缀) - modelName := strings.Replace(imgReq.Model, "-draw", "", 1) - modelName = strings.Replace(modelName, "-thinking", "", 1) - modelName = strings.Replace(modelName, "-search", "", 1) - - // 创建图像生成任务 - qwenReq := QwenRequest{ - Stream: false, - IncrementalOutput: true, - ChatType: "t2i", - Model: modelName, - Messages: []APIMessage{ - { - Role: "user", - Content: imgReq.Prompt, - ChatType: "t2i", - Extra: map[string]interface{}{}, - FeatureConfig: map[string]interface{}{ - "thinking_enabled": false, - }, - }, - }, - ID: generateUUID(), - Size: imgReq.Size, - } - - // 序列化请求 - reqData, err := json.Marshal(qwenReq) - if err != nil { - logError("[reqID:%s] 序列化请求失败: %v", reqID, err) - http.Error(w, "内部服务器错误", http.StatusInternalServerError) - return - } - - // 创建HTTP请求 - req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData)) - if err != nil { - logError("[reqID:%s] 创建请求失败: %v", reqID, err) - http.Error(w, "内部服务器错误", http.StatusInternalServerError) - return - } - - // 设置请求头 - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Mozilla/5.0") - - // 发送请求 - client := getHTTPClient() - resp, err := client.Do(req) - if err != nil { - logError("[reqID:%s] 发送请求失败: %v", reqID, err) - http.Error(w, "连接到API失败", http.StatusBadGateway) - return - } - defer resp.Body.Close() - - // 检查响应状态 - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes)) - http.Error(w, fmt.Sprintf("API错误,状态码: %d", resp.StatusCode), resp.StatusCode) - return - } - - // 解析响应获取任务ID - var qwenResp QwenResponse - if err := json.NewDecoder(resp.Body).Decode(&qwenResp); err != nil { - logError("[reqID:%s] 解析响应失败: %v", reqID, err) - http.Error(w, "解析响应失败", http.StatusInternalServerError) - return - } - - // 提取任务ID - taskID := "" - for _, msg := range qwenResp.Messages { - if msg.Role == "assistant" && msg.Extra.Wanx.TaskID != "" { - taskID = msg.Extra.Wanx.TaskID - break - } - } - - if taskID == "" { - logError("[reqID:%s] 无法获取图像生成任务ID", reqID) - http.Error(w, "无法获取图像生成任务ID", http.StatusInternalServerError) - return - } - - // 轮询等待图像生成完成 - var imageURL string - for i := 0; i < 30; i++ { - select { - case <-r.Context().Done(): - logWarn("[reqID:%s] 请求超时或被客户端取消", reqID) - http.Error(w, "请求超时", http.StatusGatewayTimeout) - return - default: - // 继续处理 - } - - // 检查任务状态 - statusURL := TasksURL + taskID - statusReq, err := http.NewRequestWithContext(r.Context(), "GET", statusURL, nil) - if err != nil { - logError("[reqID:%s] 创建状态请求失败: %v", reqID, err) - time.Sleep(6 * time.Second) - continue - } - - // 设置请求头 - statusReq.Header.Set("Authorization", "Bearer "+authToken) - statusReq.Header.Set("User-Agent", "Mozilla/5.0") - - // 发送请求 - statusResp, err := client.Do(statusReq) - if err != nil { - logError("[reqID:%s] 发送状态请求失败: %v", reqID, err) - time.Sleep(6 * time.Second) - continue - } - - // 解析响应 - var statusData TaskStatusResponse - if err := json.NewDecoder(statusResp.Body).Decode(&statusData); err != nil { - logError("[reqID:%s] 解析状态响应失败: %v", reqID, err) - statusResp.Body.Close() - time.Sleep(6 * time.Second) - continue - } - statusResp.Body.Close() - - // 检查是否有内容 - if statusData.Content != "" { - imageURL = statusData.Content - break - } - - time.Sleep(6 * time.Second) - } - - if imageURL == "" { - logError("[reqID:%s] 图像生成超时", reqID) - http.Error(w, "图像生成超时", http.StatusGatewayTimeout) - return - } - - // 构造图像列表 - images := make([]ImageURL, imgReq.N) - for i := 0; i < imgReq.N; i++ { - images[i] = ImageURL{URL: imageURL} - } - - // 返回响应 - imgResp := ImagesResponse{ - Created: time.Now().Unix(), - Data: images, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(imgResp) -} - -// 处理特殊的绘图请求 -func handleDrawRequest(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string, authToken string) { - logInfo("[reqID:%s] 处理绘图请求", reqID) - - // 获取绘图提示 - var prompt string - if len(apiReq.Messages) > 0 { - lastMsg := apiReq.Messages[len(apiReq.Messages)-1] - prompt, _ = lastMsg.Content.(string) - } - - if prompt == "" { - logError("[reqID:%s] 绘图提示为空", reqID) - http.Error(w, "绘图提示为空", http.StatusBadRequest) - return - } - - // 准备绘图请求参数 - size := "1024*1024" - modelName := strings.Replace(apiReq.Model, "-draw", "", 1) - modelName = strings.Replace(modelName, "-thinking", "", 1) - modelName = strings.Replace(modelName, "-search", "", 1) - - // 创建绘图请求 - qwenReq := QwenRequest{ - Stream: false, - IncrementalOutput: true, - ChatType: "t2i", - Model: modelName, - Messages: []APIMessage{ - { - Role: "user", - Content: prompt, - ChatType: "t2i", - Extra: map[string]interface{}{}, - FeatureConfig: map[string]interface{}{ - "thinking_enabled": false, - }, - }, - }, - ID: generateUUID(), - Size: size, - } - - // 序列化请求 - reqData, err := json.Marshal(qwenReq) - if err != nil { - logError("[reqID:%s] 序列化请求失败: %v", reqID, err) - http.Error(w, "内部服务器错误", http.StatusInternalServerError) - return - } - - // 创建HTTP请求 - req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData)) - if err != nil { - logError("[reqID:%s] 创建请求失败: %v", reqID, err) - http.Error(w, "内部服务器错误", http.StatusInternalServerError) - return - } - - // 设置请求头 - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+authToken) - req.Header.Set("User-Agent", "Mozilla/5.0") - - // 发送请求 - client := getHTTPClient() - resp, err := client.Do(req) - if err != nil { - logError("[reqID:%s] 发送请求失败: %v", reqID, err) - http.Error(w, "连接到API失败", http.StatusBadGateway) - return - } - defer resp.Body.Close() - - // 检查响应状态 - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes)) - http.Error(w, fmt.Sprintf("API错误,状态码: %d", resp.StatusCode), resp.StatusCode) - return - } - - // 解析响应获取任务ID - var qwenResp QwenResponse - if err := json.NewDecoder(resp.Body).Decode(&qwenResp); err != nil { - logError("[reqID:%s] 解析响应失败: %v", reqID, err) - http.Error(w, "解析响应失败", http.StatusInternalServerError) - return - } - - // 提取任务ID - taskID := "" - for _, msg := range qwenResp.Messages { - if msg.Role == "assistant" && msg.Extra.Wanx.TaskID != "" { - taskID = msg.Extra.Wanx.TaskID - break - } - } - - if taskID == "" { - logError("[reqID:%s] 无法获取图像生成任务ID", reqID) - http.Error(w, "无法获取图像生成任务ID", http.StatusInternalServerError) - return - } - - // 轮询等待图像生成完成 - var imageURL string - for i := 0; i < 30; i++ { - select { - case <-r.Context().Done(): - logWarn("[reqID:%s] 请求超时或被客户端取消", reqID) - http.Error(w, "请求超时", http.StatusGatewayTimeout) - return - default: - // 继续处理 - } - - // 检查任务状态 - statusURL := TasksURL + taskID - statusReq, err := http.NewRequestWithContext(r.Context(), "GET", statusURL, nil) - if err != nil { - logError("[reqID:%s] 创建状态请求失败: %v", reqID, err) - time.Sleep(6 * time.Second) - continue - } - - // 设置请求头 - statusReq.Header.Set("Authorization", "Bearer "+authToken) - statusReq.Header.Set("User-Agent", "Mozilla/5.0") - - // 发送请求 - statusResp, err := client.Do(statusReq) - if err != nil { - logError("[reqID:%s] 发送状态请求失败: %v", reqID, err) - time.Sleep(6 * time.Second) - continue - } - - // 解析响应 - var statusData TaskStatusResponse - if err := json.NewDecoder(statusResp.Body).Decode(&statusData); err != nil { - logError("[reqID:%s] 解析状态响应失败: %v", reqID, err) - statusResp.Body.Close() - time.Sleep(6 * time.Second) - continue - } - statusResp.Body.Close() - - // 检查是否有内容 - if statusData.Content != "" { - imageURL = statusData.Content - break - } - - time.Sleep(6 * time.Second) - } - - if imageURL == "" { - logError("[reqID:%s] 图像生成超时", reqID) - http.Error(w, "图像生成超时", http.StatusGatewayTimeout) - return - } - - // 返回OpenAI标准格式响应(使用Markdown嵌入图片) - completionResponse := CompletionResponse{ - ID: fmt.Sprintf("chatcmpl-%s", generateUUID()), - Object: "chat.completion", - Created: time.Now().Unix(), - Model: apiReq.Model, - Choices: []struct { - Index int `json:"index"` - Message struct { - Role string `json:"role"` - Content string `json:"content"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - }{ - { - Index: 0, - Message: struct { - Role string `json:"role"` - Content string `json:"content"` - }{ - Role: "assistant", - Content: fmt.Sprintf("![%s](%s)", imageURL, imageURL), - }, - FinishReason: "stop", - }, - }, - Usage: struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - }{ - PromptTokens: 1024, - CompletionTokens: 1024, - TotalTokens: 2048, - }, - } - - // 返回响应 - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(completionResponse) } -) - - // 持续读取响应 - buffer := "" - totalChunks := 0 - - logInfo("[reqID:%s] 开始提取非流式内容...", reqID) - - for { - // 读取一块数据 - chunk := make([]byte, 4096) - n, err := reader.Read(chunk) - - if n > 0 { - // 添加到缓冲区 - buffer += string(chunk[:n]) - - // 查找所有的data行 - matches := dataRegex.FindAllStringSubmatch(buffer, -1) - - // 处理匹配到的行 - for _, match := range matches { - // 获取数据部分 - dataStr := match[1] - - // 从缓冲区中移除已处理的行 - buffer = strings.Replace(buffer, "data: "+dataStr+"\n", "", 1) - - // 处理[DONE]消息 - if dataStr == "[DONE]" { - continue - } - - // 解析JSON - var qwenResp QwenResponse - if err := json.Unmarshal([]byte(dataStr), &qwenResp); err != nil { - logWarn("[reqID:%s] 解析JSON失败: %v, data: %s", reqID, err, dataStr) - continue - } - - // 提取内容 - for _, choice := range qwenResp.Choices { - contentBuilder.WriteString(choice.Delta.Content) - totalChunks++ - } - } - } - - // 处理读取结果 - if err != nil { - if err != io.EOF { - logError("[reqID:%s] 读取响应出错: %v", reqID, err) - return contentBuilder.String(), err - } - - // EOF处理 - 完成 - logInfo("[reqID:%s] 非流式内容提取完成,共 %d 个数据块", reqID, totalChunks) - break - } - } - - if contentBuilder.Len() == 0 { - logWarn("[reqID:%s] 提取的内容为空", reqID) - } else { - logInfo("[reqID:%s] 成功提取内容,长度: %d 字符", reqID, contentBuilder.Len()) - } - - return contentBuilder.String(), nil +// 记录提取的内容长度 +contentStr := contentBuilder.String() +logInfo("[reqID:%s] 非流式模式:成功提取完整内容,长度: %d", reqID, len(contentStr)) +return contentStr, nil } // 上传图像到千问API @@ -6484,4 +2140,4 @@ func estimateTokens(messages []APIMessage) int { } } return total -} \ No newline at end of file +}