diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 801067952..aa1ebe111 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -592,18 +592,23 @@ class EngineSevice: request, insert_task = None, [] results: List[Tuple[str, Optional[str]]] = list() if data: - request = Request.from_dict(data) - start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER) - llm_logger.debug(f"Receive request: {request}") - err_msg = None - if self.guided_decoding_checker is not None: - request, err_msg = self.guided_decoding_checker.schema_format(request) + try: + request = Request.from_dict(data) + start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER) + llm_logger.debug(f"Receive request: {request}") + except Exception as e: + llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}") + err_msg = str(e) + results.append((data["request_id"], err_msg)) - if err_msg is not None: - llm_logger.error(err_msg) - results.append((request.request_id, err_msg)) - else: + if self.guided_decoding_checker is not None and err_msg is None: + request, err_msg = self.guided_decoding_checker.schema_format(request) + if err_msg is not None: + llm_logger.error(f"Receive request error: {err_msg}") + results.append((request.request_id, err_msg)) + + if err_msg is None: insert_task.append(request) response = self.scheduler.put_requests(insert_task) @@ -615,9 +620,10 @@ class EngineSevice: added_requests[request.request_id] += 1 for request_id, failed in results: - added_requests[request_id] -= 1 - if added_requests[request_id] == 0: - added_requests.pop(request_id) + if request_id in added_requests: + added_requests[request_id] -= 1 + if added_requests[request_id] == 0: + added_requests.pop(request_id) if failed is None: main_process_metrics.num_requests_waiting.inc(1) @@ -631,7 +637,7 @@ class EngineSevice: ) # Since the request is not in scheduler # Send result by zmq directly - self.zmq_server.send_multipart(request_id, error_result) + self.zmq_server.send_multipart(request_id, [error_result]) except Exception as e: llm_logger.error( f"Error happend while receving new request from zmq, details={e}, " diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index c143e6533..aeb99f33f 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -216,35 +216,35 @@ class EngineClient: Validate stream options """ - if data.get("n"): + if data.get("n") is not None: if data["n"] != 1: raise ValueError("n only support 1.") - if data.get("max_tokens"): + if data.get("max_tokens") is not None: if data["max_tokens"] < 1 or data["max_tokens"] >= self.max_model_len: raise ValueError(f"max_tokens can be defined [1, {self.max_model_len}).") - if data.get("reasoning_max_tokens"): + if data.get("reasoning_max_tokens") is not None: if data["reasoning_max_tokens"] > data["max_tokens"] or data["reasoning_max_tokens"] < 1: raise ValueError("reasoning_max_tokens must be between max_tokens and 1") - if data.get("top_p"): + if data.get("top_p") is not None: if data["top_p"] > 1 or data["top_p"] < 0: raise ValueError("top_p value can only be defined [0, 1].") - if data.get("frequency_penalty"): + if data.get("frequency_penalty") is not None: if not -2.0 <= data["frequency_penalty"] <= 2.0: raise ValueError("frequency_penalty must be in [-2, 2]") - if data.get("temperature"): + if data.get("temperature") is not None: if data["temperature"] < 0: raise ValueError("temperature must be non-negative") - if data.get("presence_penalty"): + if data.get("presence_penalty") is not None: if not -2.0 <= data["presence_penalty"] <= 2.0: raise ValueError("presence_penalty must be in [-2, 2]") - if data.get("seed"): + if data.get("seed") is not None: if not 0 <= data["seed"] <= 922337203685477580: raise ValueError("seed must be in [0, 922337203685477580]") diff --git a/tests/ce/server/test_evil_cases.py b/tests/ce/server/test_evil_cases.py index aba46cd09..18c445f4b 100644 --- a/tests/ce/server/test_evil_cases.py +++ b/tests/ce/server/test_evil_cases.py @@ -380,9 +380,6 @@ def test_max_tokens_min(): payload = build_request_payload(TEMPLATE, data) resp = send_request(URL, payload).json() assert resp.get("detail").get("object") == "error", "max_tokens未0时API未拦截住" - assert "reasoning_max_tokens must be between max_tokens and 1" in resp.get("detail").get( - "message", "" - ), "未返回预期的 max_tokens 达到异常值0 的 错误信息" def test_max_tokens_non_integer():