diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index e609475b6..ba712ed0c 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -30,6 +30,7 @@ from fastdeploy.config import ( TaskOption, ) from fastdeploy.engine.config import Config +from fastdeploy.platforms import current_platform from fastdeploy.scheduler.config import SchedulerConfig from fastdeploy.utils import DeprecatedOptionWarning, FlexibleArgumentParser @@ -344,6 +345,13 @@ class EngineArgs: """ if not self.tokenizer: self.tokenizer = self.model + if self.enable_logprob: + if self.speculative_config is not None: + raise NotImplementedError("Logprob does not support speculation_config.") + if self.enable_expert_parallel: + raise NotImplementedError("Logprob does not support enable_expert_parallel.") + if not current_platform.is_cuda(): + raise NotImplementedError("Only CUDA platform supports logprob.") @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 000c4c0dc..d12fec2ac 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -57,6 +57,7 @@ class TokenProcessor: self.split_connector = split_connector self.speculative_decoding = self.cfg.speculative_config.method is not None + self.use_logprobs = self.cfg.enable_logprob if self.speculative_decoding: self.output_tokens = paddle.full( @@ -64,7 +65,7 @@ class TokenProcessor: fill_value=2, dtype="int64", ) - elif self.cfg.enable_logprob: + elif self.use_logprobs: self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") self.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64") @@ -125,53 +126,12 @@ class TokenProcessor: assert self.resource_manager is not None, "The resource manager is None, cannot run." if self.worker is not None: raise Exception("Worker is already running!") - use_logprobs = ( - self.cfg.enable_logprob - and not self.speculative_decoding - and not self.cfg.parallel_config.enable_expert_parallel - ) - target_func = self.process_sampling_with_logprob_results if use_logprobs else self.process_sampling_results - - self.worker = threading.Thread(target=target_func) + self.worker = threading.Thread(target=self.process_sampling_results) self.worker.daemon = True self.worker.start() - def process_sampling_with_logprob_results(self): - """ - read tokens from paddle inference engine and process logprob results - """ - if current_platform.is_cuda(): - from fastdeploy.model_executor.ops.gpu import get_output_topk - else: - raise NotImplementedError("Only CUDA platform supports logprob.") - - rank_id = self.cfg.parallel_config.local_data_parallel_id - - while True: - try: - is_blocking = True - get_output_topk( - self.output_tokens, - self.output_scores, - self.output_ranks, - K, - rank_id, - is_blocking, - ) - - if self.output_tokens[0, 0] == -2: - continue - llm_logger.debug( - f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}" - f"rank_id {rank_id} self.output_scores[0, 0] {self.output_scores[0, 0]}" - ) - self._process_prefill_metrics() - self._process_sampling_with_logprob_batch_output() - except Exception as e: - llm_logger.info(f"while get input_data error: {e} {traceback.format_exc()!s}") - def process_sampling_results(self): """ read tokens from paddle inference engine and process @@ -187,6 +147,7 @@ class TokenProcessor: from fastdeploy.model_executor.ops.gpu import ( get_output, get_output_ep, + get_output_topk, speculate_get_output, ) rank_id = self.cfg.parallel_config.local_data_parallel_id @@ -207,7 +168,17 @@ class TokenProcessor: get_output_ep(self.output_tokens, rank_id, is_blocking) else: - get_output(self.output_tokens, rank_id, is_blocking) + if self.use_logprobs: + get_output_topk( + self.output_tokens, + self.output_scores, + self.output_ranks, + K, + rank_id, + is_blocking, + ) + else: + get_output(self.output_tokens, rank_id, is_blocking) if self.output_tokens[0, 0] == -2: continue @@ -305,129 +276,23 @@ class TokenProcessor: self.total_step = 0 self.speculative_stats_step += 1 - def _process_sampling_with_logprob_batch_output(self): - """ - batch post-processing logprob output function - """ - - batch = self.output_tokens[1, 0] - tokens = self.output_tokens[2 : batch * (K + 1) + 2].numpy().reshape([batch, K + 1])[:, : (K + 1)] - scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)] - ranks = self.output_ranks[:batch].numpy() - batch_result = list() - for i in range(batch): - if self.resource_manager.stop_flags[i]: - continue - task = self.resource_manager.tasks_list[i] - task_id = task.request_id - token_id = int(tokens[i, 0]) - token_ids = [token_id] - recovery_stop = token_id == RECOVERY_STOP_SIGNAL - if recovery_stop: - llm_logger.info(f"recovery stop signal found at task {task_id}") - if not recovery_stop and token_id < 0: - continue - - if task.get("prefill_chunk_info", None) is not None: - prefill_chunk_num = task.get("prefill_chunk_num", 0) - task.prefill_chunk_num = prefill_chunk_num + 1 - - if task.prefill_chunk_num < len(task.prefill_chunk_info): - continue - - self.total_step += 1 - current_time = time.time() - if self.tokens_counter[task_id] == 0: - metrics = RequestMetrics( - arrival_time=task.arrival_time, - inference_start_time=task.inference_start_time, - first_token_time=time.time() - task.inference_start_time, - time_in_queue=task.schedule_start_time - task.preprocess_end_time, - preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time, - request_start_time=task.arrival_time, - ) - - self._record_first_token_metrics(task, current_time) - - else: - metrics = RequestMetrics( - arrival_time=time.time(), - request_start_time=task.arrival_time, - ) - self.number_of_output_tokens += len(token_ids) - self._record_metrics(task, current_time, token_ids) - result = RequestOutput( - request_id=task_id, - outputs=CompletionOutput( - index=i, - send_idx=self.tokens_counter[task_id], - token_ids=[], - logprob=None, - draft_token_ids=[], - top_logprobs=None, - ), - finished=False, - metrics=metrics, - ) - if self.tokens_counter[task_id] == 0: - if task.messages is not None: - result.prompt = task.messages - result.num_cached_tokens = task.num_cached_tokens - - is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill" - - if is_prefill and len(token_ids) > 1: - result.outputs.draft_token_ids = copy.deepcopy(token_ids) - - for idx, token_id in enumerate(token_ids): - self.tokens_counter[task_id] += 1 - if token_id != RECOVERY_STOP_SIGNAL: - result.outputs.token_ids.append(token_id) - task.output_token_ids.append(token_id) - result.outputs.logprob = float(scores[i, 0]) - # Construct top_logprobs - topk_token_ids = tokens[i, :].tolist() - topk_logprobs = scores[i, :].tolist() - sampled_rank = ranks[i].item() - - result.outputs.top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) - - if token_id in task.eos_token_ids or is_prefill or recovery_stop: - result.finished = True - if recovery_stop: - result.error_msg = "Recover is not supported, the result is incomplete!" - llm_logger.info( - f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}." - ) - llm_logger.info( - f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}" - ) - llm_logger.info(f"{self.resource_manager.info()}") - if self.cfg.speculative_config.method: - self._compute_speculative_status() - if not is_prefill: - self._record_completion_metrics(task, current_time) - self._recycle_resources(task_id, i, task, result, is_prefill) - break - if not is_prefill or self.cfg.scheduler_config.name == "splitwise": - batch_result.append(result) - - self.postprocess(batch_result) - def _process_batch_output(self): """ batch post-processing function """ tokens = self.output_tokens.numpy() + scores = None + ranks = None if self.cfg.speculative_config.method: batch = self.output_tokens[1] accept_num = tokens[2 : batch + 2] self._record_speculative_decoding_mertics(accept_num) + elif self.use_logprobs: + batch = self.output_tokens[1, 0] + tokens = tokens[2 : batch * (K + 1) + 2].reshape([batch, K + 1])[:, : (K + 1)] + scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)] + ranks = self.output_ranks[:batch].numpy() else: batch = self.output_tokens[1, 0] tokens = tokens[2 : batch + 2] @@ -522,6 +387,17 @@ class TokenProcessor: if token_id != RECOVERY_STOP_SIGNAL: result.outputs.token_ids.append(token_id) task.output_token_ids.append(token_id) + if self.use_logprobs: + result.outputs.logprob = float(scores[i, 0]) + # Construct top_logprobs + topk_token_ids = tokens[i, :].tolist() + topk_logprobs = scores[i, :].tolist() + sampled_rank = ranks[i].item() + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True if recovery_stop: