[MTP]update hybrid-mtp-with-ngram (#4047)

This commit is contained in:
freeliuzc
2025-09-15 17:13:31 +08:00
committed by GitHub
parent b1b33211e8
commit 46911f903d
3 changed files with 41 additions and 3 deletions

View File

@@ -18,6 +18,12 @@ This project implements an efficient **Speculative Decoding** inference framewor
- ⏳ Coming Soon: Support Chunk-prefill - ⏳ Coming Soon: Support Chunk-prefill
- ⏳ Coming Soon: Multi-layer MTP Layer - ⏳ Coming Soon: Multi-layer MTP Layer
- **Decoding with Hybrid MTP and Ngram Methods(Hybrid-MTP-with-Ngram)**
- Overview: A hybrid method combining MTP and Ngram. First, MTP generates N draft tokens, then Ngram matching is used to supplement additional draft tokens.
- Use Cases: Suitable when higher draft token coverage is required, leveraging both MTPs generation capability and the efficiency of Ngram matching.
--- ---
### Coming Soon ### Coming Soon
@@ -132,7 +138,13 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-password "scheduler_mtp" \ --scheduler-password "scheduler_mtp" \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${path_to_mtp_model}"}' & --speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${path_to_mtp_model}"}' &
``` ```
## Decoding with Hybrid MTP and Ngram Methods
When starting the service, you only need to modify the --speculative-config option.
For example, use MTP to generate two draft tokens, and then append three additional draft tokens from Ngram matching:
```
--speculative-config '{"method": "mtp", "num_model_steps": 2, "mtp_strategy": "with_ngram", "num_speculative_tokens": 5, "model": "'$model_path'/mtp"}'
```
## 🧠 Using Ngram-Based Decoding ## 🧠 Using Ngram-Based Decoding
This method uses an n-gram sliding window to match the prompt and generated tokens to predict draft tokens. It is particularly effective in scenarios with high input-output overlap (e.g., code completion, document search). This method uses an n-gram sliding window to match the prompt and generated tokens to predict draft tokens. It is particularly effective in scenarios with high input-output overlap (e.g., code completion, document search).

View File

@@ -14,6 +14,9 @@
- ⏳ 即将支持:兼容 Chunk Prefill - ⏳ 即将支持:兼容 Chunk Prefill
- ⏳ 即将支持:多层 MTP layer - ⏳ 即将支持:多层 MTP layer
- **混合MTP、Ngram方法解码(Hybrid-MTP-with-Ngram)**
- 方法概述混合MTP与Ngram方法先使用MTP产出N个草稿Token再使用Ngram匹配补充草稿Token。
- 使用场景适合在需要更多草稿Token时使用兼顾MTP生成能力与Ngram匹配的高效性。
--- ---
### ⏳ 规划中 ### ⏳ 规划中
@@ -110,7 +113,12 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-password "scheduler_mtp" \ --scheduler-password "scheduler_mtp" \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": ""${path_to_mtp_model}"}' & --speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": ""${path_to_mtp_model}"}' &
``` ```
## 使用混合MTP、Ngram方法解码
在启动服务时,只需改动 --speculative-config 即可。例如使用MTP产出两个DraftToken再额外拼接三个Ngram匹配的DraftToken
```
--speculative-config '{"method": "mtp", "num_model_steps": 2, "mtp_strategy": "with_ngram" ,"num_speculative_tokens": 5, "model": "'$model_path'/mtp"}'
```
## 🧠 使用 Ngram 解码 ## 🧠 使用 Ngram 解码
该算法通过 n-gram 窗口从 prompt 和已生成的 Token 中进行匹配生成草稿 Token适合输入和输出有很大 overlap 的场景,如代码续写、文档查询等。 该算法通过 n-gram 窗口从 prompt 和已生成的 Token 中进行匹配生成草稿 Token适合输入和输出有很大 overlap 的场景,如代码续写、文档查询等。
> 使用 4×H100量化方式选择 WINT4 > 使用 4×H100量化方式选择 WINT4

View File

@@ -295,6 +295,11 @@ class MTPProposer(Proposer):
# Same shape/dytpe with base model # Same shape/dytpe with base model
self.model_inputs["block_tables"] = paddle.clone(self.target_model_inputs["block_tables"]) self.model_inputs["block_tables"] = paddle.clone(self.target_model_inputs["block_tables"])
self.model_inputs["input_ids"] = paddle.clone(self.target_model_inputs["input_ids"]) self.model_inputs["input_ids"] = paddle.clone(self.target_model_inputs["input_ids"])
self.model_inputs["input_ids_cpu"] = paddle.full(
shape=[self.max_num_seqs, self.parallel_config.max_model_len],
fill_value=-1,
dtype="int64",
).cpu()
self.seq_lens_this_time_buffer = paddle.clone(self.target_model_inputs["seq_lens_this_time"]) self.seq_lens_this_time_buffer = paddle.clone(self.target_model_inputs["seq_lens_this_time"])
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"]) self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"])
@@ -401,11 +406,14 @@ class MTPProposer(Proposer):
input_ids = request.prompt_token_ids + request.output_token_ids input_ids = request.prompt_token_ids + request.output_token_ids
self.input_ids_len[idx] = length self.input_ids_len[idx] = length - 1
self.model_inputs["pre_ids"][idx : idx + 1] = -1 self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][ self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
idx : idx + 1, 1:length idx : idx + 1, 1:length
] ]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length].cpu()
encoder_block_num = len(request.block_tables) encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["block_tables"][idx : idx + 1, :] = -1
@@ -468,10 +476,17 @@ class MTPProposer(Proposer):
request = req_dicts[i] request = req_dicts[i]
idx = request.idx idx = request.idx
length = len(request.prompt_token_ids) length = len(request.prompt_token_ids)
self.input_ids_len[idx] = length self.input_ids_len[idx] = length - 1
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode": if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
length = len(request.prompt_token_ids) length = len(request.prompt_token_ids)
if length > 1:
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
request.prompt_token_ids
)[1:]
self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1] self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
prefill_token_num = self.max_draft_token_num + 1 prefill_token_num = self.max_draft_token_num + 1
self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor( self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor(
@@ -500,6 +515,9 @@ class MTPProposer(Proposer):
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[ self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids" "input_ids"
][idx : idx + 1, 1:length] ][idx : idx + 1, 1:length]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
request.prompt_token_ids
)[1:]
self.model_inputs["pre_ids"][idx : idx + 1] = -1 self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["step_idx"][idx : idx + 1] = 0 self.model_inputs["step_idx"][idx : idx + 1] = 0
if self.cache_config.enable_chunked_prefill: if self.cache_config.enable_chunked_prefill:
@@ -800,7 +818,7 @@ class MTPProposer(Proposer):
seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu() seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu()
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu() seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
hybrid_mtp_ngram( hybrid_mtp_ngram(
self.model_inputs["input_ids"]._copy_to(device, True), self.model_inputs["input_ids_cpu"],
self.input_ids_len, self.input_ids_len,
self.model_inputs["pre_ids"]._copy_to(device, True), self.model_inputs["pre_ids"]._copy_to(device, True),
self.model_inputs["step_idx"].cpu(), self.model_inputs["step_idx"].cpu(),