mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
load hadamard_block_size from config (#3797)
This commit is contained in:
@@ -255,7 +255,8 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums);
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size);
|
||||
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
|
@@ -872,16 +872,14 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
OutT* out,
|
||||
cudaStream_t &stream) {
|
||||
bool FLAGS_hardamard_use_diagonal_block_matrix = true;
|
||||
|
||||
static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size");
|
||||
static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ?
|
||||
stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512;
|
||||
constexpr int kThreads = 128;
|
||||
if (FLAGS_hardamard_use_diagonal_block_matrix) {
|
||||
const int VecSize = hardamard_moe_block_size / kThreads; // 128 / 128 = 1
|
||||
const int VecSize = hadamard_block_size / kThreads;
|
||||
const int logN = int(ceil(std::log2(kThreads * VecSize)));
|
||||
constexpr int kNChunks = 1;
|
||||
DISPATCH_SP_VS(VecSize, VEC_SIZE, {
|
||||
@@ -991,6 +989,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, phi::dtype::float16>(
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
phi::dtype::float16 *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -1009,6 +1008,7 @@ template void MoeFastHardamardWrapper<phi::dtype::float16, int8_t>(
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
int8_t *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -1027,6 +1027,7 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::bfloat16
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
phi::dtype::bfloat16 *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
@@ -1045,6 +1046,7 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
int8_t *out,
|
||||
cudaStream_t &stream
|
||||
);
|
||||
|
@@ -32,5 +32,6 @@ void MoeFastHardamardWrapper(const T *x_data,
|
||||
const int64_t dim,
|
||||
const int num_max_tokens_per_expert,
|
||||
bool used_in_ep_low_latency,
|
||||
const int hadamard_block_size,
|
||||
OutT* out,
|
||||
cudaStream_t &stream);
|
||||
|
@@ -35,7 +35,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out,
|
||||
bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
@@ -291,6 +292,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
inter_size / 2,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
hadamard_block_size,
|
||||
reinterpret_cast<int8_t *>(int8_act_out->ptr()),
|
||||
stream
|
||||
);
|
||||
@@ -340,6 +342,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
inter_size / 2,
|
||||
num_max_tokens_per_expert,
|
||||
used_in_ep_low_latency,
|
||||
hadamard_block_size,
|
||||
act_out_tensor.data<data_t>(),
|
||||
stream
|
||||
);
|
||||
@@ -403,7 +406,7 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums, const int hadamard_block_size) {
|
||||
|
||||
const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() :
|
||||
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
|
||||
@@ -424,7 +427,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
quant_method,
|
||||
ffn_out,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums);
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
@@ -439,7 +443,8 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
quant_method,
|
||||
ffn_out,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums);
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeExpertFFN");
|
||||
@@ -458,7 +463,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
return {MoeExpertFFNFunc(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
up_gate_proj_weight,
|
||||
@@ -470,7 +476,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
expert_idx_per_token,
|
||||
quant_method,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums)};
|
||||
estimate_total_token_nums,
|
||||
hadamard_block_size)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
@@ -485,7 +492,8 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
|
||||
const std::string& quant_method,
|
||||
const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size) {
|
||||
return {permute_input_shape};
|
||||
}
|
||||
|
||||
@@ -499,7 +507,7 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
|
||||
const std::string &quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums) {
|
||||
const int estimate_total_token_nums, const int hadamard_block_size) {
|
||||
if (quant_method == "w4a8" || quant_method == "w4afp8") {
|
||||
return {up_gate_proj_scale_dtype.get()};
|
||||
} else {
|
||||
@@ -555,6 +563,8 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
* Options: "none", "weight_only_int4", "weight_only_int8", "w4a8"
|
||||
* - used_in_ep_low_latency: Whether running in low latency mode
|
||||
* Affects activation function implementation
|
||||
* - estimate_total_token_nums: estimate total token numbers
|
||||
* - hadamard_block_size: hadamard block size for w4a8/w4afp8 quantization
|
||||
*
|
||||
* Note:
|
||||
* - w4a8 mode requires additional workspace memory allocation
|
||||
@@ -571,7 +581,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn)
|
||||
paddle::Optional("down_proj_in_scale"),
|
||||
paddle::Optional("expert_idx_per_token")})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int", "hadamard_block_size:int"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));
|
||||
|
@@ -401,10 +401,6 @@ class LLMEngine:
|
||||
"FLAGS_use_append_attn": 1,
|
||||
"NCCL_ALGO": "Ring",
|
||||
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
|
||||
"FLAGS_hardamard_moe_block_size": int(os.getenv("FLAGS_hardamard_moe_block_size", 128)),
|
||||
"FLAGS_hardamard_use_diagonal_block_matrix": int(
|
||||
os.getenv("FLAGS_hardamard_use_diagonal_block_matrix", 0)
|
||||
),
|
||||
}
|
||||
# environment variables needed by Dy2St
|
||||
variables.update(
|
||||
|
@@ -127,6 +127,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
self.moe_quant_type,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums,
|
||||
getattr(layer.moe_quant_config, "hadamard_block_size", 128),
|
||||
)
|
||||
|
||||
def apply_ep_prefill(
|
||||
|
@@ -38,6 +38,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
has_zero_point: bool = False,
|
||||
is_permuted: bool = True,
|
||||
is_checkpoint_bf16: bool = False,
|
||||
hadamard_block_size: int = 128,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense_quant_type = dense_quant_type
|
||||
@@ -54,6 +55,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
self.quant_round_type = 0
|
||||
self.is_permuted = is_permuted
|
||||
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||
self.hadamard_block_size = hadamard_block_size
|
||||
|
||||
def name(self) -> str:
|
||||
return "mix_quant"
|
||||
@@ -69,6 +71,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
config.get("has_zero_point", False),
|
||||
config.get("is_permuted", True),
|
||||
config.get("is_checkpoint_bf16", False),
|
||||
config.get("hadamard_block_size", 128),
|
||||
)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
@@ -76,13 +79,25 @@ class MixQuantConfig(QuantConfigBase):
|
||||
if layer.moe_tag == "Image":
|
||||
return (
|
||||
get_quantization_config(self.image_moe_quant_type)
|
||||
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.from_config(
|
||||
{
|
||||
"is_permuted": self.is_permuted,
|
||||
"self.is_checkpoint_bf16": self.is_checkpoint_bf16,
|
||||
"hadamard_block_size": self.hadamard_block_size,
|
||||
}
|
||||
)
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
else:
|
||||
return (
|
||||
get_quantization_config(self.moe_quant_type)
|
||||
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.from_config(
|
||||
{
|
||||
"is_permuted": self.is_permuted,
|
||||
"self.is_checkpoint_bf16": self.is_checkpoint_bf16,
|
||||
"hadamard_block_size": self.hadamard_block_size,
|
||||
}
|
||||
)
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
elif isinstance(layer, Attention):
|
||||
|
@@ -25,9 +25,10 @@ class W4A8Config(QuantConfigBase):
|
||||
quantization config for weight 4bits and activation 8bits
|
||||
"""
|
||||
|
||||
def __init__(self, is_permuted) -> None:
|
||||
def __init__(self, is_permuted, hadamard_block_size) -> None:
|
||||
super().__init__()
|
||||
self.is_permuted = is_permuted
|
||||
self.hadamard_block_size = hadamard_block_size
|
||||
|
||||
def name(self) -> str:
|
||||
return "w4a8"
|
||||
@@ -35,7 +36,8 @@ class W4A8Config(QuantConfigBase):
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "W4A8Config":
|
||||
is_permuted = config.get("is_permuted", True)
|
||||
return cls(is_permuted)
|
||||
hadamard_block_size = config.get("hadamard_block_size", 128)
|
||||
return cls(is_permuted, hadamard_block_size)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
|
@@ -31,7 +31,7 @@ class W4AFP8Config(QuantConfigBase):
|
||||
quantization config for weight 4bits and activation fp8
|
||||
"""
|
||||
|
||||
def __init__(self, weight_scale_dict, act_scale_dict, is_permuted) -> None:
|
||||
def __init__(self, weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size) -> None:
|
||||
super().__init__()
|
||||
self.weight_scale_dict = weight_scale_dict
|
||||
self.act_scale_dict = act_scale_dict
|
||||
@@ -39,6 +39,7 @@ class W4AFP8Config(QuantConfigBase):
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
self.is_permuted = is_permuted
|
||||
self.hadamard_block_size = hadamard_block_size
|
||||
|
||||
def name(self) -> str:
|
||||
return "w4afp8"
|
||||
@@ -48,7 +49,8 @@ class W4AFP8Config(QuantConfigBase):
|
||||
weight_scale_dict = config.get("weight_scale_dict", None)
|
||||
act_scale_dict = config.get("act_scale_dict", None)
|
||||
is_permuted = config.get("is_permuted", True)
|
||||
return cls(weight_scale_dict, act_scale_dict, is_permuted)
|
||||
hadamard_block_size = config.get("hadamard_block_size", 128)
|
||||
return cls(weight_scale_dict, act_scale_dict, is_permuted, hadamard_block_size)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
|
@@ -23,10 +23,10 @@ from fastdeploy.entrypoints.llm import LLM
|
||||
|
||||
bash_path = os.getenv("MODEL_PATH")
|
||||
FD_ENGINE_QUEUE_PORTS = [
|
||||
[9961, 9962, 9963, 9964, 9965, 9966, 9967, 9968],
|
||||
[9971, 9972, 9973, 9974, 9975, 9976, 9977, 9978],
|
||||
[9981, 9982, 9983, 9984, 9985, 9986, 9987, 9988],
|
||||
[9991, 9992, 9993, 9994, 9995, 9996, 9997, 9998],
|
||||
[9961, 9962],
|
||||
[9971, 9972],
|
||||
[9981, 9982],
|
||||
[9991, 9992],
|
||||
]
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ def llm(request):
|
||||
llm_instance = LLM(
|
||||
model=model_path,
|
||||
tensor_parallel_size=1,
|
||||
data_parallel_size=8,
|
||||
data_parallel_size=2,
|
||||
max_model_len=8192,
|
||||
num_gpu_blocks_override=1024,
|
||||
engine_worker_queue_port=FD_ENGINE_QUEUE_PORTS[port_index],
|
||||
@@ -58,7 +58,7 @@ def llm(request):
|
||||
)
|
||||
yield weakref.proxy(llm_instance)
|
||||
except Exception as e:
|
||||
pytest.skip(f"LLM initialization failed: {e}")
|
||||
assert False, f"LLM initialization failed: {e}"
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
|
Reference in New Issue
Block a user