support mm mtp (#4013)

This commit is contained in:
xiaoxiaohehe001
2025-09-09 13:55:45 +08:00
committed by GitHub
parent c753f1fc9e
commit 5223065d59
11 changed files with 278 additions and 54 deletions

View File

@@ -286,6 +286,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
@@ -309,6 +310,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,

View File

@@ -193,7 +193,8 @@ __global__ void append_speculate_cache_rope_kernel(
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size) {
const int gqa_group_size,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
@@ -253,8 +254,9 @@ __global__ void append_speculate_cache_rope_kernel(
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -476,7 +478,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size) {
const int gqa_group_size,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -522,8 +525,9 @@ __global__ void append_speculate_cache_int8_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
}
@@ -583,10 +587,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
T scale;
if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]);
} else {
scale = __ldg(&cache_v_scales[kv_head_idx]);

View File

@@ -39,7 +39,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style) {
const bool use_neox_style,
const bool rope_3d) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
@@ -96,7 +97,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads);
kv_num_heads,
rope_3d);
}
}
@@ -125,7 +127,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style) {
const bool use_neox_style,
const bool rope_3d) {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -191,7 +194,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads);
kv_num_heads,
rope_3d);
}
}
@@ -313,6 +317,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
@@ -368,7 +373,8 @@ void SpeculateWriteCacheWithRoPEKernel(
bsz,
token_nums,
stream,
use_neox_rotary_style);
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
@@ -401,7 +407,8 @@ void SpeculateWriteCacheWithRoPEKernel(
bsz,
token_nums,
stream,
use_neox_rotary_style);
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
@@ -434,7 +441,8 @@ void SpeculateWriteCacheWithRoPEKernel(
bsz,
token_nums,
stream,
use_neox_rotary_style);
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
@@ -500,6 +508,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
@@ -526,6 +535,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
@@ -551,6 +561,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
@@ -578,6 +589,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,

View File

@@ -35,6 +35,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,

View File

@@ -335,6 +335,19 @@ void TextImageIndexOut(const paddle::Tensor &token_type_ids,
const paddle::Tensor &text_input,
const paddle::Tensor &image_input);
void LimitContentLen(const paddle::Tensor& next_tokens,
const paddle::Tensor& end_thinking_tokens,
const paddle::Tensor& max_content_len,
const paddle::Tensor& max_think_len,
const paddle::Tensor& step_idx,
const paddle::Tensor& eos_token_ids,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& limit_content_status,
const paddle::Tensor& enable_thinking,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags);
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
paddle::Tensor &image_input,
paddle::Tensor &token_type_ids,

View File

@@ -0,0 +1,186 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
__global__ void limit_content_len(
int64_t* next_tokens,
const int64_t* end_thinking_tokens,
int* max_content_lens,
const int* max_think_lens,
int64_t* step_idx,
const int64_t* eos_token_ids,
int64_t* max_dec_lens,
int* limit_content_status,
const bool* enable_thinking,
int* accept_num,
int* seq_lens_decoder,
bool* stop_flags,
const int tokens_per_step,
const int bs,
const int end_thinking_token_num,
const int eos_token_id_len) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= bs) return;
if (!enable_thinking[idx]) return;
const int original_accept_num = accept_num[idx];
if (original_accept_num <= 0) return;
int current_limit_content_status = limit_content_status[idx];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_content_status == 2 && stop_flags[idx]) {
return;
}
const int max_think_len_reg = max_think_lens[idx];
const int64_t end_thinking_token_reg = end_thinking_tokens[0];
int64_t current_max_dec_len = max_dec_lens[idx];
int new_accept_num = original_accept_num;
const int64_t current_base_step = step_idx[idx] - original_accept_num + 1;
for (int token_offset = 0; token_offset < original_accept_num; token_offset++) {
const int token_idx = idx * tokens_per_step + token_offset;
int64_t next_token_reg = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
bool condition_triggered = false;
bool is_eos = false;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
if (current_limit_content_status < 1) {
bool should_transform = false;
// 当开启思考长度控制时,检查是否超时
if (max_think_len_reg > 0 && current_step >= max_think_len_reg) {
should_transform = true;
} else {
// 检查是否生成了EOS
for (int j = 0; j < eos_token_id_len; j++) {
if (eos_token_ids[j] == next_token_reg) {
is_eos = true;
should_transform = true;
break;
}
}
}
if (should_transform) {
// 强制将当前token替换为结束思考的token
next_token_reg = end_thinking_token_reg;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_content_status = 1;
condition_triggered = true; // 因为修改了token需要截断
// 只在EOS触发时清除stop_flags
if (is_eos && stop_flags[idx]) {
stop_flags[idx] = false;
}
}
}
// ======================= 思考结束处理 =======================
// 阶段 2: 检查是否已满足结束思考的条件 (status < 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型自己生成了 end_thinking_token
// 2. status == 1: 上一阶段强制注入了 end_thinking_token
if (current_limit_content_status < 2) {
if (next_token_reg == end_thinking_token_reg) {
// 确认思考结束,将状态推进到 2 (响应阶段)
current_limit_content_status = 2;
}
}
next_tokens[token_idx] = next_token_reg;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}
// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[idx] -= discarded_tokens;
seq_lens_decoder[idx] -= discarded_tokens;
}
accept_num[idx] = new_accept_num;
limit_content_status[idx] = current_limit_content_status;
max_dec_lens[idx] = current_max_dec_len;
}
void LimitContentLen(const paddle::Tensor& next_tokens,
const paddle::Tensor& end_thinking_tokens,
const paddle::Tensor& max_content_len,
const paddle::Tensor& max_think_len,
const paddle::Tensor& step_idx,
const paddle::Tensor& eos_token_ids,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& limit_content_status,
const paddle::Tensor& enable_thinking,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags) {
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
const int end_thinking_token_num = end_thinking_tokens.shape()[0];
const int end_length = eos_token_ids.shape()[0];
PD_CHECK(end_thinking_token_num == 1, "limit_content_len only support end_thinking_token_num = 1 for now.");
dim3 grid(1);
dim3 block(1024);
limit_content_len<<<grid, block>>>(
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_thinking_tokens.data<int64_t>(),
const_cast<int *>(max_content_len.data<int>()),
max_think_len.data<int>(),
const_cast<int64_t *>(step_idx.data<int64_t>()),
eos_token_ids.data<int64_t>(),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
const_cast<int *>(limit_content_status.data<int>()),
enable_thinking.data<bool>(),
const_cast<int *>(accept_num.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
tokens_per_step,
batch_size,
end_thinking_token_num,
end_length);
}
PD_BUILD_STATIC_OP(limit_content_len)
.Inputs({"next_tokens",
"end_thinking_tokens",
"max_content_len",
"max_think_len",
"step_idx",
"eos_token_ids",
"max_dec_len",
"limit_content_status",
"enable_thinking",
"accept_num",
"seq_lens_decoder",
"stop_flags"})
.Outputs({"next_tokens_out", "max_dec_len_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"},
{"max_dec_len", "max_dec_len_out"}})
.SetKernelFn(PD_KERNEL(LimitContentLen));

