Files
FastDeploy/fastdeploy/worker/worker_process.py
chen 823a47e64a [Feature] Support return logprob of generated tokens (#2784)
* online chat support logprobs

* check xpu

* check vl_gpu_model_runner

* only cuda support logprob

* get_worker() check platform

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-07-10 15:47:42 +08:00

843 lines
36 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import argparse
import time
from typing import List
import numpy as np
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from fastdeploy import envs
from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig,
GraphOptimizationConfig, LoadConfig,
ModelConfig, MoEConfig, MoEPhase,
ParallelConfig, SpeculativeConfig)
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.layers.quantization import \
get_quantization_config
from fastdeploy.platforms import current_platform
from fastdeploy.utils import get_logger, none_or_str
from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("worker_process", "worker_process.log")
def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
"""
get worker of different device
"""
if fd_config.model_config.enable_logprob and not current_platform.is_cuda():
raise NotImplementedError("Only CUDA platform supports logprob.")
if current_platform.is_cuda():
from fastdeploy.worker.gpu_worker import GpuWorker
return GpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
if current_platform.is_xpu():
from fastdeploy.worker.xpu_worker import XpuWorker
return XpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
class PaddleDisWorkerProc():
"""
Paddle Distrubuted wrapper for fastdeploy.worker.Worker,
for handling single-node multi-GPU tensor parallel.
The wrapper internally executea an event loop that continuously executes requests
in the task queue. Control flow is transmitted by IPC.
"""
def __init__(
self,
fd_config: FDConfig,
) -> None:
"""
Initialize a distributed worker and task queue for single-node multi-GPU setup.
Args:
fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
"""
self.fd_config = fd_config
self.parallel_config = fd_config.parallel_config
# Initialize distributed enviroment
(self.ranks, self.local_rank) = self.init_distributed_enviroment()
assert self.parallel_config.tensor_parallel_degree * self.parallel_config.expert_parallel_degree == self.ranks
self.fd_config.parallel_config.tensor_parallel_rank = \
self.local_rank % self.parallel_config.tensor_parallel_degree
self.fd_config.parallel_config.expert_parallel_rank = \
int(self.local_rank / self.parallel_config.tensor_parallel_degree)
if self.fd_config.parallel_config.use_ep:
self.fd_config.moe_config.num_experts_per_rank = \
self.fd_config.moe_config.num_experts // self.parallel_config.expert_parallel_degree
self.fd_config.moe_config.num_experts_start_offset = \
self.fd_config.parallel_config.expert_parallel_rank * self.fd_config.moe_config.num_experts_per_rank
# For auto TP split
self.fd_config.model_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_degree
self.fd_config.model_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank
self.fd_config.model_config.use_ep = self.parallel_config.use_ep
if self.fd_config.parallel_config.use_ep:
self.fd_config.model_config.num_experts_per_rank = self.fd_config.moe_config.num_experts_per_rank
self.fd_config.model_config.num_experts_start_offset = self.fd_config.moe_config.num_experts_start_offset
# TODO(gongshaotian): Use worker factory to get worker
self.worker = get_worker(fd_config=fd_config,
local_rank=self.local_rank,
rank=self.ranks)
# Initialize task queue
task_address = ('0.0.0.0',
self.parallel_config.engine_worker_queue_port)
self.task_queue = TaskQueue(
address=task_address,
is_server=False,
num_client=self.parallel_config.tensor_parallel_degree,
client_id=self.parallel_config.tensor_parallel_rank,
local_data_parallel_id=self.fd_config.parallel_config.
expert_parallel_rank)
def init_health_status(self) -> None:
"""
Initialize the health status of the worker.
Worker Status:
worker_ready_signal:
worker_healthy_live_signal:
exist_task_signal:
exist_swapped_task_signal:
model_weights_status:
"""
# init worker_ready_signal
array_size = min(
8, self.parallel_config.tensor_parallel_degree *
self.parallel_config.expert_parallel_degree)
workers_ready = np.zeros(shape=[array_size], dtype=np.int32)
self.worker_ready_signal = IPCSignal(
name="worker_ready_signal",
array=workers_ready,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
self.worker_ready_signal.value[self.local_rank % 8] = 1
# init worker_healthy_live_signal
workers_alive = np.zeros(shape=[self.ranks], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal",
array=workers_alive,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
self.worker_healthy_live_signal.value[self.local_rank % 8] = int(
time.time())
# init model_weights_status
workers_model_weights = np.zeros(shape=[1], dtype=np.int32)
self.model_weights_status = IPCSignal(
name="model_weights_status",
array=workers_model_weights,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
# init exist_task_signal
workers_exist_task = np.zeros(
[self.parallel_config.expert_parallel_degree], dtype=np.int32)
self.exist_task_signal = IPCSignal(
name="exist_task_signal",
array=workers_exist_task,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
# init exist_swapped_task_signal
workers_swapped_task = np.zeros(
shape=[self.parallel_config.expert_parallel_degree],
dtype=np.int32)
self.exist_swapped_task_signal = IPCSignal(
name="exist_swapped_task_signal",
array=workers_swapped_task,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
# init exist_prefill_task_signal
exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_prefill_task_signal = IPCSignal(
name="exist_prefill_task_signal",
array=exist_prefill_task_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
def event_loop_ep(self) -> None:
"""
Tmp loop function for ep utill DP is supported
"""
while True:
self.worker_healthy_live_signal.value[self.local_rank] = int(
time.time())
if self.fd_config.parallel_config.tensor_parallel_rank == 0 and self.task_queue.num_tasks(
) > 0:
tasks, read_finish = self.task_queue.get_tasks()
req_dicts = []
for req_dict, bsz in tasks:
num_running_requests = int(bsz)
req_dicts.extend(req_dict)
logger.info(f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, " \
f"num_insert_requests: {len(req_dicts)}")
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts)
# Execute model to generate token. The generated token will be written to the buffer.
# These generated tokens can be obtained through get_output op.
self.worker.execute_model()
def event_loop_normal(self) -> None:
""" Main event loop for Paddle Distrubuted Workers.
TODO(gongshaotian): support remote calling of functions that control worker.
"""
# Currently, only support single node
self.nnode = 1
req_ids = []
while True:
if self.local_rank == 0:
if self.model_weights_status.value[0] != 0:
self.exist_task_signal.value[0] = 2
else:
self.exist_task_signal.value[0] = 0
if self.parallel_config.tensor_parallel_degree > 1:
# Synchronize before updating weights
paddle.distributed.barrier()
self.insert_step = False
self.worker_healthy_live_signal.value[self.local_rank] = int(
time.time())
# The first worker detects whether there are tasks in the task queue
mp_num_per_node = self.ranks / self.nnode
if self.local_rank % mp_num_per_node == 0:
if self.task_queue.num_tasks() > 0:
if self.nnode > 1:
self.task_queue.read_finish_flag.set(1)
else:
self.exist_task_signal.value[
self.fd_config.parallel_config.
expert_parallel_rank] = 1
if self.parallel_config.tensor_parallel_degree > 1:
# Synchronize the signal for other workers
# TODO(@wufeisheng): Split TP group and EP group
paddle.distributed.barrier()
if self.fd_config.load_config.dynamic_load_weight:
if self.exist_task_signal.value[0] == 2:
from fastdeploy.rl.dynamic_weight_manager import \
DynamicWeightManager
DynamicWeightManager.check_model_weights_status(
self.model_weights_status, self.worker.model_runner,
self.parallel_config.engine_pid)
if self.exist_task_signal.value[
self.fd_config.parallel_config.expert_parallel_rank] == 1 or \
self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
self.insert_step = True
tasks, read_finish = self.task_queue.get_tasks()
if read_finish:
# Ensure that every worker get the task
self.exist_task_signal.value[self.fd_config.parallel_config
.expert_parallel_rank] = 0
self.task_queue.read_finish_flag.set(0)
req_dicts = []
for req_dict, bsz in tasks:
num_running_requests = int(bsz)
req_dicts.extend(req_dict)
req_ids = [req.request_id for req in req_dicts]
logger.info(f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, " \
f"num_insert_requests: {len(req_dicts)}, req_ids: {req_ids}")
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts)
if not self.worker.model_runner.not_need_stop():
if self.ranks > 1:
paddle.distributed.barrier()
time.sleep(0.001)
continue
# Execute model to generate token. The generated token will be written to the buffer.
# These generated tokens can be obtained through get_output op.
self.worker.execute_model(req_dicts)
self.exist_prefill_task_signal.value[
0] = self.worker.prefill_finished()
def init_distributed_enviroment(self, seed: int = 20) -> List[int]:
""" Initialize Paddle Fleet and get rank of worker """
# Global rank
self.ranks = dist.get_world_size()
dist_strategy = fleet.DistributedStrategy()
dist_strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.ranks,
"pp_degree": 1,
"sharding_degree": 1,
}
# Set control in tensor parallel
dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
fleet.init(is_collective=True, strategy=dist_strategy)
# Local rank
self.local_rank = fleet.worker_index()
return self.ranks, self.local_rank
def determine_num_available_blocks(self) -> None:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
if self.fd_config.parallel_config.do_profile:
# 1. Get available memory(bytes)
available_kv_cache_memory = self.worker.determine_available_memory(
)
logger.info(
f"------- available_kv_cache_memory:{available_kv_cache_memory / 1024**3} GB --------"
)
# 2. Calculate the appropriate number of blocks
model_block_memory_used = self.worker.cal_theortical_kvcache()
num_blocks_local = int(available_kv_cache_memory //
model_block_memory_used)
# NOTE(liuzichang): Too many block will lead to illegal memory access
# We will develop dynamic limits in future.
if num_blocks_local > 40000:
logger.info(
f"------- Reset num_blocks_local {num_blocks_local} to 40000"
)
num_blocks_local = min(40000, num_blocks_local)
logger.info(
f"------- model_block_memory_used:{model_block_memory_used} --------"
)
logger.info(
f"------- num_blocks_local:{num_blocks_local} --------")
logger.info(
f"self.fd_config.parallel_config.do_profile:{self.fd_config.parallel_config.do_profile}"
)
# 3. Send IPCSignal
get_profile_block_num = np.zeros(shape=[self.ranks],
dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False)
self.get_profile_block_num_signal.value[
self.local_rank] = num_blocks_local
# Wait all worker send the signal
while np.any(self.get_profile_block_num_signal.value <= 0):
time.sleep(0.01)
num_blocks_global = self.get_profile_block_num_signal.value.min(
).item()
self.get_profile_block_num_signal.value[
self.local_rank] = num_blocks_global
else:
num_blocks_global = self.fd_config.parallel_config.max_block_num
# NOTE(liuzichang): Too big num_blocks_global will lead to error 700
# 4. Updata share inputs
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_global)
def init_device(self) -> None:
""" Initialize device and Construct model runner """
self.worker.init_device()
def load_model(self) -> None:
""" Load weights and create model """
self.worker.load_model()
def parse_args():
"""
Parse args from command line
"""
parser = argparse.ArgumentParser("FastDeploy LLM Inference")
parser.add_argument("-m",
"--model_name_or_path",
type=str,
default="./output",
help="model dir")
parser.add_argument("-mbs",
"--max_num_seqs",
type=int,
default=34,
help="max batch size")
parser.add_argument("--total_block_num", type=int, default=2000)
parser.add_argument("--block_size", type=int, default=64)
parser.add_argument("--engine_worker_queue_port", type=int, default=9923)
parser.add_argument("--max_model_len",
type=int,
default=3072,
help="max model len")
parser.add_argument("--device_ids",
type=str,
default="0",
help="cuda visible devices")
parser.add_argument("--dtype",
type=str,
default="bfloat16",
help="input dtype")
parser.add_argument("--enc_dec_block_num",
type=int,
default=1,
help="encoder's decoder num")
parser.add_argument("--kv_cache_ratio",
type=float,
default=0.7,
help="kv cache ratio for input")
parser.add_argument("--first_token_id",
type=int,
default=1,
help="first token id")
parser.add_argument("--gpu_memory_utilization",
type=float,
default=0.9,
help="gpu memory utilization")
parser.add_argument("--engine_pid",
type=int,
default=None,
help="Process ID of engine")
parser.add_argument("--do_profile",
action='store_true',
help="do profile or not")
parser.add_argument("--pad_token_id",
type=int,
default=-1,
help="pad token id")
parser.add_argument("--eos_tokens_lens",
type=int,
default=2,
help="eos token lens")
parser.add_argument("--enable_chunked_prefill",
action='store_true',
help="enable chunked prefill")
parser.add_argument(
"--speculative_method",
default=None,
type=none_or_str,
choices=[
None,
"ngram",
"mtp",
],
)
parser.add_argument(
"--speculative_max_draft_token_num",
default=1,
type=int,
)
parser.add_argument(
"--speculative_model_name_or_path",
default="",
type=str,
)
parser.add_argument(
"--speculative_model_quantization",
default="WINT8",
type=str,
)
parser.add_argument("--max_num_batched_tokens",
type=int,
default=2048,
help="max num batched tokens")
parser.add_argument("--enable_prefix_caching",
action='store_true',
help="enable prefix cache")
parser.add_argument("--splitwise_role",
type=str,
default="mixed",
help="splitwise role")
parser.add_argument("--tensor_parallel_size",
type=int,
default=1,
help="tensor parallel size")
parser.add_argument("--expert_parallel_size",
type=int,
default=1,
help="expert parallel size")
parser.add_argument("--enable_expert_parallell",
action='store_true',
help="enable expert parallell")
parser.add_argument("--ori_vocab_size", type=int, default=None)
parser.add_argument("--quantization",
type=str,
default="None",
help="Quantization name for the model, currentlly support " \
"'wint4', 'wint8'," \
"default is None. The priority of this configuration "\
"is lower than that of the config file. " \
"More complex quantization methods need to be configured via the config file.")
parser.add_argument("--enable_static_graph_inference",
action='store_true',
help="Whether to use static mode; if enabled, " \
"'paddle.to_static' will be used to convert dynamic to static.")
parser.add_argument("--use_cudagraph",
action='store_true',
help="Flags to enable cuda graph.")
parser.add_argument("--max_capture_batch_size",
type=int,
default=64,
help="Maximum Batch Size for Cuda Graph Capture. " \
"If max_capture_batch_size set 64, FastDeploy will capture batch size in [1, 64]")
parser.add_argument("--guided_decoding_backend",
type=str,
default="off",
help="guided decoding backend")
parser.add_argument("--disable_any_whitespace",
action='store_false',
help="Disable any whitespace for guided decoding.")
parser.add_argument("--dynamic_load_weight",
action='store_true',
help="Enable dynamic weight loading strategy")
parser.add_argument(
"--load_strategy",
type=str,
choices=['ipc', 'ipc_no_reshard', 'ipc_snapshot', 'meta', 'normal'],
default='meta',
help="Weight loading method when dynamic loading is enabled: "
"'ipc': real-time IPC streaming with automatic resharding, "
"'ipc_no_reshard': IPC streaming without weight processing, "
"'ipc_snapshot': load from disk snapshot of IPC weights, "
"'meta': provide RL traing worker, no_weights_load"
"'normal':normal load weight")
parser.add_argument("--enable_logprob",
action='store_true',
help="Enable output of token-level log probabilities.")
args = parser.parse_args()
return args
def initialize_fd_config(config) -> FDConfig:
"""Initialize FDConfig from either RolloutModelConfig or argparse.Namespace
Args:
config: Configuration object containing all parameters (either RolloutModelConfig or argparse.Namespace)
Returns:
FDConfig: Initialized FastDeploy configuration object
"""
# Get model config from model directory
model_config_dict, _ = ModelConfig.get_config_dict(config.model_name_or_path)
# Handle MoE related configs
if 'num_experts' in model_config_dict:
model_config_dict['moe_num_experts'] = model_config_dict.pop('num_experts')
if 'num_experts_per_tok' in model_config_dict:
model_config_dict['moe_topk'] = model_config_dict.pop('num_experts_per_tok')
# Set default values for model config
model_config_dict["head_dim"] = model_config_dict.get(
"head_dim", model_config_dict["hidden_size"] // model_config_dict["num_attention_heads"])
model_config_dict["rope_theta"] = model_config_dict.get("rope_theta", 10000.0)
# Create model config object
model_config = ModelConfig.from_dict(model_config_dict)
model_config.head_dim = model_config_dict["head_dim"]
paddle.set_default_dtype(config.dtype)
# Initialize all config components
device_config = DeviceConfig()
decoding_config = DecodingConfig()
speculative_config = SpeculativeConfig()
parallel_config = ParallelConfig()
load_config = LoadConfig()
moe_config = MoEConfig()
# Handle graph optimization config (check for attribute existence for backward compatibility)
enable_static_graph_inference = getattr(config, 'enable_static_graph_inference', False)
use_cudagraph = getattr(config, 'use_cudagraph', False)
max_capture_batch_size = getattr(config, 'max_capture_batch_size', 0)
graph_opt_config = GraphOptimizationConfig(
enable_static_graph_inference,
use_cudagraph,
max_capture_batch_size
)
# Handle quantization (check for attribute existence)
model_config.quantization = getattr(config, 'quantization', None)
# Update speculative config
speculative_config.method = getattr(config, 'speculative_method', None)
speculative_config.num_speculative_tokens = getattr(config, 'speculative_max_draft_token_num', 0)
speculative_config.model_name_or_path = getattr(config, 'speculative_model_name_or_path', None)
speculative_config.quantization = getattr(config, 'speculative_model_quantization', None)
# Update parallel config
parallel_config.engine_pid = getattr(config, 'engine_pid', None)
parallel_config.model_name_or_path = config.model_name_or_path
parallel_config.max_num_seqs = getattr(config, 'max_num_seqs', 0)
parallel_config.max_block_num = getattr(config, 'total_block_num', 0)
parallel_config.block_size = getattr(config, 'block_size', 0)
parallel_config.engine_worker_queue_port = getattr(config, 'engine_worker_queue_port', 0)
parallel_config.max_model_len = getattr(config, 'max_model_len', 0)
model_config.max_seq_len = getattr(config, 'max_model_len', 0)
model_config.max_length = getattr(config, 'max_model_len', 0)
parallel_config.device_ids = getattr(config, 'device_ids', [])
parallel_config.dtype = config.dtype
parallel_config.enc_dec_block_num = getattr(config, 'enc_dec_block_num', 0)
parallel_config.kv_cache_ratio = getattr(config, 'kv_cache_ratio', 1.0)
parallel_config.first_token_id = getattr(config, 'first_token_id', None)
parallel_config.gpu_memory_utilization = getattr(config, 'gpu_memory_utilization', 0.9)
parallel_config.engine_pid = getattr(config, 'engine_pid', None)
parallel_config.do_profile = getattr(config, 'do_profile', False)
parallel_config.dynamic_load_weight = getattr(config, 'dynamic_load_weight', False)
parallel_config.pad_token_id = getattr(config, 'pad_token_id', None)
parallel_config.eos_tokens_lens = getattr(config, 'eos_tokens_lens', 0)
parallel_config.enable_chunked_prefill = getattr(config, 'enable_chunked_prefill', False)
parallel_config.max_num_batched_tokens = getattr(config, 'max_num_batched_tokens', 0)
parallel_config.enable_prefix_caching = getattr(config, 'enable_prefix_caching', False)
parallel_config.use_ep = getattr(config, 'enable_expert_parallell', False)
parallel_config.tensor_parallel_degree = getattr(config, 'tensor_parallel_size', 1)
parallel_config.expert_parallel_degree = getattr(config, 'expert_parallel_size', 1)
parallel_config.splitwise_role = getattr(config, 'splitwise_role', None)
parallel_config.guided_decoding_backend = getattr(config, 'guided_decoding_backend', None)
parallel_config.disable_any_whitespace = getattr(config, 'disable_any_whitespace', False)
# Handle load config (check for environment variable)
load_config.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
# Log parallel config info
logger.info(f"parallel_config.use_ep {parallel_config.use_ep}")
logger.info(f"parallel_config.tensor_parallel_degree {parallel_config.tensor_parallel_degree}")
logger.info(f"splitwise_role {parallel_config.splitwise_role}")
# Set MoE phase based on splitwise role
if parallel_config.splitwise_role == "mixed":
parallel_config.moe_phase = MoEPhase.PREFILL
elif parallel_config.splitwise_role == "prefill":
parallel_config.moe_phase = MoEPhase.PREFILL
elif parallel_config.splitwise_role == "decode":
parallel_config.moe_phase = MoEPhase.DECODER
elif parallel_config.splitwise_role is not None:
raise NotImplementedError
# Handle model architecture specific configurations
num_key_value_heads = model_config_dict.get("num_key_value_heads", -1)
if num_key_value_heads is None:
num_key_value_heads = -1
# Calculate FFN hidden size
if model_config_dict.get("ffn_hidden_size", None) is not None:
ffn_hidden_size = model_config_dict["ffn_hidden_size"]
elif model_config_dict.get("intermediate_size", None) is not None:
ffn_hidden_size = model_config_dict["intermediate_size"]
else:
ffn_hidden_size = 4 * model_config_dict["hidden_size"]
if model_config_dict["hidden_act"].lower() == "swiglu":
if paddle.distributed.get_world_size() > 1:
multiple_of = 8 * model_config_dict["num_attention_heads"]
else:
multiple_of = 4 * model_config_dict["num_attention_heads"]
ffn_hidden_size = multiple_of * (
(int(2 * ffn_hidden_size / 3) + multiple_of - 1) //
multiple_of)
# Get number of layers
num_layers = model_config_dict.get("num_layers", None) or model_config_dict.get(
"num_hidden_layers", None)
if num_layers is None:
raise ValueError(f"num_layers<{num_layers}> is invalid")
use_moe = model_config_dict.get("moe_layer_start_index", num_layers) < num_layers
# Update model config
model_config.ffn_hidden_size = ffn_hidden_size
model_config.num_layers = num_layers
model_config.num_key_value_heads = num_key_value_heads
model_config.start_layer_index = model_config_dict.get("start_layer_index", 0)
# Update MoE config
moe_config.num_experts = model_config_dict.get("moe_num_experts", None)
moe_config.moe_intermediate_size = model_config_dict.get("moe_intermediate_size", None)
moe_config.top_k = model_config_dict.get("moe_k", model_config_dict.get("moe_topk", 8))
moe_config.moe_num_shared_experts = model_config_dict.get("moe_num_shared_experts", 0)
moe_config.moe_layer_start_index = model_config_dict.get("moe_layer_start_index", 0)
moe_config.num_max_dispatch_tokens_per_rank = model_config_dict.get(
"num_max_dispatch_tokens_per_rank", 256)
moe_config.moe_use_aux_free = model_config_dict.get("moe_use_aux_free", False)
# Handle vocabulary size
model_config.ori_vocab_size = model_config_dict.get("vocab_size", -1)
archs = model_config_dict.get("architectures", [])
if "Ernie4_5_ForCausalLM" in archs or "Ernie4_5_MoeForCausalLM" in archs:
model_config.ori_vocab_size = getattr(config, 'ori_vocab_size', model_config.ori_vocab_size)
# Handle DeepseekV3 specific config
if "DeepseekV3ForCausalLM" in model_config_dict.get("architectures", []):
from paddleformers.transformers import AutoConfig
model_config.deepseekv3 = AutoConfig.from_pretrained(
config.model_name_or_path)
# Handle quantization config
quantization_config = model_config_dict.get("quantization_config", None)
if not model_config.is_quantized:
if quantization_config is not None:
if "kv_cache_quant_type" not in quantization_config:
model_config.is_quantized = True
quant_config_name = None
if quantization_config is not None and quantization_config.get(
"quantization", None) is None:
raise ValueError(
"quantization_config should have a key named 'quantization' for specify quant config."
)
if quantization_config is not None:
quant_config_name = quantization_config["quantization"]
elif getattr(config, 'quantization', None) != "None":
quantization_config = {}
quant_config_name = getattr(config, 'quantization', None)
quantization_config["quantization"] = quant_config_name
# Special handling for Ernie models
is_ernie = "Ernie4_5_ForCausalLM" in model_config_dict.get("architectures", []) or \
"Ernie4_5_MoeForCausalLM" in model_config_dict.get("architectures", [])
if use_moe and quant_config_name == "wint4" and is_ernie:
quantization_config["dense_quant_type"] = "wint8"
quantization_config["moe_quant_type"] = "wint4"
quantization_config["quantization"] = "mix_quant"
quant_config_name = "mix_quant"
else:
quant_config_name = None
if quant_config_name is None:
quant_config = None
else:
quant_cls = get_quantization_config(quant_config_name)
quant_config = quant_cls.from_config(quantization_config)
# Log quantization info
logger.info("===========quantization_config==============")
if quant_config is not None:
if model_config.is_quantized:
logger.info(
"Model Status: Offline Quantized (pre-quantized weights loaded)"
)
else:
logger.info(
"Model Status: Original (will apply online quantization)")
logger.info(f"Quantization Method: {getattr(config, 'quantization', 'None')}")
else:
logger.info(
"No quantization config found and use original weight and act dtype."
)
model_config.enable_logprob = config.enable_logprob
model_config.architectures = model_config_dict.get("architectures")
# Update load config
logger.info("===========load_config==============")
load_config.dynamic_load_weight = getattr(config, 'dynamic_load_weight', False)
load_config.load_strategy = getattr(config, 'load_strategy', None)
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
logger.info(f"- Load strategy: {load_config.load_strategy}")
# Create and return FDConfig
fd_config = FDConfig(
model_config=model_config,
parallel_config=parallel_config,
speculative_config=speculative_config,
device_config=device_config,
load_config=load_config,
moe_config=moe_config,
decoding_config=decoding_config,
quant_config=quant_config,
graph_opt_config=graph_opt_config
)
return fd_config
def run_worker_proc() -> None:
"""
start worker process
"""
# Get args form Engine
args = parse_args()
# Get fd_config
fd_config = initialize_fd_config(args)
# Create worker process
worker_proc = PaddleDisWorkerProc(fd_config)
# Initialize device and create model runner
worker_proc.init_device()
# Load model
worker_proc.load_model()
logger.info("determine_num_available_blocks")
worker_proc.determine_num_available_blocks()
# Trigger CUDAGraph capture
worker_proc.worker.graph_optimize_and_warm_up_model()
# Initialize health status
worker_proc.init_health_status()
# Start event loop
if fd_config.parallel_config.use_ep:
# TODO(wufeisheng): Delete this branch
worker_proc.event_loop_ep()
else:
worker_proc.event_loop_normal()
if __name__ == "__main__":
run_worker_proc()