mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
[Bug fix] Fix bug for seq_len_encoder is 1 (#3467)
This commit is contained in:
@@ -325,8 +325,8 @@ class ResourceManager:
|
||||
Delete cached data from the task's prompt token ids based on the cached length.
|
||||
"""
|
||||
if cached_len == len(task.prompt_token_ids):
|
||||
task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :]
|
||||
task.seq_lens_decoder = cached_len - 1
|
||||
task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :]
|
||||
task.seq_lens_decoder = cached_len - self.cfg.block_size
|
||||
else:
|
||||
task.prompt_token_ids = task.prompt_token_ids[cached_len:]
|
||||
task.seq_lens_decoder = cached_len
|
||||
|
@@ -445,8 +445,8 @@ class MTPSampler(nn.Layer):
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["output_padding_offset"],
|
||||
share_inputs["output_cum_offsets"],
|
||||
max_model_len,
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
@@ -438,6 +438,7 @@ class TokenProcessor:
|
||||
batch = self.output_tokens[1]
|
||||
accept_num = tokens[2 : batch + 2]
|
||||
self._record_speculative_decoding_mertics(accept_num)
|
||||
|
||||
else:
|
||||
batch = self.output_tokens[1, 0]
|
||||
tokens = tokens[2 : batch + 2]
|
||||
@@ -452,16 +453,22 @@ class TokenProcessor:
|
||||
|
||||
task_id = task.request_id
|
||||
if self.cfg.speculative_config.method:
|
||||
token_ids = tokens[
|
||||
2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS : 2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS
|
||||
+ accept_num[i]
|
||||
].tolist()
|
||||
if len(token_ids) == 0 or token_ids[-1] <= 0:
|
||||
continue
|
||||
if accept_num[i] == -3:
|
||||
recovery_stop = True
|
||||
if recovery_stop:
|
||||
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||
token_ids = [RECOVERY_STOP_SIGNAL]
|
||||
else:
|
||||
token_ids = tokens[
|
||||
2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS : 2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS
|
||||
+ accept_num[i]
|
||||
].tolist()
|
||||
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
|
||||
continue
|
||||
else:
|
||||
token_id = int(tokens[i, 0])
|
||||
token_ids = [token_id]
|
||||
@@ -589,7 +596,7 @@ class TokenProcessor:
|
||||
self.cfg.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
real_accept_num = [x for x in accept_num if x != 0]
|
||||
real_accept_num = [x for x in accept_num if x > 0]
|
||||
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
|
||||
self.num_accepted_tokens += num_accepted_tokens
|
||||
num_emitted_tokens = sum(real_accept_num)
|
||||
|
Reference in New Issue
Block a user