mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[feat] add metrics for yiyan adapter (#3219) (#3614)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
* [feat] add metrics for yiyan adapter * [fix] fix metrics num_requests_waiting and num_requests_running * [fix] fix metrics gpu_cache_usage_perc * [refactor] change where requests_number increases * [chore] rename xxx_block_num as xxx_gpu_block_num, and update their values accordingly * [chore] delete useless code
This commit is contained in:
@@ -32,6 +32,7 @@ from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
|
||||
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
|
||||
@@ -110,6 +111,10 @@ class PrefixCacheManager:
|
||||
+ f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}"
|
||||
)
|
||||
|
||||
@property
|
||||
def available_gpu_resource(self):
|
||||
return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0
|
||||
|
||||
def launch_cache_manager(
|
||||
self,
|
||||
cache_config,
|
||||
@@ -229,6 +234,9 @@ class PrefixCacheManager:
|
||||
heapq.heapify(self.gpu_free_block_list)
|
||||
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
|
||||
|
||||
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
|
||||
main_process_metrics.available_gpu_resource.set(1.0)
|
||||
|
||||
def _enable_cpu_cache(self):
|
||||
"""
|
||||
_enable_cpu_cache function used to enable cpu cache.
|
||||
@@ -264,6 +272,8 @@ class PrefixCacheManager:
|
||||
logger.info(
|
||||
f"allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
|
||||
)
|
||||
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
|
||||
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
|
||||
return allocated_block_ids
|
||||
|
||||
def recycle_gpu_blocks(self, gpu_block_ids):
|
||||
@@ -278,6 +288,8 @@ class PrefixCacheManager:
|
||||
heapq.heappush(self.gpu_free_block_list, gpu_block_id)
|
||||
else:
|
||||
heapq.heappush(self.gpu_free_block_list, gpu_block_ids)
|
||||
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
|
||||
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
|
||||
|
||||
def allocate_cpu_blocks(self, num_blocks):
|
||||
"""
|
||||
|
@@ -552,6 +552,8 @@ class EngineSevice:
|
||||
get_request_pool.submit(_fetch_request)
|
||||
# 2. Schedule requests
|
||||
tasks = self.resource_manager.schedule()
|
||||
main_process_metrics.num_requests_waiting.dec(len(tasks))
|
||||
main_process_metrics.num_requests_running.inc(len(tasks))
|
||||
# 3. Send to engine
|
||||
if tasks:
|
||||
self.resource_manager.get_real_bsz()
|
||||
@@ -597,6 +599,7 @@ class EngineSevice:
|
||||
try:
|
||||
request = Request.from_dict(data)
|
||||
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
|
||||
main_process_metrics.requests_number.inc()
|
||||
llm_logger.debug(f"Receive request: {request}")
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
|
||||
|
@@ -51,14 +51,15 @@ class ResourceManager:
|
||||
"""
|
||||
self.cfg = config.cache_config
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.stop_flags = [True] * max_num_seqs
|
||||
self.stop_flags = [True] * max_num_seqs # flag set to true if the slot has not been taken
|
||||
self.enable_prefix_cache = config.cache_config.enable_prefix_caching
|
||||
self.cache_manager = PrefixCacheManager(config, tensor_parallel_size, splitwise_role, local_data_parallel_id)
|
||||
self.tasks_list = [None] * max_num_seqs
|
||||
self.tasks_list = [None] * max_num_seqs # task slots
|
||||
self.req_dict = dict()
|
||||
# current batch status of the engine
|
||||
self.real_bsz = 0
|
||||
llm_logger.info(f"{self.info()}")
|
||||
main_process_metrics.max_batch_size.set(max_num_seqs)
|
||||
|
||||
def reset_cache_config(self, cfg):
|
||||
"""
|
||||
@@ -222,18 +223,18 @@ class ResourceManager:
|
||||
Returns:
|
||||
list: processed task list
|
||||
"""
|
||||
|
||||
allocated_position = 0
|
||||
processing_task_index = 0
|
||||
llm_logger.debug(f"Allocating resources for a batch of new tasks: {tasks}")
|
||||
allocated_position = 0 # number of tasks that have been allocated, also the position in request slots
|
||||
processing_task_index = 0 # current task
|
||||
processed_tasks = list()
|
||||
while allocated_position < self.max_num_seqs:
|
||||
if processing_task_index >= len(tasks):
|
||||
while allocated_position < self.max_num_seqs: # loop until all tasks are allocated resources for
|
||||
if processing_task_index >= len(tasks): # if all taskes have been tried, don't give a second chance
|
||||
break
|
||||
|
||||
can_insert = False
|
||||
while allocated_position < self.max_num_seqs:
|
||||
if sum(self.stop_flags[allocated_position : allocated_position + 1]) == 1:
|
||||
can_insert = True
|
||||
can_insert = True # if there is a empty slot, try to allocate resources for current task
|
||||
break
|
||||
allocated_position += 1
|
||||
if can_insert:
|
||||
@@ -243,7 +244,8 @@ class ResourceManager:
|
||||
task.set("seed", random.randint(0, 9223372036854775807))
|
||||
task.idx = allocated_position
|
||||
|
||||
if self.enable_prefix_cache:
|
||||
if self.enable_prefix_cache: # if prefix caching is enabled
|
||||
# 1. request for enough blocks for current task
|
||||
cache_prepare_time = time.time()
|
||||
common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids(
|
||||
task,
|
||||
@@ -253,12 +255,13 @@ class ResourceManager:
|
||||
if unique_block_ids is None:
|
||||
llm_logger.warning("req_id: {0} not enough blocks available".format(task["req_id"]))
|
||||
return
|
||||
|
||||
# 2. record cache hit information, and return the number of tokens already in cache
|
||||
cached_len = self._record_request_cache_info(task, common_block_ids, unique_block_ids, hit_info)
|
||||
task.cache_prepare_time = time.time() - cache_prepare_time
|
||||
|
||||
# 3. if prefill/decode disaggregation is enabled
|
||||
if task.disaggregate_info is not None:
|
||||
if task.disaggregate_info["role"] == "prefill":
|
||||
# record the slot position for current task, indexed by request id
|
||||
self.req_dict[task.request_id] = allocated_position
|
||||
task.disaggregate_info["block_tables"] = task.block_tables
|
||||
self._delete_cached_data(task, cached_len)
|
||||
@@ -266,17 +269,19 @@ class ResourceManager:
|
||||
self.req_dict[task.request_id] = allocated_position
|
||||
task.disaggregate_info["block_tables"] = task.need_block_tables
|
||||
else:
|
||||
# remove cached tokens from prompt token ids to avoid kv recomputation
|
||||
self._delete_cached_data(task, cached_len)
|
||||
|
||||
else:
|
||||
else: # if prefix caching is disabled
|
||||
# 1. directly allocate empty block from the cache, if there is any
|
||||
block_tables = self._get_block_tables(task.prompt_token_ids_len)
|
||||
if not block_tables:
|
||||
llm_logger.error(f"req_id: {task.request_id} block_tables is empty")
|
||||
continue
|
||||
continue # retry
|
||||
else:
|
||||
task.block_tables = block_tables
|
||||
task.need_block_tables = task.block_tables
|
||||
|
||||
# 2. if prefill/decode disaggregation is enabled
|
||||
if task.disaggregate_info is not None:
|
||||
task.disaggregate_info["block_tables"] = block_tables
|
||||
if task.disaggregate_info["role"] == "prefill":
|
||||
@@ -284,8 +289,8 @@ class ResourceManager:
|
||||
elif task.disaggregate_info["role"] == "decode":
|
||||
self.req_dict[task.request_id] = allocated_position
|
||||
|
||||
processed_tasks.append(task)
|
||||
self.stop_flags[allocated_position] = False
|
||||
processed_tasks.append(task) # add current task
|
||||
self.stop_flags[allocated_position] = False # mark the slot as occupied
|
||||
task.inference_start_time = time.time()
|
||||
task.inference_time_cost = -1.0
|
||||
task.tokens_all_num = 0
|
||||
@@ -299,11 +304,18 @@ class ResourceManager:
|
||||
processing_task_index += 1
|
||||
|
||||
# batch size when the statistical engine is inferring
|
||||
# determine batch size by index of the first slot that is not occupied
|
||||
for i in range(self.max_num_seqs - 1, -1, -1):
|
||||
if not self.stop_flags[i]:
|
||||
self.real_bsz = i + 1
|
||||
break
|
||||
|
||||
# record batch size here
|
||||
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
|
||||
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num)
|
||||
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
|
||||
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
||||
|
||||
llm_logger.info(
|
||||
f"Number of allocated requests: {len(tasks)}, number of " f"running requests in worker: {self.real_bsz}"
|
||||
)
|
||||
@@ -335,6 +347,11 @@ class ResourceManager:
|
||||
task.cpu_cache_token_num = hit_info["cpu_cache_blocks"] * self.cfg.block_size
|
||||
task.cache_info = (cache_block_num, no_cache_block_num)
|
||||
|
||||
# Report the number of cached tokens to Prometheus metrics
|
||||
main_process_metrics.prefix_cache_token_num.inc(task.num_cached_tokens)
|
||||
main_process_metrics.prefix_gpu_cache_token_num.inc(task.gpu_cache_token_num)
|
||||
main_process_metrics.prefix_cpu_cache_token_num.inc(task.cpu_cache_token_num)
|
||||
|
||||
cached_len = len(common_block_ids) * self.cfg.block_size
|
||||
task.block_tables = common_block_ids + unique_block_ids
|
||||
task.need_block_tables = unique_block_ids
|
||||
|
@@ -28,6 +28,7 @@ import paddle
|
||||
|
||||
from fastdeploy.engine.request import Request, RequestStatus, RequestType
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
@@ -77,6 +78,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
self.finish_execution_pool = ThreadPoolExecutor(max_workers=1)
|
||||
self.lock = threading.Lock()
|
||||
self.to_be_rescheduled_request_id_set = set()
|
||||
main_process_metrics.max_batch_size.set(max_num_seqs)
|
||||
|
||||
def allocated_slots(self, request: Request):
|
||||
return len(request.block_tables) * self.config.cache_config.block_size
|
||||
@@ -107,6 +109,9 @@ class ResourceManagerV1(ResourceManager):
|
||||
self.to_be_rescheduled_request_id_set.remove(request_id)
|
||||
|
||||
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
|
||||
"""
|
||||
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
|
||||
"""
|
||||
can_schedule = True
|
||||
while True:
|
||||
if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks):
|
||||
@@ -244,6 +249,9 @@ class ResourceManagerV1(ResourceManager):
|
||||
return False
|
||||
|
||||
def schedule(self):
|
||||
"""
|
||||
Try to pull a batch of requests from the waiting queue and schedule them.
|
||||
"""
|
||||
with self.lock:
|
||||
scheduled_reqs: list[Request] = []
|
||||
preempted_reqs: list[Request] = []
|
||||
@@ -305,7 +313,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
|
||||
# Prepare prefill task
|
||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||
else:
|
||||
else: # Not enough blocks to allocate, trigger preemption
|
||||
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs)
|
||||
if not can_schedule:
|
||||
break
|
||||
@@ -371,6 +379,10 @@ class ResourceManagerV1(ResourceManager):
|
||||
else:
|
||||
llm_logger.error("Unknown request status type")
|
||||
if scheduled_reqs:
|
||||
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
|
||||
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num)
|
||||
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
|
||||
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
||||
llm_logger.debug(f"schedued_reqs: {scheduled_reqs}")
|
||||
return scheduled_reqs
|
||||
|
||||
@@ -412,6 +424,11 @@ class ResourceManagerV1(ResourceManager):
|
||||
request.block_tables = common_block_ids
|
||||
request.skip_allocate = False
|
||||
|
||||
# Report the number of cached tokens to Prometheus metrics
|
||||
main_process_metrics.prefix_cache_token_num.inc(matched_token_num)
|
||||
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
|
||||
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
|
||||
|
||||
if matched_token_num == request.prompt_token_ids_len:
|
||||
request.num_computed_tokens = matched_token_num - 1
|
||||
request.skip_allocate = True
|
||||
|
@@ -154,6 +154,22 @@ class MetricsManager:
|
||||
spec_decode_num_emitted_tokens_total: "Counter"
|
||||
spec_decode_draft_single_head_acceptance_rate: "list[Gauge]"
|
||||
|
||||
# for YIYAN Adapter
|
||||
prefix_cache_token_num: "Gauge"
|
||||
prefix_gpu_cache_token_num: "Gauge"
|
||||
prefix_cpu_cache_token_num: "Gauge"
|
||||
prefix_ssd_cache_token_num: "Gauge"
|
||||
batch_size: "Gauge"
|
||||
max_batch_size: "Gauge"
|
||||
available_gpu_block_num: "Gauge"
|
||||
free_gpu_block_num: "Gauge"
|
||||
max_gpu_block_num: "Gauge"
|
||||
available_gpu_resource: "Gauge"
|
||||
requests_number: "Counter"
|
||||
send_cache_failed_num: "Counter"
|
||||
first_token_latency: "Gauge"
|
||||
infer_latency: "Gauge"
|
||||
|
||||
# 定义所有指标配置
|
||||
METRICS = {
|
||||
"num_requests_running": {
|
||||
@@ -258,6 +274,91 @@ class MetricsManager:
|
||||
"description": "Total number of successfully processed requests",
|
||||
"kwargs": {},
|
||||
},
|
||||
# for YIYAN Adapter
|
||||
"prefix_cache_token_num": {
|
||||
"type": Counter,
|
||||
"name": "fastdeploy:prefix_cache_token_num",
|
||||
"description": "Total number of cached tokens",
|
||||
"kwargs": {},
|
||||
},
|
||||
"prefix_gpu_cache_token_num": {
|
||||
"type": Counter,
|
||||
"name": "fastdeploy:prefix_gpu_cache_token_num",
|
||||
"description": "Total number of cached tokens on GPU",
|
||||
"kwargs": {},
|
||||
},
|
||||
"prefix_cpu_cache_token_num": {
|
||||
"type": Counter,
|
||||
"name": "fastdeploy:prefix_cpu_cache_token_num",
|
||||
"description": "Total number of cached tokens on CPU",
|
||||
"kwargs": {},
|
||||
},
|
||||
"prefix_ssd_cache_token_num": {
|
||||
"type": Counter,
|
||||
"name": "fastdeploy:prefix_ssd_cache_token_num",
|
||||
"description": "Total number of cached tokens on SSD",
|
||||
"kwargs": {},
|
||||
},
|
||||
"batch_size": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:batch_size",
|
||||
"description": "Real batch size during inference",
|
||||
"kwargs": {},
|
||||
},
|
||||
"max_batch_size": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:max_batch_size",
|
||||
"description": "Maximum batch size determined when service started",
|
||||
"kwargs": {},
|
||||
},
|
||||
"available_gpu_block_num": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:available_gpu_block_num",
|
||||
"description": "Number of available gpu blocks in cache, including prefix caching blocks that are not officially released",
|
||||
"kwargs": {},
|
||||
},
|
||||
"free_gpu_block_num": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:free_gpu_block_num",
|
||||
"description": "Number of free blocks in cache",
|
||||
"kwargs": {},
|
||||
},
|
||||
"max_gpu_block_num": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:max_gpu_block_num",
|
||||
"description": "Number of total blocks determined when service started",
|
||||
"kwargs": {},
|
||||
},
|
||||
"available_gpu_resource": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:available_gpu_resource",
|
||||
"description": "Available blocks percentage, i.e. available_gpu_block_num / max_gpu_block_num",
|
||||
"kwargs": {},
|
||||
},
|
||||
"requests_number": {
|
||||
"type": Counter,
|
||||
"name": "fastdeploy:requests_number",
|
||||
"description": "Total number of requests received",
|
||||
"kwargs": {},
|
||||
},
|
||||
"send_cache_failed_num": {
|
||||
"type": Counter,
|
||||
"name": "fastdeploy:send_cache_failed_num",
|
||||
"description": "Total number of failures of sending cache",
|
||||
"kwargs": {},
|
||||
},
|
||||
"first_token_latency": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:first_token_latency",
|
||||
"description": "Latest time to first token in seconds",
|
||||
"kwargs": {},
|
||||
},
|
||||
"infer_latency": {
|
||||
"type": Gauge,
|
||||
"name": "fastdeploy:infer_latency",
|
||||
"description": "Latest time to generate one token in seconds",
|
||||
"kwargs": {},
|
||||
},
|
||||
}
|
||||
SPECULATIVE_METRICS = {}
|
||||
|
||||
|
@@ -247,6 +247,15 @@ class TokenProcessor:
|
||||
self.resource_manager.stop_flags[index] = True
|
||||
self.resource_manager.tasks_list[index] = None
|
||||
self.resource_manager._recycle_block_tables(task)
|
||||
|
||||
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list])
|
||||
main_process_metrics.available_gpu_block_num.set(
|
||||
self.resource_manager.total_block_number() - task_used_block_num
|
||||
)
|
||||
main_process_metrics.batch_size.set(
|
||||
self.resource_manager.max_num_seqs - self.resource_manager.available_batch()
|
||||
)
|
||||
|
||||
if task_id in self.tokens_counter:
|
||||
del self.tokens_counter[task_id]
|
||||
|
||||
@@ -437,6 +446,7 @@ class TokenProcessor:
|
||||
def _record_first_token_metrics(self, task, current_time):
|
||||
"""Record metrics for first token"""
|
||||
task.first_token_time = current_time
|
||||
main_process_metrics.first_token_latency.set(current_time - task.inference_start_time)
|
||||
main_process_metrics.time_to_first_token.observe(current_time - task.inference_start_time)
|
||||
main_process_metrics.request_queue_time.observe(task.schedule_start_time - task.preprocess_end_time)
|
||||
|
||||
@@ -448,6 +458,7 @@ class TokenProcessor:
|
||||
|
||||
main_process_metrics.num_requests_running.dec(1)
|
||||
main_process_metrics.request_success_total.inc()
|
||||
main_process_metrics.infer_latency.set(current_time - task.inference_start_time)
|
||||
main_process_metrics.request_inference_time.observe(current_time - task.inference_start_time)
|
||||
main_process_metrics.request_generation_tokens.observe(self.tokens_counter[task.request_id])
|
||||
|
||||
|
@@ -25,6 +25,7 @@ import zmq
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("splitwise_connector", "splitwise_connector.log")
|
||||
@@ -155,6 +156,7 @@ class SplitwiseConnector:
|
||||
logger.warning(f"Send queue full for {addr}")
|
||||
except Exception as e:
|
||||
logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}")
|
||||
main_process_metrics.send_cache_failed_num.inc()
|
||||
self._close_connection(addr)
|
||||
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user