[v1 loader]support fp8 (#3593)

* support fp8

* update ci
This commit is contained in:
bukejiyu
2025-08-26 17:42:46 +08:00
committed by GitHub
parent 00898603c8
commit 3200a80de3
7 changed files with 463 additions and 160 deletions

View File

@@ -749,7 +749,6 @@ class LoadChoices(str, Enum):
"""LoadChoices""" """LoadChoices"""
DEFAULT = "default" DEFAULT = "default"
# only support qwen3-bf16 now
DEFAULT_V1 = "default_v1" DEFAULT_V1 = "default_v1"

View File

@@ -22,6 +22,7 @@ import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
from fastdeploy.utils import ceil_div from fastdeploy.utils import ceil_div
from .fused_moe_backend_base import MoEMethodBase from .fused_moe_backend_base import MoEMethodBase
@@ -36,64 +37,170 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
""" """
deepgemm create weight process. deepgemm create weight process.
""" """
self.weight_dtype = paddle.float8_e4m3fn self.up_gate_proj_weight_shape = [
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
self.ffn1_weight_shape = [
layer.num_local_experts, layer.num_local_experts,
layer.moe_intermediate_size * 2, layer.moe_intermediate_size * 2,
layer.hidden_size, layer.hidden_size,
] ]
self.ffn2_weight_shape = [ self.down_proj_weight_shape = [
layer.num_local_experts, layer.num_local_experts,
layer.hidden_size, layer.hidden_size,
layer.moe_intermediate_size, layer.moe_intermediate_size,
] ]
setattr( self.up_gate_proj_scale_shape = [
layer, layer.num_local_experts,
up_gate_proj_weight_name, layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0],
layer.create_parameter( layer.hidden_size // self.quant_config.weight_block_size[1],
shape=self.ffn1_weight_shape, ]
dtype=self.weight_dtype, self.down_proj_scale_shape = [
layer.num_local_experts,
layer.hidden_size // self.quant_config.weight_block_size[0],
layer.moe_intermediate_size // self.quant_config.weight_block_size[1],
]
if self.quant_config.is_checkpoint_bf16:
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
), )
)
setattr( layer.down_proj_weight = layer.create_parameter(
layer, shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size],
down_proj_weight_name, dtype=layer.weight_dtype,
layer.create_parameter(
shape=self.ffn2_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
), )
) set_weight_attrs(
# weight_scale layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
)
else:
self.weight_dtype = paddle.float8_e4m3fn
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
up_gate_proj_scale_name = self.added_scale_attrs[0]
down_proj_scale_name = self.added_scale_attrs[1]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
up_gate_proj_scale_name,
layer.create_parameter(
shape=self.up_gate_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_scale_name,
layer.create_parameter(
shape=self.down_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
layer.up_gate_proj_weight.tensor_track = None
else:
weight_type = "down"
layer.down_proj_weight.tensor_track = None
# 1.init shape and type
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = paddle.float8_e4m3fn
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = "float32"
# 2.crate tmp tensor
weight = paddle.empty(shape=[weight_shape[0], weight_shape[2], weight_shape[1]], dtype=weight_dtype)
scale = paddle.empty(shape=[scale_shape[0], scale_shape[2], scale_shape[1]], dtype=scale_dtype)
# 3.quantize weight
from fastdeploy.model_executor.layers.utils import per_block_cast_to_fp8
for expert_id in range(layer.num_experts):
weight_quant, scale[expert_id] = per_block_cast_to_fp8(
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
)
weight[expert_id].copy_(weight_quant, False)
getattr(layer, unquantized_weight_name).value().get_tensor()._clear()
# create weight
setattr( setattr(
layer, layer,
self.added_scale_attrs[0], weight_name,
layer.create_parameter( layer.create_parameter(
shape=[ shape=[
layer.num_local_experts, layer.num_local_experts,
ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]), ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]),
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]), ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]),
], ],
dtype="float32", dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
), ),
) )
# create scale
setattr( setattr(
layer, layer,
self.added_scale_attrs[1], scale_name,
layer.create_parameter( layer.create_parameter(
shape=[ shape=[
layer.num_local_experts, layer.num_local_experts,
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]), ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]), ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
], ],
dtype="float32", dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
), ),
) )
getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]).contiguous(), False)
getattr(layer, scale_name).copy_(scale.transpose([0, 2, 1]).contiguous(), False)
def process_loaded_weights(self, layer: nn.Layer, state_dict): def process_loaded_weights(self, layer: nn.Layer, state_dict):
""" """
@@ -244,12 +351,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
# up_gate_proj # up_gate_proj
ffn_out = paddle.empty( ffn_out = paddle.empty(
(permute_input.shape[0], layer.up_gate_proj_weight.shape[1]), (permute_input.shape[0], getattr(layer, self.added_weight_attrs[0]).shape[1]),
dtype=paddle.bfloat16, dtype=paddle.bfloat16,
) )
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(permute_input, permute_scale), (permute_input, permute_scale),
(layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale), (getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_scale_attrs[0])),
ffn_out, ffn_out,
m_indices, m_indices,
) )
@@ -264,12 +371,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]) ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
ffn_out = paddle.empty( ffn_out = paddle.empty(
(ffn_out.shape[0], layer.down_proj_weight.shape[1]), (ffn_out.shape[0], getattr(layer, self.added_weight_attrs[1]).shape[1]),
dtype=paddle.bfloat16, dtype=paddle.bfloat16,
) )
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(ffn_in_x, ffn_in_x_scale_tensor), (ffn_in_x, ffn_in_x_scale_tensor),
(layer.down_proj_weight, layer.down_proj_weight_scale), (getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_scale_attrs[1])),
ffn_out, ffn_out,
m_indices, m_indices,
) )
@@ -331,8 +438,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
permute_input, permute_input,
( (
layer.up_gate_proj_weight, getattr(layer, self.added_weight_attrs[0]),
layer.up_gate_proj_weight_scale, getattr(layer, self.added_scale_attrs[0]),
), ),
up_gate_proj_out, up_gate_proj_out,
token_nums_per_expert, token_nums_per_expert,
@@ -350,8 +457,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
(act_out_fp8, scale), (act_out_fp8, scale),
( (
layer.down_proj_weight, getattr(layer, self.added_weight_attrs[1]),
layer.down_proj_weight_scale, getattr(layer, self.added_scale_attrs[1]),
), ),
ffn_out, ffn_out,
token_nums_per_expert, token_nums_per_expert,
@@ -423,12 +530,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
# up_gate_proj # up_gate_proj
ffn_out = paddle.empty( ffn_out = paddle.empty(
(permute_input.shape[0], layer.up_gate_proj_weight.shape[1]), (permute_input.shape[0], getattr(layer, self.added_weight_attrs[0]).shape[1]),
dtype=paddle.bfloat16, dtype=paddle.bfloat16,
) )
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(permute_input, permute_scale), (permute_input, permute_scale),
(layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale), (getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_scale_attrs[0])),
ffn_out, ffn_out,
m_indices, m_indices,
) )
@@ -444,12 +551,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0]) ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
ffn_out = paddle.empty( ffn_out = paddle.empty(
(ffn_out.shape[0], layer.down_proj_weight.shape[1]), (ffn_out.shape[0], getattr(layer, self.added_weight_attrs[1]).shape[1]),
dtype=paddle.bfloat16, dtype=paddle.bfloat16,
) )
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(ffn_in_x, ffn_in_x_scale_tensor), (ffn_in_x, ffn_in_x_scale_tensor),
(layer.down_proj_weight, layer.down_proj_weight_scale), (getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_scale_attrs[1])),
ffn_out, ffn_out,
m_indices, m_indices,
) )

