[NewFeatures] support eplb (#3547)

* [NewFeatures] support eplb

* fix eplb
This commit is contained in:
xiaoxiaohehe001
2025-08-26 16:19:30 +08:00
committed by GitHub
parent 56e2d7e668
commit 9afa236e39
17 changed files with 174 additions and 67 deletions

View File

@@ -38,7 +38,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
"down_proj_weight_scale",
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
"""process_prequanted_weights"""
pass
@@ -46,7 +46,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Triton MoE create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
assert self.quant_method.name() == "wint8"

View File

@@ -49,7 +49,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
self.group_size = -1
def process_loaded_weights(self, layer: nn.Layer, state_dict):
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
layer.up_gate_proj_weight.set_value(paddle.transpose(stacked_up_gate_proj_weights, [0, 2, 1]))
@@ -254,7 +254,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
self.quant_multi_process_group_size = int(os.getenv("FD_MOE_QUANT_MULTI_PROCESS_GROUP_SIZE", 8))
logger.info(f"GCUWeightOnlyMoEMethod quant_multi_process_group_size: {self.quant_multi_process_group_size}")
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle gcu process prequanted weights.
"""
@@ -299,7 +299,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
"""
Paddle cutlass create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):

View File

@@ -59,7 +59,7 @@ class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
is_bias=False,
)
def process_prequanted_weights(self, layer, state_dict) -> None:
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False) -> None:
"""
Process pre-quantized weights before applying them to the model
Args:

View File

@@ -41,7 +41,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
"down_proj_weight_scale",
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
"""process_prequanted_weights"""
pass
@@ -50,7 +50,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Triton MoE create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts