mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[BugFix] support real batch_size (#3109)
* support real bsz * fix * fix xpu_model_runner.py,gpu_model_runner.py,gcu_model_runner.py,iluvatar_model_runner.py * add event_loop_ep * fix * Add comments * fix * support mtp real_batch_size * fix * self.tmp_seq_lens_this_time->self.seq_lens_this_time_buffer * fix * fix VL real_seq_lens_this_time * fix * fix mtp * fix * fix mtp * fix xpu * fix
This commit is contained in:
@@ -107,7 +107,7 @@ class MTPProposer(Proposer):
|
||||
idx = i
|
||||
self.model_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
||||
self.model_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
|
||||
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = input_length
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = 0
|
||||
@@ -118,6 +118,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
|
||||
idx * block_num, (idx + 1) * block_num, 1
|
||||
)
|
||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
|
||||
|
||||
def initialize_kv_cache(self):
|
||||
"""
|
||||
@@ -263,7 +264,8 @@ class MTPProposer(Proposer):
|
||||
# Same shape/dytpe with base model
|
||||
self.model_inputs["block_tables"] = paddle.clone(self.main_model_inputs["block_tables"])
|
||||
self.model_inputs["input_ids"] = paddle.clone(self.main_model_inputs["input_ids"])
|
||||
self.model_inputs["seq_lens_this_time"] = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
|
||||
self.seq_lens_this_time_buffer = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
|
||||
|
||||
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.main_model_inputs["seq_lens_encoder"])
|
||||
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.main_model_inputs["seq_lens_decoder"])
|
||||
self.model_inputs["step_idx"] = paddle.clone(self.main_model_inputs["step_idx"])
|
||||
@@ -338,7 +340,7 @@ class MTPProposer(Proposer):
|
||||
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
|
||||
)
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request]):
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
|
||||
"""
|
||||
Process inputs for prefill tasks and insert it to model_inputs buffer
|
||||
"""
|
||||
@@ -372,7 +374,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = length
|
||||
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = prefill_token_num
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = prefill_token_num
|
||||
|
||||
self.model_inputs["stop_flags"][idx : idx + 1] = False
|
||||
self.model_inputs["batch_drop"][idx : idx + 1] = False
|
||||
@@ -397,10 +399,10 @@ class MTPProposer(Proposer):
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
token_chunk_size = request.prefill_chunk_info[0]
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
||||
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = token_chunk_size
|
||||
else:
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = length
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = length
|
||||
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.model_inputs["stop_flags"][idx : idx + 1] = False
|
||||
@@ -413,6 +415,7 @@ class MTPProposer(Proposer):
|
||||
request.get("block_tables"), dtype="int32"
|
||||
)
|
||||
self.model_inputs["not_need_stop"][0] = True
|
||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
|
||||
|
||||
def _initialize_forward_meta(self):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user