mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
merge logprob into batch_output (#3266)
This commit is contained in:
@@ -30,6 +30,7 @@ from fastdeploy.config import (
|
||||
TaskOption,
|
||||
)
|
||||
from fastdeploy.engine.config import Config
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.scheduler.config import SchedulerConfig
|
||||
from fastdeploy.utils import DeprecatedOptionWarning, FlexibleArgumentParser
|
||||
|
||||
@@ -344,6 +345,13 @@ class EngineArgs:
|
||||
"""
|
||||
if not self.tokenizer:
|
||||
self.tokenizer = self.model
|
||||
if self.enable_logprob:
|
||||
if self.speculative_config is not None:
|
||||
raise NotImplementedError("Logprob does not support speculation_config.")
|
||||
if self.enable_expert_parallel:
|
||||
raise NotImplementedError("Logprob does not support enable_expert_parallel.")
|
||||
if not current_platform.is_cuda():
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
|
@@ -57,6 +57,7 @@ class TokenProcessor:
|
||||
self.split_connector = split_connector
|
||||
|
||||
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
||||
self.use_logprobs = self.cfg.enable_logprob
|
||||
|
||||
if self.speculative_decoding:
|
||||
self.output_tokens = paddle.full(
|
||||
@@ -64,7 +65,7 @@ class TokenProcessor:
|
||||
fill_value=2,
|
||||
dtype="int64",
|
||||
)
|
||||
elif self.cfg.enable_logprob:
|
||||
elif self.use_logprobs:
|
||||
self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
|
||||
self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
|
||||
self.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64")
|
||||
@@ -125,53 +126,12 @@ class TokenProcessor:
|
||||
assert self.resource_manager is not None, "The resource manager is None, cannot run."
|
||||
if self.worker is not None:
|
||||
raise Exception("Worker is already running!")
|
||||
use_logprobs = (
|
||||
self.cfg.enable_logprob
|
||||
and not self.speculative_decoding
|
||||
and not self.cfg.parallel_config.enable_expert_parallel
|
||||
)
|
||||
|
||||
target_func = self.process_sampling_with_logprob_results if use_logprobs else self.process_sampling_results
|
||||
|
||||
self.worker = threading.Thread(target=target_func)
|
||||
self.worker = threading.Thread(target=self.process_sampling_results)
|
||||
|
||||
self.worker.daemon = True
|
||||
self.worker.start()
|
||||
|
||||
def process_sampling_with_logprob_results(self):
|
||||
"""
|
||||
read tokens from paddle inference engine and process logprob results
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import get_output_topk
|
||||
else:
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
|
||||
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
||||
|
||||
while True:
|
||||
try:
|
||||
is_blocking = True
|
||||
get_output_topk(
|
||||
self.output_tokens,
|
||||
self.output_scores,
|
||||
self.output_ranks,
|
||||
K,
|
||||
rank_id,
|
||||
is_blocking,
|
||||
)
|
||||
|
||||
if self.output_tokens[0, 0] == -2:
|
||||
continue
|
||||
llm_logger.debug(
|
||||
f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}"
|
||||
f"rank_id {rank_id} self.output_scores[0, 0] {self.output_scores[0, 0]}"
|
||||
)
|
||||
self._process_prefill_metrics()
|
||||
self._process_sampling_with_logprob_batch_output()
|
||||
except Exception as e:
|
||||
llm_logger.info(f"while get input_data error: {e} {traceback.format_exc()!s}")
|
||||
|
||||
def process_sampling_results(self):
|
||||
"""
|
||||
read tokens from paddle inference engine and process
|
||||
@@ -187,6 +147,7 @@ class TokenProcessor:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_output,
|
||||
get_output_ep,
|
||||
get_output_topk,
|
||||
speculate_get_output,
|
||||
)
|
||||
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
||||
@@ -206,6 +167,16 @@ class TokenProcessor:
|
||||
):
|
||||
get_output_ep(self.output_tokens, rank_id, is_blocking)
|
||||
|
||||
else:
|
||||
if self.use_logprobs:
|
||||
get_output_topk(
|
||||
self.output_tokens,
|
||||
self.output_scores,
|
||||
self.output_ranks,
|
||||
K,
|
||||
rank_id,
|
||||
is_blocking,
|
||||
)
|
||||
else:
|
||||
get_output(self.output_tokens, rank_id, is_blocking)
|
||||
|
||||
@@ -305,129 +276,23 @@ class TokenProcessor:
|
||||
self.total_step = 0
|
||||
self.speculative_stats_step += 1
|
||||
|
||||
def _process_sampling_with_logprob_batch_output(self):
|
||||
"""
|
||||
batch post-processing logprob output function
|
||||
"""
|
||||
|
||||
batch = self.output_tokens[1, 0]
|
||||
tokens = self.output_tokens[2 : batch * (K + 1) + 2].numpy().reshape([batch, K + 1])[:, : (K + 1)]
|
||||
scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)]
|
||||
ranks = self.output_ranks[:batch].numpy()
|
||||
batch_result = list()
|
||||
for i in range(batch):
|
||||
if self.resource_manager.stop_flags[i]:
|
||||
continue
|
||||
task = self.resource_manager.tasks_list[i]
|
||||
task_id = task.request_id
|
||||
token_id = int(tokens[i, 0])
|
||||
token_ids = [token_id]
|
||||
recovery_stop = token_id == RECOVERY_STOP_SIGNAL
|
||||
if recovery_stop:
|
||||
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||
if not recovery_stop and token_id < 0:
|
||||
continue
|
||||
|
||||
if task.get("prefill_chunk_info", None) is not None:
|
||||
prefill_chunk_num = task.get("prefill_chunk_num", 0)
|
||||
task.prefill_chunk_num = prefill_chunk_num + 1
|
||||
|
||||
if task.prefill_chunk_num < len(task.prefill_chunk_info):
|
||||
continue
|
||||
|
||||
self.total_step += 1
|
||||
current_time = time.time()
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
metrics = RequestMetrics(
|
||||
arrival_time=task.arrival_time,
|
||||
inference_start_time=task.inference_start_time,
|
||||
first_token_time=time.time() - task.inference_start_time,
|
||||
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
|
||||
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
|
||||
request_start_time=task.arrival_time,
|
||||
)
|
||||
|
||||
self._record_first_token_metrics(task, current_time)
|
||||
|
||||
else:
|
||||
metrics = RequestMetrics(
|
||||
arrival_time=time.time(),
|
||||
request_start_time=task.arrival_time,
|
||||
)
|
||||
self.number_of_output_tokens += len(token_ids)
|
||||
self._record_metrics(task, current_time, token_ids)
|
||||
result = RequestOutput(
|
||||
request_id=task_id,
|
||||
outputs=CompletionOutput(
|
||||
index=i,
|
||||
send_idx=self.tokens_counter[task_id],
|
||||
token_ids=[],
|
||||
logprob=None,
|
||||
draft_token_ids=[],
|
||||
top_logprobs=None,
|
||||
),
|
||||
finished=False,
|
||||
metrics=metrics,
|
||||
)
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
if task.messages is not None:
|
||||
result.prompt = task.messages
|
||||
result.num_cached_tokens = task.num_cached_tokens
|
||||
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||
|
||||
if is_prefill and len(token_ids) > 1:
|
||||
result.outputs.draft_token_ids = copy.deepcopy(token_ids)
|
||||
|
||||
for idx, token_id in enumerate(token_ids):
|
||||
self.tokens_counter[task_id] += 1
|
||||
if token_id != RECOVERY_STOP_SIGNAL:
|
||||
result.outputs.token_ids.append(token_id)
|
||||
task.output_token_ids.append(token_id)
|
||||
result.outputs.logprob = float(scores[i, 0])
|
||||
# Construct top_logprobs
|
||||
topk_token_ids = tokens[i, :].tolist()
|
||||
topk_logprobs = scores[i, :].tolist()
|
||||
sampled_rank = ranks[i].item()
|
||||
|
||||
result.outputs.top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[topk_token_ids],
|
||||
logprobs=[topk_logprobs],
|
||||
sampled_token_ranks=[sampled_rank],
|
||||
)
|
||||
|
||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||
result.finished = True
|
||||
if recovery_stop:
|
||||
result.error_msg = "Recover is not supported, the result is incomplete!"
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}."
|
||||
)
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
|
||||
)
|
||||
llm_logger.info(f"{self.resource_manager.info()}")
|
||||
if self.cfg.speculative_config.method:
|
||||
self._compute_speculative_status()
|
||||
if not is_prefill:
|
||||
self._record_completion_metrics(task, current_time)
|
||||
self._recycle_resources(task_id, i, task, result, is_prefill)
|
||||
break
|
||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
||||
batch_result.append(result)
|
||||
|
||||
self.postprocess(batch_result)
|
||||
|
||||
def _process_batch_output(self):
|
||||
"""
|
||||
batch post-processing function
|
||||
"""
|
||||
|
||||
tokens = self.output_tokens.numpy()
|
||||
scores = None
|
||||
ranks = None
|
||||
if self.cfg.speculative_config.method:
|
||||
batch = self.output_tokens[1]
|
||||
accept_num = tokens[2 : batch + 2]
|
||||
self._record_speculative_decoding_mertics(accept_num)
|
||||
elif self.use_logprobs:
|
||||
batch = self.output_tokens[1, 0]
|
||||
tokens = tokens[2 : batch * (K + 1) + 2].reshape([batch, K + 1])[:, : (K + 1)]
|
||||
scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)]
|
||||
ranks = self.output_ranks[:batch].numpy()
|
||||
else:
|
||||
batch = self.output_tokens[1, 0]
|
||||
tokens = tokens[2 : batch + 2]
|
||||
@@ -522,6 +387,17 @@ class TokenProcessor:
|
||||
if token_id != RECOVERY_STOP_SIGNAL:
|
||||
result.outputs.token_ids.append(token_id)
|
||||
task.output_token_ids.append(token_id)
|
||||
if self.use_logprobs:
|
||||
result.outputs.logprob = float(scores[i, 0])
|
||||
# Construct top_logprobs
|
||||
topk_token_ids = tokens[i, :].tolist()
|
||||
topk_logprobs = scores[i, :].tolist()
|
||||
sampled_rank = ranks[i].item()
|
||||
result.outputs.top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[topk_token_ids],
|
||||
logprobs=[topk_logprobs],
|
||||
sampled_token_ranks=[sampled_rank],
|
||||
)
|
||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||
result.finished = True
|
||||
if recovery_stop:
|
||||
|
Reference in New Issue
Block a user