diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index e08e86eab..42dd153e4 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -31,6 +31,7 @@ from fastdeploy import envs from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus from fastdeploy.cache_manager.cache_metrics import CacheMetrics from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal +from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.utils import get_logger logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log") @@ -106,6 +107,10 @@ class PrefixCacheManager: + f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}" ) + @property + def available_gpu_resource(self): + return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0 + def launch_cache_manager( self, cache_config, @@ -289,6 +294,9 @@ class PrefixCacheManager: heapq.heapify(self.gpu_free_block_list) self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks)) + main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks) + main_process_metrics.available_gpu_resource.set(1.0) + def _enable_cpu_cache(self): """ _enable_cpu_cache function used to enable cpu cache. @@ -324,6 +332,8 @@ class PrefixCacheManager: logger.info( f"allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}" ) + main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list)) + main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) return allocated_block_ids def recycle_gpu_blocks(self, gpu_block_ids): @@ -338,6 +348,8 @@ class PrefixCacheManager: heapq.heappush(self.gpu_free_block_list, gpu_block_id) else: heapq.heappush(self.gpu_free_block_list, gpu_block_ids) + main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list)) + main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) def allocate_cpu_blocks(self, num_blocks): """ diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 8d34bade0..5c4936f11 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -439,6 +439,8 @@ class LLMEngine: get_request_pool.submit(_fetch_request) # 2. Schedule requests tasks = self.resource_manager.schedule() + main_process_metrics.num_requests_waiting.dec(len(tasks)) + main_process_metrics.num_requests_running.inc(len(tasks)) # 3. Send to engine if tasks: self.resource_manager.get_real_bsz() @@ -476,6 +478,7 @@ class LLMEngine: request = Request.from_dict(data) start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER) + main_process_metrics.requests_number.inc() llm_logger.debug(f"Receive request: {request}") err_msg = None diff --git a/fastdeploy/engine/resource_manager.py b/fastdeploy/engine/resource_manager.py index 086cf6aeb..5374acf8f 100644 --- a/fastdeploy/engine/resource_manager.py +++ b/fastdeploy/engine/resource_manager.py @@ -57,14 +57,15 @@ class ResourceManager: self.logger = llm_logger self.cfg = config.cache_config self.max_num_seqs = max_num_seqs - self.stop_flags = [True] * max_num_seqs + self.stop_flags = [True] * max_num_seqs # flag set to true if the slot has not been taken self.enable_prefix_cache = config.cache_config.enable_prefix_caching self.cache_manager = PrefixCacheManager(config, tensor_parallel_size, splitwise_role, local_data_parallel_id) - self.tasks_list = [None] * max_num_seqs + self.tasks_list = [None] * max_num_seqs # task slots self.req_dict = dict() # current batch status of the engine self.real_bsz = 0 self.logger.info(f"{self.info()}") + main_process_metrics.max_batch_size.set(max_num_seqs) def reset_cache_config(self, cfg): """ @@ -228,30 +229,31 @@ class ResourceManager: Returns: list: processed task list """ - - allocated_position = 0 - processing_task_index = 0 + llm_logger.debug(f"Allocating resources for a batch of new tasks: {tasks}") + allocated_position = 0 # number of tasks that have been allocated, also the position in request slots + processing_task_index = 0 # current task processed_tasks = list() - while allocated_position < self.max_num_seqs: - if processing_task_index >= len(tasks): + while allocated_position < self.max_num_seqs: # loop until all tasks are allocated resources for + if processing_task_index >= len(tasks): # if all taskes have been tried, don't give a second chance break can_insert = False while allocated_position + 1 <= self.max_num_seqs: if sum(self.stop_flags[allocated_position : allocated_position + 1]) == 1: - can_insert = True + can_insert = True # if there is a empty slot, try to allocate resources for current task break allocated_position += 1 if can_insert: if self.stop_flags[allocated_position]: - task = tasks[processing_task_index] + task = tasks[processing_task_index] # retrieve current task if task.get("seed") is None: task.set("seed", random.randint(0, 9223372036854775807)) task.idx = allocated_position - if self.enable_prefix_cache: + if self.enable_prefix_cache: # if prefix caching is enabled + # 1. request for enough blocks for current task cache_prepare_time = time.time() common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids( task, self.cfg.block_size, self.cfg.dec_token_num @@ -259,14 +261,15 @@ class ResourceManager: if unique_block_ids is None: self.logger.warning("req_id: {0} not enough blocks available".format(task["req_id"])) return - + # 2. record cache hit information, and return the number of tokens already in cache cached_len = self._record_request_cache_info( task, common_block_ids, unique_block_ids, hit_info ) task.cache_prepare_time = time.time() - cache_prepare_time - + # 3. if prefill/decode disaggregation is enabled if task.disaggregate_info is not None: if task.disaggregate_info["role"] == "prefill": + # record the slot position for current task, indexed by request id self.req_dict[task.request_id] = allocated_position task.disaggregate_info["block_tables"] = task.block_tables self._delete_cached_data(task, cached_len) @@ -274,17 +277,19 @@ class ResourceManager: self.req_dict[task.request_id] = allocated_position task.disaggregate_info["block_tables"] = task.need_block_tables else: + # remove cached tokens from prompt token ids to avoid kv recomputation self._delete_cached_data(task, cached_len) - else: + else: # if prefix caching is disabled + # 1. directly allocate empty block from the cache, if there is any block_tables = self._get_block_tables(task.prompt_token_ids_len) if not block_tables: llm_logger.error(f"req_id: {task.request_id} block_tables is empty") - continue + continue # retry else: task.block_tables = block_tables task.need_block_tables = task.block_tables - + # 2. if prefill/decode disaggregation is enabled if task.disaggregate_info is not None: task.disaggregate_info["block_tables"] = block_tables if task.disaggregate_info["role"] == "prefill": @@ -292,8 +297,8 @@ class ResourceManager: elif task.disaggregate_info["role"] == "decode": self.req_dict[task.request_id] = allocated_position - processed_tasks.append(task) - self.stop_flags[allocated_position] = False + processed_tasks.append(task) # add current task + self.stop_flags[allocated_position] = False # mark the slot as occupied task.inference_start_time = time.time() task.inference_time_cost = -1.0 task.tokens_all_num = 0 @@ -307,11 +312,18 @@ class ResourceManager: processing_task_index += 1 # batch size when the statistical engine is inferring + # determine batch size by index of the first slot that is not occupied for i in range(self.max_num_seqs - 1, -1, -1): if not self.stop_flags[i]: self.real_bsz = i + 1 break + # record batch size here + task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list]) + main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num) + main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch()) + main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) + self.logger.info( f"Number of allocated requests: {len(tasks)}, number of " f"running requests in worker: {self.real_bsz}" ) @@ -343,6 +355,11 @@ class ResourceManager: task.cpu_cache_token_num = hit_info["cpu_cache_blocks"] * self.cfg.block_size task.cache_info = (cache_block_num, no_cache_block_num) + # Report the number of cached tokens to Prometheus metrics + main_process_metrics.prefix_cache_token_num.inc(task.num_cached_tokens) + main_process_metrics.prefix_gpu_cache_token_num.inc(task.gpu_cache_token_num) + main_process_metrics.prefix_cpu_cache_token_num.inc(task.cpu_cache_token_num) + cached_len = len(common_block_ids) * self.cfg.block_size task.block_tables = common_block_ids + unique_block_ids task.need_block_tables = unique_block_ids diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 764e71de7..0a8237702 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -27,6 +27,7 @@ import paddle from fastdeploy.engine.request import Request, RequestStatus, RequestType from fastdeploy.engine.resource_manager import ResourceManager +from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.utils import llm_logger @@ -75,6 +76,7 @@ class ResourceManagerV1(ResourceManager): self.running: list[Request] = [] self.finish_execution_pool = ThreadPoolExecutor(max_workers=1) self.lock = threading.Lock() + main_process_metrics.max_batch_size.set(max_num_seqs) def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size @@ -98,6 +100,9 @@ class ResourceManagerV1(ResourceManager): return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id) def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs): + """ + If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out. + """ can_schedule = True while True: if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks): @@ -201,6 +206,9 @@ class ResourceManagerV1(ResourceManager): return False def schedule(self): + """ + Try to pull a batch of requests from the waiting queue and schedule them. + """ with self.lock: scheduled_reqs: list[Request] = [] preempted_reqs: list[Request] = [] @@ -262,7 +270,7 @@ class ResourceManagerV1(ResourceManager): request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) # Prepare prefill task scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) - else: + else: # Not enough blocks to allocate, trigger preemption can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs) if not can_schedule: break @@ -328,6 +336,10 @@ class ResourceManagerV1(ResourceManager): else: llm_logger.error("Unknown request status type") if scheduled_reqs: + task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list]) + main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num) + main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch()) + main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) llm_logger.debug(f"schedued_reqs: {scheduled_reqs}") return scheduled_reqs @@ -369,6 +381,11 @@ class ResourceManagerV1(ResourceManager): request.block_tables = common_block_ids request.skip_allocate = False + # Report the number of cached tokens to Prometheus metrics + main_process_metrics.prefix_cache_token_num.inc(matched_token_num) + main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num) + main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num) + if matched_token_num == request.prompt_token_ids_len: request.num_computed_tokens = matched_token_num - 1 request.skip_allocate = True diff --git a/fastdeploy/metrics/metrics.py b/fastdeploy/metrics/metrics.py index a09273fc8..0798d89af 100644 --- a/fastdeploy/metrics/metrics.py +++ b/fastdeploy/metrics/metrics.py @@ -154,6 +154,22 @@ class MetricsManager: spec_decode_num_emitted_tokens_total: "Counter" spec_decode_draft_single_head_acceptance_rate: "list[Gauge]" + # for YIYAN Adapter + prefix_cache_token_num: "Gauge" + prefix_gpu_cache_token_num: "Gauge" + prefix_cpu_cache_token_num: "Gauge" + prefix_ssd_cache_token_num: "Gauge" + batch_size: "Gauge" + max_batch_size: "Gauge" + available_gpu_block_num: "Gauge" + free_gpu_block_num: "Gauge" + max_gpu_block_num: "Gauge" + available_gpu_resource: "Gauge" + requests_number: "Counter" + send_cache_failed_num: "Counter" + first_token_latency: "Gauge" + infer_latency: "Gauge" + # 定义所有指标配置 METRICS = { "num_requests_running": { @@ -258,6 +274,91 @@ class MetricsManager: "description": "Total number of successfully processed requests", "kwargs": {}, }, + # for YIYAN Adapter + "prefix_cache_token_num": { + "type": Counter, + "name": "fastdeploy:prefix_cache_token_num", + "description": "Total number of cached tokens", + "kwargs": {}, + }, + "prefix_gpu_cache_token_num": { + "type": Counter, + "name": "fastdeploy:prefix_gpu_cache_token_num", + "description": "Total number of cached tokens on GPU", + "kwargs": {}, + }, + "prefix_cpu_cache_token_num": { + "type": Counter, + "name": "fastdeploy:prefix_cpu_cache_token_num", + "description": "Total number of cached tokens on CPU", + "kwargs": {}, + }, + "prefix_ssd_cache_token_num": { + "type": Counter, + "name": "fastdeploy:prefix_ssd_cache_token_num", + "description": "Total number of cached tokens on SSD", + "kwargs": {}, + }, + "batch_size": { + "type": Gauge, + "name": "fastdeploy:batch_size", + "description": "Real batch size during inference", + "kwargs": {}, + }, + "max_batch_size": { + "type": Gauge, + "name": "fastdeploy:max_batch_size", + "description": "Maximum batch size determined when service started", + "kwargs": {}, + }, + "available_gpu_block_num": { + "type": Gauge, + "name": "fastdeploy:available_gpu_block_num", + "description": "Number of available gpu blocks in cache, including prefix caching blocks that are not officially released", + "kwargs": {}, + }, + "free_gpu_block_num": { + "type": Gauge, + "name": "fastdeploy:free_gpu_block_num", + "description": "Number of free blocks in cache", + "kwargs": {}, + }, + "max_gpu_block_num": { + "type": Gauge, + "name": "fastdeploy:max_gpu_block_num", + "description": "Number of total blocks determined when service started", + "kwargs": {}, + }, + "available_gpu_resource": { + "type": Gauge, + "name": "fastdeploy:available_gpu_resource", + "description": "Available blocks percentage, i.e. available_gpu_block_num / max_gpu_block_num", + "kwargs": {}, + }, + "requests_number": { + "type": Counter, + "name": "fastdeploy:requests_number", + "description": "Total number of requests received", + "kwargs": {}, + }, + "send_cache_failed_num": { + "type": Counter, + "name": "fastdeploy:send_cache_failed_num", + "description": "Total number of failures of sending cache", + "kwargs": {}, + }, + "first_token_latency": { + "type": Gauge, + "name": "fastdeploy:first_token_latency", + "description": "Latest time to first token in seconds", + "kwargs": {}, + }, + "infer_latency": { + "type": Gauge, + "name": "fastdeploy:infer_latency", + "description": "Latest time to generate one token in seconds", + "kwargs": {}, + }, } SPECULATIVE_METRICS = {} diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index e26c0b057..3ef61c352 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -283,6 +283,15 @@ class TokenProcessor: self.resource_manager.stop_flags[index] = True self.resource_manager.tasks_list[index] = None self.resource_manager._recycle_block_tables(task) + + task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list]) + main_process_metrics.available_gpu_block_num.set( + self.resource_manager.total_block_number() - task_used_block_num + ) + main_process_metrics.batch_size.set( + self.resource_manager.max_num_seqs - self.resource_manager.available_batch() + ) + if task_id in self.tokens_counter: del self.tokens_counter[task_id] @@ -574,6 +583,7 @@ class TokenProcessor: def _record_first_token_metrics(self, task, current_time): """Record metrics for first token""" task.first_token_time = current_time + main_process_metrics.first_token_latency.set(current_time - task.inference_start_time) main_process_metrics.time_to_first_token.observe(current_time - task.inference_start_time) main_process_metrics.request_queue_time.observe(task.schedule_start_time - task.preprocess_end_time) @@ -585,6 +595,7 @@ class TokenProcessor: main_process_metrics.num_requests_running.dec(1) main_process_metrics.request_success_total.inc() + main_process_metrics.infer_latency.set(current_time - task.inference_start_time) main_process_metrics.request_inference_time.observe(current_time - task.inference_start_time) main_process_metrics.request_generation_tokens.observe(self.tokens_counter[task.request_id]) diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index dbcb46b47..b1fdd091c 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -24,6 +24,7 @@ import zmq from fastdeploy import envs from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput from fastdeploy.inter_communicator import EngineWorkerQueue +from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.utils import get_logger @@ -158,6 +159,7 @@ class SplitwiseConnector: except zmq.Again: self.logger.warning(f"Send queue full for {addr}") except Exception as e: + main_process_metrics.send_cache_failed_num.inc() self.logger.error(f"Send to {addr} failed: {e}") self._close_connection(addr)