merge logprob into batch_output (#3266)

This commit is contained in:
chen
2025-08-11 10:03:00 +08:00
committed by GitHub
parent 566badb83c
commit 46c8491201
2 changed files with 41 additions and 157 deletions

View File

@@ -30,6 +30,7 @@ from fastdeploy.config import (
TaskOption,
)
from fastdeploy.engine.config import Config
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler.config import SchedulerConfig
from fastdeploy.utils import DeprecatedOptionWarning, FlexibleArgumentParser
@@ -344,6 +345,13 @@ class EngineArgs:
"""
if not self.tokenizer:
self.tokenizer = self.model
if self.enable_logprob:
if self.speculative_config is not None:
raise NotImplementedError("Logprob does not support speculation_config.")
if self.enable_expert_parallel:
raise NotImplementedError("Logprob does not support enable_expert_parallel.")
if not current_platform.is_cuda():
raise NotImplementedError("Only CUDA platform supports logprob.")
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

View File

@@ -57,6 +57,7 @@ class TokenProcessor:
self.split_connector = split_connector
self.speculative_decoding = self.cfg.speculative_config.method is not None
self.use_logprobs = self.cfg.enable_logprob
if self.speculative_decoding:
self.output_tokens = paddle.full(
@@ -64,7 +65,7 @@ class TokenProcessor:
fill_value=2,
dtype="int64",
)
elif self.cfg.enable_logprob:
elif self.use_logprobs:
self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
self.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64")
@@ -125,53 +126,12 @@ class TokenProcessor:
assert self.resource_manager is not None, "The resource manager is None, cannot run."
if self.worker is not None:
raise Exception("Worker is already running!")
use_logprobs = (
self.cfg.enable_logprob
and not self.speculative_decoding
and not self.cfg.parallel_config.enable_expert_parallel
)
target_func = self.process_sampling_with_logprob_results if use_logprobs else self.process_sampling_results
self.worker = threading.Thread(target=target_func)
self.worker = threading.Thread(target=self.process_sampling_results)
self.worker.daemon = True
self.worker.start()
def process_sampling_with_logprob_results(self):
"""
read tokens from paddle inference engine and process logprob results
"""
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import get_output_topk
else:
raise NotImplementedError("Only CUDA platform supports logprob.")
rank_id = self.cfg.parallel_config.local_data_parallel_id
while True:
try:
is_blocking = True
get_output_topk(
self.output_tokens,
self.output_scores,
self.output_ranks,
K,
rank_id,
is_blocking,
)
if self.output_tokens[0, 0] == -2:
continue
llm_logger.debug(
f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}"
f"rank_id {rank_id} self.output_scores[0, 0] {self.output_scores[0, 0]}"
)
self._process_prefill_metrics()
self._process_sampling_with_logprob_batch_output()
except Exception as e:
llm_logger.info(f"while get input_data error: {e} {traceback.format_exc()!s}")
def process_sampling_results(self):
"""
read tokens from paddle inference engine and process
@@ -187,6 +147,7 @@ class TokenProcessor:
from fastdeploy.model_executor.ops.gpu import (
get_output,
get_output_ep,
get_output_topk,
speculate_get_output,
)
rank_id = self.cfg.parallel_config.local_data_parallel_id
@@ -206,6 +167,16 @@ class TokenProcessor:
):
get_output_ep(self.output_tokens, rank_id, is_blocking)
else:
if self.use_logprobs:
get_output_topk(
self.output_tokens,
self.output_scores,
self.output_ranks,
K,
rank_id,
is_blocking,
)
else:
get_output(self.output_tokens, rank_id, is_blocking)
@@ -305,129 +276,23 @@ class TokenProcessor:
self.total_step = 0
self.speculative_stats_step += 1
def _process_sampling_with_logprob_batch_output(self):
"""
batch post-processing logprob output function
"""
batch = self.output_tokens[1, 0]
tokens = self.output_tokens[2 : batch * (K + 1) + 2].numpy().reshape([batch, K + 1])[:, : (K + 1)]
scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)]
ranks = self.output_ranks[:batch].numpy()
batch_result = list()
for i in range(batch):
if self.resource_manager.stop_flags[i]:
continue
task = self.resource_manager.tasks_list[i]
task_id = task.request_id
token_id = int(tokens[i, 0])
token_ids = [token_id]
recovery_stop = token_id == RECOVERY_STOP_SIGNAL
if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}")
if not recovery_stop and token_id < 0:
continue
if task.get("prefill_chunk_info", None) is not None:
prefill_chunk_num = task.get("prefill_chunk_num", 0)
task.prefill_chunk_num = prefill_chunk_num + 1
if task.prefill_chunk_num < len(task.prefill_chunk_info):
continue
self.total_step += 1
current_time = time.time()
if self.tokens_counter[task_id] == 0:
metrics = RequestMetrics(
arrival_time=task.arrival_time,
inference_start_time=task.inference_start_time,
first_token_time=time.time() - task.inference_start_time,
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
request_start_time=task.arrival_time,
)
self._record_first_token_metrics(task, current_time)
else:
metrics = RequestMetrics(
arrival_time=time.time(),
request_start_time=task.arrival_time,
)
self.number_of_output_tokens += len(token_ids)
self._record_metrics(task, current_time, token_ids)
result = RequestOutput(
request_id=task_id,
outputs=CompletionOutput(
index=i,
send_idx=self.tokens_counter[task_id],
token_ids=[],
logprob=None,
draft_token_ids=[],
top_logprobs=None,
),
finished=False,
metrics=metrics,
)
if self.tokens_counter[task_id] == 0:
if task.messages is not None:
result.prompt = task.messages
result.num_cached_tokens = task.num_cached_tokens
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
if is_prefill and len(token_ids) > 1:
result.outputs.draft_token_ids = copy.deepcopy(token_ids)
for idx, token_id in enumerate(token_ids):
self.tokens_counter[task_id] += 1
if token_id != RECOVERY_STOP_SIGNAL:
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id)
result.outputs.logprob = float(scores[i, 0])
# Construct top_logprobs
topk_token_ids = tokens[i, :].tolist()
topk_logprobs = scores[i, :].tolist()
sampled_rank = ranks[i].item()
result.outputs.top_logprobs = LogprobsLists(
logprob_token_ids=[topk_token_ids],
logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank],
)
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
result.finished = True
if recovery_stop:
result.error_msg = "Recover is not supported, the result is incomplete!"
llm_logger.info(
f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}."
)
llm_logger.info(
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
)
llm_logger.info(f"{self.resource_manager.info()}")
if self.cfg.speculative_config.method:
self._compute_speculative_status()
if not is_prefill:
self._record_completion_metrics(task, current_time)
self._recycle_resources(task_id, i, task, result, is_prefill)
break
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
batch_result.append(result)
self.postprocess(batch_result)
def _process_batch_output(self):
"""
batch post-processing function
"""
tokens = self.output_tokens.numpy()
scores = None
ranks = None
if self.cfg.speculative_config.method:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
self._record_speculative_decoding_mertics(accept_num)
elif self.use_logprobs:
batch = self.output_tokens[1, 0]
tokens = tokens[2 : batch * (K + 1) + 2].reshape([batch, K + 1])[:, : (K + 1)]
scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)]
ranks = self.output_ranks[:batch].numpy()
else:
batch = self.output_tokens[1, 0]
tokens = tokens[2 : batch + 2]
@@ -522,6 +387,17 @@ class TokenProcessor:
if token_id != RECOVERY_STOP_SIGNAL:
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id)
if self.use_logprobs:
result.outputs.logprob = float(scores[i, 0])
# Construct top_logprobs
topk_token_ids = tokens[i, :].tolist()
topk_logprobs = scores[i, :].tolist()
sampled_rank = ranks[i].item()
result.outputs.top_logprobs = LogprobsLists(
logprob_token_ids=[topk_token_ids],
logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank],
)
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
result.finished = True
if recovery_stop: