diff --git a/fastdeploy/engine/resource_manager.py b/fastdeploy/engine/resource_manager.py index a0f0b6d59..086cf6aeb 100644 --- a/fastdeploy/engine/resource_manager.py +++ b/fastdeploy/engine/resource_manager.py @@ -325,8 +325,8 @@ class ResourceManager: Delete cached data from the task's prompt token ids based on the cached length. """ if cached_len == len(task.prompt_token_ids): - task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :] - task.seq_lens_decoder = cached_len - 1 + task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :] + task.seq_lens_decoder = cached_len - self.cfg.block_size else: task.prompt_token_ids = task.prompt_token_ids[cached_len:] task.seq_lens_decoder = cached_len diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 412a7eda7..65c044201 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -445,8 +445,8 @@ class MTPSampler(nn.Layer): sampling_metadata.min_dec_lens, sampling_metadata.eos_token_ids, share_inputs["seq_lens_this_time"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], + share_inputs["output_padding_offset"], + share_inputs["output_cum_offsets"], max_model_len, ) probs = F.softmax(logits) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 3f590b73c..e26c0b057 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -438,6 +438,7 @@ class TokenProcessor: batch = self.output_tokens[1] accept_num = tokens[2 : batch + 2] self._record_speculative_decoding_mertics(accept_num) + else: batch = self.output_tokens[1, 0] tokens = tokens[2 : batch + 2] @@ -452,16 +453,22 @@ class TokenProcessor: task_id = task.request_id if self.cfg.speculative_config.method: - token_ids = tokens[ - 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS : 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS - + accept_num[i] - ].tolist() - if len(token_ids) == 0 or token_ids[-1] <= 0: - continue + if accept_num[i] == -3: + recovery_stop = True + if recovery_stop: + llm_logger.info(f"recovery stop signal found at task {task_id}") + token_ids = [RECOVERY_STOP_SIGNAL] + else: + token_ids = tokens[ + 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS : 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS + + accept_num[i] + ].tolist() + if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0): + continue else: token_id = int(tokens[i, 0]) token_ids = [token_id] @@ -589,7 +596,7 @@ class TokenProcessor: self.cfg.speculative_config.num_speculative_tokens, ) - real_accept_num = [x for x in accept_num if x != 0] + real_accept_num = [x for x in accept_num if x > 0] num_accepted_tokens = sum([x - 1 for x in real_accept_num]) self.num_accepted_tokens += num_accepted_tokens num_emitted_tokens = sum(real_accept_num)