mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-09 02:20:17 +08:00
support mm mtp (#4013)
This commit is contained in:
@@ -286,6 +286,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
|||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
cache_quant_type_str,
|
cache_quant_type_str,
|
||||||
use_neox_rotary_style,
|
use_neox_rotary_style,
|
||||||
|
rope_3d,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
exec_stream,
|
exec_stream,
|
||||||
&qkv_out,
|
&qkv_out,
|
||||||
@@ -309,6 +310,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
|||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
cache_quant_type_str,
|
cache_quant_type_str,
|
||||||
use_neox_rotary_style,
|
use_neox_rotary_style,
|
||||||
|
rope_3d,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
exec_stream,
|
exec_stream,
|
||||||
&qkv_out,
|
&qkv_out,
|
||||||
|
@@ -193,7 +193,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
|||||||
const int head_size,
|
const int head_size,
|
||||||
const int block_size,
|
const int block_size,
|
||||||
const int elem_cnt,
|
const int elem_cnt,
|
||||||
const int gqa_group_size) {
|
const int gqa_group_size,
|
||||||
|
const bool rope_3d) {
|
||||||
using LoadT = AlignedVector<T, VecSize>;
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
using LoadFloat = AlignedVector<float, VecSize>;
|
using LoadFloat = AlignedVector<float, VecSize>;
|
||||||
using LoadInT = AlignedVector<InT, VecSize>;
|
using LoadInT = AlignedVector<InT, VecSize>;
|
||||||
@@ -253,8 +254,9 @@ __global__ void append_speculate_cache_rope_kernel(
|
|||||||
if (hi < num_heads + gqa_group_size) {
|
if (hi < num_heads + gqa_group_size) {
|
||||||
// q k rope
|
// q k rope
|
||||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||||
|
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < HalfVecSize; i++) {
|
for (int i = 0; i < HalfVecSize; i++) {
|
||||||
@@ -476,7 +478,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
|||||||
const int block_size,
|
const int block_size,
|
||||||
const float max_bound,
|
const float max_bound,
|
||||||
const float min_bound,
|
const float min_bound,
|
||||||
const int gqa_group_size) {
|
const int gqa_group_size,
|
||||||
|
const bool rope_3d) {
|
||||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||||
constexpr int NUM_WARPS = 4;
|
constexpr int NUM_WARPS = 4;
|
||||||
@@ -522,8 +525,9 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
|||||||
|
|
||||||
// q rope
|
// q rope
|
||||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||||
|
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||||
if (qkv_out_scales) {
|
if (qkv_out_scales) {
|
||||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||||
}
|
}
|
||||||
@@ -583,10 +587,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
|||||||
T scale;
|
T scale;
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
if (head_idx < num_heads + gqa_group_size) {
|
||||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||||
|
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||||
} else {
|
} else {
|
||||||
scale = __ldg(&cache_v_scales[kv_head_idx]);
|
scale = __ldg(&cache_v_scales[kv_head_idx]);
|
||||||
|
@@ -39,7 +39,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
|||||||
const int bsz,
|
const int bsz,
|
||||||
const int token_num,
|
const int token_num,
|
||||||
const cudaStream_t& stream,
|
const cudaStream_t& stream,
|
||||||
const bool use_neox_style) {
|
const bool use_neox_style,
|
||||||
|
const bool rope_3d) {
|
||||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||||
|
|
||||||
const uint32_t elem_nums =
|
const uint32_t elem_nums =
|
||||||
@@ -96,7 +97,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
|||||||
dim_head,
|
dim_head,
|
||||||
block_size,
|
block_size,
|
||||||
elem_nums,
|
elem_nums,
|
||||||
kv_num_heads);
|
kv_num_heads,
|
||||||
|
rope_3d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +127,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
|||||||
const int bsz,
|
const int bsz,
|
||||||
const int token_num,
|
const int token_num,
|
||||||
const cudaStream_t& stream,
|
const cudaStream_t& stream,
|
||||||
const bool use_neox_style) {
|
const bool use_neox_style,
|
||||||
|
const bool rope_3d) {
|
||||||
constexpr int num_warps = 4;
|
constexpr int num_warps = 4;
|
||||||
const int all_warps =
|
const int all_warps =
|
||||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||||
@@ -191,7 +194,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
|||||||
block_size,
|
block_size,
|
||||||
127.0f,
|
127.0f,
|
||||||
-127.0f,
|
-127.0f,
|
||||||
kv_num_heads);
|
kv_num_heads,
|
||||||
|
rope_3d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -313,6 +317,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
|
const bool rope_3d,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
@@ -368,7 +373,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
|||||||
bsz,
|
bsz,
|
||||||
token_nums,
|
token_nums,
|
||||||
stream,
|
stream,
|
||||||
use_neox_rotary_style);
|
use_neox_rotary_style,
|
||||||
|
rope_3d);
|
||||||
} else if (cache_quant_type_str == "cache_int8") {
|
} else if (cache_quant_type_str == "cache_int8") {
|
||||||
append_speculate_cache_int8_rope(
|
append_speculate_cache_int8_rope(
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
@@ -401,7 +407,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
|||||||
bsz,
|
bsz,
|
||||||
token_nums,
|
token_nums,
|
||||||
stream,
|
stream,
|
||||||
use_neox_rotary_style);
|
use_neox_rotary_style,
|
||||||
|
rope_3d);
|
||||||
} else if (cache_quant_type_str == "cache_fp8") {
|
} else if (cache_quant_type_str == "cache_fp8") {
|
||||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
@@ -434,7 +441,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
|||||||
bsz,
|
bsz,
|
||||||
token_nums,
|
token_nums,
|
||||||
stream,
|
stream,
|
||||||
use_neox_rotary_style);
|
use_neox_rotary_style,
|
||||||
|
rope_3d);
|
||||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||||
append_speculate_cache_int4_rope(
|
append_speculate_cache_int4_rope(
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
@@ -500,6 +508,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
|
const bool rope_3d,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
@@ -526,6 +535,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
|
const bool rope_3d,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
@@ -551,6 +561,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
|
const bool rope_3d,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
@@ -578,6 +589,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
|
const bool rope_3d,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
|
@@ -35,6 +35,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
|
const bool rope_3d,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
|
@@ -335,6 +335,19 @@ void TextImageIndexOut(const paddle::Tensor &token_type_ids,
|
|||||||
const paddle::Tensor &text_input,
|
const paddle::Tensor &text_input,
|
||||||
const paddle::Tensor &image_input);
|
const paddle::Tensor &image_input);
|
||||||
|
|
||||||
|
void LimitContentLen(const paddle::Tensor& next_tokens,
|
||||||
|
const paddle::Tensor& end_thinking_tokens,
|
||||||
|
const paddle::Tensor& max_content_len,
|
||||||
|
const paddle::Tensor& max_think_len,
|
||||||
|
const paddle::Tensor& step_idx,
|
||||||
|
const paddle::Tensor& eos_token_ids,
|
||||||
|
const paddle::Tensor& max_dec_len,
|
||||||
|
const paddle::Tensor& limit_content_status,
|
||||||
|
const paddle::Tensor& enable_thinking,
|
||||||
|
const paddle::Tensor& accept_num,
|
||||||
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
|
const paddle::Tensor& stop_flags);
|
||||||
|
|
||||||
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||||
paddle::Tensor &image_input,
|
paddle::Tensor &image_input,
|
||||||
paddle::Tensor &token_type_ids,
|
paddle::Tensor &token_type_ids,
|
||||||
|
186
custom_ops/gpu_ops/limit_content_len.cu
Normal file
186
custom_ops/gpu_ops/limit_content_len.cu
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
|
__global__ void limit_content_len(
|
||||||
|
int64_t* next_tokens,
|
||||||
|
const int64_t* end_thinking_tokens,
|
||||||
|
int* max_content_lens,
|
||||||
|
const int* max_think_lens,
|
||||||
|
int64_t* step_idx,
|
||||||
|
const int64_t* eos_token_ids,
|
||||||
|
int64_t* max_dec_lens,
|
||||||
|
int* limit_content_status,
|
||||||
|
const bool* enable_thinking,
|
||||||
|
int* accept_num,
|
||||||
|
int* seq_lens_decoder,
|
||||||
|
bool* stop_flags,
|
||||||
|
const int tokens_per_step,
|
||||||
|
const int bs,
|
||||||
|
const int end_thinking_token_num,
|
||||||
|
const int eos_token_id_len) {
|
||||||
|
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx >= bs) return;
|
||||||
|
|
||||||
|
if (!enable_thinking[idx]) return;
|
||||||
|
|
||||||
|
const int original_accept_num = accept_num[idx];
|
||||||
|
if (original_accept_num <= 0) return;
|
||||||
|
|
||||||
|
int current_limit_content_status = limit_content_status[idx];
|
||||||
|
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
|
||||||
|
if (current_limit_content_status == 2 && stop_flags[idx]) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int max_think_len_reg = max_think_lens[idx];
|
||||||
|
|
||||||
|
const int64_t end_thinking_token_reg = end_thinking_tokens[0];
|
||||||
|
|
||||||
|
int64_t current_max_dec_len = max_dec_lens[idx];
|
||||||
|
int new_accept_num = original_accept_num;
|
||||||
|
|
||||||
|
const int64_t current_base_step = step_idx[idx] - original_accept_num + 1;
|
||||||
|
|
||||||
|
for (int token_offset = 0; token_offset < original_accept_num; token_offset++) {
|
||||||
|
const int token_idx = idx * tokens_per_step + token_offset;
|
||||||
|
int64_t next_token_reg = next_tokens[token_idx];
|
||||||
|
const int64_t current_step = current_base_step + token_offset;
|
||||||
|
|
||||||
|
bool condition_triggered = false;
|
||||||
|
bool is_eos = false;
|
||||||
|
|
||||||
|
// ======================= 思考阶段控制 =======================
|
||||||
|
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
|
||||||
|
if (current_limit_content_status < 1) {
|
||||||
|
bool should_transform = false;
|
||||||
|
|
||||||
|
// 当开启思考长度控制时,检查是否超时
|
||||||
|
if (max_think_len_reg > 0 && current_step >= max_think_len_reg) {
|
||||||
|
should_transform = true;
|
||||||
|
} else {
|
||||||
|
// 检查是否生成了EOS
|
||||||
|
for (int j = 0; j < eos_token_id_len; j++) {
|
||||||
|
if (eos_token_ids[j] == next_token_reg) {
|
||||||
|
is_eos = true;
|
||||||
|
should_transform = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (should_transform) {
|
||||||
|
// 强制将当前token替换为结束思考的token
|
||||||
|
next_token_reg = end_thinking_token_reg;
|
||||||
|
// 将状态推进到 1, 表示 "正在结束思考"
|
||||||
|
current_limit_content_status = 1;
|
||||||
|
condition_triggered = true; // 因为修改了token,需要截断
|
||||||
|
// 只在EOS触发时清除stop_flags
|
||||||
|
if (is_eos && stop_flags[idx]) {
|
||||||
|
stop_flags[idx] = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ======================= 思考结束处理 =======================
|
||||||
|
// 阶段 2: 检查是否已满足结束思考的条件 (status < 2)
|
||||||
|
// 这种情况会处理两种场景:
|
||||||
|
// 1. status == 0: 模型自己生成了 end_thinking_token
|
||||||
|
// 2. status == 1: 上一阶段强制注入了 end_thinking_token
|
||||||
|
if (current_limit_content_status < 2) {
|
||||||
|
if (next_token_reg == end_thinking_token_reg) {
|
||||||
|
// 确认思考结束,将状态推进到 2 (响应阶段)
|
||||||
|
current_limit_content_status = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
next_tokens[token_idx] = next_token_reg;
|
||||||
|
|
||||||
|
if (condition_triggered) {
|
||||||
|
new_accept_num = token_offset + 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新全局状态
|
||||||
|
int discarded_tokens = original_accept_num - new_accept_num;
|
||||||
|
if (discarded_tokens > 0) {
|
||||||
|
step_idx[idx] -= discarded_tokens;
|
||||||
|
seq_lens_decoder[idx] -= discarded_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
accept_num[idx] = new_accept_num;
|
||||||
|
limit_content_status[idx] = current_limit_content_status;
|
||||||
|
max_dec_lens[idx] = current_max_dec_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LimitContentLen(const paddle::Tensor& next_tokens,
|
||||||
|
const paddle::Tensor& end_thinking_tokens,
|
||||||
|
const paddle::Tensor& max_content_len,
|
||||||
|
const paddle::Tensor& max_think_len,
|
||||||
|
const paddle::Tensor& step_idx,
|
||||||
|
const paddle::Tensor& eos_token_ids,
|
||||||
|
const paddle::Tensor& max_dec_len,
|
||||||
|
const paddle::Tensor& limit_content_status,
|
||||||
|
const paddle::Tensor& enable_thinking,
|
||||||
|
const paddle::Tensor& accept_num,
|
||||||
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
|
const paddle::Tensor& stop_flags) {
|
||||||
|
|
||||||
|
const int batch_size = next_tokens.shape()[0];
|
||||||
|
const int tokens_per_step = next_tokens.shape()[1];
|
||||||
|
const int end_thinking_token_num = end_thinking_tokens.shape()[0];
|
||||||
|
const int end_length = eos_token_ids.shape()[0];
|
||||||
|
PD_CHECK(end_thinking_token_num == 1, "limit_content_len only support end_thinking_token_num = 1 for now.");
|
||||||
|
|
||||||
|
dim3 grid(1);
|
||||||
|
dim3 block(1024);
|
||||||
|
|
||||||
|
limit_content_len<<<grid, block>>>(
|
||||||
|
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||||
|
end_thinking_tokens.data<int64_t>(),
|
||||||
|
const_cast<int *>(max_content_len.data<int>()),
|
||||||
|
max_think_len.data<int>(),
|
||||||
|
const_cast<int64_t *>(step_idx.data<int64_t>()),
|
||||||
|
eos_token_ids.data<int64_t>(),
|
||||||
|
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
|
||||||
|
const_cast<int *>(limit_content_status.data<int>()),
|
||||||
|
enable_thinking.data<bool>(),
|
||||||
|
const_cast<int *>(accept_num.data<int>()),
|
||||||
|
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||||
|
const_cast<bool *>(stop_flags.data<bool>()),
|
||||||
|
tokens_per_step,
|
||||||
|
batch_size,
|
||||||
|
end_thinking_token_num,
|
||||||
|
end_length);
|
||||||
|
}
|
||||||
|
PD_BUILD_STATIC_OP(limit_content_len)
|
||||||
|
.Inputs({"next_tokens",
|
||||||
|
"end_thinking_tokens",
|
||||||
|
"max_content_len",
|
||||||
|
"max_think_len",
|
||||||
|
"step_idx",
|
||||||
|
"eos_token_ids",
|
||||||
|
"max_dec_len",
|
||||||
|
"limit_content_status",
|
||||||
|
"enable_thinking",
|
||||||
|
"accept_num",
|
||||||
|
"seq_lens_decoder",
|
||||||
|
"stop_flags"})
|
||||||
|
.Outputs({"next_tokens_out", "max_dec_len_out"})
|
||||||
|
.SetInplaceMap({{"next_tokens", "next_tokens_out"},
|
||||||
|
{"max_dec_len", "max_dec_len_out"}})
|
||||||
|
.SetKernelFn(PD_KERNEL(LimitContentLen));
|
@@ -293,6 +293,7 @@ elif paddle.is_compiled_with_cuda():
|
|||||||
"gpu_ops/fused_rotary_position_encoding.cu",
|
"gpu_ops/fused_rotary_position_encoding.cu",
|
||||||
"gpu_ops/noaux_tc.cu",
|
"gpu_ops/noaux_tc.cu",
|
||||||
"gpu_ops/custom_all_reduce/all_reduce.cu",
|
"gpu_ops/custom_all_reduce/all_reduce.cu",
|
||||||
|
"gpu_ops/limit_content_len.cu",
|
||||||
]
|
]
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
|
@@ -33,6 +33,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
|
|||||||
min_p_sampling,
|
min_p_sampling,
|
||||||
top_k_top_p_sampling,
|
top_k_top_p_sampling,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.ops.gpu import limit_content_len
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
||||||
|
|
||||||
@@ -304,6 +305,7 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
self.speculative_verify_window = fd_config.speculative_config.verify_window
|
self.speculative_verify_window = fd_config.speculative_config.verify_window
|
||||||
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
|
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
|
||||||
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
|
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
|
||||||
|
self.fd_config = fd_config
|
||||||
|
|
||||||
def pre_process(self, skip_idx_list: List[int] = []):
|
def pre_process(self, skip_idx_list: List[int] = []):
|
||||||
"""pre process before running"""
|
"""pre process before running"""
|
||||||
@@ -382,6 +384,22 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
self.speculative_benchmark_mode,
|
self.speculative_benchmark_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hasattr(self.fd_config.model_config, "think_end_id") and self.fd_config.model_config.think_end_id > 0:
|
||||||
|
limit_content_len(
|
||||||
|
share_inputs["accept_tokens"],
|
||||||
|
self.fd_config.model_config.think_end_id,
|
||||||
|
share_inputs["max_content_len"],
|
||||||
|
share_inputs["max_think_len"],
|
||||||
|
share_inputs["step_idx"],
|
||||||
|
sampling_metadata.eos_token_ids,
|
||||||
|
share_inputs["max_dec_len"],
|
||||||
|
share_inputs["limit_content_status"],
|
||||||
|
share_inputs["enable_thinking"],
|
||||||
|
share_inputs["accept_num"],
|
||||||
|
share_inputs["seq_lens_decoder"],
|
||||||
|
share_inputs["stop_flags"],
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@@ -46,7 +46,6 @@ from fastdeploy.platforms import current_platform
|
|||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
extract_text_token_output,
|
|
||||||
text_image_gather_scatter,
|
text_image_gather_scatter,
|
||||||
text_image_index_out,
|
text_image_index_out,
|
||||||
)
|
)
|
||||||
@@ -472,26 +471,6 @@ class Ernie4_5_VLModel(nn.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
# -----------------------
|
|
||||||
hidden_states = hidden_states.cast("float32")
|
|
||||||
score_text = hidden_states
|
|
||||||
|
|
||||||
if image_input is not None:
|
|
||||||
token_type_ids = token_type_ids.reshape([-1])
|
|
||||||
text_pos_shifted = token_type_ids[:token_num] == 0
|
|
||||||
score_text = hidden_states[text_pos_shifted.reshape([-1])]
|
|
||||||
max_seq_len, max_seq_len_index = paddle.topk(forward_meta.seq_lens_this_time.squeeze(-1), k=1)
|
|
||||||
hidden_states = extract_text_token_output(
|
|
||||||
max_seq_len,
|
|
||||||
max_seq_len_index.cast("int32"),
|
|
||||||
image_token_num.cast("int32"),
|
|
||||||
forward_meta.seq_lens_this_time,
|
|
||||||
forward_meta.cu_seqlens_q,
|
|
||||||
score_text,
|
|
||||||
).cast(self._dtype)
|
|
||||||
# -----------------------
|
|
||||||
|
|
||||||
out = self.norm(hidden_states)
|
out = self.norm(hidden_states)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
@@ -269,12 +269,19 @@ class MTPProposer(Proposer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||||
|
if len(self.main_model_inputs["rope_emb"].shape) == 5:
|
||||||
self.model_inputs["rope_emb"] = get_rope(
|
self.model_inputs["rope_emb"] = get_rope(
|
||||||
rotary_dim=self.model_config.head_dim,
|
rotary_dim=self.model_config.head_dim,
|
||||||
position_ids=tmp_position_ids,
|
position_ids=tmp_position_ids,
|
||||||
base=self.model_config.rope_theta,
|
base=self.model_config.rope_theta,
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.model_inputs["max_content_len"] = paddle.clone(self.main_model_inputs["max_content_len"])
|
||||||
|
self.model_inputs["max_think_len"] = paddle.clone(self.main_model_inputs["max_think_len"])
|
||||||
|
self.model_inputs["limit_content_status"] = paddle.clone(self.main_model_inputs["limit_content_status"])
|
||||||
|
self.model_inputs["enable_thinking"] = paddle.clone(self.main_model_inputs["enable_thinking"])
|
||||||
|
self.model_inputs["rope_emb"] = paddle.clone(self.main_model_inputs["rope_emb"])
|
||||||
# self.model_inputs["caches"] = self.cache_kvs
|
# self.model_inputs["caches"] = self.cache_kvs
|
||||||
# Inherit generation hyperparameters from the main model for consistency
|
# Inherit generation hyperparameters from the main model for consistency
|
||||||
self.model_inputs["top_p"] = self.main_model_inputs["top_p"]
|
self.model_inputs["top_p"] = self.main_model_inputs["top_p"]
|
||||||
@@ -294,6 +301,7 @@ class MTPProposer(Proposer):
|
|||||||
# Integrate the updated results in model forward
|
# Integrate the updated results in model forward
|
||||||
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
|
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
|
||||||
self.model_inputs["substep"] = 0
|
self.model_inputs["substep"] = 0
|
||||||
|
self.max_num_seqs = self.main_model_inputs["draft_tokens"].shape[0]
|
||||||
|
|
||||||
# Input tokens
|
# Input tokens
|
||||||
self.model_inputs["draft_tokens"] = paddle.full(
|
self.model_inputs["draft_tokens"] = paddle.full(
|
||||||
|
@@ -1210,7 +1210,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["image_features"],
|
self.share_inputs["image_features"],
|
||||||
self.forward_meta,
|
self.forward_meta,
|
||||||
)
|
)
|
||||||
hidden_states = model_output
|
|
||||||
else:
|
else:
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||||
|
Reference in New Issue
Block a user