mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
[Bug fix] Fix bug for seq_len_encoder is 1 (#3467)
This commit is contained in:
@@ -325,8 +325,8 @@ class ResourceManager:
|
|||||||
Delete cached data from the task's prompt token ids based on the cached length.
|
Delete cached data from the task's prompt token ids based on the cached length.
|
||||||
"""
|
"""
|
||||||
if cached_len == len(task.prompt_token_ids):
|
if cached_len == len(task.prompt_token_ids):
|
||||||
task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :]
|
task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :]
|
||||||
task.seq_lens_decoder = cached_len - 1
|
task.seq_lens_decoder = cached_len - self.cfg.block_size
|
||||||
else:
|
else:
|
||||||
task.prompt_token_ids = task.prompt_token_ids[cached_len:]
|
task.prompt_token_ids = task.prompt_token_ids[cached_len:]
|
||||||
task.seq_lens_decoder = cached_len
|
task.seq_lens_decoder = cached_len
|
||||||
|
@@ -445,8 +445,8 @@ class MTPSampler(nn.Layer):
|
|||||||
sampling_metadata.min_dec_lens,
|
sampling_metadata.min_dec_lens,
|
||||||
sampling_metadata.eos_token_ids,
|
sampling_metadata.eos_token_ids,
|
||||||
share_inputs["seq_lens_this_time"],
|
share_inputs["seq_lens_this_time"],
|
||||||
share_inputs["seq_lens_encoder"],
|
share_inputs["output_padding_offset"],
|
||||||
share_inputs["seq_lens_decoder"],
|
share_inputs["output_cum_offsets"],
|
||||||
max_model_len,
|
max_model_len,
|
||||||
)
|
)
|
||||||
probs = F.softmax(logits)
|
probs = F.softmax(logits)
|
||||||
|
@@ -438,6 +438,7 @@ class TokenProcessor:
|
|||||||
batch = self.output_tokens[1]
|
batch = self.output_tokens[1]
|
||||||
accept_num = tokens[2 : batch + 2]
|
accept_num = tokens[2 : batch + 2]
|
||||||
self._record_speculative_decoding_mertics(accept_num)
|
self._record_speculative_decoding_mertics(accept_num)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
batch = self.output_tokens[1, 0]
|
batch = self.output_tokens[1, 0]
|
||||||
tokens = tokens[2 : batch + 2]
|
tokens = tokens[2 : batch + 2]
|
||||||
@@ -452,16 +453,22 @@ class TokenProcessor:
|
|||||||
|
|
||||||
task_id = task.request_id
|
task_id = task.request_id
|
||||||
if self.cfg.speculative_config.method:
|
if self.cfg.speculative_config.method:
|
||||||
token_ids = tokens[
|
if accept_num[i] == -3:
|
||||||
2
|
recovery_stop = True
|
||||||
+ SPECULATE_MAX_BSZ
|
if recovery_stop:
|
||||||
+ i * MAX_DRAFT_TOKENS : 2
|
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||||
+ SPECULATE_MAX_BSZ
|
token_ids = [RECOVERY_STOP_SIGNAL]
|
||||||
+ i * MAX_DRAFT_TOKENS
|
else:
|
||||||
+ accept_num[i]
|
token_ids = tokens[
|
||||||
].tolist()
|
2
|
||||||
if len(token_ids) == 0 or token_ids[-1] <= 0:
|
+ SPECULATE_MAX_BSZ
|
||||||
continue
|
+ i * MAX_DRAFT_TOKENS : 2
|
||||||
|
+ SPECULATE_MAX_BSZ
|
||||||
|
+ i * MAX_DRAFT_TOKENS
|
||||||
|
+ accept_num[i]
|
||||||
|
].tolist()
|
||||||
|
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
token_id = int(tokens[i, 0])
|
token_id = int(tokens[i, 0])
|
||||||
token_ids = [token_id]
|
token_ids = [token_id]
|
||||||
@@ -589,7 +596,7 @@ class TokenProcessor:
|
|||||||
self.cfg.speculative_config.num_speculative_tokens,
|
self.cfg.speculative_config.num_speculative_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
real_accept_num = [x for x in accept_num if x != 0]
|
real_accept_num = [x for x in accept_num if x > 0]
|
||||||
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
|
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
|
||||||
self.num_accepted_tokens += num_accepted_tokens
|
self.num_accepted_tokens += num_accepted_tokens
|
||||||
num_emitted_tokens = sum(real_accept_num)
|
num_emitted_tokens = sum(real_accept_num)
|
||||||
|
Reference in New Issue
Block a user