mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
@@ -22,6 +22,7 @@ import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
|
||||
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from .fused_moe_backend_base import MoEMethodBase
|
||||
@@ -36,64 +37,170 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
"""
|
||||
deepgemm create weight process.
|
||||
"""
|
||||
self.weight_dtype = paddle.float8_e4m3fn
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
self.ffn1_weight_shape = [
|
||||
self.up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
self.ffn2_weight_shape = [
|
||||
self.down_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
self.up_gate_proj_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0],
|
||||
layer.hidden_size // self.quant_config.weight_block_size[1],
|
||||
]
|
||||
self.down_proj_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // self.quant_config.weight_block_size[0],
|
||||
layer.moe_intermediate_size // self.quant_config.weight_block_size[1],
|
||||
]
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
)
|
||||
|
||||
layer.down_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.weight_dtype = paddle.float8_e4m3fn
|
||||
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
up_gate_proj_scale_name = self.added_scale_attrs[0]
|
||||
down_proj_scale_name = self.added_scale_attrs[1]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
weight_id_map = {"gate_up": 0, "down": 1}
|
||||
if (
|
||||
hasattr(layer.up_gate_proj_weight, "tensor_track")
|
||||
and layer.up_gate_proj_weight.tensor_track is not None
|
||||
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
|
||||
):
|
||||
weight_type = "gate_up"
|
||||
layer.up_gate_proj_weight.tensor_track = None
|
||||
else:
|
||||
weight_type = "down"
|
||||
layer.down_proj_weight.tensor_track = None
|
||||
|
||||
# 1.init shape and type
|
||||
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
|
||||
# weight
|
||||
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
|
||||
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
|
||||
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
|
||||
weight_dtype = paddle.float8_e4m3fn
|
||||
# scale
|
||||
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
|
||||
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
|
||||
scale_dtype = "float32"
|
||||
|
||||
# 2.crate tmp tensor
|
||||
|
||||
weight = paddle.empty(shape=[weight_shape[0], weight_shape[2], weight_shape[1]], dtype=weight_dtype)
|
||||
scale = paddle.empty(shape=[scale_shape[0], scale_shape[2], scale_shape[1]], dtype=scale_dtype)
|
||||
|
||||
# 3.quantize weight
|
||||
from fastdeploy.model_executor.layers.utils import per_block_cast_to_fp8
|
||||
|
||||
for expert_id in range(layer.num_experts):
|
||||
weight_quant, scale[expert_id] = per_block_cast_to_fp8(
|
||||
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
|
||||
)
|
||||
weight[expert_id].copy_(weight_quant, False)
|
||||
getattr(layer, unquantized_weight_name).value().get_tensor()._clear()
|
||||
|
||||
# create weight
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]),
|
||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]),
|
||||
],
|
||||
dtype="float32",
|
||||
dtype=weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# create scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
|
||||
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
|
||||
],
|
||||
dtype="float32",
|
||||
dtype=scale_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]).contiguous(), False)
|
||||
getattr(layer, scale_name).copy_(scale.transpose([0, 2, 1]).contiguous(), False)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
@@ -244,12 +351,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
# up_gate_proj
|
||||
ffn_out = paddle.empty(
|
||||
(permute_input.shape[0], layer.up_gate_proj_weight.shape[1]),
|
||||
(permute_input.shape[0], getattr(layer, self.added_weight_attrs[0]).shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(permute_input, permute_scale),
|
||||
(layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale),
|
||||
(getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_scale_attrs[0])),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
@@ -264,12 +371,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
|
||||
|
||||
ffn_out = paddle.empty(
|
||||
(ffn_out.shape[0], layer.down_proj_weight.shape[1]),
|
||||
(ffn_out.shape[0], getattr(layer, self.added_weight_attrs[1]).shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(ffn_in_x, ffn_in_x_scale_tensor),
|
||||
(layer.down_proj_weight, layer.down_proj_weight_scale),
|
||||
(getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_scale_attrs[1])),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
@@ -331,8 +438,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
permute_input,
|
||||
(
|
||||
layer.up_gate_proj_weight,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
getattr(layer, self.added_scale_attrs[0]),
|
||||
),
|
||||
up_gate_proj_out,
|
||||
token_nums_per_expert,
|
||||
@@ -350,8 +457,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
(act_out_fp8, scale),
|
||||
(
|
||||
layer.down_proj_weight,
|
||||
layer.down_proj_weight_scale,
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
getattr(layer, self.added_scale_attrs[1]),
|
||||
),
|
||||
ffn_out,
|
||||
token_nums_per_expert,
|
||||
@@ -423,12 +530,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
# up_gate_proj
|
||||
ffn_out = paddle.empty(
|
||||
(permute_input.shape[0], layer.up_gate_proj_weight.shape[1]),
|
||||
(permute_input.shape[0], getattr(layer, self.added_weight_attrs[0]).shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(permute_input, permute_scale),
|
||||
(layer.up_gate_proj_weight, layer.up_gate_proj_weight_scale),
|
||||
(getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_scale_attrs[0])),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
@@ -444,12 +551,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.transpose([1, 0])
|
||||
|
||||
ffn_out = paddle.empty(
|
||||
(ffn_out.shape[0], layer.down_proj_weight.shape[1]),
|
||||
(ffn_out.shape[0], getattr(layer, self.added_weight_attrs[1]).shape[1]),
|
||||
dtype=paddle.bfloat16,
|
||||
)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(ffn_in_x, ffn_in_x_scale_tensor),
|
||||
(layer.down_proj_weight, layer.down_proj_weight_scale),
|
||||
(getattr(layer, self.added_weight_attrs[1]), getattr(layer, self.added_scale_attrs[1])),
|
||||
ffn_out,
|
||||
m_indices,
|
||||
)
|
||||
|
@@ -20,6 +20,7 @@ from paddle import nn
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
@@ -604,64 +605,170 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
self.weight_dtype = paddle.float8_e4m3fn
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
self.ffn1_weight_shape = [
|
||||
self.up_gate_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
self.ffn2_weight_shape = [
|
||||
self.down_proj_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
self.up_gate_proj_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0],
|
||||
layer.hidden_size // self.quant_config.weight_block_size[1],
|
||||
]
|
||||
self.down_proj_scale_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // self.quant_config.weight_block_size[0],
|
||||
layer.moe_intermediate_size // self.quant_config.weight_block_size[1],
|
||||
]
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
)
|
||||
|
||||
layer.down_proj_weight = layer.create_parameter(
|
||||
shape=[layer.num_experts, layer.moe_intermediate_size, layer.hidden_size],
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.weight_dtype = paddle.float8_e4m3fn
|
||||
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
|
||||
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||
down_proj_weight_name = self.added_weight_attrs[1]
|
||||
up_gate_proj_scale_name = self.added_scale_attrs[0]
|
||||
down_proj_scale_name = self.added_scale_attrs[1]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_scale_shape,
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
weight_id_map = {"gate_up": 0, "down": 1}
|
||||
if (
|
||||
hasattr(layer.up_gate_proj_weight, "tensor_track")
|
||||
and layer.up_gate_proj_weight.tensor_track is not None
|
||||
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
|
||||
):
|
||||
weight_type = "gate_up"
|
||||
layer.up_gate_proj_weight.tensor_track = None
|
||||
else:
|
||||
weight_type = "down"
|
||||
layer.down_proj_weight.tensor_track = None
|
||||
|
||||
# 1.init shape and type
|
||||
self.added_scale_attrs = ["up_gate_proj_weight_scale_inv", "down_proj_weight_scale_inv"]
|
||||
# weight
|
||||
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
|
||||
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
|
||||
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
|
||||
weight_dtype = paddle.float8_e4m3fn
|
||||
# scale
|
||||
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
|
||||
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
|
||||
scale_dtype = "float32"
|
||||
|
||||
# 2.crate tmp tensor
|
||||
|
||||
weight = paddle.empty(shape=[weight_shape[0], weight_shape[2], weight_shape[1]], dtype=weight_dtype)
|
||||
scale = paddle.empty(shape=[scale_shape[0], scale_shape[2], scale_shape[1]], dtype=scale_dtype)
|
||||
|
||||
# 3.quantize weight
|
||||
from fastdeploy.model_executor.layers.utils import per_block_cast_to_fp8
|
||||
|
||||
for expert_id in range(layer.num_experts):
|
||||
weight_quant, scale[expert_id] = per_block_cast_to_fp8(
|
||||
getattr(layer, unquantized_weight_name)[expert_id], self.quant_config.weight_block_size
|
||||
)
|
||||
weight[expert_id].copy_(weight_quant, False)
|
||||
getattr(layer, unquantized_weight_name).value().get_tensor()._clear()
|
||||
|
||||
# create weight
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
ceil_div(layer.moe_intermediate_size * 2, self.quant_config.weight_block_size[0]),
|
||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[1]),
|
||||
],
|
||||
dtype="float32",
|
||||
dtype=weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# create scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=[
|
||||
layer.num_local_experts,
|
||||
ceil_div(layer.hidden_size, self.quant_config.weight_block_size[0]),
|
||||
ceil_div(layer.moe_intermediate_size, self.quant_config.weight_block_size[1]),
|
||||
],
|
||||
dtype="float32",
|
||||
dtype=scale_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]).contiguous(), False)
|
||||
getattr(layer, scale_name).copy_(scale.transpose([0, 2, 1]).contiguous(), False)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
@@ -719,8 +826,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
num_local_experts = layer.num_local_experts
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
E, N1, _ = layer.up_gate_proj_weight.shape
|
||||
N2 = layer.down_proj_weight.shape[1]
|
||||
E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape
|
||||
N2 = getattr(layer, self.added_weight_attrs[1]).shape[1]
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
@@ -759,10 +866,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x_q,
|
||||
layer.up_gate_proj_weight,
|
||||
getattr(layer, self.added_weight_attrs[0]),
|
||||
intermediate_cache1,
|
||||
x_scale,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
getattr(layer, self.added_scale_attrs[0]),
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -773,17 +880,17 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
K=hidden_size,
|
||||
stride_am=x_q.strides[0],
|
||||
stride_ak=x_q.strides[1],
|
||||
stride_be=layer.up_gate_proj_weight.strides[0],
|
||||
stride_bk=layer.up_gate_proj_weight.strides[2],
|
||||
stride_bn=layer.up_gate_proj_weight.strides[1],
|
||||
stride_be=getattr(layer, self.added_weight_attrs[0]).strides[0],
|
||||
stride_bk=getattr(layer, self.added_weight_attrs[0]).strides[2],
|
||||
stride_bn=getattr(layer, self.added_weight_attrs[0]).strides[1],
|
||||
stride_cm=intermediate_cache1.strides[0],
|
||||
stride_cn=intermediate_cache1.strides[1],
|
||||
#
|
||||
stride_asm=x_scale.strides[0], # only used in blockwise fp8
|
||||
stride_ask=x_scale.strides[1], # only used in blockwise fp8
|
||||
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
|
||||
stride_bsk=layer.up_gate_proj_weight_scale.strides[2],
|
||||
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
|
||||
stride_bse=getattr(layer, self.added_scale_attrs[0]).strides[0],
|
||||
stride_bsk=getattr(layer, self.added_scale_attrs[0]).strides[2],
|
||||
stride_bsn=getattr(layer, self.added_scale_attrs[0]).strides[1],
|
||||
group_n=self.quant_config.weight_block_size[1],
|
||||
group_k=self.quant_config.weight_block_size[0],
|
||||
# Meta-parameters
|
||||
@@ -813,10 +920,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x_q,
|
||||
layer.down_proj_weight,
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
intermediate_cache3,
|
||||
x_scale,
|
||||
layer.down_proj_weight_scale,
|
||||
getattr(layer, self.added_scale_attrs[1]),
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -827,16 +934,16 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
K=moe_intermediate_size,
|
||||
stride_am=x_q.strides[0],
|
||||
stride_ak=x_q.strides[1],
|
||||
stride_be=layer.down_proj_weight.strides[0],
|
||||
stride_bk=layer.down_proj_weight.strides[2],
|
||||
stride_bn=layer.down_proj_weight.strides[1],
|
||||
stride_be=getattr(layer, self.added_weight_attrs[1]).strides[0],
|
||||
stride_bk=getattr(layer, self.added_weight_attrs[1]).strides[2],
|
||||
stride_bn=getattr(layer, self.added_weight_attrs[1]).strides[1],
|
||||
stride_cm=intermediate_cache3.strides[0],
|
||||
stride_cn=intermediate_cache3.strides[1],
|
||||
stride_asm=x_scale.strides[0], # only used in blockwise fp8
|
||||
stride_ask=x_scale.strides[1], # only used in blockwise fp8
|
||||
stride_bse=layer.down_proj_weight_scale.strides[0],
|
||||
stride_bsk=layer.down_proj_weight_scale.strides[2],
|
||||
stride_bsn=layer.down_proj_weight_scale.strides[1],
|
||||
stride_bse=getattr(layer, self.added_scale_attrs[1]).strides[0],
|
||||
stride_bsk=getattr(layer, self.added_scale_attrs[1]).strides[2],
|
||||
stride_bsn=getattr(layer, self.added_scale_attrs[1]).strides[1],
|
||||
group_n=self.quant_config.weight_block_size[1],
|
||||
group_k=self.quant_config.weight_block_size[0],
|
||||
# Meta-parameters
|
||||
|
@@ -20,7 +20,12 @@ import paddle
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||
|
||||
from ..utils import get_tensor, per_block_cast_to_fp8
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
@@ -33,13 +38,14 @@ class BlockWiseFP8Config(QuantConfigBase):
|
||||
per-token quantization of activations during inference.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_block_size: list = [-1, -1]) -> None:
|
||||
def __init__(self, weight_block_size: list = [-1, -1], is_checkpoint_bf16: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.weight_block_size = weight_block_size
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM)
|
||||
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||
|
||||
def name(self) -> str:
|
||||
return "block_wise_fp8"
|
||||
@@ -47,7 +53,8 @@ class BlockWiseFP8Config(QuantConfigBase):
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "BlockWiseFP8Config":
|
||||
weight_block_size = config.get("weight_block_size", [128, 128])
|
||||
return cls(weight_block_size)
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
return cls(weight_block_size, is_checkpoint_bf16)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
@@ -82,31 +89,78 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer, **extra_weight_attrs):
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
quant_attrs = extra_weight_attrs
|
||||
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
|
||||
quant_attrs = {
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(
|
||||
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
|
||||
),
|
||||
}
|
||||
set_weight_attrs(
|
||||
layer.weight,
|
||||
quant_attrs,
|
||||
)
|
||||
else:
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
layer.weight_scale_inv = layer.create_parameter(
|
||||
shape=[
|
||||
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
|
||||
// self.quant_config.weight_block_size[0],
|
||||
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
|
||||
// self.quant_config.weight_block_size[1],
|
||||
],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
weight_tensor = layer.weight.transpose([1, 0])
|
||||
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
|
||||
|
||||
if hasattr(layer.weight, "tensor_track"):
|
||||
layer.weight.tensor_track = None
|
||||
layer.weight.value().get_tensor()._clear()
|
||||
del layer.weight
|
||||
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
shape=quanted_weight_tensor.shape,
|
||||
dtype="float8_e4m3fn",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.weight_scale_inv = layer.create_parameter(
|
||||
shape=weight_block_scale_tensor.shape,
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=[
|
||||
(layer.output_size + self.quant_config.weight_block_size[0] - 1)
|
||||
// self.quant_config.weight_block_size[0],
|
||||
(layer.input_size + self.quant_config.weight_block_size[1] - 1)
|
||||
// self.quant_config.weight_block_size[1],
|
||||
],
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
layer.weight.copy_(quanted_weight_tensor, False)
|
||||
layer.weight_scale_inv.copy_(weight_block_scale_tensor, False)
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
weight_tensor = weights.transpose([1, 0])
|
||||
quanted_weight_tensor, weight_block_scale_tensor = per_block_cast_to_fp8(weight_tensor)
|
||||
layer.weight.copy_(quanted_weight_tensor, False)
|
||||
layer.weight_scale.set_value(weight_block_scale_tensor)
|
||||
layer.weight_scale_inv.set_value(weight_block_scale_tensor)
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
@@ -119,7 +173,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
layer.weight.copy_(quant_weight.view("float8_e4m3fn"), False)
|
||||
|
||||
weight_scale = weight_scale.transpose([1, 0])
|
||||
layer.weight_scale.set_value(weight_scale)
|
||||
layer.weight_scale_inv.set_value(weight_scale)
|
||||
|
||||
def apply(self, layer, x):
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant_padding(
|
||||
@@ -130,7 +184,7 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(
|
||||
(x, x_scale_tensor),
|
||||
(layer.weight, layer.weight_scale),
|
||||
(layer.weight, layer.weight_scale_inv),
|
||||
linear_out,
|
||||
)
|
||||
if layer.with_bias:
|
||||
|
@@ -37,6 +37,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
is_channel_wise: bool = False,
|
||||
has_zero_point: bool = False,
|
||||
is_permuted: bool = True,
|
||||
is_checkpoint_bf16: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense_quant_type = dense_quant_type
|
||||
@@ -52,6 +53,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
self.quant_min_bound = 0
|
||||
self.quant_round_type = 0
|
||||
self.is_permuted = is_permuted
|
||||
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||
|
||||
def name(self) -> str:
|
||||
return "mix_quant"
|
||||
@@ -66,6 +68,7 @@ class MixQuantConfig(QuantConfigBase):
|
||||
config.get("is_channel_wise", False),
|
||||
config.get("has_zero_point", False),
|
||||
config.get("is_permuted", True),
|
||||
config.get("is_checkpoint_bf16", False),
|
||||
)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
@@ -73,13 +76,13 @@ class MixQuantConfig(QuantConfigBase):
|
||||
if layer.moe_tag == "Image":
|
||||
return (
|
||||
get_quantization_config(self.image_moe_quant_type)
|
||||
.from_config({"is_permuted": self.is_permuted})
|
||||
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
else:
|
||||
return (
|
||||
get_quantization_config(self.moe_quant_type)
|
||||
.from_config({"is_permuted": self.is_permuted})
|
||||
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
elif isinstance(layer, Attention):
|
||||
@@ -92,4 +95,8 @@ class MixQuantConfig(QuantConfigBase):
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return get_quantization_config(self.dense_quant_type).from_config({}).get_quant_method(layer)
|
||||
return (
|
||||
get_quantization_config(self.dense_quant_type)
|
||||
.from_config({"self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
|
Reference in New Issue
Block a user