[Features] Add speculative metrics (#2857)

This commit is contained in:
GoldPancake
2025-07-17 11:08:55 +08:00
committed by GitHub
parent 52aca233e8
commit 42d4001400
2 changed files with 164 additions and 10 deletions

View File

@@ -107,8 +107,6 @@ REQUEST_LATENCY_BUCKETS = [
]
class MetricsManager:
"""Prometheus Metrics Manager handles all metric updates """
@@ -126,6 +124,12 @@ class MetricsManager:
request_decode_time: 'Histogram'
request_generation_tokens: 'Histogram'
request_success_total: 'Counter'
spec_decode_draft_acceptance_rate: 'Gauge'
spec_decode_efficiency: 'Gauge'
spec_decode_num_accepted_tokens_total: 'Counter'
spec_decode_num_draft_tokens_total: 'Counter'
spec_decode_num_emitted_tokens_total: 'Counter'
spec_decode_draft_single_head_acceptance_rate: 'list[Gauge]'
# 定义所有指标配置
METRICS = {
@@ -216,8 +220,9 @@ class MetricsManager:
'name': 'fastdeploy:request_success_total',
'description': 'Total number of successfully processed requests',
'kwargs': {}
}
},
}
SPECULATIVE_METRICS = {}
def __init__(self):
"""Initializes the Prometheus metrics and starts the HTTP server if not already initialized."""
@@ -229,6 +234,75 @@ class MetricsManager:
**config['kwargs']
))
def _init_speculative_metrics(self, speculative_method, num_speculative_tokens):
self.SPECULATIVE_METRICS = {
"spec_decode_draft_acceptance_rate": {
"type": Gauge,
"name": "fastdeploy:spec_decode_draft_acceptance_rate",
"description": "Acceptance rate of speculative decoding",
"kwargs": {},
},
"spec_decode_num_accepted_tokens_total": {
"type": Counter,
"name": "fastdeploy:spec_decode_num_accepted_tokens_total",
"description": "Total number of tokens accepted by the scoring model and verification program",
"kwargs": {},
},
"spec_decode_num_emitted_tokens_total": {
"type": Counter,
"name": "fastdeploy:spec_decode_num_emitted_tokens_total",
"description": "Total number of tokens output by the entire system",
"kwargs": {},
},
}
if speculative_method == "mtp":
self.SPECULATIVE_METRICS["spec_decode_efficiency"]={
"type": Gauge,
"name": "fastdeploy:spec_decode_efficiency",
"description": "Efficiency of speculative decoding",
"kwargs": {},
}
self.SPECULATIVE_METRICS["spec_decode_num_draft_tokens_total"]={
"type": Counter,
"name": "fastdeploy:spec_decode_num_draft_tokens_total",
"description": "Total number of speculative tokens generated by the proposal method",
"kwargs": {},
}
self.SPECULATIVE_METRICS["spec_decode_draft_single_head_acceptance_rate"]={
"type": list[Gauge],
"name": "fastdeploy:spec_decode_draft_single_head_acceptance_rate",
"description": "Single head acceptance rate of speculative decoding",
"kwargs": {},
}
for metric_name, config in self.SPECULATIVE_METRICS.items():
if metric_name == "spec_decode_draft_single_head_acceptance_rate":
gauges = []
for i in range(num_speculative_tokens):
gauges.append(
Gauge(
f"{config['name']}_{i}",
f"{config['description']} (head {i})",
)
)
setattr(self, metric_name, gauges)
else:
setattr(
self,
metric_name,
config["type"](
config["name"], config["description"], **config["kwargs"]
),
)
def register_speculative_metrics(self, registry: CollectorRegistry):
"""Register all speculative metrics to the specified registry"""
for metric_name in self.SPECULATIVE_METRICS:
if metric_name == "spec_decode_draft_single_head_acceptance_rate":
for gauge in getattr(self, metric_name):
registry.register(gauge)
else:
registry.register(getattr(self, metric_name))
def register_all(self, registry: CollectorRegistry, workers: int = 1):
"""Register all metrics to the specified registry"""
for metric_name in self.METRICS:
@@ -238,6 +312,8 @@ class MetricsManager:
registry.register(work_process_metrics.request_params_max_tokens)
registry.register(work_process_metrics.prompt_tokens_total)
registry.register(work_process_metrics.request_prompt_tokens)
if hasattr(main_process_metrics, "spec_decode_draft_acceptance_rate"):
self.register_speculative_metrics(registry)
@classmethod
def get_excluded_metrics(cls) -> Set[str]:

View File

