Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -0,0 +1,772 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import argparse
import json
import time
from typing import List
import numpy as np
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig,
GraphOptimizationConfig, LoadConfig,
ModelConfig, MoEConfig, MoEPhase,
ParallelConfig, SpeculativeConfig)
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.layers.quantization import \
get_quantization_config
from fastdeploy.platforms import current_platform
from fastdeploy.utils import get_logger, none_or_str
from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("worker_process", "worker_process.log")
def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
"""
get worker of different device
"""
if current_platform.is_cuda():
from fastdeploy.worker.gpu_worker import GpuWorker
return GpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
if current_platform.is_xpu():
from fastdeploy.worker.xpu_worker import XpuWorker
return XpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
class PaddleDisWorkerProc():
"""
Paddle Distrubuted wrapper for fastdeploy.worker.Worker,
for handling single-node multi-GPU tensor parallel.
The wrapper internally executea an event loop that continuously executes requests
in the task queue. Control flow is transmitted by IPC.
"""
def __init__(
self,
fd_config: FDConfig,
):
self.fd_config = fd_config
self.parallel_config = fd_config.parallel_config
# Initialize distributed enviroment
(self.rank, self.local_rank) = self.init_distributed_enviroment()
assert self.parallel_config.tensor_parallel_degree * self.parallel_config.expert_parallel_degree == self.rank
self.fd_config.parallel_config.tensor_parallel_rank = \
self.local_rank % self.parallel_config.tensor_parallel_degree
self.fd_config.parallel_config.expert_parallel_rank = \
int(self.local_rank / self.parallel_config.tensor_parallel_degree)
if self.fd_config.parallel_config.use_ep:
self.fd_config.moe_config.num_experts_per_rank = \
self.fd_config.moe_config.num_experts // self.parallel_config.expert_parallel_degree
self.fd_config.moe_config.num_experts_start_offset = \
self.fd_config.parallel_config.expert_parallel_rank * self.fd_config.moe_config.num_experts_per_rank
self.fd_config.parallel_config.column_cut = False
# For auto TP split
self.fd_config.model_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_degree
self.fd_config.model_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank
self.fd_config.model_config.use_ep = self.parallel_config.use_ep
if self.fd_config.parallel_config.use_ep:
self.fd_config.model_config.num_experts_per_rank = self.fd_config.moe_config.num_experts_per_rank
self.fd_config.model_config.num_experts_start_offset = self.fd_config.moe_config.num_experts_start_offset
# TODO(gongshaotian): Use worker factory to get worker
self.worker = get_worker(fd_config=fd_config,
local_rank=self.local_rank,
rank=self.rank)
# Initialize task queue
task_address = ('0.0.0.0',
self.parallel_config.engine_worker_queue_port)
self.task_queue = TaskQueue(
address=task_address,
is_server=False,
num_client=self.parallel_config.tensor_parallel_degree,
client_id=self.parallel_config.tensor_parallel_rank,
local_data_parallel_id=self.fd_config.parallel_config.
expert_parallel_rank)
def init_health_status(self):
"""
Initialize the health status of the worker.
Worker Status:
worker_ready_signal:
worker_healthy_live_signal:
exist_task_signal:
exist_swapped_task_signal:
model_weights_status:
"""
# init worker_ready_signal
array_size = min(
8, self.parallel_config.tensor_parallel_degree *
self.parallel_config.expert_parallel_degree)
workers_ready = np.zeros(shape=[array_size], dtype=np.int32)
self.worker_ready_signal = IPCSignal(
name="worker_ready_signal",
array=workers_ready,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
self.worker_ready_signal.value[self.local_rank % 8] = 1
# init worker_healthy_live_signal
workers_alive = np.zeros(shape=[self.rank], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal",
array=workers_alive,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
self.worker_healthy_live_signal.value[self.local_rank % 8] = int(
time.time())
# init model_weights_status
workers_model_weights = np.zeros(shape=[1], dtype=np.int32)
self.model_weights_status = IPCSignal(
name="model_weights_status",
array=workers_model_weights,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
# init exist_task_signal
workers_exist_task = np.zeros(
[self.parallel_config.expert_parallel_degree], dtype=np.int32)
self.exist_task_signal = IPCSignal(
name="exist_task_signal",
array=workers_exist_task,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
# init exist_swapped_task_signal
workers_swapped_task = np.zeros(
shape=[self.parallel_config.expert_parallel_degree],
dtype=np.int32)
self.exist_swapped_task_signal = IPCSignal(
name="exist_swapped_task_signal",
array=workers_swapped_task,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
# init exist_prefill_task_signal
exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_prefill_task_signal = IPCSignal(
name="exist_prefill_task_signal",
array=exist_prefill_task_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
# init model_weights_status
workers_model_weights = np.zeros(shape=[1], dtype=np.int32)
self.model_weights_status = IPCSignal(
name="model_weights_status",
array=workers_model_weights,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
def event_loop_ep(self):
"""
Tmp loop function for ep utill DP is supported
"""
while True:
self.worker_healthy_live_signal.value[self.local_rank] = int(
time.time())
if self.fd_config.parallel_config.tensor_parallel_rank == 0 and self.task_queue.num_tasks(
) > 0:
tasks, read_finish = self.task_queue.get_tasks()
req_dicts = []
for req_dict, bsz in tasks:
num_running_requests = int(bsz)
req_dicts.extend(req_dict)
logger.info(f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, " \
f"num_insert_requests: {len(req_dicts)}")
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts)
# Execute model to generate token. The generated token will be written to the buffer.
# These generated tokens can be obtained through get_output op.
self.worker.execute_model()
def event_loop_normal(self):
""" Main event loop for Paddle Distrubuted Workers.
TODO(gongshaotian): support remote calling of functions that control worker.
"""
# Currently, only support single node
self.nnode = 1
req_ids = []
while True:
if self.parallel_config.tensor_parallel_degree > 1:
# Synchronize before updating weights
paddle.distributed.barrier()
self.insert_step = False
self.worker_healthy_live_signal.value[self.local_rank] = int(
time.time())
# The first worker detects whether there are tasks in the task queue
mp_num_per_node = self.rank / self.nnode
if self.local_rank % mp_num_per_node == 0:
if self.task_queue.num_tasks() > 0:
if self.nnode > 1:
self.task_queue.read_finish_flag.set(1)
else:
self.exist_task_signal.value[
self.fd_config.parallel_config.
expert_parallel_rank] = 1
if self.parallel_config.tensor_parallel_degree > 1:
# Synchronize the signal for other workers
# TODO(@wufeisheng): Split TP group and EP group
paddle.distributed.barrier()
if self.exist_task_signal.value[
self.fd_config.parallel_config.expert_parallel_rank] == 1 or \
self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
self.insert_step = True
tasks, read_finish = self.task_queue.get_tasks()
if read_finish:
# Ensure that every worker get the task
self.exist_task_signal.value[self.fd_config.parallel_config
.expert_parallel_rank] = 0
self.task_queue.read_finish_flag.set(0)
req_dicts = []
for req_dict, bsz in tasks:
num_running_requests = int(bsz)
req_dicts.extend(req_dict)
req_ids = [req.request_id for req in req_dicts]
logger.info(f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, " \
f"num_insert_requests: {len(req_dicts)}, req_ids: {req_ids}")
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts)
if not self.worker.model_runner.not_need_stop():
if self.rank > 1:
paddle.distributed.barrier()
time.sleep(0.001)
continue
# Execute model to generate token. The generated token will be written to the buffer.
# These generated tokens can be obtained through get_output op.
self.worker.execute_model(req_dicts)
self.exist_prefill_task_signal.value[
0] = self.worker.prefill_finished()
def init_distributed_enviroment(self, seed=20) -> List[int]:
""" Initialize Paddle Fleet and get rank of worker """
# Global rank
self.rank = dist.get_world_size()
dist_strategy = fleet.DistributedStrategy()
dist_strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.rank,
"pp_degree": 1,
"sharding_degree": 1,
}
# Set control in tensor parallel
dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
fleet.init(is_collective=True, strategy=dist_strategy)
# Local rank
self.local_rank = fleet.worker_index()
return self.rank, self.local_rank
def determine_num_available_blocks(self):
"""
"""
if self.fd_config.parallel_config.do_profile:
# 1. Get available memory(bytes)
available_kv_cache_memory = self.worker.determine_available_memory(
)
logger.info(
f"------- available_kv_cache_memory:{available_kv_cache_memory / 1024**3} GB --------"
)
# 2. Calculate the appropriate number of blocks
model_block_memory_used = self.worker.cal_theortical_kvcache()
num_blocks_local = int(available_kv_cache_memory //
model_block_memory_used)
# NOTE(liuzichang): Too many block will lead to illegal memory access
# We will develop dynamic limits in future.
if num_blocks_local > 20000:
logger.info(
f"------- Reset num_blocks_local {num_blocks_local} to 20000"
)
num_blocks_local = min(20000, num_blocks_local)
logger.info(
f"------- model_block_memory_used:{model_block_memory_used} --------"
)
logger.info(
f"------- num_blocks_local:{num_blocks_local} --------")
logger.info(
f"self.fd_config.parallel_config.do_profile:{self.fd_config.parallel_config.do_profile}"
)
# 3. Send IPCSignal
get_profile_block_num = np.zeros(shape=[self.rank], dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
self.get_profile_block_num_signal.value[
self.local_rank] = num_blocks_local
# Wait all worker send the signal
while np.any(self.get_profile_block_num_signal.value <= 0):
time.sleep(0.01)
num_blocks_global = self.get_profile_block_num_signal.value.min(
).item()
self.get_profile_block_num_signal.value[
self.local_rank] = num_blocks_global
else:
num_blocks_global = self.fd_config.parallel_config.max_block_num
# NOTE(liuzichang): Too big num_blocks_global will lead to error 700
# 4. Updata share inputs
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_global)
def init_device(self):
""" """
self.worker.init_device()
def load_model(self):
""" """
self.worker.load_model()
def parse_args():
"""
Parse args from command line
"""
parser = argparse.ArgumentParser("FastDeploy LLM Inference")
parser.add_argument("-m",
"--model_name_or_path",
type=str,
default="./output",
help="model dir")
parser.add_argument("-mbs",
"--max_num_seqs",
type=int,
default=34,
help="max batch size")
parser.add_argument("--total_block_num", type=int, default=2000)
parser.add_argument("--block_size", type=int, default=64)
parser.add_argument("--engine_worker_queue_port", type=int, default=9923)
parser.add_argument("--max_model_len",
type=int,
default=3072,
help="max model len")
parser.add_argument("--device_ids",
type=str,
default="0",
help="cuda visible devices")
parser.add_argument("--dtype",
type=str,
default="bfloat16",
help="input dtype")
parser.add_argument("--enc_dec_block_num",
type=int,
default=1,
help="encoder's decoder num")
parser.add_argument("--kv_cache_ratio",
type=float,
default=0.7,
help="kv cache ratio for input")
parser.add_argument("--first_token_id",
type=int,
default=1,
help="first token id")
parser.add_argument("--gpu_memory_utilization",
type=float,
default=0.9,
help="gpu memory utilization")
parser.add_argument("--engine_pid",
type=int,
default=None,
help="Process ID of engine")
parser.add_argument("--do_profile",
action='store_true',
help="do profile or not")
parser.add_argument("--dynamic_load_weight",
action='store_true',
help="dynamic load weight or not")
parser.add_argument("--pad_token_id",
type=int,
default=-1,
help="pad token id")
parser.add_argument("--eos_tokens_lens",
type=int,
default=2,
help="eos token lens")
parser.add_argument("--enable_chunked_prefill",
action='store_true',
help="enable chunked prefill")
parser.add_argument(
"--speculative_method",
default=None,
type=none_or_str,
choices=[
None,
"ngram",
"mtp",
],
)
parser.add_argument(
"--speculative_max_draft_token_num",
default=1,
type=int,
)
parser.add_argument(
"--speculative_model_name_or_path",
default="",
type=str,
)
parser.add_argument(
"--speculative_model_quantization",
default="WINT8",
type=str,
)
parser.add_argument(
"--attention_backend",
default="APPEND_ATTN",
type=str,
choices=[
"APPEND_ATTN",
],
)
parser.add_argument("--max_num_batched_tokens",
type=int,
default=2048,
help="max num batched tokens")
parser.add_argument("--enable_prefix_caching",
action='store_true',
help="enable prefix cache")
parser.add_argument("--splitwise_role",
type=str,
default="mixed",
help="splitwise role")
parser.add_argument("--tensor_parallel_size",
type=int,
default=1,
help="tensor parallel size")
parser.add_argument("--expert_parallel_size",
type=int,
default=1,
help="expert parallel size")
parser.add_argument("--enable_expert_parallell",
action='store_true',
help="enable expert parallell")
parser.add_argument("--ori_vocab_size", type=int, default=None)
parser.add_argument("--quantization",
type=str,
default="",
help="Quantization name for the model, currentlly support " \
"'wint4', 'wint8'," \
"default is None. The priority of this configuration "\
"is lower than that of the config file. " \
"More complex quantization methods need to be configured via the config file.")
parser.add_argument("--enable_static_graph_inference",
action='store_true',
help="Whether to use static mode; if enabled, " \
"'paddle.to_static' will be used to convert dynamic to static.")
parser.add_argument("--use_cudagraph",
action='store_true',
help="Flags to enable cuda graph.")
parser.add_argument("--max_capture_batch_size",
type=int,
default=64,
help="Maximum Batch Size for Cuda Graph Capture. " \
"If max_capture_batch_size set 64, FastDeploy will capture batch size in [1, 64]")
parser.add_argument("--guided_decoding_backend",
type=str,
default="off",
help="guided decoding backend")
parser.add_argument("--disable_any_whitespace",
action='store_false',
help="Disable any whitespace for guided decoding.")
args = parser.parse_args()
return args
def initialize_fd_config(args) -> FDConfig:
"""Initialize FDConfig
TODO(gongshaotian): Unified all configs to FDConfig
"""
# NOTE(gongshaotian): From build stream line model
config, _ = ModelConfig.get_config_dict(args.model_name_or_path)
if 'num_experts' in config:
config['moe_num_experts'] = config.pop('num_experts')
if 'num_experts_per_tok' in config:
config['moe_topk'] = config.pop('num_experts_per_tok')
config["head_dim"] = config.get(
"head_dim", config["hidden_size"] // config["num_attention_heads"])
config["rope_theta"] = config.get("rope_theta", 10000.0)
model_config = ModelConfig.from_dict(config)
# TODO Set `head_dim` again. Because `ModelConfig` class doesn't support feeding head_dim at all!
model_config.head_dim = config["head_dim"]
paddle.set_default_dtype(args.dtype)
device_config = DeviceConfig()
# model_config = ModelConfig()
decoding_config = DecodingConfig()
decoding_config = MoEConfig()
speculative_config = SpeculativeConfig()
parallel_config = ParallelConfig()
load_config = LoadConfig()
moe_config = MoEConfig()
graph_opt_config = GraphOptimizationConfig(
args.enable_static_graph_inference, args.use_cudagraph,
args.max_capture_batch_size)
model_config.quantization = args.quantization
# Update speculate config
speculative_config.method = args.speculative_method
speculative_config.num_speculative_tokens = args.speculative_max_draft_token_num
speculative_config.model_name_or_path = args.speculative_model_name_or_path
speculative_config.quantization = args.speculative_model_quantization
# Update parallel config
parallel_config.engine_pid = args.engine_pid
parallel_config.model_name_or_path = args.model_name_or_path
parallel_config.max_num_seqs = args.max_num_seqs
parallel_config.max_block_num = args.total_block_num
parallel_config.block_size = args.block_size
parallel_config.engine_worker_queue_port = args.engine_worker_queue_port
parallel_config.max_model_len = args.max_model_len
model_config.max_seq_len = args.max_model_len
model_config.max_length = args.max_model_len
parallel_config.device_ids = args.device_ids
parallel_config.dtype = args.dtype
parallel_config.enc_dec_block_num = args.enc_dec_block_num
parallel_config.kv_cache_ratio = args.kv_cache_ratio
parallel_config.first_token_id = args.first_token_id
parallel_config.gpu_memory_utilization = args.gpu_memory_utilization
parallel_config.engine_pid = args.engine_pid
parallel_config.do_profile = args.do_profile
parallel_config.dynamic_load_weight = args.dynamic_load_weight
parallel_config.pad_token_id = args.pad_token_id
parallel_config.eos_tokens_lens = args.eos_tokens_lens
parallel_config.enable_chunked_prefill = args.enable_chunked_prefill
parallel_config.attention_backend = args.attention_backend
parallel_config.max_num_batched_tokens = args.max_num_batched_tokens
parallel_config.enable_prefix_caching = args.enable_prefix_caching
parallel_config.use_ep = args.enable_expert_parallell
parallel_config.tensor_parallel_degree = args.tensor_parallel_size
parallel_config.expert_parallel_degree = args.expert_parallel_size
parallel_config.splitwise_role = args.splitwise_role
parallel_config.guided_decoding_backend = args.guided_decoding_backend
parallel_config.disable_any_whitespace = args.disable_any_whitespace
logger.info(f"parallel_config.use_ep {parallel_config.use_ep}")
logger.info(
f"parallel_config.tensor_parallel_degree {parallel_config.tensor_parallel_degree}"
)
logger.info(f"args.splitwise_role {args.splitwise_role}")
if args.splitwise_role == "mixed":
parallel_config.moe_phase = MoEPhase.PREFILL
elif args.splitwise_role == "prefill":
parallel_config.moe_phase = MoEPhase.PREFILL
elif args.splitwise_role == "decode":
parallel_config.moe_phase = MoEPhase.DECODER
else:
raise NotImplementedError
num_key_value_heads = config.get("num_key_value_heads", -1)
if num_key_value_heads is None:
num_key_value_heads = -1
if config.get("ffn_hidden_size", None) is not None:
ffn_hidden_size = config["ffn_hidden_size"]
elif config.get("intermediate_size", None) is not None:
ffn_hidden_size = config["intermediate_size"]
else:
ffn_hidden_size = 4 * config["hidden_size"]
if config["hidden_act"].lower() == "swiglu":
if paddle.distributed.get_world_size() > 1:
multiple_of = 8 * config["num_attention_heads"]
else:
multiple_of = 4 * config["num_attention_heads"]
ffn_hidden_size = multiple_of * (
(int(2 * ffn_hidden_size / 3) + multiple_of - 1) //
multiple_of)
num_layers = config.get("num_layers", None) or config.get(
"num_hidden_layers", None)
if num_layers is None:
raise ValueError(f"num_layers<{num_layers}> is invalid")
use_moe = config.get("moe_layer_start_index", num_layers) < num_layers
model_config.ffn_hidden_size = ffn_hidden_size
model_config.num_layers = num_layers
model_config.num_key_value_heads = num_key_value_heads
model_config.start_layer_index = config.get("start_layer_index", 0)
moe_config.num_experts = config.get("moe_num_experts", None)
moe_config.moe_intermediate_size = config.get("moe_intermediate_size",
None)
moe_config.top_k = config.get("moe_k", config.get("moe_topk", 8))
moe_config.moe_num_shared_experts = config.get("moe_num_shared_experts", 0)
moe_config.moe_layer_start_index = config.get("moe_layer_start_index", 0)
moe_config.num_max_dispatch_tokens_per_rank = config.get(
"num_max_dispatch_tokens_per_rank", 256)
model_config.ori_vocab_size = config.get("vocab_size", -1)
if "Ernie4_5_ForCausalLM" in config.get("architectures"):
model_config.ori_vocab_size = args.ori_vocab_size
quantization_config = config.get("quantization_config", None)
# Note(@wufeisheng): The `is_quantized` flag should be explicitly set to `true`
# when the weights are actually quantized offline. For backward compatibility
# with preview logic:
# - If `quantization_config` is provided but `is_quantized` is not explicitly set,
# the value of `is_quantized` will be determined by whether `kv_cache_quant_type`
# has been configured.
if not model_config.is_quantized:
if quantization_config is not None:
if "kv_cache_quant_type" not in quantization_config:
model_config.is_quantized = True
quant_config_name = None
if quantization_config is not None and quantization_config.get(
"quantization", None) is None:
raise ValueError(
"quantization_config should have a key named 'quantization' for specify quant config."
)
if quantization_config is not None:
quant_config_name = quantization_config["quantization"]
elif args.quantization != "None":
quantization_config = {}
quant_config_name = args.quantization
if use_moe and quant_config_name == "wint4":
quantization_config["dense_quant_type"] = "wint8"
quantization_config["moe_quant_type"] = "wint4"
quant_config_name = "mix_quant"
else:
quant_config_name = None
if quant_config_name is None:
quant_config = None
else:
quant_cls = get_quantization_config(quant_config_name)
quant_config = quant_cls.from_config(quantization_config)
logger.info("===========quantization_config==============")
if quant_config is not None:
if model_config.is_quantized:
logger.info(
"=====The currently loaded model is an offline quantized model====="
)
else:
logger.info("=====The currently loaded model is the original model\
The model will be quantized online=====")
logger.info(f"{json.dumps(quantization_config, indent=2)}")
else:
logger.info(
"No quantization config found and use original weight and act dtype."
)
logger.info("============================================")
model_config.architectures = config.get("architectures")
fd_config = FDConfig(model_config=model_config,
parallel_config=parallel_config,
speculative_config=speculative_config,
device_config=device_config,
load_config=load_config,
moe_config=moe_config,
decoding_config=decoding_config,
quant_config=quant_config,
graph_opt_config=graph_opt_config)
return fd_config
def run_worker_proc():
"""
start worker process
"""
# Get args form Engine
args = parse_args()
# Get fd_config
fd_config = initialize_fd_config(args)
# Create worker process
worker_proc = PaddleDisWorkerProc(fd_config)
# Initialize device and create model runner
worker_proc.init_device()
# Load model
worker_proc.load_model()
logger.info("determine_num_available_blocks")
worker_proc.determine_num_available_blocks()
# Trigger CUDAGraph capture
worker_proc.worker.graph_optimize_and_warm_up_model()
# Initialize health status
worker_proc.init_health_status()
# Start event loop
if fd_config.parallel_config.use_ep:
# TODO(wufeisheng): Delete this branch
worker_proc.event_loop_ep()
else:
worker_proc.event_loop_normal()
if __name__ == "__main__":
run_worker_proc()