[Bugfix] Fix model accuracy in some ops (#3231)

* fix noaux_tc op

* fix

* update

* fix qk norm

* fix linear for prequant loader

* test

* fix

* fix

* rm some print

* fix noaux_tc op

* test

* Fix the confused enable_early_stop when only set early_stop_config (#3214)

* fix the confused early_stop_config when only set early_stop_config

* pre-commit

* write a general method

* Add ci case for min token and max token (#3229)

Co-authored-by: xujing43 <xujing43@baidu.com>

* add some evil cases (#3240)

* add repitation early stop cases

* add repitation early stop cases

* add bad cases

* add bad cases

* add evil cases

* qwen3_moe (#3084)

* [Feature] support seed parameter (#3161)

* support seed

* fix

* add SamplingMetadata seed test

* The next_tokens values are inconsistent!

* add air and rejection seed test

* fix

* add SamplingParams seed test

* fix seed=0

* Default to defualt

* fix

* fix args_utils

* fix review

* fix review

* fix

* fix

* add xpu,gcu,iluvatar support seed

* fix

* 【Fix Bug】 修复 fa3 支持集中式bug (#3235)

* fix fa3 集中式bug

* 增加qknorm参数

* fix qk norm

* fix

* update

* fix linear for prequant loader

* fix

* fix

* rm some print

* fix

* fix moe init weight&scale

* fix moe init weight&scale

---------

Co-authored-by: bukejiyu <395822456@qq.com>
Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com>
Co-authored-by: Zero Rains <linjunlu@zerorains.top>
Co-authored-by: xjkmfa <108254620+xjkmfa@users.noreply.github.com>
Co-authored-by: xujing43 <xujing43@baidu.com>
Co-authored-by: Divano <dddivano@outlook.com>
Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
Co-authored-by: lizexu123 <39205361+lizexu123@users.noreply.github.com>
Co-authored-by: yangjianfengo1 <125249383+yangjianfengo1@users.noreply.github.com>
Co-authored-by: qingqing01 <dangqingqing@baidu.com>
This commit is contained in:
gaoziyuan
2025-08-08 17:30:37 +08:00
committed by GitHub
parent ce1f353c70
commit a799d14df1
8 changed files with 62 additions and 31 deletions

View File

@@ -285,7 +285,7 @@ class FusedMoE(nn.Layer):
dtype="float32",
)
up_gate_proj_output_dim = self.moe_intermediate_size * 2
if self.moe_quant_type in ["fp8", "wint8"]:
if self.moe_quant_type in ["block_wise_fp8", "wint8"]:
up_gate_proj_weight_shape = [
self.num_local_experts,
up_gate_proj_output_dim,
@@ -309,9 +309,10 @@ class FusedMoE(nn.Layer):
]
# Create parameters
if self.moe_quant_type == "fp8":
if self.moe_quant_type == "block_wise_fp8":
# (TODO:gaoziyuan)
pass
self.weight_dtype = "float8_e4m3fn"
self.init_block_wise_fp8_scale()
elif self.moe_quant_type == "wint8":
self.weight_dtype = "int8"
self.init_weight_only_scale()
@@ -342,6 +343,21 @@ class FusedMoE(nn.Layer):
dtype=self._dtype,
)
def init_block_wise_fp8_scale(self):
"""
Initialize the weight scale.
"""
self.up_gate_proj_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.moe_intermediate_size * 2 // 128, self.hidden_size // 128],
dtype="float32",
is_bias=False,
)
self.down_proj_weight_scale = self.create_parameter(
shape=[self.num_local_experts, self.hidden_size // 128, self.moe_intermediate_size // 128],
dtype="float32",
is_bias=False,
)
def load_experts_weight(
self,
state_dict: dict,