load hadamard_block_size from config (#3797)

This commit is contained in:
Yuan Xiaolan
2025-09-05 17:07:58 +08:00
committed by GitHub
parent 41aee08982
commit 2cf55168ca
10 changed files with 60 additions and 30 deletions

View File

@@ -255,7 +255,8 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& down_proj_in_scale, const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token, const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency, const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums); const int estimate_total_token_nums,
const int hadamard_block_size);
paddle::Tensor MoeExpertFFNWint2Func( paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input, const paddle::Tensor& permute_input,

View File

@@ -872,16 +872,14 @@ void MoeFastHardamardWrapper(const T *x_data,
const int64_t dim, const int64_t dim,
const int num_max_tokens_per_expert, const int num_max_tokens_per_expert,
bool used_in_ep_low_latency, bool used_in_ep_low_latency,
const int hadamard_block_size,
OutT* out, OutT* out,
cudaStream_t &stream) { cudaStream_t &stream) {
bool FLAGS_hardamard_use_diagonal_block_matrix = true; bool FLAGS_hardamard_use_diagonal_block_matrix = true;
static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size");
static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ?
stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512;
constexpr int kThreads = 128; constexpr int kThreads = 128;
if (FLAGS_hardamard_use_diagonal_block_matrix) { if (FLAGS_hardamard_use_diagonal_block_matrix) {
const int VecSize = hardamard_moe_block_size / kThreads; // 128 / 128 = 1 const int VecSize = hadamard_block_size / kThreads;
const int logN = int(ceil(std::log2(kThreads * VecSize))); const int logN = int(ceil(std::log2(kThreads * VecSize)));
constexpr int kNChunks = 1; constexpr int kNChunks = 1;
DISPATCH_SP_VS(VecSize, VEC_SIZE, { DISPATCH_SP_VS(VecSize, VEC_SIZE, {
@@ -991,6 +989,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
const int64_t dim, const int64_t dim,
const int num_max_tokens_per_expert, const int num_max_tokens_per_expert,
bool used_in_ep_low_latency, bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::float16 *out, phi::dtype::float16 *out,
cudaStream_t &stream cudaStream_t &stream
); );
@@ -1009,6 +1008,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
const int64_t dim, const int64_t dim,
const int num_max_tokens_per_expert, const int num_max_tokens_per_expert,
bool used_in_ep_low_latency, bool used_in_ep_low_latency,
const int hadamard_block_size,
int8_t *out, int8_t *out,
cudaStream_t &stream cudaStream_t &stream
); );
@@ -1027,6 +1027,7 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16
const int64_t dim, const int64_t dim,
const int num_max_tokens_per_expert, const int num_max_tokens_per_expert,
bool used_in_ep_low_latency, bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::bfloat16 *out, phi::dtype::bfloat16 *out,
cudaStream_t &stream cudaStream_t &stream
); );
@@ -1045,6 +1046,7 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
const int64_t dim, const int64_t dim,
const int num_max_tokens_per_expert, const int num_max_tokens_per_expert,
bool used_in_ep_low_latency, bool used_in_ep_low_latency,
const int hadamard_block_size,
int8_t *out, int8_t *out,
cudaStream_t &stream cudaStream_t &stream
); );

View File

@@ -32,5 +32,6 @@ void MoeFastHardamardWrapper(const T *x_data,
const int64_t dim, const int64_t dim,
const int num_max_tokens_per_expert, const int num_max_tokens_per_expert,
bool used_in_ep_low_latency, bool used_in_ep_low_latency,
const int hadamard_block_size,
OutT* out, OutT* out,
cudaStream_t &stream); cudaStream_t &stream);

View File

