optimize w4a8 decoding (#3050)

This commit is contained in:
Yuan Xiaolan
2025-07-28 22:20:13 +08:00
committed by GitHub
parent e80ea8a71b
commit 7d87aaace8
6 changed files with 253 additions and 36 deletions

View File

@@ -240,6 +240,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
MoeFastHardamardWrapper<data_t, int8_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
down_proj_shift, // down_proj_shift->data<T>(),
down_proj_smooth, // down_proj_smooth->data<T>(),
down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr,
@@ -248,6 +249,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
-127.0,
expanded_active_expert_rows,
inter_size / 2,
num_max_tokens_per_expert,
used_in_ep_low_latency,
reinterpret_cast<int8_t *>(int8_act_out->ptr()),
stream
);