[Bug fix] Fix bug for seq_len_encoder is 1 (#3467)

This commit is contained in:
chenjian
2025-08-19 15:21:32 +08:00
committed by GitHub
parent aba94169dc
commit d2f6c3b998
3 changed files with 22 additions and 15 deletions

View File

@@ -325,8 +325,8 @@ class ResourceManager:
Delete cached data from the task's prompt token ids based on the cached length.
"""
if cached_len == len(task.prompt_token_ids):
task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :]
task.seq_lens_decoder = cached_len - 1
task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :]
task.seq_lens_decoder = cached_len - self.cfg.block_size
else:
task.prompt_token_ids = task.prompt_token_ids[cached_len:]
task.seq_lens_decoder = cached_len

View File

@@ -445,8 +445,8 @@ class MTPSampler(nn.Layer):
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
max_model_len,
)
probs = F.softmax(logits)

View File

@@ -438,6 +438,7 @@ class TokenProcessor:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
self._record_speculative_decoding_mertics(accept_num)
else:
batch = self.output_tokens[1, 0]
tokens = tokens[2 : batch + 2]
@@ -452,16 +453,22 @@ class TokenProcessor:
task_id = task.request_id
if self.cfg.speculative_config.method:
token_ids = tokens[
2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS : 2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS
+ accept_num[i]
].tolist()
if len(token_ids) == 0 or token_ids[-1] <= 0:
continue
if accept_num[i] == -3:
recovery_stop = True
if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}")
token_ids = [RECOVERY_STOP_SIGNAL]
else:
token_ids = tokens[
2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS : 2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS
+ accept_num[i]
].tolist()
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
continue
else:
token_id = int(tokens[i, 0])
token_ids = [token_id]
@@ -589,7 +596,7 @@ class TokenProcessor:
self.cfg.speculative_config.num_speculative_tokens,
)
real_accept_num = [x for x in accept_num if x != 0]
real_accept_num = [x for x in accept_num if x > 0]
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
self.num_accepted_tokens += num_accepted_tokens
num_emitted_tokens = sum(real_accept_num)