@@ -83,6 +83,16 @@ class TokenProcessor(object):
self.number_of_output_tokens = 0
self.total_step = 0
self.speculative_stats_step = 0
self.num_draft_tokens = 0
self.num_accepted_tokens = 0
self.num_emitted_tokens = 0
self.max_num_emitted_tokens = 0
self.num_rest_requests_per_head = [
0,
] * MAX_DRAFT_TOKENS
self.num_accept_requests_per_head = [
0,
] * MAX_DRAFT_TOKENS
prefill_time_data = np.zeros([100], dtype=np.float32)
self.prefill_time_signal = IPCSignal(name="prefill_time_signal",
array=prefill_time_data,
@@ -278,8 +288,7 @@ class TokenProcessor(object):
def _compute_speculative_status(self):
# TODO(liuzichang): Supplement more statistics
interval = 10
self.speculative_stats_step += 1
interval = 50
if self.speculative_stats_step % interval == 0:
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
spec_logger.info(
@@ -287,15 +296,19 @@ class TokenProcessor(object):
f" total step: {self.total_step}. total output token num: {self.number_of_output_tokens}"
)
if self.cfg.speculative_config.method in ["mtp"] and \
self.cfg.speculative_config.num_speculative_tokens == 1:
single_head_accep_ratio = accept_ratio / (1 - accept_ratio)
spec_logger.info(
f" Single head accept ratio: {single_head_accep_ratio}")
if self.cfg.speculative_config.method in ["mtp"]:
single_head_acceptance_rates = []
for head in range(self.cfg.speculative_config.num_speculative_tokens):
single_head_acceptance_rates.append(
self.num_accept_requests_per_head[head]
/ self.num_rest_requests_per_head[head]
)
spec_logger.info(f" Single head accept ratio: {single_head_acceptance_rates}")
if self.number_of_output_tokens > 1000000:
self.number_of_output_tokens = 0
self.total_step = 0
self.speculative_stats_step += 1
def _process_sampling_with_logprob_batch_output(self):
"""
@@ -422,6 +435,7 @@ class TokenProcessor(object):
if self.cfg.speculative_config.method:
batch = self.output_tokens[1]
accept_num = tokens[2:batch + 2]
self._record_speculative_decoding_mertics(accept_num)
else:
batch = self.output_tokens[1, 0]
tokens = tokens[2:batch + 2]
@@ -558,6 +572,70 @@ class TokenProcessor(object):
main_process_metrics.request_generation_tokens.observe(
self.tokens_counter[task.request_id])
def _record_speculative_decoding_mertics(self, accept_num):
"""Record metrics of speculative decoding"""
if not hasattr(main_process_metrics, "spec_decode_draft_acceptance_rate"):
main_process_metrics._init_speculative_metrics(
self.cfg.speculative_config.method,
self.cfg.speculative_config.num_speculative_tokens,
)
real_accept_num = [x for x in accept_num if x != 0]
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
self.num_accepted_tokens += num_accepted_tokens
num_emitted_tokens = sum(real_accept_num)
self.num_emitted_tokens += num_emitted_tokens
main_process_metrics.spec_decode_num_accepted_tokens_total.inc(
num_accepted_tokens
)
main_process_metrics.spec_decode_num_emitted_tokens_total.inc(
num_emitted_tokens
)
if self.cfg.speculative_config.method in ["ngram"]:
main_process_metrics.spec_decode_draft_acceptance_rate.set(
self.num_accepted_tokens / self.num_emitted_tokens
)
if self.cfg.speculative_config.method in ["mtp"]:
num_draft_tokens = (
len(real_accept_num)
* self.cfg.speculative_config.num_speculative_tokens
)
self.num_draft_tokens += num_draft_tokens
self.max_num_emitted_tokens += len(real_accept_num) * (
self.cfg.speculative_config.num_speculative_tokens + 1
)
main_process_metrics.spec_decode_draft_acceptance_rate.set(
self.num_accepted_tokens / self.num_draft_tokens
)
main_process_metrics.spec_decode_efficiency.set(
self.num_emitted_tokens / self.max_num_emitted_tokens
)
main_process_metrics.spec_decode_num_draft_tokens_total.inc(
num_draft_tokens
)
num_rest_requests = len(real_accept_num)
for head in range(self.cfg.speculative_config.num_speculative_tokens):
num_accept_requests = len([x for x in real_accept_num if x >= head + 2])
# Accumulate the number of requests for each head
self.num_accept_requests_per_head[head] += num_accept_requests
self.num_rest_requests_per_head[head] += num_rest_requests
# Update the rest requests for each head
num_rest_requests = num_accept_requests
# Calculate the acceptance rate for each head
single_head_acceptance_rate = (
self.num_accept_requests_per_head[head]
/ self.num_rest_requests_per_head[head]
)
main_process_metrics.spec_decode_draft_single_head_acceptance_rate[
head
].set(single_head_acceptance_rate)
class WarmUpTokenProcessor(TokenProcessor):
"""