Files
FastDeploy/fastdeploy/engine/async_llm.py
zhouchong 5d9b5e4a5b [Engine] [Feature] Refactor async_llm:cross-process with EngineService,based on zmq communication (#4868)
* Refactor async_llm:cross-process with EngineService

* fix: async_llm output process

* fix: return prompt_token_ids and prompt_tokens in first res

* optimize common_engine start func
2025-12-09 10:53:40 +08:00

620 lines
23 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import inspect
import os
import signal
import time
import traceback
import uuid
import weakref
from dataclasses import asdict
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import numpy as np
import zmq
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.common_engine import EngineService
from fastdeploy.engine.request import RequestOutput
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.inter_communicator.zmq_client import ZmqIpcClient
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import EngineError, llm_logger
class AsyncOutputProcessor:
"""Async output processor responsible for distributing engine outputs to corresponding request queues"""
def __init__(self, data_processor=None):
"""
Args:
data_processor: The data processor created by InputPreprocessor,
used to post-process RequestOutput (decode token_ids, reasoning, tools, etc.).
"""
self.data_processor = data_processor
def _process_output(
self,
response_dict: Dict[str, Any],
stream: bool = True,
enable_thinking: bool = False,
include_stop_str_in_output: bool = False,
) -> Dict[str, Any]:
"""Process a single response dict via data_processor.process_response_dict.
This mirrors the behavior of ChatResponseProcessor in the OpenAI serving
path: operate on a dict representation and return a dict. On any error
we fall back to the original dict and ensure ``outputs.text`` exists to
avoid cascading failures.
"""
try:
processed = self.data_processor.process_response_dict(
response_dict,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
# Some processors may return None when there is no valid text.
if processed is None:
outputs = response_dict.get("outputs") or {}
if "text" not in outputs:
outputs["text"] = ""
response_dict["outputs"] = outputs
return response_dict
return processed
except Exception:
outputs = response_dict.get("outputs") or {}
if "text" not in outputs:
outputs["text"] = ""
response_dict["outputs"] = outputs
return response_dict
class EngineServiceClient:
"""
Base engine service client, responsible for managing EngineService lifecycle.
"""
def __init__(self, cfg, pid):
self.cfg = cfg
self.engine_process = None
self.engine_pid = pid
self._running = False
llm_logger.info(f"EngineServiceClient initialized with engine_pid: {self.engine_pid}")
async def start(self):
"""Start engine service process"""
try:
# Start independent engine process
self._start_engine_process()
# Wait for engine to be ready
if not self._wait_engine_ready():
raise EngineError("Engine failed to start within timeout", error_code=500)
self._running = True
llm_logger.info("EngineServiceClient started successfully")
except Exception as e:
llm_logger.error(f"Failed to start EngineServiceClient: {e}")
raise
return True
def _start_engine_process(self):
"""Start engine process"""
try:
import multiprocessing
self.shutdown_signal = multiprocessing.Value("i", 0) # 0=running, 1=shutdown
def run_engine():
engine = None
def signal_handler(signum, frame):
llm_logger.info(f"Engine process received signal {signum}, initiating shutdown...")
if engine:
engine.running = False
# Register signal handlers
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
try:
engine = EngineService(self.cfg, use_async_llm=True)
# Start engine with ZMQ service
engine.start(async_llm_pid=self.engine_pid)
# Keep engine running until shutdown signal is received
while self.shutdown_signal.value == 0 and getattr(engine, "running", True):
time.sleep(0.5)
except Exception as e:
llm_logger.error(f"Engine process error: {e}, {str(traceback.format_exc())}")
finally:
if engine and hasattr(engine, "_exit_sub_services"):
try:
engine._exit_sub_services()
llm_logger.info("Engine process cleanup completed")
except Exception as e:
llm_logger.error(f"Error during engine cleanup: {e}")
self.engine_process = multiprocessing.Process(target=run_engine)
self.engine_process.start()
llm_logger.info(f"Started engine process with PID: {self.engine_process.pid}")
except Exception as e:
llm_logger.error(f"Failed to start engine process: {e}")
raise
def _wait_engine_ready(self) -> bool:
"""Wait for engine and workers to be fully ready"""
max_wait_time = 180 # seconds
wait_interval = 1
elapsed_time = 0
llm_logger.info("Waiting for engine and workers to be ready...")
# Use IPC signals to check engine readiness
# Get the correct suffix
ipc_suffix = (
self.cfg.parallel_config.engine_worker_queue_port[0]
if hasattr(self.cfg, "parallel_config")
else self.engine_pid
)
# Check if loaded_model_signal exists and is ready
loaded_model_signal = None
while elapsed_time < max_wait_time:
# Try to connect to loaded_model_signal
if loaded_model_signal is None:
try:
loaded_model_signal = IPCSignal(
name="loaded_model_signal",
array=np.zeros([1], dtype=np.int32),
dtype=np.int32,
suffix=ipc_suffix,
create=False,
)
except:
# Signal not ready yet
time.sleep(wait_interval)
elapsed_time += wait_interval
continue
# Check if workers have loaded models
if loaded_model_signal.value[0] > 0:
llm_logger.info("Workers have loaded models successfully")
# Give ZMQ service more time to fully start
llm_logger.info("Waiting additional time for ZMQ service to be ready...")
time.sleep(5) # Wait for ZMQ service startup + recv_result_handle
return True
time.sleep(wait_interval)
elapsed_time += wait_interval
if elapsed_time % 10 == 0: # Log every 10 seconds
llm_logger.info(f"Waiting for workers to load models... ({elapsed_time}s)")
return False
def shutdown(self):
"""Shutdown engine service process"""
llm_logger.info("Shutting down EngineServiceClient...")
self._running = False
# Send graceful shutdown signal to engine process
if hasattr(self, "shutdown_signal"):
llm_logger.info("Sending shutdown signal to engine process...")
self.shutdown_signal.value = 1
# Wait for engine process to shutdown
if self.engine_process and self.engine_process.is_alive():
llm_logger.info("Waiting for engine process to shutdown...")
self.engine_process.terminate()
self.engine_process.join(timeout=5)
if self.engine_process.is_alive():
llm_logger.warning("Force killing engine process...")
self.engine_process.kill()
llm_logger.info("EngineServiceClient shutdown completed")
class AsyncLLM(EngineServiceClient):
"""
Engine class responsible for managing the Large Language Model (LLM) operations.
Attributes:
cfg (Config): Configuration object containing all the parameters.
cached_generated_tokens (queue.Queue): Queue to store generated tokens.
scheduler (LocalScheduler or GlobalScheduler): Scheduling tasks.
input_processor (InputPreprocessor): Preprocessor for input data.
resource_manager (ResourceManager): Manager for resource allocation.
token_processor (TokenProcessor): Processor for token generation.
engine_worker_queue (EngineWorkerQueue): Queue for communication between engine and workers.
do_profile (int): Flag indicating if profiling is enabled.
"""
@classmethod
def from_engine_args(cls, engine_args: EngineArgs, pid):
"""
Creates an AsyncLLM client from the provided engine arguments.
Args:
engine_args (EngineArgs): Engine arguments object.
Returns:
AsyncLLM: Instance of the AsyncLLM class.
"""
# Create the engine configs.
config = engine_args.create_engine_config()
# Create the AsyncLLM client.
return cls(cfg=config, pid=pid)
def __init__(self, cfg, pid):
"""
Initializes the AsyncLLM client with the provided configuration.
Args:
cfg (Config): Config object containing all the configuration parameters.
"""
super().__init__(cfg, pid)
self.cfg = cfg
self.running = True
self._prompt_metadata: Dict[str, Dict[str, Any]] = {}
self.input_processor = InputPreprocessor(
cfg.model_config,
cfg.structured_outputs_config.reasoning_parser,
cfg.limit_mm_per_prompt,
cfg.mm_processor_kwargs,
cfg.tool_parser,
)
# Create data processor
self.data_processor = self.input_processor.create_processor()
# Create high-performance async connection manager
self.connection_manager = None
self.request_client = None
# Output processor uses data_processor for post-processing engine outputs
self.output_processor = AsyncOutputProcessor(self.data_processor)
self._finalizer = weakref.finalize(self, self._exit_sub_services)
main_process_metrics.set_cache_config_info(obj=self.cfg.cache_config)
async def init_connections(self):
"""Initialize high-performance ZMQ connections"""
try:
# Create ZMQ client for sending requests
self.request_client = ZmqIpcClient(name=self.engine_pid, mode=zmq.PUSH)
self.request_client.connect()
# Create high-performance async connection manager for receiving responses
self.connection_manager = DealerConnectionManager(
pid=self.engine_pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
)
if not self.connection_manager.running:
await self.connection_manager.initialize()
llm_logger.info("High-performance ZMQ connections initialized successfully")
except Exception as e:
llm_logger.error(f"Failed to initialize ZMQ connections: {e}")
raise
async def get_model_config(self):
"""Get model configuration"""
return self.cfg.model_config
async def get_tokenizer(self):
"""Get tokenizer"""
if hasattr(self, "data_processor"):
return self.data_processor.tokenizer
return None
def _has_guided_input(self, request):
"""
Check if the request has any guided input.
"""
return any(
x is not None
for x in (
request.guided_json,
request.guided_regex,
request.guided_choice,
request.structural_tag,
request.guided_grammar,
request.guided_json_object,
)
)
async def add_request(
self,
request_id: str,
prompt: Union[str, List[str], Dict[str, Any]],
sampling_params: Optional[SamplingParams] = None,
arrival_time: Optional[float] = None,
**kwargs,
):
"""
Async add request
Args:
request_id: Request ID
prompt: Input prompt
sampling_params: Sampling parameters
arrival_time: Arrival time
**kwargs: Other parameters
"""
if request_id is None:
request_id = str(uuid.uuid4())
if arrival_time is None:
arrival_time = time.time()
if isinstance(prompt, str):
prompt = {
"prompt": prompt,
"request_id": request_id,
}
elif isinstance(prompt, list) and isinstance(prompt[0], int):
prompt = {
"prompt_token_ids": prompt,
"request_id": request_id,
}
elif isinstance(prompt, dict):
prompt["request_id"] = request_id
else:
raise TypeError(f"Invalid type for 'prompt': {type(prompt)}, expected one of ['str', 'list', 'dict'].")
if sampling_params is not None:
prompt.update(asdict(sampling_params))
try:
# Check if already preprocessed by api_server
is_preprocessed = prompt.get("_preprocessed", False)
if inspect.iscoroutinefunction(self.data_processor.process_request_dict):
request = await self.data_processor.process_request_dict(prompt, self.cfg.model_config.max_model_len)
else:
request = self.data_processor.process_request_dict(prompt, self.cfg.model_config.max_model_len)
request["prompt_token_ids_len"] = len(request["prompt_token_ids"])
# Cache prompt metadata for later enrichment of async responses
req_id = request.get("request_id")
self._prompt_metadata[req_id] = {
"prompt_token_ids": request.get("prompt_token_ids"),
"prompt_tokens": request.get("prompt_tokens"),
}
if not is_preprocessed:
request["preprocess_start_time"] = arrival_time
input_ids_len = request["prompt_token_ids_len"]
request["max_tokens"] = min(
self.cfg.model_config.max_model_len - input_ids_len, request.get("max_tokens")
)
min_tokens = request.get("min_tokens", 1)
if input_ids_len + min_tokens >= self.cfg.model_config.max_model_len:
error_msg = (
f"Input text is too long, length of prompt token({input_ids_len}) "
f"+ min_dec_len ({min_tokens}) >= max_model_len "
)
llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=400)
request["preprocess_end_time"] = time.time()
preprocess_cost_time = request["preprocess_end_time"] - request["preprocess_start_time"]
llm_logger.info(
f"Cache request with request_id ({request.get('request_id')}), "
f"preprocess time cost {preprocess_cost_time}"
)
if not self.cfg.model_config.enable_mm:
self.request_client.send_json(request)
else:
self.request_client.send_pyobj(request)
except EngineError:
raise
except Exception as e:
raise EngineError(f"async_llm add request failed: {e}", error_code=400)
async def generate(
self,
prompt: Union[str, List[str], Dict[str, Any]],
sampling_params: Optional[SamplingParams] = None,
request_id: Optional[str] = None,
**kwargs,
) -> AsyncGenerator[RequestOutput, None]:
"""
Async generation interface
Args:
prompt: Input prompt
sampling_params: Sampling parameters. If `sampling_params.n > 1`,
will generate `n` completions sequentially.
request_id: Request ID
**kwargs: Other parameters
Yields:
RequestOutput: Generated output
"""
num_choices = sampling_params.n if sampling_params is not None and sampling_params.n else 1
stream = True
include_stop_str_in_output = False
enable_thinking = kwargs.pop("enable_thinking", False)
if isinstance(prompt, dict):
num_choices = prompt.get("n")
stream = prompt.get("stream", True)
include_stop_str_in_output = prompt.get("include_stop_str_in_output", False)
# Ensure ZMQ client and connection manager are initialized in current process
if (
self.request_client is None
or self.connection_manager is None
or not getattr(self.connection_manager, "running", False)
):
raise EngineError(
"AsyncLLM engine not initialized. Call init_connections() before generate.",
error_code=500,
)
# Build request ids and connection key
if num_choices <= 1:
# Single-choice: keep user-provided request_id semantics
child_request_ids = [request_id or str(uuid.uuid4())]
conn_request_id = child_request_ids[0]
else:
# Multi-choice: use unified "cmpl-" base id so DealerConnectionManager
# can merge cmpl-xxx_0, cmpl-xxx_1, ... back to the same response queue.
user_request_id = request_id or str(uuid.uuid4())
conn_request_id = f"cmpl-{user_request_id}"
child_request_ids = [f"{conn_request_id}_{i}" for i in range(num_choices)]
try:
# 1) Send all sub-requests to engine
for child_request_id in child_request_ids:
await self.add_request(child_request_id, prompt, sampling_params, **kwargs)
# 2) Get a shared connection for conn_request_id and handshake all sub-requests
dealer, response_queue = await self.connection_manager.get_connection(
request_id=conn_request_id, num_choices=num_choices
)
for child_request_id in child_request_ids:
dealer.write([b"", child_request_id.encode("utf-8")])
# 3) Stream responses from all choices interleaved
remaining = num_choices
while remaining > 0:
response_list = await response_queue.get()
for response_item in response_list:
if isinstance(response_item, dict) and "request_id" in response_item:
req_id = response_item.get("request_id")
# First, use output_processor to post-process the raw dict
if hasattr(self, "output_processor"):
processed_output = self.output_processor._process_output(
response_item,
stream=stream,
enable_thinking=enable_thinking,
include_stop_str_in_output=include_stop_str_in_output,
)
else:
processed_output = response_item
# Then convert processed dict to RequestOutput
request_output = RequestOutput.from_dict(processed_output)
# Enrich outputs with prompt metadata on the first packet
if req_id:
prompt_meta = self._prompt_metadata.get(req_id)
if prompt_meta is not None and request_output.outputs.send_idx == 0:
request_output.prompt_token_ids = prompt_meta.get("prompt_token_ids")
request_output.prompt = prompt_meta.get("prompt_tokens")
self._prompt_metadata.pop(req_id, None)
if request_output.finished:
remaining -= 1
yield request_output
except GeneratorExit:
llm_logger.info(f"Request {conn_request_id} generator exit (outer)")
return
except Exception as e:
llm_logger.error(f"Request {conn_request_id} failed: {e}")
raise EngineError(str(e), error_code=500) from e
finally:
# Ensure request_map/request_num are cleaned up
try:
await self.connection_manager.cleanup_request(conn_request_id)
except Exception:
pass
async def abort_request(self, request_id: str) -> None:
"""
Abort the specified request
Args:
request_id: Request ID to abort
"""
try:
# Clean up request through DealerConnectionManager
if hasattr(self, "connection_manager") and self.connection_manager:
await self.connection_manager.cleanup_request(request_id)
llm_logger.info(f"Aborted request {request_id}")
except Exception as e:
llm_logger.error(f"Failed to abort request {request_id}: {e}")
async def shutdown(self):
"""
Gracefully shutdown AsyncLLM engine
"""
llm_logger.info("Starting AsyncLLM shutdown...")
self.running = False
# Close high-performance connection manager
if hasattr(self, "connection_manager") and self.connection_manager is not None:
llm_logger.info("Stopping connection manager...")
try:
await self.connection_manager.close()
except Exception as e:
llm_logger.error(f"Error while stopping connection manager: {e}")
# Close ZMQ client
if hasattr(self, "request_client") and self.request_client is not None:
llm_logger.info("Closing request client...")
try:
self.request_client.close()
except Exception as e:
llm_logger.warning(f"Error closing request client: {e}")
# Shutdown engine service process
try:
super().shutdown()
except Exception as e:
llm_logger.error(f"Error while stopping engine service process: {e}")
llm_logger.info("AsyncLLM shutdown completed")
def _exit_sub_services(self):
"""
Clean up any remaining resources
"""
pass