mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Support return logprob of generated tokens (#2784)
* online chat support logprobs * check xpu * check vl_gpu_model_runner * only cuda support logprob * get_worker() check platform --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -30,9 +30,11 @@ from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import llm_logger, spec_logger
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
RECOVERY_STOP_SIGNAL = -3
|
||||
MAX_BSZ = 512
|
||||
K = 20
|
||||
MAX_DRAFT_TOKENS = 6
|
||||
SPECULATE_MAX_BSZ = 256
|
||||
|
||||
@@ -62,6 +64,13 @@ class TokenProcessor(object):
|
||||
],
|
||||
fill_value=2,
|
||||
dtype="int64")
|
||||
elif self.cfg.enable_logprob:
|
||||
self.output_tokens = paddle.full(
|
||||
shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
|
||||
self.output_scores = paddle.full(
|
||||
shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
|
||||
self.output_ranks = paddle.full(
|
||||
shape=[MAX_BSZ], fill_value=0, dtype="int64")
|
||||
else:
|
||||
self.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1],
|
||||
fill_value=2,
|
||||
@@ -109,12 +118,51 @@ class TokenProcessor(object):
|
||||
assert self.resource_manager is not None, "The resource manager is None, cannot run."
|
||||
if self.worker is not None:
|
||||
raise Exception("Worker is already running!")
|
||||
use_logprobs = (
|
||||
self.cfg.enable_logprob
|
||||
and not self.speculative_decoding
|
||||
and not self.cfg.parallel_config.enable_expert_parallel
|
||||
)
|
||||
|
||||
target_func = (
|
||||
self.process_sampling_with_logprob_results
|
||||
if use_logprobs else
|
||||
self.process_sampling_results
|
||||
)
|
||||
|
||||
self.worker = threading.Thread(target=target_func)
|
||||
|
||||
self.worker = threading.Thread(target=self.process_sampling_results,
|
||||
args=())
|
||||
self.worker.daemon = True
|
||||
self.worker.start()
|
||||
|
||||
def process_sampling_with_logprob_results(self):
|
||||
"""
|
||||
read tokens from paddle inference engine and process logprob results
|
||||
"""
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import get_output_topk
|
||||
else:
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
||||
|
||||
while True:
|
||||
try:
|
||||
is_blocking = True
|
||||
get_output_topk(self.output_tokens, self.output_scores, self.output_ranks, K, rank_id, is_blocking)
|
||||
|
||||
if self.output_tokens[0, 0] == -2:
|
||||
continue
|
||||
llm_logger.debug(
|
||||
f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}"
|
||||
f"rank_id {rank_id} self.output_scores[0, 0] {self.output_scores[0, 0]}"
|
||||
)
|
||||
self._process_prefill_metrics()
|
||||
self._process_sampling_with_logprob_batch_output()
|
||||
except Exception as e:
|
||||
llm_logger.info("while get input_data error: {0} {1}".format(
|
||||
e, str(traceback.format_exc())))
|
||||
|
||||
def process_sampling_results(self):
|
||||
"""
|
||||
read tokens from paddle inference engine and process
|
||||
@@ -245,6 +293,122 @@ class TokenProcessor(object):
|
||||
self.number_of_output_tokens = 0
|
||||
self.total_step = 0
|
||||
|
||||
def _process_sampling_with_logprob_batch_output(self):
|
||||
"""
|
||||
batch post-processing logprob output function
|
||||
"""
|
||||
|
||||
batch = self.output_tokens[1, 0]
|
||||
tokens = self.output_tokens[2:batch * (K + 1) + 2].numpy().reshape(
|
||||
[batch, K + 1])[:, :(K + 1)]
|
||||
scores = self.output_scores[:batch * (K + 1)].numpy().reshape(
|
||||
[batch, K + 1])[:, :(K + 1)]
|
||||
ranks = self.output_ranks[:batch].numpy()
|
||||
batch_result = list()
|
||||
for i in range(batch):
|
||||
if self.resource_manager.stop_flags[i]:
|
||||
continue
|
||||
task = self.resource_manager.tasks_list[i]
|
||||
task_id = task.request_id
|
||||
token_id = int(tokens[i, 0])
|
||||
token_ids = [token_id]
|
||||
recovery_stop = token_id == RECOVERY_STOP_SIGNAL
|
||||
if recovery_stop:
|
||||
llm_logger.info(
|
||||
f"recovery stop signal found at task {task_id}")
|
||||
if not recovery_stop and token_id < 0:
|
||||
continue
|
||||
|
||||
if task.get("prefill_chunk_info", None) is not None:
|
||||
prefill_chunk_num = task.get("prefill_chunk_num", 0)
|
||||
task.prefill_chunk_num = prefill_chunk_num + 1
|
||||
|
||||
if task.prefill_chunk_num < len(task.prefill_chunk_info):
|
||||
continue
|
||||
|
||||
self.total_step += 1
|
||||
current_time = time.time()
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
metrics = RequestMetrics(
|
||||
arrival_time=task.arrival_time,
|
||||
inference_start_time=task.inference_start_time,
|
||||
first_token_time=time.time() - task.inference_start_time,
|
||||
time_in_queue=task.schedule_start_time -
|
||||
task.preprocess_end_time,
|
||||
preprocess_cost_time=task.preprocess_end_time -
|
||||
task.preprocess_start_time)
|
||||
|
||||
self._record_first_token_metrics(task, current_time)
|
||||
|
||||
else:
|
||||
metrics = RequestMetrics(
|
||||
arrival_time=time.time(),
|
||||
request_start_time=task.arrival_time,
|
||||
)
|
||||
self.number_of_output_tokens += len(token_ids)
|
||||
self._record_metrics(task, current_time, token_ids)
|
||||
result = RequestOutput(request_id=task_id,
|
||||
outputs=CompletionOutput(
|
||||
index=i,
|
||||
send_idx=self.tokens_counter[task_id],
|
||||
token_ids=[],
|
||||
logprob = None,
|
||||
draft_token_ids=[],
|
||||
top_logprobs=None,
|
||||
),
|
||||
finished=False,
|
||||
metrics=metrics)
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
if task.messages is not None:
|
||||
result.prompt = task.messages
|
||||
result.num_cached_tokens = task.num_cached_tokens
|
||||
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info[
|
||||
"role"] == "prefill"
|
||||
|
||||
if is_prefill and len(token_ids) > 1:
|
||||
result.outputs.draft_token_ids = copy.deepcopy(token_ids)
|
||||
|
||||
for idx, token_id in enumerate(token_ids):
|
||||
self.tokens_counter[task_id] += 1
|
||||
if token_id != RECOVERY_STOP_SIGNAL:
|
||||
result.outputs.token_ids.append(token_id)
|
||||
result.outputs.logprob = float(scores[i, 0])
|
||||
# Construct top_logprobs
|
||||
topk_token_ids = tokens[i, :].tolist()
|
||||
topk_logprobs = scores[i, :].tolist()
|
||||
sampled_rank = ranks[i].item()
|
||||
|
||||
result.outputs.top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[topk_token_ids],
|
||||
logprobs=[topk_logprobs],
|
||||
sampled_token_ranks=[sampled_rank]
|
||||
)
|
||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||
result.finished = True
|
||||
result.prompt = task.prompt
|
||||
result.prompt_token_ids = task.prompt_token_ids
|
||||
if recovery_stop:
|
||||
result.error_msg = "Recover is not supported, the result is incomplete!"
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} finished, number of "
|
||||
f"generated tokens: {self.tokens_counter[task_id]}.")
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
|
||||
)
|
||||
llm_logger.info(f"{self.resource_manager.info()}")
|
||||
if self.cfg.speculative_config.method:
|
||||
self._compute_speculative_status()
|
||||
if not is_prefill:
|
||||
self._record_completion_metrics(task, current_time)
|
||||
self._recycle_resources(task_id, i, task, result,
|
||||
is_prefill)
|
||||
break
|
||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
||||
batch_result.append(result)
|
||||
|
||||
self.postprocess(batch_result)
|
||||
|
||||
def _process_batch_output(self):
|
||||
"""
|
||||
batch post-processing function
|
||||
|
Reference in New Issue
Block a user