mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
support reduce_dialogue_repetition
This commit is contained in:
@@ -282,7 +282,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
"""
|
||||
if self.config.use_hf_tokenizer:
|
||||
from transformers import AutoTokenizer
|
||||
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False, vocab_file=os.path.join(self.config.model_dir, "sentencepiece.bpe.model"))
|
||||
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False)
|
||||
else:
|
||||
from paddlenlp.transformers import AutoTokenizer
|
||||
return AutoTokenizer.from_pretrained(self.config.model_dir)
|
||||
|
||||
@@ -52,6 +52,8 @@ class ModelRunner:
|
||||
self.args.num_attention_heads = self.get_value(self.model_cfg, ["num_attention_heads", "n_head"])
|
||||
self.args.hidden_size = self.model_cfg["hidden_size"]
|
||||
|
||||
self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))
|
||||
|
||||
self.nranks = dist.get_world_size()
|
||||
self.init_dist_env()
|
||||
self.rank = fleet.worker_index()
|
||||
@@ -246,6 +248,12 @@ class ModelRunner:
|
||||
self.share_inputs['free_list_len'] = paddle.full(
|
||||
shape=[1], fill_value=self.free_list_len, dtype="int32")
|
||||
|
||||
if self.reduce_dialogue_repetition:
|
||||
self.share_inputs["first_token_ids"] = paddle.full(
|
||||
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
|
||||
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
|
||||
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
|
||||
|
||||
def dy_input_preprocess(self, tasks):
|
||||
"""
|
||||
dynamic insertion
|
||||
@@ -279,6 +287,10 @@ class ModelRunner:
|
||||
self.share_inputs['max_length'][idx:idx + 1] = max_dec_len
|
||||
self.share_inputs['stop_flags'][idx:idx + 1] = False
|
||||
|
||||
if self.reduce_dialogue_repetition:
|
||||
self.share_inputs['first_token_ids'][idx:idx + 1] = self.share_inputs['input_ids'][idx:idx + 1, :1]
|
||||
self.share_inputs["ori_seq_lens_encoder"][idx:idx + 1] = length
|
||||
|
||||
if "infer_seed" in task:
|
||||
self.share_inputs['infer_seed'][idx:idx + 1] = task['infer_seed']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user