refactor code

This commit is contained in:
Wanglongzhi2001
2024-12-17 13:12:51 +00:00
parent 389015bf04
commit 08877a985d
3 changed files with 17 additions and 20 deletions

View File

@@ -91,9 +91,6 @@ class Config:
self.block_size = int(env.get("BLOCK_SIZE", 64))
self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))
# speculate decoding config
self.speculate_method = str(os.getenv("SPECULATE_METHOD", None))
# infer config
self.max_batch_size = int(env.get("BATCH_SIZE", 50))

View File

@@ -47,6 +47,7 @@ class ModelRunner:
self.config = Config()
self.model_cfg = self.config.get_model_config()
self.is_speculate_decoding = self.model_cfg.get("speculate_method") is not None
self.format_print_configuration()
self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
@@ -68,16 +69,16 @@ class ModelRunner:
self.cache_kvs = {}
self.init_inputs()
# whether use speculate decoding
if self.config.speculate_method is not None:
if self.config.speculate_method == "inference_with_reference":
if self.is_speculate_decoding:
logger.info(f'Using speculating decoding, method: {self.model_cfg["speculate_method"]}.')
if self.model_cfg["speculate_method"] == "inference_with_reference":
self.proposer = InferenceWithReferenceProposer(
self.model_cfg["speculate_max_draft_token_num"],
self.model_cfg["speculate_max_ngram_size"],
self.args.max_batch_size,
self.args.max_seq_len)
else:
raise NotImplementedError(f'Not support {self.config.speculate_method}, only support inference_with_reference now.')
raise NotImplementedError(f'Not support {self.model_cfg["speculate_method"]}, only support inference_with_reference now.')
else:
self.proposer = None
@@ -278,7 +279,7 @@ class ModelRunner:
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
# speculate decoding input
if self.config.speculate_method is not None:
if self.is_speculate_decoding:
self.share_inputs["accept_tokens"] = paddle.full(
shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64"
)
@@ -344,16 +345,16 @@ class ModelRunner:
task["stop_seqs_len"], dtype="int32")
self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
task["stop_seqs"], dtype="int64")
if self.proposer is not None:
if self.config.speculate_method == "inference_with_reference":
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1])
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]])
if self.is_speculate_decoding:
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1])
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]])
def step_cuda(self, seq_lens_this_time):
"""
step cuda
"""
if self.config.speculate_method is None:
if not self.is_speculate_decoding:
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
self.share_inputs['step_seq_lens_encoder'],
self.share_inputs['seq_lens_encoder'],

View File

@@ -22,9 +22,8 @@ from datetime import datetime
import numpy as np
from paddlenlp_ops import get_output, speculate_get_output
from server.utils import datetime_diff, model_server_logger, monitor_logger
from paddlenlp.utils.env import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ
SPECULATE_MAX_BSZ = 256
MAX_DRAFT_TOKEN_NUM = 6
class TokenProcessor(object):
"""
@@ -40,8 +39,9 @@ class TokenProcessor(object):
self.tokens_counter = Counter()
if self.cfg.speculate_method is not None:
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKEN_NUM + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64")
self.is_speculate_decoding = self.cfg.get_model_config().get("speculate_method") is not None
if self.is_speculate_decoding:
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64")
else:
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
self.worker = None
@@ -71,7 +71,7 @@ class TokenProcessor(object):
if self.worker is not None:
raise Exception("Worker is already running!")
if self.cfg.speculate_method is not None:
if self.is_speculate_decoding:
self.worker = threading.Thread(target=self.process_speculate_results, args=())
else:
self.worker = threading.Thread(target=self.process_sampling_results, args=())
@@ -302,7 +302,6 @@ class TokenProcessor(object):
batch post-processing function
"""
tokens = self.output_tokens.numpy()
model_server_logger.info(f"speculate_result tokens: {self.output_tokens.tolist()}")
batch = self.output_tokens[1]
output_token_msg_id = int(self.output_tokens[0])
accept_num = tokens[2 : batch + 2]
@@ -317,7 +316,7 @@ class TokenProcessor(object):
if self.resource_manager.stop_flags[i]:
continue
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM + accept_num[i]].tolist()
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i]].tolist()
# 跳过非法token
if len(token_ids) == 0 or token_ids[-1] == 0:
continue