@@ -35,7 +35,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
const std::string& quant_method, const std::string& quant_method,
paddle::Tensor ffn_out, paddle::Tensor ffn_out,
bool used_in_ep_low_latency, bool used_in_ep_low_latency,
const int estimate_total_token_nums) { const int estimate_total_token_nums,
const int hadamard_block_size) {
using namespace phi; using namespace phi;
typedef PDTraits<T> traits_; typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_; typedef typename traits_::DataType DataType_;
@@ -291,6 +292,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
inter_size / 2, inter_size / 2,
num_max_tokens_per_expert, num_max_tokens_per_expert,
used_in_ep_low_latency, used_in_ep_low_latency,
hadamard_block_size,
reinterpret_cast<int8_t *>(int8_act_out->ptr()), reinterpret_cast<int8_t *>(int8_act_out->ptr()),
stream stream
); );
@@ -340,6 +342,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
inter_size / 2, inter_size / 2,
num_max_tokens_per_expert, num_max_tokens_per_expert,
used_in_ep_low_latency, used_in_ep_low_latency,
hadamard_block_size,
act_out_tensor.data<data_t>(), act_out_tensor.data<data_t>(),
stream stream
); );
@@ -403,7 +406,7 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& down_proj_in_scale, const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token, const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency, const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums) { const int estimate_total_token_nums, const int hadamard_block_size) {
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() : const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 : (quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
@@ -424,7 +427,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
quant_method, quant_method,
ffn_out, ffn_out,
used_in_ep_low_latency, used_in_ep_low_latency,
estimate_total_token_nums); estimate_total_token_nums,
hadamard_block_size);
break; break;
case paddle::DataType::FLOAT16: case paddle::DataType::FLOAT16:
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input, MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
@@ -439,7 +443,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
quant_method, quant_method,
ffn_out, ffn_out,
used_in_ep_low_latency, used_in_ep_low_latency,
estimate_total_token_nums); estimate_total_token_nums,
hadamard_block_size);
break; break;
default: default:
PD_THROW("Unsupported data type for MoeExpertFFN"); PD_THROW("Unsupported data type for MoeExpertFFN");
@@ -458,7 +463,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::optional<paddle::Tensor>& down_proj_in_scale, const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token, const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency, const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums) { const int estimate_total_token_nums,
const int hadamard_block_size) {
return {MoeExpertFFNFunc(permute_input, return {MoeExpertFFNFunc(permute_input,
tokens_expert_prefix_sum, tokens_expert_prefix_sum,
up_gate_proj_weight, up_gate_proj_weight,
@@ -470,7 +476,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
expert_idx_per_token, expert_idx_per_token,
quant_method, quant_method,
used_in_ep_low_latency, used_in_ep_low_latency,
estimate_total_token_nums)}; estimate_total_token_nums,
hadamard_block_size)};
} }
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape( std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
@@ -485,7 +492,8 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape, const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
const std::string& quant_method, const std::string& quant_method,
const bool used_in_ep_low_latency, const bool used_in_ep_low_latency,
const int estimate_total_token_nums) { const int estimate_total_token_nums,
const int hadamard_block_size) {
return {permute_input_shape}; return {permute_input_shape};
} }
@@ -499,7 +507,7 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::optional<paddle::DataType> &down_proj_scale_dtype, const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype, const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
const std::string &quant_method, const bool used_in_ep_low_latency, const std::string &quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums) { const int estimate_total_token_nums, const int hadamard_block_size) {
if (quant_method == "w4a8" || quant_method == "w4afp8") { if (quant_method == "w4a8" || quant_method == "w4afp8") {
return {up_gate_proj_scale_dtype.get()}; return {up_gate_proj_scale_dtype.get()};
} else { } else {
@@ -555,6 +563,8 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
* Options: "none", "weight_only_int4", "weight_only_int8", "w4a8" * Options: "none", "weight_only_int4", "weight_only_int8", "w4a8"
* - used_in_ep_low_latency: Whether running in low latency mode * - used_in_ep_low_latency: Whether running in low latency mode
* Affects activation function implementation * Affects activation function implementation
* - estimate_total_token_nums: estimate total token numbers
* - hadamard_block_size: hadamard block size for w4a8/w4afp8 quantization
* *
* Note: * Note:
* - w4a8 mode requires additional workspace memory allocation * - w4a8 mode requires additional workspace memory allocation
@@ -571,7 +581,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
paddle::Optional("down_proj_in_scale"), paddle::Optional("down_proj_in_scale"),
paddle::Optional("expert_idx_per_token")}) paddle::Optional("expert_idx_per_token")})
.Outputs({"output_tensor"}) .Outputs({"output_tensor"})
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int"}) .Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int"})
.SetKernelFn(PD_KERNEL(MoeExpertFFN)) .SetKernelFn(PD_KERNEL(MoeExpertFFN))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype)); .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));

