[BugFix]fix v1 loader moe bf16, and supoort dynamic_load_weight create quant param (#4229)

* fix v1 loader moe bf16, and supoort dynamic_load_weight create quant param

* include_stop_str_in_output=False not return eos text
This commit is contained in:
chen
2025-09-24 14:12:05 +08:00
committed by GitHub
parent 44010cee13
commit 3161014e49
3 changed files with 9 additions and 2 deletions

View File

@@ -185,6 +185,9 @@ class DataProcessor(BaseDataProcessor):
from paddleformers.trl.llm_utils import get_eos_token_id from paddleformers.trl.llm_utils import get_eos_token_id
self.eos_token_ids = get_eos_token_id(self.tokenizer, self.generation_config) self.eos_token_ids = get_eos_token_id(self.tokenizer, self.generation_config)
data_processor_logger.info(
f"The eos_token_ids obtained by merging tokenizer and generation_config is {self.eos_token_ids}"
)
self.eos_token_id_len = len(self.eos_token_ids) self.eos_token_id_len = len(self.eos_token_ids)
self.pad_token_id = self.get_pad_id() self.pad_token_id = self.get_pad_id()
self.reasoning_parser = None self.reasoning_parser = None
@@ -396,7 +399,7 @@ class DataProcessor(BaseDataProcessor):
is_end = response_dict["finished"] is_end = response_dict["finished"]
req_id = response_dict["request_id"] req_id = response_dict["request_id"]
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"): if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] in self.eos_token_ids:
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id) delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
if is_end: if is_end:
@@ -434,7 +437,7 @@ class DataProcessor(BaseDataProcessor):
token_ids = response_dict["outputs"]["token_ids"] token_ids = response_dict["outputs"]["token_ids"]
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"): if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] in self.eos_token_ids:
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id) delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
response_dict["outputs"]["raw_prediction"] = delta_text response_dict["outputs"]["raw_prediction"] = delta_text

View File

@@ -199,6 +199,7 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
layer.up_gate_proj_weight, layer.up_gate_proj_weight,
{ {
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)), "weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
"model_format": extra_weight_attrs.get("model_format", ""), "model_format": extra_weight_attrs.get("model_format", ""),
}, },
) )
@@ -206,6 +207,7 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
layer.down_proj_weight, layer.down_proj_weight,
{ {
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)), "weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
"model_format": extra_weight_attrs.get("model_format", ""), "model_format": extra_weight_attrs.get("model_format", ""),
}, },
) )

View File

@@ -85,6 +85,8 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
else: else:
if not quantization_config.get("is_quantized"): if not quantization_config.get("is_quantized"):
quantization_config["is_quantized"] = model_config.is_quantized quantization_config["is_quantized"] = model_config.is_quantized
if args.dynamic_load_weight and quantization_config is not None:
quantization_config["is_quantized"] = True
quant_cls = get_quantization_config(quant_config_name) quant_cls = get_quantization_config(quant_config_name)
quant_config = quant_cls.from_config(quantization_config) quant_config = quant_cls.from_config(quantization_config)
return quant_config return quant_config