[Bug fix] Fix bug for seq_len_encoder is 1 (#3467)

This commit is contained in:
chenjian
2025-08-19 15:21:32 +08:00
committed by GitHub
parent aba94169dc
commit d2f6c3b998
3 changed files with 22 additions and 15 deletions

View File

@@ -438,6 +438,7 @@ class TokenProcessor:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
self._record_speculative_decoding_mertics(accept_num)
else:
batch = self.output_tokens[1, 0]
tokens = tokens[2 : batch + 2]
@@ -452,16 +453,22 @@ class TokenProcessor:
task_id = task.request_id
if self.cfg.speculative_config.method:
token_ids = tokens[
2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS : 2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS
+ accept_num[i]
].tolist()
if len(token_ids) == 0 or token_ids[-1] <= 0:
continue
if accept_num[i] == -3:
recovery_stop = True
if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}")
token_ids = [RECOVERY_STOP_SIGNAL]
else:
token_ids = tokens[
2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS : 2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS
+ accept_num[i]
].tolist()
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
continue
else:
token_id = int(tokens[i, 0])
token_ids = [token_id]
@@ -589,7 +596,7 @@ class TokenProcessor:
self.cfg.speculative_config.num_speculative_tokens,
)
real_accept_num = [x for x in accept_num if x != 0]
real_accept_num = [x for x in accept_num if x > 0]
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
self.num_accepted_tokens += num_accepted_tokens
num_emitted_tokens = sum(real_accept_num)