mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] merge apply_tp, ops support token_num = 0 (#4507)
This commit is contained in:
@@ -40,20 +40,21 @@ std::vector<paddle::Tensor> MoeEPCombineKernel(
|
||||
|
||||
auto combined_out = paddle::empty(
|
||||
{recv_token_num, hidden_dim}, ffn_out.dtype(), ffn_out.place());
|
||||
|
||||
const float* dequant_score = nullptr;
|
||||
int ret = infer_ops::moe_ep_ffn_post_fusion(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_T*>(ffn_out.data<T>()),
|
||||
moe_index.data<int32_t>(),
|
||||
reinterpret_cast<const XPU_T*>(weights.data<T>()),
|
||||
dequant_score,
|
||||
reinterpret_cast<XPU_T*>(combined_out.mutable_data<T>()),
|
||||
recv_token_num,
|
||||
hidden_dim,
|
||||
topk,
|
||||
expand_token_num);
|
||||
PD_CHECK(ret == 0);
|
||||
if (recv_token_num > 0) {
|
||||
int ret = infer_ops::moe_ep_ffn_post_fusion(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_T*>(ffn_out.data<T>()),
|
||||
moe_index.data<int32_t>(),
|
||||
reinterpret_cast<const XPU_T*>(weights.data<T>()),
|
||||
dequant_score,
|
||||
reinterpret_cast<XPU_T*>(combined_out.mutable_data<T>()),
|
||||
recv_token_num,
|
||||
hidden_dim,
|
||||
topk,
|
||||
expand_token_num);
|
||||
PD_CHECK(ret == 0);
|
||||
}
|
||||
|
||||
return {combined_out};
|
||||
}
|
||||
|
||||
@@ -60,44 +60,48 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
|
||||
if (std::is_same<TY, int8_t>::value) {
|
||||
permute_input =
|
||||
paddle::empty({token_nums_this_rank, n}, paddle::DataType::INT8, place);
|
||||
auto ret = infer_ops::moe_ffn_pre_sorted_quant_pe<XPU_TX, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
|
||||
topk_ids.data<int>(),
|
||||
input_scales.get_ptr()->data<float>(),
|
||||
nullptr,
|
||||
reinterpret_cast<int8_t*>(permute_input.data<int8_t>()),
|
||||
const_cast<int*>(permute_indices_per_token.data<int>()),
|
||||
const_cast<int*>(expert_m.data<int>()),
|
||||
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
|
||||
expand_input_scales.data<float>(),
|
||||
m,
|
||||
n,
|
||||
expert_num,
|
||||
topk,
|
||||
block_num,
|
||||
token_nums_this_rank);
|
||||
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
|
||||
if (token_nums_this_rank > 0) {
|
||||
auto ret = infer_ops::moe_ffn_pre_sorted_quant_pe<XPU_TX, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
|
||||
topk_ids.data<int>(),
|
||||
input_scales.get_ptr()->data<float>(),
|
||||
nullptr,
|
||||
reinterpret_cast<int8_t*>(permute_input.data<int8_t>()),
|
||||
const_cast<int*>(permute_indices_per_token.data<int>()),
|
||||
const_cast<int*>(expert_m.data<int>()),
|
||||
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
|
||||
expand_input_scales.data<float>(),
|
||||
m,
|
||||
n,
|
||||
expert_num,
|
||||
topk,
|
||||
block_num,
|
||||
token_nums_this_rank);
|
||||
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
|
||||
}
|
||||
} else {
|
||||
permute_input = paddle::empty({token_nums_this_rank, n}, input_type, place);
|
||||
auto ret = infer_ops::moe_ep_ffn_pre_sorted<XPU_TX, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
|
||||
topk_ids.data<int>(),
|
||||
nullptr,
|
||||
reinterpret_cast<XPU_TX*>(permute_input.data<TX>()),
|
||||
const_cast<int*>(permute_indices_per_token.data<int>()),
|
||||
const_cast<int*>(expert_m.data<int>()),
|
||||
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
|
||||
m,
|
||||
n,
|
||||
expert_num,
|
||||
topk,
|
||||
block_num,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
token_nums_this_rank);
|
||||
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
|
||||
if (token_nums_this_rank > 0) {
|
||||
auto ret = infer_ops::moe_ep_ffn_pre_sorted<XPU_TX, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
|
||||
topk_ids.data<int>(),
|
||||
nullptr,
|
||||
reinterpret_cast<XPU_TX*>(permute_input.data<TX>()),
|
||||
const_cast<int*>(permute_indices_per_token.data<int>()),
|
||||
const_cast<int*>(expert_m.data<int>()),
|
||||
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
|
||||
m,
|
||||
n,
|
||||
expert_num,
|
||||
topk,
|
||||
block_num,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
token_nums_this_rank);
|
||||
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
|
||||
}
|
||||
}
|
||||
return {permute_input,
|
||||
permute_indices_per_token,
|
||||
|
||||
@@ -441,6 +441,12 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const std::string& quant_method,
|
||||
const int hadamard_blocksize,
|
||||
const int valid_token_num) {
|
||||
if (ffn_in.numel() == 0) {
|
||||
paddle::Tensor ffn2_out =
|
||||
paddle::empty_like(ffn_in, paddle::DataType::BFLOAT16);
|
||||
return {ffn2_out};
|
||||
}
|
||||
|
||||
const auto x_type = ffn_in.dtype();
|
||||
const auto w_type = ffn1_weight.dtype();
|
||||
|
||||
|
||||
@@ -146,14 +146,14 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
|
||||
layer.down_proj_weight.set_value(stacked_down_proj_weights)
|
||||
|
||||
def apply_tp(
|
||||
def apply_tp_fused_op(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
XPU compute Fused MoE.
|
||||
Apply TP Fused Op.
|
||||
"""
|
||||
from fastdeploy.model_executor.ops.xpu import xpu_moe_layer
|
||||
|
||||
@@ -165,9 +165,9 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
layer.down_proj_weight,
|
||||
None, # up_gate_proj bias
|
||||
None, # down_proj bias
|
||||
getattr(layer, "up_gate_proj_weight_scale", None),
|
||||
getattr(layer, "down_proj_weight_scale", None),
|
||||
getattr(layer, "up_gate_proj_in_scale", None),
|
||||
getattr(layer, self.added_scale_attrs[0], None),
|
||||
getattr(layer, self.added_scale_attrs[1], None),
|
||||
getattr(layer, self.added_in_scale_attrs[0], None),
|
||||
self.moe_quant_type,
|
||||
layer.top_k,
|
||||
False, # moe group, used in deepseek
|
||||
@@ -177,21 +177,103 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
def apply_tp_scatter_op(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply TP Scatter Op.
|
||||
"""
|
||||
gate_out = gate(x.cast("float32"))
|
||||
topk_idx, topk_weights = moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True,
|
||||
)
|
||||
token_nums_per_expert_list = list(range(64)) # placeholder, not use
|
||||
(
|
||||
permute_input,
|
||||
permute_indices_per_token,
|
||||
token_num_lod,
|
||||
dst_weights,
|
||||
ffn1_act_scale_per_token,
|
||||
) = ep_moe_expert_dispatch(
|
||||
x,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
getattr(layer, self.added_in_scale_attrs[0], None),
|
||||
token_nums_per_expert_list,
|
||||
x.shape[0] * layer.top_k,
|
||||
self.moe_quant_type,
|
||||
)
|
||||
|
||||
if not hasattr(layer, self.added_in_scale_attrs[0]):
|
||||
ffn1_act_scale_per_token = None
|
||||
ffn_out = self.compute_ffn(
|
||||
layer,
|
||||
permute_input,
|
||||
token_num_lod,
|
||||
x.shape[0] * layer.top_k,
|
||||
ffn1_act_scale_per_token,
|
||||
)
|
||||
|
||||
topk_weights_bf16 = topk_weights.astype("bfloat16")
|
||||
tmp_ffn_out = ep_moe_expert_combine(
|
||||
ffn_out,
|
||||
permute_indices_per_token,
|
||||
topk_weights_bf16,
|
||||
permute_indices_per_token.shape[0],
|
||||
ffn_out.shape[0],
|
||||
ffn_out.shape[1],
|
||||
permute_indices_per_token.shape[1],
|
||||
)
|
||||
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
tmp_ffn_out = tensor_model_parallel_all_reduce(tmp_ffn_out)
|
||||
return tmp_ffn_out
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
apply tp
|
||||
"""
|
||||
if self.moe_quant_type in ["w16a16"]:
|
||||
using_ep_moe_algo = False
|
||||
elif self.moe_quant_type in ["w4a8"]:
|
||||
using_ep_moe_algo = True
|
||||
else:
|
||||
using_ep_moe_algo = int(os.environ.get("USING_EP_MOE_ALGO", 0)) != 0
|
||||
print(f"using_ep_moe_algo: {using_ep_moe_algo}")
|
||||
|
||||
if using_ep_moe_algo:
|
||||
fused_moe_out = self.apply_tp_scatter_op(layer, x, gate)
|
||||
else:
|
||||
fused_moe_out = self.apply_tp_fused_op(layer, x, gate)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
def compute_ffn(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
permute_input,
|
||||
token_num_lod,
|
||||
valid_token_num=-1,
|
||||
extra_ffn1_in_scale=None,
|
||||
valid_token_num,
|
||||
ffn1_act_scale_per_token=None,
|
||||
):
|
||||
"""
|
||||
Calculate moe
|
||||
"""
|
||||
# ffn1_in_scale = extra_ffn1_in_scale
|
||||
moe_ffn1_scale = None
|
||||
moe_ffn2_scale = None
|
||||
|
||||
if self.moe_quant_type in ["w4a8"]:
|
||||
hadamard_block_size = getattr(layer.moe_quant_config, "hadamard_block_size", 128)
|
||||
else:
|
||||
hadamard_block_size = -1
|
||||
ffn_out = moe_expert_ffn(
|
||||
permute_input,
|
||||
token_num_lod,
|
||||
@@ -199,14 +281,14 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
getattr(layer, self.added_weight_attrs[1]),
|
||||
None,
|
||||
None,
|
||||
moe_ffn1_scale,
|
||||
moe_ffn2_scale,
|
||||
getattr(layer, self.added_scale_attrs[0]),
|
||||
getattr(layer, self.added_scale_attrs[1]),
|
||||
ffn1_act_scale_per_token,
|
||||
getattr(layer, self.added_in_scale_attrs[1], None),
|
||||
getattr(layer, self.added_scale_attrs[0], None),
|
||||
getattr(layer, self.added_scale_attrs[1], None),
|
||||
None,
|
||||
None,
|
||||
self.moe_quant_type,
|
||||
-1,
|
||||
hadamard_block_size,
|
||||
valid_token_num,
|
||||
)
|
||||
return ffn_out
|
||||
@@ -245,45 +327,41 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
token_all_num = sum(token_num_per_expert)
|
||||
|
||||
# 4. Compute ffn
|
||||
if token_all_num > 0:
|
||||
moe_dispatch_scale = None
|
||||
(
|
||||
permute_input,
|
||||
permute_indices_per_token,
|
||||
token_num_lod,
|
||||
dst_weights,
|
||||
ffn1_act_scale_per_token,
|
||||
) = ep_moe_expert_dispatch(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
moe_dispatch_scale,
|
||||
token_num_per_expert,
|
||||
token_all_num,
|
||||
self.moe_quant_type,
|
||||
)
|
||||
moe_dispatch_scale = None
|
||||
(
|
||||
permute_input,
|
||||
permute_indices_per_token,
|
||||
token_num_lod,
|
||||
dst_weights,
|
||||
ffn1_act_scale_per_token,
|
||||
) = ep_moe_expert_dispatch(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
moe_dispatch_scale,
|
||||
token_num_per_expert,
|
||||
token_all_num,
|
||||
self.moe_quant_type,
|
||||
)
|
||||
|
||||
ffn_out = self.compute_ffn(
|
||||
layer,
|
||||
permute_input,
|
||||
token_num_lod,
|
||||
token_all_num,
|
||||
)
|
||||
ffn_out = self.compute_ffn(
|
||||
layer,
|
||||
permute_input,
|
||||
token_num_lod,
|
||||
token_all_num,
|
||||
)
|
||||
|
||||
# prmt back per rank
|
||||
recv_topk_weights_bf16 = recv_topk_weights.astype("bfloat16")
|
||||
tmp_ffn_out = ep_moe_expert_combine(
|
||||
ffn_out,
|
||||
permute_indices_per_token,
|
||||
recv_topk_weights_bf16,
|
||||
permute_indices_per_token.shape[0],
|
||||
ffn_out.shape[0],
|
||||
ffn_out.shape[1],
|
||||
permute_indices_per_token.shape[1],
|
||||
)
|
||||
|
||||
else:
|
||||
tmp_ffn_out = paddle.empty(recv_x.shape, "bfloat16")
|
||||
# prmt back per rank
|
||||
recv_topk_weights_bf16 = recv_topk_weights.astype("bfloat16")
|
||||
tmp_ffn_out = ep_moe_expert_combine(
|
||||
ffn_out,
|
||||
permute_indices_per_token,
|
||||
recv_topk_weights_bf16,
|
||||
permute_indices_per_token.shape[0],
|
||||
ffn_out.shape[0],
|
||||
ffn_out.shape[1],
|
||||
permute_indices_per_token.shape[1],
|
||||
)
|
||||
|
||||
# 5. EP combine
|
||||
handle = None
|
||||
@@ -395,98 +473,6 @@ class XPUWeightOnlyMoEMethod(XPUMoEMethod):
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
XPU compute Fused MoE.
|
||||
"""
|
||||
USING_EP_MOE_ALGO = int(os.environ.get("USING_EP_MOE_ALGO", 0))
|
||||
if USING_EP_MOE_ALGO:
|
||||
token_num = x.shape[0]
|
||||
if token_num > 0:
|
||||
gate_out = gate(x.cast("float32"))
|
||||
topk_idx, topk_weights = moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True,
|
||||
)
|
||||
token_nums_per_expert_list = list(range(64)) # 填充做占位符
|
||||
(
|
||||
permute_input,
|
||||
permute_indices_per_token,
|
||||
token_num_lod,
|
||||
dst_weights,
|
||||
ffn1_act_scale_per_token,
|
||||
) = ep_moe_expert_dispatch(
|
||||
x,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
getattr(layer, "up_gate_proj_in_scale", None),
|
||||
token_nums_per_expert_list,
|
||||
x.shape[0] * layer.top_k,
|
||||
self.moe_quant_type,
|
||||
)
|
||||
|
||||
ffn_out = moe_expert_ffn(
|
||||
permute_input,
|
||||
token_num_lod,
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
None, # moe_ffn1_bias
|
||||
None, # moe_ffn2_bias
|
||||
None, # ffn1 in scale
|
||||
None, # ffn2 in scale
|
||||
getattr(layer, "up_gate_proj_weight_scale", None),
|
||||
getattr(layer, "down_proj_weight_scale", None),
|
||||
None, # moe_ffn2_shift
|
||||
None, # moe_ffn2_smooth
|
||||
self.moe_quant_type,
|
||||
-1,
|
||||
x.shape[0] * layer.top_k, # token_all_num
|
||||
)
|
||||
topk_weights_bf16 = topk_weights.astype("bfloat16")
|
||||
tmp_ffn_out = ep_moe_expert_combine(
|
||||
ffn_out,
|
||||
permute_indices_per_token,
|
||||
topk_weights_bf16,
|
||||
permute_indices_per_token.shape[0],
|
||||
ffn_out.shape[0],
|
||||
ffn_out.shape[1],
|
||||
permute_indices_per_token.shape[1],
|
||||
)
|
||||
else:
|
||||
tmp_ffn_out = paddle.empty(x.shape, x.dtype)
|
||||
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(tmp_ffn_out)
|
||||
return tmp_ffn_out
|
||||
else:
|
||||
from fastdeploy.model_executor.ops.xpu import xpu_moe_layer
|
||||
|
||||
fused_moe_out = xpu_moe_layer(
|
||||
x,
|
||||
gate.weight.transpose([1, 0]),
|
||||
layer.gate_correction_bias,
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
None, # up_gate_proj bias
|
||||
None, # down_proj bias
|
||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||
(layer.down_proj_in_scale if hasattr(layer, "down_proj_in_scale") else None),
|
||||
self.moe_quant_type,
|
||||
layer.top_k,
|
||||
False, # moe group, used in deepseek
|
||||
)
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
class XPUW4A8MoEMethod(XPUMoEMethod):
|
||||
"""
|
||||
@@ -607,63 +593,3 @@ class XPUW4A8MoEMethod(XPUMoEMethod):
|
||||
|
||||
for weight_scale_name in self.added_scale_attrs:
|
||||
getattr(layer, weight_scale_name).set_value(paddle.stack(scale_weight_map[weight_scale_name], axis=0))
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
) -> paddle.Tensor:
|
||||
gate_out = gate(x.cast("float32"))
|
||||
topk_idx, topk_weights = moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True,
|
||||
)
|
||||
token_nums_per_expert_list = list(range(64)) # 填充做占位符
|
||||
(
|
||||
permute_input,
|
||||
permute_indices_per_token,
|
||||
token_num_lod,
|
||||
dst_weights,
|
||||
ffn1_act_scale_per_token,
|
||||
) = ep_moe_expert_dispatch(
|
||||
x,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
getattr(layer, "up_gate_proj_in_scale", None),
|
||||
token_nums_per_expert_list,
|
||||
x.shape[0] * layer.top_k,
|
||||
self.moe_quant_type,
|
||||
)
|
||||
ffn_out = moe_expert_ffn(
|
||||
permute_input,
|
||||
token_num_lod,
|
||||
layer.up_gate_proj_weight,
|
||||
layer.down_proj_weight,
|
||||
None, # moe_ffn1_bias
|
||||
None, # moe_ffn2_bias
|
||||
(ffn1_act_scale_per_token if hasattr(layer, "up_gate_proj_in_scale") else None),
|
||||
getattr(layer, "down_proj_in_scale", None),
|
||||
getattr(layer, "up_gate_proj_weight_scale", None),
|
||||
getattr(layer, "down_proj_weight_scale", None),
|
||||
None, # moe_ffn2_shift
|
||||
None, # moe_ffn2_smooth
|
||||
self.moe_quant_type,
|
||||
getattr(layer.moe_quant_config, "hadamard_block_size", 128),
|
||||
x.shape[0] * layer.top_k, # token_all_num
|
||||
)
|
||||
topk_weights_bf16 = topk_weights.astype("bfloat16")
|
||||
tmp_ffn_out = ep_moe_expert_combine(
|
||||
ffn_out,
|
||||
permute_indices_per_token,
|
||||
topk_weights_bf16,
|
||||
permute_indices_per_token.shape[0],
|
||||
ffn_out.shape[0],
|
||||
ffn_out.shape[1],
|
||||
permute_indices_per_token.shape[1],
|
||||
)
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
tmp_ffn_out = tensor_model_parallel_all_reduce(tmp_ffn_out)
|
||||
return tmp_ffn_out
|
||||
|
||||
Reference in New Issue
Block a user