mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
@@ -749,7 +749,6 @@ class LoadChoices(str, Enum):
|
||||
"""LoadChoices"""
|
||||
|
||||
DEFAULT = "default"
|
||||
# only support qwen3-bf16 now
|
||||
DEFAULT_V1 = "default_v1"
|
||||
|
||||
|
||||
|
@@ -22,6 +22,7 @@ import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
|
||||
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from .fused_moe_backend_base import MoEMethodBase
|
||||
@@ -36,24 +37,64 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
"""
|
||||
deepgemm create weight process.
|
||||
"""
|
||||
self.weight_dtype = paddle.float8_e4m3fn
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
self.ffn1_weight_shape = [
|
||||
self.up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
self.ffn2_weight_shape = [
|
||||
self.down_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
self.up_gate_proj_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0],
|
||||
layer.hidden_size // self.quant_config.weight_block_size[1],
|
||||
]
|
||||
self.down_proj_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // self.quant_config.weight_block_size[0],
|
||||
layer.moe_intermediate_size // self.quant_config.weight_block_size[1],
|
||||
]
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
layer.down_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.weight_dtype = paddle.float8_e4m3fn
|
||||
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
up_gate_proj_scale_name = self.added_scale_attrs[0]
|
||||
down_proj_scale_name = self.added_scale_attrs[1]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
@@ -62,7 +103,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
shape=self.down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
@@ -70,30 +111,96 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
up_gate_proj_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]),
|
||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]),
|
||||
],
|
||||
shape=self.up_gate_proj_scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
down_proj_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
weight_id_map = {"gate_up": 0, "down": 1}
|
||||
if (
|
||||
hasattr(layer.up_gate_proj_weight, "tensor_track")
|
||||
and layer.up_gate_proj_weight.tensor_track is not None
|
||||
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
|
||||
):
|
||||
weight_type = "gate_up"
|
||||
layer.up_gate_proj_weight.tensor_track = None
|
||||
else:
|
||||
weight_type = "down"
|
||||
layer.down_proj_weight.tensor_track = None
|
||||
|
||||
# 1.init shape and type
|
||||
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
|
||||
# weight
|
||||
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
|
||||
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
|
||||
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
|
||||
weight_dtype = paddle.float8_e4m3fn
|
||||
# scale
|
||||
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
|
||||
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
|
||||
scale_dtype = "float32"
|
||||
|
||||
# 2.crate tmp tensor
|
||||
|
||||
weight = paddle.empty(shape=[weight_shape[0], weight_shape[2], weight_shape[1]], dtype=weight_dtype)
|
||||
scale = paddle.empty(shape=[scale_shape[0], scale_shape[2], scale_shape[1]], dtype=scale_dtype)
|
||||
|
||||
# 3.quantize weight
|
||||
from fastdeploy.model_executor.layers.utils import per_block_cast_to_fp8
|
||||
|
||||
for expert_id in range(layer.num_experts):
|
||||
weight_quant, scale[expert_id] = per_block_cast_to_fp8(
|
||||
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
|
||||
)
|
||||
weight[expert_id].copy_(weight_quant, False)
|
||||
getattr(layer, unquantized_weight_name).value().get_tensor()._clear()
|
||||
|
||||
# create weight
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]),
|
||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]),
|
||||
],
|
||||
dtype=weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# create scale
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
|
||||
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
|
||||
],
|
||||
dtype="float32",
|
||||
dtype=scale_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]).contiguous(), False)
|
||||
getattr(layer, scale_name).copy_(scale.transpose([0, 2, 1]).contiguous(), False)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
@@ -244,12 +351,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
# up_gate_proj
|
||||
ffn_out = paddle.empty(
|
||||
(permute_input.shape[0], layer.up_gate_proj_weight.shape[1]),
|
||||
(permute_input.shape[0], getattr(layer, self.added_weight_attrs[0]).shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(permute_input, permute_scale),
|
||||
(layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale),
|
||||
(getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_scale_attrs[0])),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
@@ -264,12 +371,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
|
||||
|
||||
ffn_out = paddle.empty(
|
||||
(ffn_out.shape[0], layer.down_proj_weight.shape[1]),
|
||||
(ffn_out.shape[0], getattr(layer, self.added_weight_attrs[1]).shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(ffn_in_x, ffn_in_x_scale_tensor),
|
||||
(layer.down_proj_weight, layer.down_proj_weight_scale),
|
||||
(getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_scale_attrs[1])),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
@@ -331,8 +438,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
permute_input,
|
||||
(
|
||||
layer.up_gate_proj_weight,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
getattr(layer, self.added_scale_attrs[0]),
|
||||
),
|
||||
up_gate_proj_out,
|
||||
token_nums_per_expert,
|
||||
@@ -350,8 +457,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
(act_out_fp8, scale),
|
||||
(
|
||||
layer.down_proj_weight,
|
||||
layer.down_proj_weight_scale,
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
getattr(layer, self.added_scale_attrs[1]),
|
||||
),
|
||||
ffn_out,
|
||||
token_nums_per_expert,
|
||||
@@ -423,12 +530,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
# up_gate_proj
|
||||
ffn_out = paddle.empty(
|
||||
(permute_input.shape[0], layer.up_gate_proj_weight.shape[1]),
|
||||
(permute_input.shape[0], getattr(layer, self.added_weight_attrs[0]).shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(permute_input, permute_scale),
|
||||
(layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale),
|
||||
(getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_scale_attrs[0])),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
@@ -444,12 +551,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
|
||||
|
||||
ffn_out = paddle.empty(
|
||||
(ffn_out.shape[0], layer.down_proj_weight.shape[1]),
|
||||
(ffn_out.shape[0], getattr(layer, self.added_weight_attrs[1]).shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(ffn_in_x, ffn_in_x_scale_tensor),
|
||||
(layer.down_proj_weight, layer.down_proj_weight_scale),
|
||||
(getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_scale_attrs[1])),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
|
@@ -20,6 +20,7 @@ from paddle import nn
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
@@ -604,24 +605,64 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
self.weight_dtype = paddle.float8_e4m3fn
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
self.ffn1_weight_shape = [
|
||||
self.up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
self.ffn2_weight_shape = [
|
||||
self.down_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
self.up_gate_proj_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0],
|
||||
layer.hidden_size // self.quant_config.weight_block_size[1],
|
||||
]
|
||||
self.down_proj_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // self.quant_config.weight_block_size[0],
|
||||
layer.moe_intermediate_size // self.quant_config.weight_block_size[1],
|
||||
]
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
layer.down_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.weight_dtype = paddle.float8_e4m3fn
|
||||
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
up_gate_proj_scale_name = self.added_scale_attrs[0]
|
||||
down_proj_scale_name = self.added_scale_attrs[1]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
@@ -630,7 +671,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
@@ -638,30 +679,96 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
up_gate_proj_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]),
|
||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]),
|
||||
],
|
||||
shape=self.up_gate_proj_scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
down_proj_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
weight_id_map = {"gate_up": 0, "down": 1}
|
||||
if (
|
||||
hasattr(layer.up_gate_proj_weight, "tensor_track")
|
||||
and layer.up_gate_proj_weight.tensor_track is not None
|
||||
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
|
||||
):
|
||||
weight_type = "gate_up"
|
||||
layer.up_gate_proj_weight.tensor_track = None
|
||||
else:
|
||||
weight_type = "down"
|
||||
layer.down_proj_weight.tensor_track = None
|
||||
|
||||
# 1.init shape and type
|
||||
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
|
||||
# weight
|
||||
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
|
||||
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
|
||||
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
|
||||
weight_dtype = paddle.float8_e4m3fn
|
||||
# scale
|
||||
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
|
||||
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
|
||||
scale_dtype = "float32"
|
||||
|
||||
# 2.crate tmp tensor
|
||||
|
||||
weight = paddle.empty(shape=[weight_shape[0], weight_shape[2], weight_shape[1]], dtype=weight_dtype)
|
||||
scale = paddle.empty(shape=[scale_shape[0], scale_shape[2], scale_shape[1]], dtype=scale_dtype)
|
||||
|
||||
# 3.quantize weight
|
||||
from fastdeploy.model_executor.layers.utils import per_block_cast_to_fp8
|
||||
|
||||
for expert_id in range(layer.num_experts):
|
||||
weight_quant, scale[expert_id] = per_block_cast_to_fp8(
|
||||
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
|
||||
)
|
||||
weight[expert_id].copy_(weight_quant, False)
|
||||
getattr(layer, unquantized_weight_name).value().get_tensor()._clear()
|
||||
|
||||
# create weight
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]),
|
||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]),
|
||||
],
|
||||
dtype=weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# create scale
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
|
||||
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
|
||||
],
|
||||
dtype="float32",
|
||||
dtype=scale_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]).contiguous(), False)
|
||||
getattr(layer, scale_name).copy_(scale.transpose([0, 2, 1]).contiguous(), False)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
@@ -719,8 +826,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
num_local_experts = layer.num_local_experts
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
E, N1, _ = layer.up_gate_proj_weight.shape
|
||||
N2 = layer.down_proj_weight.shape[1]
|
||||
E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape
|
||||
N2 = getattr(layer, self.added_weight_attrs[1]).shape[1]
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
@@ -759,10 +866,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x_q,
|
||||
layer.up_gate_proj_weight,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
intermediate_cache1,
|
||||
x_scale,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
getattr(layer, self.added_scale_attrs[0]),
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -773,17 +880,17 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
K=hidden_size,
|
||||
stride_am=x_q.strides[0],
|
||||
stride_ak=x_q.strides[1],
|
||||
stride_be=layer.up_gate_proj_weight.strides[0],
|
||||
stride_bk=layer.up_gate_proj_weight.strides[2],
|
||||
stride_bn=layer.up_gate_proj_weight.strides[1],
|
||||
stride_be=getattr(layer, self.added_weight_attrs[0]).strides[0],
|
||||
stride_bk=getattr(layer, self.added_weight_attrs[0]).strides[2],
|
||||
stride_bn=getattr(layer, self.added_weight_attrs[0]).strides[1],
|
||||
stride_cm=intermediate_cache1.strides[0],
|
||||
stride_cn=intermediate_cache1.strides[1],
|
||||
#
|
||||
stride_asm=x_scale.strides[0], # only used in blockwise fp8
|
||||
stride_ask=x_scale.strides[1], # only used in blockwise fp8
|
||||
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
|
||||
stride_bsk=layer.up_gate_proj_weight_scale.strides[2],
|
||||
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
|
||||
stride_bse=getattr(layer, self.added_scale_attrs[0]).strides[0],
|
||||
stride_bsk=getattr(layer, self.added_scale_attrs[0]).strides[2],
|
||||
stride_bsn=getattr(layer, self.added_scale_attrs[0]).strides[1],
|
||||
group_n=self.quant_config.weight_block_size[1],
|
||||
group_k=self.quant_config.weight_block_size[0],
|
||||
# Meta-parameters
|
||||
@@ -813,10 +920,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x_q,
|
||||
layer.down_proj_weight,
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
intermediate_cache3,
|
||||
x_scale,
|
||||
layer.down_proj_weight_scale,
|
||||
getattr(layer, self.added_scale_attrs[1]),
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -827,16 +934,16 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
K=moe_intermediate_size,
|
||||
stride_am=x_q.strides[0],
|
||||
stride_ak=x_q.strides[1],
|
||||
stride_be=layer.down_proj_weight.strides[0],
|
||||
stride_bk=layer.down_proj_weight.strides[2],
|
||||
stride_bn=layer.down_proj_weight.strides[1],
|
||||
stride_be=getattr(layer, self.added_weight_attrs[1]).strides[0],
|
||||
stride_bk=getattr(layer, self.added_weight_attrs[1]).strides[2],
|
||||
stride_bn=getattr(layer, self.added_weight_attrs[1]).strides[1],
|
||||
stride_cm=intermediate_cache3.strides[0],
|
||||
stride_cn=intermediate_cache3.strides[1],
|
||||
stride_asm=x_scale.strides[0], # only used in blockwise fp8
|
||||
stride_ask=x_scale.strides[1], # only used in blockwise fp8
|
||||
stride_bse=layer.down_proj_weight_scale.strides[0],
|
||||
stride_bsk=layer.down_proj_weight_scale.strides[2],
|
||||
stride_bsn=layer.down_proj_weight_scale.strides[1],
|
||||
stride_bse=getattr(layer, self.added_scale_attrs[1]).strides[0],
|
||||
stride_bsk=getattr(layer, self.added_scale_attrs[1]).strides[2],
|
||||
stride_bsn=getattr(layer, self.added_scale_attrs[1]).strides[1],
|
||||
group_n=self.quant_config.weight_block_size[1],
|
||||
group_k=self.quant_config.weight_block_size[0],
|
||||
# Meta-parameters
|
||||
|
@@ -20,7 +20,12 @@ import paddle
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||
|
||||
from ..utils import get_tensor, per_block_cast_to_fp8
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
@@ -33,13 +38,14 @@ class BlockWiseFP8Config(QuantConfigBase):
|
||||
per-token quantization of activations during inference.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_block_size: list = [-1, -1]) -> None:
|
||||
def __init__(self, weight_block_size: list = [-1, -1], is_checkpoint_bf16: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.weight_block_size = weight_block_size
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM)
|
||||
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||
|
||||
def name(self) -> str:
|
||||
return "block_wise_fp8"
|
||||
@@ -47,7 +53,8 @@ class BlockWiseFP8Config(QuantConfigBase):
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "BlockWiseFP8Config":
|
||||
weight_block_size = config.get("weight_block_size", [128, 128])
|
||||
return cls(weight_block_size)
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
return cls(weight_block_size, is_checkpoint_bf16)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
@@ -82,6 +89,26 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer, **extra_weight_attrs):
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
quant_attrs = extra_weight_attrs
|
||||
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
|
||||
quant_attrs = {
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(
|
||||
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
|
||||
),
|
||||
}
|
||||
set_weight_attrs(
|
||||
layer.weight,
|
||||
quant_attrs,
|
||||
)
|
||||
else:
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
layer.weight = layer.create_parameter(
|
||||
@@ -91,7 +118,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
layer.weight_scale_inv = layer.create_parameter(
|
||||
shape=[
|
||||
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
|
||||
// self.quant_config.weight_block_size[0],
|
||||
@@ -102,11 +129,38 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
weight_tensor = layer.weight.transpose([1, 0])
|
||||
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
|
||||
|
||||
if hasattr(layer.weight, "tensor_track"):
|
||||
layer.weight.tensor_track = None
|
||||
layer.weight.value().get_tensor()._clear()
|
||||
del layer.weight
|
||||
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=quanted_weight_tensor.shape,
|
||||
dtype="float8_e4m3fn",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.weight_scale_inv = layer.create_parameter(
|
||||
shape=weight_block_scale_tensor.shape,
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
layer.weight.copy_(quanted_weight_tensor, False)
|
||||
layer.weight_scale_inv.copy_(weight_block_scale_tensor, False)
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
|
||||
layer.weight.copy_(quanted_weight_tensor, False)
|
||||
layer.weight_scale.set_value(weight_block_scale_tensor)
|
||||
layer.weight_scale_inv.set_value(weight_block_scale_tensor)
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
@@ -119,7 +173,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
|
||||
weight_scale = weight_scale.transpose([1, 0])
|
||||
layer.weight_scale.set_value(weight_scale)
|
||||
layer.weight_scale_inv.set_value(weight_scale)
|
||||
|
||||
def apply(self, layer, x):
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
|
||||
@@ -130,7 +184,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(
|
||||
(x, x_scale_tensor),
|
||||
(layer.weight, layer.weight_scale),
|
||||
(layer.weight, layer.weight_scale_inv),
|
||||
linear_out,
|
||||
)
|
||||
if layer.with_bias:
|
||||
|
@@ -37,6 +37,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
is_channel_wise: bool = False,
|
||||
has_zero_point: bool = False,
|
||||
is_permuted: bool = True,
|
||||
is_checkpoint_bf16: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense_quant_type = dense_quant_type
|
||||
@@ -52,6 +53,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
self.quant_min_bound = 0
|
||||
self.quant_round_type = 0
|
||||
self.is_permuted = is_permuted
|
||||
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||
|
||||
def name(self) -> str:
|
||||
return "mix_quant"
|
||||
@@ -66,6 +68,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
config.get("is_channel_wise", False),
|
||||
config.get("has_zero_point", False),
|
||||
config.get("is_permuted", True),
|
||||
config.get("is_checkpoint_bf16", False),
|
||||
)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
@@ -73,13 +76,13 @@ class MixQuantConfig(QuantConfigBase):
|
||||
if layer.moe_tag == "Image":
|
||||
return (
|
||||
get_quantization_config(self.image_moe_quant_type)
|
||||
.from_config({"is_permuted": self.is_permuted})
|
||||
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
else:
|
||||
return (
|
||||
get_quantization_config(self.moe_quant_type)
|
||||
.from_config({"is_permuted": self.is_permuted})
|
||||
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
elif isinstance(layer, Attention):
|
||||
@@ -92,4 +95,8 @@ class MixQuantConfig(QuantConfigBase):
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return get_quantization_config(self.dense_quant_type).from_config({}).get_quant_method(layer)
|
||||
return (
|
||||
get_quantization_config(self.dense_quant_type)
|
||||
.from_config({"self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
|
@@ -660,6 +660,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
quantization_config = {}
|
||||
quant_config_name = args.quantization
|
||||
quantization_config["quantization"] = quant_config_name
|
||||
# Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization.
|
||||
if load_config.load_choices == "default_v1":
|
||||
quantization_config["is_checkpoint_bf16"] = True
|
||||
# Special handling for Ernie models
|
||||
is_ernie = ErnieArchitectures.contains_ernie_arch(model_config.architectures)
|
||||
if quant_config_name == "wint4" and is_ernie:
|
||||
|
@@ -13,12 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import traceback
|
||||
import warnings
|
||||
from multiprocessing import Process, Queue
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ["LOAD_STATE_DICT_THREAD_NUM"] = "1"
|
||||
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
|
||||
MAX_WAIT_SECONDS = 60 * 5
|
||||
|
||||
@@ -46,6 +48,33 @@ def get_model_paths(base_model_name: str) -> tuple[str, str]:
|
||||
return fd_model_path, torch_model_path
|
||||
|
||||
|
||||
def clear_logs():
|
||||
log_path = os.path.join(os.getcwd(), "log")
|
||||
if os.path.exists(log_path):
|
||||
try:
|
||||
shutil.rmtree(log_path)
|
||||
print(f"Deleted log directory: {log_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to delete log directory {log_path}: {e}")
|
||||
else:
|
||||
print(f"No log directory found at {log_path}")
|
||||
|
||||
|
||||
def print_logs():
|
||||
log_dir = os.path.join(os.getcwd(), "log")
|
||||
log_file = os.path.join(log_dir, "workerlog.0")
|
||||
|
||||
if not os.path.exists(log_file):
|
||||
print(f"Log file {log_file} does not exist.")
|
||||
return
|
||||
|
||||
print(f"\n===== {log_file} start =====")
|
||||
with open(log_file, "r") as f:
|
||||
for line in f:
|
||||
print(line, end="")
|
||||
print(f"\n===== {log_file} end =====\n")
|
||||
|
||||
|
||||
def check_tokens_id_and_text_close(
|
||||
*,
|
||||
outputs_0_lst: TokensIdText,
|
||||
@@ -110,37 +139,72 @@ def form_model_get_output(
|
||||
pytest.fail(f"Failed to initialize LLM model from {model_path}")
|
||||
|
||||
|
||||
def run_with_timeout(target, args, timeout=60 * 5):
|
||||
clear_logs()
|
||||
result_queue = Queue()
|
||||
p = Process(target=target, args=(*args, result_queue))
|
||||
p.start()
|
||||
p.join(timeout)
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
print_logs()
|
||||
raise RuntimeError("Worker process hung and was terminated")
|
||||
try:
|
||||
return result_queue.get(timeout=60)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get result from worker: {e}")
|
||||
|
||||
|
||||
model_param_map = {
|
||||
"Qwen3-0.6B": {
|
||||
"quantizations": ["None", "wint4", "wint8"],
|
||||
},
|
||||
"ernie-4_5-21b-a3b-bf16-paddle": {
|
||||
"tensor_parallel_size": 2,
|
||||
"quantizations": ["wint8"],
|
||||
"quantizations": [
|
||||
"wint8",
|
||||
],
|
||||
},
|
||||
"Qwen2-7B-Instruct": {
|
||||
"quantizations": ["None", "wint8"],
|
||||
},
|
||||
"Qwen3-30B-A3B": {
|
||||
"tensor_parallel_size": 2,
|
||||
"quantizations": [
|
||||
{
|
||||
"quant_type": "block_wise_fp8",
|
||||
"backend": "triton",
|
||||
"env": {"FD_USE_DEEP_GEMM": "0", "DG_NVCC_OVERRIDE_CPP_STANDARD": "17"},
|
||||
},
|
||||
{"quant_type": "block_wise_fp8", "backend": "deepgemm", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
params = []
|
||||
for model, cfg in model_param_map.items():
|
||||
for q in cfg["quantizations"]:
|
||||
if isinstance(q, dict):
|
||||
quant, backend, env = q["quant_type"], q.get("backend", "default"), q.get("env", {})
|
||||
else:
|
||||
quant, backend, env = q, "default", {}
|
||||
params.append(
|
||||
pytest.param(
|
||||
model,
|
||||
cfg.get("tensor_parallel_size", 1),
|
||||
cfg.get("max_model_len", 1024),
|
||||
q,
|
||||
quant,
|
||||
cfg.get("max_tokens", 32),
|
||||
env,
|
||||
marks=[pytest.mark.core_model],
|
||||
id=f"{model}.{quant}.{backend}",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens",
|
||||
"model_name_or_path,tensor_parallel_size,max_model_len,quantization,max_tokens,env",
|
||||
params,
|
||||
)
|
||||
def test_common_model(
|
||||
@@ -150,46 +214,26 @@ def test_common_model(
|
||||
max_model_len: int,
|
||||
max_tokens: int,
|
||||
quantization: str,
|
||||
env,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
base_path = os.getenv("MODEL_PATH")
|
||||
if base_path:
|
||||
model_path = os.path.join(base_path, model_name_or_path)
|
||||
else:
|
||||
model_path = model_name_or_path
|
||||
result_queue = Queue()
|
||||
p = Process(
|
||||
target=form_model_get_output,
|
||||
args=(
|
||||
fd_runner,
|
||||
model_path,
|
||||
tensor_parallel_size,
|
||||
max_model_len,
|
||||
max_tokens,
|
||||
quantization,
|
||||
"default",
|
||||
result_queue,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
fd_outputs_v0 = result_queue.get(timeout=60)
|
||||
if env:
|
||||
for k, v in env.items():
|
||||
monkeypatch.setenv(k, v)
|
||||
|
||||
p = Process(
|
||||
fd_outputs_v0 = run_with_timeout(
|
||||
target=form_model_get_output,
|
||||
args=(
|
||||
fd_runner,
|
||||
model_path,
|
||||
tensor_parallel_size,
|
||||
max_model_len,
|
||||
max_tokens,
|
||||
quantization,
|
||||
"default_v1",
|
||||
result_queue,
|
||||
),
|
||||
args=(fd_runner, model_path, tensor_parallel_size, max_model_len, max_tokens, quantization, "default"),
|
||||
)
|
||||
fd_outputs_v1 = run_with_timeout(
|
||||
target=form_model_get_output,
|
||||
args=(fd_runner, model_path, tensor_parallel_size, max_model_len, max_tokens, quantization, "default_v1"),
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
fd_outputs_v1 = result_queue.get(timeout=60)
|
||||
check_tokens_id_and_text_close(
|
||||
outputs_0_lst=fd_outputs_v0,
|
||||
outputs_1_lst=fd_outputs_v1,
|
||||
@@ -235,26 +279,12 @@ def test_paddle_vs_torch_model(
|
||||
|
||||
fd_model_path, torch_model_path = get_model_paths(model_name_or_path)
|
||||
|
||||
result_queue = Queue()
|
||||
|
||||
p_paddle = Process(
|
||||
paddle_outputs = run_with_timeout(
|
||||
target=form_model_get_output,
|
||||
args=(
|
||||
fd_runner,
|
||||
fd_model_path,
|
||||
tensor_parallel_size,
|
||||
max_model_len,
|
||||
max_tokens,
|
||||
quantization,
|
||||
"default",
|
||||
result_queue,
|
||||
),
|
||||
args=(fd_runner, fd_model_path, tensor_parallel_size, max_model_len, max_tokens, quantization, "default"),
|
||||
)
|
||||
p_paddle.start()
|
||||
p_paddle.join()
|
||||
paddle_outputs = result_queue.get(timeout=60)
|
||||
|
||||
p_hf = Process(
|
||||
hf_outputs = run_with_timeout(
|
||||
target=form_model_get_output,
|
||||
args=(
|
||||
fd_runner,
|
||||
@@ -264,12 +294,8 @@ def test_paddle_vs_torch_model(
|
||||
max_tokens,
|
||||
quantization,
|
||||
"default_v1",
|
||||
result_queue,
|
||||
),
|
||||
)
|
||||
p_hf.start()
|
||||
p_hf.join()
|
||||
hf_outputs = result_queue.get(timeout=60)
|
||||
|
||||
check_tokens_id_and_text_close(
|
||||
outputs_0_lst=paddle_outputs,
|
||||
|
Reference in New Issue
Block a user