mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
WINT4/WINT8 dense gemm default use Machete (#4451)
This commit is contained in:
@@ -86,3 +86,52 @@ std::vector<paddle::Tensor> MacheteMMKernel(
|
||||
maybe_schedule);
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MacheteMMKernelInferShape(
|
||||
std::vector<int64_t> const& A_shape,
|
||||
std::vector<int64_t> const& B_shape,
|
||||
paddle::optional<std::vector<int64_t>> const& maybe_group_scales_shape,
|
||||
paddle::optional<std::vector<int64_t>> const& maybe_group_zeros_shape,
|
||||
paddle::optional<std::vector<int64_t>> const& maybe_channel_scales_shape,
|
||||
paddle::optional<std::vector<int64_t>> const& maybe_token_scales_shape,
|
||||
std::string const& b_type_str,
|
||||
std::string const& maybe_out_type_str,
|
||||
int64_t const& maybe_group_size,
|
||||
std::string const& maybe_schedule) {
|
||||
return {{A_shape[0], B_shape[1]}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MacheteMMKernelInferDtype(
|
||||
paddle::DataType const& A_dtype,
|
||||
paddle::DataType const& B_dtype,
|
||||
paddle::optional<paddle::DataType> const& maybe_group_scales_dtype,
|
||||
paddle::optional<paddle::DataType> const& maybe_group_zeros_dtype,
|
||||
paddle::optional<paddle::DataType> const& maybe_channel_scales_dtype,
|
||||
paddle::optional<paddle::DataType> const& maybe_token_scales_dtype,
|
||||
std::string const& b_type_str,
|
||||
std::string const& maybe_out_type_str,
|
||||
int64_t const& maybe_group_size,
|
||||
std::string const& maybe_schedule) {
|
||||
|
||||
paddle::DataType maybe_out_type;
|
||||
if (maybe_out_type_str == "float16") {
|
||||
maybe_out_type = paddle::DataType::FLOAT16;
|
||||
} else if (maybe_out_type_str == "bfloat16") {
|
||||
maybe_out_type = paddle::DataType::BFLOAT16;
|
||||
} else {
|
||||
maybe_out_type = A_dtype;
|
||||
}
|
||||
return {maybe_out_type};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(machete_mm)
|
||||
.Inputs({"A", "B",
|
||||
paddle::Optional("maybe_group_scales"),
|
||||
paddle::Optional("maybe_group_zeros"),
|
||||
paddle::Optional("maybe_channel_scales"),
|
||||
paddle::Optional("maybe_token_scales")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"b_type_str:std::string", "maybe_out_type_str:std::string", "maybe_group_size:int64_t", "maybe_schedule:std::string"})
|
||||
.SetKernelFn(PD_KERNEL(MacheteMMKernel))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MacheteMMKernelInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MacheteMMKernelInferDtype));
|
||||
|
||||
@@ -71,3 +71,23 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
|
||||
return {B_prepacked};
|
||||
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MachetePrepackBKernelInferShape(
|
||||
std::vector<int64_t> const& B_shape, std::string const& a_type_str, std::string const& b_type_str,
|
||||
std::string const& maybe_group_scales_type_str) {
|
||||
return {{B_shape[1], B_shape[0]}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MachetePrepackBKernelInferDtype(
|
||||
paddle::DataType const& B_dtype, std::string const& a_type_str, std::string const& b_type_str,
|
||||
std::string const& maybe_group_scales_type_str) {
|
||||
return {B_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(machete_prepack_B)
|
||||
.Inputs({"B"})
|
||||
.Outputs({"B_prepacked"})
|
||||
.Attrs({"a_type_str:std::string", "b_type_str:std::string", "maybe_group_scales_type_str:std::string"})
|
||||
.SetKernelFn(PD_KERNEL(MachetePrepackBKernel))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MachetePrepackBKernelInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MachetePrepackBKernelInferDtype));
|
||||
|
||||
Reference in New Issue
Block a user