""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import argparse import time from collections import defaultdict from concurrent.futures import ThreadPoolExecutor import numpy as np import paddle import paddle.distributed as dist import paddle.distributed.fleet as fleet from fastdeploy.engine.config import ModelConfig from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal from fastdeploy.utils import get_logger, none_or_str from fastdeploy.worker.worker_process import initialize_fd_config, parse_args logger = get_logger("worker", "worker.log") class PrefillTracker: """ Record the prefill time of the request """ def __init__( self, engine_pid: int, ) -> None: """ Initialize the PrefillTracker. """ super().__init__() self.start_times = defaultdict(float) prefill_time_data = np.zeros([100], dtype=np.float32) self.prefill_time_signal = IPCSignal(name="prefill_time_signal", array=prefill_time_data, dtype=np.float32, suffix=engine_pid, create=False) self.current_index = 0 self.executor = ThreadPoolExecutor(max_workers=1) def start_prefill(self, task_idx: int): """ Record the start time of the prefill process for a given task index. Args: task_idx (int): The index of the task being prefetched. """ self.start_times[task_idx] = time.time() def end_prefill(self, task_idx: int): """ Record the end time of the prefill process for a given task index and asynchronously submit the duration for metric recording. Args: task_idx (int): The index of the task being prefetched. """ if task_idx in self.start_times: duration = time.time() - self.start_times[task_idx] # Submit metric recording to the executor for asynchronous execution self.executor.submit(self._record_metrics, duration) del self.start_times[task_idx] def _record_metrics(self, duration: float): """ Internal method to record the prefill duration into the signal buffer. Logs the duration and updates a circular buffer of timing metrics. Args: duration (float): Time taken for the prefill process in seconds. """ self.prefill_time_signal.value[self.current_index] = duration self.current_index = (self.current_index + 1) % len( self.prefill_time_signal.value) def __del__(self): """Clean up resources""" if hasattr(self, 'executor'): self.executor.shutdown(wait=False) class Worker: """ Engine -> (WIP)Executor -> Worker -> ModelRunner -> Model Worker interface that allows inference framwork to cleanly separate implementations for different harware. """ def __init__( self, args, ) -> None: """ Initialize the Worker. """ super().__init__() self.args = args self.MAX_INFER_SEED = 9223372036854775806 paddle.set_default_dtype(args.dtype) self.device_ids = self.args.device_ids.split(",") self.model_cfg = ModelConfig(args.model_name_or_path) from fastdeploy.worker.vl_gpu_model_runner import GPUVLModelRunner self.init_dist_env() self.format_print_configuration() self.helper_tensors = {} local_rank = self.rank % self.args.tensor_parallel_size self.local_data_parallel_id = self.rank // self.args.tensor_parallel_size self.infer_engine = GPUVLModelRunner(config=self.model_cfg, args=self.args, nranks=self.nranks, rank=self.rank) self.prefill_tracker = PrefillTracker(args.engine_pid) address = (self.args.pod_ip, self.args.engine_worker_queue_port) self.engine_worker_queue = EngineWorkerQueue( address=address, is_server=False, num_client=self.nranks, client_id=local_rank, local_data_parallel_id=self.local_data_parallel_id) self.init_health() def init_dist_env(self, seed=20): """ init distributed env """ self.nranks = dist.get_world_size() strategy = fleet.DistributedStrategy() strategy.hybrid_configs = { "dp_degree": 1, "mp_degree": self.nranks, "pp_degree": 1, "sharding_degree": 1, } # Set control in tensor parallel strategy.tensor_parallel_configs = {"tensor_init_seed": seed} fleet.init(is_collective=True, strategy=strategy) self.rank = fleet.worker_index() def init_health(self): """ init health signals """ # To perceive whether each worker process is ready worker_ready_signal_data = np.zeros(shape=[self.nranks], dtype=np.int32) self.worker_ready_signal = IPCSignal(name="worker_ready_signal", array=worker_ready_signal_data, dtype=np.int32, suffix=self.args.engine_pid, create=False) self.worker_ready_signal.value[self.rank] = 1 # To monitor the liveness of worker processes and record each step's timestamp worker_healthy_live_recorded_time_array = np.zeros(shape=[self.nranks], dtype=np.int32) self.worker_healthy_live_signal = IPCSignal( name="worker_healthy_live_signal", array=worker_healthy_live_recorded_time_array, dtype=np.int32, suffix=self.args.engine_pid, create=False) self.worker_healthy_live_signal.value[self.rank] = int(time.time()) # To perceive whether there is a new task to be processed exist_task_signal_data = np.zeros([1], dtype=np.int32) self.exist_task_signal = IPCSignal(name="exist_task_signal", array=exist_task_signal_data, dtype=np.int32, suffix=self.args.engine_pid, create=False) # To detect whether there are swapped tasks in the worker exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32) self.exist_swapped_task_signal = IPCSignal( name="exist_swapped_task_signal", array=exist_swapped_task_signal_data, dtype=np.int32, suffix=self.args.engine_pid, create=False) model_weights_status = np.zeros([1], dtype=np.int32) self.model_weights_status_signal = IPCSignal( name="model_weights_status", array=model_weights_status, dtype=np.int32, suffix=self.args.engine_pid, create=False) def format_print_configuration(self): """ print model config """ logger.info("=============== Model Information ==============") for k, v in self.model_cfg.__dict__.items(): logger.info("{:<20}:{:<6}{}".format(k, "", v)) logger.info("=============== Service Configuration ===============") for k, v in vars(self.args).items(): logger.info("{:<20}:{:<6}{}".format(k, "", v)) logger.info("=====================================================\n") def step_cuda(self): """ step cuda """ from fastdeploy.model_executor.ops.gpu import (step_reschedule, step_system_cache) if self.args.enable_prefix_caching: step_system_cache( self.infer_engine.share_inputs["stop_flags"], self.infer_engine.share_inputs["seq_lens_this_time"], self.infer_engine.share_inputs["step_seq_lens_encoder"], self.infer_engine.share_inputs["step_seq_lens_decoder"], self.infer_engine.share_inputs["seq_lens_encoder"], self.infer_engine.share_inputs["seq_lens_decoder"], self.infer_engine.share_inputs["block_tables"], self.infer_engine.share_inputs["encoder_block_lens"], self.infer_engine.share_inputs["is_block_step"], self.infer_engine.share_inputs["step_block_list"], self.infer_engine.share_inputs["step_lens"], self.infer_engine.share_inputs["recover_block_list"], self.infer_engine.share_inputs["recover_lens"], self.infer_engine.share_inputs["need_block_list"], self.infer_engine.share_inputs["need_block_len"], self.infer_engine.share_inputs["used_list_len"], self.infer_engine.share_inputs["free_list"], self.infer_engine.share_inputs["free_list_len"], self.infer_engine.share_inputs["input_ids"], self.infer_engine.share_inputs["pre_ids"], self.infer_engine.share_inputs["step_idx"], self.infer_engine.share_inputs["next_tokens"], self.infer_engine.share_inputs["first_token_ids"], self.args.block_size, self.args.enc_dec_block_num) else: step_reschedule( self.infer_engine.share_inputs["stop_flags"], self.infer_engine.share_inputs["seq_lens_this_time"], self.infer_engine.share_inputs["step_seq_lens_encoder"], self.infer_engine.share_inputs["seq_lens_encoder"], self.infer_engine.share_inputs["seq_lens_decoder"], self.infer_engine.share_inputs["block_tables"], self.infer_engine.share_inputs["encoder_block_lens"], self.infer_engine.share_inputs["is_block_step"], self.infer_engine.share_inputs["step_block_list"], self.infer_engine.share_inputs["step_lens"], self.infer_engine.share_inputs["recover_block_list"], self.infer_engine.share_inputs["recover_lens"], self.infer_engine.share_inputs["need_block_list"], self.infer_engine.share_inputs["need_block_len"], self.infer_engine.share_inputs["used_list_len"], self.infer_engine.share_inputs["free_list"], self.infer_engine.share_inputs["free_list_len"], self.infer_engine.share_inputs["input_ids"], self.infer_engine.share_inputs["pre_ids"], self.infer_engine.share_inputs["step_idx"], self.infer_engine.share_inputs["next_tokens"], self.infer_engine.share_inputs["first_token_ids"], self.args.block_size, self.args.enc_dec_block_num, ) def check_model_weights_status(self): """ check model weights status """ is_stop = 0 while self.model_weights_status_signal.value[0] != 0: if self.model_weights_status_signal.value[0] == 1: logger.info( f"infer engine stopped! start to load new checkpoint... {self.rank}" ) self.infer_engine.update_parameters(self.args.engine_pid) elif self.model_weights_status_signal.value[0] == -1: logger.info( f"infer engine stopped! start to clear checkpoint... {self.rank}" ) self.infer_engine.clear_parameters(self.args.engine_pid) while True: if self.model_weights_status_signal.value[0] == 0: logger.info(f"finished loading new checkpoint {self.rank}") break elif is_stop == 1 or (self.model_weights_status_signal.value[0] == -2 and is_stop == 0): if is_stop == 0: logger.info( f"finished clearing checkpoint {self.rank}") is_stop = 1 time.sleep(0.001) break else: time.sleep(0.001) def run(self): """ run function, continuously get tasks and do inference. """ infer_seed_increment = paddle.full(shape=[self.args.max_num_seqs, 1], fill_value=4, dtype="int64") self.nnode = int((self.nranks + 7) // 8) mp_num_per_node = self.nranks // self.nnode while True: if self.rank == 0: if self.model_weights_status_signal.value[0] != 0: self.exist_task_signal.value[0] = 2 else: self.exist_task_signal.value[0] = 0 if self.nranks > 1: paddle.distributed.barrier() if self.exist_task_signal.value[0] == 2: self.check_model_weights_status() self.insert_step = False self.worker_healthy_live_signal.value[self.rank] = int(time.time()) if self.rank % mp_num_per_node == 0: if self.engine_worker_queue.num_tasks( ) > 0 and self.infer_engine.prefill_finished(): if self.nnode > 1: self.engine_worker_queue.read_finish_flag.set(1) else: self.exist_task_signal.value[0] = 1 if self.nranks > 1: paddle.distributed.barrier() if self.exist_task_signal.value[ 0] == 1 or self.engine_worker_queue.read_finish_flag.get( ) == 1: logger.info(f"Rank: {self.rank} Detected new requests.") self.insert_step = True tasks, read_finish = self.engine_worker_queue.get_tasks() if read_finish: self.exist_task_signal.value[0] = 0 self.engine_worker_queue.read_finish_flag.set(0) req_dicts = [] for req_dict, bsz in tasks: num_running_requests = int(bsz) req_dicts.extend(req_dict) req_ids = [req.request_id for req in req_dicts] logger.info(f"Rank: {self.rank}, num_running_requests: {num_running_requests}, " \ f"num_insert_requests: {len(req_dicts)}. {req_ids}") self.infer_engine.dy_input_preprocess(req_dicts) for req_dict in req_dicts: if self.infer_engine.share_inputs["seq_lens_this_time"][ req_dict.idx] > 1: self.prefill_tracker.start_prefill(req_dict.idx) self.infer_engine.share_inputs["not_need_stop"][0] = True if not self.infer_engine.share_inputs["not_need_stop"]: time.sleep(0.001) continue self.infer_engine.generate() self.infer_engine.share_inputs["infer_seed"].add_( infer_seed_increment) self.infer_engine.share_inputs[ "infer_seed"][:] %= self.MAX_INFER_SEED for req_dict in req_dicts: if (self.infer_engine.share_inputs["seq_lens_this_time"][ req_dict.idx] == 1 and req_dict.idx in self.prefill_tracker.start_times): self.prefill_tracker.end_prefill(req_dict.idx) self.infer_engine.update_chunked_prefill(req_dicts) self.step_cuda() def determine_num_available_blocks(self): """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. The engine will first conduct a profiling of the existing memory usage. Then, it calculate the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. .. tip:: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. start_time = time.time() GiB = 1024**3 paddle.device.cuda.empty_cache() paddle.device.cuda.reset_max_memory_allocated() before_activation_gpu_memory = paddle.device.cuda.max_memory_allocated( ) / GiB logger.info( f"before activate gpu memory: {before_activation_gpu_memory} GiB.") import gc import pynvml pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex( int(self.device_ids[self.rank])) meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) total_gpu_memory = meminfo.total / GiB used_gpu_memory = meminfo.used / GiB pynvml.nvmlShutdown() logger.info(f"used gpu memory: {used_gpu_memory} GiB.") self.run_profile() current_max_peak_gpu_memory = paddle.device.cuda.max_memory_reserved( ) / GiB logger.info( f"current max peak gpu memory: {current_max_peak_gpu_memory} GiB.") per_block_memory_used = self.infer_engine._cal_theortical_kvcache( ) / GiB logger.info(f"each kv cache block takes {per_block_memory_used} GiB.") used_cache_gpu_memory = self.args.total_block_num * per_block_memory_used logger.info(f"used cache gpu memory: {used_cache_gpu_memory} GiB.") model_weights_memory = used_gpu_memory - used_cache_gpu_memory paddle_peak_increase = current_max_peak_gpu_memory - before_activation_gpu_memory memory_for_current_instance = total_gpu_memory * self.args.gpu_memory_utilization available_kv_cache_memory = memory_for_current_instance - used_gpu_memory - \ paddle_peak_increase + used_cache_gpu_memory num_gpu_blocks = max( int(available_kv_cache_memory // per_block_memory_used), self.args.total_block_num) profile_time = time.time() - start_time msg = (f"Memory profiling takes {profile_time:.2f} seconds\n" "the current instance can use " "total_gpu_memory " f"({(total_gpu_memory):.2f}GiB)" " x gpu_memory_utilization " f"({self.args.gpu_memory_utilization})" f" = {(memory_for_current_instance):.2f}GiB\n" "model weights take " f"{(model_weights_memory ):.2f}GiB;" " Paddle activation peak memory takes " f"{(paddle_peak_increase):.2f}GiB;" " the rest of the memory reserved for KV Cache is " f"{(available_kv_cache_memory):.2f}GiB.") self.infer_engine.record_profile_msg = { "per_block_memory_used": per_block_memory_used, "paddle_peak_increase": paddle_peak_increase, } logger.info(msg) # Final cleanup get_profile_block_num = np.zeros(shape=[self.nranks], dtype=np.int32) self.get_profile_block_num_signal = IPCSignal( name="get_profile_block_num", array=get_profile_block_num, dtype=np.int32, suffix=self.args.engine_pid, create=False) self.get_profile_block_num_signal.value[self.rank] = int( num_gpu_blocks) while np.any(self.get_profile_block_num_signal.value <= 0): time.sleep(0.01) num_gpu_blocks = self.get_profile_block_num_signal.value.min().item() self.get_profile_block_num_signal.value[self.rank] = int( num_gpu_blocks) logger.info( f"{self.get_profile_block_num_signal.value[self.rank]} GPU KV blocks can be allocated." ) self.infer_engine.num_gpu_blocks = num_gpu_blocks self.infer_engine._update_share_input_block_num() paddle.device.cuda.empty_cache() gc.collect() def run_profile(self): """ run profile """ infer_seed_increment = paddle.full(shape=[self.args.max_num_seqs, 1], fill_value=4, dtype="int64") self.infer_engine.dummy_input(self.args.max_num_batched_tokens, self.args.max_num_seqs) while True: if self.nranks > 1: paddle.distributed.barrier() self.infer_engine.generate() self.infer_engine.share_inputs["infer_seed"].add_( infer_seed_increment) self.infer_engine.share_inputs[ "infer_seed"][:] %= self.MAX_INFER_SEED self.step_cuda() if int((self.infer_engine.share_inputs['seq_lens_this_time'] > 0).sum()) == 0: break def main(): """ start worker """ args = parse_args() worker = Worker(args) if args.do_profile: worker.determine_num_available_blocks() worker.run() if __name__ == "__main__": main()