fix scheduler bug for bs=1 (#3288)

This commit is contained in:
chenjian
2025-08-09 12:22:12 +08:00
committed by GitHub
parent ce1d4944e7
commit c208086f61

View File

@@ -203,6 +203,7 @@ class GPUModelRunner(ModelRunnerBase):
req_len = len(req_dicts) req_len = len(req_dicts)
has_prefill_task = False has_prefill_task = False
has_decode_task = False
for i in range(req_len): for i in range(req_len):
request = req_dicts[i] request = req_dicts[i]
idx = request.idx idx = request.idx
@@ -240,6 +241,8 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32" request.block_tables, dtype="int32"
) )
if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode
has_decode_task = True
continue continue
else: # preempted task else: # preempted task
logger.debug(f"Handle preempted request {request} at idx {idx}") logger.debug(f"Handle preempted request {request} at idx {idx}")
@@ -280,7 +283,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
request.get("stop_token_ids"), dtype="int64" request.get("stop_token_ids"), dtype="int64"
) )
if has_prefill_task: if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True self.share_inputs["not_need_stop"][0] = True
def insert_prefill_inputs(self, req_dicts: List[Request]): def insert_prefill_inputs(self, req_dicts: List[Request]):