mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
merge logprob into batch_output (#3266)
This commit is contained in:
@@ -30,6 +30,7 @@ from fastdeploy.config import (
|
|||||||
TaskOption,
|
TaskOption,
|
||||||
)
|
)
|
||||||
from fastdeploy.engine.config import Config
|
from fastdeploy.engine.config import Config
|
||||||
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.scheduler.config import SchedulerConfig
|
from fastdeploy.scheduler.config import SchedulerConfig
|
||||||
from fastdeploy.utils import DeprecatedOptionWarning, FlexibleArgumentParser
|
from fastdeploy.utils import DeprecatedOptionWarning, FlexibleArgumentParser
|
||||||
|
|
||||||
@@ -344,6 +345,13 @@ class EngineArgs:
|
|||||||
"""
|
"""
|
||||||
if not self.tokenizer:
|
if not self.tokenizer:
|
||||||
self.tokenizer = self.model
|
self.tokenizer = self.model
|
||||||
|
if self.enable_logprob:
|
||||||
|
if self.speculative_config is not None:
|
||||||
|
raise NotImplementedError("Logprob does not support speculation_config.")
|
||||||
|
if self.enable_expert_parallel:
|
||||||
|
raise NotImplementedError("Logprob does not support enable_expert_parallel.")
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
|
@@ -57,6 +57,7 @@ class TokenProcessor:
|
|||||||
self.split_connector = split_connector
|
self.split_connector = split_connector
|
||||||
|
|
||||||
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
||||||
|
self.use_logprobs = self.cfg.enable_logprob
|
||||||
|
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
self.output_tokens = paddle.full(
|
self.output_tokens = paddle.full(
|
||||||
@@ -64,7 +65,7 @@ class TokenProcessor:
|
|||||||
fill_value=2,
|
fill_value=2,
|
||||||
dtype="int64",
|
dtype="int64",
|
||||||
)
|
)
|
||||||
elif self.cfg.enable_logprob:
|
elif self.use_logprobs:
|
||||||
self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
|
self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
|
||||||
self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
|
self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
|
||||||
self.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64")
|
self.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64")
|
||||||
@@ -125,53 +126,12 @@ class TokenProcessor:
|
|||||||
assert self.resource_manager is not None, "The resource manager is None, cannot run."
|
assert self.resource_manager is not None, "The resource manager is None, cannot run."
|
||||||
if self.worker is not None:
|
if self.worker is not None:
|
||||||
raise Exception("Worker is already running!")
|
raise Exception("Worker is already running!")
|
||||||
use_logprobs = (
|
|
||||||
self.cfg.enable_logprob
|
|
||||||
and not self.speculative_decoding
|
|
||||||
and not self.cfg.parallel_config.enable_expert_parallel
|
|
||||||
)
|
|
||||||
|
|
||||||
target_func = self.process_sampling_with_logprob_results if use_logprobs else self.process_sampling_results
|
self.worker = threading.Thread(target=self.process_sampling_results)
|
||||||
|
|
||||||
self.worker = threading.Thread(target=target_func)
|
|
||||||
|
|
||||||
self.worker.daemon = True
|
self.worker.daemon = True
|
||||||
self.worker.start()
|
self.worker.start()
|
||||||
|
|
||||||
def process_sampling_with_logprob_results(self):
|
|
||||||
"""
|
|
||||||
read tokens from paddle inference engine and process logprob results
|
|
||||||
"""
|
|
||||||
if current_platform.is_cuda():
|
|
||||||
from fastdeploy.model_executor.ops.gpu import get_output_topk
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
|
||||||
|
|
||||||
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
is_blocking = True
|
|
||||||
get_output_topk(
|
|
||||||
self.output_tokens,
|
|
||||||
self.output_scores,
|
|
||||||
self.output_ranks,
|
|
||||||
K,
|
|
||||||
rank_id,
|
|
||||||
is_blocking,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.output_tokens[0, 0] == -2:
|
|
||||||
continue
|
|
||||||
llm_logger.debug(
|
|
||||||
f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}"
|
|
||||||
f"rank_id {rank_id} self.output_scores[0, 0] {self.output_scores[0, 0]}"
|
|
||||||
)
|
|
||||||
self._process_prefill_metrics()
|
|
||||||
self._process_sampling_with_logprob_batch_output()
|
|
||||||
except Exception as e:
|
|
||||||
llm_logger.info(f"while get input_data error: {e} {traceback.format_exc()!s}")
|
|
||||||
|
|
||||||
def process_sampling_results(self):
|
def process_sampling_results(self):
|
||||||
"""
|
"""
|
||||||
read tokens from paddle inference engine and process
|
read tokens from paddle inference engine and process
|
||||||
@@ -187,6 +147,7 @@ class TokenProcessor:
|
|||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
get_output,
|
get_output,
|
||||||
get_output_ep,
|
get_output_ep,
|
||||||
|
get_output_topk,
|
||||||
speculate_get_output,
|
speculate_get_output,
|
||||||
)
|
)
|
||||||
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
||||||
@@ -207,7 +168,17 @@ class TokenProcessor:
|
|||||||
get_output_ep(self.output_tokens, rank_id, is_blocking)
|
get_output_ep(self.output_tokens, rank_id, is_blocking)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
get_output(self.output_tokens, rank_id, is_blocking)
|
if self.use_logprobs:
|
||||||
|
get_output_topk(
|
||||||
|
self.output_tokens,
|
||||||
|
self.output_scores,
|
||||||
|
self.output_ranks,
|
||||||
|
K,
|
||||||
|
rank_id,
|
||||||
|
is_blocking,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
get_output(self.output_tokens, rank_id, is_blocking)
|
||||||
|
|
||||||
if self.output_tokens[0, 0] == -2:
|
if self.output_tokens[0, 0] == -2:
|
||||||
continue
|
continue
|
||||||
@@ -305,129 +276,23 @@ class TokenProcessor:
|
|||||||
self.total_step = 0
|
self.total_step = 0
|
||||||
self.speculative_stats_step += 1
|
self.speculative_stats_step += 1
|
||||||
|
|
||||||
def _process_sampling_with_logprob_batch_output(self):
|
|
||||||
"""
|
|
||||||
batch post-processing logprob output function
|
|
||||||
"""
|
|
||||||
|
|
||||||
batch = self.output_tokens[1, 0]
|
|
||||||
tokens = self.output_tokens[2 : batch * (K + 1) + 2].numpy().reshape([batch, K + 1])[:, : (K + 1)]
|
|
||||||
scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)]
|
|
||||||
ranks = self.output_ranks[:batch].numpy()
|
|
||||||
batch_result = list()
|
|
||||||
for i in range(batch):
|
|
||||||
if self.resource_manager.stop_flags[i]:
|
|
||||||
continue
|
|
||||||
task = self.resource_manager.tasks_list[i]
|
|
||||||
task_id = task.request_id
|
|
||||||
token_id = int(tokens[i, 0])
|
|
||||||
token_ids = [token_id]
|
|
||||||
recovery_stop = token_id == RECOVERY_STOP_SIGNAL
|
|
||||||
if recovery_stop:
|
|
||||||
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
|
||||||
if not recovery_stop and token_id < 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if task.get("prefill_chunk_info", None) is not None:
|
|
||||||
prefill_chunk_num = task.get("prefill_chunk_num", 0)
|
|
||||||
task.prefill_chunk_num = prefill_chunk_num + 1
|
|
||||||
|
|
||||||
if task.prefill_chunk_num < len(task.prefill_chunk_info):
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.total_step += 1
|
|
||||||
current_time = time.time()
|
|
||||||
if self.tokens_counter[task_id] == 0:
|
|
||||||
metrics = RequestMetrics(
|
|
||||||
arrival_time=task.arrival_time,
|
|
||||||
inference_start_time=task.inference_start_time,
|
|
||||||
first_token_time=time.time() - task.inference_start_time,
|
|
||||||
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
|
|
||||||
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
|
|
||||||
request_start_time=task.arrival_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._record_first_token_metrics(task, current_time)
|
|
||||||
|
|
||||||
else:
|
|
||||||
metrics = RequestMetrics(
|
|
||||||
arrival_time=time.time(),
|
|
||||||
request_start_time=task.arrival_time,
|
|
||||||
)
|
|
||||||
self.number_of_output_tokens += len(token_ids)
|
|
||||||
self._record_metrics(task, current_time, token_ids)
|
|
||||||
result = RequestOutput(
|
|
||||||
request_id=task_id,
|
|
||||||
outputs=CompletionOutput(
|
|
||||||
index=i,
|
|
||||||
send_idx=self.tokens_counter[task_id],
|
|
||||||
token_ids=[],
|
|
||||||
logprob=None,
|
|
||||||
draft_token_ids=[],
|
|
||||||
top_logprobs=None,
|
|
||||||
),
|
|
||||||
finished=False,
|
|
||||||
metrics=metrics,
|
|
||||||
)
|
|
||||||
if self.tokens_counter[task_id] == 0:
|
|
||||||
if task.messages is not None:
|
|
||||||
result.prompt = task.messages
|
|
||||||
result.num_cached_tokens = task.num_cached_tokens
|
|
||||||
|
|
||||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
|
||||||
|
|
||||||
if is_prefill and len(token_ids) > 1:
|
|
||||||
result.outputs.draft_token_ids = copy.deepcopy(token_ids)
|
|
||||||
|
|
||||||
for idx, token_id in enumerate(token_ids):
|
|
||||||
self.tokens_counter[task_id] += 1
|
|
||||||
if token_id != RECOVERY_STOP_SIGNAL:
|
|
||||||
result.outputs.token_ids.append(token_id)
|
|
||||||
task.output_token_ids.append(token_id)
|
|
||||||
result.outputs.logprob = float(scores[i, 0])
|
|
||||||
# Construct top_logprobs
|
|
||||||
topk_token_ids = tokens[i, :].tolist()
|
|
||||||
topk_logprobs = scores[i, :].tolist()
|
|
||||||
sampled_rank = ranks[i].item()
|
|
||||||
|
|
||||||
result.outputs.top_logprobs = LogprobsLists(
|
|
||||||
logprob_token_ids=[topk_token_ids],
|
|
||||||
logprobs=[topk_logprobs],
|
|
||||||
sampled_token_ranks=[sampled_rank],
|
|
||||||
)
|
|
||||||
|
|
||||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
|
||||||
result.finished = True
|
|
||||||
if recovery_stop:
|
|
||||||
result.error_msg = "Recover is not supported, the result is incomplete!"
|
|
||||||
llm_logger.info(
|
|
||||||
f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}."
|
|
||||||
)
|
|
||||||
llm_logger.info(
|
|
||||||
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
|
|
||||||
)
|
|
||||||
llm_logger.info(f"{self.resource_manager.info()}")
|
|
||||||
if self.cfg.speculative_config.method:
|
|
||||||
self._compute_speculative_status()
|
|
||||||
if not is_prefill:
|
|
||||||
self._record_completion_metrics(task, current_time)
|
|
||||||
self._recycle_resources(task_id, i, task, result, is_prefill)
|
|
||||||
break
|
|
||||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
|
||||||
batch_result.append(result)
|
|
||||||
|
|
||||||
self.postprocess(batch_result)
|
|
||||||
|
|
||||||
def _process_batch_output(self):
|
def _process_batch_output(self):
|
||||||
"""
|
"""
|
||||||
batch post-processing function
|
batch post-processing function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokens = self.output_tokens.numpy()
|
tokens = self.output_tokens.numpy()
|
||||||
|
scores = None
|
||||||
|
ranks = None
|
||||||
if self.cfg.speculative_config.method:
|
if self.cfg.speculative_config.method:
|
||||||
batch = self.output_tokens[1]
|
batch = self.output_tokens[1]
|
||||||
accept_num = tokens[2 : batch + 2]
|
accept_num = tokens[2 : batch + 2]
|
||||||
self._record_speculative_decoding_mertics(accept_num)
|
self._record_speculative_decoding_mertics(accept_num)
|
||||||
|
elif self.use_logprobs:
|
||||||
|
batch = self.output_tokens[1, 0]
|
||||||
|
tokens = tokens[2 : batch * (K + 1) + 2].reshape([batch, K + 1])[:, : (K + 1)]
|
||||||
|
scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)]
|
||||||
|
ranks = self.output_ranks[:batch].numpy()
|
||||||
else:
|
else:
|
||||||
batch = self.output_tokens[1, 0]
|
batch = self.output_tokens[1, 0]
|
||||||
tokens = tokens[2 : batch + 2]
|
tokens = tokens[2 : batch + 2]
|
||||||
@@ -522,6 +387,17 @@ class TokenProcessor:
|
|||||||
if token_id != RECOVERY_STOP_SIGNAL:
|
if token_id != RECOVERY_STOP_SIGNAL:
|
||||||
result.outputs.token_ids.append(token_id)
|
result.outputs.token_ids.append(token_id)
|
||||||
task.output_token_ids.append(token_id)
|
task.output_token_ids.append(token_id)
|
||||||
|
if self.use_logprobs:
|
||||||
|
result.outputs.logprob = float(scores[i, 0])
|
||||||
|
# Construct top_logprobs
|
||||||
|
topk_token_ids = tokens[i, :].tolist()
|
||||||
|
topk_logprobs = scores[i, :].tolist()
|
||||||
|
sampled_rank = ranks[i].item()
|
||||||
|
result.outputs.top_logprobs = LogprobsLists(
|
||||||
|
logprob_token_ids=[topk_token_ids],
|
||||||
|
logprobs=[topk_logprobs],
|
||||||
|
sampled_token_ranks=[sampled_rank],
|
||||||
|
)
|
||||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||||
result.finished = True
|
result.finished = True
|
||||||
if recovery_stop:
|
if recovery_stop:
|
||||||
|
Reference in New Issue
Block a user