mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Feature] DeepseekV3 use pd_build_static_op (#2948)
Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
@@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
#include "mla_cache_kernel.cuh"
|
#include "mla_cache_kernel.cuh"
|
||||||
|
|
||||||
template <paddle::DataType T>
|
template <paddle::DataType T>
|
||||||
@@ -259,7 +260,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
PD_BUILD_OP(prefill_mla_write_cache)
|
PD_BUILD_STATIC_OP(prefill_mla_write_cache)
|
||||||
.Inputs({"kv_nope",
|
.Inputs({"kv_nope",
|
||||||
"kv_pe",
|
"kv_pe",
|
||||||
"kv_cache",
|
"kv_cache",
|
||||||
@@ -274,7 +275,7 @@ PD_BUILD_OP(prefill_mla_write_cache)
|
|||||||
"max_seq_len: int"})
|
"max_seq_len: int"})
|
||||||
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
|
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
|
||||||
|
|
||||||
PD_BUILD_OP(decode_mla_write_cache)
|
PD_BUILD_STATIC_OP(decode_mla_write_cache)
|
||||||
.Inputs({"kv_nope",
|
.Inputs({"kv_nope",
|
||||||
"kv_pe",
|
"kv_pe",
|
||||||
"kv_cache",
|
"kv_cache",
|
||||||
|
@@ -15,6 +15,7 @@
|
|||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
|
|
||||||
|
|
||||||
template <typename T, bool IS_NEOX>
|
template <typename T, bool IS_NEOX>
|
||||||
inline __device__ void apply_token_rotary_embedding_kernel(
|
inline __device__ void apply_token_rotary_embedding_kernel(
|
||||||
T* __restrict__ arr,
|
T* __restrict__ arr,
|
||||||
@@ -138,7 +139,7 @@ void FusedRotaryPositionEncoding(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_OP(fused_rotary_position_encoding)
|
PD_BUILD_STATIC_OP(fused_rotary_position_encoding)
|
||||||
.Inputs({"query", "key", "position_ids", "cos_sin_cache"})
|
.Inputs({"query", "key", "position_ids", "cos_sin_cache"})
|
||||||
.Outputs({"query_out", "key_out"})
|
.Outputs({"query_out", "key_out"})
|
||||||
.Attrs({"head_size: int", "is_neox: bool"})
|
.Attrs({"head_size: int", "is_neox: bool"})
|
||||||
|
@@ -15,6 +15,7 @@
|
|||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
|
|
||||||
|
|
||||||
__global__ void GetPositionIdsAndMaskEncoderBatchKernel(
|
__global__ void GetPositionIdsAndMaskEncoderBatchKernel(
|
||||||
const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度
|
const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度
|
||||||
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
|
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
|
||||||
@@ -74,7 +75,7 @@ void GetPositionIdsAndMaskEncoderBatch(
|
|||||||
bsz);
|
bsz);
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_OP(get_position_ids_and_mask_encoder_batch)
|
PD_BUILD_STATIC_OP(get_position_ids_and_mask_encoder_batch)
|
||||||
.Inputs({"seq_lens_encoder",
|
.Inputs({"seq_lens_encoder",
|
||||||
"seq_lens_decoder",
|
"seq_lens_decoder",
|
||||||
"seq_lens_this_time",
|
"seq_lens_this_time",
|
||||||
|
@@ -12,9 +12,9 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
|
|
||||||
|
|
||||||
#define CEILDIV(a,b) (((a+b-1)/b))
|
#define CEILDIV(a,b) (((a+b-1)/b))
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
@@ -189,7 +189,7 @@ std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& to
|
|||||||
return {sorted_ids, expert_ids, num_tokens_post_pad};
|
return {sorted_ids, expert_ids, num_tokens_post_pad};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_OP(tritonmoe_preprocess)
|
PD_BUILD_STATIC_OP(tritonmoe_preprocess)
|
||||||
.Inputs({"topk_ids"})
|
.Inputs({"topk_ids"})
|
||||||
.Attrs({"num_experts: int64_t", "GEMM_BLOCK_SIZE_M: int64_t"})
|
.Attrs({"num_experts: int64_t", "GEMM_BLOCK_SIZE_M: int64_t"})
|
||||||
.Outputs({"sorted_ids", "expert_ids", "num_tokens_post_pad"})
|
.Outputs({"sorted_ids", "expert_ids", "num_tokens_post_pad"})
|
||||||
|
@@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "append_attn/multi_head_latent_attention_kernel.h"
|
#include "append_attn/multi_head_latent_attention_kernel.h"
|
||||||
|
#include "helper.h"
|
||||||
#include "mla_attn/batch_mla_with_paged_kv_cache.h"
|
#include "mla_attn/batch_mla_with_paged_kv_cache.h"
|
||||||
|
|
||||||
template <paddle::DataType D>
|
template <paddle::DataType D>
|
||||||
@@ -410,7 +411,7 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_OP(multi_head_latent_attention)
|
PD_BUILD_STATIC_OP(multi_head_latent_attention)
|
||||||
.Inputs({"query",
|
.Inputs({"query",
|
||||||
"key_cache",
|
"key_cache",
|
||||||
"value_cache",
|
"value_cache",
|
||||||
|
@@ -18,6 +18,7 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
#include "noauxtc_kernel.h"
|
#include "noauxtc_kernel.h"
|
||||||
|
|
||||||
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||||
@@ -60,7 +61,7 @@ std::vector<std::vector<int64_t>> NoauxTcInferShape(
|
|||||||
return {scores_shape};
|
return {scores_shape};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_OP(noaux_tc)
|
PD_BUILD_STATIC_OP(noaux_tc)
|
||||||
.Inputs({"scores", "scores_with_bias"})
|
.Inputs({"scores", "scores_with_bias"})
|
||||||
.Outputs({"output_tensor"})
|
.Outputs({"output_tensor"})
|
||||||
.Attrs({"n_group: int",
|
.Attrs({"n_group: int",
|
||||||
|
Reference in New Issue
Block a user