mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Speculative Decoding][MTP]Support attn mask offset (#4641)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* [MTP]Merge support attn (#4591) * support mask_offset in speculate decoding * fix dummpy run output * add unit test * fix unit test import * support attn_mask_offset in mtp mode * add update_attn_mask op * fix unit test && fix code-style
This commit is contained in:
@@ -46,6 +46,7 @@ from fastdeploy.model_executor.ops.gpu import (
|
||||
share_external_data,
|
||||
speculate_get_logits,
|
||||
speculate_save_output_topk,
|
||||
update_attn_mask_offsets,
|
||||
)
|
||||
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
|
||||
|
||||
@@ -441,6 +442,21 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["cu_next_token_offset"] = paddle.full(
|
||||
shape=[self.max_num_seqs + 1], fill_value=0, dtype="int32"
|
||||
)
|
||||
self.model_inputs["mask_rollback"] = paddle.full([self.max_num_seqs, 1], 0, dtype="int32")
|
||||
# attn_mask
|
||||
if self.enable_mm:
|
||||
self.model_inputs["attn_mask_offsets"] = paddle.full(
|
||||
shape=[self.max_num_seqs * self.max_model_len], fill_value=-1, dtype="int32"
|
||||
)
|
||||
self.model_inputs["attn_mask_offsets_full"] = paddle.full(
|
||||
[self.max_num_seqs, self.max_model_len], -1, dtype="int32"
|
||||
)
|
||||
self.model_inputs["attn_mask_offsets_decoder"] = paddle.full([self.max_num_seqs, 1], -1, dtype="int32")
|
||||
self.model_inputs["decode_states"] = paddle.full(
|
||||
[self.max_num_seqs, self.max_draft_token_num + 1],
|
||||
-1,
|
||||
dtype="int32",
|
||||
)
|
||||
|
||||
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
|
||||
|
||||
@@ -482,6 +498,16 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = (
|
||||
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
|
||||
)
|
||||
if self.enable_mm:
|
||||
inputs = request.multimodal_inputs
|
||||
self.model_inputs["attn_mask_offsets_full"][idx][0 : prefill_end_index - prefill_start_index] = (
|
||||
paddle.to_tensor(
|
||||
inputs["attention_mask_offset"][prefill_start_index:prefill_end_index], dtype="int32"
|
||||
)
|
||||
)
|
||||
self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = (
|
||||
inputs["attention_mask_offset"][prefill_end_index - 1] + 1
|
||||
)
|
||||
|
||||
# has_prefill_task = True
|
||||
elif request.task_type.value == RequestType.DECODE.value: # decode task
|
||||
@@ -621,6 +647,7 @@ class MTPProposer(Proposer):
|
||||
kv_batch_ids=self.model_inputs["kv_batch_ids"],
|
||||
kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"],
|
||||
kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"],
|
||||
attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.enable_mm else None,
|
||||
)
|
||||
|
||||
# Initialzie attention meta data
|
||||
@@ -754,6 +781,21 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
)
|
||||
|
||||
if self.enable_mm:
|
||||
attn_mask_offsets = update_attn_mask_offsets(
|
||||
ids_remove_padding,
|
||||
getattr(self.model_inputs, "seq_lens_this_time", self.seq_lens_this_time_buffer),
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
cu_seqlens_q,
|
||||
self.model_inputs["attn_mask_offsets_full"],
|
||||
self.model_inputs["attn_mask_offsets_decoder"],
|
||||
self.model_inputs["is_block_step"],
|
||||
self.model_inputs["decode_states"],
|
||||
self.model_inputs["mask_rollback"],
|
||||
)[0]
|
||||
self.model_inputs["attn_mask_offsets"].copy_(attn_mask_offsets, False)
|
||||
|
||||
# Initialize forward meta data
|
||||
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
||||
self.model_inputs["batch_id_per_token"][:] = -1
|
||||
|
||||
Reference in New Issue
Block a user