support stop_seqs

This commit is contained in:
minghaipeng
2025-01-07 06:35:25 +00:00
parent cbd77205f3
commit 093614e47d
2 changed files with 42 additions and 0 deletions

View File

@@ -143,6 +143,9 @@ class DataProcessor(BaseDataProcessor):
request["eos_token_ids"] = []
request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config))
if "stop_seqs" not in request or (isinstance(request["stop_seqs"], (list, tuple)) and len(request["stop_seqs"]) == 0):
self.update_stop_seq(request)
if "input_ids" not in request or \
(isinstance(request["input_ids"], (list, tuple)) and len(request["input_ids"]) == 0):
if "text" in request:
@@ -334,3 +337,19 @@ class DataProcessor(BaseDataProcessor):
if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
return self.tokenizer.eos_token
return self.tokenizer.pad_token_id
def update_stop_seq(self, request):
"""
Update stop sequences from request.
"""
stop_seqs = [[2], [100273]]
for seq in request.get("stop_sequences", []):
if seq != self._get_eos_token_id():
stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
request["stop_seqs"], request["stop_seqs_len"] = self.pad_batch_data(
stop_seqs,
pad_id=-1,
return_seq_len=True,
return_array=False
)
data_processor_logger.debug(f"processed request: {request['stop_seqs'], request['stop_seqs_len']}")

View File

@@ -54,6 +54,9 @@ class ModelRunner:
self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))
self.max_stop_seqs_num = int(os.getenv("MAX_STOP_SEQS_NUM", 5))
self.stop_seqs_max_len = int(os.getenv("STOP_SEQS_MAX_LEN", 8))
self.nranks = dist.get_world_size()
self.init_dist_env()
self.rank = fleet.worker_index()
@@ -248,6 +251,13 @@ class ModelRunner:
self.share_inputs['free_list_len'] = paddle.full(
shape=[1], fill_value=self.free_list_len, dtype="int32")
self.share_inputs['stop_seqs_len'] = paddle.full(shape=[max_stop_seqs_num,],
fill_value=0,
dtype="int32")
self.share_inputs['stop_seqs'] = paddle.full(shape=[max_stop_seqs_num, stop_seqs_max_len],
fill_value=-1,
dtype="int64")
if self.reduce_dialogue_repetition:
self.share_inputs["first_token_ids"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
@@ -300,6 +310,14 @@ class ModelRunner:
self.share_inputs["block_tables"][idx:idx + 1, :encoder_block_num] = np.array(
task['block_tables'], dtype="int32")
if "stop_seqs_len" in task:
stop_seqs_num = len(task["stop_seqs_len"])
for i in range(stop_seqs_num, max_stop_seqs_num):
task["stop_seqs_len"].append(0)
share_inputs['stop_seqs_len'][:] = np.array(
task["stop_seqs_len"], dtype="int32")
share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
task["stop_seqs"], dtype="int64")
def step_cuda(self, seq_lens_this_time):
"""
step cuda
@@ -486,6 +504,11 @@ class InferenceEngine(object):
config.switch_ir_optim(False)
config.enable_use_gpu(100, device_id)
pir_flag = int(os.environ.get("FLAGS_enable_pir_api", 0))
if pir_flag == 1:
config.enable_new_executor()
config.enable_new_ir()
# distributed config
if self.mp_degree > 1:
trainer_endpoints = fleet.worker_endpoints()