fix mtp quant param (#4469)

This commit is contained in:
GoldPancake
2025-10-17 14:53:01 +08:00
committed by GitHub
parent 6c15945e4d
commit 631a1e2339
2 changed files with 5 additions and 1 deletions

View File

@@ -315,7 +315,9 @@ class Ernie4_5_MTPModel(nn.Layer):
hidden_states = hidden_states + residual
hidden_states = self.norm(hidden_states)
# NOTE@wangyuanpeng04 Whether to use norm here is determined by
# whether norm is used in the MTP training phase.
# hidden_states = self.norm(hidden_states)
return hidden_states

View File

@@ -82,6 +82,8 @@ class MTPProposer(Proposer):
self.model_config.quantization = self.speculative_config.quantization
self.model_config.start_layer_index = self.num_main_model_layers
self.speculative_config.model_type = "mtp"
if self.speculative_config.quantization is not None:
self.model_config.is_quantized = False
def _load_model(self):
"""