mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Quantization] Support w4afp8 MoE dynamic quantization (#5282)
* support dynamic activation quant for w4afp8 * support dynamic w4afp8 * add test * fix * fix --------- Co-authored-by: zhoutianzi666 <17801055074@163.com>
This commit is contained in:
@@ -430,7 +430,9 @@ __global__ void permute_x_kernel(
|
||||
}
|
||||
abs_max = phi::BlockAllReduce<MaxOp, float, Kthread>(abs_max);
|
||||
float scale = 440.f / abs_max; // use 440 so we do not have to clip
|
||||
dequant_scale[dst_token_idx] = abs_max;
|
||||
if (tid == 0) {
|
||||
dequant_scale[dst_token_idx] = abs_max;
|
||||
}
|
||||
for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) {
|
||||
Load<T, vec_size>(&data_smem[v_id * vec_size], &src_vec);
|
||||
#pragma unroll
|
||||
@@ -661,7 +663,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
|
||||
int dequant_scale_size = 1;
|
||||
if (moe_quant_type == "w4afp8" && !up_gate_proj_in_scale) {
|
||||
dequant_scale_size = moe_topk * num_rows;
|
||||
dequant_scale_size = token_nums_this_rank;
|
||||
}
|
||||
|
||||
auto dequant_scale =
|
||||
|
||||
Reference in New Issue
Block a user