WINT4/WINT8 dense gemm default use Machete (#4451)

This commit is contained in:
Sunny-bot1
2025-10-23 17:57:59 +08:00
committed by GitHub
parent a240425db9
commit 4ffe41a747
12 changed files with 310 additions and 15 deletions

View File

@@ -86,3 +86,52 @@ std::vector<paddle::Tensor> MacheteMMKernel(
maybe_schedule);
return {out};
}
std::vector<std::vector<int64_t>> MacheteMMKernelInferShape(
std::vector<int64_t> const& A_shape,
std::vector<int64_t> const& B_shape,
paddle::optional<std::vector<int64_t>> const& maybe_group_scales_shape,
paddle::optional<std::vector<int64_t>> const& maybe_group_zeros_shape,
paddle::optional<std::vector<int64_t>> const& maybe_channel_scales_shape,
paddle::optional<std::vector<int64_t>> const& maybe_token_scales_shape,
std::string const& b_type_str,
std::string const& maybe_out_type_str,
int64_t const& maybe_group_size,
std::string const& maybe_schedule) {
return {{A_shape[0], B_shape[1]}};
}
std::vector<paddle::DataType> MacheteMMKernelInferDtype(
paddle::DataType const& A_dtype,
paddle::DataType const& B_dtype,
paddle::optional<paddle::DataType> const& maybe_group_scales_dtype,
paddle::optional<paddle::DataType> const& maybe_group_zeros_dtype,
paddle::optional<paddle::DataType> const& maybe_channel_scales_dtype,
paddle::optional<paddle::DataType> const& maybe_token_scales_dtype,
std::string const& b_type_str,
std::string const& maybe_out_type_str,
int64_t const& maybe_group_size,
std::string const& maybe_schedule) {
paddle::DataType maybe_out_type;
if (maybe_out_type_str == "float16") {
maybe_out_type = paddle::DataType::FLOAT16;
} else if (maybe_out_type_str == "bfloat16") {
maybe_out_type = paddle::DataType::BFLOAT16;
} else {
maybe_out_type = A_dtype;
}
return {maybe_out_type};
}
PD_BUILD_STATIC_OP(machete_mm)
.Inputs({"A", "B",
paddle::Optional("maybe_group_scales"),
paddle::Optional("maybe_group_zeros"),
paddle::Optional("maybe_channel_scales"),
paddle::Optional("maybe_token_scales")})
.Outputs({"out"})
.Attrs({"b_type_str:std::string", "maybe_out_type_str:std::string", "maybe_group_size:int64_t", "maybe_schedule:std::string"})
.SetKernelFn(PD_KERNEL(MacheteMMKernel))
.SetInferShapeFn(PD_INFER_SHAPE(MacheteMMKernelInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MacheteMMKernelInferDtype));

View File

@@ -71,3 +71,23 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
return {B_prepacked};
}
std::vector<std::vector<int64_t>> MachetePrepackBKernelInferShape(
std::vector<int64_t> const& B_shape, std::string const& a_type_str, std::string const& b_type_str,
std::string const& maybe_group_scales_type_str) {
return {{B_shape[1], B_shape[0]}};
}
std::vector<paddle::DataType> MachetePrepackBKernelInferDtype(
paddle::DataType const& B_dtype, std::string const& a_type_str, std::string const& b_type_str,
std::string const& maybe_group_scales_type_str) {
return {B_dtype};
}
PD_BUILD_STATIC_OP(machete_prepack_B)
.Inputs({"B"})
.Outputs({"B_prepacked"})
.Attrs({"a_type_str:std::string", "b_type_str:std::string", "maybe_group_scales_type_str:std::string"})
.SetKernelFn(PD_KERNEL(MachetePrepackBKernel))
.SetInferShapeFn(PD_INFER_SHAPE(MachetePrepackBKernelInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MachetePrepackBKernelInferDtype));