diff --git a/drivers/ai-provider/openrouter/message.go b/drivers/ai-provider/openrouter/message.go index abfeb7cb..d0897c13 100644 --- a/drivers/ai-provider/openrouter/message.go +++ b/drivers/ai-provider/openrouter/message.go @@ -16,6 +16,7 @@ type Response struct { Model string `json:"model"` Choices []ResponseChoice `json:"choices"` Usage Usage `json:"usage"` + Error *Error `json:"error"` } type ResponseChoice struct { @@ -33,3 +34,14 @@ type Usage struct { type CompletionTokensDetails struct { ReasoningTokens int `json:"reasoning_tokens"` } + +// Error represents the error response from the provider. +// {"error":{"message":"Provider returned error","code":400,"metadata":{"raw":"{\n \"error\": {\n \"message\": \"Invalid value: 'yyy'. Supported values are: 'system', 'assistant', 'user', 'function', 'tool', and 'developer'.\",\n \"type\": \"invalid_request_error\",\n \"param\": \"messages[0].role\",\n \"code\": \"invalid_value\"\n }\n}","provider_name":"OpenAI"}},"user_id":"user_2nQFDPHnNOxsrry6JpmcPXFzfnC"} +type Error struct { + Message string `json:"message"` + Code int `json:"code"` + Metadata struct { + Raw string `json:"raw"` + ProviderName string `json:"provider_name"` + } `json:"metadata"` +} diff --git a/drivers/ai-provider/openrouter/mode.go b/drivers/ai-provider/openrouter/mode.go index cf943ab8..a54b9902 100644 --- a/drivers/ai-provider/openrouter/mode.go +++ b/drivers/ai-provider/openrouter/mode.go @@ -74,15 +74,61 @@ func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error { if err != nil { return err } - if httpContext.Response().StatusCode() != 200 { - return nil - } + body := httpContext.Response().GetBody() data := eosc.NewBase[Response]() err = json.Unmarshal(body, data) if err != nil { return err } + + // 400: Bad Request (invalid or missing params, CORS) + // 401: Invalid credentials (OAuth session expired, disabled/invalid API key) + // 402: Your account or API key has insufficient credits. Add more credits and retry the request. + // 403: Your chosen model requires moderation and your input was flagged + // 408: Your request timed out + // 429: You are being rate limited + // 502: Your chosen model is down or we received an invalid response from it + // 503: There is no available model provider that meets your routing requirements + switch httpContext.Response().StatusCode() { + case 200: + if data.Config.Error != nil { + // Handle the error response. + switch data.Config.Error.Code { + case 400: + ai_provider.SetAIStatusInvalidRequest(ctx) + case 401: + ai_provider.SetAIStatusInvalid(ctx) + case 402: + ai_provider.SetAIStatusQuotaExhausted(ctx) + case 429: + ai_provider.SetAIStatusExceeded(ctx) + default: + ai_provider.SetAIStatusInvalidRequest(ctx) + } + } else { + // Calculate the token consumption for a successful request. + usage := data.Config.Usage + ai_provider.SetAIStatusNormal(ctx) + ai_provider.SetAIModelInputToken(ctx, usage.PromptTokens) + ai_provider.SetAIModelOutputToken(ctx, usage.CompletionTokens) + ai_provider.SetAIModelTotalToken(ctx, usage.TotalTokens) + } + case 400: + // Handle the bad request error. + ai_provider.SetAIStatusInvalidRequest(ctx) + case 401: + // Handle the invalid key error. + ai_provider.SetAIStatusInvalid(ctx) + case 402: + // Handle the expired key error. + ai_provider.SetAIStatusQuotaExhausted(ctx) + case 429: + ai_provider.SetAIStatusExceeded(ctx) + default: + ai_provider.SetAIStatusInvalidRequest(ctx) + } + responseBody := &ai_provider.ClientResponse{} if len(data.Config.Choices) > 0 { msg := data.Config.Choices[0] diff --git a/drivers/ai-provider/openrouter/openrouter_test.go b/drivers/ai-provider/openrouter/openrouter_test.go new file mode 100644 index 00000000..88b8c96d --- /dev/null +++ b/drivers/ai-provider/openrouter/openrouter_test.go @@ -0,0 +1,232 @@ +package openrouter + +import ( + "fmt" + "net/url" + "os" + "testing" + "time" + + "github.com/eolinker/eosc/eocontext" + + "github.com/eolinker/apinto/convert" + ai_provider "github.com/eolinker/apinto/drivers/ai-provider" + http_context "github.com/eolinker/apinto/node/http-context" + "github.com/joho/godotenv" + "github.com/valyala/fasthttp" +) + +var ( + defaultConfig = `{ + "frequency_penalty": "", + "max_tokens": 512, + "presence_penalty": "", + "response_format": "", + "temperature": "", + "top_p": "" + }` + successBody = []byte(`{ + "messages": [ + { + "content": "Hello, how can I help you?", + "role": "assistant" + } + ] + }`) + failBody = []byte(`{ + "messages": [ + { + "content": "Hello, how can I help you?", + "role": "yyy" + } + ],"variables":{} + }`) +) + +func validNormalFunc(ctx eocontext.EoContext) bool { + fmt.Printf("input token: %d\n", ai_provider.GetAIModelInputToken(ctx)) + fmt.Printf("output token: %d\n", ai_provider.GetAIModelOutputToken(ctx)) + fmt.Printf("total token: %d\n", ai_provider.GetAIModelTotalToken(ctx)) + if ai_provider.GetAIModelInputToken(ctx) <= 0 { + return false + } + if ai_provider.GetAIModelOutputToken(ctx) <= 0 { + return false + } + return ai_provider.GetAIModelTotalToken(ctx) > 0 +} + +// TestSentTo tests the end-to-end execution of the openrouter integration. +func TestSentTo(t *testing.T) { + // Load .env file + err := godotenv.Load(".env") + if err != nil { + t.Fatalf("Error loading .env file: %v", err) + } + + // Test data for different scenarios + testData := []struct { + name string + apiKey string + wantStatus string + body []byte + validFunc func(ctx eocontext.EoContext) bool + }{ + { + name: "success", + apiKey: os.Getenv("ValidKey"), + wantStatus: ai_provider.StatusNormal, + body: successBody, + validFunc: validNormalFunc, + }, + { + name: "invalid request", + apiKey: os.Getenv("ValidKey"), + wantStatus: ai_provider.StatusInvalidRequest, + body: failBody, + }, + { + name: "invalid key", + apiKey: os.Getenv("InvalidKey"), + wantStatus: ai_provider.StatusInvalid, + }, + { + name: "expired key", + apiKey: os.Getenv("ExpiredKey"), + wantStatus: ai_provider.StatusInvalid, + }, + } + + // Run tests for each scenario + for _, data := range testData { + t.Run(data.name, func(t *testing.T) { + if err := runTest(data.apiKey, data.body, data.wantStatus, data.validFunc); err != nil { + t.Fatalf("Test failed: %v", err) + } + }) + } +} + +// runTest handles a single test case +func runTest(apiKey string, requestBody []byte, wantStatus string, validFunc func(ctx eocontext.EoContext) bool) error { + cfg := &Config{ + APIKey: apiKey, + Organization: "", + } + + // Create the worker + worker, err := Create("openrouter", "openrouter", cfg, nil) + if err != nil { + return fmt.Errorf("failed to create worker: %w", err) + } + + // Get the handler + handler, ok := worker.(convert.IConverterDriver) + if !ok { + return fmt.Errorf("worker does not implement IConverterDriver") + } + + // Default to success body if no body is provided + if len(requestBody) == 0 { + requestBody = successBody + } + + // Mock HTTP context + ctx := createMockHttpContext("/xxx/xxx", nil, nil, requestBody) + + // Execute the conversion process + err = executeConverter(ctx, handler, "openai/gpt-3.5-turbo", "https://openrouter.ai") + if err != nil { + return fmt.Errorf("failed to execute conversion process: %w", err) + } + + // Check the status + if ai_provider.GetAIStatus(ctx) != wantStatus { + return fmt.Errorf("unexpected status: got %s, expected %s", ai_provider.GetAIStatus(ctx), wantStatus) + } + if validFunc != nil { + if validFunc(ctx) { + return nil + } + return fmt.Errorf("execute validFunc failed") + } + + return nil +} + +// executeConverter handles the full flow of a conversion process. +func executeConverter(ctx *http_context.HttpContext, handler convert.IConverterDriver, model string, baseUrl string) error { + // Balance handler setup + balanceHandler, err := ai_provider.NewBalanceHandler("test", baseUrl, 30*time.Second) + if err != nil { + return fmt.Errorf("failed to create balance handler: %w", err) + } + + // Get model function + fn, has := handler.GetModel(model) + if !has { + return fmt.Errorf("model %s not found", model) + } + + // Generate config + extender, err := fn(defaultConfig) + if err != nil { + return fmt.Errorf("failed to generate config: %w", err) + } + + // Get converter + converter, has := handler.GetConverter(model) + if !has { + return fmt.Errorf("converter for model %s not found", model) + } + + // Convert request + if err := converter.RequestConvert(ctx, extender); err != nil { + return fmt.Errorf("request conversion failed: %w", err) + } + + // Select node via balance handler + node, _, err := balanceHandler.Select(ctx) + if err != nil { + return fmt.Errorf("node selection failed: %w", err) + } + + // Send request to the node + if err := ctx.SendTo(balanceHandler.Scheme(), node, balanceHandler.TimeOut()); err != nil { + return fmt.Errorf("failed to send request to node: %w", err) + } + + // Convert response + if err := converter.ResponseConvert(ctx); err != nil { + return fmt.Errorf("response conversion failed: %w", err) + } + + return nil +} + +// createMockHttpContext creates a mock fasthttp.RequestCtx and wraps it with HttpContext. +func createMockHttpContext(rawURL string, headers map[string]string, query url.Values, body []byte) *http_context.HttpContext { + req := fasthttp.AcquireRequest() + u := fasthttp.AcquireURI() + + // Set request URI and path + uri, _ := url.Parse(rawURL) + u.SetPath(uri.Path) + u.SetScheme(uri.Scheme) + u.SetHost(uri.Host) + u.SetQueryString(uri.RawQuery) + req.SetURI(u) + req.Header.SetMethod("POST") + + // Set headers + for k, v := range headers { + req.Header.Set(k, v) + } + req.SetBody(body) + + // Create HttpContext + return http_context.NewContext(&fasthttp.RequestCtx{ + Request: *req, + Response: fasthttp.Response{}, + }, 8099) +}