diff --git a/fastdeploy/config.py b/fastdeploy/config.py index b86060dcf..9b0192ec8 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -749,7 +749,6 @@ class LoadChoices(str, Enum): """LoadChoices""" DEFAULT = "default" - # only support qwen3-bf16 now DEFAULT_V1 = "default_v1" diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index b53302093..5bbce2982 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -22,6 +22,7 @@ import fastdeploy from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm +from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs from fastdeploy.utils import ceil_div from .fused_moe_backend_base import MoEMethodBase @@ -36,64 +37,170 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): """ deepgemm create weight process. """ - self.weight_dtype = paddle.float8_e4m3fn - up_gate_proj_weight_name = self.added_weight_attrs[0] - down_proj_weight_name = self.added_weight_attrs[1] - self.ffn1_weight_shape = [ + self.up_gate_proj_weight_shape = [ layer.num_local_experts, layer.moe_intermediate_size * 2, layer.hidden_size, ] - self.ffn2_weight_shape = [ + self.down_proj_weight_shape = [ layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size, ] - setattr( - layer, - up_gate_proj_weight_name, - layer.create_parameter( - shape=self.ffn1_weight_shape, - dtype=self.weight_dtype, + self.up_gate_proj_scale_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0], + layer.hidden_size // self.quant_config.weight_block_size[1], + ] + self.down_proj_scale_shape = [ + layer.num_local_experts, + layer.hidden_size // self.quant_config.weight_block_size[0], + layer.moe_intermediate_size // self.quant_config.weight_block_size[1], + ] + if self.quant_config.is_checkpoint_bf16: + layer.up_gate_proj_weight = layer.create_parameter( + shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2], + dtype=layer.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - down_proj_weight_name, - layer.create_parameter( - shape=self.ffn2_weight_shape, - dtype=self.weight_dtype, + ) + + layer.down_proj_weight = layer.create_parameter( + shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size], + dtype=layer.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - # weight_scale + ) + set_weight_attrs( + layer.up_gate_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True), + }, + ) + set_weight_attrs( + layer.down_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False), + }, + ) + else: + self.weight_dtype = paddle.float8_e4m3fn + self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"] + up_gate_proj_weight_name = self.added_weight_attrs[0] + down_proj_weight_name = self.added_weight_attrs[1] + up_gate_proj_scale_name = self.added_scale_attrs[0] + down_proj_scale_name = self.added_scale_attrs[1] + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + up_gate_proj_scale_name, + layer.create_parameter( + shape=self.up_gate_proj_scale_shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_scale_name, + layer.create_parameter( + shape=self.down_proj_scale_shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + def process_weights_after_loading(self, layer): + """ """ + if not self.quant_config.is_checkpoint_bf16: + return + weight_id_map = {"gate_up": 0, "down": 1} + if ( + hasattr(layer.up_gate_proj_weight, "tensor_track") + and layer.up_gate_proj_weight.tensor_track is not None + and layer.up_gate_proj_weight.tensor_track.is_fully_copied() + ): + weight_type = "gate_up" + layer.up_gate_proj_weight.tensor_track = None + else: + weight_type = "down" + layer.down_proj_weight.tensor_track = None + + # 1.init shape and type + self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"] + # weight + weight_name = self.added_weight_attrs[weight_id_map[weight_type]] + unquantized_weight_name = weight_name.replace("quant_weight", "weight") + weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape + weight_dtype = paddle.float8_e4m3fn + # scale + scale_name = self.added_scale_attrs[weight_id_map[weight_type]] + scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape + scale_dtype = "float32" + + # 2.crate tmp tensor + + weight = paddle.empty(shape=[weight_shape[0], weight_shape[2], weight_shape[1]], dtype=weight_dtype) + scale = paddle.empty(shape=[scale_shape[0], scale_shape[2], scale_shape[1]], dtype=scale_dtype) + + # 3.quantize weight + from fastdeploy.model_executor.layers.utils import per_block_cast_to_fp8 + + for expert_id in range(layer.num_experts): + weight_quant, scale[expert_id] = per_block_cast_to_fp8( + getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size + ) + weight[expert_id].copy_(weight_quant, False) + getattr(layer, unquantized_weight_name).value().get_tensor()._clear() + + # create weight setattr( layer, - self.added_scale_attrs[0], + weight_name, layer.create_parameter( shape=[ layer.num_local_experts, ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]), ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]), ], - dtype="float32", + dtype=weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), ) + # create scale setattr( layer, - self.added_scale_attrs[1], + scale_name, layer.create_parameter( shape=[ layer.num_local_experts, ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]), ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]), ], - dtype="float32", + dtype=scale_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), ) + getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]).contiguous(), False) + getattr(layer, scale_name).copy_(scale.transpose([0, 2, 1]).contiguous(), False) def process_loaded_weights(self, layer: nn.Layer, state_dict): """ @@ -244,12 +351,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): # up_gate_proj ffn_out = paddle.empty( - (permute_input.shape[0], layer.up_gate_proj_weight.shape[1]), + (permute_input.shape[0], getattr(layer, self.added_weight_attrs[0]).shape[1]), dtype=paddle.bfloat16, ) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (permute_input, permute_scale), - (layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale), + (getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_scale_attrs[0])), ffn_out, m_indices, ) @@ -264,12 +371,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]) ffn_out = paddle.empty( - (ffn_out.shape[0], layer.down_proj_weight.shape[1]), + (ffn_out.shape[0], getattr(layer, self.added_weight_attrs[1]).shape[1]), dtype=paddle.bfloat16, ) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (ffn_in_x, ffn_in_x_scale_tensor), - (layer.down_proj_weight, layer.down_proj_weight_scale), + (getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_scale_attrs[1])), ffn_out, m_indices, ) @@ -331,8 +438,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( permute_input, ( - layer.up_gate_proj_weight, - layer.up_gate_proj_weight_scale, + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_scale_attrs[0]), ), up_gate_proj_out, token_nums_per_expert, @@ -350,8 +457,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( (act_out_fp8, scale), ( - layer.down_proj_weight, - layer.down_proj_weight_scale, + getattr(layer, self.added_weight_attrs[1]), + getattr(layer, self.added_scale_attrs[1]), ), ffn_out, token_nums_per_expert, @@ -423,12 +530,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): # up_gate_proj ffn_out = paddle.empty( - (permute_input.shape[0], layer.up_gate_proj_weight.shape[1]), + (permute_input.shape[0], getattr(layer, self.added_weight_attrs[0]).shape[1]), dtype=paddle.bfloat16, ) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (permute_input, permute_scale), - (layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale), + (getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_scale_attrs[0])), ffn_out, m_indices, ) @@ -444,12 +551,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]) ffn_out = paddle.empty( - (ffn_out.shape[0], layer.down_proj_weight.shape[1]), + (ffn_out.shape[0], getattr(layer, self.added_weight_attrs[1]).shape[1]), dtype=paddle.bfloat16, ) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (ffn_in_x, ffn_in_x_scale_tensor), - (layer.down_proj_weight, layer.down_proj_weight_scale), + (getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_scale_attrs[1])), ffn_out, m_indices, ) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 5526df882..7d99b5d9f 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -20,6 +20,7 @@ from paddle import nn import fastdeploy from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs from fastdeploy.utils import ceil_div from ..quantization.quant_base import QuantMethodBase @@ -604,64 +605,170 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): """ Triton MoE create weight process. """ - self.weight_dtype = paddle.float8_e4m3fn - up_gate_proj_weight_name = self.added_weight_attrs[0] - down_proj_weight_name = self.added_weight_attrs[1] - self.ffn1_weight_shape = [ + self.up_gate_proj_weight_shape = [ layer.num_local_experts, layer.moe_intermediate_size * 2, layer.hidden_size, ] - self.ffn2_weight_shape = [ + self.down_proj_weight_shape = [ layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size, ] - setattr( - layer, - up_gate_proj_weight_name, - layer.create_parameter( - shape=self.ffn1_weight_shape, - dtype=self.weight_dtype, + self.up_gate_proj_scale_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0], + layer.hidden_size // self.quant_config.weight_block_size[1], + ] + self.down_proj_scale_shape = [ + layer.num_local_experts, + layer.hidden_size // self.quant_config.weight_block_size[0], + layer.moe_intermediate_size // self.quant_config.weight_block_size[1], + ] + if self.quant_config.is_checkpoint_bf16: + layer.up_gate_proj_weight = layer.create_parameter( + shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2], + dtype=layer.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - down_proj_weight_name, - layer.create_parameter( - shape=self.ffn2_weight_shape, - dtype=self.weight_dtype, + ) + + layer.down_proj_weight = layer.create_parameter( + shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size], + dtype=layer.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - # weight_scale + ) + set_weight_attrs( + layer.up_gate_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True), + }, + ) + set_weight_attrs( + layer.down_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False), + }, + ) + else: + self.weight_dtype = paddle.float8_e4m3fn + self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"] + up_gate_proj_weight_name = self.added_weight_attrs[0] + down_proj_weight_name = self.added_weight_attrs[1] + up_gate_proj_scale_name = self.added_scale_attrs[0] + down_proj_scale_name = self.added_scale_attrs[1] + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + up_gate_proj_scale_name, + layer.create_parameter( + shape=self.up_gate_proj_scale_shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_scale_name, + layer.create_parameter( + shape=self.down_proj_scale_shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + def process_weights_after_loading(self, layer): + """ """ + if not self.quant_config.is_checkpoint_bf16: + return + weight_id_map = {"gate_up": 0, "down": 1} + if ( + hasattr(layer.up_gate_proj_weight, "tensor_track") + and layer.up_gate_proj_weight.tensor_track is not None + and layer.up_gate_proj_weight.tensor_track.is_fully_copied() + ): + weight_type = "gate_up" + layer.up_gate_proj_weight.tensor_track = None + else: + weight_type = "down" + layer.down_proj_weight.tensor_track = None + + # 1.init shape and type + self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"] + # weight + weight_name = self.added_weight_attrs[weight_id_map[weight_type]] + unquantized_weight_name = weight_name.replace("quant_weight", "weight") + weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape + weight_dtype = paddle.float8_e4m3fn + # scale + scale_name = self.added_scale_attrs[weight_id_map[weight_type]] + scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape + scale_dtype = "float32" + + # 2.crate tmp tensor + + weight = paddle.empty(shape=[weight_shape[0], weight_shape[2], weight_shape[1]], dtype=weight_dtype) + scale = paddle.empty(shape=[scale_shape[0], scale_shape[2], scale_shape[1]], dtype=scale_dtype) + + # 3.quantize weight + from fastdeploy.model_executor.layers.utils import per_block_cast_to_fp8 + + for expert_id in range(layer.num_experts): + weight_quant, scale[expert_id] = per_block_cast_to_fp8( + getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size + ) + weight[expert_id].copy_(weight_quant, False) + getattr(layer, unquantized_weight_name).value().get_tensor()._clear() + + # create weight setattr( layer, - self.added_scale_attrs[0], + weight_name, layer.create_parameter( shape=[ layer.num_local_experts, ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]), ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]), ], - dtype="float32", + dtype=weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), ) + # create scale setattr( layer, - self.added_scale_attrs[1], + scale_name, layer.create_parameter( shape=[ layer.num_local_experts, ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]), ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]), ], - dtype="float32", + dtype=scale_dtype, default_initializer=paddle.nn.initializer.Constant(0), ), ) + getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]).contiguous(), False) + getattr(layer, scale_name).copy_(scale.transpose([0, 2, 1]).contiguous(), False) def process_loaded_weights(self, layer: nn.Layer, state_dict): """ @@ -719,8 +826,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): num_local_experts = layer.num_local_experts moe_intermediate_size = layer.moe_intermediate_size hidden_size = layer.hidden_size - E, N1, _ = layer.up_gate_proj_weight.shape - N2 = layer.down_proj_weight.shape[1] + E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape + N2 = getattr(layer, self.added_weight_attrs[1]).shape[1] topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, @@ -759,10 +866,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): fused_moe_kernel_paddle[grid]( x_q, - layer.up_gate_proj_weight, + getattr(layer, self.added_weight_attrs[0]), intermediate_cache1, x_scale, - layer.up_gate_proj_weight_scale, + getattr(layer, self.added_scale_attrs[0]), None, sorted_token_ids, expert_ids, @@ -773,17 +880,17 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): K=hidden_size, stride_am=x_q.strides[0], stride_ak=x_q.strides[1], - stride_be=layer.up_gate_proj_weight.strides[0], - stride_bk=layer.up_gate_proj_weight.strides[2], - stride_bn=layer.up_gate_proj_weight.strides[1], + stride_be=getattr(layer, self.added_weight_attrs[0]).strides[0], + stride_bk=getattr(layer, self.added_weight_attrs[0]).strides[2], + stride_bn=getattr(layer, self.added_weight_attrs[0]).strides[1], stride_cm=intermediate_cache1.strides[0], stride_cn=intermediate_cache1.strides[1], # stride_asm=x_scale.strides[0], # only used in blockwise fp8 stride_ask=x_scale.strides[1], # only used in blockwise fp8 - stride_bse=layer.up_gate_proj_weight_scale.strides[0], - stride_bsk=layer.up_gate_proj_weight_scale.strides[2], - stride_bsn=layer.up_gate_proj_weight_scale.strides[1], + stride_bse=getattr(layer, self.added_scale_attrs[0]).strides[0], + stride_bsk=getattr(layer, self.added_scale_attrs[0]).strides[2], + stride_bsn=getattr(layer, self.added_scale_attrs[0]).strides[1], group_n=self.quant_config.weight_block_size[1], group_k=self.quant_config.weight_block_size[0], # Meta-parameters @@ -813,10 +920,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): fused_moe_kernel_paddle[grid]( x_q, - layer.down_proj_weight, + getattr(layer, self.added_weight_attrs[1]), intermediate_cache3, x_scale, - layer.down_proj_weight_scale, + getattr(layer, self.added_scale_attrs[1]), topk_weights, sorted_token_ids, expert_ids, @@ -827,16 +934,16 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): K=moe_intermediate_size, stride_am=x_q.strides[0], stride_ak=x_q.strides[1], - stride_be=layer.down_proj_weight.strides[0], - stride_bk=layer.down_proj_weight.strides[2], - stride_bn=layer.down_proj_weight.strides[1], + stride_be=getattr(layer, self.added_weight_attrs[1]).strides[0], + stride_bk=getattr(layer, self.added_weight_attrs[1]).strides[2], + stride_bn=getattr(layer, self.added_weight_attrs[1]).strides[1], stride_cm=intermediate_cache3.strides[0], stride_cn=intermediate_cache3.strides[1], stride_asm=x_scale.strides[0], # only used in blockwise fp8 stride_ask=x_scale.strides[1], # only used in blockwise fp8 - stride_bse=layer.down_proj_weight_scale.strides[0], - stride_bsk=layer.down_proj_weight_scale.strides[2], - stride_bsn=layer.down_proj_weight_scale.strides[1], + stride_bse=getattr(layer, self.added_scale_attrs[1]).strides[0], + stride_bsk=getattr(layer, self.added_scale_attrs[1]).strides[2], + stride_bsn=getattr(layer, self.added_scale_attrs[1]).strides[1], group_n=self.quant_config.weight_block_size[1], group_k=self.quant_config.weight_block_size[0], # Meta-parameters diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index f76ff8ca4..c3f503590 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -20,7 +20,12 @@ import paddle import fastdeploy from fastdeploy import envs +from fastdeploy.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, +) from fastdeploy.model_executor.layers.moe import FusedMoE +from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs from ..utils import get_tensor, per_block_cast_to_fp8 from .quant_base import QuantConfigBase, QuantMethodBase @@ -33,13 +38,14 @@ class BlockWiseFP8Config(QuantConfigBase): per-token quantization of activations during inference. """ - def __init__(self, weight_block_size: list = [-1, -1]) -> None: + def __init__(self, weight_block_size: list = [-1, -1], is_checkpoint_bf16: bool = False) -> None: super().__init__() self.weight_block_size = weight_block_size self.quant_max_bound = 448 self.quant_min_bound = -448 self.quant_round_type = 1 self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM) + self.is_checkpoint_bf16 = is_checkpoint_bf16 def name(self) -> str: return "block_wise_fp8" @@ -47,7 +53,8 @@ class BlockWiseFP8Config(QuantConfigBase): @classmethod def from_config(cls, config: dict) -> "BlockWiseFP8Config": weight_block_size = config.get("weight_block_size", [128, 128]) - return cls(weight_block_size) + is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False) + return cls(weight_block_size, is_checkpoint_bf16) def get_quant_method(self, layer) -> Optional[QuantMethodBase]: """ @@ -82,31 +89,78 @@ class BlockWiseFP8LinearMethod(QuantMethodBase): self.quant_config = quant_config def create_weights(self, layer, **extra_weight_attrs): - layer.weight_shape.reverse() - layer.weight_dtype = "float8_e4m3fn" + if self.quant_config.is_checkpoint_bf16: + layer.weight = layer.create_parameter( + shape=layer.weight_shape, + dtype=layer.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + quant_attrs = extra_weight_attrs + if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear): + quant_attrs = { + **extra_weight_attrs, + "tensor_track": TensorTracker( + shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim") + ), + } + set_weight_attrs( + layer.weight, + quant_attrs, + ) + else: + layer.weight_shape.reverse() + layer.weight_dtype = "float8_e4m3fn" + layer.weight = layer.create_parameter( + shape=layer.weight_shape, + dtype=layer.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + layer.weight_scale_inv = layer.create_parameter( + shape=[ + (layer.output_size + self.quant_config.weight_block_size[0] - 1) + // self.quant_config.weight_block_size[0], + (layer.input_size + self.quant_config.weight_block_size[1] - 1) + // self.quant_config.weight_block_size[1], + ], + dtype="float32", + is_bias=False, + ) + + def process_weights_after_loading(self, layer) -> None: + if not self.quant_config.is_checkpoint_bf16: + return + weight_tensor = layer.weight.transpose([1, 0]) + quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor) + + if hasattr(layer.weight, "tensor_track"): + layer.weight.tensor_track = None + layer.weight.value().get_tensor()._clear() + del layer.weight + layer.weight = layer.create_parameter( - shape=layer.weight_shape, - dtype=layer.weight_dtype, + shape=quanted_weight_tensor.shape, + dtype="float8_e4m3fn", + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.weight_scale_inv = layer.create_parameter( + shape=weight_block_scale_tensor.shape, + dtype="float32", is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) - layer.weight_scale = layer.create_parameter( - shape=[ - (layer.output_size + self.quant_config.weight_block_size[0] - 1) - // self.quant_config.weight_block_size[0], - (layer.input_size + self.quant_config.weight_block_size[1] - 1) - // self.quant_config.weight_block_size[1], - ], - dtype="float32", - is_bias=False, - ) + layer.weight.copy_(quanted_weight_tensor, False) + layer.weight_scale_inv.copy_(weight_block_scale_tensor, False) def process_loaded_weights(self, layer, weights) -> None: weight_tensor = weights.transpose([1, 0]) quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor) layer.weight.copy_(quanted_weight_tensor, False) - layer.weight_scale.set_value(weight_block_scale_tensor) + layer.weight_scale_inv.set_value(weight_block_scale_tensor) def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False): """ @@ -119,7 +173,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase): layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False) weight_scale = weight_scale.transpose([1, 0]) - layer.weight_scale.set_value(weight_scale) + layer.weight_scale_inv.set_value(weight_scale) def apply(self, layer, x): x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding( @@ -130,7 +184,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase): deep_gemm.gemm_fp8_fp8_bf16_nt( (x, x_scale_tensor), - (layer.weight, layer.weight_scale), + (layer.weight, layer.weight_scale_inv), linear_out, ) if layer.with_bias: diff --git a/fastdeploy/model_executor/layers/quantization/mix_quant.py b/fastdeploy/model_executor/layers/quantization/mix_quant.py index f9c3a42f8..05c456d55 100644 --- a/fastdeploy/model_executor/layers/quantization/mix_quant.py +++ b/fastdeploy/model_executor/layers/quantization/mix_quant.py @@ -37,6 +37,7 @@ class MixQuantConfig(QuantConfigBase): is_channel_wise: bool = False, has_zero_point: bool = False, is_permuted: bool = True, + is_checkpoint_bf16: bool = False, ) -> None: super().__init__() self.dense_quant_type = dense_quant_type @@ -52,6 +53,7 @@ class MixQuantConfig(QuantConfigBase): self.quant_min_bound = 0 self.quant_round_type = 0 self.is_permuted = is_permuted + self.is_checkpoint_bf16 = is_checkpoint_bf16 def name(self) -> str: return "mix_quant" @@ -66,6 +68,7 @@ class MixQuantConfig(QuantConfigBase): config.get("is_channel_wise", False), config.get("has_zero_point", False), config.get("is_permuted", True), + config.get("is_checkpoint_bf16", False), ) def get_quant_method(self, layer) -> Optional[QuantMethodBase]: @@ -73,13 +76,13 @@ class MixQuantConfig(QuantConfigBase): if layer.moe_tag == "Image": return ( get_quantization_config(self.image_moe_quant_type) - .from_config({"is_permuted": self.is_permuted}) + .from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16}) .get_quant_method(layer) ) else: return ( get_quantization_config(self.moe_quant_type) - .from_config({"is_permuted": self.is_permuted}) + .from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16}) .get_quant_method(layer) ) elif isinstance(layer, Attention): @@ -92,4 +95,8 @@ class MixQuantConfig(QuantConfigBase): else: return None else: - return get_quantization_config(self.dense_quant_type).from_config({}).get_quant_method(layer) + return ( + get_quantization_config(self.dense_quant_type) + .from_config({"self.is_checkpoint_bf16": self.is_checkpoint_bf16}) + .get_quant_method(layer) + ) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 71c314841..b63cca841 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -660,6 +660,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: quantization_config = {} quant_config_name = args.quantization quantization_config["quantization"] = quant_config_name + # Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization. + if load_config.load_choices == "default_v1": + quantization_config["is_checkpoint_bf16"] = True # Special handling for Ernie models is_ernie = ErnieArchitectures.contains_ernie_arch(model_config.architectures) if quant_config_name == "wint4" and is_ernie: diff --git a/tests/model_loader/test_common_model.py b/tests/model_loader/test_common_model.py index 3b3f4fbf7..b7c918411 100644 --- a/tests/model_loader/test_common_model.py +++ b/tests/model_loader/test_common_model.py @@ -13,12 +13,14 @@ # limitations under the License. import os +import shutil import traceback import warnings from multiprocessing import Process, Queue import pytest +os.environ["LOAD_STATE_DICT_THREAD_NUM"] = "1" FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313)) MAX_WAIT_SECONDS = 60 * 5 @@ -46,6 +48,33 @@ def get_model_paths(base_model_name: str) -> tuple[str, str]: return fd_model_path, torch_model_path +def clear_logs(): + log_path = os.path.join(os.getcwd(), "log") + if os.path.exists(log_path): + try: + shutil.rmtree(log_path) + print(f"Deleted log directory: {log_path}") + except Exception as e: + print(f"Failed to delete log directory {log_path}: {e}") + else: + print(f"No log directory found at {log_path}") + + +def print_logs(): + log_dir = os.path.join(os.getcwd(), "log") + log_file = os.path.join(log_dir, "workerlog.0") + + if not os.path.exists(log_file): + print(f"Log file {log_file} does not exist.") + return + + print(f"\n===== {log_file} start =====") + with open(log_file, "r") as f: + for line in f: + print(line, end="") + print(f"\n===== {log_file} end =====\n") + + def check_tokens_id_and_text_close( *, outputs_0_lst: TokensIdText, @@ -110,37 +139,72 @@ def form_model_get_output( pytest.fail(f"Failed to initialize LLM model from {model_path}") +def run_with_timeout(target, args, timeout=60 * 5): + clear_logs() + result_queue = Queue() + p = Process(target=target, args=(*args, result_queue)) + p.start() + p.join(timeout) + if p.is_alive(): + p.terminate() + print_logs() + raise RuntimeError("Worker process hung and was terminated") + try: + return result_queue.get(timeout=60) + except Exception as e: + raise RuntimeError(f"Failed to get result from worker: {e}") + + model_param_map = { "Qwen3-0.6B": { "quantizations": ["None", "wint4", "wint8"], }, "ernie-4_5-21b-a3b-bf16-paddle": { "tensor_parallel_size": 2, - "quantizations": ["wint8"], + "quantizations": [ + "wint8", + ], }, "Qwen2-7B-Instruct": { "quantizations": ["None", "wint8"], }, + "Qwen3-30B-A3B": { + "tensor_parallel_size": 2, + "quantizations": [ + { + "quant_type": "block_wise_fp8", + "backend": "triton", + "env": {"FD_USE_DEEP_GEMM": "0", "DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}, + }, + {"quant_type": "block_wise_fp8", "backend": "deepgemm", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}}, + ], + }, } params = [] for model, cfg in model_param_map.items(): for q in cfg["quantizations"]: + if isinstance(q, dict): + quant, backend, env = q["quant_type"], q.get("backend", "default"), q.get("env", {}) + else: + quant, backend, env = q, "default", {} params.append( pytest.param( model, cfg.get("tensor_parallel_size", 1), cfg.get("max_model_len", 1024), - q, + quant, cfg.get("max_tokens", 32), + env, marks=[pytest.mark.core_model], + id=f"{model}.{quant}.{backend}", ) ) @pytest.mark.parametrize( - "model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens", + "model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens,env", params, ) def test_common_model( @@ -150,46 +214,26 @@ def test_common_model( max_model_len: int, max_tokens: int, quantization: str, + env, + monkeypatch, ) -> None: base_path = os.getenv("MODEL_PATH") if base_path: model_path = os.path.join(base_path, model_name_or_path) else: model_path = model_name_or_path - result_queue = Queue() - p = Process( - target=form_model_get_output, - args=( - fd_runner, - model_path, - tensor_parallel_size, - max_model_len, - max_tokens, - quantization, - "default", - result_queue, - ), - ) - p.start() - p.join() - fd_outputs_v0 = result_queue.get(timeout=60) + if env: + for k, v in env.items(): + monkeypatch.setenv(k, v) - p = Process( + fd_outputs_v0 = run_with_timeout( target=form_model_get_output, - args=( - fd_runner, - model_path, - tensor_parallel_size, - max_model_len, - max_tokens, - quantization, - "default_v1", - result_queue, - ), + args=(fd_runner, model_path, tensor_parallel_size, max_model_len, max_tokens, quantization, "default"), + ) + fd_outputs_v1 = run_with_timeout( + target=form_model_get_output, + args=(fd_runner, model_path, tensor_parallel_size, max_model_len, max_tokens, quantization, "default_v1"), ) - p.start() - p.join() - fd_outputs_v1 = result_queue.get(timeout=60) check_tokens_id_and_text_close( outputs_0_lst=fd_outputs_v0, outputs_1_lst=fd_outputs_v1, @@ -235,26 +279,12 @@ def test_paddle_vs_torch_model( fd_model_path, torch_model_path = get_model_paths(model_name_or_path) - result_queue = Queue() - - p_paddle = Process( + paddle_outputs = run_with_timeout( target=form_model_get_output, - args=( - fd_runner, - fd_model_path, - tensor_parallel_size, - max_model_len, - max_tokens, - quantization, - "default", - result_queue, - ), + args=(fd_runner, fd_model_path, tensor_parallel_size, max_model_len, max_tokens, quantization, "default"), ) - p_paddle.start() - p_paddle.join() - paddle_outputs = result_queue.get(timeout=60) - p_hf = Process( + hf_outputs = run_with_timeout( target=form_model_get_output, args=( fd_runner, @@ -264,12 +294,8 @@ def test_paddle_vs_torch_model( max_tokens, quantization, "default_v1", - result_queue, ), ) - p_hf.start() - p_hf.join() - hf_outputs = result_queue.get(timeout=60) check_tokens_id_and_text_close( outputs_0_lst=paddle_outputs,