diff --git a/custom_ops/xpu_ops/src/ops/moe_layer.cc b/custom_ops/xpu_ops/src/ops/moe_layer.cc index 4e8d54cb7..937580d2c 100644 --- a/custom_ops/xpu_ops/src/ops/moe_layer.cc +++ b/custom_ops/xpu_ops/src/ops/moe_layer.cc @@ -228,8 +228,9 @@ MoeLayer(const paddle::Tensor &x, const paddle::Tensor &gate_weight, quant_method == "weight_only_int4") { APPLY_MOE_LAYER_KERNEL(paddle::bfloat16, int4_t); } else { - PD_THROW("MoeLayer not support x_type==%d, w_type==%d", - static_cast(x_type), static_cast(w_type)); + PD_THROW("MoeLayer not support x_type=", static_cast(x_type), + ", w_type=", static_cast(w_type), + ", quant_method=", quant_method); return {}; } #undef APPLY_MOE_LAYER_KERNEL diff --git a/fastdeploy/model_executor/layers/backends/xpu/__init__.py b/fastdeploy/model_executor/layers/backends/xpu/__init__.py index ac530637e..e3cf1e1cc 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/__init__.py +++ b/fastdeploy/model_executor/layers/backends/xpu/__init__.py @@ -16,16 +16,11 @@ xpu backend methods """ -from .moe.fused_moe import ( - XPUMoEMethod, - XPUWeightOnlyMoeEpMethod, - XPUWeightOnlyMoEMethod, -) +from .moe.fused_moe import XPUMoEMethod, XPUWeightOnlyMoEMethod from .quantization.weight_only import XPUWeightOnlyLinearMethod __all__ = [ "XPUWeightOnlyLinearMethod", "XPUMoEMethod", "XPUWeightOnlyMoEMethod", - "XPUWeightOnlyMoeEpMethod", ] diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py index 4636d82dd..2e74e5346 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py @@ -17,10 +17,8 @@ import paddle from paddle import nn -from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import ( - UnquantizedFusedMoEMethod, -) -from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce +from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.ops.xpu import ( @@ -32,11 +30,109 @@ from fastdeploy.model_executor.ops.xpu import ( ) -class XPUMoEMethod(UnquantizedFusedMoEMethod): +class XPUMoEMethod(MoEMethodBase): """ XPU MOE """ + def __init__( + self, + quant_config: WeightOnlyConfig, + ) -> None: + super().__init__(quant_config) + + if self.moe_quant_type in ["w16a16"]: + self.weight_dtype = "bfloat16" + elif self.moe_quant_type in ["weight_only_int8", "w8a8", "weight_only_int4", "w4a8"]: + self.weight_dtype = "int8" + else: + raise ValueError(f"Unsupported moe quant type: {self.moe_quant_type}") + self.scale_dtype = "float32" + self.bias_dtype = "float32" + + def import_backend_ep_runner(self) -> None: + from .ep import XPUEPDecoderRunner, XPUEPPrefillRunner + + self.EPPrefillRunner = XPUEPPrefillRunner + self.EPDecoderRunner = XPUEPDecoderRunner + + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): + """ + create weight process. + """ + self.up_gate_proj_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size, + ] + self.down_proj_weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size, + ] + if self.moe_quant_type in ["weight_only_int4", "w4a8"]: + self.up_gate_proj_weight_shape[-1] //= 2 + self.down_proj_weight_shape[-1] //= 2 + + setattr( + layer, + self.added_weight_attrs[0], + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + self.added_weight_attrs[1], + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + if self.moe_quant_type in ["weight_only_int8", "w8a8", "weight_only_int4", "w4a8"]: + self.up_gate_proj_scale_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + ] + self.down_proj_scale_shape = [ + layer.num_local_experts, + layer.hidden_size, + ] + setattr( + layer, + self.added_scale_attrs[0], + layer.create_parameter( + shape=self.up_gate_proj_scale_shape, + dtype=self.scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + self.added_scale_attrs[1], + layer.create_parameter( + shape=self.down_proj_scale_shape, + dtype=self.scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + if self.moe_quant_type in ["w8a8", "w4a8"]: + for in_scale_name in self.added_in_scale_attrs: + setattr( + layer, + in_scale_name, + layer.create_parameter( + shape=[layer.num_local_experts], + dtype=self.scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + def process_loaded_weights(self, layer: nn.Layer, state_dict): up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) for weights in [up_gate_proj_weights, down_proj_weights]: @@ -55,7 +151,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod): gate: nn.Layer, ) -> paddle.Tensor: """ - Paddle Cutlass compute Fused MoE. + XPU compute Fused MoE. """ from fastdeploy.model_executor.ops.xpu import xpu_moe_layer @@ -67,22 +163,52 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod): layer.down_proj_weight, None, # up_gate_proj bias None, # down_proj bias - None, # up_gate_proj scale - None, # down_proj scale - None, # up_gate_proj_in_scale - "", # moe_quant_type + getattr(layer, "up_gate_proj_weight_scale", None), + getattr(layer, "down_proj_weight_scale", None), + getattr(layer, "up_gate_proj_in_scale", None), + self.moe_quant_type, layer.top_k, False, # moe group, used in deepseek ) if layer.reduce_results and layer.tp_size > 1: - from fastdeploy.distributed.communication import ( - tensor_model_parallel_all_reduce, - ) - tensor_model_parallel_all_reduce(fused_moe_out) return fused_moe_out + def compute_ffn( + self, + layer: nn.Layer, + permute_input, + token_num_lod, + valid_token_num=-1, + extra_ffn1_in_scale=None, + ): + """ + Calculate moe + """ + # ffn1_in_scale = extra_ffn1_in_scale + moe_ffn1_scale = None + moe_ffn2_scale = None + + ffn_out = moe_expert_ffn( + permute_input, + token_num_lod, + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_weight_attrs[1]), + None, + None, + moe_ffn1_scale, + moe_ffn2_scale, + getattr(layer, self.added_scale_attrs[0]), + getattr(layer, self.added_scale_attrs[1]), + None, + None, + self.moe_quant_type, + -1, + valid_token_num, + ) + return ffn_out + def apply_ep_prefill( self, layer: nn.Layer, @@ -92,7 +218,74 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod): """ Apply the EP prefill method. """ - raise NotImplementedError + gate_out = gate(x.cast("float32")) + # 1. Select topk experts and weights + topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) + # 2. Dynamic compute blockwise quantization scales + # x, x_scale_tensor = fastdeploy.model_executor.ops.xpu.per_token_quant(x) + x_scale_tensor = None + # 3. EP Dispatch + ( + recv_x, + recv_x_scales, + recv_topk_idx, + recv_topk_weights, + recv_num_tokens_per_expert_list, + _, + ) = self.ep_prefill_runner.dispatch( + x, + topk_idx, + topk_weights, + x_scale_tensor=x_scale_tensor, + ) + + token_num_per_expert = recv_num_tokens_per_expert_list.numpy().tolist() + token_all_num = sum(token_num_per_expert) + + # 4. Compute ffn + if token_all_num > 0: + moe_dispatch_scale = None + ( + permute_input, + permute_indices_per_token, + token_num_lod, + dst_weights, + ffn1_act_scale_per_token, + ) = ep_moe_expert_dispatch( + recv_x, + recv_topk_idx, + recv_topk_weights, + moe_dispatch_scale, + token_num_per_expert, + token_all_num, + self.moe_quant_type, + ) + + ffn_out = self.compute_ffn( + layer, + permute_input, + token_num_lod, + token_all_num, + ) + + # prmt back per rank + recv_topk_weights_bf16 = recv_topk_weights.astype("bfloat16") + tmp_ffn_out = ep_moe_expert_combine( + ffn_out, + permute_indices_per_token, + recv_topk_weights_bf16, + permute_indices_per_token.shape[0], + ffn_out.shape[0], + ffn_out.shape[1], + permute_indices_per_token.shape[1], + ) + + else: + tmp_ffn_out = paddle.empty(recv_x.shape, "bfloat16") + + # 5. EP combine + handle = None + return self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights) def apply_ep_decode( self, @@ -103,138 +296,42 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod): """ Apply the EP decoder method. """ - raise NotImplementedError + gate_out = gate(x.cast("float32")) + # 1. Select topk experts and weights + topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) -class XPUWeightOnlyMoEMethod(QuantMethodBase): - """ - XPU Fused MoE Method. - """ - - def __init__( - self, - quant_config: WeightOnlyConfig, - ) -> None: - super().__init__() - self.quant_config = quant_config - self.moe_quant_type = self.quant_config.algo - self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] - self.added_scale_attrs = [ - "up_gate_proj_weight_scale", - "down_proj_weight_scale", - ] - - def create_weights(self, layer: nn.Layer, **extra_weight_attrs): - """ - Paddle cutlass create weight process. - """ - self.default_dtype = "float32" - self.weight_dtype = "int8" - - if self.moe_quant_type in ["weight_only_int4", "w4a8"]: - self.up_gate_proj_weight_shape = [ - layer.num_local_experts, - layer.moe_intermediate_size * 2, - layer.hidden_size // 2, - ] - else: - self.up_gate_proj_weight_shape = [ - layer.num_local_experts, - layer.moe_intermediate_size * 2, - layer.hidden_size, - ] - if self.moe_quant_type in ["weight_only_int4", "w4a8"]: - self.down_proj_weight_shape = [ - layer.num_local_experts, - layer.hidden_size, - layer.moe_intermediate_size // 2, - ] - else: - self.down_proj_weight_shape = [ - layer.num_local_experts, - layer.hidden_size, - layer.moe_intermediate_size, - ] - - setattr( - layer, - self.added_weight_attrs[0], - layer.create_parameter( - shape=self.up_gate_proj_weight_shape, - dtype=self.weight_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - self.added_weight_attrs[1], - layer.create_parameter( - shape=self.down_proj_weight_shape, - dtype=self.weight_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - # weight_scale - setattr( - layer, - self.added_scale_attrs[0], - layer.create_parameter( - shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], - dtype=self.default_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - self.added_scale_attrs[1], - layer.create_parameter( - shape=[layer.num_local_experts, layer.hidden_size], - dtype=self.default_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), + # 2. EP Dispatch + expertwise_scale = None + use_fp8 = False + ( + permute_input, + token_nums_per_expert, + handle, + valid_token_num, + ) = self.ep_decoder_runner.dispatch( + x, + topk_idx, + topk_weights, + expertwise_scale=expertwise_scale, + use_fp8=use_fp8, ) - def process_loaded_weights(self, layer: nn.Layer, state_dict): - """ - Paddle xpu load weight process. - """ + # 3. Compute ffn + ffn_out = self.compute_ffn( + layer, + permute_input, + token_nums_per_expert, + valid_token_num, + ) - # for k, v in state_dict.items(): - # print(f"k : {k}, value.shape {v.shape}, value.dtype : {v.dtype}") - up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) - assert len(up_gate_proj_weights) == layer.num_local_experts - assert len(down_proj_weights) == layer.num_local_experts - assert up_gate_proj_weights[0].shape == [ - layer.hidden_size, - layer.moe_intermediate_size * 2, - ] - assert down_proj_weights[0].shape == [ - layer.moe_intermediate_size, - layer.hidden_size, - ] - - for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): - weight_name = self.added_weight_attrs[idx] - scale_name = self.added_scale_attrs[idx] - - weight_list = [] - weight_scale_list = [] - for i in range(layer.num_local_experts): - # print(f"=======================第{i}层=======================") - # print(f" wint4 未量化前权重: {weight_tensor[i]}") - quant_weight, scale = weight_quantize_xpu( - weight_tensor[i], self.moe_quant_type, -1, -1 - ) # weight is [k,n] - - # print(f" wint4 量化后权重: {quant_weight}") - # print(f" wint4 量化后scale: {scale}") - weight_list.append(quant_weight.transpose([1, 0])) # transpose weight to [n,k] - weight_scale_list.append(scale) - quanted_weight = paddle.stack(weight_list, axis=0) - getattr(layer, weight_name).set_value(quanted_weight) - - quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) - getattr(layer, scale_name).set_value(quanted_weight_scale) + # 4. EP combine + return self.ep_decoder_runner.combine( + ffn_out, + topk_idx, + topk_weights, + handle, + ) def apply( self, @@ -243,185 +340,24 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase): gate: nn.Layer, ) -> paddle.Tensor: """ - XPU compute Fused MoE. + compute Fused MoE. """ - # from fastdeploy.model_executor.ops.xpu import xpu_moe_layer - - # fused_moe_out = xpu_moe_layer( - # x, - # gate.weight.transpose([1, 0]), - # layer.gate_correction_bias, - # layer.up_gate_proj_weight, - # layer.down_proj_weight, - # None, # up_gate_proj bias - # None, # down_proj bias - # (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), - # (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), - # (layer.down_proj_in_scale if hasattr(layer, "down_proj_in_scale") else None), - # self.moe_quant_type, - # layer.top_k, - # False, # moe group, used in deepseek - # ) - # if layer.reduce_results and layer.tp_size > 1: - # from fastdeploy.distributed.communication import ( - # tensor_model_parallel_all_reduce, - # ) - - # tensor_model_parallel_all_reduce(fused_moe_out) - - # return fused_moe_out - - token_num = x.shape[0] - if token_num > 0: - gate_out = paddle.matmul(x.cast("float32"), gate.weight.transpose([1, 0]), transpose_y=True) - topk_idx, topk_weights = moe_topk_select(gate_out, layer.gate_correction_bias, layer.top_k, True) - token_nums_per_expert_list = list(range(64)) # 填充做占位符 - permute_input, permute_indices_per_token, token_num_lod, dst_weights, ffn1_act_scale_per_token = ( - ep_moe_expert_dispatch( - x, - topk_idx, - topk_weights, - (layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None), - token_nums_per_expert_list, - x.shape[0] * layer.top_k, - self.moe_quant_type, - ) - ) - - ffn_out = moe_expert_ffn( - permute_input, - token_num_lod, - layer.up_gate_proj_weight, - layer.down_proj_weight, - None, # moe_ffn1_bias - None, # moe_ffn2_bias - None, # ffn1 in scale - None, # ffn2 in scale - (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), - (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), - None, # moe_ffn2_shift - None, # moe_ffn2_smooth - self.moe_quant_type, - -1, - x.shape[0] * layer.top_k, # token_all_num - ) - topk_weights_bf16 = topk_weights.astype("bfloat16") - tmp_ffn_out = ep_moe_expert_combine( - ffn_out, - permute_indices_per_token, - topk_weights_bf16, - permute_indices_per_token.shape[0], - ffn_out.shape[0], - ffn_out.shape[1], - permute_indices_per_token.shape[1], - ) + if layer.ep_size > 1: + if layer.fd_config.model_config.moe_phase.phase == "prefill": + return self.apply_ep_prefill(layer, x, gate) + elif layer.fd_config.model_config.moe_phase.phase == "decode": + return self.apply_ep_decode(layer, x, gate) + else: + raise ValueError(f"Unsupported phase: {layer.fd_config.model_config.moe_phase.phase}") else: - tmp_ffn_out = paddle.empty(x.shape, x.dtype) - - if layer.reduce_results and layer.tp_size > 1: - from fastdeploy.distributed.communication import ( - tensor_model_parallel_all_reduce, - ) - - tensor_model_parallel_all_reduce(tmp_ffn_out) - return tmp_ffn_out + return self.apply_tp(layer, x, gate) -class XPUWeightOnlyMoeEpMethod(XPUMoEMethod): +class XPUWeightOnlyMoEMethod(XPUMoEMethod): """ - XPU Fused MoE EP Method. + XPU Fused MoE Method. """ - def __init__( - self, - quant_config: WeightOnlyConfig, - ) -> None: - super().__init__(quant_config) - self.moe_quant_type = self.quant_config.algo - self.weight_dtype = "int8" - self.scale_dtype = "float32" - - def import_backend_ep_runner(self) -> None: - from .ep import XPUEPDecoderRunner, XPUEPPrefillRunner - - self.EPPrefillRunner = XPUEPPrefillRunner - self.EPDecoderRunner = XPUEPDecoderRunner - - def create_weights(self, layer: nn.Layer, **extra_weight_attrs): - """ - create weight process. - """ - if self.moe_quant_type in ["weight_only_int8"]: - self.up_gate_proj_weight_shape = [ - layer.num_local_experts, - layer.moe_intermediate_size * 2, - layer.hidden_size, - ] - self.down_proj_weight_shape = [ - layer.num_local_experts, - layer.hidden_size, - layer.moe_intermediate_size, - ] - elif self.moe_quant_type in ["weight_only_int4"]: - self.up_gate_proj_weight_shape = [ - layer.num_local_experts, - layer.moe_intermediate_size * 2, - layer.hidden_size // 2, - ] - self.down_proj_weight_shape = [ - layer.num_local_experts, - layer.hidden_size, - layer.moe_intermediate_size // 2, - ] - else: - raise ValueError(f"Unsupported moe quant type: {self.moe_quant_type}") - - self.up_gate_proj_scale_shape = [ - layer.num_local_experts, - layer.moe_intermediate_size * 2, - ] - self.down_proj_scale_shape = [ - layer.num_local_experts, - layer.hidden_size, - ] - - setattr( - layer, - self.added_weight_attrs[0], - layer.create_parameter( - shape=self.up_gate_proj_weight_shape, - dtype=self.weight_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - self.added_weight_attrs[1], - layer.create_parameter( - shape=self.down_proj_weight_shape, - dtype=self.weight_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - self.added_scale_attrs[0], - layer.create_parameter( - shape=self.up_gate_proj_scale_shape, - dtype=self.scale_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - self.added_scale_attrs[1], - layer.create_parameter( - shape=self.down_proj_scale_shape, - dtype=self.scale_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - def process_loaded_weights(self, layer: nn.Layer, state_dict): """ Paddle xpu load weight process. @@ -457,37 +393,25 @@ class XPUWeightOnlyMoeEpMethod(XPUMoEMethod): quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) getattr(layer, scale_name).set_value(quanted_weight_scale) - def apply_ep_prefill( + def apply_tp( self, layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, ) -> paddle.Tensor: """ - Apply the EP prefill method. + XPU compute Fused MoE. """ - gate_out = gate(x.cast("float32")) - # 1. Select topk experts and weights - topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) - # 2. Dynamic compute blockwise quantization scales - # x, x_scale_tensor = fastdeploy.model_executor.ops.xpu.per_token_quant(x) - x_scale_tensor = None - # 3. EP Dispatch - ( - recv_x, - recv_x_scales, - recv_topk_idx, - recv_topk_weights, - recv_num_tokens_per_expert_list, - _, - ) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor) - - token_num_per_expert = recv_num_tokens_per_expert_list.numpy().tolist() - token_all_num = sum(token_num_per_expert) - - # 4. Compute ffn - if token_all_num > 0: - moe_dispatch_scale = None + token_num = x.shape[0] + if token_num > 0: + gate_out = gate(x.cast("float32")) + topk_idx, topk_weights = moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, + ) + token_nums_per_expert_list = list(range(64)) # 填充做占位符 ( permute_input, permute_indices_per_token, @@ -495,124 +419,48 @@ class XPUWeightOnlyMoeEpMethod(XPUMoEMethod): dst_weights, ffn1_act_scale_per_token, ) = ep_moe_expert_dispatch( - recv_x, - recv_topk_idx, - recv_topk_weights, - moe_dispatch_scale, - token_num_per_expert, - token_all_num, + x, + topk_idx, + topk_weights, + getattr(layer, "up_gate_proj_in_scale", None), + token_nums_per_expert_list, + x.shape[0] * layer.top_k, self.moe_quant_type, ) - moe_ffn1_scale = None - moe_ffn2_scale = None ffn_out = moe_expert_ffn( permute_input, token_num_lod, - getattr(layer, self.added_weight_attrs[0]), - getattr(layer, self.added_weight_attrs[1]), - None, - None, - moe_ffn1_scale, - moe_ffn2_scale, - getattr(layer, self.added_scale_attrs[0]), - getattr(layer, self.added_scale_attrs[1]), - None, - None, + layer.up_gate_proj_weight, + layer.down_proj_weight, + None, # moe_ffn1_bias + None, # moe_ffn2_bias + None, # ffn1 in scale + None, # ffn2 in scale + getattr(layer, "up_gate_proj_weight_scale", None), + getattr(layer, "down_proj_weight_scale", None), + None, # moe_ffn2_shift + None, # moe_ffn2_smooth self.moe_quant_type, -1, - token_all_num, + x.shape[0] * layer.top_k, # token_all_num ) - - # prmt back per rank - recv_topk_weights_bf16 = recv_topk_weights.astype("bfloat16") + topk_weights_bf16 = topk_weights.astype("bfloat16") tmp_ffn_out = ep_moe_expert_combine( ffn_out, permute_indices_per_token, - recv_topk_weights_bf16, + topk_weights_bf16, permute_indices_per_token.shape[0], ffn_out.shape[0], ffn_out.shape[1], permute_indices_per_token.shape[1], ) - else: - tmp_ffn_out = paddle.empty(recv_x.shape, "bfloat16") + tmp_ffn_out = paddle.empty(x.shape, x.dtype) - # 5. EP combine - handle = None - return self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights) - - def compute_ffn( - self, - layer: nn.Layer, - permute_input, - token_nums_per_expert, - valid_token_num=-1, - extra_ffn1_in_scale=None, - ): - """ - Calculate moe - """ - # ffn1_in_scale = extra_ffn1_in_scale - moe_ffn1_scale = None - moe_ffn2_scale = None - - ffn_out = moe_expert_ffn( - permute_input, - token_nums_per_expert, - getattr(layer, self.added_weight_attrs[0]), - getattr(layer, self.added_weight_attrs[1]), - None, - None, - moe_ffn1_scale, - moe_ffn2_scale, - getattr(layer, self.added_scale_attrs[0]), - getattr(layer, self.added_scale_attrs[1]), - None, - None, - self.moe_quant_type, - -1, - valid_token_num, - ) - return ffn_out - - def apply_ep_decode( - self, - layer: nn.Layer, - x: paddle.Tensor, - gate: nn.Layer, - ) -> paddle.Tensor: - """ - Apply the EP decoder method. - """ - gate_out = gate(x.cast("float32")) - - # 1. Select topk experts and weights - topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) - - # 2. EP Dispatch - expertwise_scale = None - use_fp8 = False - ( - permute_input, - token_nums_per_expert, - handle, - valid_token_num, - ) = self.ep_decoder_runner.dispatch( - x, topk_idx, topk_weights, expertwise_scale=expertwise_scale, use_fp8=use_fp8 - ) - - # 3. Compute ffn - ffn_out = self.compute_ffn( - layer, - permute_input, - token_nums_per_expert, - valid_token_num, - ) - - # 4. EP combine - return self.ep_decoder_runner.combine(ffn_out, topk_idx, topk_weights, handle) + if layer.reduce_results and layer.tp_size > 1: + tensor_model_parallel_all_reduce(tmp_ffn_out) + return tmp_ffn_out class XPUW4A8MoEMethod(XPUMoEMethod): @@ -620,100 +468,6 @@ class XPUW4A8MoEMethod(XPUMoEMethod): XPU w4a8 MoE Method """ - def __init__( - self, - quant_config: WeightOnlyConfig, - ) -> None: - super().__init__(quant_config) - self.quant_config = quant_config - self.moe_quant_type = "w4a8" - self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] - self.added_scale_attrs = [ - "up_gate_proj_weight_scale", - "down_proj_weight_scale", - ] - - def create_weights(self, layer: nn.Layer, **extra_weight_attrs): - """ - Paddle cutlass create weight process. - """ - self.weight_dtype = "int8" - self.scale_dtype = "float32" - # get weight shape - if self.moe_quant_type in ["weight_only_int4", "w4a8"]: - self.up_gate_proj_weight_shape = [ - layer.num_local_experts, - layer.moe_intermediate_size * 2, - layer.hidden_size // 2, - ] - else: - self.up_gate_proj_weight_shape = [ - layer.num_local_experts, - layer.moe_intermediate_size * 2, - layer.hidden_size, - ] - if self.moe_quant_type in ["weight_only_int4", "w4a8"]: - self.down_proj_weight_shape = [ - layer.num_local_experts, - layer.hidden_size, - layer.moe_intermediate_size // 2, - ] - else: - self.down_proj_weight_shape = [ - layer.num_local_experts, - layer.hidden_size, - layer.moe_intermediate_size, - ] - # set weight param - setattr( - layer, - self.added_weight_attrs[0], - layer.create_parameter( - shape=self.up_gate_proj_weight_shape, - dtype=self.weight_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - self.added_weight_attrs[1], - layer.create_parameter( - shape=self.down_proj_weight_shape, - dtype=self.weight_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - # weight_scales - setattr( - layer, - self.added_scale_attrs[0], - layer.create_parameter( - shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], - dtype=self.scale_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - self.added_scale_attrs[1], - layer.create_parameter( - shape=[layer.num_local_experts, layer.hidden_size], - dtype=self.scale_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - # in_scale - for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]: - setattr( - layer, - in_scale_name, - layer.create_parameter( - shape=[layer.num_local_experts], - dtype=self.scale_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - def paddle_swap_int4_pack_int4_0123_to_int8_1032in_int8(self, weight_tensor: paddle.Tensor) -> paddle.Tensor: """ Pack the last dimension of a tensor into int8 format by combining adjacent int4 values. @@ -728,9 +482,12 @@ class XPUW4A8MoEMethod(XPUMoEMethod): """ load weight and process. """ - up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = ( - layer.extract_moe_ffn_weights(state_dict) - ) + ( + up_gate_proj_weights, + down_proj_weights, + logical_expert_ids, + ep_rank_to_expert_id_list, + ) = layer.extract_moe_ffn_weights(state_dict) assert len(up_gate_proj_weights) == layer.num_local_experts assert len(down_proj_weights) == layer.num_local_experts assert up_gate_proj_weights[0].shape == [ @@ -772,7 +529,12 @@ class XPUW4A8MoEMethod(XPUMoEMethod): state_dict (dict): The state dict. """ - def _extract_scale_tensor(layer: nn.Layer, state_dict, key_template, expert_idx): + def _extract_scale_tensor( + layer: nn.Layer, + state_dict, + key_template, + expert_idx, + ): return get_tensor( ( state_dict.pop(key_template.format(expert_idx)) @@ -806,35 +568,49 @@ class XPUW4A8MoEMethod(XPUMoEMethod): for expert_idx in logical_expert_ids: for name, scale_key_template in scale_key_map.items(): - scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx) + scale_tensor = _extract_scale_tensor( + layer, + state_dict, + scale_key_template, + expert_idx, + ) scale_weight_map[name].append(scale_tensor) # 2. Process scale tensor and set to layer - for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]: + for in_scale_name in self.added_in_scale_attrs: getattr(layer, in_scale_name).set_value(paddle.concat(scale_weight_map[in_scale_name])) - for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]): + for weight_scale_name in self.added_scale_attrs: getattr(layer, weight_scale_name).set_value(paddle.stack(scale_weight_map[weight_scale_name], axis=0)) - def apply( + def apply_tp( self, layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, ) -> paddle.Tensor: - gate_out = paddle.matmul(x.cast("float32"), gate.weight.transpose([1, 0]), transpose_y=True) - topk_idx, topk_weights = moe_topk_select(gate_out, layer.gate_correction_bias, layer.top_k, True) + gate_out = gate(x.cast("float32")) + topk_idx, topk_weights = moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, + ) token_nums_per_expert_list = list(range(64)) # 填充做占位符 - permute_input, permute_indices_per_token, token_num_lod, dst_weights, ffn1_act_scale_per_token = ( - ep_moe_expert_dispatch( - x, - topk_idx, - topk_weights, - (layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None), - token_nums_per_expert_list, - x.shape[0] * layer.top_k, - self.moe_quant_type, - ) + ( + permute_input, + permute_indices_per_token, + token_num_lod, + dst_weights, + ffn1_act_scale_per_token, + ) = ep_moe_expert_dispatch( + x, + topk_idx, + topk_weights, + getattr(layer, "up_gate_proj_in_scale", None), + token_nums_per_expert_list, + x.shape[0] * layer.top_k, + self.moe_quant_type, ) ffn_out = moe_expert_ffn( permute_input, @@ -843,14 +619,14 @@ class XPUW4A8MoEMethod(XPUMoEMethod): layer.down_proj_weight, None, # moe_ffn1_bias None, # moe_ffn2_bias - (ffn1_act_scale_per_token if hasattr(layer, "up_gate_proj_in_scale") else None), - (layer.down_proj_in_scale if hasattr(layer, "down_proj_in_scale") else None), - (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), - (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), + getattr(layer, "up_gate_proj_in_scale", None), + getattr(layer, "down_proj_in_scale", None), + getattr(layer, "up_gate_proj_weight_scale", None), + getattr(layer, "down_proj_weight_scale", None), None, # moe_ffn2_shift None, # moe_ffn2_smooth self.moe_quant_type, - getattr(layer.moe_quant_config, "hadamard_block_size", 128), # hadamard_blocksize defalue 128 + getattr(layer.moe_quant_config, "hadamard_block_size", 128), x.shape[0] * layer.top_k, # token_all_num ) topk_weights_bf16 = topk_weights.astype("bfloat16") @@ -864,9 +640,5 @@ class XPUW4A8MoEMethod(XPUMoEMethod): permute_indices_per_token.shape[1], ) if layer.reduce_results and layer.tp_size > 1: - from fastdeploy.distributed.communication import ( - tensor_model_parallel_all_reduce, - ) - tensor_model_parallel_all_reduce(tmp_ffn_out) return tmp_ffn_out diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 9078021fd..f3414e8b8 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -30,15 +30,22 @@ class MoEMethodBase(QuantMethodBase): def __init__(self, quant_config): super().__init__() - if quant_config is None: + self.quant_config = quant_config + if self.quant_config is None: self.moe_quant_type = "w16a16" + elif hasattr(quant_config, "algo"): + self.moe_quant_type = quant_config.algo else: - self.quant_config = quant_config + self.moe_quant_type = quant_config.name() self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] self.added_scale_attrs = [ "up_gate_proj_weight_scale", "down_proj_weight_scale", ] + self.added_in_scale_attrs = [ + "up_gate_proj_in_scale", + "down_proj_in_scale", + ] self.pack_num = 1 self.ep_prefill_runner = None self.ep_decoder_runner = None diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index 0f6b4e6ad..b448afa12 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -79,18 +79,11 @@ class WeightOnlyConfig(QuantConfigBase): def get_quant_method(self, layer) -> Optional[QuantMethodBase]: if current_platform.is_xpu(): if isinstance(layer, FusedMoE): - if layer.ep_size > 1: - from fastdeploy.model_executor.layers.backends import ( - XPUWeightOnlyMoeEpMethod, - ) + from fastdeploy.model_executor.layers.backends import ( + XPUWeightOnlyMoEMethod, + ) - return XPUWeightOnlyMoeEpMethod(self) - else: - from fastdeploy.model_executor.layers.backends import ( - XPUWeightOnlyMoEMethod, - ) - - return XPUWeightOnlyMoEMethod(self) + return XPUWeightOnlyMoEMethod(self) else: from fastdeploy.model_executor.layers.backends import ( XPUWeightOnlyLinearMethod, diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 0f27fde5c..ce96c00e0 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -404,9 +404,9 @@ class PaddleDisWorkerProc: if num_blocks_local <= 0: raise ValueError( - "The total number of blocks cannot be less than zero." - "Please increase gpu_memory_utilization" - "Or decrease max_num_batched_tokens(max model length) " + "The total number of blocks cannot be less than zero. " + "Please increase gpu_memory_utilization " + "Or decrease max_num_batched_tokens(max model length)." ) if self.ranks > 1: diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 985e2a911..19a970f87 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -1227,7 +1227,8 @@ class XPUModelRunner(ModelRunnerBase): """ Clear the block tables and kv cache after profiling. """ - del self.share_inputs["caches"] + if hasattr(self.share_inputs, "caches"): + del self.share_inputs["caches"] if self.forward_meta is not None: del self.forward_meta.caches paddle.device.xpu.empty_cache()