mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-16 21:51:31 +08:00
fix ep wint8 (#4102)
This commit is contained in:
@@ -448,137 +448,71 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
|||||||
auto place = input.place();
|
auto place = input.place();
|
||||||
const int gridx = min(132 * 8, num_rows);
|
const int gridx = min(132 * 8, num_rows);
|
||||||
if (moe_quant_type == "w4a8") {
|
if (moe_quant_type == "w4a8") {
|
||||||
if (num_experts_per_rank == 8) {
|
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||||
permute_x_kernel<data_t, int8_t, 8><<<gridx, 512, 0, stream>>>(
|
permute_x_kernel<data_t, int8_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||||
input.data<data_t>(),
|
input.data<data_t>(),
|
||||||
topk_ids.data<int64_t>(),
|
topk_ids.data<int64_t>(),
|
||||||
topk_weights.data<float>(),
|
topk_weights.data<float>(),
|
||||||
token_nums_per_expert.data<int>(),
|
token_nums_per_expert.data<int>(),
|
||||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||||
moe_topk,
|
moe_topk,
|
||||||
num_rows,
|
num_rows,
|
||||||
token_nums_this_rank,
|
token_nums_this_rank,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
permute_input->data<int8_t>(),
|
permute_input->data<int8_t>(),
|
||||||
permute_indices_per_token->data<int>(),
|
permute_indices_per_token->data<int>(),
|
||||||
dst_weights->data<float>(),
|
dst_weights->data<float>(),
|
||||||
dst_indices->data<int>(),
|
dst_indices->data<int>(),
|
||||||
cumsum_idx_gpu->data<int>(),
|
cumsum_idx_gpu->data<int>(),
|
||||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||||
expert_idx_per_token->data<int64_t>(),
|
expert_idx_per_token->data<int64_t>(),
|
||||||
127.0,
|
127.0,
|
||||||
-127.0
|
-127.0
|
||||||
);
|
);)
|
||||||
} else if (num_experts_per_rank == 16) {
|
|
||||||
permute_x_kernel<data_t, int8_t, 16><<<gridx, 512, 0, stream>>>(
|
|
||||||
input.data<data_t>(),
|
|
||||||
topk_ids.data<int64_t>(),
|
|
||||||
topk_weights.data<float>(),
|
|
||||||
token_nums_per_expert.data<int>(),
|
|
||||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
|
||||||
moe_topk,
|
|
||||||
num_rows,
|
|
||||||
token_nums_this_rank,
|
|
||||||
hidden_size,
|
|
||||||
permute_input->data<int8_t>(),
|
|
||||||
permute_indices_per_token->data<int>(),
|
|
||||||
dst_weights->data<float>(),
|
|
||||||
dst_indices->data<int>(),
|
|
||||||
cumsum_idx_gpu->data<int>(),
|
|
||||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
|
||||||
expert_idx_per_token->data<int64_t>(),
|
|
||||||
127.0,
|
|
||||||
-127.0
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else if (moe_quant_type == "w4afp8") {
|
} else if (moe_quant_type == "w4afp8") {
|
||||||
if (num_experts_per_rank == 8) {
|
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||||
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
|
permute_x_kernel<data_t, data_t_fp8, NUM_EXPERTS_PER_RANK, 512><<<gridx, 512, 0, stream>>>(
|
||||||
input.data<data_t>(),
|
input.data<data_t>(),
|
||||||
topk_ids.data<int64_t>(),
|
topk_ids.data<int64_t>(),
|
||||||
topk_weights.data<float>(),
|
topk_weights.data<float>(),
|
||||||
token_nums_per_expert.data<int>(),
|
token_nums_per_expert.data<int>(),
|
||||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||||
moe_topk,
|
moe_topk,
|
||||||
num_rows,
|
num_rows,
|
||||||
token_nums_this_rank,
|
token_nums_this_rank,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
permute_input->data<data_t_fp8>(),
|
permute_input->data<data_t_fp8>(),
|
||||||
permute_indices_per_token->data<int>(),
|
permute_indices_per_token->data<int>(),
|
||||||
dst_weights->data<float>(),
|
dst_weights->data<float>(),
|
||||||
dst_indices->data<int>(),
|
dst_indices->data<int>(),
|
||||||
cumsum_idx_gpu->data<int>(),
|
cumsum_idx_gpu->data<int>(),
|
||||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||||
expert_idx_per_token->data<int64_t>(),
|
expert_idx_per_token->data<int64_t>(),
|
||||||
448.0f,
|
448.0f,
|
||||||
-448.0f
|
-448.0f
|
||||||
);
|
);)
|
||||||
} else if (num_experts_per_rank == 16) {
|
|
||||||
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
|
|
||||||
input.data<data_t>(),
|
|
||||||
topk_ids.data<int64_t>(),
|
|
||||||
topk_weights.data<float>(),
|
|
||||||
token_nums_per_expert.data<int>(),
|
|
||||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
|
||||||
moe_topk,
|
|
||||||
num_rows,
|
|
||||||
token_nums_this_rank,
|
|
||||||
hidden_size,
|
|
||||||
permute_input->data<data_t_fp8>(),
|
|
||||||
permute_indices_per_token->data<int>(),
|
|
||||||
dst_weights->data<float>(),
|
|
||||||
dst_indices->data<int>(),
|
|
||||||
cumsum_idx_gpu->data<int>(),
|
|
||||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
|
||||||
expert_idx_per_token->data<int64_t>(),
|
|
||||||
448.0f,
|
|
||||||
-448.0f
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if (num_experts_per_rank == 8) {
|
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||||
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
|
permute_x_kernel<data_t, data_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||||
input.data<data_t>(),
|
input.data<data_t>(),
|
||||||
topk_ids.data<int64_t>(),
|
topk_ids.data<int64_t>(),
|
||||||
topk_weights.data<float>(),
|
topk_weights.data<float>(),
|
||||||
token_nums_per_expert.data<int>(),
|
token_nums_per_expert.data<int>(),
|
||||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||||
moe_topk,
|
moe_topk,
|
||||||
num_rows,
|
num_rows,
|
||||||
token_nums_this_rank,
|
token_nums_this_rank,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
permute_input->data<data_t>(),
|
permute_input->data<data_t>(),
|
||||||
permute_indices_per_token->data<int>(),
|
permute_indices_per_token->data<int>(),
|
||||||
dst_weights->data<float>(),
|
dst_weights->data<float>(),
|
||||||
dst_indices->data<int>(),
|
dst_indices->data<int>(),
|
||||||
cumsum_idx_gpu->data<int>(),
|
cumsum_idx_gpu->data<int>(),
|
||||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||||
expert_idx_per_token->data<int64_t>(),
|
expert_idx_per_token->data<int64_t>(),
|
||||||
127.0,
|
127.0,
|
||||||
-127.0
|
-127.0
|
||||||
);
|
);)
|
||||||
} else if (num_experts_per_rank == 16) {
|
|
||||||
permute_x_kernel<data_t, data_t, 16><<<gridx, 512, 0, stream>>>(
|
|
||||||
input.data<data_t>(),
|
|
||||||
topk_ids.data<int64_t>(),
|
|
||||||
topk_weights.data<float>(),
|
|
||||||
token_nums_per_expert.data<int>(),
|
|
||||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
|
||||||
moe_topk,
|
|
||||||
num_rows,
|
|
||||||
token_nums_this_rank,
|
|
||||||
hidden_size,
|
|
||||||
permute_input->data<data_t>(),
|
|
||||||
permute_indices_per_token->data<int>(),
|
|
||||||
dst_weights->data<float>(),
|
|
||||||
dst_indices->data<int>(),
|
|
||||||
cumsum_idx_gpu->data<int>(),
|
|
||||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
|
||||||
expert_idx_per_token->data<int64_t>(),
|
|
||||||
127.0,
|
|
||||||
-127.0
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user