[Optimize] Support WINT8 and group scale for Machete (#3905)

This commit is contained in:
Sunny-bot1
2025-09-15 12:01:34 +08:00
committed by GitHub
parent 4408dc7f67
commit b1a5b756a3
5 changed files with 125 additions and 42 deletions

View File

@@ -30,10 +30,12 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
std::optional<paddle::Tensor> const& maybe_token_scales,
std::string maybe_schedule) {
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
std::optional<int64_t> maybe_group_size_opt;
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
std::optional<std::string> maybe_schedule_opt;
if (maybe_schedule == "") {
maybe_schedule_opt = std::nullopt;
} else {
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
}
return machete::mm_dispatch({.A = A,
.B = B,
@@ -63,6 +65,8 @@ std::vector<paddle::Tensor> MacheteMMKernel(
paddle::DataType maybe_out_type;
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else if (b_type_str == "uint8b128") {
b_type_id = machete::kU8B128.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}

View File

@@ -51,6 +51,8 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else if (b_type_str == "uint8b128") {
b_type_id = machete::kU8B128.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}