[feat] add metrics for yiyan adapter (#3219) (#3614)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled

* [feat] add metrics for yiyan adapter

* [fix] fix metrics num_requests_waiting and num_requests_running

* [fix] fix metrics gpu_cache_usage_perc

* [refactor] change where requests_number increases

* [chore] rename xxx_block_num as xxx_gpu_block_num, and update their values accordingly

* [chore] delete useless code
This commit is contained in:
李泳桦
2025-08-30 23:20:58 +08:00
committed by GitHub
parent fe5d09f9ee
commit 98e03fb4ea
7 changed files with 180 additions and 17 deletions

View File

@@ -32,6 +32,7 @@ from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
@@ -110,6 +111,10 @@ class PrefixCacheManager:
+ f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}"
)
@property
def available_gpu_resource(self):
return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0
def launch_cache_manager(
self,
cache_config,
@@ -229,6 +234,9 @@ class PrefixCacheManager:
heapq.heapify(self.gpu_free_block_list)
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_resource.set(1.0)
def _enable_cpu_cache(self):
"""
_enable_cpu_cache function used to enable cpu cache.
@@ -264,6 +272,8 @@ class PrefixCacheManager:
logger.info(
f"allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
)
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
return allocated_block_ids
def recycle_gpu_blocks(self, gpu_block_ids):
@@ -278,6 +288,8 @@ class PrefixCacheManager:
heapq.heappush(self.gpu_free_block_list, gpu_block_id)
else:
heapq.heappush(self.gpu_free_block_list, gpu_block_ids)
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
def allocate_cpu_blocks(self, num_blocks):
"""

View File

@@ -552,6 +552,8 @@ class EngineSevice:
get_request_pool.submit(_fetch_request)
# 2. Schedule requests
tasks = self.resource_manager.schedule()
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
# 3. Send to engine
if tasks:
self.resource_manager.get_real_bsz()
@@ -597,6 +599,7 @@ class EngineSevice:
try:
request = Request.from_dict(data)
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
main_process_metrics.requests_number.inc()
llm_logger.debug(f"Receive request: {request}")
except Exception as e:
llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")

View File

@@ -51,14 +51,15 @@ class ResourceManager:
"""
self.cfg = config.cache_config
self.max_num_seqs = max_num_seqs
self.stop_flags = [True] * max_num_seqs
self.stop_flags = [True] * max_num_seqs # flag set to true if the slot has not been taken
self.enable_prefix_cache = config.cache_config.enable_prefix_caching
self.cache_manager = PrefixCacheManager(config, tensor_parallel_size, splitwise_role, local_data_parallel_id)
self.tasks_list = [None] * max_num_seqs
self.tasks_list = [None] * max_num_seqs # task slots
self.req_dict = dict()
# current batch status of the engine
self.real_bsz = 0
llm_logger.info(f"{self.info()}")
main_process_metrics.max_batch_size.set(max_num_seqs)
def reset_cache_config(self, cfg):
"""
@@ -222,18 +223,18 @@ class ResourceManager:
Returns:
list: processed task list
"""
allocated_position = 0
processing_task_index = 0
llm_logger.debug(f"Allocating resources for a batch of new tasks: {tasks}")
allocated_position = 0 # number of tasks that have been allocated, also the position in request slots
processing_task_index = 0 # current task
processed_tasks = list()
while allocated_position < self.max_num_seqs:
if processing_task_index >= len(tasks):
while allocated_position < self.max_num_seqs: # loop until all tasks are allocated resources for
if processing_task_index >= len(tasks): # if all taskes have been tried, don't give a second chance
break
can_insert = False
while allocated_position < self.max_num_seqs:
if sum(self.stop_flags[allocated_position : allocated_position + 1]) == 1:
can_insert = True
can_insert = True # if there is a empty slot, try to allocate resources for current task
break
allocated_position += 1
if can_insert:
@@ -243,7 +244,8 @@ class ResourceManager:
task.set("seed", random.randint(0, 9223372036854775807))
task.idx = allocated_position
if self.enable_prefix_cache:
if self.enable_prefix_cache: # if prefix caching is enabled
# 1. request for enough blocks for current task
cache_prepare_time = time.time()
common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids(
task,
@@ -253,12 +255,13 @@ class ResourceManager:
if unique_block_ids is None:
llm_logger.warning("req_id: {0} not enough blocks available".format(task["req_id"]))
return
# 2. record cache hit information, and return the number of tokens already in cache
cached_len = self._record_request_cache_info(task, common_block_ids, unique_block_ids, hit_info)
task.cache_prepare_time = time.time() - cache_prepare_time
# 3. if prefill/decode disaggregation is enabled
if task.disaggregate_info is not None:
if task.disaggregate_info["role"] == "prefill":
# record the slot position for current task, indexed by request id
self.req_dict[task.request_id] = allocated_position
task.disaggregate_info["block_tables"] = task.block_tables
self._delete_cached_data(task, cached_len)
@@ -266,17 +269,19 @@ class ResourceManager:
self.req_dict[task.request_id] = allocated_position
task.disaggregate_info["block_tables"] = task.need_block_tables
else:
# remove cached tokens from prompt token ids to avoid kv recomputation
self._delete_cached_data(task, cached_len)
else:
else: # if prefix caching is disabled
# 1. directly allocate empty block from the cache, if there is any
block_tables = self._get_block_tables(task.prompt_token_ids_len)
if not block_tables:
llm_logger.error(f"req_id: {task.request_id} block_tables is empty")
continue
continue # retry
else:
task.block_tables = block_tables
task.need_block_tables = task.block_tables
# 2. if prefill/decode disaggregation is enabled
if task.disaggregate_info is not None:
task.disaggregate_info["block_tables"] = block_tables
if task.disaggregate_info["role"] == "prefill":
@@ -284,8 +289,8 @@ class ResourceManager:
elif task.disaggregate_info["role"] == "decode":
self.req_dict[task.request_id] = allocated_position
processed_tasks.append(task)
self.stop_flags[allocated_position] = False
processed_tasks.append(task) # add current task
self.stop_flags[allocated_position] = False # mark the slot as occupied
task.inference_start_time = time.time()
task.inference_time_cost = -1.0
task.tokens_all_num = 0
@@ -299,11 +304,18 @@ class ResourceManager:
processing_task_index += 1
# batch size when the statistical engine is inferring
# determine batch size by index of the first slot that is not occupied
for i in range(self.max_num_seqs - 1, -1, -1):
if not self.stop_flags[i]:
self.real_bsz = i + 1
break
# record batch size here
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num)
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
llm_logger.info(
f"Number of allocated requests: {len(tasks)}, number of " f"running requests in worker: {self.real_bsz}"
)
@@ -335,6 +347,11 @@ class ResourceManager:
task.cpu_cache_token_num = hit_info["cpu_cache_blocks"] * self.cfg.block_size
task.cache_info = (cache_block_num, no_cache_block_num)
# Report the number of cached tokens to Prometheus metrics
main_process_metrics.prefix_cache_token_num.inc(task.num_cached_tokens)
main_process_metrics.prefix_gpu_cache_token_num.inc(task.gpu_cache_token_num)
main_process_metrics.prefix_cpu_cache_token_num.inc(task.cpu_cache_token_num)
cached_len = len(common_block_ids) * self.cfg.block_size
task.block_tables = common_block_ids + unique_block_ids
task.need_block_tables = unique_block_ids

View File

@@ -28,6 +28,7 @@ import paddle
from fastdeploy.engine.request import Request, RequestStatus, RequestType
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import llm_logger
@@ -77,6 +78,7 @@ class ResourceManagerV1(ResourceManager):
self.finish_execution_pool = ThreadPoolExecutor(max_workers=1)
self.lock = threading.Lock()
self.to_be_rescheduled_request_id_set = set()
main_process_metrics.max_batch_size.set(max_num_seqs)
def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size
@@ -107,6 +109,9 @@ class ResourceManagerV1(ResourceManager):
self.to_be_rescheduled_request_id_set.remove(request_id)
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
"""
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
"""
can_schedule = True
while True:
if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks):
@@ -244,6 +249,9 @@ class ResourceManagerV1(ResourceManager):
return False
def schedule(self):
"""
Try to pull a batch of requests from the waiting queue and schedule them.
"""
with self.lock:
scheduled_reqs: list[Request] = []
preempted_reqs: list[Request] = []
@@ -305,7 +313,7 @@ class ResourceManagerV1(ResourceManager):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
# Prepare prefill task
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
else:
else: # Not enough blocks to allocate, trigger preemption
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs)
if not can_schedule:
break
@@ -371,6 +379,10 @@ class ResourceManagerV1(ResourceManager):
else:
llm_logger.error("Unknown request status type")
if scheduled_reqs:
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num)
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
llm_logger.debug(f"schedued_reqs: {scheduled_reqs}")
return scheduled_reqs
@@ -412,6 +424,11 @@ class ResourceManagerV1(ResourceManager):
request.block_tables = common_block_ids
request.skip_allocate = False
# Report the number of cached tokens to Prometheus metrics
main_process_metrics.prefix_cache_token_num.inc(matched_token_num)
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
if matched_token_num == request.prompt_token_ids_len:
request.num_computed_tokens = matched_token_num - 1
request.skip_allocate = True

View File

@@ -154,6 +154,22 @@ class MetricsManager:
spec_decode_num_emitted_tokens_total: "Counter"
spec_decode_draft_single_head_acceptance_rate: "list[Gauge]"
# for YIYAN Adapter
prefix_cache_token_num: "Gauge"
prefix_gpu_cache_token_num: "Gauge"
prefix_cpu_cache_token_num: "Gauge"
prefix_ssd_cache_token_num: "Gauge"
batch_size: "Gauge"
max_batch_size: "Gauge"
available_gpu_block_num: "Gauge"
free_gpu_block_num: "Gauge"
max_gpu_block_num: "Gauge"
available_gpu_resource: "Gauge"
requests_number: "Counter"
send_cache_failed_num: "Counter"
first_token_latency: "Gauge"
infer_latency: "Gauge"
# 定义所有指标配置
METRICS = {
"num_requests_running": {
@@ -258,6 +274,91 @@ class MetricsManager:
"description": "Total number of successfully processed requests",
"kwargs": {},
},
# for YIYAN Adapter
"prefix_cache_token_num": {
"type": Counter,
"name": "fastdeploy:prefix_cache_token_num",
"description": "Total number of cached tokens",
"kwargs": {},
},
"prefix_gpu_cache_token_num": {
"type": Counter,
"name": "fastdeploy:prefix_gpu_cache_token_num",
"description": "Total number of cached tokens on GPU",
"kwargs": {},
},
"prefix_cpu_cache_token_num": {
"type": Counter,
"name": "fastdeploy:prefix_cpu_cache_token_num",
"description": "Total number of cached tokens on CPU",
"kwargs": {},
},
"prefix_ssd_cache_token_num": {
"type": Counter,
"name": "fastdeploy:prefix_ssd_cache_token_num",
"description": "Total number of cached tokens on SSD",
"kwargs": {},
},
"batch_size": {
"type": Gauge,
"name": "fastdeploy:batch_size",
"description": "Real batch size during inference",
"kwargs": {},
},
"max_batch_size": {
"type": Gauge,
"name": "fastdeploy:max_batch_size",
"description": "Maximum batch size determined when service started",
"kwargs": {},
},
"available_gpu_block_num": {
"type": Gauge,
"name": "fastdeploy:available_gpu_block_num",
"description": "Number of available gpu blocks in cache, including prefix caching blocks that are not officially released",
"kwargs": {},
},
"free_gpu_block_num": {
"type": Gauge,
"name": "fastdeploy:free_gpu_block_num",
"description": "Number of free blocks in cache",
"kwargs": {},
},
"max_gpu_block_num": {
"type": Gauge,
"name": "fastdeploy:max_gpu_block_num",
"description": "Number of total blocks determined when service started",
"kwargs": {},
},
"available_gpu_resource": {
"type": Gauge,
"name": "fastdeploy:available_gpu_resource",
"description": "Available blocks percentage, i.e. available_gpu_block_num / max_gpu_block_num",
"kwargs": {},
},
"requests_number": {
"type": Counter,
"name": "fastdeploy:requests_number",
"description": "Total number of requests received",
"kwargs": {},
},
"send_cache_failed_num": {
"type": Counter,
"name": "fastdeploy:send_cache_failed_num",
"description": "Total number of failures of sending cache",
"kwargs": {},
},
"first_token_latency": {
"type": Gauge,
"name": "fastdeploy:first_token_latency",
"description": "Latest time to first token in seconds",
"kwargs": {},
},
"infer_latency": {
"type": Gauge,
"name": "fastdeploy:infer_latency",
"description": "Latest time to generate one token in seconds",
"kwargs": {},
},
}
SPECULATIVE_METRICS = {}

View File

@@ -247,6 +247,15 @@ class TokenProcessor:
self.resource_manager.stop_flags[index] = True
self.resource_manager.tasks_list[index] = None
self.resource_manager._recycle_block_tables(task)
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list])
main_process_metrics.available_gpu_block_num.set(
self.resource_manager.total_block_number() - task_used_block_num
)
main_process_metrics.batch_size.set(
self.resource_manager.max_num_seqs - self.resource_manager.available_batch()
)
if task_id in self.tokens_counter:
del self.tokens_counter[task_id]
@@ -437,6 +446,7 @@ class TokenProcessor:
def _record_first_token_metrics(self, task, current_time):
"""Record metrics for first token"""
task.first_token_time = current_time
main_process_metrics.first_token_latency.set(current_time - task.inference_start_time)
main_process_metrics.time_to_first_token.observe(current_time - task.inference_start_time)
main_process_metrics.request_queue_time.observe(task.schedule_start_time - task.preprocess_end_time)
@@ -448,6 +458,7 @@ class TokenProcessor:
main_process_metrics.num_requests_running.dec(1)
main_process_metrics.request_success_total.inc()
main_process_metrics.infer_latency.set(current_time - task.inference_start_time)
main_process_metrics.request_inference_time.observe(current_time - task.inference_start_time)
main_process_metrics.request_generation_tokens.observe(self.tokens_counter[task.request_id])

View File

@@ -25,6 +25,7 @@ import zmq
from fastdeploy import envs
from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput
from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
logger = get_logger("splitwise_connector", "splitwise_connector.log")
@@ -155,6 +156,7 @@ class SplitwiseConnector:
logger.warning(f"Send queue full for {addr}")
except Exception as e:
logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
main_process_metrics.send_cache_failed_num.inc()
self._close_connection(addr)
except Exception as e: