mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-12 20:11:20 +08:00
support stop_seqs
This commit is contained in:
@@ -143,6 +143,9 @@ class DataProcessor(BaseDataProcessor):
|
||||
request["eos_token_ids"] = []
|
||||
request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config))
|
||||
|
||||
if "stop_seqs" not in request or (isinstance(request["stop_seqs"], (list, tuple)) and len(request["stop_seqs"]) == 0):
|
||||
self.update_stop_seq(request)
|
||||
|
||||
if "input_ids" not in request or \
|
||||
(isinstance(request["input_ids"], (list, tuple)) and len(request["input_ids"]) == 0):
|
||||
if "text" in request:
|
||||
@@ -334,3 +337,19 @@ class DataProcessor(BaseDataProcessor):
|
||||
if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
|
||||
return self.tokenizer.eos_token
|
||||
return self.tokenizer.pad_token_id
|
||||
|
||||
def update_stop_seq(self, request):
|
||||
"""
|
||||
Update stop sequences from request.
|
||||
"""
|
||||
stop_seqs = [[2], [100273]]
|
||||
for seq in request.get("stop_sequences", []):
|
||||
if seq != self._get_eos_token_id():
|
||||
stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
|
||||
request["stop_seqs"], request["stop_seqs_len"] = self.pad_batch_data(
|
||||
stop_seqs,
|
||||
pad_id=-1,
|
||||
return_seq_len=True,
|
||||
return_array=False
|
||||
)
|
||||
data_processor_logger.debug(f"processed request: {request['stop_seqs'], request['stop_seqs_len']}")
|
||||
|
@@ -54,6 +54,9 @@ class ModelRunner:
|
||||
|
||||
self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))
|
||||
|
||||
self.max_stop_seqs_num = int(os.getenv("MAX_STOP_SEQS_NUM", 5))
|
||||
self.stop_seqs_max_len = int(os.getenv("STOP_SEQS_MAX_LEN", 8))
|
||||
|
||||
self.nranks = dist.get_world_size()
|
||||
self.init_dist_env()
|
||||
self.rank = fleet.worker_index()
|
||||
@@ -248,6 +251,13 @@ class ModelRunner:
|
||||
self.share_inputs['free_list_len'] = paddle.full(
|
||||
shape=[1], fill_value=self.free_list_len, dtype="int32")
|
||||
|
||||
self.share_inputs['stop_seqs_len'] = paddle.full(shape=[max_stop_seqs_num,],
|
||||
fill_value=0,
|
||||
dtype="int32")
|
||||
self.share_inputs['stop_seqs'] = paddle.full(shape=[max_stop_seqs_num, stop_seqs_max_len],
|
||||
fill_value=-1,
|
||||
dtype="int64")
|
||||
|
||||
if self.reduce_dialogue_repetition:
|
||||
self.share_inputs["first_token_ids"] = paddle.full(
|
||||
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
|
||||
@@ -300,6 +310,14 @@ class ModelRunner:
|
||||
self.share_inputs["block_tables"][idx:idx + 1, :encoder_block_num] = np.array(
|
||||
task['block_tables'], dtype="int32")
|
||||
|
||||
if "stop_seqs_len" in task:
|
||||
stop_seqs_num = len(task["stop_seqs_len"])
|
||||
for i in range(stop_seqs_num, max_stop_seqs_num):
|
||||
task["stop_seqs_len"].append(0)
|
||||
share_inputs['stop_seqs_len'][:] = np.array(
|
||||
task["stop_seqs_len"], dtype="int32")
|
||||
share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
|
||||
task["stop_seqs"], dtype="int64")
|
||||
def step_cuda(self, seq_lens_this_time):
|
||||
"""
|
||||
step cuda
|
||||
@@ -486,6 +504,11 @@ class InferenceEngine(object):
|
||||
config.switch_ir_optim(False)
|
||||
config.enable_use_gpu(100, device_id)
|
||||
|
||||
pir_flag = int(os.environ.get("FLAGS_enable_pir_api", 0))
|
||||
if pir_flag == 1:
|
||||
config.enable_new_executor()
|
||||
config.enable_new_ir()
|
||||
|
||||
# distributed config
|
||||
if self.mp_degree > 1:
|
||||
trainer_endpoints = fleet.worker_endpoints()
|
||||
|
Reference in New Issue
Block a user