mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 09:31:35 +08:00
refactor code
This commit is contained in:
@@ -91,9 +91,6 @@ class Config:
|
|||||||
self.block_size = int(env.get("BLOCK_SIZE", 64))
|
self.block_size = int(env.get("BLOCK_SIZE", 64))
|
||||||
self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
|
self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
|
||||||
self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))
|
self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))
|
||||||
|
|
||||||
# speculate decoding config
|
|
||||||
self.speculate_method = str(os.getenv("SPECULATE_METHOD", None))
|
|
||||||
|
|
||||||
# infer config
|
# infer config
|
||||||
self.max_batch_size = int(env.get("BATCH_SIZE", 50))
|
self.max_batch_size = int(env.get("BATCH_SIZE", 50))
|
||||||
|
@@ -47,6 +47,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
self.config = Config()
|
self.config = Config()
|
||||||
self.model_cfg = self.config.get_model_config()
|
self.model_cfg = self.config.get_model_config()
|
||||||
|
self.is_speculate_decoding = self.model_cfg.get("speculate_method") is not None
|
||||||
self.format_print_configuration()
|
self.format_print_configuration()
|
||||||
|
|
||||||
self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
|
self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
|
||||||
@@ -68,16 +69,16 @@ class ModelRunner:
|
|||||||
self.cache_kvs = {}
|
self.cache_kvs = {}
|
||||||
self.init_inputs()
|
self.init_inputs()
|
||||||
|
|
||||||
# whether use speculate decoding
|
if self.is_speculate_decoding:
|
||||||
if self.config.speculate_method is not None:
|
logger.info(f'Using speculating decoding, method: {self.model_cfg["speculate_method"]}.')
|
||||||
if self.config.speculate_method == "inference_with_reference":
|
if self.model_cfg["speculate_method"] == "inference_with_reference":
|
||||||
self.proposer = InferenceWithReferenceProposer(
|
self.proposer = InferenceWithReferenceProposer(
|
||||||
self.model_cfg["speculate_max_draft_token_num"],
|
self.model_cfg["speculate_max_draft_token_num"],
|
||||||
self.model_cfg["speculate_max_ngram_size"],
|
self.model_cfg["speculate_max_ngram_size"],
|
||||||
self.args.max_batch_size,
|
self.args.max_batch_size,
|
||||||
self.args.max_seq_len)
|
self.args.max_seq_len)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Not support {self.config.speculate_method}, only support inference_with_reference now.')
|
raise NotImplementedError(f'Not support {self.model_cfg["speculate_method"]}, only support inference_with_reference now.')
|
||||||
else:
|
else:
|
||||||
self.proposer = None
|
self.proposer = None
|
||||||
|
|
||||||
@@ -278,7 +279,7 @@ class ModelRunner:
|
|||||||
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
|
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
|
||||||
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
|
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
|
||||||
# speculate decoding input
|
# speculate decoding input
|
||||||
if self.config.speculate_method is not None:
|
if self.is_speculate_decoding:
|
||||||
self.share_inputs["accept_tokens"] = paddle.full(
|
self.share_inputs["accept_tokens"] = paddle.full(
|
||||||
shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64"
|
shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64"
|
||||||
)
|
)
|
||||||
@@ -344,16 +345,16 @@ class ModelRunner:
|
|||||||
task["stop_seqs_len"], dtype="int32")
|
task["stop_seqs_len"], dtype="int32")
|
||||||
self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
|
self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
|
||||||
task["stop_seqs"], dtype="int64")
|
task["stop_seqs"], dtype="int64")
|
||||||
if self.proposer is not None:
|
|
||||||
if self.config.speculate_method == "inference_with_reference":
|
if self.is_speculate_decoding:
|
||||||
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1])
|
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1])
|
||||||
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]])
|
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]])
|
||||||
|
|
||||||
def step_cuda(self, seq_lens_this_time):
|
def step_cuda(self, seq_lens_this_time):
|
||||||
"""
|
"""
|
||||||
step cuda
|
step cuda
|
||||||
"""
|
"""
|
||||||
if self.config.speculate_method is None:
|
if not self.is_speculate_decoding:
|
||||||
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
|
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
|
||||||
self.share_inputs['step_seq_lens_encoder'],
|
self.share_inputs['step_seq_lens_encoder'],
|
||||||
self.share_inputs['seq_lens_encoder'],
|
self.share_inputs['seq_lens_encoder'],
|
||||||
|
@@ -22,9 +22,8 @@ from datetime import datetime
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from paddlenlp_ops import get_output, speculate_get_output
|
from paddlenlp_ops import get_output, speculate_get_output
|
||||||
from server.utils import datetime_diff, model_server_logger, monitor_logger
|
from server.utils import datetime_diff, model_server_logger, monitor_logger
|
||||||
|
from paddlenlp.utils.env import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ
|
||||||
|
|
||||||
SPECULATE_MAX_BSZ = 256
|
|
||||||
MAX_DRAFT_TOKEN_NUM = 6
|
|
||||||
|
|
||||||
class TokenProcessor(object):
|
class TokenProcessor(object):
|
||||||
"""
|
"""
|
||||||
@@ -40,8 +39,9 @@ class TokenProcessor(object):
|
|||||||
|
|
||||||
self.tokens_counter = Counter()
|
self.tokens_counter = Counter()
|
||||||
|
|
||||||
if self.cfg.speculate_method is not None:
|
self.is_speculate_decoding = self.cfg.get_model_config().get("speculate_method") is not None
|
||||||
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKEN_NUM + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64")
|
if self.is_speculate_decoding:
|
||||||
|
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64")
|
||||||
else:
|
else:
|
||||||
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
|
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
|
||||||
self.worker = None
|
self.worker = None
|
||||||
@@ -71,7 +71,7 @@ class TokenProcessor(object):
|
|||||||
if self.worker is not None:
|
if self.worker is not None:
|
||||||
raise Exception("Worker is already running!")
|
raise Exception("Worker is already running!")
|
||||||
|
|
||||||
if self.cfg.speculate_method is not None:
|
if self.is_speculate_decoding:
|
||||||
self.worker = threading.Thread(target=self.process_speculate_results, args=())
|
self.worker = threading.Thread(target=self.process_speculate_results, args=())
|
||||||
else:
|
else:
|
||||||
self.worker = threading.Thread(target=self.process_sampling_results, args=())
|
self.worker = threading.Thread(target=self.process_sampling_results, args=())
|
||||||
@@ -302,7 +302,6 @@ class TokenProcessor(object):
|
|||||||
batch post-processing function
|
batch post-processing function
|
||||||
"""
|
"""
|
||||||
tokens = self.output_tokens.numpy()
|
tokens = self.output_tokens.numpy()
|
||||||
model_server_logger.info(f"speculate_result tokens: {self.output_tokens.tolist()}")
|
|
||||||
batch = self.output_tokens[1]
|
batch = self.output_tokens[1]
|
||||||
output_token_msg_id = int(self.output_tokens[0])
|
output_token_msg_id = int(self.output_tokens[0])
|
||||||
accept_num = tokens[2 : batch + 2]
|
accept_num = tokens[2 : batch + 2]
|
||||||
@@ -317,7 +316,7 @@ class TokenProcessor(object):
|
|||||||
if self.resource_manager.stop_flags[i]:
|
if self.resource_manager.stop_flags[i]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM + accept_num[i]].tolist()
|
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i]].tolist()
|
||||||
# 跳过非法token
|
# 跳过非法token
|
||||||
if len(token_ids) == 0 or token_ids[-1] == 0:
|
if len(token_ids) == 0 or token_ids[-1] == 0:
|
||||||
continue
|
continue
|
||||||
|
Reference in New Issue
Block a user