Files
AI2API/qwen2api.go
BlueSkyXN b63b8dae00 0.1.8
2025-03-25 10:26:22 +08:00

2144 lines
56 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"mime/multipart"
"net/http"
"os"
"os/signal"
"regexp"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
)
// 版本和API常量
const (
Version = "1.0.0"
TargetURL = "https://chat.qwen.ai/api/chat/completions"
ModelsURL = "https://chat.qwen.ai/api/models"
FilesURL = "https://chat.qwen.ai/api/v1/files/"
TasksURL = "https://chat.qwen.ai/api/v1/tasks/status/"
)
// 默认模型列表(当获取接口失败时使用)
var DefaultModels = []string{
"qwen-max-latest",
"qwen-plus-latest",
"qwen2.5-vl-72b-instruct",
"qwen2.5-14b-instruct-1m",
"qvq-72b-preview",
"qwq-32b-preview",
"qwen2.5-coder-32b-instruct",
"qwen-turbo-latest",
"qwen2.5-72b-instruct",
}
// 扩展模型变种后缀
var ModelSuffixes = []string{
"",
"-thinking",
"-search",
"-thinking-search",
"-draw",
}
// 日志级别常量
const (
LogLevelDebug = "debug"
LogLevelInfo = "info"
LogLevelWarn = "warn"
LogLevelError = "error"
)
// WorkerPool 工作池结构体用于管理goroutine
type WorkerPool struct {
taskQueue chan *Task
workerCount int
shutdownChannel chan struct{}
wg sync.WaitGroup
}
// Task 任务结构体,包含请求处理所需数据
type Task struct {
r *http.Request
w http.ResponseWriter
done chan struct{}
reqID string
isStream bool
apiReq APIRequest
path string
}
// Semaphore 信号量实现,用于限制并发数量
type Semaphore struct {
sem chan struct{}
}
// 配置结构体
type Config struct {
Port string
Address string
LogLevel string
DevMode bool
MaxRetries int
Timeout int
VerifySSL bool
WorkerCount int
QueueSize int
MaxConcurrent int
APIPrefix string
}
// APIRequest OpenAI兼容的请求结构体
type APIRequest struct {
Model string `json:"model"`
Messages []APIMessage `json:"messages"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}
// APIMessage 消息结构体
type APIMessage struct {
Role string `json:"role"`
Content interface{} `json:"content"`
FeatureConfig interface{} `json:"feature_config,omitempty"`
ChatType string `json:"chat_type,omitempty"`
Extra interface{} `json:"extra,omitempty"`
}
// 内容项目结构体(处理图像等内容)
type ContentItem struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
Image string `json:"image,omitempty"`
}
// ImageURL 图像URL结构体
type ImageURL struct {
URL string `json:"url"`
}
// QwenRequest 通义千问API请求结构体
type QwenRequest struct {
Model string `json:"model"`
Messages []APIMessage `json:"messages"`
Stream bool `json:"stream"`
ChatType string `json:"chat_type,omitempty"`
ID string `json:"id,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
Size string `json:"size,omitempty"`
}
// QwenResponse 通义千问API响应结构体
type QwenResponse struct {
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
Extra struct {
Wanx struct {
TaskID string `json:"task_id"`
} `json:"wanx"`
} `json:"extra"`
} `json:"messages"`
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// FileUploadResponse 文件上传响应
type FileUploadResponse struct {
ID string `json:"id"`
}
// TaskStatusResponse 任务状态响应
type TaskStatusResponse struct {
Content string `json:"content"`
}
// StreamChunk OpenAI兼容的流式响应块
type StreamChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
} `json:"choices"`
}
// CompletionResponse OpenAI兼容的完成响应
type CompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// ImagesResponse 图像生成响应
type ImagesResponse struct {
Created int64 `json:"created"`
Data []ImageURL `json:"data"`
}
// ImagesRequest 图像生成请求
type ImagesRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
}
// ModelData 模型数据
type ModelData struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
// ModelsResponse 模型列表响应
type ModelsResponse struct {
Object string `json:"object"`
Data []ModelData `json:"data"`
}
// 全局变量
var (
appConfig *Config
logger *log.Logger
logLevel string
logMutex sync.Mutex
workerPool *WorkerPool
requestSem *Semaphore
requestCount uint64 = 0
countMutex sync.Mutex
// 性能指标
requestCounter int64
successCounter int64
errorCounter int64
avgResponseTime int64
queuedRequests int64
rejectedRequests int64
)
// NewSemaphore 创建新的信号量
func NewSemaphore(size int) *Semaphore {
return &Semaphore{
sem: make(chan struct{}, size),
}
}
// Acquire 获取信号量(阻塞)
func (s *Semaphore) Acquire() {
s.sem <- struct{}{}
}
// Release 释放信号量
func (s *Semaphore) Release() {
<-s.sem
}
// TryAcquire 尝试获取信号量(非阻塞)
func (s *Semaphore) TryAcquire() bool {
select {
case s.sem <- struct{}{}:
return true
default:
return false
}
}
// NewWorkerPool 创建并启动一个新的工作池
func NewWorkerPool(workerCount int, queueSize int) *WorkerPool {
pool := &WorkerPool{
taskQueue: make(chan *Task, queueSize),
workerCount: workerCount,
shutdownChannel: make(chan struct{}),
}
pool.Start()
return pool
}
// Start 启动工作池中的worker goroutines
func (pool *WorkerPool) Start() {
// 启动工作goroutine
for i := 0; i < pool.workerCount; i++ {
pool.wg.Add(1)
go func(workerID int) {
defer pool.wg.Done()
logInfo("Worker %d 已启动", workerID)
for {
select {
case task, ok := <-pool.taskQueue:
if !ok {
// 队列已关闭退出worker
logInfo("Worker %d 收到队列关闭信号,准备退出", workerID)
return
}
logDebug("Worker %d 处理任务 reqID:%s", workerID, task.reqID)
// 处理任务
switch task.path {
case "/v1/models":
handleModels(task.w, task.r)
case "/v1/chat/completions":
if task.isStream {
handleStreamingRequest(task.w, task.r, task.apiReq, task.reqID)
} else {
handleNonStreamingRequest(task.w, task.r, task.apiReq, task.reqID)
}
case "/v1/images/generations":
handleImageGenerations(task.w, task.r, task.apiReq, task.reqID)
}
// 通知任务完成
close(task.done)
case <-pool.shutdownChannel:
// 收到关闭信号退出worker
logInfo("Worker %d 收到关闭信号,准备退出", workerID)
return
}
}
}(i)
}
}
// SubmitTask 提交任务到工作池,非阻塞
func (pool *WorkerPool) SubmitTask(task *Task) (bool, error) {
select {
case pool.taskQueue <- task:
// 任务成功添加到队列
return true, nil
default:
// 队列已满
return false, fmt.Errorf("任务队列已满")
}
}
// Shutdown 关闭工作池
func (pool *WorkerPool) Shutdown() {
logInfo("正在关闭工作池...")
// 发送关闭信号给所有worker
close(pool.shutdownChannel)
// 等待所有worker退出
pool.wg.Wait()
// 关闭任务队列
close(pool.taskQueue)
logInfo("工作池已关闭")
}
// 日志函数
func initLogger(level string) {
logger = log.New(os.Stdout, "[QwenAPI] ", log.LstdFlags)
logLevel = level
}
func logDebug(format string, v ...interface{}) {
if logLevel == LogLevelDebug {
logMutex.Lock()
logger.Printf("[DEBUG] "+format, v...)
logMutex.Unlock()
}
}
func logInfo(format string, v ...interface{}) {
if logLevel == LogLevelDebug || logLevel == LogLevelInfo {
logMutex.Lock()
logger.Printf("[INFO] "+format, v...)
logMutex.Unlock()
}
}
func logWarn(format string, v ...interface{}) {
if logLevel == LogLevelDebug || logLevel == LogLevelInfo || logLevel == LogLevelWarn {
logMutex.Lock()
logger.Printf("[WARN] "+format, v...)
logMutex.Unlock()
}
}
func logError(format string, v ...interface{}) {
logMutex.Lock()
logger.Printf("[ERROR] "+format, v...)
logMutex.Unlock()
// 错误计数
atomic.AddInt64(&errorCounter, 1)
}
// 解析命令行参数
func parseFlags() *Config {
cfg := &Config{}
flag.StringVar(&cfg.Port, "port", "8080", "Port to listen on")
flag.StringVar(&cfg.Address, "address", "localhost", "Address to listen on")
flag.StringVar(&cfg.LogLevel, "log-level", LogLevelInfo, "Log level (debug, info, warn, error)")
flag.BoolVar(&cfg.DevMode, "dev", false, "Enable development mode with enhanced logging")
flag.IntVar(&cfg.MaxRetries, "max-retries", 3, "Maximum number of retries for failed requests")
flag.IntVar(&cfg.Timeout, "timeout", 300, "Request timeout in seconds")
flag.BoolVar(&cfg.VerifySSL, "verify-ssl", true, "Verify SSL certificates")
flag.IntVar(&cfg.WorkerCount, "workers", 50, "Number of worker goroutines in the pool")
flag.IntVar(&cfg.QueueSize, "queue-size", 500, "Size of the task queue")
flag.IntVar(&cfg.MaxConcurrent, "max-concurrent", 100, "Maximum number of concurrent requests")
flag.StringVar(&cfg.APIPrefix, "api-prefix", "", "API prefix for all endpoints")
flag.Parse()
// 如果开发模式开启自动设置日志级别为debug
if cfg.DevMode && cfg.LogLevel != LogLevelDebug {
cfg.LogLevel = LogLevelDebug
fmt.Println("开发模式已启用日志级别设置为debug")
}
return cfg
}
// 从请求头中提取令牌
func extractToken(r *http.Request) (string, error) {
// 获取 Authorization 头部
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", fmt.Errorf("missing Authorization header")
}
// 验证格式并提取令牌
if !strings.HasPrefix(authHeader, "Bearer ") {
return "", fmt.Errorf("invalid Authorization header format, must start with 'Bearer '")
}
// 提取令牌值
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == "" {
return "", fmt.Errorf("empty token in Authorization header")
}
return token, nil
}
// 设置CORS头
func setCORSHeaders(w http.ResponseWriter) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
}
// 生成UUID
func generateUUID() string {
b := make([]byte, 16)
_, err := rand.Read(b)
if err != nil {
return fmt.Sprintf("%d", time.Now().UnixNano())
}
return fmt.Sprintf("%x-%x-%x-%x-%x",
b[0:4], b[4:6], b[6:8], b[8:10], b[10:])
}
// 安全的HTTP客户端支持禁用SSL验证
func getHTTPClient() *http.Client {
tr := &http.Transport{
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSClientConfig: nil, // 默认配置
}
// 如果配置了禁用SSL验证
if !appConfig.VerifySSL {
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
return &http.Client{
Timeout: time.Duration(appConfig.Timeout) * time.Second,
Transport: tr,
}
}
// 主入口函数
func main() {
// 解析配置
appConfig = parseFlags()
// 初始化日志
initLogger(appConfig.LogLevel)
logInfo("启动服务: 地址=%s, 端口=%s, 版本=%s, 日志级别=%s",
appConfig.Address, appConfig.Port, Version, appConfig.LogLevel)
// 创建工作池和信号量
workerPool = NewWorkerPool(appConfig.WorkerCount, appConfig.QueueSize)
requestSem = NewSemaphore(appConfig.MaxConcurrent)
logInfo("工作池已创建: %d个worker, 队列大小为%d", appConfig.WorkerCount, appConfig.QueueSize)
// 配置更高的并发处理能力
http.DefaultTransport.(*http.Transport).MaxIdleConnsPerHost = 100
http.DefaultTransport.(*http.Transport).MaxIdleConns = 100
http.DefaultTransport.(*http.Transport).IdleConnTimeout = 90 * time.Second
// 创建自定义服务器,支持更高并发
server := &http.Server{
Addr: appConfig.Address + ":" + appConfig.Port,
ReadTimeout: time.Duration(appConfig.Timeout) * time.Second,
WriteTimeout: time.Duration(appConfig.Timeout) * time.Second,
IdleTimeout: 120 * time.Second,
Handler: nil, // 使用默认的ServeMux
}
// API路径前缀
apiPrefix := appConfig.APIPrefix
// 创建处理器
http.HandleFunc(apiPrefix+"/v1/models", func(w http.ResponseWriter, r *http.Request) {
setCORSHeaders(w)
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 计数
countMutex.Lock()
requestCount++
currentCount := requestCount
countMutex.Unlock()
reqID := generateRequestID()
logInfo("[reqID:%s] 收到模型列表请求 #%d", reqID, currentCount)
// 请求计数
atomic.AddInt64(&requestCounter, 1)
startTime := time.Now()
// 创建任务
task := &Task{
r: r,
w: w,
done: make(chan struct{}),
reqID: reqID,
path: "/v1/models",
}
// 尝试获取信号量
if !requestSem.TryAcquire() {
// 请求数量超过限制
atomic.AddInt64(&rejectedRequests, 1)
logWarn("[reqID:%s] 请求被拒绝: 当前并发请求数已达上限", reqID)
w.Header().Set("Retry-After", "30")
http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable)
return
}
// 释放信号量(在函数返回时)
defer requestSem.Release()
// 添加到任务队列
atomic.AddInt64(&queuedRequests, 1)
submitted, err := workerPool.SubmitTask(task)
if !submitted {
atomic.AddInt64(&queuedRequests, -1)
atomic.AddInt64(&rejectedRequests, 1)
logError("[reqID:%s] 提交任务失败: %v", reqID, err)
w.Header().Set("Retry-After", "60")
http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable)
return
}
logInfo("[reqID:%s] 任务已提交到队列", reqID)
// 等待任务完成或超时
select {
case <-task.done:
// 任务已完成
logInfo("[reqID:%s] 任务已完成", reqID)
case <-r.Context().Done():
// 请求被取消或超时
logWarn("[reqID:%s] 请求被取消或超时", reqID)
}
// 请求处理完成,更新指标
atomic.AddInt64(&queuedRequests, -1)
elapsed := time.Since(startTime).Milliseconds()
// 更新平均响应时间
atomic.AddInt64(&avgResponseTime, elapsed)
if r.Context().Err() == nil {
// 成功计数增加
atomic.AddInt64(&successCounter, 1)
logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed)
} else {
logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed)
}
})
http.HandleFunc(apiPrefix+"/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
setCORSHeaders(w)
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 计数器增加
countMutex.Lock()
requestCount++
currentCount := requestCount
countMutex.Unlock()
reqID := generateRequestID()
logInfo("[reqID:%s] 收到新请求 #%d", reqID, currentCount)
// 请求计数
atomic.AddInt64(&requestCounter, 1)
startTime := time.Now()
// 尝试获取信号量
if !requestSem.TryAcquire() {
// 请求数量超过限制
atomic.AddInt64(&rejectedRequests, 1)
logWarn("[reqID:%s] 请求 #%d 被拒绝: 当前并发请求数已达上限", reqID, currentCount)
w.Header().Set("Retry-After", "30")
http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable)
return
}
// 释放信号量(在函数返回时)
defer requestSem.Release()
// 解析请求体
var apiReq APIRequest
if err := json.NewDecoder(r.Body).Decode(&apiReq); err != nil {
logError("[reqID:%s] 解析请求失败: %v", reqID, err)
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// 创建任务
task := &Task{
r: r,
w: w,
done: make(chan struct{}),
reqID: reqID,
isStream: apiReq.Stream,
apiReq: apiReq,
path: "/v1/chat/completions",
}
// 添加到任务队列
atomic.AddInt64(&queuedRequests, 1)
submitted, err := workerPool.SubmitTask(task)
if !submitted {
atomic.AddInt64(&queuedRequests, -1)
atomic.AddInt64(&rejectedRequests, 1)
logError("[reqID:%s] 提交任务失败: %v", reqID, err)
w.Header().Set("Retry-After", "60")
http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable)
return
}
logInfo("[reqID:%s] 任务已提交到队列", reqID)
// 等待任务完成或超时
select {
case <-task.done:
// 任务已完成
logInfo("[reqID:%s] 任务已完成", reqID)
case <-r.Context().Done():
// 请求被取消或超时
logWarn("[reqID:%s] 请求被取消或超时", reqID)
}
// 请求处理完成,更新指标
atomic.AddInt64(&queuedRequests, -1)
elapsed := time.Since(startTime).Milliseconds()
// 更新平均响应时间
atomic.AddInt64(&avgResponseTime, elapsed)
if r.Context().Err() == nil {
// 成功计数增加
atomic.AddInt64(&successCounter, 1)
logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed)
} else {
logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed)
}
})
http.HandleFunc(apiPrefix+"/v1/images/generations", func(w http.ResponseWriter, r *http.Request) {
setCORSHeaders(w)
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 计数器增加
countMutex.Lock()
requestCount++
currentCount := requestCount
countMutex.Unlock()
reqID := generateRequestID()
logInfo("[reqID:%s] 收到图像生成请求 #%d", reqID, currentCount)
// 请求计数
atomic.AddInt64(&requestCounter, 1)
startTime := time.Now()
// 尝试获取信号量
if !requestSem.TryAcquire() {
// 请求数量超过限制
atomic.AddInt64(&rejectedRequests, 1)
logWarn("[reqID:%s] 请求 #%d 被拒绝: 当前并发请求数已达上限", reqID, currentCount)
w.Header().Set("Retry-After", "30")
http.Error(w, "Server is busy, please try again later", http.StatusServiceUnavailable)
return
}
// 释放信号量(在函数返回时)
defer requestSem.Release()
// 解析请求体
var apiReq APIRequest
if err := json.NewDecoder(r.Body).Decode(&apiReq); err != nil {
logError("[reqID:%s] 解析请求失败: %v", reqID, err)
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// 创建任务
task := &Task{
r: r,
w: w,
done: make(chan struct{}),
reqID: reqID,
apiReq: apiReq,
path: "/v1/images/generations",
}
// 添加到任务队列
atomic.AddInt64(&queuedRequests, 1)
submitted, err := workerPool.SubmitTask(task)
if !submitted {
atomic.AddInt64(&queuedRequests, -1)
atomic.AddInt64(&rejectedRequests, 1)
logError("[reqID:%s] 提交任务失败: %v", reqID, err)
w.Header().Set("Retry-After", "60")
http.Error(w, "Server queue is full, please try again later", http.StatusServiceUnavailable)
return
}
logInfo("[reqID:%s] 任务已提交到队列", reqID)
// 等待任务完成或超时
select {
case <-task.done:
// 任务已完成
logInfo("[reqID:%s] 任务已完成", reqID)
case <-r.Context().Done():
// 请求被取消或超时
logWarn("[reqID:%s] 请求被取消或超时", reqID)
}
// 请求处理完成,更新指标
atomic.AddInt64(&queuedRequests, -1)
elapsed := time.Since(startTime).Milliseconds()
// 更新平均响应时间
atomic.AddInt64(&avgResponseTime, elapsed)
if r.Context().Err() == nil {
// 成功计数增加
atomic.AddInt64(&successCounter, 1)
logInfo("[reqID:%s] 请求处理成功,耗时: %dms", reqID, elapsed)
} else {
logError("[reqID:%s] 请求处理失败: %v, 耗时: %dms", reqID, r.Context().Err(), elapsed)
}
})
// 添加健康检查端点
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
setCORSHeaders(w)
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// 获取各种计数器的值
reqCount := atomic.LoadInt64(&requestCounter)
succCount := atomic.LoadInt64(&successCounter)
errCount := atomic.LoadInt64(&errorCounter)
queuedCount := atomic.LoadInt64(&queuedRequests)
rejectedCount := atomic.LoadInt64(&rejectedRequests)
// 计算平均响应时间
var avgTime int64 = 0
if reqCount > 0 {
avgTime = atomic.LoadInt64(&avgResponseTime) / reqCount
}
// 构建响应
stats := map[string]interface{}{
"status": "ok",
"version": Version,
"requests": reqCount,
"success": succCount,
"errors": errCount,
"queued": queuedCount,
"rejected": rejectedCount,
"avg_time_ms": avgTime,
"worker_count": workerPool.workerCount,
"queue_size": len(workerPool.taskQueue),
"queue_capacity": cap(workerPool.taskQueue),
"queue_percent": float64(len(workerPool.taskQueue)) / float64(cap(workerPool.taskQueue)) * 100,
"concurrent_limit": appConfig.MaxConcurrent,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(stats)
})
// 创建停止通道
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
// 在goroutine中启动服务器
go func() {
logInfo("Starting proxy server on %s", server.Addr)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logError("Failed to start server: %v", err)
os.Exit(1)
}
}()
// 等待停止信号
<-stop
// 创建上下文用于优雅关闭
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 优雅关闭服务器
logInfo("Server is shutting down...")
if err := server.Shutdown(ctx); err != nil {
logError("Server shutdown failed: %v", err)
}
// 关闭工作池
workerPool.Shutdown()
logInfo("Server gracefully stopped")
}
// 生成请求ID
func generateRequestID() string {
return fmt.Sprintf("%x", time.Now().UnixNano())
}
// 处理模型列表请求
func handleModels(w http.ResponseWriter, r *http.Request) {
logInfo("处理模型列表请求")
// 从请求中提取token
authToken, err := extractToken(r)
if err != nil {
logWarn("提取token失败: %v", err)
// 使用默认模型列表
returnDefaultModels(w)
return
}
// 请求通义千问API获取模型列表
client := getHTTPClient()
req, err := http.NewRequest("GET", ModelsURL, nil)
if err != nil {
logError("创建请求失败: %v", err)
returnDefaultModels(w)
return
}
// 设置请求头
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
resp, err := client.Do(req)
if err != nil {
logError("请求模型列表失败: %v", err)
returnDefaultModels(w)
return
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
logError("获取模型列表返回非200状态码: %d", resp.StatusCode)
returnDefaultModels(w)
return
}
// 解析响应
var qwenResp struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&qwenResp); err != nil {
logError("解析模型列表响应失败: %v", err)
returnDefaultModels(w)
return
}
// 提取模型ID
models := make([]string, 0, len(qwenResp.Data))
for _, model := range qwenResp.Data {
models = append(models, model.ID)
}
// 如果没有获取到模型,使用默认列表
if len(models) == 0 {
logWarn("未获取到模型,使用默认列表")
returnDefaultModels(w)
return
}
// 扩展模型列表,增加变种后缀
expandedModels := make([]ModelData, 0, len(models)*len(ModelSuffixes))
for _, model := range models {
for _, suffix := range ModelSuffixes {
expandedModels = append(expandedModels, ModelData{
ID: model + suffix,
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "qwen",
})
}
}
// 构建响应
modelsResp := ModelsResponse{
Object: "list",
Data: expandedModels,
}
// 返回响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(modelsResp)
}
// 返回默认模型列表
func returnDefaultModels(w http.ResponseWriter) {
// 扩展默认模型列表,增加变种后缀
expandedModels := make([]ModelData, 0, len(DefaultModels)*len(ModelSuffixes))
for _, model := range DefaultModels {
for _, suffix := range ModelSuffixes {
expandedModels = append(expandedModels, ModelData{
ID: model + suffix,
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "qwen",
})
}
}
// 构建响应
modelsResp := ModelsResponse{
Object: "list",
Data: expandedModels,
}
// 返回响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(modelsResp)
}
// 处理聊天完成请求(流式)
func handleStreamingRequest(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string) {
logInfo("[reqID:%s] 处理流式请求", reqID)
// 从请求中提取token
authToken, err := extractToken(r)
if err != nil {
logError("[reqID:%s] 提取token失败: %v", reqID, err)
http.Error(w, "无效的认证信息", http.StatusUnauthorized)
return
}
// 检查消息
if len(apiReq.Messages) == 0 {
logError("[reqID:%s] 消息为空", reqID)
http.Error(w, "消息为空", http.StatusBadRequest)
return
}
// 准备模型名和聊天类型
modelName := "qwen-turbo-latest"
if apiReq.Model != "" {
modelName = apiReq.Model
}
chatType := "t2t"
// 处理特殊模型名后缀
if strings.Contains(modelName, "-draw") {
handleDrawRequest(w, r, apiReq, reqID, authToken)
return
}
// 处理思考模式
if strings.Contains(modelName, "-thinking") {
modelName = strings.Replace(modelName, "-thinking", "", 1)
lastMsgIdx := len(apiReq.Messages) - 1
if lastMsgIdx >= 0 {
apiReq.Messages[lastMsgIdx].FeatureConfig = map[string]interface{}{
"thinking_enabled": true,
}
}
}
// 处理搜索模式
if strings.Contains(modelName, "-search") {
modelName = strings.Replace(modelName, "-search", "", 1)
chatType = "search"
lastMsgIdx := len(apiReq.Messages) - 1
if lastMsgIdx >= 0 {
apiReq.Messages[lastMsgIdx].ChatType = "search"
}
}
// 处理图片消息
lastMsgIdx := len(apiReq.Messages) - 1
if lastMsgIdx >= 0 {
lastMsg := apiReq.Messages[lastMsgIdx]
// 检查内容是否为数组
contentArray, ok := lastMsg.Content.([]interface{})
if ok {
// 处理内容数组
for i, item := range contentArray {
itemMap, isMap := item.(map[string]interface{})
if !isMap {
continue
}
// 检查是否包含图像URL
if imageURL, hasImageURL := itemMap["image_url"]; hasImageURL {
imageURLMap, isMap := imageURL.(map[string]interface{})
if !isMap {
continue
}
// 获取URL
url, hasURL := imageURLMap["url"].(string)
if !hasURL {
continue
}
// 上传图像
imageID, uploadErr := uploadImage(url, authToken)
if uploadErr != nil {
logError("[reqID:%s] 上传图像失败: %v", reqID, uploadErr)
continue
}
// 替换内容
contentArrayCopy := make([]interface{}, len(contentArray))
copy(contentArrayCopy, contentArray)
contentArrayCopy[i] = map[string]interface{}{
"type": "image",
"image": imageID,
}
apiReq.Messages[lastMsgIdx].Content = contentArrayCopy
break
}
}
}
}
// 创建通义千问请求
qwenReq := QwenRequest{
Model: modelName,
Messages: apiReq.Messages,
Stream: true,
ChatType: chatType,
ID: generateUUID(),
}
// 序列化请求
reqData, err := json.Marshal(qwenReq)
if err != nil {
logError("[reqID:%s] 序列化请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 创建HTTP请求
req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData))
if err != nil {
logError("[reqID:%s] 创建请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
client := getHTTPClient()
resp, err := client.Do(req)
if err != nil {
logError("[reqID:%s] 发送请求失败: %v", reqID, err)
http.Error(w, "连接到API失败", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes))
http.Error(w, fmt.Sprintf("API错误状态码: %d", resp.StatusCode), resp.StatusCode)
return
}
// 设置响应头
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// 创建响应ID和时间戳
respID := fmt.Sprintf("chatcmpl-%s", generateUUID())
createdTime := time.Now().Unix()
// 创建读取器和Flusher
reader := bufio.NewReaderSize(resp.Body, 16384)
flusher, ok := w.(http.Flusher)
if !ok {
logError("[reqID:%s] 流式传输不支持", reqID)
http.Error(w, "流式传输不支持", http.StatusInternalServerError)
return
}
// 发送角色块
roleChunk := createRoleChunk(respID, createdTime, modelName)
w.Write([]byte("data: " + string(roleChunk) + "\n\n"))
flusher.Flush()
// 用于去重的前一个内容
previousContent := ""
// 创建正则表达式来查找 data: 行
dataRegex := regexp.MustCompile(`(?m)^data: (.+)$`)
// 持续读取响应
buffer := ""
pendingContent := "" // 用于累积内容,解决流处理断开问题
for {
// 添加超时检测
select {
case <-r.Context().Done():
logWarn("[reqID:%s] 请求超时或被客户端取消", reqID)
return
default:
// 继续处理
}
// 读取一块数据
chunk := make([]byte, 4096)
n, err := reader.Read(chunk)
if err != nil {
if err != io.EOF {
logError("[reqID:%s] 读取响应出错: %v", reqID, err)
return
}
break
}
// 添加到缓冲区
buffer += string(chunk[:n])
// 更稳健的处理方式:按行分割并只处理完整行
lines := strings.Split(buffer, "\n")
// 保留最后可能不完整的行
if len(lines) > 0 {
buffer = lines[len(lines)-1]
}
// 处理所有完整的行(除最后一行外)
for i := 0; i < len(lines)-1; i++ {
line := lines[i]
if !strings.HasPrefix(line, "data: ") {
continue
}
// 提取数据部分
dataStr := strings.TrimPrefix(line, "data: ")
// 处理[DONE]消息
if dataStr == "[DONE]" {
logDebug("[reqID:%s] 收到[DONE]消息", reqID)
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
continue
}
// 解析JSON
var qwenResp QwenResponse
if err := json.Unmarshal([]byte(dataStr), &qwenResp); err != nil {
logWarn("[reqID:%s] 解析JSON失败: %v, data: %s", reqID, err, dataStr)
continue
}
// 处理块
for _, choice := range qwenResp.Choices {
content := choice.Delta.Content
// 改进去重逻辑 - 只处理重复前缀
if previousContent != "" && strings.HasPrefix(content, previousContent) {
// 计算新增内容
newContent := content[len(previousContent):]
if newContent != "" {
// 创建内容块 - 只发送新部分
contentChunk := createContentChunk(respID, createdTime, modelName, newContent)
w.Write([]byte("data: " + string(contentChunk) + "\n\n"))
flusher.Flush()
pendingContent += newContent // 累积内容
}
} else if content != "" {
// 直接发送完整内容
contentChunk := createContentChunk(respID, createdTime, modelName, content)
w.Write([]byte("data: " + string(contentChunk) + "\n\n"))
flusher.Flush()
pendingContent += content // 累积内容
}
// 更新前一个内容为完整内容
if content != "" {
previousContent = content
}
// 处理完成标志
if choice.FinishReason != "" {
finishReason := choice.FinishReason
doneChunk := createDoneChunk(respID, createdTime, modelName, finishReason)
w.Write([]byte("data: " + string(doneChunk) + "\n\n"))
flusher.Flush()
}
}
}
}
// 检查是否有累积的内容需要作为最终响应
if pendingContent != "" {
logInfo("[reqID:%s] 流处理完成,累积内容长度: %d", reqID, len(pendingContent))
}
// 发送结束信号(如果没有正常结束)
finishReason := "stop"
doneChunk := createDoneChunk(respID, createdTime, modelName, finishReason)
w.Write([]byte("data: " + string(doneChunk) + "\n\n"))
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}
// 处理聊天完成请求(非流式)
func handleNonStreamingRequest(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string) {
logInfo("[reqID:%s] 处理非流式请求", reqID)
// 从请求中提取token
authToken, err := extractToken(r)
if err != nil {
logError("[reqID:%s] 提取token失败: %v", reqID, err)
http.Error(w, "无效的认证信息", http.StatusUnauthorized)
return
}
// 检查消息
if len(apiReq.Messages) == 0 {
logError("[reqID:%s] 消息为空", reqID)
http.Error(w, "消息为空", http.StatusBadRequest)
return
}
// 准备模型名和聊天类型
modelName := "qwen-turbo-latest"
if apiReq.Model != "" {
modelName = apiReq.Model
}
chatType := "t2t"
// 处理特殊模型名后缀
if strings.Contains(modelName, "-draw") {
handleDrawRequest(w, r, apiReq, reqID, authToken)
return
}
// 处理思考模式
if strings.Contains(modelName, "-thinking") {
modelName = strings.Replace(modelName, "-thinking", "", 1)
lastMsgIdx := len(apiReq.Messages) - 1
if lastMsgIdx >= 0 {
apiReq.Messages[lastMsgIdx].FeatureConfig = map[string]interface{}{
"thinking_enabled": true,
}
}
}
// 处理搜索模式
if strings.Contains(modelName, "-search") {
modelName = strings.Replace(modelName, "-search", "", 1)
chatType = "search"
lastMsgIdx := len(apiReq.Messages) - 1
if lastMsgIdx >= 0 {
apiReq.Messages[lastMsgIdx].ChatType = "search"
}
}
// 处理图片消息
lastMsgIdx := len(apiReq.Messages) - 1
if lastMsgIdx >= 0 {
lastMsg := apiReq.Messages[lastMsgIdx]
// 检查内容是否为数组
contentArray, ok := lastMsg.Content.([]interface{})
if ok {
// 处理内容数组
for i, item := range contentArray {
itemMap, isMap := item.(map[string]interface{})
if !isMap {
continue
}
// 检查是否包含图像URL
if imageURL, hasImageURL := itemMap["image_url"]; hasImageURL {
imageURLMap, isMap := imageURL.(map[string]interface{})
if !isMap {
continue
}
// 获取URL
url, hasURL := imageURLMap["url"].(string)
if !hasURL {
continue
}
// 上传图像
imageID, uploadErr := uploadImage(url, authToken)
if uploadErr != nil {
logError("[reqID:%s] 上传图像失败: %v", reqID, uploadErr)
continue
}
// 替换内容
contentArrayCopy := make([]interface{}, len(contentArray))
copy(contentArrayCopy, contentArray)
contentArrayCopy[i] = map[string]interface{}{
"type": "image",
"image": imageID,
}
apiReq.Messages[lastMsgIdx].Content = contentArrayCopy
break
}
}
}
}
// 创建通义千问请求 - 通过流式请求来获取非流式响应
qwenReq := QwenRequest{
Model: modelName,
Messages: apiReq.Messages,
Stream: true, // 使用流式API
ChatType: chatType,
ID: generateUUID(),
}
// 序列化请求
reqData, err := json.Marshal(qwenReq)
if err != nil {
logError("[reqID:%s] 序列化请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 创建HTTP请求
req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData))
if err != nil {
logError("[reqID:%s] 创建请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
client := getHTTPClient()
resp, err := client.Do(req)
if err != nil {
logError("[reqID:%s] 发送请求失败: %v", reqID, err)
http.Error(w, "连接到API失败", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes))
http.Error(w, fmt.Sprintf("API错误状态码: %d", resp.StatusCode), resp.StatusCode)
return
}
// 从流式响应中提取完整内容
fullContent, err := extractFullContentFromStream(resp.Body, reqID)
if err != nil {
logError("[reqID:%s] 提取内容失败: %v", reqID, err)
http.Error(w, "解析响应失败", http.StatusInternalServerError)
return
}
// 创建非流式响应
completionResponse := CompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", generateUUID()),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: modelName,
Choices: []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
}{
{
Index: 0,
Message: struct {
Role string `json:"role"`
Content string `json:"content"`
}{
Role: "assistant",
Content: fullContent,
},
FinishReason: "stop",
},
},
Usage: struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}{
PromptTokens: estimateTokens(apiReq.Messages),
CompletionTokens: len(fullContent) / 4,
TotalTokens: estimateTokens(apiReq.Messages) + len(fullContent)/4,
},
}
// 返回响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(completionResponse)
}
// 处理图像生成请求
func handleImageGenerations(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string) {
logInfo("[reqID:%s] 处理图像生成请求", reqID)
// 从请求中提取token
authToken, err := extractToken(r)
if err != nil {
logError("[reqID:%s] 提取token失败: %v", reqID, err)
http.Error(w, "无效的认证信息", http.StatusUnauthorized)
return
}
// 解析图像生成请求
var imgReq ImagesRequest
if err := json.NewDecoder(r.Body).Decode(&imgReq); err != nil {
logError("[reqID:%s] 解析图像请求失败: %v", reqID, err)
http.Error(w, "无效的请求体", http.StatusBadRequest)
return
}
// 默认值设置
if imgReq.Model == "" {
imgReq.Model = "qwen-max-latest-draw"
}
if imgReq.Size == "" {
imgReq.Size = "1024*1024"
}
if imgReq.N <= 0 {
imgReq.N = 1
}
// 获取纯模型名(去除-draw后缀
modelName := strings.Replace(imgReq.Model, "-draw", "", 1)
modelName = strings.Replace(modelName, "-thinking", "", 1)
modelName = strings.Replace(modelName, "-search", "", 1)
// 创建图像生成任务
qwenReq := QwenRequest{
Stream: false,
IncrementalOutput: true,
ChatType: "t2i",
Model: modelName,
Messages: []APIMessage{
{
Role: "user",
Content: imgReq.Prompt,
ChatType: "t2i",
Extra: map[string]interface{}{},
FeatureConfig: map[string]interface{}{
"thinking_enabled": false,
},
},
},
ID: generateUUID(),
Size: imgReq.Size,
}
// 序列化请求
reqData, err := json.Marshal(qwenReq)
if err != nil {
logError("[reqID:%s] 序列化请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 创建HTTP请求
req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData))
if err != nil {
logError("[reqID:%s] 创建请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
client := getHTTPClient()
resp, err := client.Do(req)
if err != nil {
logError("[reqID:%s] 发送请求失败: %v", reqID, err)
http.Error(w, "连接到API失败", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes))
http.Error(w, fmt.Sprintf("API错误状态码: %d", resp.StatusCode), resp.StatusCode)
return
}
// 解析响应获取任务ID
var qwenResp QwenResponse
if err := json.NewDecoder(resp.Body).Decode(&qwenResp); err != nil {
logError("[reqID:%s] 解析响应失败: %v", reqID, err)
http.Error(w, "解析响应失败", http.StatusInternalServerError)
return
}
// 提取任务ID
taskID := ""
for _, msg := range qwenResp.Messages {
if msg.Role == "assistant" && msg.Extra.Wanx.TaskID != "" {
taskID = msg.Extra.Wanx.TaskID
break
}
}
if taskID == "" {
logError("[reqID:%s] 无法获取图像生成任务ID", reqID)
http.Error(w, "无法获取图像生成任务ID", http.StatusInternalServerError)
return
}
// 轮询等待图像生成完成
var imageURL string
for i := 0; i < 30; i++ {
select {
case <-r.Context().Done():
logWarn("[reqID:%s] 请求超时或被客户端取消", reqID)
http.Error(w, "请求超时", http.StatusGatewayTimeout)
return
default:
// 继续处理
}
// 检查任务状态
statusURL := TasksURL + taskID
statusReq, err := http.NewRequestWithContext(r.Context(), "GET", statusURL, nil)
if err != nil {
logError("[reqID:%s] 创建状态请求失败: %v", reqID, err)
time.Sleep(6 * time.Second)
continue
}
// 设置请求头
statusReq.Header.Set("Authorization", "Bearer "+authToken)
statusReq.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
statusResp, err := client.Do(statusReq)
if err != nil {
logError("[reqID:%s] 发送状态请求失败: %v", reqID, err)
time.Sleep(6 * time.Second)
continue
}
// 解析响应
var statusData TaskStatusResponse
if err := json.NewDecoder(statusResp.Body).Decode(&statusData); err != nil {
logError("[reqID:%s] 解析状态响应失败: %v", reqID, err)
statusResp.Body.Close()
time.Sleep(6 * time.Second)
continue
}
statusResp.Body.Close()
// 检查是否有内容
if statusData.Content != "" {
imageURL = statusData.Content
break
}
time.Sleep(6 * time.Second)
}
if imageURL == "" {
logError("[reqID:%s] 图像生成超时", reqID)
http.Error(w, "图像生成超时", http.StatusGatewayTimeout)
return
}
// 构造图像列表
images := make([]ImageURL, imgReq.N)
for i := 0; i < imgReq.N; i++ {
images[i] = ImageURL{URL: imageURL}
}
// 返回响应
imgResp := ImagesResponse{
Created: time.Now().Unix(),
Data: images,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(imgResp)
}
// 处理特殊的绘图请求
func handleDrawRequest(w http.ResponseWriter, r *http.Request, apiReq APIRequest, reqID string, authToken string) {
logInfo("[reqID:%s] 处理绘图请求", reqID)
// 获取绘图提示
var prompt string
if len(apiReq.Messages) > 0 {
lastMsg := apiReq.Messages[len(apiReq.Messages)-1]
prompt, _ = lastMsg.Content.(string)
}
if prompt == "" {
logError("[reqID:%s] 绘图提示为空", reqID)
http.Error(w, "绘图提示为空", http.StatusBadRequest)
return
}
// 准备绘图请求参数
size := "1024*1024"
modelName := strings.Replace(apiReq.Model, "-draw", "", 1)
modelName = strings.Replace(modelName, "-thinking", "", 1)
modelName = strings.Replace(modelName, "-search", "", 1)
// 创建绘图请求
qwenReq := QwenRequest{
Stream: false,
IncrementalOutput: true,
ChatType: "t2i",
Model: modelName,
Messages: []APIMessage{
{
Role: "user",
Content: prompt,
ChatType: "t2i",
Extra: map[string]interface{}{},
FeatureConfig: map[string]interface{}{
"thinking_enabled": false,
},
},
},
ID: generateUUID(),
Size: size,
}
// 序列化请求
reqData, err := json.Marshal(qwenReq)
if err != nil {
logError("[reqID:%s] 序列化请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 创建HTTP请求
req, err := http.NewRequestWithContext(r.Context(), "POST", TargetURL, bytes.NewBuffer(reqData))
if err != nil {
logError("[reqID:%s] 创建请求失败: %v", reqID, err)
http.Error(w, "内部服务器错误", http.StatusInternalServerError)
return
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
client := getHTTPClient()
resp, err := client.Do(req)
if err != nil {
logError("[reqID:%s] 发送请求失败: %v", reqID, err)
http.Error(w, "连接到API失败", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
logError("[reqID:%s] API返回非200状态码: %d, 响应: %s", reqID, resp.StatusCode, string(bodyBytes))
http.Error(w, fmt.Sprintf("API错误状态码: %d", resp.StatusCode), resp.StatusCode)
return
}
// 解析响应获取任务ID
var qwenResp QwenResponse
if err := json.NewDecoder(resp.Body).Decode(&qwenResp); err != nil {
logError("[reqID:%s] 解析响应失败: %v", reqID, err)
http.Error(w, "解析响应失败", http.StatusInternalServerError)
return
}
// 提取任务ID
taskID := ""
for _, msg := range qwenResp.Messages {
if msg.Role == "assistant" && msg.Extra.Wanx.TaskID != "" {
taskID = msg.Extra.Wanx.TaskID
break
}
}
if taskID == "" {
logError("[reqID:%s] 无法获取图像生成任务ID", reqID)
http.Error(w, "无法获取图像生成任务ID", http.StatusInternalServerError)
return
}
// 轮询等待图像生成完成
var imageURL string
for i := 0; i < 30; i++ {
select {
case <-r.Context().Done():
logWarn("[reqID:%s] 请求超时或被客户端取消", reqID)
http.Error(w, "请求超时", http.StatusGatewayTimeout)
return
default:
// 继续处理
}
// 检查任务状态
statusURL := TasksURL + taskID
statusReq, err := http.NewRequestWithContext(r.Context(), "GET", statusURL, nil)
if err != nil {
logError("[reqID:%s] 创建状态请求失败: %v", reqID, err)
time.Sleep(6 * time.Second)
continue
}
// 设置请求头
statusReq.Header.Set("Authorization", "Bearer "+authToken)
statusReq.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
statusResp, err := client.Do(statusReq)
if err != nil {
logError("[reqID:%s] 发送状态请求失败: %v", reqID, err)
time.Sleep(6 * time.Second)
continue
}
// 解析响应
var statusData TaskStatusResponse
if err := json.NewDecoder(statusResp.Body).Decode(&statusData); err != nil {
logError("[reqID:%s] 解析状态响应失败: %v", reqID, err)
statusResp.Body.Close()
time.Sleep(6 * time.Second)
continue
}
statusResp.Body.Close()
// 检查是否有内容
if statusData.Content != "" {
imageURL = statusData.Content
break
}
time.Sleep(6 * time.Second)
}
if imageURL == "" {
logError("[reqID:%s] 图像生成超时", reqID)
http.Error(w, "图像生成超时", http.StatusGatewayTimeout)
return
}
// 返回OpenAI标准格式响应使用Markdown嵌入图片
completionResponse := CompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", generateUUID()),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: apiReq.Model,
Choices: []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
}{
{
Index: 0,
Message: struct {
Role string `json:"role"`
Content string `json:"content"`
}{
Role: "assistant",
Content: fmt.Sprintf("![%s](%s)", imageURL, imageURL),
},
FinishReason: "stop",
},
},
Usage: struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}{
PromptTokens: 1024,
CompletionTokens: 1024,
TotalTokens: 2048,
},
}
// 返回响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(completionResponse)
}
// 从流式响应中提取完整内容
func extractFullContentFromStream(body io.ReadCloser, reqID string) (string, error) {
var contentBuilder strings.Builder
// 创建读取器
reader := bufio.NewReaderSize(body, 16384)
// 持续读取响应
buffer := ""
for {
// 读取一块数据
chunk := make([]byte, 4096)
n, err := reader.Read(chunk)
if err != nil {
if err != io.EOF {
return contentBuilder.String(), err
}
break
}
// 添加到缓冲区
buffer += string(chunk[:n])
// 更稳健的处理方式:按行分割并只处理完整行
lines := strings.Split(buffer, "\n")
// 保留最后可能不完整的行
if len(lines) > 0 {
buffer = lines[len(lines)-1]
}
// 处理所有完整的行(除最后一行外)
for i := 0; i < len(lines)-1; i++ {
line := lines[i]
if !strings.HasPrefix(line, "data: ") {
continue
}
// 提取数据部分
dataStr := strings.TrimPrefix(line, "data: ")
// 处理[DONE]消息
if dataStr == "[DONE]" {
logDebug("[reqID:%s] 非流式模式收到[DONE]消息", reqID)
continue
}
// 解析JSON
var qwenResp QwenResponse
if err := json.Unmarshal([]byte(dataStr), &qwenResp); err != nil {
logWarn("[reqID:%s] 解析JSON失败: %v, data: %s", reqID, err, dataStr)
continue
}
// 提取内容 - 累积所有delta内容片段
for _, choice := range qwenResp.Choices {
if choice.Delta.Content != "" {
contentBuilder.WriteString(choice.Delta.Content)
}
}
}
}
// 记录提取的内容长度
contentStr := contentBuilder.String()
logInfo("[reqID:%s] 非流式模式:成功提取完整内容,长度: %d", reqID, len(contentStr))
return contentStr, nil
}
// 上传图像到千问API
func uploadImage(base64Data string, authToken string) (string, error) {
// 从base64数据中提取图片数据
if !strings.HasPrefix(base64Data, "data:") {
return "", fmt.Errorf("invalid base64 data format")
}
parts := strings.SplitN(base64Data, ",", 2)
if len(parts) != 2 {
return "", fmt.Errorf("invalid base64 data format")
}
imageData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("failed to decode base64 data: %v", err)
}
// 创建multipart表单
body := bytes.Buffer{}
writer := multipart.NewWriter(&body)
// 添加文件
part, err := writer.CreateFormFile("file", fmt.Sprintf("image-%d.jpg", time.Now().UnixNano()))
if err != nil {
return "", fmt.Errorf("failed to create form file: %v", err)
}
if _, err := part.Write(imageData); err != nil {
return "", fmt.Errorf("failed to write image data: %v", err)
}
// 关闭writer
if err := writer.Close(); err != nil {
return "", fmt.Errorf("failed to close writer: %v", err)
}
// 创建HTTP请求
req, err := http.NewRequest("POST", FilesURL, &body)
if err != nil {
return "", fmt.Errorf("failed to create request: %v", err)
}
// 设置请求头
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Mozilla/5.0")
// 发送请求
client := getHTTPClient()
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %v", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API returned non-200 status code: %d, response: %s", resp.StatusCode, string(bodyBytes))
}
// 解析响应
var uploadResp FileUploadResponse
if err := json.NewDecoder(resp.Body).Decode(&uploadResp); err != nil {
return "", fmt.Errorf("failed to parse response: %v", err)
}
return uploadResp.ID, nil
}
// 创建角色块
func createRoleChunk(id string, created int64, model string) []byte {
chunk := StreamChunk{
ID: id,
Object: "chat.completion.chunk",
Created: created,
Model: model,
Choices: []struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}{
{
Index: 0,
Delta: struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}{
Role: "assistant",
},
},
},
}
data, _ := json.Marshal(chunk)
return data
}
// 创建内容块
func createContentChunk(id string, created int64, model string, content string) []byte {
chunk := StreamChunk{
ID: id,
Object: "chat.completion.chunk",
Created: created,
Model: model,
Choices: []struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}{
{
Index: 0,
Delta: struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}{
Content: content,
},
},
},
}
data, _ := json.Marshal(chunk)
return data
}
// 创建完成块
func createDoneChunk(id string, created int64, model string, reason string) []byte {
finishReason := reason
chunk := StreamChunk{
ID: id,
Object: "chat.completion.chunk",
Created: created,
Model: model,
Choices: []struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}{
{
Index: 0,
Delta: struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
}{},
FinishReason: &finishReason,
},
},
}
data, _ := json.Marshal(chunk)
return data
}
// 估算tokens简单实现
func estimateTokens(messages []APIMessage) int {
var total int
for _, msg := range messages {
switch content := msg.Content.(type) {
case string:
total += len(content) / 4
case []interface{}:
for _, item := range content {
if itemMap, ok := item.(map[string]interface{}); ok {
if text, ok := itemMap["text"].(string); ok {
total += len(text) / 4
}
}
}
}
}
return total
}