[Bug fix] Fix bug for seq_len_encoder is 1 (#3467)

This commit is contained in:
chenjian
2025-08-19 15:21:32 +08:00
committed by GitHub
parent aba94169dc
commit d2f6c3b998
3 changed files with 22 additions and 15 deletions

View File

@@ -325,8 +325,8 @@ class ResourceManager:
Delete cached data from the task's prompt token ids based on the cached length. Delete cached data from the task's prompt token ids based on the cached length.
""" """
if cached_len == len(task.prompt_token_ids): if cached_len == len(task.prompt_token_ids):
task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :] task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :]
task.seq_lens_decoder = cached_len - 1 task.seq_lens_decoder = cached_len - self.cfg.block_size
else: else:
task.prompt_token_ids = task.prompt_token_ids[cached_len:] task.prompt_token_ids = task.prompt_token_ids[cached_len:]
task.seq_lens_decoder = cached_len task.seq_lens_decoder = cached_len

View File

@@ -445,8 +445,8 @@ class MTPSampler(nn.Layer):
sampling_metadata.min_dec_lens, sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids, sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"], share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"], share_inputs["output_padding_offset"],
share_inputs["seq_lens_decoder"], share_inputs["output_cum_offsets"],
max_model_len, max_model_len,
) )
probs = F.softmax(logits) probs = F.softmax(logits)

View File

@@ -438,6 +438,7 @@ class TokenProcessor:
batch = self.output_tokens[1] batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2] accept_num = tokens[2 : batch + 2]
self._record_speculative_decoding_mertics(accept_num) self._record_speculative_decoding_mertics(accept_num)
else: else:
batch = self.output_tokens[1, 0] batch = self.output_tokens[1, 0]
tokens = tokens[2 : batch + 2] tokens = tokens[2 : batch + 2]
@@ -452,6 +453,12 @@ class TokenProcessor:
task_id = task.request_id task_id = task.request_id
if self.cfg.speculative_config.method: if self.cfg.speculative_config.method:
if accept_num[i] == -3:
recovery_stop = True
if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}")
token_ids = [RECOVERY_STOP_SIGNAL]
else:
token_ids = tokens[ token_ids = tokens[
2 2
+ SPECULATE_MAX_BSZ + SPECULATE_MAX_BSZ
@@ -460,7 +467,7 @@ class TokenProcessor:
+ i * MAX_DRAFT_TOKENS + i * MAX_DRAFT_TOKENS
+ accept_num[i] + accept_num[i]
].tolist() ].tolist()
if len(token_ids) == 0 or token_ids[-1] <= 0: if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
continue continue
else: else:
token_id = int(tokens[i, 0]) token_id = int(tokens[i, 0])
@@ -589,7 +596,7 @@ class TokenProcessor:
self.cfg.speculative_config.num_speculative_tokens, self.cfg.speculative_config.num_speculative_tokens,
) )
real_accept_num = [x for x in accept_num if x != 0] real_accept_num = [x for x in accept_num if x > 0]
num_accepted_tokens = sum([x - 1 for x in real_accept_num]) num_accepted_tokens = sum([x - 1 for x in real_accept_num])
self.num_accepted_tokens += num_accepted_tokens self.num_accepted_tokens += num_accepted_tokens
num_emitted_tokens = sum(real_accept_num) num_emitted_tokens = sum(real_accept_num)