From 2cf55168ca99406094fc6eefd98163196f90473f Mon Sep 17 00:00:00 2001 From: Yuan Xiaolan <845594810@qq.com> Date: Fri, 5 Sep 2025 17:07:58 +0800 Subject: [PATCH] load hadamard_block_size from config (#3797) --- custom_ops/gpu_ops/cpp_extensions.cc | 3 +- .../gpu_ops/moe/fast_hardamard_kernel.cu | 10 ++++--- .../gpu_ops/moe/fast_hardamard_kernel.h | 1 + custom_ops/gpu_ops/moe/moe_ffn.cu | 28 +++++++++++++------ fastdeploy/engine/engine.py | 4 --- .../layers/moe/fused_moe_cutlass_backend.py | 1 + .../layers/quantization/mix_quant.py | 19 +++++++++++-- .../layers/quantization/w4a8.py | 6 ++-- .../layers/quantization/w4afp8.py | 6 ++-- tests/model_loader/test_w4a8_model.py | 12 ++++---- 10 files changed, 60 insertions(+), 30 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 1a3588491..b0bb23604 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -255,7 +255,8 @@ paddle::Tensor MoeExpertFFNFunc( const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, const std::string& quant_method, const bool used_in_ep_low_latency, - const int estimate_total_token_nums); + const int estimate_total_token_nums, + const int hadamard_block_size); paddle::Tensor MoeExpertFFNWint2Func( const paddle::Tensor& permute_input, diff --git a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu index 63b45b743..1323cb483 100644 --- a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu +++ b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu @@ -872,16 +872,14 @@ void MoeFastHardamardWrapper(const T *x_data, const int64_t dim, const int num_max_tokens_per_expert, bool used_in_ep_low_latency, + const int hadamard_block_size, OutT* out, cudaStream_t &stream) { bool FLAGS_hardamard_use_diagonal_block_matrix = true; - static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size"); - static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ? - stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512; constexpr int kThreads = 128; if (FLAGS_hardamard_use_diagonal_block_matrix) { - const int VecSize = hardamard_moe_block_size / kThreads; // 128 / 128 = 1 + const int VecSize = hadamard_block_size / kThreads; const int logN = int(ceil(std::log2(kThreads * VecSize))); constexpr int kNChunks = 1; DISPATCH_SP_VS(VecSize, VEC_SIZE, { @@ -991,6 +989,7 @@ template void MoeFastHardamardWrapper( const int64_t dim, const int num_max_tokens_per_expert, bool used_in_ep_low_latency, + const int hadamard_block_size, phi::dtype::float16 *out, cudaStream_t &stream ); @@ -1009,6 +1008,7 @@ template void MoeFastHardamardWrapper( const int64_t dim, const int num_max_tokens_per_expert, bool used_in_ep_low_latency, + const int hadamard_block_size, int8_t *out, cudaStream_t &stream ); @@ -1027,6 +1027,7 @@ template void MoeFastHardamardWrapper( const int64_t dim, const int num_max_tokens_per_expert, bool used_in_ep_low_latency, + const int hadamard_block_size, int8_t *out, cudaStream_t &stream ); diff --git a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h index 64c5c20ad..ccb624e5c 100644 --- a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h +++ b/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h @@ -32,5 +32,6 @@ void MoeFastHardamardWrapper(const T *x_data, const int64_t dim, const int num_max_tokens_per_expert, bool used_in_ep_low_latency, + const int hadamard_block_size, OutT* out, cudaStream_t &stream); diff --git a/custom_ops/gpu_ops/moe/moe_ffn.cu b/custom_ops/gpu_ops/moe/moe_ffn.cu index 117f1c63e..7387246ab 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn.cu @@ -35,7 +35,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, const std::string& quant_method, paddle::Tensor ffn_out, bool used_in_ep_low_latency, - const int estimate_total_token_nums) { + const int estimate_total_token_nums, + const int hadamard_block_size) { using namespace phi; typedef PDTraits traits_; typedef typename traits_::DataType DataType_; @@ -291,6 +292,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, inter_size / 2, num_max_tokens_per_expert, used_in_ep_low_latency, + hadamard_block_size, reinterpret_cast(int8_act_out->ptr()), stream ); @@ -340,6 +342,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, inter_size / 2, num_max_tokens_per_expert, used_in_ep_low_latency, + hadamard_block_size, act_out_tensor.data(), stream ); @@ -403,7 +406,7 @@ paddle::Tensor MoeExpertFFNFunc( const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, const std::string& quant_method, const bool used_in_ep_low_latency, - const int estimate_total_token_nums) { + const int estimate_total_token_nums, const int hadamard_block_size) { const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() : (quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 : @@ -424,7 +427,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() quant_method, ffn_out, used_in_ep_low_latency, - estimate_total_token_nums); + estimate_total_token_nums, + hadamard_block_size); break; case paddle::DataType::FLOAT16: MoeFFNKernel(permute_input, @@ -439,7 +443,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() quant_method, ffn_out, used_in_ep_low_latency, - estimate_total_token_nums); + estimate_total_token_nums, + hadamard_block_size); break; default: PD_THROW("Unsupported data type for MoeExpertFFN"); @@ -458,7 +463,8 @@ std::vector MoeExpertFFN( const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, const std::string& quant_method, const bool used_in_ep_low_latency, - const int estimate_total_token_nums) { + const int estimate_total_token_nums, + const int hadamard_block_size) { return {MoeExpertFFNFunc(permute_input, tokens_expert_prefix_sum, up_gate_proj_weight, @@ -470,7 +476,8 @@ std::vector MoeExpertFFN( expert_idx_per_token, quant_method, used_in_ep_low_latency, - estimate_total_token_nums)}; + estimate_total_token_nums, + hadamard_block_size)}; } std::vector> MoeExpertFFNInferShape( @@ -485,7 +492,8 @@ std::vector> MoeExpertFFNInferShape( const paddle::optional>& expert_idx_per_token_shape, const std::string& quant_method, const bool used_in_ep_low_latency, - const int estimate_total_token_nums) { + const int estimate_total_token_nums, + const int hadamard_block_size) { return {permute_input_shape}; } @@ -499,7 +507,7 @@ std::vector MoeExpertFFNInferDtype( const paddle::optional &down_proj_scale_dtype, const paddle::optional &down_proj_in_scale_dtype, const std::string &quant_method, const bool used_in_ep_low_latency, - const int estimate_total_token_nums) { + const int estimate_total_token_nums, const int hadamard_block_size) { if (quant_method == "w4a8" || quant_method == "w4afp8") { return {up_gate_proj_scale_dtype.get()}; } else { @@ -555,6 +563,8 @@ std::vector MoeExpertFFNInferDtype( * Options: "none", "weight_only_int4", "weight_only_int8", "w4a8" * - used_in_ep_low_latency: Whether running in low latency mode * Affects activation function implementation + * - estimate_total_token_nums: estimate total token numbers + * - hadamard_block_size: hadamard block size for w4a8/w4afp8 quantization * * Note: * - w4a8 mode requires additional workspace memory allocation @@ -571,7 +581,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn) paddle::Optional("down_proj_in_scale"), paddle::Optional("expert_idx_per_token")}) .Outputs({"output_tensor"}) - .Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int"}) + .Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int"}) .SetKernelFn(PD_KERNEL(MoeExpertFFN)) .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype)); diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index a710d8cc9..e24db25b4 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -401,10 +401,6 @@ class LLMEngine: "FLAGS_use_append_attn": 1, "NCCL_ALGO": "Ring", "FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)), - "FLAGS_hardamard_moe_block_size": int(os.getenv("FLAGS_hardamard_moe_block_size", 128)), - "FLAGS_hardamard_use_diagonal_block_matrix": int( - os.getenv("FLAGS_hardamard_use_diagonal_block_matrix", 0) - ), } # environment variables needed by Dy2St variables.update( diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 78c011330..589b4b838 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -127,6 +127,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): self.moe_quant_type, used_in_ep_low_latency, estimate_total_token_nums, + getattr(layer.moe_quant_config, "hadamard_block_size", 128), ) def apply_ep_prefill( diff --git a/fastdeploy/model_executor/layers/quantization/mix_quant.py b/fastdeploy/model_executor/layers/quantization/mix_quant.py index 05c456d55..b36b71938 100644 --- a/fastdeploy/model_executor/layers/quantization/mix_quant.py +++ b/fastdeploy/model_executor/layers/quantization/mix_quant.py @@ -38,6 +38,7 @@ class MixQuantConfig(QuantConfigBase): has_zero_point: bool = False, is_permuted: bool = True, is_checkpoint_bf16: bool = False, + hadamard_block_size: int = 128, ) -> None: super().__init__() self.dense_quant_type = dense_quant_type @@ -54,6 +55,7 @@ class MixQuantConfig(QuantConfigBase): self.quant_round_type = 0 self.is_permuted = is_permuted self.is_checkpoint_bf16 = is_checkpoint_bf16 + self.hadamard_block_size = hadamard_block_size def name(self) -> str: return "mix_quant" @@ -69,6 +71,7 @@ class MixQuantConfig(QuantConfigBase): config.get("has_zero_point", False), config.get("is_permuted", True), config.get("is_checkpoint_bf16", False), + config.get("hadamard_block_size", 128), ) def get_quant_method(self, layer) -> Optional[QuantMethodBase]: @@ -76,13 +79,25 @@ class MixQuantConfig(QuantConfigBase): if layer.moe_tag == "Image": return ( get_quantization_config(self.image_moe_quant_type) - .from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16}) + .from_config( + { + "is_permuted": self.is_permuted, + "self.is_checkpoint_bf16": self.is_checkpoint_bf16, + "hadamard_block_size": self.hadamard_block_size, + } + ) .get_quant_method(layer) ) else: return ( get_quantization_config(self.moe_quant_type) - .from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16}) + .from_config( + { + "is_permuted": self.is_permuted, + "self.is_checkpoint_bf16": self.is_checkpoint_bf16, + "hadamard_block_size": self.hadamard_block_size, + } + ) .get_quant_method(layer) ) elif isinstance(layer, Attention): diff --git a/fastdeploy/model_executor/layers/quantization/w4a8.py b/fastdeploy/model_executor/layers/quantization/w4a8.py index 944d5219a..3806a74a8 100644 --- a/fastdeploy/model_executor/layers/quantization/w4a8.py +++ b/fastdeploy/model_executor/layers/quantization/w4a8.py @@ -25,9 +25,10 @@ class W4A8Config(QuantConfigBase): quantization config for weight 4bits and activation 8bits """ - def __init__(self, is_permuted) -> None: + def __init__(self, is_permuted, hadamard_block_size) -> None: super().__init__() self.is_permuted = is_permuted + self.hadamard_block_size = hadamard_block_size def name(self) -> str: return "w4a8" @@ -35,7 +36,8 @@ class W4A8Config(QuantConfigBase): @classmethod def from_config(cls, config: dict) -> "W4A8Config": is_permuted = config.get("is_permuted", True) - return cls(is_permuted) + hadamard_block_size = config.get("hadamard_block_size", 128) + return cls(is_permuted, hadamard_block_size) def get_quant_method(self, layer) -> Optional[QuantMethodBase]: if isinstance(layer, FusedMoE): diff --git a/fastdeploy/model_executor/layers/quantization/w4afp8.py b/fastdeploy/model_executor/layers/quantization/w4afp8.py index 4afc8fa58..e7be78b06 100644 --- a/fastdeploy/model_executor/layers/quantization/w4afp8.py +++ b/fastdeploy/model_executor/layers/quantization/w4afp8.py @@ -31,7 +31,7 @@ class W4AFP8Config(QuantConfigBase): quantization config for weight 4bits and activation fp8 """ - def __init__(self, weight_scale_dict, act_scale_dict, is_permuted) -> None: + def __init__(self, weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size) -> None: super().__init__() self.weight_scale_dict = weight_scale_dict self.act_scale_dict = act_scale_dict @@ -39,6 +39,7 @@ class W4AFP8Config(QuantConfigBase): self.quant_min_bound = -448 self.quant_round_type = 1 self.is_permuted = is_permuted + self.hadamard_block_size = hadamard_block_size def name(self) -> str: return "w4afp8" @@ -48,7 +49,8 @@ class W4AFP8Config(QuantConfigBase): weight_scale_dict = config.get("weight_scale_dict", None) act_scale_dict = config.get("act_scale_dict", None) is_permuted = config.get("is_permuted", True) - return cls(weight_scale_dict, act_scale_dict, is_permuted) + hadamard_block_size = config.get("hadamard_block_size", 128) + return cls(weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size) def get_quant_method(self, layer) -> Optional[QuantMethodBase]: if isinstance(layer, FusedMoE): diff --git a/tests/model_loader/test_w4a8_model.py b/tests/model_loader/test_w4a8_model.py index 3007b0a1b..af3108aff 100644 --- a/tests/model_loader/test_w4a8_model.py +++ b/tests/model_loader/test_w4a8_model.py @@ -23,10 +23,10 @@ from fastdeploy.entrypoints.llm import LLM bash_path = os.getenv("MODEL_PATH") FD_ENGINE_QUEUE_PORTS = [ - [9961, 9962, 9963, 9964, 9965, 9966, 9967, 9968], - [9971, 9972, 9973, 9974, 9975, 9976, 9977, 9978], - [9981, 9982, 9983, 9984, 9985, 9986, 9987, 9988], - [9991, 9992, 9993, 9994, 9995, 9996, 9997, 9998], + [9961, 9962], + [9971, 9972], + [9981, 9982], + [9991, 9992], ] @@ -49,7 +49,7 @@ def llm(request): llm_instance = LLM( model=model_path, tensor_parallel_size=1, - data_parallel_size=8, + data_parallel_size=2, max_model_len=8192, num_gpu_blocks_override=1024, engine_worker_queue_port=FD_ENGINE_QUEUE_PORTS[port_index], @@ -58,7 +58,7 @@ def llm(request): ) yield weakref.proxy(llm_instance) except Exception as e: - pytest.skip(f"LLM initialization failed: {e}") + assert False, f"LLM initialization failed: {e}" @pytest.mark.timeout(60)