View File

@@ -401,10 +401,6 @@ class LLMEngine:
"FLAGS_use_append_attn": 1, "FLAGS_use_append_attn": 1,
"NCCL_ALGO": "Ring", "NCCL_ALGO": "Ring",
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)), "FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
"FLAGS_hardamard_moe_block_size": int(os.getenv("FLAGS_hardamard_moe_block_size", 128)),
"FLAGS_hardamard_use_diagonal_block_matrix": int(
os.getenv("FLAGS_hardamard_use_diagonal_block_matrix", 0)
),
} }
# environment variables needed by Dy2St # environment variables needed by Dy2St
variables.update( variables.update(

View File

@@ -127,6 +127,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
self.moe_quant_type, self.moe_quant_type,
used_in_ep_low_latency, used_in_ep_low_latency,
estimate_total_token_nums, estimate_total_token_nums,
getattr(layer.moe_quant_config, "hadamard_block_size", 128),
) )
def apply_ep_prefill( def apply_ep_prefill(

View File

@@ -38,6 +38,7 @@ class MixQuantConfig(QuantConfigBase):
has_zero_point: bool = False, has_zero_point: bool = False,
is_permuted: bool = True, is_permuted: bool = True,
is_checkpoint_bf16: bool = False, is_checkpoint_bf16: bool = False,
hadamard_block_size: int = 128,
) -> None: ) -> None:
super().__init__() super().__init__()
self.dense_quant_type = dense_quant_type self.dense_quant_type = dense_quant_type
@@ -54,6 +55,7 @@ class MixQuantConfig(QuantConfigBase):
self.quant_round_type = 0 self.quant_round_type = 0
self.is_permuted = is_permuted self.is_permuted = is_permuted
self.is_checkpoint_bf16 = is_checkpoint_bf16 self.is_checkpoint_bf16 = is_checkpoint_bf16
self.hadamard_block_size = hadamard_block_size
def name(self) -> str: def name(self) -> str:
return "mix_quant" return "mix_quant"
@@ -69,6 +71,7 @@ class MixQuantConfig(QuantConfigBase):
config.get("has_zero_point", False), config.get("has_zero_point", False),
config.get("is_permuted", True), config.get("is_permuted", True),
config.get("is_checkpoint_bf16", False), config.get("is_checkpoint_bf16", False),
config.get("hadamard_block_size", 128),
) )
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -76,13 +79,25 @@ class MixQuantConfig(QuantConfigBase):
if layer.moe_tag == "Image": if layer.moe_tag == "Image":
return ( return (
get_quantization_config(self.image_moe_quant_type) get_quantization_config(self.image_moe_quant_type)
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16}) .from_config(
{
"is_permuted": self.is_permuted,
"self.is_checkpoint_bf16": self.is_checkpoint_bf16,
"hadamard_block_size": self.hadamard_block_size,
}
)
.get_quant_method(layer) .get_quant_method(layer)
) )
else: else:
return ( return (
get_quantization_config(self.moe_quant_type) get_quantization_config(self.moe_quant_type)
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16}) .from_config(
{
"is_permuted": self.is_permuted,
"self.is_checkpoint_bf16": self.is_checkpoint_bf16,
"hadamard_block_size": self.hadamard_block_size,
}
)
.get_quant_method(layer) .get_quant_method(layer)
) )
elif isinstance(layer, Attention): elif isinstance(layer, Attention):

View File

