mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-23 16:44:22 +08:00 
			
		
		
		
	 0fa28b1068
			
		
	
	0fa28b1068
	
	
	
		
			
			* [fix] fix ep group all-reduce * [fix] fix clear/update lock not working when workers > 1 * [chore] add preemption triggered info log * [fix] fix code style * fix model_weights_signal (#4092) * fix model_weights_signal --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
		
			
				
	
	
		
			855 lines
		
	
	
		
			34 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			855 lines
		
	
	
		
			34 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License"
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| 
 | |
| import argparse
 | |
| import json
 | |
| import time
 | |
| from typing import Tuple
 | |
| 
 | |
| import numpy as np
 | |
| import paddle
 | |
| import paddle.distributed as dist
 | |
| from paddle.distributed import fleet
 | |
| 
 | |
| from fastdeploy import envs
 | |
| from fastdeploy.config import (
 | |
|     CacheConfig,
 | |
|     DecodingConfig,
 | |
|     DeviceConfig,
 | |
|     EarlyStopConfig,
 | |
|     ErnieArchitectures,
 | |
|     FDConfig,
 | |
|     GraphOptimizationConfig,
 | |
|     LoadConfig,
 | |
|     ModelConfig,
 | |
|     ParallelConfig,
 | |
|     PlasAttentionConfig,
 | |
|     SpeculativeConfig,
 | |
| )
 | |
| from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
 | |
| from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
 | |
| from fastdeploy.inter_communicator import ExistTaskStatus, IPCSignal, ModelWeightsStatus
 | |
| from fastdeploy.model_executor.layers.quantization import get_quantization_config
 | |
| from fastdeploy.platforms import current_platform
 | |
| from fastdeploy.utils import get_logger, parse_quantization
 | |
| from fastdeploy.worker.worker_base import WorkerBase
 | |
| 
 | |
| logger = get_logger("worker_process", "worker_process.log")
 | |
| 
 | |
| 
 | |
| def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
 | |
|     """
 | |
|     get worker of different device
 | |
|     """
 | |
|     if fd_config.model_config.enable_logprob and not current_platform.is_cuda():
 | |
|         raise NotImplementedError("Only CUDA platform supports logprob.")
 | |
|     if current_platform.is_dcu():
 | |
|         from fastdeploy.worker.dcu_worker import DcuWorker
 | |
| 
 | |
|         return DcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
 | |
|     if current_platform.is_cuda():
 | |
|         from fastdeploy.worker.gpu_worker import GpuWorker
 | |
| 
 | |
|         return GpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
 | |
|     if current_platform.is_xpu():
 | |
|         from fastdeploy.worker.xpu_worker import XpuWorker
 | |
| 
 | |
|         return XpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
 | |
|     if current_platform.is_iluvatar():
 | |
|         from fastdeploy.worker.iluvatar_worker import IluvatarWorker
 | |
| 
 | |
|         return IluvatarWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
 | |
|     if current_platform.is_gcu():
 | |
|         from fastdeploy.worker.gcu_worker import GcuWorker
 | |
| 
 | |
|         return GcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
 | |
|     if current_platform.is_maca():
 | |
|         from fastdeploy.worker.metax_worker import MetaxWorker
 | |
| 
 | |
|         return MetaxWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
 | |
| 
 | |
| 
 | |
| def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
 | |
|     """Initialize Paddle Fleet and get rank of worker"""
 | |
|     # Global rank
 | |
|     ranks = dist.get_world_size()
 | |
|     dist_strategy = fleet.DistributedStrategy()
 | |
| 
 | |
|     dist_strategy.hybrid_configs = {
 | |
|         "dp_degree": 1,
 | |
|         "mp_degree": ranks,
 | |
|         "pp_degree": 1,
 | |
|         "sharding_degree": 1,
 | |
|     }
 | |
| 
 | |
|     # Set control in tensor parallel
 | |
|     dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
 | |
|     fleet.init(is_collective=True, strategy=dist_strategy)
 | |
| 
 | |
|     # Local rank
 | |
|     local_rank = fleet.worker_index()
 | |
| 
 | |
|     return ranks, local_rank
 | |
| 
 | |
| 
 | |
| def update_fd_config_for_mm(fd_config: FDConfig) -> None:
 | |
|     architectures = fd_config.model_config.architectures
 | |
|     if fd_config.model_config.enable_mm and ErnieArchitectures.contains_ernie_arch(architectures):
 | |
|         tokenizer = Ernie4_5Tokenizer.from_pretrained(
 | |
|             fd_config.model_config.model,
 | |
|             model_max_length=fd_config.parallel_config.max_model_len,
 | |
|             padding_side="right",
 | |
|             use_fast=False,
 | |
|         )
 | |
|         tokenizer.ignored_index = -100
 | |
|         if tokenizer.pad_token is None:
 | |
|             tokenizer.pad_token = tokenizer.unk_token
 | |
| 
 | |
|         fd_config.model_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size
 | |
|         fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
 | |
|         vision_config = fd_config.model_config.vision_config
 | |
|         vision_config.dtype = fd_config.model_config.dtype
 | |
|         # vision_config.tensor_parallel_degree = fd_config.parallel_config.tensor_parallel_size
 | |
|         # vision_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
 | |
|         fd_config.model_config.im_patch_id = tokenizer.get_vocab()["<|IMAGE_PLACEHOLDER|>"]
 | |
|         fd_config.model_config.think_end_id = tokenizer.get_vocab()["</think>"]
 | |
|         fd_config.model_config.sequence_parallel = fd_config.parallel_config.sequence_parallel
 | |
| 
 | |
| 
 | |
| def update_think_end_id_for_ernie(fd_config: FDConfig) -> None:
 | |
|     """
 | |
|     Updates the think_end_id in the model config. Uses the ID of '</think>'
 | |
|     if it exists, otherwise defaults to None.
 | |
|     """
 | |
|     is_ernie = ErnieArchitectures.contains_ernie_arch(fd_config.model_config.architectures)
 | |
|     if current_platform.is_cuda() and is_ernie:
 | |
|         tokenizer = Ernie4_5Tokenizer.from_pretrained(
 | |
|             fd_config.model_config.model,
 | |
|             model_max_length=fd_config.parallel_config.max_model_len,
 | |
|             padding_side="right",
 | |
|             use_fast=False,
 | |
|         )
 | |
| 
 | |
|         vocab = tokenizer.get_vocab()
 | |
|         fd_config.model_config.think_end_id = vocab.get("</think>", None)
 | |
|         if fd_config.model_config.think_end_id is not None:
 | |
|             logger.info(f"Get think_end_id {fd_config.model_config.think_end_id} from vocab.")
 | |
|         else:
 | |
|             logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
 | |
| 
 | |
| 
 | |
| class PaddleDisWorkerProc:
 | |
|     """
 | |
|     Paddle Distributed wrapper for fastdeploy.worker.Worker,
 | |
|         for handling single-node multi-GPU tensor parallel.
 | |
|     The wrapper internally executes an event loop that continuously executes requests
 | |
|         in the task queue. Control flow is transmitted by IPC.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0) -> None:
 | |
|         """
 | |
|         Initialize a distributed worker and task queue for single-node multi-GPU setup.
 | |
|         Args:
 | |
|             fd_config (FDConfig): Arguments related to inference, containing
 | |
|                 attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
 | |
|                 num_attention_heads, and ffn_hidden_size.
 | |
|         """
 | |
|         self.ranks = ranks
 | |
|         self.local_rank = local_rank
 | |
|         self.fd_config = fd_config
 | |
|         self.parallel_config = fd_config.parallel_config
 | |
|         self.cache_config = fd_config.cache_config
 | |
| 
 | |
|         # TODO(gongshaotian): Use worker factory to get worker
 | |
|         self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks)
 | |
| 
 | |
|         self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
 | |
| 
 | |
|     def init_health_status(self) -> None:
 | |
|         """
 | |
|         Initialize the health status of the worker.
 | |
|         Worker Status:
 | |
|             worker_ready_signal:
 | |
|             worker_healthy_live_signal:
 | |
|             exist_task_signal:
 | |
|             exist_swapped_task_signal:
 | |
|             model_weights_status:
 | |
|         """
 | |
|         self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
 | |
|         if self.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER:
 | |
|             launched_expert_service_signal_data = np.zeros(
 | |
|                 shape=[min(self.parallel_config.data_parallel_size, self.max_chips_per_node)], dtype=np.int32
 | |
|             )
 | |
|             self.launched_expert_service_signal = IPCSignal(
 | |
|                 name="launched_expert_service_signal",
 | |
|                 array=launched_expert_service_signal_data,
 | |
|                 dtype=np.int32,
 | |
|                 suffix=self.parallel_config.engine_worker_queue_port,
 | |
|                 create=False,
 | |
|             )
 | |
|             while self.launched_expert_service_signal.value[self.local_rank % self.max_chips_per_node] == 0:
 | |
|                 pass
 | |
| 
 | |
|         # init worker_ready_signal
 | |
|         array_size = min(
 | |
|             self.max_chips_per_node,
 | |
|             self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size,
 | |
|         )
 | |
| 
 | |
|         workers_ready = np.zeros(shape=[array_size], dtype=np.int32)
 | |
|         self.worker_ready_signal = IPCSignal(
 | |
|             name="worker_ready_signal",
 | |
|             array=workers_ready,
 | |
|             dtype=np.int32,
 | |
|             suffix=self.parallel_config.engine_worker_queue_port,
 | |
|             create=False,
 | |
|         )
 | |
|         self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1
 | |
|         # init worker_healthy_live_signal
 | |
|         workers_alive = np.zeros(shape=[min(array_size, self.parallel_config.tensor_parallel_size)], dtype=np.int32)
 | |
|         self.worker_healthy_live_signal = IPCSignal(
 | |
|             name="worker_healthy_live_signal",
 | |
|             array=workers_alive,
 | |
|             dtype=np.int32,
 | |
|             suffix=self.parallel_config.engine_worker_queue_port,
 | |
|             create=False,
 | |
|         )
 | |
|         local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
 | |
|         self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
 | |
| 
 | |
|         # init model_weights_status
 | |
|         workers_model_weights = np.zeros(shape=[1], dtype=np.int32)
 | |
|         self.model_weights_status = IPCSignal(
 | |
|             name="model_weights_status",
 | |
|             array=workers_model_weights,
 | |
|             dtype=np.int32,
 | |
|             suffix=self.parallel_config.engine_worker_queue_port,
 | |
|             create=False,
 | |
|         )
 | |
| 
 | |
|         # init exist_task_signal
 | |
|         workers_exist_task = np.zeros([1], dtype=np.int32)
 | |
|         self.exist_task_signal = IPCSignal(
 | |
|             name="exist_task_signal",
 | |
|             array=workers_exist_task,
 | |
|             dtype=np.int32,
 | |
|             suffix=self.parallel_config.engine_worker_queue_port,
 | |
|             create=False,
 | |
|         )
 | |
| 
 | |
|         # init exist_swapped_task_signal
 | |
|         workers_swapped_task = np.zeros(shape=[1], dtype=np.int32)
 | |
|         self.exist_swapped_task_signal = IPCSignal(
 | |
|             name="exist_swapped_task_signal",
 | |
|             array=workers_swapped_task,
 | |
|             dtype=np.int32,
 | |
|             suffix=self.parallel_config.engine_worker_queue_port,
 | |
|             create=False,
 | |
|         )
 | |
| 
 | |
|         # init exist_prefill_task_signal
 | |
|         exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32)
 | |
|         self.exist_prefill_task_signal = IPCSignal(
 | |
|             name="exist_prefill_task_signal",
 | |
|             array=exist_prefill_task_signal_data,
 | |
|             dtype=np.int32,
 | |
|             suffix=self.parallel_config.engine_worker_queue_port,
 | |
|             create=False,
 | |
|         )
 | |
| 
 | |
|     def _broadcast_model_weights_signal(self, src: int, group) -> int:
 | |
|         model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
 | |
|         paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
 | |
|         return model_weights_signal_tensor.item()
 | |
| 
 | |
|     def event_loop_normal(self) -> None:
 | |
|         """Main event loop for Paddle Distrubuted Workers.
 | |
|         TODO(gongshaotian): support remote calling of functions that control worker.
 | |
|         """
 | |
|         # Currently, only support single node
 | |
|         self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
 | |
|         req_ids = []
 | |
|         num_running_requests = 0
 | |
|         local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
 | |
|         self.model_weights_signal = np.zeros([1], dtype=np.int32)
 | |
|         while True:
 | |
|             if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
 | |
|                 if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
 | |
|                     self.model_weights_signal[0] = int(self.model_weights_status.value[0])
 | |
|                 if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
 | |
|                     self.model_weights_signal[0] = self._broadcast_model_weights_signal(
 | |
|                         src=0, group=self.parallel_config.ep_group
 | |
|                     )
 | |
|             if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.tensor_parallel_size > 1:
 | |
|                 self.model_weights_signal[0] = self._broadcast_model_weights_signal(
 | |
|                     src=0, group=self.parallel_config.tp_group
 | |
|                 )
 | |
| 
 | |
|             self.insert_step = False
 | |
|             req_dicts = None
 | |
|             self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
 | |
| 
 | |
|             # The first worker detects whether there are tasks in the task queue
 | |
|             if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
 | |
|                 if self.task_queue.num_tasks() > 0:
 | |
|                     # VL only support 1 batch to prefill
 | |
|                     if envs.ENABLE_V1_KVCACHE_SCHEDULER or not (
 | |
|                         self.fd_config.model_config.enable_mm and self.worker.exist_prefill()
 | |
|                     ):
 | |
|                         if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node:
 | |
|                             self.task_queue.read_finish_flag.set(1)
 | |
|                         else:
 | |
|                             self.exist_task_signal.value[0] = ExistTaskStatus.EXIST
 | |
| 
 | |
|             if self.parallel_config.tensor_parallel_size > 1:
 | |
|                 # Synchronize the signal for other workers
 | |
|                 paddle.distributed.barrier(self.parallel_config.tp_group)
 | |
| 
 | |
|             if self.fd_config.load_config.dynamic_load_weight:
 | |
|                 if self.parallel_config.enable_expert_parallel:
 | |
|                     paddle.distributed.barrier(self.parallel_config.ep_group)
 | |
|                 else:
 | |
|                     paddle.distributed.barrier(self.parallel_config.tp_group)
 | |
|                 if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
 | |
|                     logger.info(
 | |
|                         f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
 | |
|                     )
 | |
|                     from fastdeploy.rl.dynamic_weight_manager import (
 | |
|                         DynamicWeightManager,
 | |
|                     )
 | |
| 
 | |
|                     self.model_weights_status.value[0] = self.model_weights_signal[0]
 | |
|                     DynamicWeightManager.check_model_weights_status(
 | |
|                         self.model_weights_status,
 | |
|                         self.worker.model_runner,
 | |
|                         self.parallel_config.engine_worker_queue_port,
 | |
|                     )
 | |
|                     self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
 | |
|                     logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
 | |
| 
 | |
|             if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
 | |
|                 logger.info(f"Rank: {self.local_rank} Detected new requests.")
 | |
|                 self.insert_step = True
 | |
| 
 | |
|                 tasks, read_finish = self.task_queue.get_tasks()
 | |
|                 if read_finish:
 | |
|                     # Ensure that every worker get the task
 | |
|                     self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
 | |
|                     self.task_queue.read_finish_flag.set(0)
 | |
| 
 | |
|                 req_dicts = []
 | |
|                 for req_dict, bsz in tasks:
 | |
|                     num_running_requests = int(bsz)
 | |
|                     req_dicts.extend(req_dict)
 | |
| 
 | |
|                 req_ids = [req.request_id for req in req_dicts]
 | |
|                 logger.info(
 | |
|                     f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, "
 | |
|                     f"num_insert_requests: {len(req_dicts)}, req_ids: {req_ids}"
 | |
|                 )
 | |
| 
 | |
|                 # Process prefill inputs
 | |
|                 self.worker.preprocess_new_task(req_dicts, num_running_requests)
 | |
| 
 | |
|             if (not self.parallel_config.use_ep) and (not self.worker.model_runner.not_need_stop()):
 | |
|                 if self.ranks > 1:
 | |
|                     paddle.distributed.barrier(self.parallel_config.tp_group)
 | |
| 
 | |
|                 time.sleep(0.001)
 | |
|                 continue
 | |
| 
 | |
|             # Execute model to generate token. The generated token will be written to the buffer.
 | |
|             # These generated tokens can be obtained through get_output op.
 | |
|             self.worker.execute_model(req_dicts, num_running_requests)
 | |
|             self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
 | |
| 
 | |
|     def initialize_kv_cache(self) -> None:
 | |
|         """Profiles the peak memory usage of the model to determine how many
 | |
|         KV blocks may be allocated without OOMs.
 | |
| 
 | |
|         The engine will first conduct a profiling of the existing memory usage.
 | |
|         Then, it calculate the maximum possible number of GPU and CPU blocks
 | |
|         that can be allocated with the remaining free memory.
 | |
| 
 | |
|         .. tip::
 | |
|             You may limit the usage of GPU memory
 | |
|             by adjusting the `gpu_memory_utilization` parameter.
 | |
|         """
 | |
|         if self.fd_config.parallel_config.do_profile:
 | |
|             # 1. Get available memory(bytes)
 | |
|             available_kv_cache_memory = self.worker.determine_available_memory()
 | |
|             logger.info(f"------- available_kv_cache_memory:{available_kv_cache_memory / 1024**3} GB --------")
 | |
| 
 | |
|             # 2. Calculate the appropriate number of blocks
 | |
|             model_block_memory_used = self.worker.cal_theortical_kvcache()
 | |
|             num_blocks_local = int(available_kv_cache_memory // model_block_memory_used)
 | |
|             # NOTE(liuzichang): Too many block will lead to illegal memory access
 | |
|             # We will develop dynamic limits in future.
 | |
|             if num_blocks_local > 40000:
 | |
|                 logger.info(f"------- Reset num_blocks_local {num_blocks_local} to 40000")
 | |
|                 num_blocks_local = min(40000, num_blocks_local)
 | |
|             logger.info(f"------- model_block_memory_used:{model_block_memory_used} --------")
 | |
|             logger.info(f"------- num_blocks_local:{num_blocks_local} --------")
 | |
| 
 | |
|             if num_blocks_local <= 0:
 | |
|                 raise ValueError(
 | |
|                     "The total number of blocks cannot be less than zero."
 | |
|                     "Please increase gpu_memory_utilization"
 | |
|                     "Or decrease max_num_batched_tokens(max model length) "
 | |
|                 )
 | |
| 
 | |
|             if self.ranks > 1:
 | |
|                 num_blocks_local = paddle.full(shape=[1], fill_value=num_blocks_local, dtype="int32")
 | |
|                 dist.all_reduce(num_blocks_local, op=dist.ReduceOp.MIN)
 | |
|                 num_blocks_local = num_blocks_local.item()
 | |
| 
 | |
|             if self.local_rank % self.max_chips_per_node == 0:
 | |
|                 # 3. Send IPCSignal
 | |
|                 get_profile_block_num = np.zeros(shape=[1], dtype=np.int32)
 | |
|                 self.get_profile_block_num_signal = IPCSignal(
 | |
|                     name="get_profile_block_num",
 | |
|                     array=get_profile_block_num,
 | |
|                     dtype=np.int32,
 | |
|                     suffix=self.parallel_config.engine_worker_queue_port,
 | |
|                     create=False,
 | |
|                 )
 | |
|                 self.get_profile_block_num_signal.value[0] = num_blocks_local
 | |
|         else:
 | |
|             num_blocks_local = self.fd_config.parallel_config.total_block_num
 | |
| 
 | |
|         logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
 | |
| 
 | |
|         # 4. init kv_cache with accurate num_blocks
 | |
|         self.worker.initialize_cache(num_gpu_blocks=num_blocks_local)
 | |
| 
 | |
|     def graph_optimize_and_warm_up_model(self) -> None:
 | |
|         self.worker.graph_optimize_and_warm_up_model()
 | |
| 
 | |
|     def init_device(self) -> None:
 | |
|         """Initialize device and Construct model runner"""
 | |
|         self.worker.init_device()
 | |
| 
 | |
|     def start_task_queue_service(self):
 | |
|         # Initialize task queue
 | |
|         task_address = (
 | |
|             self.parallel_config.pod_ip,
 | |
|             self.parallel_config.engine_worker_queue_port,
 | |
|         )
 | |
|         logger.info(f"connect task queue address {task_address}")
 | |
|         self.task_queue = TaskQueue(
 | |
|             address=task_address,
 | |
|             is_server=False,
 | |
|             num_client=self.parallel_config.tensor_parallel_size,
 | |
|             client_id=self.parallel_config.tensor_parallel_rank,
 | |
|             local_data_parallel_id=self.parallel_config.data_parallel_rank,
 | |
|         )
 | |
| 
 | |
|     def load_model(self) -> None:
 | |
|         """Load weights and create model"""
 | |
| 
 | |
|         self.worker.load_model()
 | |
|         loaded_model_signal_data = np.zeros(shape=[1], dtype=np.int32)
 | |
|         self.loaded_model_signal = IPCSignal(
 | |
|             name="loaded_model_signal",
 | |
|             array=loaded_model_signal_data,
 | |
|             dtype=np.int32,
 | |
|             suffix=self.parallel_config.engine_worker_queue_port,
 | |
|             create=False,
 | |
|         )
 | |
|         if self.ranks > 1:
 | |
|             paddle.distributed.barrier()
 | |
|         self.loaded_model_signal.value[0] = 1
 | |
| 
 | |
| 
 | |
| def parse_args():
 | |
|     """
 | |
|     Parse args from command line
 | |
|     """
 | |
|     parser = argparse.ArgumentParser("FastDeploy LLM Inference")
 | |
|     parser.add_argument(
 | |
|         "-m",
 | |
|         "--model",
 | |
|         type=str,
 | |
|         default="./output",
 | |
|         help="model dir",
 | |
|     )
 | |
|     parser.add_argument("-mbs", "--max_num_seqs", type=int, default=34, help="max batch size")
 | |
|     parser.add_argument("--total_block_num", type=int, default=2000)
 | |
|     parser.add_argument("--block_size", type=int, default=64)
 | |
|     parser.add_argument("--pod_ip", type=str, default="127.0.0.1")
 | |
|     parser.add_argument("--engine_worker_queue_port", type=str, default="9923")
 | |
|     parser.add_argument("--max_model_len", type=int, default=3072, help="max model len")
 | |
|     parser.add_argument("--device_ids", type=str, default="0", help="cuda visible devices")
 | |
|     parser.add_argument("--dtype", type=str, default="bfloat16", help="input dtype")
 | |
|     parser.add_argument("--enc_dec_block_num", type=int, default=1, help="encoder's decoder num")
 | |
|     parser.add_argument(
 | |
|         "--kv_cache_ratio",
 | |
|         type=float,
 | |
|         default=0.7,
 | |
|         help="kv cache ratio for input",
 | |
|     )
 | |
|     parser.add_argument("--first_token_id", type=int, default=1, help="first token id")
 | |
|     parser.add_argument(
 | |
|         "--gpu_memory_utilization",
 | |
|         type=float,
 | |
|         default=0.9,
 | |
|         help="gpu memory utilization",
 | |
|     )
 | |
|     parser.add_argument("--engine_pid", type=int, default=None, help="Process ID of engine")
 | |
|     parser.add_argument("--do_profile", action="store_true", help="do profile or not")
 | |
|     parser.add_argument("--pad_token_id", type=int, default=-1, help="pad token id")
 | |
|     parser.add_argument("--eos_tokens_lens", type=int, default=2, help="eos token lens")
 | |
|     parser.add_argument(
 | |
|         "--enable_chunked_prefill",
 | |
|         action="store_true",
 | |
|         help="enable chunked prefill",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--speculative_config",
 | |
|         type=json.loads,
 | |
|         default=None,
 | |
|         help="Configation of SpeculativeConfig.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--max_num_batched_tokens",
 | |
|         type=int,
 | |
|         default=2048,
 | |
|         help="max num batched tokens",
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         "--enable_prefix_caching",
 | |
|         action="store_true",
 | |
|         help="enable prefix cache",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--disable_custom_all_reduce",
 | |
|         action="store_true",
 | |
|         help="enable custom all-reduce",
 | |
|     )
 | |
|     parser.add_argument("--splitwise_role", type=str, default="mixed", help="splitwise role")
 | |
|     parser.add_argument(
 | |
|         "--tensor_parallel_size",
 | |
|         type=int,
 | |
|         default=1,
 | |
|         help="tensor parallel size",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--expert_parallel_size",
 | |
|         type=int,
 | |
|         default=1,
 | |
|         help="expert parallel size",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--data_parallel_size",
 | |
|         type=int,
 | |
|         default=1,
 | |
|         help="data parallel size",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--enable_expert_parallel",
 | |
|         action="store_true",
 | |
|         help="enable expert parallel",
 | |
|     )
 | |
|     parser.add_argument("--ori_vocab_size", type=int, default=None)
 | |
| 
 | |
|     parser.add_argument(
 | |
|         "--quantization",
 | |
|         type=json.loads,
 | |
|         default=None,
 | |
|         help="Quantization name for the model, currently support "
 | |
|         "'wint4', 'wint8',"
 | |
|         "default is None. The priority of this configuration "
 | |
|         "is lower than that of the config file. "
 | |
|         "More complex quantization methods need to be configured via the config file.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--graph_optimization_config",
 | |
|         type=json.loads,
 | |
|         default=None,
 | |
|         help="Configation of Graph optimization backend.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--plas_attention_config",
 | |
|         type=json.loads,
 | |
|         default=None,
 | |
|         help="Configation of plas attention.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--guided_decoding_backend",
 | |
|         type=str,
 | |
|         default="off",
 | |
|         help="guided decoding backend",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--disable_any_whitespace",
 | |
|         action="store_false",
 | |
|         help="Disable any whitespace for guided decoding.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--dynamic_load_weight",
 | |
|         action="store_true",
 | |
|         help="Enable dynamic weight loading strategy",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--load_strategy",
 | |
|         type=str,
 | |
|         choices=["ipc", "ipc_snapshot"],
 | |
|         default="ipc_snapshot",
 | |
|         help="Weight loading method when dynamic loading is enabled: "
 | |
|         "'ipc': real-time IPC streaming with automatic resharding, "
 | |
|         "'ipc_snapshot': load from disk snapshot of IPC weights.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--enable_logprob",
 | |
|         action="store_true",
 | |
|         help="Enable output of token-level log probabilities.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "--early_stop_config",
 | |
|         type=json.loads,
 | |
|         default=None,
 | |
|         help="Configuration of early stop.",
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         "--load_choices",
 | |
|         type=str,
 | |
|         default="default",
 | |
|         help="The format of the model weights to load. default/new_loader.",
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         "--ips",
 | |
|         type=str,
 | |
|         default=None,
 | |
|         help="The ips of multinode deployment.",
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         "--lm_head_fp32",
 | |
|         action="store_true",
 | |
|         help="Flag to specify dtype of lm_head as FP32",
 | |
|     )
 | |
| 
 | |
|     args = parser.parse_args()
 | |
|     return args
 | |
| 
 | |
| 
 | |
| def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
 | |
|     """Initialize FDConfig from either RolloutModelConfig or argparse.Namespace
 | |
| 
 | |
|     Args:
 | |
|         config: Configuration object containing all parameters (either RolloutModelConfig or argparse.Namespace)
 | |
| 
 | |
|     Returns:
 | |
|         FDConfig: Initialized FastDeploy configuration object
 | |
|     """
 | |
|     # RL rollout
 | |
|     if args.quantization is not None and isinstance(args.quantization, str):
 | |
|         args.quantization = parse_quantization(args.quantization)
 | |
|     paddle.set_default_dtype(args.dtype)
 | |
|     model_config = ModelConfig(vars(args))
 | |
|     device_config = DeviceConfig(vars(args))
 | |
|     decoding_config = DecodingConfig(vars(args))
 | |
|     speculative_config = SpeculativeConfig(args.speculative_config)
 | |
|     parallel_config = ParallelConfig(vars(args))
 | |
|     cache_config = CacheConfig(vars(args))
 | |
|     parallel_config.tensor_parallel_rank = local_rank % parallel_config.tensor_parallel_size
 | |
|     parallel_config.data_parallel_rank = local_rank // parallel_config.tensor_parallel_size
 | |
|     # config for EP
 | |
|     if parallel_config.expert_parallel_size > 1:
 | |
|         expert_parallel_rank = int(local_rank % parallel_config.expert_parallel_size)
 | |
|         if isinstance(model_config.moe_num_experts, list):
 | |
|             num_experts = model_config.moe_num_experts[0]
 | |
|         else:
 | |
|             num_experts = model_config.moe_num_experts
 | |
| 
 | |
|         num_experts_per_rank = num_experts // parallel_config.expert_parallel_size
 | |
|         num_experts_start_offset = expert_parallel_rank * num_experts_per_rank
 | |
|         max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
 | |
|         parallel_config.local_data_parallel_id = parallel_config.data_parallel_rank % (
 | |
|             max_chips_per_node // parallel_config.tensor_parallel_size
 | |
|         )
 | |
| 
 | |
|         parallel_config.expert_parallel_rank = expert_parallel_rank
 | |
|         parallel_config.num_experts_per_rank = num_experts_per_rank
 | |
|         parallel_config.num_experts_start_offset = num_experts_start_offset
 | |
| 
 | |
|     parallel_config.engine_worker_queue_port = parallel_config.engine_worker_queue_port[
 | |
|         parallel_config.local_data_parallel_id
 | |
|     ]
 | |
|     parallel_config.set_tp_group()
 | |
| 
 | |
|     load_config = LoadConfig(vars(args))
 | |
| 
 | |
|     graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
 | |
| 
 | |
|     plas_attention_config = PlasAttentionConfig(args.plas_attention_config)
 | |
| 
 | |
|     early_stop_config = EarlyStopConfig(args.early_stop_config)
 | |
| 
 | |
|     # Note(tangbinhan): used for load_checkpoint
 | |
|     model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
 | |
|     model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size
 | |
|     model_config.pretrained_config.is_mtp = False
 | |
|     model_config.pretrained_config.head_dim = model_config.head_dim
 | |
| 
 | |
|     logger.info(f"parallel_config.use_ep {parallel_config.use_ep}")
 | |
|     logger.info(f"parallel_config.tensor_parallel_size {parallel_config.tensor_parallel_size}")
 | |
|     logger.info(f"parallel_config.tensor_parallel_rank {parallel_config.tensor_parallel_rank}")
 | |
|     logger.info(f"parallel_config.engine_worker_queue_port {parallel_config.engine_worker_queue_port}")
 | |
| 
 | |
|     if getattr(model_config, "num_hidden_layers", None) is None:
 | |
|         raise ValueError("num_hidden_layers is None")
 | |
| 
 | |
|     quantization_config = model_config.quantization_config
 | |
|     if not model_config.is_quantized:
 | |
|         if quantization_config is not None:
 | |
|             if "is_quantized" in quantization_config:
 | |
|                 model_config.is_quantized = quantization_config["is_quantized"]
 | |
|             elif "kv_cache_quant_type" not in quantization_config:
 | |
|                 model_config.is_quantized = True
 | |
| 
 | |
|     quant_config_name = None
 | |
|     if quantization_config is not None and quantization_config.get("quantization", None) is None:
 | |
|         raise ValueError("quantization_config should have a key named 'quantization' for specify quant config.")
 | |
| 
 | |
|     if quantization_config is not None:
 | |
|         quant_config_name = quantization_config["quantization"]
 | |
|     elif args.quantization is not None:
 | |
|         quantization_config = {}
 | |
|         try:
 | |
|             quantization_config.update(args.quantization)
 | |
|             quant_config_name = quantization_config["quantization"]
 | |
|         except:
 | |
|             quant_config_name = args.quantization["quantization"]
 | |
|             quantization_config["quantization"] = quant_config_name
 | |
|         # Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization.
 | |
|         if load_config.load_choices == "default_v1" and not load_config.dynamic_load_weight:
 | |
|             quantization_config["is_checkpoint_bf16"] = True
 | |
|         # Special handling for Ernie models
 | |
|         is_ernie = ErnieArchitectures.contains_ernie_arch(model_config.architectures)
 | |
|         if quant_config_name == "wint4" and is_ernie:
 | |
|             quantization_config["dense_quant_type"] = "wint8"
 | |
|             quantization_config["moe_quant_type"] = "wint4"
 | |
|             quantization_config["quantization"] = "mix_quant"
 | |
|             quant_config_name = "mix_quant"
 | |
|     else:
 | |
|         quant_config_name = None
 | |
| 
 | |
|     if quant_config_name is None:
 | |
|         quant_config = None
 | |
|     else:
 | |
|         quant_cls = get_quantization_config(quant_config_name)
 | |
|         quant_config = quant_cls.from_config(quantization_config)
 | |
| 
 | |
|     # Log quantization info
 | |
|     logger.info("===========quantization_config==============")
 | |
|     if quant_config is not None:
 | |
|         if model_config.is_quantized:
 | |
|             logger.info("Model Status: Offline Quantized (pre-quantized weights loaded)")
 | |
|         else:
 | |
|             logger.info("Model Status: Original (will apply online quantization)")
 | |
| 
 | |
|         logger.info(f"{quantization_config}")
 | |
|     else:
 | |
|         logger.info("No quantization config found and use original weight and act dtype.")
 | |
| 
 | |
|     logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
 | |
|     logger.info(f"- Load strategy: {load_config.load_strategy}")
 | |
| 
 | |
|     if (
 | |
|         args.speculative_config is not None
 | |
|         and ("method" in args.speculative_config)
 | |
|         and (args.speculative_config["method"] is not None)
 | |
|     ):
 | |
|         logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.")
 | |
|         envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
 | |
|     if args.splitwise_role != "mixed":
 | |
|         logger.info(f"Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported {args.splitwise_role} now.")
 | |
|         envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
 | |
|     if not current_platform.is_cuda():
 | |
|         logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.")
 | |
|         envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
 | |
|     if parallel_config.guided_decoding_backend != "off":
 | |
|         logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.")
 | |
|         envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
 | |
| 
 | |
|     fd_config = FDConfig(
 | |
|         model_config=model_config,
 | |
|         parallel_config=parallel_config,
 | |
|         speculative_config=speculative_config,
 | |
|         device_config=device_config,
 | |
|         load_config=load_config,
 | |
|         decoding_config=decoding_config,
 | |
|         quant_config=quant_config,
 | |
|         graph_opt_config=graph_opt_config,
 | |
|         early_stop_config=early_stop_config,
 | |
|         cache_config=cache_config,
 | |
|         engine_worker_queue_port=args.engine_worker_queue_port,
 | |
|         ips=args.ips,
 | |
|         plas_attention_config=plas_attention_config,
 | |
|     )
 | |
|     update_fd_config_for_mm(fd_config)
 | |
|     update_think_end_id_for_ernie(fd_config)
 | |
| 
 | |
|     return fd_config
 | |
| 
 | |
| 
 | |
| def run_worker_proc() -> None:
 | |
|     """
 | |
|     start worker process
 | |
|     """
 | |
|     # Get args form Engine
 | |
|     args = parse_args()
 | |
| 
 | |
|     ranks, local_rank = init_distributed_environment()
 | |
| 
 | |
|     # Get fd_config
 | |
|     fd_config = initialize_fd_config(args, ranks, local_rank)
 | |
| 
 | |
|     # Create worker process
 | |
|     if current_platform.is_iluvatar():
 | |
|         from fastdeploy.worker.iluvatar_worker import IluvatarPaddleDisWorkerProc
 | |
| 
 | |
|         worker_proc = IluvatarPaddleDisWorkerProc(fd_config, ranks, local_rank)
 | |
|     else:
 | |
|         worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank)
 | |
| 
 | |
|     # Initialize device and create model runner
 | |
|     worker_proc.init_device()
 | |
| 
 | |
|     # Load model
 | |
|     worker_proc.load_model()
 | |
|     # Initialize KV Cache
 | |
|     worker_proc.initialize_kv_cache()
 | |
| 
 | |
|     # Trigger CUDAGraph capture
 | |
|     worker_proc.worker.graph_optimize_and_warm_up_model()
 | |
| 
 | |
|     # Initialize health status
 | |
|     worker_proc.init_health_status()
 | |
| 
 | |
|     worker_proc.start_task_queue_service()
 | |
| 
 | |
|     # Start event loop
 | |
|     worker_proc.event_loop_normal()
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     run_worker_proc()
 |