mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-31 03:46:40 +08:00
[CP]Glm45 air 2.2 (#4073)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [Feature] Support zai-org/GLM-4.5-Air BF16 model (#3928) * support glm45_air * [Feature] GLM-45-AIR Support Mix Quantization(Dense wfp8afp8 and wint8 triton_moe_backend) (#4051) * check * fix v1 load for mix and wint8 * check --quantizations 'None' * check * support RL rollout * check v1 loader * check glm rollout_model, change wfp8afp8 per_token_cast_to_fp8 to native impl * check rollout moe gate begin layer_id * check rollout e_score_correction_bias * delete infer_to_train_mapping={} * code check
This commit is contained in:
@@ -28,38 +28,9 @@ except:
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.config import MoEPhase
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.utils import singleton
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||
except:
|
||||
logger.warning("import noaux_tc Failed!")
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group if n_group > 0 else 1,
|
||||
topk_group if topk_group > 0 else 1,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
@singleton
|
||||
class DeepEPEngine:
|
||||
|
||||
@@ -27,11 +27,8 @@ from ..utils import get_tensor
|
||||
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
moe_expert_dispatch,
|
||||
moe_expert_reduce,
|
||||
noaux_tc,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||
@@ -46,31 +43,6 @@ elif current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
|
||||
|
||||
|
||||
# used for deepseek_v3
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
"""
|
||||
Use Cutlass Group Gemm to compute Fused MoE.
|
||||
|
||||
@@ -481,7 +481,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
gate_out = gate(x.cast("float32"))
|
||||
|
||||
if layer.topk_method == "noaux_tc":
|
||||
from .ep import get_moe_scores
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
|
||||
_, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
|
||||
@@ -19,39 +19,15 @@ from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
MoeWna16MarlinGemmApi,
|
||||
noaux_tc,
|
||||
tritonmoe_preprocess_func,
|
||||
)
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
def gptq_marlin_moe_repack(
|
||||
b_q_weight: paddle.Tensor,
|
||||
perm: paddle.Tensor,
|
||||
|
||||
@@ -24,7 +24,6 @@ from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
from .ep import get_moe_scores
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
|
||||
@@ -32,6 +31,7 @@ try:
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
except ImportError:
|
||||
pass
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
|
||||
|
||||
class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
@@ -72,43 +72,70 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer.moe_intermediate_size,
|
||||
layer.hidden_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
)
|
||||
|
||||
layer.down_proj_weight = layer.create_parameter(
|
||||
shape=self.down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
|
||||
},
|
||||
)
|
||||
else:
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
@@ -151,6 +178,62 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
|
||||
algo = layer.quant_method.quant_config.name()
|
||||
assert algo == "wint8"
|
||||
max_bound = 127
|
||||
weight_id_map = {"gate_up": 0, "down": 1}
|
||||
if (
|
||||
hasattr(layer.up_gate_proj_weight, "tensor_track")
|
||||
and layer.up_gate_proj_weight.tensor_track is not None
|
||||
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
|
||||
):
|
||||
weight_type = "gate_up"
|
||||
layer.up_gate_proj_weight.tensor_track = None
|
||||
else:
|
||||
weight_type = "down"
|
||||
layer.down_proj_weight.tensor_track = None
|
||||
|
||||
# weight
|
||||
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
|
||||
# scale
|
||||
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
|
||||
|
||||
weight_tensor = getattr(layer, weight_name)
|
||||
quanted_weight_scale = weight_tensor.abs().max(axis=1)
|
||||
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
|
||||
quanted_weight = paddle.round(quanted_weight).astype("int8")
|
||||
quanted_weight_scale = quanted_weight_scale / max_bound
|
||||
|
||||
getattr(layer, weight_name).value().get_tensor()._clear()
|
||||
|
||||
# create weight
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# create scale
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).copy_(quanted_weight, False)
|
||||
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
@@ -164,12 +247,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
num_local_experts = layer.num_local_experts
|
||||
top_k = layer.top_k
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
if layer.topk_method == "noaux_tc":
|
||||
_, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
@@ -177,15 +259,15 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
)
|
||||
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
|
||||
else:
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight
|
||||
top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -302,6 +384,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -432,7 +517,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
if layer.topk_method == "noaux_tc":
|
||||
|
||||
_, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
|
||||
@@ -27,6 +27,11 @@ from fastdeploy.model_executor.utils import slice_fn
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||
except:
|
||||
logger.warning("import noaux_tc Failed!")
|
||||
|
||||
|
||||
def get_moe_method():
|
||||
"""
|
||||
@@ -54,6 +59,31 @@ def get_moe_method():
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group if n_group > 0 else 1,
|
||||
topk_group if topk_group > 0 else 1,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
class FusedMoE(nn.Layer):
|
||||
"""
|
||||
FusedMoE is a layer that performs MoE (Mixture of Experts) computation.
|
||||
|
||||
Reference in New Issue
Block a user