View File

@@ -293,6 +293,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",
"gpu_ops/custom_all_reduce/all_reduce.cu",
"gpu_ops/limit_content_len.cu",
]
# pd_disaggregation

View File

@@ -33,6 +33,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
min_p_sampling,
top_k_top_p_sampling,
)
from fastdeploy.model_executor.ops.gpu import limit_content_len
from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
@@ -304,6 +305,7 @@ class SpeculativeSampler(nn.Layer):
self.speculative_verify_window = fd_config.speculative_config.verify_window
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
self.fd_config = fd_config
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
@@ -382,6 +384,22 @@ class SpeculativeSampler(nn.Layer):
self.speculative_benchmark_mode,
)
if hasattr(self.fd_config.model_config, "think_end_id") and self.fd_config.model_config.think_end_id > 0:
limit_content_len(
share_inputs["accept_tokens"],
self.fd_config.model_config.think_end_id,
share_inputs["max_content_len"],
share_inputs["max_think_len"],
share_inputs["step_idx"],
sampling_metadata.eos_token_ids,
share_inputs["max_dec_len"],
share_inputs["limit_content_status"],
share_inputs["enable_thinking"],
share_inputs["accept_num"],
share_inputs["seq_lens_decoder"],
share_inputs["stop_flags"],
)
return None

View File

@@ -46,7 +46,6 @@ from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
extract_text_token_output,
text_image_gather_scatter,
text_image_index_out,
)
@@ -472,26 +471,6 @@ class Ernie4_5_VLModel(nn.Layer):
)
hidden_states = hidden_states + residual
# -----------------------
hidden_states = hidden_states.cast("float32")
score_text = hidden_states
if image_input is not None:
token_type_ids = token_type_ids.reshape([-1])
text_pos_shifted = token_type_ids[:token_num] == 0
score_text = hidden_states[text_pos_shifted.reshape([-1])]
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time.squeeze(-1), k=1)
hidden_states = extract_text_token_output(
max_seq_len,
max_seq_len_index.cast("int32"),
image_token_num.cast("int32"),
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
score_text,
).cast(self._dtype)
# -----------------------
out = self.norm(hidden_states)
return out

View File

@@ -269,12 +269,19 @@ class MTPProposer(Proposer):
)
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
if len(self.main_model_inputs["rope_emb"].shape) == 5:
self.model_inputs["rope_emb"] = get_rope(
rotary_dim=self.model_config.head_dim,
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
)
else:
self.model_inputs["max_content_len"] = paddle.clone(self.main_model_inputs["max_content_len"])
self.model_inputs["max_think_len"] = paddle.clone(self.main_model_inputs["max_think_len"])
self.model_inputs["limit_content_status"] = paddle.clone(self.main_model_inputs["limit_content_status"])
self.model_inputs["enable_thinking"] = paddle.clone(self.main_model_inputs["enable_thinking"])
self.model_inputs["rope_emb"] = paddle.clone(self.main_model_inputs["rope_emb"])
# self.model_inputs["caches"] = self.cache_kvs
# Inherit generation hyperparameters from the main model for consistency
self.model_inputs["top_p"] = self.main_model_inputs["top_p"]
@@ -294,6 +301,7 @@ class MTPProposer(Proposer):
# Integrate the updated results in model forward
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
self.model_inputs["substep"] = 0
self.max_num_seqs = self.main_model_inputs["draft_tokens"].shape[0]
# Input tokens
self.model_inputs["draft_tokens"] = paddle.full(

View File

@@ -1210,7 +1210,6 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["image_features"],
self.forward_meta,
)
hidden_states = model_output
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],