[BugFix] fix control signal release failed (#3374)

* [BugFix]

* [BugFix]

* [BugFix]

* [BugFix]

* fix

* fix

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
ltd0924
2025-08-14 17:01:25 +08:00
committed by GitHub
parent b2df0311b8
commit 03347626a6
4 changed files with 51 additions and 44 deletions

View File

@@ -165,9 +165,9 @@ async def connection_manager():
yield yield
except asyncio.TimeoutError: except asyncio.TimeoutError:
api_server_logger.info(f"Reach max request release: {connection_semaphore.status()}") api_server_logger.info(f"Reach max request release: {connection_semaphore.status()}")
if connection_semaphore.locked(): raise HTTPException(
connection_semaphore.release() status_code=429, detail=f"Too many requests, current max concurrency is {args.max_concurrency}"
raise HTTPException(status_code=429, detail="Too many requests") )
def wrap_streaming_generator(original_generator: AsyncGenerator): def wrap_streaming_generator(original_generator: AsyncGenerator):
@@ -180,7 +180,7 @@ def wrap_streaming_generator(original_generator: AsyncGenerator):
async for chunk in original_generator: async for chunk in original_generator:
yield chunk yield chunk
finally: finally:
api_server_logger.debug(f"release: {connection_semaphore.status()}") api_server_logger.debug(f"current concurrency status: {connection_semaphore.status()}")
connection_semaphore.release() connection_semaphore.release()
return wrapped_generator return wrapped_generator
@@ -255,9 +255,11 @@ async def create_chat_completion(request: ChatCompletionRequest):
generator = await app.state.chat_handler.create_chat_completion(request) generator = await app.state.chat_handler.create_chat_completion(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
connection_semaphore.release() connection_semaphore.release()
api_server_logger.debug(f"current concurrency status: {connection_semaphore.status()}")
return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code) return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code)
elif isinstance(generator, ChatCompletionResponse): elif isinstance(generator, ChatCompletionResponse):
connection_semaphore.release() connection_semaphore.release()
api_server_logger.debug(f"current concurrency status: {connection_semaphore.status()}")
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
else: else:
wrapped_generator = wrap_streaming_generator(generator) wrapped_generator = wrap_streaming_generator(generator)

View File

@@ -78,6 +78,13 @@ class OpenAIServingChat:
api_server_logger.error(err_msg) api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400) return ErrorResponse(message=err_msg, code=400)
try:
if self.max_waiting_time < 0:
await self.engine_client.semaphore.acquire()
else:
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
api_server_logger.debug(f"current waiting request {self.engine_client.semaphore.status()}")
if request.user is not None: if request.user is not None:
request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}" request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}"
else: else:
@@ -96,15 +103,6 @@ class OpenAIServingChat:
del current_req_dict del current_req_dict
try:
api_server_logger.debug(f"{self.engine_client.semaphore.status()}")
if self.max_waiting_time < 0:
await self.engine_client.semaphore.acquire()
else:
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
except Exception:
return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}")
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
request, request_id, request.model, prompt_token_ids, text_after_process request, request_id, request.model, prompt_token_ids, text_after_process
@@ -116,6 +114,8 @@ class OpenAIServingChat:
) )
except Exception as e: except Exception as e:
return ErrorResponse(code=400, message=str(e)) return ErrorResponse(code=400, message=str(e))
except Exception:
return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}")
def _create_streaming_error_response(self, message: str) -> str: def _create_streaming_error_response(self, message: str) -> str:
error_response = ErrorResponse( error_response = ErrorResponse(

View File

@@ -101,6 +101,13 @@ class OpenAIServingCompletion:
api_server_logger.info(f"start inference for request {num_choices}") api_server_logger.info(f"start inference for request {num_choices}")
prompt_batched_token_ids = [] prompt_batched_token_ids = []
text_after_process_list = [] text_after_process_list = []
try:
if self.max_waiting_time < 0:
await self.engine_client.semaphore.acquire()
else:
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
except Exception:
return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}")
try: try:
for idx, prompt in enumerate(request_prompts): for idx, prompt in enumerate(request_prompts):
request_id_idx = f"{request_id}-{idx}" request_id_idx = f"{request_id}-{idx}"
@@ -117,14 +124,6 @@ class OpenAIServingCompletion:
del current_req_dict del current_req_dict
try:
if self.max_waiting_time < 0:
await self.engine_client.semaphore.acquire()
else:
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
except Exception:
return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}")
if request.stream: if request.stream:
return self.completion_stream_generator( return self.completion_stream_generator(
request=request, request=request,

View File

@@ -67,6 +67,7 @@ class ZmqClient:
""" """
self.router = self.context.socket(zmq.ROUTER) self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router.setsockopt(zmq.SNDTIMEO, -1) self.router.setsockopt(zmq.SNDTIMEO, -1)
self.router.bind(f"ipc://{self.router_path}") self.router.bind(f"ipc://{self.router_path}")
@@ -111,7 +112,6 @@ class ZmqClient:
""" """
if self.router is None: if self.router is None:
raise RuntimeError("Router socket not created. Call create_router() first.") raise RuntimeError("Router socket not created. Call create_router() first.")
while self.running: while self.running:
with self.mutex: with self.mutex:
if req_id not in self.req_dict: if req_id not in self.req_dict:
@@ -124,7 +124,11 @@ class ZmqClient:
continue continue
else: else:
break break
if self.req_dict[req_id] == -1:
if data[-1].finished:
with self.mutex:
self.req_dict.pop(req_id, None)
return
try: try:
start_send = time.time() start_send = time.time()
if self.aggregate_send: if self.aggregate_send:
@@ -133,7 +137,9 @@ class ZmqClient:
result = msgpack.packb([response.to_dict() for response in data]) result = msgpack.packb([response.to_dict() for response in data])
self.router.send_multipart([self.req_dict[req_id], b"", result]) self.router.send_multipart([self.req_dict[req_id], b"", result])
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
except zmq.ZMQError as e:
llm_logger.error(f"[{req_id}] zmq error: {e}")
self.req_dict[req_id] = -1
except Exception as e: except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}") llm_logger.error(f"Send result to zmq client failed: {e}")