mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 09:31:35 +08:00
refactor code
This commit is contained in:
@@ -92,9 +92,6 @@ class Config:
|
||||
self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
|
||||
self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))
|
||||
|
||||
# speculate decoding config
|
||||
self.speculate_method = str(os.getenv("SPECULATE_METHOD", None))
|
||||
|
||||
# infer config
|
||||
self.max_batch_size = int(env.get("BATCH_SIZE", 50))
|
||||
self.max_seq_len = int(env.get("MAX_SEQ_LEN", 8192))
|
||||
|
@@ -47,6 +47,7 @@ class ModelRunner:
|
||||
|
||||
self.config = Config()
|
||||
self.model_cfg = self.config.get_model_config()
|
||||
self.is_speculate_decoding = self.model_cfg.get("speculate_method") is not None
|
||||
self.format_print_configuration()
|
||||
|
||||
self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
|
||||
@@ -68,16 +69,16 @@ class ModelRunner:
|
||||
self.cache_kvs = {}
|
||||
self.init_inputs()
|
||||
|
||||
# whether use speculate decoding
|
||||
if self.config.speculate_method is not None:
|
||||
if self.config.speculate_method == "inference_with_reference":
|
||||
if self.is_speculate_decoding:
|
||||
logger.info(f'Using speculating decoding, method: {self.model_cfg["speculate_method"]}.')
|
||||
if self.model_cfg["speculate_method"] == "inference_with_reference":
|
||||
self.proposer = InferenceWithReferenceProposer(
|
||||
self.model_cfg["speculate_max_draft_token_num"],
|
||||
self.model_cfg["speculate_max_ngram_size"],
|
||||
self.args.max_batch_size,
|
||||
self.args.max_seq_len)
|
||||
else:
|
||||
raise NotImplementedError(f'Not support {self.config.speculate_method}, only support inference_with_reference now.')
|
||||
raise NotImplementedError(f'Not support {self.model_cfg["speculate_method"]}, only support inference_with_reference now.')
|
||||
else:
|
||||
self.proposer = None
|
||||
|
||||
@@ -278,7 +279,7 @@ class ModelRunner:
|
||||
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
|
||||
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
|
||||
# speculate decoding input
|
||||
if self.config.speculate_method is not None:
|
||||
if self.is_speculate_decoding:
|
||||
self.share_inputs["accept_tokens"] = paddle.full(
|
||||
shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64"
|
||||
)
|
||||
@@ -344,8 +345,8 @@ class ModelRunner:
|
||||
task["stop_seqs_len"], dtype="int32")
|
||||
self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
|
||||
task["stop_seqs"], dtype="int64")
|
||||
if self.proposer is not None:
|
||||
if self.config.speculate_method == "inference_with_reference":
|
||||
|
||||
if self.is_speculate_decoding:
|
||||
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1])
|
||||
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]])
|
||||
|
||||
@@ -353,7 +354,7 @@ class ModelRunner:
|
||||
"""
|
||||
step cuda
|
||||
"""
|
||||
if self.config.speculate_method is None:
|
||||
if not self.is_speculate_decoding:
|
||||
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
|
||||
self.share_inputs['step_seq_lens_encoder'],
|
||||
self.share_inputs['seq_lens_encoder'],
|
||||
|
@@ -22,9 +22,8 @@ from datetime import datetime
|
||||
import numpy as np
|
||||
from paddlenlp_ops import get_output, speculate_get_output
|
||||
from server.utils import datetime_diff, model_server_logger, monitor_logger
|
||||
from paddlenlp.utils.env import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ
|
||||
|
||||
SPECULATE_MAX_BSZ = 256
|
||||
MAX_DRAFT_TOKEN_NUM = 6
|
||||
|
||||
class TokenProcessor(object):
|
||||
"""
|
||||
@@ -40,8 +39,9 @@ class TokenProcessor(object):
|
||||
|
||||
self.tokens_counter = Counter()
|
||||
|
||||
if self.cfg.speculate_method is not None:
|
||||
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKEN_NUM + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64")
|
||||
self.is_speculate_decoding = self.cfg.get_model_config().get("speculate_method") is not None
|
||||
if self.is_speculate_decoding:
|
||||
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64")
|
||||
else:
|
||||
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
|
||||
self.worker = None
|
||||
@@ -71,7 +71,7 @@ class TokenProcessor(object):
|
||||
if self.worker is not None:
|
||||
raise Exception("Worker is already running!")
|
||||
|
||||
if self.cfg.speculate_method is not None:
|
||||
if self.is_speculate_decoding:
|
||||
self.worker = threading.Thread(target=self.process_speculate_results, args=())
|
||||
else:
|
||||
self.worker = threading.Thread(target=self.process_sampling_results, args=())
|
||||
@@ -302,7 +302,6 @@ class TokenProcessor(object):
|
||||
batch post-processing function
|
||||
"""
|
||||
tokens = self.output_tokens.numpy()
|
||||
model_server_logger.info(f"speculate_result tokens: {self.output_tokens.tolist()}")
|
||||
batch = self.output_tokens[1]
|
||||
output_token_msg_id = int(self.output_tokens[0])
|
||||
accept_num = tokens[2 : batch + 2]
|
||||
@@ -317,7 +316,7 @@ class TokenProcessor(object):
|
||||
if self.resource_manager.stop_flags[i]:
|
||||
continue
|
||||
|
||||
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM + accept_num[i]].tolist()
|
||||
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i]].tolist()
|
||||
# 跳过非法token
|
||||
if len(token_ids) == 0 or token_ids[-1] == 0:
|
||||
continue
|
||||
|
Reference in New Issue
Block a user