[Engine] [Feature] Refactor async_llm:cross-process with EngineService,based on zmq communication (#4868)

* Refactor async_llm:cross-process with EngineService

* fix: async_llm output process

* fix: return prompt_token_ids and prompt_tokens in first res

* optimize common_engine start func
This commit is contained in:
zhouchong
2025-12-09 10:53:40 +08:00
committed by GitHub
parent 2f208db4e9
commit 5d9b5e4a5b
8 changed files with 2217 additions and 1790 deletions

View File

@@ -17,7 +17,13 @@
from __future__ import annotations
import copy
import json
import multiprocessing
import os
import re
import signal
import subprocess
import sys
import threading
import time
import traceback
@@ -30,6 +36,7 @@ import paddle
import requests
import zmq
from opentelemetry import trace
from tqdm import tqdm
from fastdeploy.engine.request import Request, RequestOutput, RequestType
from fastdeploy.engine.resource_manager import ResourceManager
@@ -66,7 +73,7 @@ class EngineService:
Base class containing common engine functionality
"""
def __init__(self, cfg, start_queue=True):
def __init__(self, cfg, start_queue=True, use_async_llm=False):
"""
Initializes the LLMEngine with the provided configuration.
@@ -74,6 +81,7 @@ class EngineService:
cfg (Config): Config object containing all the configuration parameters.
"""
self.cfg = cfg
self.use_async_llm = use_async_llm
if cfg.scheduler_config.splitwise_role != "mixed" or cfg.cache_config.enable_prefix_caching:
if isinstance(self.cfg.cache_config.cache_queue_port, str):
self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port.split(",")
@@ -149,10 +157,21 @@ class EngineService:
)
init_eplb_signals(cfg, current_suffix)
if self.use_async_llm:
# Add worker management attributes
self.worker_proc = None
self.do_profile = 1 if self.cfg.cache_config.num_gpu_blocks_override is None else 0
self.ipc_signal_suffix = None
self.cache_manager_processes = None
self._finalizer = weakref.finalize(self, self._exit_sub_services)
def start(self):
def start(self, async_llm_pid=None):
self.running = True
if self.use_async_llm:
self.start_worker_service(async_llm_pid)
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.insert_task_to_worker_thread = threading.Thread(
target=self._schedule_request_to_worker_v1, daemon=True
@@ -167,6 +186,69 @@ class EngineService:
self._register_to_router()
def start_worker_service(self, async_llm_pid=None):
# Initialize IPC signals for worker management
self.ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
self._init_worker_signals()
# Create data processor if not exists
if not hasattr(self, "data_processor"):
self.create_data_processor()
# Launch components: scheduler, cache_manager, expert_service et.al.
self.launch_components()
# If block number is specified and model is deployed in splitwise mode, start cache manager first
if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed":
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
# Start worker processes
self.worker_proc = self._start_worker_service()
time.sleep(5)
self.worker_init_status = dict()
result_container = {}
def check_worker_initialize_status_func(res: dict):
res["worker_is_alive"] = True
if not self.check_worker_initialize_status():
llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
res["worker_is_alive"] = False
self.check_worker_initialize_status_func_thread = threading.Thread(
target=check_worker_initialize_status_func, args=(result_container,), daemon=True
)
self.check_worker_initialize_status_func_thread.start()
# Wait model loading
while self.loaded_model_signal.value[0] == 0:
# Make sure worker process is alive
if not self.check_worker_initialize_status_func_thread.is_alive():
return False
time.sleep(1)
# If block number is not specified, let workers do profiling to determine the block number,
# and then start the cache manager
if self.do_profile:
self._stop_profile()
elif self.cfg.scheduler_config.splitwise_role == "mixed" and self.cfg.cache_config.enable_prefix_caching:
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
# Set cache manager signal
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.launched_cache_manager_signal.value[0] = 1
# Worker launched
self.check_worker_initialize_status_func_thread.join()
if not result_container["worker_is_alive"]:
llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
return False
# Start ZMQ service for communication with AsyncLLM
if async_llm_pid:
self.start_zmq_service(async_llm_pid)
def create_data_processor(self):
self.input_processor = InputPreprocessor(
self.cfg.model_config,
@@ -970,7 +1052,13 @@ class EngineService:
else:
err, data = self.recv_request_server.receive_pyobj_once(block)
if err is not None:
self.llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
# The message "Context was terminated" is normal when closing a ZMQ context
if "Context was terminated" in str(err):
self.llm_logger.info(
"Engine stops inserting zmq task into scheduler due to ZMQ context termination (normal shutdown)."
)
else:
self.llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
break
request, insert_task = None, []
@@ -1336,6 +1424,58 @@ class EngineService:
"""
llm_logger.info("Exit sub services.....")
self.running = False
if self.use_async_llm:
# Clean up worker processes first (before closing multiprocessing services)
if hasattr(self, "worker_proc") and self.worker_proc is not None:
llm_logger.info("Cleaning up worker processes...")
try:
pgid = os.getpgid(self.worker_proc.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception as e:
llm_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}")
# Clean up cache manager processes
if hasattr(self, "cache_manager_processes"):
llm_logger.info("Cleaning up cache manager processes...")
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")
try:
pgid = os.getpgid(p.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception as e:
llm_logger.error(
f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}"
)
if hasattr(self, "cache_task_queue") and self.cache_task_queue is not None:
llm_logger.info("Cleaning up cache_task_queue...")
# Check if cleanup method exists
if hasattr(self.cache_task_queue, "cleanup"):
self.cache_task_queue.cleanup()
elif hasattr(self.cache_task_queue, "manager"):
try:
llm_logger.info("Shutting down cache_task_queue manager...")
self.cache_task_queue.manager.shutdown()
except Exception as e:
llm_logger.warning(f"Error shutting down cache_task_queue manager: {e}")
if hasattr(self, "get_profile_block_num_signal"):
self.get_profile_block_num_signal.clear()
self.worker_ready_signal.clear()
self.loaded_model_signal.clear()
# Clean up other services
if hasattr(self, "dp_processed"):
for p in self.dp_processed:
llm_logger.info(f"Waiting for worker {p.pid} to exit")
p.join()
for p in self.dp_engine_worker_queue_server:
p.cleanup()
if hasattr(self, "engine_worker_queue_server") and self.engine_worker_queue_server is not None:
self.engine_worker_queue_server.cleanup()
self.exist_task_signal.clear()
@@ -1353,3 +1493,395 @@ class EngineService:
self.recv_request_server.close()
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
self.recv_control_cmd_server.close()
# 从 async_llm 移到 common_engine
def _worker_processes_ready(self):
"""
judge if all worker processes are ready
"""
if np.sum(self.worker_ready_signal.value) == self.cfg.worker_num_per_node:
return True
return False
def _init_worker_signals(self):
"""
Initialize shared memory to indicate engine status
"""
# worker_ready_signal 用于worker进程感知engine是否启动完成
worker_ready_signal_data = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
self.worker_ready_signal = IPCSignal(
name="worker_ready_signal",
array=worker_ready_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
# launched_cache_manager_signal 用于感知engine是否启动了cache_manager
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
# launched_expert_service_signal: Used to sense whether each expet_servic is started successfully
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
launched_expert_service_signal_data = np.zeros(
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
)
self.launched_expert_service_signal = IPCSignal(
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
# loaded_model_signal: Used to detect whether each worker has completed model loading
loaded_model_signal_data = np.zeros([1], dtype=np.int32)
self.loaded_model_signal = IPCSignal(
name="loaded_model_signal",
array=loaded_model_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
if self.do_profile:
if paddle.is_compiled_with_custom_device("iluvatar_gpu"):
get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32)
else:
get_profile_block_num = np.zeros([1], dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
def _setting_environ_variables(self):
"""
配置环境变量
"""
variables = {
"ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0,
"LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.parallel_config.device_ids.split(",")),
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
"FLAGS_use_append_attn": 1,
"NCCL_ALGO": "Ring",
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
"OMP_NUM_THREADS": 3,
}
# environment variables needed by Dy2St
variables.update(
{
"SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"),
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"),
"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(
"FLAGS_pir_interpreter_record_stream_for_gc_cache", default="1"
),
"FLAGS_parameters_persistent_mode_in_dy2st": os.getenv(
"FLAGS_parameters_persistent_mode_in_dy2st", default="1"
),
}
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1
else:
variables["FLAGS_use_pd_disaggregation"] = 1
# TODO dynamic load environment variable
if self.cfg.scheduler_config.splitwise_role == "prefill":
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
if self.cfg.model_config.enable_mm:
variables["FLAGS_max_partition_size"] = 1024
command_prefix = ""
for k, v in variables.items():
command_prefix += f"{k}={v} "
return command_prefix
def _start_worker_service(self):
"""
start gpu worker service
"""
log_dir = os.getenv("FD_LOG_DIR", default="log")
command_prefix = self._setting_environ_variables()
current_file_path = os.path.abspath(__file__)
current_dir_path = os.path.split(current_file_path)[0]
# TODO
uncache_worker_stdout = "" if os.getenv("UNCACHE_WORKER_STDOUT", "0") == "1" else "-u"
pd_cmd = f"{command_prefix} {sys.executable} {uncache_worker_stdout} -m paddle.distributed.launch"
pd_cmd = pd_cmd + f" --log_dir {log_dir}"
worker_path = "../worker/worker_process.py"
py_script = os.path.join(current_dir_path, worker_path)
ori_vocab_size = (
len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, "sp_model")
else len(self.data_processor.tokenizer.vocab)
)
think_end_id = self.data_processor.tokenizer.get_vocab().get("</think>", -1)
if think_end_id > 0:
llm_logger.info(f"Get think_end_id {think_end_id} from vocab.")
else:
llm_logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
image_patch_id = self.data_processor.tokenizer.get_vocab().get("<|IMAGE_PLACEHOLDER|>", -1)
line_break_id = self.data_processor.tokenizer.get_vocab().get("\n", -1)
ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port)
ips = None
if self.cfg.ips is not None:
ips = ",".join(self.cfg.ips)
arguments = (
f" --devices {self.cfg.parallel_config.device_ids} {py_script}"
f" --max_num_seqs {self.cfg.scheduler_config.max_num_seqs} --max_model_len {self.cfg.model_config.max_model_len}"
f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}"
f" --model {self.cfg.model_config.model!s}"
f" --device_ids {self.cfg.parallel_config.device_ids}"
f" --tensor_parallel_size {self.cfg.parallel_config.tensor_parallel_size}"
f" --engine_worker_queue_port {ports}"
f" --pod_ip {self.cfg.master_ip}"
f" --block_size {self.cfg.cache_config.block_size}"
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
f" --eos_tokens_lens {self.data_processor.eos_token_id_len}"
f" --pad_token_id {self.data_processor.pad_token_id}"
f" --engine_pid {self.cfg.parallel_config.engine_worker_queue_port[0]}"
f" --max_num_batched_tokens {self.cfg.scheduler_config.max_num_batched_tokens}"
f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}"
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
f" --chunked_moe_size {self.cfg.parallel_config.chunked_moe_size}"
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
f" --ori_vocab_size {ori_vocab_size}"
f" --think_end_id {think_end_id}"
f" --image_patch_id {image_patch_id}"
f" --line_break_id {line_break_id}"
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
f" --load_strategy {self.cfg.load_config.load_strategy}"
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
f" --load_choices {self.cfg.load_config.load_choices}"
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
f" --ips {ips}"
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
f" --runner {self.cfg.model_config.runner}"
f" --convert {self.cfg.model_config.convert}"
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
f" --max_logprobs {self.cfg.model_config.max_logprobs}"
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
)
if self.cfg.structured_outputs_config.logits_processors is not None:
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
worker_store_true_flag = {
"enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel,
"enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching,
"enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill,
"do_profile": self.do_profile,
"dynamic_load_weight": self.cfg.load_config.dynamic_load_weight,
"disable_any_whitespace": self.cfg.structured_outputs_config.disable_any_whitespace,
"disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce,
"use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage,
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
"enable_logprob": self.cfg.model_config.enable_logprob,
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
arguments = arguments + f" --{worker_flag}"
worker_default_none_flag = {
"num_gpu_blocks_override": self.cfg.cache_config.num_gpu_blocks_override,
}
for worker_flag, value in worker_default_none_flag.items():
if value:
arguments = arguments + f" --{worker_flag} {value}"
if self.cfg.nnode > 1:
pd_cmd = pd_cmd + f" --ips {ips} --nnodes {len(self.cfg.ips)}"
pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log"
llm_logger.info(f"Launch worker service command: {pd_cmd}")
p = subprocess.Popen(
pd_cmd,
stdout=subprocess.PIPE,
shell=True,
preexec_fn=os.setsid,
)
return p
def _stop_profile(self):
"""
Stop profiling of the model server and reset variables.
"""
self.do_profile = 0
while self.get_profile_block_num_signal.value[0] == 0:
time.sleep(1)
num_gpu_blocks = self.get_profile_block_num_signal.value[0]
self.cfg.cache_config.reset(num_gpu_blocks)
self.resource_manager.reset_cache_config(self.cfg.cache_config)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
def check_health(self, time_interval_threashold=30):
"""
Check the health of the model server by checking whether all workers are alive.
"""
if self.worker_healthy_live_signal.value[0]:
elapsed_time = time.time() - self.worker_healthy_live_signal.value[0]
if elapsed_time > time_interval_threashold:
return False, "Worker Service Not Healthy"
return True, ""
def launch_components(self):
if self.cfg.scheduler_config.splitwise_role != "mixed":
# 单机逻辑
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
self.splitwise_receive_thread.daemon = True
self.splitwise_receive_thread.start()
role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
request_queues_for_dp_ipc = None
result_queue_for_dp_ipc = None
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, disaggregate)
elif self.cfg.scheduler_config.name == "dp":
request_queues_for_dp_ipc = []
result_queue_for_dp_ipc = multiprocessing.Queue()
for i in range(self.cfg.parallel_config.data_parallel_size):
request_queues_for_dp_ipc.append(multiprocessing.Queue())
self.scheduler.start(
self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node,
request_queues_for_dp_ipc,
result_queue_for_dp_ipc,
)
if not envs.FD_ENABLE_MULTI_API_SERVER:
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.launched_expert_service_signal.value[0] = 1
self.dp_processed = []
self.dp_engine_worker_queue_server = []
for i in range(
1,
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
):
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
address = (
self.cfg.master_ip,
int(self.cfg.parallel_config.engine_worker_queue_port[i]),
)
else:
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock"
llm_logger.info(f"dp start queue service {address}")
self.dp_engine_worker_queue_server.append(
EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
)
from fastdeploy.engine.expert_service import (
start_data_parallel_service,
)
self.dp_processed.append(
multiprocessing.Process(
target=start_data_parallel_service,
args=(
self.cfg,
i,
),
)
)
llm_logger.info(
f"Engine is initialized successfully with {self.cfg.parallel_config.tensor_parallel_size}"
+ f" data parallel id {i}"
)
self.dp_processed[-1].start()
while self.launched_expert_service_signal.value[i] == 0:
time.sleep(1)
def check_worker_initialize_status(self):
"""
Check the initlialize status of workers by stdout logging
"""
def detect_thread():
for line in self.worker_proc.stdout:
line = line.decode("utf-8", errors="ignore")
if self.worker_init_status.get("finished", False):
break
if match := re.search(
r"Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)",
line,
):
self.worker_init_status["weight_loadding"] = eval(match.group(1)) * 1.0 / 100
elif (match := re.search(r"Start load layer (\d+)", line)) or (
match := re.search(r"set state for layer (\d+)", line)
):
progress = eval(match.group(1)) * 1.0 / self.cfg.model_config.num_hidden_layers
self.worker_init_status["layer_loadding"] = progress
if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_hidden_layers - 1:
self.worker_init_status["finished"] = True
self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)
self.checking_worker_status_thread.start()
# display weight loadding progress
with tqdm(total=100, desc="Loading Weights") as pbar:
progress = 0
while progress < 100:
progress = int(self.worker_init_status.get("weight_loadding", 0) * 100)
if self.worker_init_status.get("layer_loadding", 0) > 0 or self._worker_processes_ready():
progress = 100
pbar.update(progress - pbar.n)
pbar.refresh()
time.sleep(0.5)
if self.worker_proc.poll() is not None:
return False
# display layer loadding progress
with tqdm(total=100, desc="Loading Layers") as pbar:
progress = 0
while progress < 100:
progress = int(self.worker_init_status.get("layer_loadding", 0) * 100)
if self._worker_processes_ready():
progress = 100
pbar.update(progress - pbar.n)
pbar.refresh()
time.sleep(0.5)
if self.worker_proc.poll() is not None:
return False
self.worker_init_status["finished"] = True
try:
self.checking_worker_status_thread.join(timeout=1)
except Exception:
pass
return True