@@ -25,9 +25,10 @@ class W4A8Config(QuantConfigBase):
quantization config for weight 4bits and activation 8bits quantization config for weight 4bits and activation 8bits
""" """
def __init__(self, is_permuted) -> None: def __init__(self, is_permuted, hadamard_block_size) -> None:
super().__init__() super().__init__()
self.is_permuted = is_permuted self.is_permuted = is_permuted
self.hadamard_block_size = hadamard_block_size
def name(self) -> str: def name(self) -> str:
return "w4a8" return "w4a8"
@@ -35,7 +36,8 @@ class W4A8Config(QuantConfigBase):
@classmethod @classmethod
def from_config(cls, config: dict) -> "W4A8Config": def from_config(cls, config: dict) -> "W4A8Config":
is_permuted = config.get("is_permuted", True) is_permuted = config.get("is_permuted", True)
return cls(is_permuted) hadamard_block_size = config.get("hadamard_block_size", 128)
return cls(is_permuted, hadamard_block_size)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):

View File

@@ -31,7 +31,7 @@ class W4AFP8Config(QuantConfigBase):
quantization config for weight 4bits and activation fp8 quantization config for weight 4bits and activation fp8
""" """
def __init__(self, weight_scale_dict, act_scale_dict, is_permuted) -> None: def __init__(self, weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size) -> None:
super().__init__() super().__init__()
self.weight_scale_dict = weight_scale_dict self.weight_scale_dict = weight_scale_dict
self.act_scale_dict = act_scale_dict self.act_scale_dict = act_scale_dict
@@ -39,6 +39,7 @@ class W4AFP8Config(QuantConfigBase):
self.quant_min_bound = -448 self.quant_min_bound = -448
self.quant_round_type = 1 self.quant_round_type = 1
self.is_permuted = is_permuted self.is_permuted = is_permuted
self.hadamard_block_size = hadamard_block_size
def name(self) -> str: def name(self) -> str:
return "w4afp8" return "w4afp8"
@@ -48,7 +49,8 @@ class W4AFP8Config(QuantConfigBase):
weight_scale_dict = config.get("weight_scale_dict", None) weight_scale_dict = config.get("weight_scale_dict", None)
act_scale_dict = config.get("act_scale_dict", None) act_scale_dict = config.get("act_scale_dict", None)
is_permuted = config.get("is_permuted", True) is_permuted = config.get("is_permuted", True)
return cls(weight_scale_dict, act_scale_dict, is_permuted) hadamard_block_size = config.get("hadamard_block_size", 128)
return cls(weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):

View File

@@ -23,10 +23,10 @@ from fastdeploy.entrypoints.llm import LLM
bash_path = os.getenv("MODEL_PATH") bash_path = os.getenv("MODEL_PATH")
FD_ENGINE_QUEUE_PORTS = [ FD_ENGINE_QUEUE_PORTS = [
[9961, 9962, 9963, 9964, 9965, 9966, 9967, 9968], [9961, 9962],
[9971, 9972, 9973, 9974, 9975, 9976, 9977, 9978], [9971, 9972],
[9981, 9982, 9983, 9984, 9985, 9986, 9987, 9988], [9981, 9982],
[9991, 9992, 9993, 9994, 9995, 9996, 9997, 9998], [9991, 9992],
] ]
@@ -49,7 +49,7 @@ def llm(request):
llm_instance = LLM( llm_instance = LLM(
model=model_path, model=model_path,
tensor_parallel_size=1, tensor_parallel_size=1,
data_parallel_size=8, data_parallel_size=2,
max_model_len=8192, max_model_len=8192,
num_gpu_blocks_override=1024, num_gpu_blocks_override=1024,
engine_worker_queue_port=FD_ENGINE_QUEUE_PORTS[port_index], engine_worker_queue_port=FD_ENGINE_QUEUE_PORTS[port_index],
@@ -58,7 +58,7 @@ def llm(request):
) )
yield weakref.proxy(llm_instance) yield weakref.proxy(llm_instance)
except Exception as e: except Exception as e:
pytest.skip(f"LLM initialization failed: {e}") assert False, f"LLM initialization failed: {e}"
@pytest.mark.timeout(60) @pytest.mark.timeout(60)