View File

@@ -20,6 +20,7 @@ from paddle import nn
import fastdeploy import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
from fastdeploy.utils import ceil_div from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase
@@ -604,64 +605,170 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
""" """
Triton MoE create weight process. Triton MoE create weight process.
""" """
self.weight_dtype = paddle.float8_e4m3fn self.up_gate_proj_weight_shape = [
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
self.ffn1_weight_shape = [
layer.num_local_experts, layer.num_local_experts,
layer.moe_intermediate_size * 2, layer.moe_intermediate_size * 2,
layer.hidden_size, layer.hidden_size,
] ]
self.ffn2_weight_shape = [ self.down_proj_weight_shape = [
layer.num_local_experts, layer.num_local_experts,
layer.hidden_size, layer.hidden_size,
layer.moe_intermediate_size, layer.moe_intermediate_size,
] ]
setattr( self.up_gate_proj_scale_shape = [
layer, layer.num_local_experts,
up_gate_proj_weight_name, layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0],
layer.create_parameter( layer.hidden_size // self.quant_config.weight_block_size[1],
shape=self.ffn1_weight_shape, ]
dtype=self.weight_dtype, self.down_proj_scale_shape = [
layer.num_local_experts,
layer.hidden_size // self.quant_config.weight_block_size[0],
layer.moe_intermediate_size // self.quant_config.weight_block_size[1],
]
if self.quant_config.is_checkpoint_bf16:
layer.up_gate_proj_weight = layer.create_parameter(
shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
), )
)
setattr( layer.down_proj_weight = layer.create_parameter(
layer, shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size],
down_proj_weight_name, dtype=layer.weight_dtype,
layer.create_parameter(
shape=self.ffn2_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
), )
) set_weight_attrs(
# weight_scale layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
)
else:
self.weight_dtype = paddle.float8_e4m3fn
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
up_gate_proj_scale_name = self.added_scale_attrs[0]
down_proj_scale_name = self.added_scale_attrs[1]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
up_gate_proj_scale_name,
layer.create_parameter(
shape=self.up_gate_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_scale_name,
layer.create_parameter(
shape=self.down_proj_scale_shape,
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
layer.up_gate_proj_weight.tensor_track = None
else:
weight_type = "down"
layer.down_proj_weight.tensor_track = None
# 1.init shape and type
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
weight_dtype = paddle.float8_e4m3fn
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
scale_dtype = "float32"
# 2.crate tmp tensor
weight = paddle.empty(shape=[weight_shape[0], weight_shape[2], weight_shape[1]], dtype=weight_dtype)
scale = paddle.empty(shape=[scale_shape[0], scale_shape[2], scale_shape[1]], dtype=scale_dtype)
# 3.quantize weight
from fastdeploy.model_executor.layers.utils import per_block_cast_to_fp8
for expert_id in range(layer.num_experts):
weight_quant, scale[expert_id] = per_block_cast_to_fp8(
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
)
weight[expert_id].copy_(weight_quant, False)
getattr(layer, unquantized_weight_name).value().get_tensor()._clear()
# create weight
setattr( setattr(
layer, layer,
self.added_scale_attrs[0], weight_name,
layer.create_parameter( layer.create_parameter(
shape=[ shape=[
layer.num_local_experts, layer.num_local_experts,
ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]), ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]),
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]), ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]),
], ],
dtype="float32", dtype=weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
), ),
) )
# create scale
setattr( setattr(
layer, layer,
self.added_scale_attrs[1], scale_name,
layer.create_parameter( layer.create_parameter(
shape=[ shape=[
layer.num_local_experts, layer.num_local_experts,
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]), ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]), ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
], ],
dtype="float32", dtype=scale_dtype,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
), ),
) )
getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]).contiguous(), False)
getattr(layer, scale_name).copy_(scale.transpose([0, 2, 1]).contiguous(), False)
def process_loaded_weights(self, layer: nn.Layer, state_dict): def process_loaded_weights(self, layer: nn.Layer, state_dict):
""" """
@@ -719,8 +826,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
num_local_experts = layer.num_local_experts num_local_experts = layer.num_local_experts
moe_intermediate_size = layer.moe_intermediate_size moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size hidden_size = layer.hidden_size
E, N1, _ = layer.up_gate_proj_weight.shape E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape
N2 = layer.down_proj_weight.shape[1] N2 = getattr(layer, self.added_weight_attrs[1]).shape[1]
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out, gate_out,
@@ -759,10 +866,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid]( fused_moe_kernel_paddle[grid](
x_q, x_q,
layer.up_gate_proj_weight, getattr(layer, self.added_weight_attrs[0]),
intermediate_cache1, intermediate_cache1,
x_scale, x_scale,
layer.up_gate_proj_weight_scale, getattr(layer, self.added_scale_attrs[0]),
None, None,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
@@ -773,17 +880,17 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
K=hidden_size, K=hidden_size,
stride_am=x_q.strides[0], stride_am=x_q.strides[0],
stride_ak=x_q.strides[1], stride_ak=x_q.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0], stride_be=getattr(layer, self.added_weight_attrs[0]).strides[0],
stride_bk=layer.up_gate_proj_weight.strides[2], stride_bk=getattr(layer, self.added_weight_attrs[0]).strides[2],
stride_bn=layer.up_gate_proj_weight.strides[1], stride_bn=getattr(layer, self.added_weight_attrs[0]).strides[1],
stride_cm=intermediate_cache1.strides[0], stride_cm=intermediate_cache1.strides[0],
stride_cn=intermediate_cache1.strides[1], stride_cn=intermediate_cache1.strides[1],
# #
stride_asm=x_scale.strides[0], # only used in blockwise fp8 stride_asm=x_scale.strides[0], # only used in blockwise fp8
stride_ask=x_scale.strides[1], # only used in blockwise fp8 stride_ask=x_scale.strides[1], # only used in blockwise fp8
stride_bse=layer.up_gate_proj_weight_scale.strides[0], stride_bse=getattr(layer, self.added_scale_attrs[0]).strides[0],
stride_bsk=layer.up_gate_proj_weight_scale.strides[2], stride_bsk=getattr(layer, self.added_scale_attrs[0]).strides[2],
stride_bsn=layer.up_gate_proj_weight_scale.strides[1], stride_bsn=getattr(layer, self.added_scale_attrs[0]).strides[1],
group_n=self.quant_config.weight_block_size[1], group_n=self.quant_config.weight_block_size[1],
group_k=self.quant_config.weight_block_size[0], group_k=self.quant_config.weight_block_size[0],
# Meta-parameters # Meta-parameters
@@ -813,10 +920,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
fused_moe_kernel_paddle[grid]( fused_moe_kernel_paddle[grid](
x_q, x_q,
layer.down_proj_weight, getattr(layer, self.added_weight_attrs[1]),
intermediate_cache3, intermediate_cache3,
x_scale, x_scale,
layer.down_proj_weight_scale, getattr(layer, self.added_scale_attrs[1]),
topk_weights, topk_weights,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
@@ -827,16 +934,16 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
K=moe_intermediate_size, K=moe_intermediate_size,
stride_am=x_q.strides[0], stride_am=x_q.strides[0],
stride_ak=x_q.strides[1], stride_ak=x_q.strides[1],
stride_be=layer.down_proj_weight.strides[0], stride_be=getattr(layer, self.added_weight_attrs[1]).strides[0],
stride_bk=layer.down_proj_weight.strides[2], stride_bk=getattr(layer, self.added_weight_attrs[1]).strides[2],
stride_bn=layer.down_proj_weight.strides[1], stride_bn=getattr(layer, self.added_weight_attrs[1]).strides[1],
stride_cm=intermediate_cache3.strides[0], stride_cm=intermediate_cache3.strides[0],
stride_cn=intermediate_cache3.strides[1], stride_cn=intermediate_cache3.strides[1],
stride_asm=x_scale.strides[0], # only used in blockwise fp8 stride_asm=x_scale.strides[0], # only used in blockwise fp8
stride_ask=x_scale.strides[1], # only used in blockwise fp8 stride_ask=x_scale.strides[1], # only used in blockwise fp8
stride_bse=layer.down_proj_weight_scale.strides[0], stride_bse=getattr(layer, self.added_scale_attrs[1]).strides[0],
stride_bsk=layer.down_proj_weight_scale.strides[2], stride_bsk=getattr(layer, self.added_scale_attrs[1]).strides[2],
stride_bsn=layer.down_proj_weight_scale.strides[1], stride_bsn=getattr(layer, self.added_scale_attrs[1]).strides[1],
group_n=self.quant_config.weight_block_size[1], group_n=self.quant_config.weight_block_size[1],
group_k=self.quant_config.weight_block_size[0], group_k=self.quant_config.weight_block_size[0],
# Meta-parameters # Meta-parameters

View File

@@ -20,7 +20,12 @@ import paddle
import fastdeploy import fastdeploy
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.layers.moe import FusedMoE from fastdeploy.model_executor.layers.moe import FusedMoE
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
from ..utils import get_tensor, per_block_cast_to_fp8 from ..utils import get_tensor, per_block_cast_to_fp8
from .quant_base import QuantConfigBase, QuantMethodBase from .quant_base import QuantConfigBase, QuantMethodBase
@@ -33,13 +38,14 @@ class BlockWiseFP8Config(QuantConfigBase):
per-token quantization of activations during inference. per-token quantization of activations during inference.
""" """
def __init__(self, weight_block_size: list = [-1, -1]) -> None: def __init__(self, weight_block_size: list = [-1, -1], is_checkpoint_bf16: bool = False) -> None:
super().__init__() super().__init__()
self.weight_block_size = weight_block_size self.weight_block_size = weight_block_size
self.quant_max_bound = 448 self.quant_max_bound = 448
self.quant_min_bound = -448 self.quant_min_bound = -448
self.quant_round_type = 1 self.quant_round_type = 1
self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM) self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM)
self.is_checkpoint_bf16 = is_checkpoint_bf16
def name(self) -> str: def name(self) -> str:
return "block_wise_fp8" return "block_wise_fp8"
@@ -47,7 +53,8 @@ class BlockWiseFP8Config(QuantConfigBase):
@classmethod @classmethod
def from_config(cls, config: dict) -> "BlockWiseFP8Config": def from_config(cls, config: dict) -> "BlockWiseFP8Config":
weight_block_size = config.get("weight_block_size", [128, 128]) weight_block_size = config.get("weight_block_size", [128, 128])
return cls(weight_block_size) is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
return cls(weight_block_size, is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
""" """
@@ -82,31 +89,78 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer, **extra_weight_attrs): def create_weights(self, layer, **extra_weight_attrs):
layer.weight_shape.reverse() if self.quant_config.is_checkpoint_bf16:
layer.weight_dtype = "float8_e4m3fn" layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
quant_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
),
}
set_weight_attrs(
layer.weight,
quant_attrs,
)
else:
layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale_inv = layer.create_parameter(
shape=[
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
// self.quant_config.weight_block_size[0],
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
// self.quant_config.weight_block_size[1],
],
dtype="float32",
is_bias=False,
)
def process_weights_after_loading(self, layer) -> None:
if not self.quant_config.is_checkpoint_bf16:
return
weight_tensor = layer.weight.transpose([1, 0])
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
if hasattr(layer.weight, "tensor_track"):
layer.weight.tensor_track = None
layer.weight.value().get_tensor()._clear()
del layer.weight
layer.weight = layer.create_parameter( layer.weight = layer.create_parameter(
shape=layer.weight_shape, shape=quanted_weight_tensor.shape,
dtype=layer.weight_dtype, dtype="float8_e4m3fn",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale_inv = layer.create_parameter(
shape=weight_block_scale_tensor.shape,
dtype="float32",
is_bias=False, is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
) )
layer.weight_scale = layer.create_parameter( layer.weight.copy_(quanted_weight_tensor, False)
shape=[ layer.weight_scale_inv.copy_(weight_block_scale_tensor, False)
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
// self.quant_config.weight_block_size[0],
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
// self.quant_config.weight_block_size[1],
],
dtype="float32",
is_bias=False,
)
def process_loaded_weights(self, layer, weights) -> None: def process_loaded_weights(self, layer, weights) -> None:
weight_tensor = weights.transpose([1, 0]) weight_tensor = weights.transpose([1, 0])
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor) quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
layer.weight.copy_(quanted_weight_tensor, False) layer.weight.copy_(quanted_weight_tensor, False)
layer.weight_scale.set_value(weight_block_scale_tensor) layer.weight_scale_inv.set_value(weight_block_scale_tensor)
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False): def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False):
""" """
@@ -119,7 +173,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False) layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
weight_scale = weight_scale.transpose([1, 0]) weight_scale = weight_scale.transpose([1, 0])
layer.weight_scale.set_value(weight_scale) layer.weight_scale_inv.set_value(weight_scale)
def apply(self, layer, x): def apply(self, layer, x):
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding( x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
@@ -130,7 +184,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
deep_gemm.gemm_fp8_fp8_bf16_nt( deep_gemm.gemm_fp8_fp8_bf16_nt(
(x, x_scale_tensor), (x, x_scale_tensor),
(layer.weight, layer.weight_scale), (layer.weight, layer.weight_scale_inv),
linear_out, linear_out,
) )
if layer.with_bias: if layer.with_bias:

View File

@@ -37,6 +37,7 @@ class MixQuantConfig(QuantConfigBase):
is_channel_wise: bool = False, is_channel_wise: bool = False,
has_zero_point: bool = False, has_zero_point: bool = False,
is_permuted: bool = True, is_permuted: bool = True,
is_checkpoint_bf16: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.dense_quant_type = dense_quant_type self.dense_quant_type = dense_quant_type
@@ -52,6 +53,7 @@ class MixQuantConfig(QuantConfigBase):
self.quant_min_bound = 0 self.quant_min_bound = 0
self.quant_round_type = 0 self.quant_round_type = 0
self.is_permuted = is_permuted self.is_permuted = is_permuted
self.is_checkpoint_bf16 = is_checkpoint_bf16
def name(self) -> str: def name(self) -> str:
return "mix_quant" return "mix_quant"
@@ -66,6 +68,7 @@ class MixQuantConfig(QuantConfigBase):
config.get("is_channel_wise", False), config.get("is_channel_wise", False),
config.get("has_zero_point", False), config.get("has_zero_point", False),
config.get("is_permuted", True), config.get("is_permuted", True),
config.get("is_checkpoint_bf16", False),
) )
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -73,13 +76,13 @@ class MixQuantConfig(QuantConfigBase):
if layer.moe_tag == "Image": if layer.moe_tag == "Image":
return ( return (
get_quantization_config(self.image_moe_quant_type) get_quantization_config(self.image_moe_quant_type)
.from_config({"is_permuted": self.is_permuted}) .from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer) .get_quant_method(layer)
) )
else: else:
return ( return (
get_quantization_config(self.moe_quant_type) get_quantization_config(self.moe_quant_type)
.from_config({"is_permuted": self.is_permuted}) .from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer) .get_quant_method(layer)
) )
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
@@ -92,4 +95,8 @@ class MixQuantConfig(QuantConfigBase):
else: else:
return None return None
else: else:
return get_quantization_config(self.dense_quant_type).from_config({}).get_quant_method(layer) return (
get_quantization_config(self.dense_quant_type)
.from_config({"self.is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer)
)

View File

@@ -660,6 +660,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
quantization_config = {} quantization_config = {}
quant_config_name = args.quantization quant_config_name = args.quantization
quantization_config["quantization"] = quant_config_name quantization_config["quantization"] = quant_config_name
# Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization.
if load_config.load_choices == "default_v1":
quantization_config["is_checkpoint_bf16"] = True
# Special handling for Ernie models # Special handling for Ernie models
is_ernie = ErnieArchitectures.contains_ernie_arch(model_config.architectures) is_ernie = ErnieArchitectures.contains_ernie_arch(model_config.architectures)
if quant_config_name == "wint4" and is_ernie: if quant_config_name == "wint4" and is_ernie:

View File

@@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
import os import os
import shutil
import traceback import traceback
import warnings import warnings
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
import pytest import pytest
os.environ["LOAD_STATE_DICT_THREAD_NUM"] = "1"
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313)) FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
MAX_WAIT_SECONDS = 60 * 5 MAX_WAIT_SECONDS = 60 * 5
@@ -46,6 +48,33 @@ def get_model_paths(base_model_name: str) -> tuple[str, str]:
return fd_model_path, torch_model_path return fd_model_path, torch_model_path
def clear_logs():
log_path = os.path.join(os.getcwd(), "log")
if os.path.exists(log_path):
try:
shutil.rmtree(log_path)
print(f"Deleted log directory: {log_path}")
except Exception as e:
print(f"Failed to delete log directory {log_path}: {e}")
else:
print(f"No log directory found at {log_path}")
def print_logs():
log_dir = os.path.join(os.getcwd(), "log")
log_file = os.path.join(log_dir, "workerlog.0")
if not os.path.exists(log_file):
print(f"Log file {log_file} does not exist.")
return
print(f"\n===== {log_file} start =====")
with open(log_file, "r") as f:
for line in f:
print(line, end="")
print(f"\n===== {log_file} end =====\n")
def check_tokens_id_and_text_close( def check_tokens_id_and_text_close(
*, *,
outputs_0_lst: TokensIdText, outputs_0_lst: TokensIdText,
@@ -110,37 +139,72 @@ def form_model_get_output(
pytest.fail(f"Failed to initialize LLM model from {model_path}") pytest.fail(f"Failed to initialize LLM model from {model_path}")
def run_with_timeout(target, args, timeout=60 * 5):
clear_logs()
result_queue = Queue()
p = Process(target=target, args=(*args, result_queue))
p.start()
p.join(timeout)
if p.is_alive():
p.terminate()
print_logs()
raise RuntimeError("Worker process hung and was terminated")
try:
return result_queue.get(timeout=60)
except Exception as e:
raise RuntimeError(f"Failed to get result from worker: {e}")
model_param_map = { model_param_map = {
"Qwen3-0.6B": { "Qwen3-0.6B": {
"quantizations": ["None", "wint4", "wint8"], "quantizations": ["None", "wint4", "wint8"],
}, },
"ernie-4_5-21b-a3b-bf16-paddle": { "ernie-4_5-21b-a3b-bf16-paddle": {
"tensor_parallel_size": 2, "tensor_parallel_size": 2,
"quantizations": ["wint8"], "quantizations": [
"wint8",
],
}, },
"Qwen2-7B-Instruct": { "Qwen2-7B-Instruct": {
"quantizations": ["None", "wint8"], "quantizations": ["None", "wint8"],
}, },
"Qwen3-30B-A3B": {
"tensor_parallel_size": 2,
"quantizations": [
{
"quant_type": "block_wise_fp8",
"backend": "triton",
"env": {"FD_USE_DEEP_GEMM": "0", "DG_NVCC_OVERRIDE_CPP_STANDARD": "17"},
},
{"quant_type": "block_wise_fp8", "backend": "deepgemm", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}},
],
},
} }
params = [] params = []
for model, cfg in model_param_map.items(): for model, cfg in model_param_map.items():
for q in cfg["quantizations"]: for q in cfg["quantizations"]:
if isinstance(q, dict):
quant, backend, env = q["quant_type"], q.get("backend", "default"), q.get("env", {})
else:
quant, backend, env = q, "default", {}
params.append( params.append(
pytest.param( pytest.param(
model, model,
cfg.get("tensor_parallel_size", 1), cfg.get("tensor_parallel_size", 1),
cfg.get("max_model_len", 1024), cfg.get("max_model_len", 1024),
q, quant,
cfg.get("max_tokens", 32), cfg.get("max_tokens", 32),
env,
marks=[pytest.mark.core_model], marks=[pytest.mark.core_model],
id=f"{model}.{quant}.{backend}",
) )
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens", "model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens,env",
params, params,
) )
def test_common_model( def test_common_model(
@@ -150,46 +214,26 @@ def test_common_model(
max_model_len: int, max_model_len: int,
max_tokens: int, max_tokens: int,
quantization: str, quantization: str,
env,
monkeypatch,
) -> None: ) -> None:
base_path = os.getenv("MODEL_PATH") base_path = os.getenv("MODEL_PATH")
if base_path: if base_path:
model_path = os.path.join(base_path, model_name_or_path) model_path = os.path.join(base_path, model_name_or_path)
else: else:
model_path = model_name_or_path model_path = model_name_or_path
result_queue = Queue() if env:
p = Process( for k, v in env.items():
target=form_model_get_output, monkeypatch.setenv(k, v)
args=(
fd_runner,
model_path,
tensor_parallel_size,
max_model_len,
max_tokens,
quantization,
"default",
result_queue,
),
)
p.start()
p.join()
fd_outputs_v0 = result_queue.get(timeout=60)
p = Process( fd_outputs_v0 = run_with_timeout(
target=form_model_get_output, target=form_model_get_output,
args=( args=(fd_runner, model_path, tensor_parallel_size, max_model_len, max_tokens, quantization, "default"),
fd_runner, )
model_path, fd_outputs_v1 = run_with_timeout(
tensor_parallel_size, target=form_model_get_output,
max_model_len, args=(fd_runner, model_path, tensor_parallel_size, max_model_len, max_tokens, quantization, "default_v1"),
max_tokens,
quantization,
"default_v1",
result_queue,
),
) )
p.start()
p.join()
fd_outputs_v1 = result_queue.get(timeout=60)
check_tokens_id_and_text_close( check_tokens_id_and_text_close(
outputs_0_lst=fd_outputs_v0, outputs_0_lst=fd_outputs_v0,
outputs_1_lst=fd_outputs_v1, outputs_1_lst=fd_outputs_v1,
@@ -235,26 +279,12 @@ def test_paddle_vs_torch_model(
fd_model_path, torch_model_path = get_model_paths(model_name_or_path) fd_model_path, torch_model_path = get_model_paths(model_name_or_path)
result_queue = Queue() paddle_outputs = run_with_timeout(
p_paddle = Process(
target=form_model_get_output, target=form_model_get_output,
args=( args=(fd_runner, fd_model_path, tensor_parallel_size, max_model_len, max_tokens, quantization, "default"),
fd_runner,
fd_model_path,
tensor_parallel_size,
max_model_len,
max_tokens,
quantization,
"default",
result_queue,
),
) )
p_paddle.start()
p_paddle.join()
paddle_outputs = result_queue.get(timeout=60)
p_hf = Process( hf_outputs = run_with_timeout(
target=form_model_get_output, target=form_model_get_output,
args=( args=(
fd_runner, fd_runner,
@@ -264,12 +294,8 @@ def test_paddle_vs_torch_model(
max_tokens, max_tokens,
quantization, quantization,
"default_v1", "default_v1",
result_queue,
), ),
) )
p_hf.start()
p_hf.join()
hf_outputs = result_queue.get(timeout=60)
check_tokens_id_and_text_close( check_tokens_id_and_text_close(
outputs_0_lst=paddle_outputs, outputs_0_lst=paddle_outputs,