support reduce_dialogue_repetition

This commit is contained in:
minghaipeng
2025-01-06 12:39:19 +00:00
parent 608d4be580
commit 577b7a7681
2 changed files with 13 additions and 1 deletions

View File

@@ -282,7 +282,7 @@ class DataProcessor(BaseDataProcessor):
"""
if self.config.use_hf_tokenizer:
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False, vocab_file=os.path.join(self.config.model_dir, "sentencepiece.bpe.model"))
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False)
else:
from paddlenlp.transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(self.config.model_dir)

View File

@@ -52,6 +52,8 @@ class ModelRunner:
self.args.num_attention_heads = self.get_value(self.model_cfg, ["num_attention_heads", "n_head"])
self.args.hidden_size = self.model_cfg["hidden_size"]
self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))
self.nranks = dist.get_world_size()
self.init_dist_env()
self.rank = fleet.worker_index()
@@ -246,6 +248,12 @@ class ModelRunner:
self.share_inputs['free_list_len'] = paddle.full(
shape=[1], fill_value=self.free_list_len, dtype="int32")
if self.reduce_dialogue_repetition:
self.share_inputs["first_token_ids"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
def dy_input_preprocess(self, tasks):
"""
dynamic insertion
@@ -279,6 +287,10 @@ class ModelRunner:
self.share_inputs['max_length'][idx:idx + 1] = max_dec_len
self.share_inputs['stop_flags'][idx:idx + 1] = False
if self.reduce_dialogue_repetition:
self.share_inputs['first_token_ids'][idx:idx + 1] = self.share_inputs['input_ids'][idx:idx + 1, :1]
self.share_inputs["ori_seq_lens_encoder"][idx:idx + 1] = length
if "infer_seed" in task:
self.share_inputs['infer_seed'][idx:idx + 1] = task['infer_seed']