mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
fix scheduler bug for bs=1 (#3288)
This commit is contained in:
@@ -203,6 +203,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
req_len = len(req_dicts)
|
req_len = len(req_dicts)
|
||||||
has_prefill_task = False
|
has_prefill_task = False
|
||||||
|
has_decode_task = False
|
||||||
for i in range(req_len):
|
for i in range(req_len):
|
||||||
request = req_dicts[i]
|
request = req_dicts[i]
|
||||||
idx = request.idx
|
idx = request.idx
|
||||||
@@ -240,6 +241,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||||
request.block_tables, dtype="int32"
|
request.block_tables, dtype="int32"
|
||||||
)
|
)
|
||||||
|
if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode
|
||||||
|
has_decode_task = True
|
||||||
continue
|
continue
|
||||||
else: # preempted task
|
else: # preempted task
|
||||||
logger.debug(f"Handle preempted request {request} at idx {idx}")
|
logger.debug(f"Handle preempted request {request} at idx {idx}")
|
||||||
@@ -280,7 +283,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
|
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
|
||||||
request.get("stop_token_ids"), dtype="int64"
|
request.get("stop_token_ids"), dtype="int64"
|
||||||
)
|
)
|
||||||
if has_prefill_task:
|
if has_prefill_task or has_decode_task:
|
||||||
self.share_inputs["not_need_stop"][0] = True
|
self.share_inputs["not_need_stop"][0] = True
|
||||||
|
|
||||||
def insert_prefill_inputs(self, req_dicts: List[Request]):
|
def insert_prefill_inputs(self, req_dicts: List[Request]):
|
||||||
|
Reference in New Issue
Block a user