diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index a86ef2432..ddcf31d28 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -27,6 +27,7 @@ import time import traceback import uuid import weakref +from dataclasses import asdict import numpy as np import paddle @@ -190,6 +191,8 @@ class LLMEngine: """ # TODO 输入输出长度确认 + if sampling_params is not None: + task.update(asdict(sampling_params)) request = Request.from_dict(task) llm_logger.info(f"Receive request {request}") if sampling_params is not None: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 291388aa6..85b144805 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -263,7 +263,10 @@ class GPUModelRunner(ModelRunnerBase): position_ids, request.get("max_tokens", 2048) ) - input_ids = request.prompt_token_ids + request.output_token_ids + if len(request.output_token_ids) == 0: + input_ids = request.prompt_token_ids + else: + input_ids = request.prompt_token_ids + request.output_token_ids logger.debug( f"Handle prefill request {request} at idx {idx}, " f"{prefill_start_index=}, {prefill_end_index=}, "