[BugFix] fix prompt token ids type (#2994)

* Update serving_completion.py

* fix

* fix
This commit is contained in:
ltd0924
2025-07-23 21:00:56 +08:00
committed by GitHub
parent 5d1788c7b5
commit fb0f284e67
2 changed files with 8 additions and 4 deletions

View File

@@ -20,7 +20,7 @@ import time
import traceback import traceback
import uuid import uuid
from typing import List, Optional from typing import List, Optional
import numpy as np
import msgpack import msgpack
import aiozmq import aiozmq
from aiozmq import zmq from aiozmq import zmq
@@ -75,6 +75,8 @@ class OpenAIServingChat:
current_req_dict = request.to_dict_for_infer(request_id) current_req_dict = request.to_dict_for_infer(request_id)
current_req_dict["arrival_time"] = time.time() current_req_dict["arrival_time"] = time.time()
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
if isinstance(prompt_token_ids, np.ndarray):
prompt_token_ids = prompt_token_ids.tolist()
except Exception as e: except Exception as e:
return ErrorResponse(code=400, message=str(e)) return ErrorResponse(code=400, message=str(e))

View File

@@ -18,6 +18,7 @@ import asyncio
import aiozmq import aiozmq
import json import json
import msgpack import msgpack
import numpy as np
from aiozmq import zmq from aiozmq import zmq
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
import time import time
@@ -105,9 +106,10 @@ class OpenAIServingCompletion:
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt) current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
try: try:
current_req_dict["arrival_time"] = time.time() current_req_dict["arrival_time"] = time.time()
prompt_batched_token_ids.append( prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
self.engine_client.format_and_add_data(current_req_dict) if isinstance(prompt_token_ids, np.ndarray):
) prompt_token_ids = prompt_token_ids.tolist()
prompt_batched_token_ids.append(prompt_token_ids)
except Exception as e: except Exception as e:
return ErrorResponse(message=str(e), code=400) return ErrorResponse(message=str(e), code=400)