mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-13 04:13:58 +08:00
support stop_seqs
This commit is contained in:
@@ -143,6 +143,9 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
request["eos_token_ids"] = []
|
request["eos_token_ids"] = []
|
||||||
request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config))
|
request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config))
|
||||||
|
|
||||||
|
if "stop_seqs" not in request or (isinstance(request["stop_seqs"], (list, tuple)) and len(request["stop_seqs"]) == 0):
|
||||||
|
self.update_stop_seq(request)
|
||||||
|
|
||||||
if "input_ids" not in request or \
|
if "input_ids" not in request or \
|
||||||
(isinstance(request["input_ids"], (list, tuple)) and len(request["input_ids"]) == 0):
|
(isinstance(request["input_ids"], (list, tuple)) and len(request["input_ids"]) == 0):
|
||||||
if "text" in request:
|
if "text" in request:
|
||||||
@@ -334,3 +337,19 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
|
if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
|
||||||
return self.tokenizer.eos_token
|
return self.tokenizer.eos_token
|
||||||
return self.tokenizer.pad_token_id
|
return self.tokenizer.pad_token_id
|
||||||
|
|
||||||
|
def update_stop_seq(self, request):
|
||||||
|
"""
|
||||||
|
Update stop sequences from request.
|
||||||
|
"""
|
||||||
|
stop_seqs = [[2], [100273]]
|
||||||
|
for seq in request.get("stop_sequences", []):
|
||||||
|
if seq != self._get_eos_token_id():
|
||||||
|
stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
|
||||||
|
request["stop_seqs"], request["stop_seqs_len"] = self.pad_batch_data(
|
||||||
|
stop_seqs,
|
||||||
|
pad_id=-1,
|
||||||
|
return_seq_len=True,
|
||||||
|
return_array=False
|
||||||
|
)
|
||||||
|
data_processor_logger.debug(f"processed request: {request['stop_seqs'], request['stop_seqs_len']}")
|
||||||
|
@@ -54,6 +54,9 @@ class ModelRunner:
|
|||||||
|
|
||||||
self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))
|
self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))
|
||||||
|
|
||||||
|
self.max_stop_seqs_num = int(os.getenv("MAX_STOP_SEQS_NUM", 5))
|
||||||
|
self.stop_seqs_max_len = int(os.getenv("STOP_SEQS_MAX_LEN", 8))
|
||||||
|
|
||||||
self.nranks = dist.get_world_size()
|
self.nranks = dist.get_world_size()
|
||||||
self.init_dist_env()
|
self.init_dist_env()
|
||||||
self.rank = fleet.worker_index()
|
self.rank = fleet.worker_index()
|
||||||
@@ -248,6 +251,13 @@ class ModelRunner:
|
|||||||
self.share_inputs['free_list_len'] = paddle.full(
|
self.share_inputs['free_list_len'] = paddle.full(
|
||||||
shape=[1], fill_value=self.free_list_len, dtype="int32")
|
shape=[1], fill_value=self.free_list_len, dtype="int32")
|
||||||
|
|
||||||
|
self.share_inputs['stop_seqs_len'] = paddle.full(shape=[max_stop_seqs_num,],
|
||||||
|
fill_value=0,
|
||||||
|
dtype="int32")
|
||||||
|
self.share_inputs['stop_seqs'] = paddle.full(shape=[max_stop_seqs_num, stop_seqs_max_len],
|
||||||
|
fill_value=-1,
|
||||||
|
dtype="int64")
|
||||||
|
|
||||||
if self.reduce_dialogue_repetition:
|
if self.reduce_dialogue_repetition:
|
||||||
self.share_inputs["first_token_ids"] = paddle.full(
|
self.share_inputs["first_token_ids"] = paddle.full(
|
||||||
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
|
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
|
||||||
@@ -300,6 +310,14 @@ class ModelRunner:
|
|||||||
self.share_inputs["block_tables"][idx:idx + 1, :encoder_block_num] = np.array(
|
self.share_inputs["block_tables"][idx:idx + 1, :encoder_block_num] = np.array(
|
||||||
task['block_tables'], dtype="int32")
|
task['block_tables'], dtype="int32")
|
||||||
|
|
||||||
|
if "stop_seqs_len" in task:
|
||||||
|
stop_seqs_num = len(task["stop_seqs_len"])
|
||||||
|
for i in range(stop_seqs_num, max_stop_seqs_num):
|
||||||
|
task["stop_seqs_len"].append(0)
|
||||||
|
share_inputs['stop_seqs_len'][:] = np.array(
|
||||||
|
task["stop_seqs_len"], dtype="int32")
|
||||||
|
share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
|
||||||
|
task["stop_seqs"], dtype="int64")
|
||||||
def step_cuda(self, seq_lens_this_time):
|
def step_cuda(self, seq_lens_this_time):
|
||||||
"""
|
"""
|
||||||
step cuda
|
step cuda
|
||||||
@@ -486,6 +504,11 @@ class InferenceEngine(object):
|
|||||||
config.switch_ir_optim(False)
|
config.switch_ir_optim(False)
|
||||||
config.enable_use_gpu(100, device_id)
|
config.enable_use_gpu(100, device_id)
|
||||||
|
|
||||||
|
pir_flag = int(os.environ.get("FLAGS_enable_pir_api", 0))
|
||||||
|
if pir_flag == 1:
|
||||||
|
config.enable_new_executor()
|
||||||
|
config.enable_new_ir()
|
||||||
|
|
||||||
# distributed config
|
# distributed config
|
||||||
if self.mp_degree > 1:
|
if self.mp_degree > 1:
|
||||||
trainer_endpoints = fleet.worker_endpoints()
|
trainer_endpoints = fleet.worker_endpoints()
|
||||||
|
Reference in New Issue
Block a user