add: enhance Novita AI API response structure and implement token usage tracking and handle error situations

This commit is contained in:
dashen
2024-12-30 11:35:08 +08:00
parent 49f54f45ba
commit c444b0400d
2 changed files with 263 additions and 5 deletions

View File

@@ -74,17 +74,43 @@ func (c *Chat) ResponseConvert(ctx eocontext.EoContext) error {
if err != nil {
return err
}
if httpContext.Response().StatusCode() != 200 {
return nil
}
body := httpContext.Response().GetBody()
log.Println(string(body))
log.Println(httpContext.Response().StatusCode())
data := eosc.NewBase[Response]()
err = json.Unmarshal(body, data)
if err != nil {
return err
}
// 401 INVALID_API_KEY The API key is invalid. You can check your API key here: Manage API Key
// 403 NOT_ENOUGH_BALANCE Your credit is not enough. You can top up more credit here: Top Up Credit
// 404 MODEL_NOT_FOUND The requested model is not found. You can find all the models we support here: https://novita.ai/llm-api or request the Models API to get all available models.
// 429 RATE_LIMIT_EXCEEDED You have exceeded the rate limit. Please refer to Rate Limits for more information.
// 500 MODEL_NOT_AVAILABLE The requested model is not available now. This is usually due to the model being under maintenance. You can contact us on Discord for more information.
switch httpContext.Response().StatusCode() {
case 200:
// Calculate the token consumption for a successful request.
usage := data.Config.Usage
ai_provider.SetAIStatusNormal(ctx)
ai_provider.SetAIModelInputToken(ctx, usage.PromptTokens)
ai_provider.SetAIModelOutputToken(ctx, usage.CompletionTokens)
ai_provider.SetAIModelTotalToken(ctx, usage.TotalTokens)
case 400:
// Handle the bad request error.
ai_provider.SetAIStatusInvalidRequest(ctx)
case 401:
// Handle the key error.
ai_provider.SetAIStatusInvalid(ctx)
case 403:
// handle credit is exhausted
ai_provider.SetAIStatusQuotaExhausted(ctx)
case 429:
// Handle the rate limit error.
ai_provider.SetAIStatusExceeded(ctx)
default:
ai_provider.SetAIStatusInvalidRequest(ctx)
}
responseBody := &ai_provider.ClientResponse{}
if len(data.Config.Choices) > 0 {
msg := data.Config.Choices[0]

View File

@@ -0,0 +1,232 @@
package novita
import (
"fmt"
"net/url"
"os"
"testing"
"time"
"github.com/eolinker/eosc/eocontext"
"github.com/eolinker/apinto/convert"
ai_provider "github.com/eolinker/apinto/drivers/ai-provider"
http_context "github.com/eolinker/apinto/node/http-context"
"github.com/joho/godotenv"
"github.com/valyala/fasthttp"
)
var (
defaultConfig = `{
"frequency_penalty": "",
"max_tokens": 512,
"presence_penalty": "",
"response_format": "",
"temperature": "",
"top_p": "0.1"
}`
successBody = []byte(`{
"messages": [
{
"content": "Hello, how can I help you?",
"role": "user"
}
]
}`)
failBody = []byte(`{
"messages": [
{
"content": "Hello, how can I help you?",
"role": "yyy"
}
],"top_p":"0.0"
}`)
)
func validNormalFunc(ctx eocontext.EoContext) bool {
fmt.Printf("input token: %d\n", ai_provider.GetAIModelInputToken(ctx))
fmt.Printf("output token: %d\n", ai_provider.GetAIModelOutputToken(ctx))
fmt.Printf("total token: %d\n", ai_provider.GetAIModelTotalToken(ctx))
if ai_provider.GetAIModelInputToken(ctx) <= 0 {
return false
}
if ai_provider.GetAIModelOutputToken(ctx) <= 0 {
return false
}
return ai_provider.GetAIModelTotalToken(ctx) > 0
}
// TestSentTo tests the end-to-end execution of the novita integration.
func TestSentTo(t *testing.T) {
// Load .env file
err := godotenv.Load(".env")
if err != nil {
t.Fatalf("Error loading .env file: %v", err)
}
// Test data for different scenarios
testData := []struct {
name string
apiKey string
wantStatus string
body []byte
validFunc func(ctx eocontext.EoContext) bool
}{
{
name: "success",
apiKey: os.Getenv("ValidKey"),
wantStatus: ai_provider.StatusNormal,
body: successBody,
validFunc: validNormalFunc,
},
{
name: "invalid request",
apiKey: os.Getenv("ValidKey"),
wantStatus: ai_provider.StatusInvalidRequest,
body: failBody,
},
{
name: "invalid key",
apiKey: os.Getenv("InvalidKey"),
wantStatus: ai_provider.StatusInvalid,
},
{
name: "expired key",
apiKey: os.Getenv("ExpiredKey"),
wantStatus: ai_provider.StatusInvalid,
},
}
// Run tests for each scenario
for _, data := range testData {
t.Run(data.name, func(t *testing.T) {
if err := runTest(data.apiKey, data.body, data.wantStatus, data.validFunc); err != nil {
t.Fatalf("Test failed: %v", err)
}
})
}
}
// runTest handles a single test case
func runTest(apiKey string, requestBody []byte, wantStatus string, validFunc func(ctx eocontext.EoContext) bool) error {
cfg := &Config{
APIKey: apiKey,
Organization: "",
}
// Create the worker
worker, err := Create("novita", "novita", cfg, nil)
if err != nil {
return fmt.Errorf("failed to create worker: %w", err)
}
// Get the handler
handler, ok := worker.(convert.IConverterDriver)
if !ok {
return fmt.Errorf("worker does not implement IConverterDriver")
}
// Default to success body if no body is provided
if len(requestBody) == 0 {
requestBody = successBody
}
// Mock HTTP context
ctx := createMockHttpContext("/xxx/xxx", nil, nil, requestBody)
// Execute the conversion process
err = executeConverter(ctx, handler, "meta-llama/llama-3.1-8b-instruct", "https://api.novita.ai")
if err != nil {
return fmt.Errorf("failed to execute conversion process: %w", err)
}
// Check the status
if ai_provider.GetAIStatus(ctx) != wantStatus {
return fmt.Errorf("unexpected status: got %s, expected %s", ai_provider.GetAIStatus(ctx), wantStatus)
}
if validFunc != nil {
if validFunc(ctx) {
return nil
}
return fmt.Errorf("execute validFunc failed")
}
return nil
}
// executeConverter handles the full flow of a conversion process.
func executeConverter(ctx *http_context.HttpContext, handler convert.IConverterDriver, model string, baseUrl string) error {
// Balance handler setup
balanceHandler, err := ai_provider.NewBalanceHandler("test", baseUrl, 30*time.Second)
if err != nil {
return fmt.Errorf("failed to create balance handler: %w", err)
}
// Get model function
fn, has := handler.GetModel(model)
if !has {
return fmt.Errorf("model %s not found", model)
}
// Generate config
extender, err := fn(defaultConfig)
if err != nil {
return fmt.Errorf("failed to generate config: %w", err)
}
// Get converter
converter, has := handler.GetConverter(model)
if !has {
return fmt.Errorf("converter for model %s not found", model)
}
// Convert request
if err := converter.RequestConvert(ctx, extender); err != nil {
return fmt.Errorf("request conversion failed: %w", err)
}
// Select node via balance handler
node, _, err := balanceHandler.Select(ctx)
if err != nil {
return fmt.Errorf("node selection failed: %w", err)
}
// Send request to the node
if err := ctx.SendTo(balanceHandler.Scheme(), node, balanceHandler.TimeOut()); err != nil {
return fmt.Errorf("failed to send request to node: %w", err)
}
// Convert response
if err := converter.ResponseConvert(ctx); err != nil {
return fmt.Errorf("response conversion failed: %w", err)
}
return nil
}
// createMockHttpContext creates a mock fasthttp.RequestCtx and wraps it with HttpContext.
func createMockHttpContext(rawURL string, headers map[string]string, query url.Values, body []byte) *http_context.HttpContext {
req := fasthttp.AcquireRequest()
u := fasthttp.AcquireURI()
// Set request URI and path
uri, _ := url.Parse(rawURL)
u.SetPath(uri.Path)
u.SetScheme(uri.Scheme)
u.SetHost(uri.Host)
u.SetQueryString(uri.RawQuery)
req.SetURI(u)
req.Header.SetMethod("POST")
// Set headers
for k, v := range headers {
req.Header.Set(k, v)
}
req.SetBody(body)
// Create HttpContext
return http_context.NewContext(&fasthttp.RequestCtx{
Request: *req,
Response: fasthttp.Response{},
}, 8099)
}