""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import argparse import json import time from typing import List import numpy as np import paddle import paddle.distributed as dist from paddle.distributed import fleet from fastdeploy.config import ( DecodingConfig, DeviceConfig, ErnieArchitectures, FDConfig, GraphOptimizationConfig, LoadConfig, ModelConfig, ParallelConfig, SpeculativeConfig, ) from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue from fastdeploy.inter_communicator import IPCSignal from fastdeploy.model_executor.layers.quantization import get_quantization_config from fastdeploy.platforms import current_platform from fastdeploy.utils import get_logger, none_or_str from fastdeploy.worker.worker_base import WorkerBase logger = get_logger("worker_process", "worker_process.log") def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase: """ get worker of different device """ if fd_config.model_config.enable_logprob and not current_platform.is_cuda(): raise NotImplementedError("Only CUDA platform supports logprob.") if current_platform.is_dcu(): from fastdeploy.worker.dcu_worker import DcuWorker return DcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) if current_platform.is_cuda(): from fastdeploy.worker.gpu_worker import GpuWorker return GpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) if current_platform.is_xpu(): from fastdeploy.worker.xpu_worker import XpuWorker return XpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) if current_platform.is_iluvatar(): from fastdeploy.worker.iluvatar_worker import IluvatarWorker return IluvatarWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) if current_platform.is_gcu(): from fastdeploy.worker.gcu_worker import GcuWorker return GcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) def init_distributed_environment(seed: int = 20) -> List[int]: """Initialize Paddle Fleet and get rank of worker""" # Global rank ranks = dist.get_world_size() dist_strategy = fleet.DistributedStrategy() dist_strategy.hybrid_configs = { "dp_degree": 1, "mp_degree": ranks, "pp_degree": 1, "sharding_degree": 1, } # Set control in tensor parallel dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed} fleet.init(is_collective=True, strategy=dist_strategy) # Local rank local_rank = fleet.worker_index() return ranks, local_rank def update_fd_config_for_mm(fd_config: FDConfig) -> None: if fd_config.model_config.enable_mm: tokenizer = ErnieBotTokenizer.from_pretrained( fd_config.parallel_config.model_name_or_path, model_max_length=fd_config.parallel_config.max_model_len, padding_side="right", use_fast=False, ) tokenizer.ignored_index = -100 if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.unk_token fd_config.model_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank vision_config = fd_config.model_config.vision_config vision_config.dtype = fd_config.model_config.dtype # vision_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size # vision_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank fd_config.model_config.im_patch_id = tokenizer.get_vocab()["<|IMAGE_PLACEHOLDER|>"] fd_config.model_config.think_end_id = tokenizer.get_vocab()[""] fd_config.model_config.sequence_parallel = fd_config.parallel_config.sequence_parallel class PaddleDisWorkerProc: """ Paddle Distrubuted wrapper for fastdeploy.worker.Worker, for handling single-node multi-GPU tensor parallel. The wrapper internally executea an event loop that continuously executes requests in the task queue. Control flow is transmitted by IPC. """ def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0) -> None: """ Initialize a distributed worker and task queue for single-node multi-GPU setup. Args: fd_config (FDConfig): Arguments related to inference, containing attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim, num_attention_heads, and ffn_hidden_size. """ self.ranks = ranks self.local_rank = local_rank self.fd_config = fd_config self.parallel_config = fd_config.parallel_config # TODO(gongshaotian): Use worker factory to get worker self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks) # Initialize task queue task_address = ( self.parallel_config.pod_ip, self.parallel_config.engine_worker_queue_port, ) self.task_queue = TaskQueue( address=task_address, is_server=False, num_client=self.parallel_config.tensor_parallel_size, client_id=self.parallel_config.tensor_parallel_rank, local_data_parallel_id=self.parallel_config.expert_parallel_rank, ) def init_health_status(self) -> None: """ Initialize the health status of the worker. Worker Status: worker_ready_signal: worker_healthy_live_signal: exist_task_signal: exist_swapped_task_signal: model_weights_status: """ # init worker_ready_signal self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 array_size = min( self.max_chips_per_node, self.parallel_config.tensor_parallel_size * self.parallel_config.expert_parallel_size, ) workers_ready = np.zeros(shape=[array_size], dtype=np.int32) self.worker_ready_signal = IPCSignal( name="worker_ready_signal", array=workers_ready, dtype=np.int32, suffix=self.parallel_config.engine_pid, create=False, ) self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1 # init worker_healthy_live_signal workers_alive = np.zeros(shape=[array_size], dtype=np.int32) self.worker_healthy_live_signal = IPCSignal( name="worker_healthy_live_signal", array=workers_alive, dtype=np.int32, suffix=self.parallel_config.engine_pid, create=False, ) self.worker_healthy_live_signal.value[self.local_rank % 8] = int(time.time()) # init model_weights_status workers_model_weights = np.zeros(shape=[1], dtype=np.int32) self.model_weights_status = IPCSignal( name="model_weights_status", array=workers_model_weights, dtype=np.int32, suffix=self.parallel_config.engine_pid, create=False, ) # init exist_task_signal workers_exist_task = np.zeros([self.parallel_config.expert_parallel_size], dtype=np.int32) self.exist_task_signal = IPCSignal( name="exist_task_signal", array=workers_exist_task, dtype=np.int32, suffix=self.parallel_config.engine_pid, create=False, ) # init exist_swapped_task_signal workers_swapped_task = np.zeros(shape=[self.parallel_config.expert_parallel_size], dtype=np.int32) self.exist_swapped_task_signal = IPCSignal( name="exist_swapped_task_signal", array=workers_swapped_task, dtype=np.int32, suffix=self.parallel_config.engine_pid, create=False, ) # init exist_prefill_task_signal exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32) self.exist_prefill_task_signal = IPCSignal( name="exist_prefill_task_signal", array=exist_prefill_task_signal_data, dtype=np.int32, suffix=self.parallel_config.engine_pid, create=False, ) def event_loop_ep(self) -> None: """ Tmp loop function for ep utill DP is supported """ while True: self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time()) if self.fd_config.parallel_config.tensor_parallel_rank == 0 and self.task_queue.num_tasks() > 0: tasks, read_finish = self.task_queue.get_tasks() req_dicts = [] for req_dict, bsz in tasks: num_running_requests = int(bsz) req_dicts.extend(req_dict) logger.info( f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, " f"num_insert_requests: {len(req_dicts)}" ) # Process prefill inputs self.worker.preprocess_new_task(req_dicts) # Execute model to generate token. The generated token will be written to the buffer. # These generated tokens can be obtained through get_output op. self.worker.execute_model() def event_loop_normal(self) -> None: """Main event loop for Paddle Distrubuted Workers. TODO(gongshaotian): support remote calling of functions that control worker. """ # Currently, only support single node self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8) mp_num_per_node = self.parallel_config.tensor_parallel_size // self.nnode req_ids = [] while True: if self.local_rank == 0: if self.model_weights_status.value[0] != 0: self.exist_task_signal.value[0] = 2 else: self.exist_task_signal.value[0] = 0 if self.parallel_config.tensor_parallel_size > 1: # Synchronize before updating weights paddle.distributed.barrier() self.insert_step = False self.worker_healthy_live_signal.value[self.local_rank] = int(time.time()) # The first worker detects whether there are tasks in the task queue if self.local_rank % mp_num_per_node == 0: if self.task_queue.num_tasks() > 0: # VL only support 1 batch to prefill if not self.fd_config.model_config.enable_mm or self.worker.prefill_finished(): if self.nnode > 1: self.task_queue.read_finish_flag.set(1) else: self.exist_task_signal.value[self.fd_config.parallel_config.expert_parallel_rank] = 1 if self.parallel_config.tensor_parallel_size > 1: # Synchronize the signal for other workers # TODO(@wufeisheng): Split TP group and EP group paddle.distributed.barrier() if self.fd_config.load_config.dynamic_load_weight: if self.exist_task_signal.value[0] == 2: from fastdeploy.rl.dynamic_weight_manager import ( DynamicWeightManager, ) DynamicWeightManager.check_model_weights_status( self.model_weights_status, self.worker.model_runner, self.parallel_config.engine_pid, ) if ( self.exist_task_signal.value[self.fd_config.parallel_config.expert_parallel_rank] == 1 or self.task_queue.read_finish_flag.get() == 1 ): logger.info(f"Rank: {self.local_rank} Detected new requests.") self.insert_step = True tasks, read_finish = self.task_queue.get_tasks() if read_finish: # Ensure that every worker get the task self.exist_task_signal.value[self.fd_config.parallel_config.expert_parallel_rank] = 0 self.task_queue.read_finish_flag.set(0) req_dicts = [] for req_dict, bsz in tasks: num_running_requests = int(bsz) req_dicts.extend(req_dict) req_ids = [req.request_id for req in req_dicts] logger.info( f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, " f"num_insert_requests: {len(req_dicts)}, req_ids: {req_ids}" ) # Process prefill inputs self.worker.preprocess_new_task(req_dicts) if not self.worker.model_runner.not_need_stop(): if self.ranks > 1: paddle.distributed.barrier() time.sleep(0.001) continue # Execute model to generate token. The generated token will be written to the buffer. # These generated tokens can be obtained through get_output op. self.worker.execute_model(req_dicts) if not self.fd_config.model_config.enable_mm: self.exist_prefill_task_signal.value[0] = self.worker.prefill_finished() def initialize_kv_cache(self) -> None: """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. The engine will first conduct a profiling of the existing memory usage. Then, it calculate the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. .. tip:: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ if self.fd_config.parallel_config.do_profile: # 1. Get available memory(bytes) available_kv_cache_memory = self.worker.determine_available_memory() logger.info(f"------- available_kv_cache_memory:{available_kv_cache_memory / 1024**3} GB --------") # 2. Calculate the appropriate number of blocks model_block_memory_used = self.worker.cal_theortical_kvcache() num_blocks_local = int(available_kv_cache_memory // model_block_memory_used) # NOTE(liuzichang): Too many block will lead to illegal memory access # We will develop dynamic limits in future. if num_blocks_local > 40000: logger.info(f"------- Reset num_blocks_local {num_blocks_local} to 40000") num_blocks_local = min(40000, num_blocks_local) logger.info(f"------- model_block_memory_used:{model_block_memory_used} --------") logger.info(f"------- num_blocks_local:{num_blocks_local} --------") if num_blocks_local <= 0: raise ValueError( "The total number of blocks cannot be less than zero." "Please increase gpu_memory_utilization" "Or decrease max_num_batched_tokens(max model length) " ) if self.ranks > 1: num_blocks_local = paddle.full(shape=[1], fill_value=num_blocks_local, dtype="int32") dist.all_reduce(num_blocks_local, op=dist.ReduceOp.MIN) num_blocks_local = num_blocks_local.item() if self.local_rank == 0: # 3. Send IPCSignal get_profile_block_num = np.zeros(shape=[1], dtype=np.int32) self.get_profile_block_num_signal = IPCSignal( name="get_profile_block_num", array=get_profile_block_num, dtype=np.int32, suffix=self.parallel_config.engine_pid, create=False, ) self.get_profile_block_num_signal.value[0] = num_blocks_local else: num_blocks_local = self.fd_config.parallel_config.total_block_num logger.info(f"------- num_blocks_global: {num_blocks_local} --------") # wait engine launch cache_manager if self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed": launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) self.launched_cache_manager_signal = IPCSignal( name="launched_cache_manager_signal", array=launched_cache_manager_signal_data, dtype=np.int32, suffix=self.parallel_config.engine_pid, create=False, ) while np.any(self.launched_cache_manager_signal.value[0] <= 0): time.sleep(0.01) # 4. init kv_cache with accurate num_blocks self.worker.initialize_cache(num_gpu_blocks=num_blocks_local) def graph_optimize_and_warm_up_model(self) -> None: self.worker.graph_optimize_and_warm_up_model() def init_device(self) -> None: """Initialize device and Construct model runner""" self.worker.init_device() def load_model(self) -> None: """Load weights and create model""" self.worker.load_model() def parse_args(): """ Parse args from command line """ parser = argparse.ArgumentParser("FastDeploy LLM Inference") parser.add_argument( "-m", "--model_name_or_path", type=str, default="./output", help="model dir", ) parser.add_argument("-mbs", "--max_num_seqs", type=int, default=34, help="max batch size") parser.add_argument("--total_block_num", type=int, default=2000) parser.add_argument("--block_size", type=int, default=64) parser.add_argument("--pod_ip", type=str, default="127.0.0.1") parser.add_argument("--engine_worker_queue_port", type=int, default=9923) parser.add_argument("--max_model_len", type=int, default=3072, help="max model len") parser.add_argument("--device_ids", type=str, default="0", help="cuda visible devices") parser.add_argument("--dtype", type=str, default="bfloat16", help="input dtype") parser.add_argument("--enc_dec_block_num", type=int, default=1, help="encoder's decoder num") parser.add_argument( "--kv_cache_ratio", type=float, default=0.7, help="kv cache ratio for input", ) parser.add_argument("--first_token_id", type=int, default=1, help="first token id") parser.add_argument( "--gpu_memory_utilization", type=float, default=0.9, help="gpu memory utilization", ) parser.add_argument("--engine_pid", type=int, default=None, help="Process ID of engine") parser.add_argument("--do_profile", action="store_true", help="do profile or not") parser.add_argument("--pad_token_id", type=int, default=-1, help="pad token id") parser.add_argument("--eos_tokens_lens", type=int, default=2, help="eos token lens") parser.add_argument( "--enable_chunked_prefill", action="store_true", help="enable chunked prefill", ) parser.add_argument( "--speculative_method", default=None, type=none_or_str, choices=[ None, "ngram", "mtp", ], ) parser.add_argument( "--speculative_max_draft_token_num", default=1, type=int, ) parser.add_argument( "--speculative_model_name_or_path", default="", type=str, ) parser.add_argument( "--speculative_model_quantization", default="WINT8", type=str, ) parser.add_argument( "--speculative_benchmark_mode", default="False", type=str, ) parser.add_argument( "--max_num_batched_tokens", type=int, default=2048, help="max num batched tokens", ) parser.add_argument( "--enable_prefix_caching", action="store_true", help="enable prefix cache", ) parser.add_argument( "--enable_custom_all_reduce", action="store_true", help="enable custom all-reduce", ) parser.add_argument("--splitwise_role", type=str, default="mixed", help="splitwise role") parser.add_argument( "--tensor_parallel_size", type=int, default=1, help="tensor parallel size", ) parser.add_argument( "--expert_parallel_size", type=int, default=1, help="expert parallel size", ) parser.add_argument( "--enable_expert_parallel", action="store_true", help="enable expert parallel", ) parser.add_argument("--ori_vocab_size", type=int, default=None) parser.add_argument( "--quantization", type=str, default="None", help="Quantization name for the model, currentlly support " "'wint4', 'wint8'," "default is None. The priority of this configuration " "is lower than that of the config file. " "More complex quantization methods need to be configured via the config file.", ) parser.add_argument( "--graph_optimization_config", type=json.loads, default=None, help="Configation of Graph optimization backend.", ) parser.add_argument( "--guided_decoding_backend", type=str, default="off", help="guided decoding backend", ) parser.add_argument( "--disable_any_whitespace", action="store_false", help="Disable any whitespace for guided decoding.", ) parser.add_argument( "--dynamic_load_weight", action="store_true", help="Enable dynamic weight loading strategy", ) parser.add_argument( "--load_strategy", type=str, choices=["ipc", "ipc_snapshot"], default="ipc_snapshot", help="Weight loading method when dynamic loading is enabled: " "'ipc': real-time IPC streaming with automatic resharding, " "'ipc_snapshot': load from disk snapshot of IPC weights.", ) parser.add_argument("--enable_mm", action="store_true", help="Whether to enable vl model") parser.add_argument( "--enable_logprob", action="store_true", help="Enable output of token-level log probabilities.", ) args = parser.parse_args() return args def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: """Initialize FDConfig from either RolloutModelConfig or argparse.Namespace Args: config: Configuration object containing all parameters (either RolloutModelConfig or argparse.Namespace) Returns: FDConfig: Initialized FastDeploy configuration object """ paddle.set_default_dtype(args.dtype) model_config = ModelConfig(vars(args)) device_config = DeviceConfig(vars(args)) decoding_config = DecodingConfig(vars(args)) speculative_config = SpeculativeConfig(vars(args)) parallel_config = ParallelConfig(vars(args)) parallel_config.tensor_parallel_size = args.tensor_parallel_size parallel_config.tensor_parallel_rank = local_rank % args.tensor_parallel_size parallel_config.expert_parallel_size = args.expert_parallel_size # config for EP if args.expert_parallel_size > 1: expert_parallel_rank = int(local_rank / args.tensor_parallel_size) if isinstance(model_config.moe_num_experts, list): num_experts = model_config.moe_num_experts[0] else: num_experts = model_config.moe_num_experts num_experts_per_rank = num_experts // args.expert_parallel_size num_experts_start_offset = expert_parallel_rank * num_experts_per_rank parallel_config.expert_parallel_rank = expert_parallel_rank parallel_config.num_experts_per_rank = num_experts_per_rank parallel_config.num_experts_start_offset = num_experts_start_offset load_config = LoadConfig(vars(args)) graph_opt_config = GraphOptimizationConfig() if args.graph_optimization_config is not None: graph_opt_config = GraphOptimizationConfig( use_cudagraph=args.graph_optimization_config["use_cudagraph"], graph_opt_level=args.graph_optimization_config["graph_opt_level"], cudagraph_capture_sizes=args.graph_optimization_config["cudagraph_capture_sizes"], sot_warmup_sizes=args.graph_optimization_config["sot_warmup_sizes"], ) # Note(tangbinhan): used for load_checkpoint model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size model_config.pretrained_config.is_mtp = False model_config.pretrained_config.head_dim = model_config.head_dim logger.info(f"parallel_config.use_ep {parallel_config.use_ep}") logger.info(f"parallel_config.tensor_parallel_size {parallel_config.tensor_parallel_size}") logger.info(f"parallel_config.tensor_parallel_rank {parallel_config.tensor_parallel_rank}") if getattr(model_config, "num_hidden_layers", None) is None: raise ValueError("num_hidden_layers is None") quantization_config = model_config.quantization_config if not model_config.is_quantized: if quantization_config is not None: if "kv_cache_quant_type" not in quantization_config: model_config.is_quantized = True quant_config_name = None if quantization_config is not None and quantization_config.get("quantization", None) is None: raise ValueError("quantization_config should have a key named 'quantization' for specify quant config.") if quantization_config is not None: quant_config_name = quantization_config["quantization"] elif args.quantization != "None": quantization_config = {} quant_config_name = args.quantization quantization_config["quantization"] = quant_config_name # Special handling for Ernie models is_ernie = ErnieArchitectures.contains_ernie_arch(model_config.architectures) if quant_config_name == "wint4" and is_ernie: quantization_config["dense_quant_type"] = "wint8" quantization_config["moe_quant_type"] = "wint4" quantization_config["quantization"] = "mix_quant" quant_config_name = "mix_quant" else: quant_config_name = None if quant_config_name is None: quant_config = None else: quant_cls = get_quantization_config(quant_config_name) quant_config = quant_cls.from_config(quantization_config) # Log quantization info logger.info("===========quantization_config==============") if quant_config is not None: if model_config.is_quantized: logger.info("Model Status: Offline Quantized (pre-quantized weights loaded)") else: logger.info("Model Status: Original (will apply online quantization)") logger.info(f"{quantization_config}") else: logger.info("No quantization config found and use original weight and act dtype.") # Set VL tag model_config.enable_mm = args.enable_mm logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}") logger.info(f"- Load strategy: {load_config.load_strategy}") fd_config = FDConfig( model_config=model_config, parallel_config=parallel_config, speculative_config=speculative_config, device_config=device_config, load_config=load_config, decoding_config=decoding_config, quant_config=quant_config, graph_opt_config=graph_opt_config, ) update_fd_config_for_mm(fd_config) return fd_config def run_worker_proc() -> None: """ start worker process """ # Get args form Engine args = parse_args() ranks, local_rank = init_distributed_environment() # Get fd_config fd_config = initialize_fd_config(args, ranks, local_rank) # Create worker process worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank) # Initialize device and create model runner worker_proc.init_device() # Load model worker_proc.load_model() # Initialize KV Cache worker_proc.initialize_kv_cache() # Trigger CUDAGraph capture worker_proc.worker.graph_optimize_and_warm_up_model() # Initialize health status worker_proc.init_health_status() # Start event loop if fd_config.parallel_config.use_ep: # TODO(wufeisheng): Delete this branch worker_proc.event_loop_ep() else: worker_proc.event_loop_normal() if __name__ == "__main__": run_worker_proc()