mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
fix scheduler bug for bs=1 (#3288)
This commit is contained in:
@@ -203,6 +203,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
req_len = len(req_dicts)
|
||||
has_prefill_task = False
|
||||
has_decode_task = False
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
@@ -240,6 +241,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode
|
||||
has_decode_task = True
|
||||
continue
|
||||
else: # preempted task
|
||||
logger.debug(f"Handle preempted request {request} at idx {idx}")
|
||||
@@ -280,7 +283,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
|
||||
request.get("stop_token_ids"), dtype="int64"
|
||||
)
|
||||
if has_prefill_task:
|
||||
if has_prefill_task or has_decode_task:
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request]):
|
||||
|
Reference in New Issue
Block a user