mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-03 15:56:49 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import threading
|
||||
@@ -24,8 +25,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.engine.request import (CompletionOutput, RequestMetrics,
|
||||
RequestOutput)
|
||||
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -39,13 +39,12 @@ MAX_DRAFT_TOKENS = 6
|
||||
SPECULATE_MAX_BSZ = 256
|
||||
|
||||
|
||||
class TokenProcessor(object):
|
||||
class TokenProcessor:
|
||||
"""
|
||||
get Token/Score from Paddle inference engine
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, cached_generated_tokens, engine_worker_queue,
|
||||
split_connector):
|
||||
def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_connector):
|
||||
import paddle
|
||||
|
||||
paddle.device.set_device("cpu")
|
||||
@@ -59,22 +58,17 @@ class TokenProcessor(object):
|
||||
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
||||
|
||||
if self.speculative_decoding:
|
||||
self.output_tokens = paddle.full(shape=[
|
||||
SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2
|
||||
],
|
||||
fill_value=2,
|
||||
dtype="int64")
|
||||
elif self.cfg.enable_logprob:
|
||||
self.output_tokens = paddle.full(
|
||||
shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
|
||||
self.output_scores = paddle.full(
|
||||
shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
|
||||
self.output_ranks = paddle.full(
|
||||
shape=[MAX_BSZ], fill_value=0, dtype="int64")
|
||||
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2],
|
||||
fill_value=2,
|
||||
dtype="int64",
|
||||
)
|
||||
elif self.cfg.enable_logprob:
|
||||
self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
|
||||
self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
|
||||
self.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64")
|
||||
else:
|
||||
self.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1],
|
||||
fill_value=2,
|
||||
dtype="int64")
|
||||
self.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64")
|
||||
self.worker = None
|
||||
|
||||
self.statics_start_time = time.time()
|
||||
@@ -94,21 +88,23 @@ class TokenProcessor(object):
|
||||
0,
|
||||
] * MAX_DRAFT_TOKENS
|
||||
prefill_time_data = np.zeros([100], dtype=np.float32)
|
||||
self.prefill_time_signal = IPCSignal(name="prefill_time_signal",
|
||||
array=prefill_time_data,
|
||||
dtype=np.float32,
|
||||
suffix=os.getpid(),
|
||||
create=True)
|
||||
self.prefill_time_signal = IPCSignal(
|
||||
name="prefill_time_signal",
|
||||
array=prefill_time_data,
|
||||
dtype=np.float32,
|
||||
suffix=os.getpid(),
|
||||
create=True,
|
||||
)
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.prefill_result_status = dict()
|
||||
self._finalizer = weakref.finalize(self, self._cleanup_resources)
|
||||
|
||||
def _cleanup_resources(self):
|
||||
"""Cleaning up shared memory resources"""
|
||||
if hasattr(self, 'prefill_time_signal'):
|
||||
if hasattr(self, "prefill_time_signal"):
|
||||
self.prefill_time_signal.clear()
|
||||
|
||||
if hasattr(self, 'executor'):
|
||||
if hasattr(self, "executor"):
|
||||
self.executor.shutdown(wait=False)
|
||||
|
||||
def set_resource_manager(self, resource_manager):
|
||||
@@ -129,16 +125,12 @@ class TokenProcessor(object):
|
||||
if self.worker is not None:
|
||||
raise Exception("Worker is already running!")
|
||||
use_logprobs = (
|
||||
self.cfg.enable_logprob
|
||||
and not self.speculative_decoding
|
||||
and not self.cfg.parallel_config.enable_expert_parallel
|
||||
self.cfg.enable_logprob
|
||||
and not self.speculative_decoding
|
||||
and not self.cfg.parallel_config.enable_expert_parallel
|
||||
)
|
||||
|
||||
target_func = (
|
||||
self.process_sampling_with_logprob_results
|
||||
if use_logprobs else
|
||||
self.process_sampling_results
|
||||
)
|
||||
target_func = self.process_sampling_with_logprob_results if use_logprobs else self.process_sampling_results
|
||||
|
||||
self.worker = threading.Thread(target=target_func)
|
||||
|
||||
@@ -159,7 +151,14 @@ class TokenProcessor(object):
|
||||
while True:
|
||||
try:
|
||||
is_blocking = True
|
||||
get_output_topk(self.output_tokens, self.output_scores, self.output_ranks, K, rank_id, is_blocking)
|
||||
get_output_topk(
|
||||
self.output_tokens,
|
||||
self.output_scores,
|
||||
self.output_ranks,
|
||||
K,
|
||||
rank_id,
|
||||
is_blocking,
|
||||
)
|
||||
|
||||
if self.output_tokens[0, 0] == -2:
|
||||
continue
|
||||
@@ -170,8 +169,7 @@ class TokenProcessor(object):
|
||||
self._process_prefill_metrics()
|
||||
self._process_sampling_with_logprob_batch_output()
|
||||
except Exception as e:
|
||||
llm_logger.info("while get input_data error: {0} {1}".format(
|
||||
e, str(traceback.format_exc())))
|
||||
llm_logger.info(f"while get input_data error: {e} {traceback.format_exc()!s}")
|
||||
|
||||
def process_sampling_results(self):
|
||||
"""
|
||||
@@ -186,21 +184,25 @@ class TokenProcessor(object):
|
||||
from fastdeploy.model_executor.ops.gcu import get_output
|
||||
else:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_output, get_output_ep, speculate_get_output)
|
||||
get_output,
|
||||
get_output_ep,
|
||||
speculate_get_output,
|
||||
)
|
||||
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
||||
|
||||
while True:
|
||||
try:
|
||||
is_blocking = True
|
||||
if self.speculative_decoding:
|
||||
speculate_get_output(self.output_tokens, rank_id,
|
||||
is_blocking, False)
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if self.output_tokens[0] == -2:
|
||||
continue
|
||||
|
||||
else:
|
||||
if self.cfg.parallel_config.enable_expert_parallel and \
|
||||
self.cfg.parallel_config.data_parallel_size > 1:
|
||||
if (
|
||||
self.cfg.parallel_config.enable_expert_parallel
|
||||
and self.cfg.parallel_config.data_parallel_size > 1
|
||||
):
|
||||
get_output_ep(self.output_tokens, rank_id, is_blocking)
|
||||
|
||||
else:
|
||||
@@ -208,14 +210,11 @@ class TokenProcessor(object):
|
||||
|
||||
if self.output_tokens[0, 0] == -2:
|
||||
continue
|
||||
llm_logger.debug(
|
||||
f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}"
|
||||
)
|
||||
llm_logger.debug(f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}")
|
||||
self._process_prefill_metrics()
|
||||
self._process_batch_output()
|
||||
except Exception as e:
|
||||
llm_logger.info("while get input_data error: {0} {1}".format(
|
||||
e, str(traceback.format_exc())))
|
||||
llm_logger.info(f"while get input_data error: {e} {traceback.format_exc()!s}")
|
||||
|
||||
def _process_prefill_metrics(self):
|
||||
"""Asynchronous processing prefill time indicators"""
|
||||
@@ -224,11 +223,9 @@ class TokenProcessor(object):
|
||||
try:
|
||||
current_index = 0
|
||||
while current_index < len(self.prefill_time_signal.value):
|
||||
prefill_time = self.prefill_time_signal.value[
|
||||
current_index]
|
||||
prefill_time = self.prefill_time_signal.value[current_index]
|
||||
if prefill_time > 0:
|
||||
main_process_metrics.request_prefill_time.observe(
|
||||
prefill_time)
|
||||
main_process_metrics.request_prefill_time.observe(prefill_time)
|
||||
self.prefill_time_signal.value[current_index] = 0
|
||||
current_index += 1
|
||||
except Exception as e:
|
||||
@@ -248,12 +245,7 @@ class TokenProcessor(object):
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Error in TokenProcessor's postprocess: {e}")
|
||||
|
||||
def _recycle_resources(self,
|
||||
task_id,
|
||||
index,
|
||||
task,
|
||||
result=None,
|
||||
is_prefill=False):
|
||||
def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False):
|
||||
"""
|
||||
recycle resources
|
||||
"""
|
||||
@@ -262,13 +254,10 @@ class TokenProcessor(object):
|
||||
finished_task_ids = self.engine_worker_queue.get_finished_req()
|
||||
if len(finished_task_ids) > 0:
|
||||
for finished_task_id in finished_task_ids:
|
||||
llm_logger.info(
|
||||
f"finished_task_id: {finished_task_id}")
|
||||
self.prefill_result_status[
|
||||
finished_task_id[0]] = finished_task_id[1]
|
||||
llm_logger.info(f"finished_task_id: {finished_task_id}")
|
||||
self.prefill_result_status[finished_task_id[0]] = finished_task_id[1]
|
||||
if task_id in self.prefill_result_status:
|
||||
self.split_connector.send_first_token(
|
||||
task.disaggregate_info, [result])
|
||||
self.split_connector.send_first_token(task.disaggregate_info, [result])
|
||||
self.resource_manager.stop_flags[index] = True
|
||||
self.resource_manager.tasks_list[index] = None
|
||||
self.resource_manager._recycle_block_tables(task)
|
||||
@@ -300,8 +289,7 @@ class TokenProcessor(object):
|
||||
single_head_acceptance_rates = []
|
||||
for head in range(self.cfg.speculative_config.num_speculative_tokens):
|
||||
single_head_acceptance_rates.append(
|
||||
self.num_accept_requests_per_head[head]
|
||||
/ self.num_rest_requests_per_head[head]
|
||||
self.num_accept_requests_per_head[head] / self.num_rest_requests_per_head[head]
|
||||
)
|
||||
spec_logger.info(f" Single head accept ratio: {single_head_acceptance_rates}")
|
||||
|
||||
@@ -316,10 +304,8 @@ class TokenProcessor(object):
|
||||
"""
|
||||
|
||||
batch = self.output_tokens[1, 0]
|
||||
tokens = self.output_tokens[2:batch * (K + 1) + 2].numpy().reshape(
|
||||
[batch, K + 1])[:, :(K + 1)]
|
||||
scores = self.output_scores[:batch * (K + 1)].numpy().reshape(
|
||||
[batch, K + 1])[:, :(K + 1)]
|
||||
tokens = self.output_tokens[2 : batch * (K + 1) + 2].numpy().reshape([batch, K + 1])[:, : (K + 1)]
|
||||
scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)]
|
||||
ranks = self.output_ranks[:batch].numpy()
|
||||
batch_result = list()
|
||||
for i in range(batch):
|
||||
@@ -331,8 +317,7 @@ class TokenProcessor(object):
|
||||
token_ids = [token_id]
|
||||
recovery_stop = token_id == RECOVERY_STOP_SIGNAL
|
||||
if recovery_stop:
|
||||
llm_logger.info(
|
||||
f"recovery stop signal found at task {task_id}")
|
||||
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||
if not recovery_stop and token_id < 0:
|
||||
continue
|
||||
|
||||
@@ -350,10 +335,9 @@ class TokenProcessor(object):
|
||||
arrival_time=task.arrival_time,
|
||||
inference_start_time=task.inference_start_time,
|
||||
first_token_time=time.time() - task.inference_start_time,
|
||||
time_in_queue=task.schedule_start_time -
|
||||
task.preprocess_end_time,
|
||||
preprocess_cost_time=task.preprocess_end_time -
|
||||
task.preprocess_start_time)
|
||||
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
|
||||
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
|
||||
)
|
||||
|
||||
self._record_first_token_metrics(task, current_time)
|
||||
|
||||
@@ -364,24 +348,25 @@ class TokenProcessor(object):
|
||||
)
|
||||
self.number_of_output_tokens += len(token_ids)
|
||||
self._record_metrics(task, current_time, token_ids)
|
||||
result = RequestOutput(request_id=task_id,
|
||||
outputs=CompletionOutput(
|
||||
index=i,
|
||||
send_idx=self.tokens_counter[task_id],
|
||||
token_ids=[],
|
||||
logprob = None,
|
||||
draft_token_ids=[],
|
||||
top_logprobs=None,
|
||||
),
|
||||
finished=False,
|
||||
metrics=metrics)
|
||||
result = RequestOutput(
|
||||
request_id=task_id,
|
||||
outputs=CompletionOutput(
|
||||
index=i,
|
||||
send_idx=self.tokens_counter[task_id],
|
||||
token_ids=[],
|
||||
logprob=None,
|
||||
draft_token_ids=[],
|
||||
top_logprobs=None,
|
||||
),
|
||||
finished=False,
|
||||
metrics=metrics,
|
||||
)
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
if task.messages is not None:
|
||||
result.prompt = task.messages
|
||||
result.num_cached_tokens = task.num_cached_tokens
|
||||
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info[
|
||||
"role"] == "prefill"
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||
|
||||
if is_prefill and len(token_ids) > 1:
|
||||
result.outputs.draft_token_ids = copy.deepcopy(token_ids)
|
||||
@@ -399,7 +384,7 @@ class TokenProcessor(object):
|
||||
result.outputs.top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[topk_token_ids],
|
||||
logprobs=[topk_logprobs],
|
||||
sampled_token_ranks=[sampled_rank]
|
||||
sampled_token_ranks=[sampled_rank],
|
||||
)
|
||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||
result.finished = True
|
||||
@@ -408,8 +393,8 @@ class TokenProcessor(object):
|
||||
if recovery_stop:
|
||||
result.error_msg = "Recover is not supported, the result is incomplete!"
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} finished, number of "
|
||||
f"generated tokens: {self.tokens_counter[task_id]}.")
|
||||
f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}."
|
||||
)
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
|
||||
)
|
||||
@@ -418,8 +403,7 @@ class TokenProcessor(object):
|
||||
self._compute_speculative_status()
|
||||
if not is_prefill:
|
||||
self._record_completion_metrics(task, current_time)
|
||||
self._recycle_resources(task_id, i, task, result,
|
||||
is_prefill)
|
||||
self._recycle_resources(task_id, i, task, result, is_prefill)
|
||||
break
|
||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
||||
batch_result.append(result)
|
||||
@@ -434,11 +418,11 @@ class TokenProcessor(object):
|
||||
tokens = self.output_tokens.numpy()
|
||||
if self.cfg.speculative_config.method:
|
||||
batch = self.output_tokens[1]
|
||||
accept_num = tokens[2:batch + 2]
|
||||
accept_num = tokens[2 : batch + 2]
|
||||
self._record_speculative_decoding_mertics(accept_num)
|
||||
else:
|
||||
batch = self.output_tokens[1, 0]
|
||||
tokens = tokens[2:batch + 2]
|
||||
tokens = tokens[2 : batch + 2]
|
||||
|
||||
batch_result = list()
|
||||
for i in range(batch):
|
||||
@@ -450,10 +434,14 @@ class TokenProcessor(object):
|
||||
|
||||
task_id = task.request_id
|
||||
if self.cfg.speculative_config.method:
|
||||
token_ids = tokens[2 + SPECULATE_MAX_BSZ +
|
||||
i * MAX_DRAFT_TOKENS:2 + SPECULATE_MAX_BSZ +
|
||||
i * MAX_DRAFT_TOKENS +
|
||||
accept_num[i]].tolist()
|
||||
token_ids = tokens[
|
||||
2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS : 2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS
|
||||
+ accept_num[i]
|
||||
].tolist()
|
||||
if len(token_ids) == 0 or token_ids[-1] <= 0:
|
||||
continue
|
||||
else:
|
||||
@@ -461,8 +449,7 @@ class TokenProcessor(object):
|
||||
token_ids = [token_id]
|
||||
recovery_stop = token_id == RECOVERY_STOP_SIGNAL
|
||||
if recovery_stop:
|
||||
llm_logger.info(
|
||||
f"recovery stop signal found at task {task_id}")
|
||||
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||
if not recovery_stop and token_id < 0:
|
||||
continue
|
||||
|
||||
@@ -480,10 +467,9 @@ class TokenProcessor(object):
|
||||
arrival_time=task.arrival_time,
|
||||
inference_start_time=task.inference_start_time,
|
||||
first_token_time=time.time() - task.inference_start_time,
|
||||
time_in_queue=task.schedule_start_time -
|
||||
task.preprocess_end_time,
|
||||
preprocess_cost_time=task.preprocess_end_time -
|
||||
task.preprocess_start_time)
|
||||
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
|
||||
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
|
||||
)
|
||||
|
||||
self._record_first_token_metrics(task, current_time)
|
||||
|
||||
@@ -494,21 +480,23 @@ class TokenProcessor(object):
|
||||
)
|
||||
self.number_of_output_tokens += len(token_ids)
|
||||
self._record_metrics(task, current_time, token_ids)
|
||||
result = RequestOutput(request_id=task_id,
|
||||
outputs=CompletionOutput(
|
||||
index=i,
|
||||
send_idx=self.tokens_counter[task_id],
|
||||
token_ids=[],
|
||||
draft_token_ids=[]),
|
||||
finished=False,
|
||||
metrics=metrics)
|
||||
result = RequestOutput(
|
||||
request_id=task_id,
|
||||
outputs=CompletionOutput(
|
||||
index=i,
|
||||
send_idx=self.tokens_counter[task_id],
|
||||
token_ids=[],
|
||||
draft_token_ids=[],
|
||||
),
|
||||
finished=False,
|
||||
metrics=metrics,
|
||||
)
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
if task.messages is not None:
|
||||
result.prompt = task.messages
|
||||
result.num_cached_tokens = task.num_cached_tokens
|
||||
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info[
|
||||
"role"] == "prefill"
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||
|
||||
if is_prefill and len(token_ids) > 1:
|
||||
result.outputs.draft_token_ids = copy.deepcopy(token_ids)
|
||||
@@ -522,8 +510,8 @@ class TokenProcessor(object):
|
||||
if recovery_stop:
|
||||
result.error_msg = "Recover is not supported, the result is incomplete!"
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} finished, number of "
|
||||
f"generated tokens: {self.tokens_counter[task_id]}.")
|
||||
f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}."
|
||||
)
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
|
||||
)
|
||||
@@ -532,8 +520,7 @@ class TokenProcessor(object):
|
||||
self._compute_speculative_status()
|
||||
if not is_prefill:
|
||||
self._record_completion_metrics(task, current_time)
|
||||
self._recycle_resources(task_id, i, task, result,
|
||||
is_prefill)
|
||||
self._recycle_resources(task_id, i, task, result, is_prefill)
|
||||
break
|
||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
||||
batch_result.append(result)
|
||||
@@ -542,8 +529,7 @@ class TokenProcessor(object):
|
||||
|
||||
def _record_metrics(self, task, current_time, token_ids):
|
||||
"""Record all metrics for a task"""
|
||||
if hasattr(task,
|
||||
'last_token_time') and task.last_token_time is not None:
|
||||
if hasattr(task, "last_token_time") and task.last_token_time is not None:
|
||||
token_gen_time = current_time - task.last_token_time
|
||||
main_process_metrics.time_per_output_token.observe(token_gen_time)
|
||||
task.last_token_time = current_time
|
||||
@@ -554,23 +540,19 @@ class TokenProcessor(object):
|
||||
def _record_first_token_metrics(self, task, current_time):
|
||||
"""Record metrics for first token"""
|
||||
task.first_token_time = current_time
|
||||
main_process_metrics.time_to_first_token.observe(
|
||||
current_time - task.inference_start_time)
|
||||
main_process_metrics.request_queue_time.observe(
|
||||
task.schedule_start_time - task.preprocess_end_time)
|
||||
main_process_metrics.time_to_first_token.observe(current_time - task.inference_start_time)
|
||||
main_process_metrics.request_queue_time.observe(task.schedule_start_time - task.preprocess_end_time)
|
||||
|
||||
def _record_completion_metrics(self, task, current_time):
|
||||
"""Record metrics when request completes"""
|
||||
if hasattr(task, 'first_token_time'):
|
||||
if hasattr(task, "first_token_time"):
|
||||
decode_time = current_time - task.first_token_time
|
||||
main_process_metrics.request_decode_time.observe(decode_time)
|
||||
|
||||
main_process_metrics.num_requests_running.dec(1)
|
||||
main_process_metrics.request_success_total.inc()
|
||||
main_process_metrics.request_inference_time.observe(
|
||||
current_time - task.inference_start_time)
|
||||
main_process_metrics.request_generation_tokens.observe(
|
||||
self.tokens_counter[task.request_id])
|
||||
main_process_metrics.request_inference_time.observe(current_time - task.inference_start_time)
|
||||
main_process_metrics.request_generation_tokens.observe(self.tokens_counter[task.request_id])
|
||||
|
||||
def _record_speculative_decoding_mertics(self, accept_num):
|
||||
"""Record metrics of speculative decoding"""
|
||||
@@ -586,12 +568,8 @@ class TokenProcessor(object):
|
||||
num_emitted_tokens = sum(real_accept_num)
|
||||
self.num_emitted_tokens += num_emitted_tokens
|
||||
|
||||
main_process_metrics.spec_decode_num_accepted_tokens_total.inc(
|
||||
num_accepted_tokens
|
||||
)
|
||||
main_process_metrics.spec_decode_num_emitted_tokens_total.inc(
|
||||
num_emitted_tokens
|
||||
)
|
||||
main_process_metrics.spec_decode_num_accepted_tokens_total.inc(num_accepted_tokens)
|
||||
main_process_metrics.spec_decode_num_emitted_tokens_total.inc(num_emitted_tokens)
|
||||
|
||||
if self.cfg.speculative_config.method in ["ngram"]:
|
||||
main_process_metrics.spec_decode_draft_acceptance_rate.set(
|
||||
@@ -599,10 +577,7 @@ class TokenProcessor(object):
|
||||
)
|
||||
|
||||
if self.cfg.speculative_config.method in ["mtp"]:
|
||||
num_draft_tokens = (
|
||||
len(real_accept_num)
|
||||
* self.cfg.speculative_config.num_speculative_tokens
|
||||
)
|
||||
num_draft_tokens = len(real_accept_num) * self.cfg.speculative_config.num_speculative_tokens
|
||||
self.num_draft_tokens += num_draft_tokens
|
||||
|
||||
self.max_num_emitted_tokens += len(real_accept_num) * (
|
||||
@@ -612,12 +587,8 @@ class TokenProcessor(object):
|
||||
main_process_metrics.spec_decode_draft_acceptance_rate.set(
|
||||
self.num_accepted_tokens / self.num_draft_tokens
|
||||
)
|
||||
main_process_metrics.spec_decode_efficiency.set(
|
||||
self.num_emitted_tokens / self.max_num_emitted_tokens
|
||||
)
|
||||
main_process_metrics.spec_decode_num_draft_tokens_total.inc(
|
||||
num_draft_tokens
|
||||
)
|
||||
main_process_metrics.spec_decode_efficiency.set(self.num_emitted_tokens / self.max_num_emitted_tokens)
|
||||
main_process_metrics.spec_decode_num_draft_tokens_total.inc(num_draft_tokens)
|
||||
|
||||
num_rest_requests = len(real_accept_num)
|
||||
for head in range(self.cfg.speculative_config.num_speculative_tokens):
|
||||
@@ -629,12 +600,11 @@ class TokenProcessor(object):
|
||||
num_rest_requests = num_accept_requests
|
||||
# Calculate the acceptance rate for each head
|
||||
single_head_acceptance_rate = (
|
||||
self.num_accept_requests_per_head[head]
|
||||
/ self.num_rest_requests_per_head[head]
|
||||
self.num_accept_requests_per_head[head] / self.num_rest_requests_per_head[head]
|
||||
)
|
||||
main_process_metrics.spec_decode_draft_single_head_acceptance_rate[head].set(
|
||||
single_head_acceptance_rate
|
||||
)
|
||||
main_process_metrics.spec_decode_draft_single_head_acceptance_rate[
|
||||
head
|
||||
].set(single_head_acceptance_rate)
|
||||
|
||||
|
||||
class WarmUpTokenProcessor(TokenProcessor):
|
||||
@@ -661,14 +631,15 @@ class WarmUpTokenProcessor(TokenProcessor):
|
||||
from fastdeploy.model_executor.ops.iluvatar import get_output
|
||||
else:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_output, speculate_get_output)
|
||||
get_output,
|
||||
speculate_get_output,
|
||||
)
|
||||
|
||||
while self._is_running:
|
||||
try:
|
||||
rank_id = 0
|
||||
if self.speculative_decoding:
|
||||
speculate_get_output(self.output_tokens, rank_id,
|
||||
self._is_blocking)
|
||||
speculate_get_output(self.output_tokens, rank_id, self._is_blocking)
|
||||
if self.output_tokens[0] == -2:
|
||||
continue
|
||||
else:
|
||||
@@ -678,8 +649,7 @@ class WarmUpTokenProcessor(TokenProcessor):
|
||||
continue
|
||||
self._process_batch_output()
|
||||
except Exception as e:
|
||||
llm_logger.info("while get input_data error: {0} {1}".format(
|
||||
e, str(traceback.format_exc())))
|
||||
llm_logger.info(f"while get input_data error: {e} {traceback.format_exc()!s}")
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user