mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-16 05:30:58 +08:00
support w4afp8 EP inference (#3044)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
This commit is contained in:
@@ -12,13 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .fused_moe_cutlass_backend import CutlassW4A8MoEMethod, CutlassWeightOnlyMoEMethod
|
||||
from .fused_moe_cutlass_backend import (
|
||||
CutlassW4A8MoEMethod,
|
||||
CutlassW4AFP8MoEMethod,
|
||||
CutlassWeightOnlyMoEMethod,
|
||||
)
|
||||
from .fused_moe_triton_backend import TritonWeightOnlyMoEMethod
|
||||
from .moe import FusedMoE
|
||||
|
||||
__all__ = [
|
||||
CutlassWeightOnlyMoEMethod,
|
||||
CutlassW4A8MoEMethod,
|
||||
CutlassW4AFP8MoEMethod,
|
||||
FusedMoE,
|
||||
TritonWeightOnlyMoEMethod,
|
||||
]
|
||||
|
@@ -389,7 +389,7 @@ class EPPrefillRunner(EPRunner):
|
||||
):
|
||||
(
|
||||
num_tokens_per_rank,
|
||||
_,
|
||||
num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert,
|
||||
is_token_in_rank,
|
||||
_,
|
||||
@@ -399,6 +399,7 @@ class EPPrefillRunner(EPRunner):
|
||||
dispatch_args = {
|
||||
"x": (x, x_scale_tensor) if x_scale_tensor is not None else x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": self.ep_engine.ep_config,
|
||||
|
@@ -31,6 +31,7 @@ if current_platform.is_cuda():
|
||||
moe_expert_dispatch,
|
||||
moe_expert_reduce,
|
||||
noaux_tc,
|
||||
w4afp8_gemm_scale_permute,
|
||||
)
|
||||
elif current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.ops.iluvatar import (
|
||||
@@ -87,6 +88,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
token_nums_per_expert: paddle.Tensor,
|
||||
expert_idx_per_token: paddle.Tensor,
|
||||
used_in_ep_low_latency: bool = False,
|
||||
estimate_total_token_nums: int = -1,
|
||||
):
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -104,6 +106,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
expert_idx_per_token,
|
||||
self.moe_quant_type,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums,
|
||||
)
|
||||
return fastdeploy.model_executor.ops.gpu.moe_expert_ffn(
|
||||
permute_input,
|
||||
@@ -117,6 +120,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
expert_idx_per_token,
|
||||
self.moe_quant_type,
|
||||
used_in_ep_low_latency,
|
||||
estimate_total_token_nums,
|
||||
)
|
||||
|
||||
def apply_ep_prefill(
|
||||
@@ -157,13 +161,13 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
(self.up_gate_proj_in_scale if hasattr(self, "up_gate_proj_in_scale") else None),
|
||||
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
|
||||
recv_num_tokens_per_expert_list,
|
||||
token_all_num,
|
||||
self.moe_quant_type,
|
||||
)
|
||||
if self.moe_quant_type != "w4a8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||
# only w4a8 and w4afp8 need expert_idx_per_token
|
||||
# Other need not this tensor, so we make it None.
|
||||
expert_idx_per_token = None
|
||||
else:
|
||||
@@ -202,18 +206,19 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
Apply the EP decoder method.
|
||||
"""
|
||||
gate_out = gate(x.cast("float32"))
|
||||
estimate_total_token_nums = gate_out.shape[0] * layer.top_k
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||
expertwise_scale = None
|
||||
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
|
||||
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
|
||||
|
||||
use_fp8 = self.moe_quant_type == "w4afp8"
|
||||
# 2. EP Dispatch
|
||||
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
|
||||
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale
|
||||
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale, use_fp8=use_fp8
|
||||
)
|
||||
# 3. Compute ffn
|
||||
if self.moe_quant_type == "w4a8":
|
||||
if self.moe_quant_type == "w4a8" or self.moe_quant_type == "w4afp8":
|
||||
num_local_experts, max_num, _ = permute_input.shape
|
||||
expert_idx_per_token = paddle.arange(num_local_experts)[:, None].tile([1, max_num])
|
||||
elif self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]:
|
||||
@@ -227,6 +232,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
token_nums_per_expert.cast("int64"),
|
||||
expert_idx_per_token,
|
||||
True,
|
||||
estimate_total_token_nums,
|
||||
)
|
||||
|
||||
# 4. EP combine
|
||||
@@ -290,7 +296,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
topk_only_mode=False,
|
||||
)
|
||||
|
||||
if self.moe_quant_type != "w4a8":
|
||||
if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
# Other need not this tensor, so we make it None.
|
||||
expert_idx_per_token = None
|
||||
@@ -373,9 +379,9 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
|
||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0).cast(paddle.get_default_dtype())
|
||||
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0)
|
||||
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0)
|
||||
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0)
|
||||
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).unsqueeze()
|
||||
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).unsqueeze()
|
||||
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).unsqueeze()
|
||||
|
||||
name_tensor_map = {
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
@@ -448,7 +454,6 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
Args:
|
||||
layer (nn.Layer): The layer to add parameters to.
|
||||
weight_key_map (dict): The weight key map.
|
||||
state_dict (dict): The state dict.
|
||||
"""
|
||||
self.default_dtype = layer._helper.get_default_dtype()
|
||||
if layer.ep_size > 1:
|
||||
@@ -572,6 +577,263 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
)
|
||||
|
||||
|
||||
class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
w4a8 MoE Method
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.quant_config = quant_config
|
||||
self.moe_quant_type = "w4afp8"
|
||||
self.pack_num = 2
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
up_gate_proj_expert_weight_key = layer.weight_key_map.get("up_gate_proj_expert_weight_key", None)
|
||||
down_proj_expert_weight_key = layer.weight_key_map.get("down_proj_expert_weight_key", None)
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
|
||||
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
)
|
||||
)
|
||||
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight_scale = []
|
||||
up_gate_proj_in_scale_all_experts = []
|
||||
up_gate_proj_in_scale = []
|
||||
down_proj_in_scale = []
|
||||
|
||||
if layer.ep_size > 1:
|
||||
for expert_idx in ep_rank_to_expert_id_list:
|
||||
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
|
||||
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
||||
|
||||
for expert_idx in logical_expert_ids:
|
||||
up_gate_proj_weight_scale.append(
|
||||
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
|
||||
)
|
||||
down_proj_weight_scale.append(
|
||||
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
|
||||
)
|
||||
up_gate_proj_in_scale.append(
|
||||
get_tensor(state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx)))
|
||||
)
|
||||
down_proj_in_scale.append(get_tensor(state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))))
|
||||
|
||||
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
|
||||
up_gate_proj_weight_scale = paddle.stack(up_gate_proj_weight_scale, axis=0)
|
||||
down_proj_weight_scale = paddle.stack(down_proj_weight_scale, axis=0)
|
||||
up_gate_proj_in_scale_all_experts = paddle.stack(up_gate_proj_in_scale_all_experts, axis=0).squeeze()
|
||||
up_gate_proj_in_scale = paddle.stack(up_gate_proj_in_scale, axis=0).squeeze()
|
||||
down_proj_in_scale = paddle.stack(down_proj_in_scale, axis=0).squeeze()
|
||||
|
||||
name_tensor_map = {
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scale,
|
||||
"down_proj_weight_scale": down_proj_weight_scale,
|
||||
"up_gate_proj_in_scale_all_experts": up_gate_proj_in_scale_all_experts,
|
||||
"up_gate_proj_in_scale": up_gate_proj_in_scale,
|
||||
"down_proj_in_scale": down_proj_in_scale,
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
getattr(layer, name).set_value(tensor)
|
||||
|
||||
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
self.weight_dtype = "int8"
|
||||
self.ffn1_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.hidden_size // 2,
|
||||
layer.moe_intermediate_size * 2,
|
||||
]
|
||||
self.ffn2_weight_shape = [
|
||||
layer.num_local_experts,
|
||||
layer.moe_intermediate_size // 2,
|
||||
layer.hidden_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
self.added_weight_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=self.ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_weight_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=self.ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
self.create_w4afp8_scale_weights(layer, layer.weight_key_map)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass load weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
weight_list = []
|
||||
for i in range(layer.num_local_experts):
|
||||
quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80)
|
||||
weight_list.append(quant_weight)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
self.load_w4afp8_scale_weights(layer, layer.weight_key_map, state_dict)
|
||||
|
||||
def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
|
||||
"""
|
||||
Get w4afp8 weights from state dict and process them.
|
||||
Args:
|
||||
layer (nn.Layer): The layer to add parameters to.
|
||||
weight_key_map (dict): The weight key map.
|
||||
"""
|
||||
|
||||
self.default_dtype = layer._helper.get_default_dtype()
|
||||
if layer.ep_size > 1:
|
||||
setattr(
|
||||
layer,
|
||||
"up_gate_proj_in_scale_all_experts",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_experts],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
# in_scales
|
||||
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||
setattr(
|
||||
layer,
|
||||
in_scale_name,
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
# weight_scales
|
||||
setattr(
|
||||
layer,
|
||||
"up_gate_proj_weight_scale",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
"down_proj_weight_scale",
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype="float32",
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def load_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict):
|
||||
"""
|
||||
Get w4afp8 weights from state dict and process them.
|
||||
Args:
|
||||
layer (nn.Layer): The layer to add parameters to.
|
||||
weight_key_map (dict): The weight key map.
|
||||
state_dict (dict): The state dict.
|
||||
"""
|
||||
|
||||
def _extract_scale_tensor(state_dict, key_template, expert_idx):
|
||||
return get_tensor(state_dict.pop(key_template.format(expert_idx)))
|
||||
|
||||
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
|
||||
processed_in_scale = 1 / paddle.concat(in_scales)
|
||||
getattr(layer, name).set_value(processed_in_scale)
|
||||
return processed_in_scale
|
||||
|
||||
def _permute_weight_scale(weight_scale: paddle.Tensor):
|
||||
weight_scale = w4afp8_gemm_scale_permute(weight_scale)
|
||||
return weight_scale
|
||||
|
||||
def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor):
|
||||
processed_weight_scale = (
|
||||
paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None]
|
||||
)
|
||||
processed_weight_scale = _permute_weight_scale(processed_weight_scale)
|
||||
getattr(layer, name).set_value(processed_weight_scale)
|
||||
|
||||
# 1. Init scale containers and maps
|
||||
up_gate_proj_weight_scales = []
|
||||
down_proj_weight_scales = []
|
||||
up_gate_proj_in_scales_all_experts = []
|
||||
up_gate_proj_in_scales = []
|
||||
down_proj_in_scales = []
|
||||
|
||||
scale_weight_map = {
|
||||
"up_gate_proj_weight_scale": up_gate_proj_weight_scales,
|
||||
"down_proj_weight_scale": down_proj_weight_scales,
|
||||
"up_gate_proj_in_scale": up_gate_proj_in_scales,
|
||||
"down_proj_in_scale": down_proj_in_scales,
|
||||
}
|
||||
scale_key_map = {
|
||||
"up_gate_proj_weight_scale": weight_key_map.get("up_gate_proj_expert_weight_scale_key", None),
|
||||
"down_proj_weight_scale": weight_key_map.get("down_proj_expert_weight_scale_key", None),
|
||||
"up_gate_proj_in_scale": weight_key_map.get("up_gate_proj_expert_in_scale_key", None),
|
||||
"down_proj_in_scale": weight_key_map.get("down_proj_expert_in_scale_key", None),
|
||||
}
|
||||
for name, value in scale_key_map.items():
|
||||
if value is None:
|
||||
raise ValueError(f"scale {name} should not be none in w4a8 mode.")
|
||||
|
||||
# 2. Extract scale tensor from state dict
|
||||
if layer.ep_size > 1:
|
||||
for expert_idx in range(layer.num_experts):
|
||||
scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)])
|
||||
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
|
||||
getattr(layer, "up_gate_proj_in_scale_all_experts").set_value(
|
||||
paddle.concat(up_gate_proj_in_scales_all_experts)
|
||||
)
|
||||
|
||||
for local_expert_idx in range(layer.num_local_experts):
|
||||
expert_idx = local_expert_idx + layer.expert_id_offset
|
||||
for name, scale_key_template in scale_key_map.items():
|
||||
scale_tensor = _extract_scale_tensor(state_dict, scale_key_template, expert_idx)
|
||||
scale_weight_map[name].append(scale_tensor)
|
||||
|
||||
# 3. Process scale tensor and set to layer
|
||||
in_scales = []
|
||||
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
|
||||
in_scales.append(_process_in_scale(in_scale_name, scale_weight_map[in_scale_name]))
|
||||
|
||||
for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]):
|
||||
_process_weight_scale(
|
||||
weight_scale_name,
|
||||
scale_weight_map[weight_scale_name],
|
||||
in_scales[i],
|
||||
)
|
||||
|
||||
|
||||
class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
"""
|
||||
weight only for moe
|
||||
|
@@ -20,6 +20,7 @@ import paddle
|
||||
|
||||
import fastdeploy
|
||||
|
||||
from ..moe import FusedMoE
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
QUANT_SCALING_FACTOR = 448
|
||||
@@ -30,24 +31,32 @@ class W4AFP8Config(QuantConfigBase):
|
||||
quantization config for weight 4bits and activation fp8
|
||||
"""
|
||||
|
||||
def __init__(self, weight_scale_dict, act_scale_dict) -> None:
|
||||
def __init__(self, weight_scale_dict, act_scale_dict, is_permuted) -> None:
|
||||
super().__init__()
|
||||
self.weight_scale_dict = weight_scale_dict
|
||||
self.act_scale_dict = act_scale_dict
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
self.is_permuted = is_permuted
|
||||
|
||||
def name(self) -> str:
|
||||
return "w4afp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "W4AFP8Config":
|
||||
weight_scale_dict = config["weight_scale_dict"]
|
||||
act_scale_dict = config["act_scale_dict"]
|
||||
return cls(weight_scale_dict, act_scale_dict)
|
||||
weight_scale_dict = config.get("weight_scale_dict", None)
|
||||
act_scale_dict = config.get("act_scale_dict", None)
|
||||
is_permuted = config.get("is_permuted", True)
|
||||
return cls(weight_scale_dict, act_scale_dict, is_permuted)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
|
||||
CutlassW4AFP8MoEMethod,
|
||||
)
|
||||
|
||||
return CutlassW4AFP8MoEMethod(self)
|
||||
return W4AFP8LinearMethod(self)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user