mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
[BugFix]fix v1 loader moe bf16, and supoort dynamic_load_weight create quant param (#4229)
* fix v1 loader moe bf16, and supoort dynamic_load_weight create quant param * include_stop_str_in_output=False not return eos text
This commit is contained in:
@@ -185,6 +185,9 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
from paddleformers.trl.llm_utils import get_eos_token_id
|
from paddleformers.trl.llm_utils import get_eos_token_id
|
||||||
|
|
||||||
self.eos_token_ids = get_eos_token_id(self.tokenizer, self.generation_config)
|
self.eos_token_ids = get_eos_token_id(self.tokenizer, self.generation_config)
|
||||||
|
data_processor_logger.info(
|
||||||
|
f"The eos_token_ids obtained by merging tokenizer and generation_config is {self.eos_token_ids}"
|
||||||
|
)
|
||||||
self.eos_token_id_len = len(self.eos_token_ids)
|
self.eos_token_id_len = len(self.eos_token_ids)
|
||||||
self.pad_token_id = self.get_pad_id()
|
self.pad_token_id = self.get_pad_id()
|
||||||
self.reasoning_parser = None
|
self.reasoning_parser = None
|
||||||
@@ -396,7 +399,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
is_end = response_dict["finished"]
|
is_end = response_dict["finished"]
|
||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
if token_ids[-1] in self.eos_token_ids:
|
||||||
token_ids = token_ids[:-1]
|
token_ids = token_ids[:-1]
|
||||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||||
if is_end:
|
if is_end:
|
||||||
@@ -434,7 +437,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
token_ids = response_dict["outputs"]["token_ids"]
|
token_ids = response_dict["outputs"]["token_ids"]
|
||||||
|
|
||||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
if token_ids[-1] in self.eos_token_ids:
|
||||||
token_ids = token_ids[:-1]
|
token_ids = token_ids[:-1]
|
||||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||||
response_dict["outputs"]["raw_prediction"] = delta_text
|
response_dict["outputs"]["raw_prediction"] = delta_text
|
||||||
|
@@ -199,6 +199,7 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
|
|||||||
layer.up_gate_proj_weight,
|
layer.up_gate_proj_weight,
|
||||||
{
|
{
|
||||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||||
|
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||||
"model_format": extra_weight_attrs.get("model_format", ""),
|
"model_format": extra_weight_attrs.get("model_format", ""),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -206,6 +207,7 @@ class UnquantizedFusedMoEMethod(MoEMethodBase):
|
|||||||
layer.down_proj_weight,
|
layer.down_proj_weight,
|
||||||
{
|
{
|
||||||
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
|
||||||
|
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
|
||||||
"model_format": extra_weight_attrs.get("model_format", ""),
|
"model_format": extra_weight_attrs.get("model_format", ""),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@@ -85,6 +85,8 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader):
|
|||||||
else:
|
else:
|
||||||
if not quantization_config.get("is_quantized"):
|
if not quantization_config.get("is_quantized"):
|
||||||
quantization_config["is_quantized"] = model_config.is_quantized
|
quantization_config["is_quantized"] = model_config.is_quantized
|
||||||
|
if args.dynamic_load_weight and quantization_config is not None:
|
||||||
|
quantization_config["is_quantized"] = True
|
||||||
quant_cls = get_quantization_config(quant_config_name)
|
quant_cls = get_quantization_config(quant_config_name)
|
||||||
quant_config = quant_cls.from_config(quantization_config)
|
quant_config = quant_cls.from_config(quantization_config)
|
||||||
return quant_config
|
return quant_config
|
||||||
|
Reference in New Issue
Block a user