mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[stop sequence] support stop sequence (#3025)
* stop seqs in multi-ends * unittest for gpu stop op * kernel tid==0
This commit is contained in:
@@ -275,11 +275,16 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
request.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32")
|
||||
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
|
||||
request.get("stop_token_ids"), dtype="int64"
|
||||
request.sampling_params.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
|
||||
request.sampling_params.stop_seqs_len, dtype="int32"
|
||||
)
|
||||
self.share_inputs["stop_seqs"][
|
||||
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
|
||||
] = np.array(request.get("stop_token_ids"), dtype="int64")
|
||||
else:
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
|
||||
|
||||
if has_prefill_task:
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
|
||||
@@ -446,11 +451,15 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
request.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32")
|
||||
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
|
||||
request.get("stop_token_ids"), dtype="int64"
|
||||
request.sampling_params.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
|
||||
request.sampling_params.stop_seqs_len, dtype="int32"
|
||||
)
|
||||
self.share_inputs["stop_seqs"][
|
||||
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
|
||||
] = np.array(request.get("stop_token_ids"), dtype="int64")
|
||||
else:
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
|
||||
|
||||
self.sampler.apply_logits_processor(idx, request.get("logits_processor"), prefill_tokens)
|
||||
|
||||
@@ -619,14 +628,17 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["free_list_len"] = paddle.full([1], self.free_list_len, dtype="int32")
|
||||
|
||||
# Initialize stop seqs
|
||||
self.share_inputs["stop_seqs_len"] = paddle.full([self.model_config.max_stop_seqs_num], 0, dtype="int32")
|
||||
self.share_inputs["stop_seqs_len"] = paddle.full(
|
||||
[max_num_seqs, self.model_config.max_stop_seqs_num], 0, dtype="int32"
|
||||
)
|
||||
self.share_inputs["stop_seqs"] = paddle.full(
|
||||
[
|
||||
max_num_seqs,
|
||||
self.model_config.max_stop_seqs_num,
|
||||
self.model_config.stop_seqs_max_len,
|
||||
],
|
||||
-1,
|
||||
dtype="int32",
|
||||
dtype="int64",
|
||||
)
|
||||
if self.speculative_decoding:
|
||||
max_draft_token_num = self.speculative_config.num_speculative_tokens
|
||||
@@ -1012,6 +1024,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
|
||||
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
|
||||
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
)
|
||||
|
||||
post_process(
|
||||
@@ -1276,6 +1290,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
|
||||
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
|
||||
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
)
|
||||
|
||||
if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill":
|
||||
|
Reference in New Issue
Block a user