diff --git a/docs/usage/environment_variables.md b/docs/usage/environment_variables.md index a8f3ac17b..31f895370 100644 --- a/docs/usage/environment_variables.md +++ b/docs/usage/environment_variables.md @@ -38,7 +38,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether to use HuggingFace tokenizer (0 or 1) "FD_USE_HF_TOKENIZER": - lambda: os.getenv("FD_USE_HF_TOKENIZER", 0), + lambda: bool(int(os.getenv("FD_USE_HF_TOKENIZER", 0))), # ZMQ send high-water mark (HWM) during initialization "FD_ZMQ_SNDHWM": diff --git a/docs/zh/usage/environment_variables.md b/docs/zh/usage/environment_variables.md index 8037c3362..cda1fc4f0 100644 --- a/docs/zh/usage/environment_variables.md +++ b/docs/zh/usage/environment_variables.md @@ -1,4 +1,5 @@ # FastDeploy 环境变量说明 + FastDeploy 的环境变量保存在了代码库根目录下 fastdeploy/envs.py 文件中,以下是其对应的中文版说明: ```python @@ -37,7 +38,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # 是否使用 HuggingFace 分词器 "FD_USE_HF_TOKENIZER": - lambda: os.getenv("FD_USE_HF_TOKENIZER", 0), + lambda: bool(int(os.getenv("FD_USE_HF_TOKENIZER", 0))), # 设置 ZMQ 初始化期间接收数据的高水位标记(HWM) "FD_ZMQ_SNDHWM": diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 865a09082..99278f7d1 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -125,6 +125,8 @@ class ModelConfig: self.redundant_experts_num = 0 self.seed = 0 self.quantization = None + self.pad_token_id: int = -1 + self.eos_tokens_lens: int = 2 for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) @@ -258,10 +260,6 @@ class ParallelConfig: self.engine_pid: Optional[int] = None # Do profile or not self.do_profile: bool = False - # - self.pad_token_id: int = -1 - # - self.eos_tokens_lens: int = 2 self.max_num_batched_tokens: int = 2048 # splitwise role diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index d9f5beb9c..1c310961c 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -42,7 +42,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # splited by comma, such as 0,1,2. "CUDA_VISIBLE_DEVICES": lambda: os.getenv("CUDA_VISIBLE_DEVICES", None), # Whether to use HuggingFace tokenizer. - "FD_USE_HF_TOKENIZER": lambda: os.getenv("FD_USE_HF_TOKENIZER", 0), + "FD_USE_HF_TOKENIZER": lambda: bool(int(os.getenv("FD_USE_HF_TOKENIZER", "0"))), # Set the high watermark (HWM) for receiving data during ZMQ initialization "FD_ZMQ_SNDHWM": lambda: os.getenv("FD_ZMQ_SNDHWM", 10000), # cache kv quant params directory @@ -61,7 +61,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether transition from standalone PD decoupling to centralized inference "FD_PD_CHANGEABLE": lambda: os.getenv("FD_PD_CHANGEABLE", "0"), # Whether to use fastsafetensor load weight (0 or 1) - "FD_USE_FASTSAFETENSOR": lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"), + "FD_USE_FASTSAFETENSOR": lambda: bool(int(os.getenv("FD_USE_FASTSAFETENSOR", "0"))), # Whether to use DeepGemm for FP8 blockwise MoE. "FD_USE_DEEP_GEMM": lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))), # Whether to use aggregate send. diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index 63feda934..28d91bdbf 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -19,7 +19,6 @@ import os import numpy as np from paddleformers.generation import GenerationConfig -from fastdeploy import envs from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer from fastdeploy.input.text_processor import BaseDataProcessor from fastdeploy.utils import data_processor_logger @@ -47,25 +46,6 @@ class ErnieProcessor(BaseDataProcessor): self.model_name_or_path = model_name_or_path data_processor_logger.info(f"model_name_or_path: {model_name_or_path}") - self._init_config() - - self.decode_status = dict() - self.thinking_parser_dict = dict() - self._load_tokenizer() - data_processor_logger.info( - f"tokenizer information: bos_token is {self.tokenizer.bos_token} \ - {self.tokenizer.bos_token_id}, \ - eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id} " - ) - self.eos_token_ids = [self.tokenizer.eos_token_id] - self.eos_token_id_len = len(self.eos_token_ids) - self.pad_token_id = self.get_pad_id() - self.reasoning_parser = None - if reasoning_parser_obj: - self.reasoning_parser = reasoning_parser_obj(self.tokenizer) - - def _init_config(self): - self.use_hf_tokenizer = int(envs.FD_USE_HF_TOKENIZER) == 1 # Generation config try: @@ -77,6 +57,23 @@ class ErnieProcessor(BaseDataProcessor): ) self.generation_config = None + self.decode_status = dict() + self.thinking_parser_dict = dict() + self._load_tokenizer() + data_processor_logger.info( + f"tokenizer information: bos_token is {self.tokenizer.bos_token} \ + {self.tokenizer.bos_token_id}, \ + eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id} " + ) + from paddleformers.trl.llm_utils import get_eos_token_id + + self.eos_token_ids = get_eos_token_id(self.tokenizer, self.generation_config) + self.eos_token_id_len = len(self.eos_token_ids) + self.pad_token_id = self.get_pad_id() + self.reasoning_parser = None + if reasoning_parser_obj: + self.reasoning_parser = reasoning_parser_obj(self.tokenizer) + def process_request(self, request, max_model_len=None, **kwargs): """ Preprocess the request diff --git a/fastdeploy/input/ernie_vl_processor.py b/fastdeploy/input/ernie_vl_processor.py index a2c4dd1e5..63ae5bc31 100644 --- a/fastdeploy/input/ernie_vl_processor.py +++ b/fastdeploy/input/ernie_vl_processor.py @@ -14,8 +14,6 @@ # limitations under the License. """ -import os - import numpy as np from paddleformers.generation import GenerationConfig @@ -35,10 +33,6 @@ class ErnieMoEVLProcessor(ErnieProcessor): mm_processor_kwargs=None, reasoning_parser_obj=None, ): - self.use_hf_tokenizer = False - - if "merge_llm_model" in model_name_or_path: - model_name_or_path = os.path.dirname(model_name_or_path) data_processor_logger.info(f"model_name_or_path: {model_name_or_path}") tokenizer_path = model_name_or_path preprocessor_path = model_name_or_path @@ -55,13 +49,6 @@ class ErnieMoEVLProcessor(ErnieProcessor): self.decode_status = dict() self._load_tokenizer() - self.eos_token_ids = [self.tokenizer.eos_token_id] - self.eos_token_id_len = len(self.eos_token_ids) - self.pad_token_id = self.get_pad_id() - self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt) - self.reasoning_parser = None - if reasoning_parser_obj: - self.reasoning_parser = reasoning_parser_obj(self.tokenizer) # Generation config try: @@ -72,6 +59,17 @@ class ErnieMoEVLProcessor(ErnieProcessor): ) self.generation_config = None + # self.eos_token_ids = [self.tokenizer.eos_token_id] + from paddleformers.trl.llm_utils import get_eos_token_id + + self.eos_token_ids = get_eos_token_id(self.tokenizer, self.generation_config) + self.eos_token_id_len = len(self.eos_token_ids) + self.pad_token_id = self.get_pad_id() + self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt) + self.reasoning_parser = None + if reasoning_parser_obj: + self.reasoning_parser = reasoning_parser_obj(self.tokenizer) + def get_pad_id(self): """get pad id""" return self.tokenizer.pad_token_id diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index 664868a59..cbaca990c 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -165,7 +165,14 @@ class DataProcessor(BaseDataProcessor): self.model_name_or_path = model_name_or_path - self._init_config() + # Generation config + try: + self.generation_config = GenerationConfig.from_pretrained(self.model_name_or_path) + except Exception as e: + data_processor_logger.warning( + f"Can't find generation config: {e}, so it will not use generation_config field in the model config" + ) + self.generation_config = None self.decode_status = dict() self.tokenizer = self._load_tokenizer() @@ -184,30 +191,6 @@ class DataProcessor(BaseDataProcessor): self.reasoning_parser = reasoning_parser_obj(self.tokenizer) self.tokenizer.pad_token_id = self.pad_token_id - def _init_config(self): - """ - 初始化配置,包括模型名称、使用Hugging Face Tokenizer等。 - - Args: - 无参数,但是会从环境变量中获取一些配置信息。 - - Returns: - 无返回值,直接修改了类的属性。 - - Raises: - 无异常抛出。 - """ - self.use_hf_tokenizer = int(envs.FD_USE_HF_TOKENIZER) == 1 - - # Generation config - try: - self.generation_config = GenerationConfig.from_pretrained(self.model_name_or_path) - except Exception as e: - data_processor_logger.warning( - f"Can't find generation config: {e}, so it will not use generation_config field in the model config" - ) - self.generation_config = None - def process_request(self, request, max_model_len=None, **kwargs): """ Preprocess the request @@ -433,7 +416,7 @@ class DataProcessor(BaseDataProcessor): Returns: List[int]: token ids list """ - if self.use_hf_tokenizer: + if envs.FD_USE_HF_TOKENIZER: tokens = self.tokenizer( text, return_tensors="np", @@ -491,7 +474,7 @@ class DataProcessor(BaseDataProcessor): Returns: List[str]: strings """ - if self.use_hf_tokenizer: + if envs.FD_USE_HF_TOKENIZER: if task_id not in self.decode_status: # history token ids & history token strings & befer decode str self.decode_status[task_id] = [[], [], ""] @@ -536,7 +519,7 @@ class DataProcessor(BaseDataProcessor): Returns: tokenizer (AutoTokenizer) """ - if self.use_hf_tokenizer: + if envs.FD_USE_HF_TOKENIZER: from transformers import AutoTokenizer return AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=False) @@ -557,7 +540,7 @@ class DataProcessor(BaseDataProcessor): """ results_all = "" if task_id in self.decode_status: - if self.use_hf_tokenizer: + if envs.FD_USE_HF_TOKENIZER: results_all = self.decode_status[task_id][2] else: results_all = "".join(self.decode_status[task_id][3]) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 0794e42cf..f622a6e39 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -181,7 +181,8 @@ def post_process_normal( ) stop_wo_think = ( - (sampler_output.sampled_token_ids == model_output.eos_token_id) | (model_output.reasoning_index == 0) + (sampler_output.sampled_token_ids == model_output.eos_token_id.T).any(axis=1, keepdim=True) + | (model_output.reasoning_index == 0) ) & (model_output.need_think_end > 0) sampler_output.sampled_token_ids = paddle.where( stop_wo_think, diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index b52b35bc4..d7054cf93 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -236,8 +236,6 @@ class GCUModelRunner(ModelRunnerBase): self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length self.share_inputs["prompt_lens"][idx : idx + 1] = length - if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: - request.eos_token_ids.append(request.eos_token_ids[0]) self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7) self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0) @@ -315,7 +313,9 @@ class GCUModelRunner(ModelRunnerBase): idx = i self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) - self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1) + self.share_inputs["eos_token_id"][:] = np.array( + [2] * self.model_config.eos_tokens_lens, dtype="int64" + ).reshape(-1, 1) self.seq_lens_this_time_buffer[idx : idx + 1] = input_length self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length @@ -350,15 +350,15 @@ class GCUModelRunner(ModelRunnerBase): ) self.share_inputs["input_ids"] = paddle.full( [max_num_seqs, self.parallel_config.max_model_len], - self.parallel_config.pad_token_id, + self.model_config.pad_token_id, dtype="int64", ) self.share_inputs["prompt_ids"] = paddle.full( [max_num_seqs, self.parallel_config.max_model_len], - self.parallel_config.pad_token_id, + self.model_config.pad_token_id, dtype="int64", ) - self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64") + self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64") self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32") diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index c22993adf..9a5895742 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -265,7 +265,11 @@ class GPUModelRunner(ModelRunnerBase): ) input_ids = request.prompt_token_ids + request.output_token_ids - logger.debug(f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}") + logger.debug( + f"Handle prefill request {request} at idx {idx}, " + f"{prefill_start_index=}, {prefill_end_index=}, " + f"need_prefilled_token_num={len(input_ids)}" + ) self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( input_ids[prefill_start_index:prefill_end_index] ) @@ -307,8 +311,7 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["is_block_step"][idx : idx + 1] = False continue - if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: - request.eos_token_ids.append(request.eos_token_ids[0]) + assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) @@ -471,8 +474,7 @@ class GPUModelRunner(ModelRunnerBase): else: return default_value - if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: - request.eos_token_ids.append(request.eos_token_ids[0]) + assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7) self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0) @@ -562,7 +564,9 @@ class GPUModelRunner(ModelRunnerBase): idx = i self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) - self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1) + self.share_inputs["eos_token_id"][:] = np.array( + [2] * self.model_config.eos_tokens_lens, dtype="int64" + ).reshape(-1, 1) self.seq_lens_this_time_buffer[idx : idx + 1] = input_length self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length @@ -597,15 +601,15 @@ class GPUModelRunner(ModelRunnerBase): ) self.share_inputs["input_ids"] = paddle.full( [max_num_seqs, self.parallel_config.max_model_len], - self.parallel_config.pad_token_id, + self.model_config.pad_token_id, dtype="int64", ) self.share_inputs["prompt_ids"] = paddle.full( [max_num_seqs, self.parallel_config.max_model_len], - self.parallel_config.pad_token_id, + self.model_config.pad_token_id, dtype="int64", ) - self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64") + self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64") self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32") diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 6ea5ff7e5..8c06481de 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -431,8 +431,7 @@ class XPUModelRunner(ModelRunnerBase): self.share_inputs["is_block_step"][idx : idx + 1] = False continue - if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: - request.eos_token_ids.append(request.eos_token_ids[0]) + assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) @@ -472,8 +471,7 @@ class XPUModelRunner(ModelRunnerBase): idx = request.idx length = request.prompt_token_ids_len self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) - if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: - request.eos_token_ids.append(request.eos_token_ids[0]) + assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["pre_ids"][idx : idx + 1] = -1 self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) @@ -543,10 +541,10 @@ class XPUModelRunner(ModelRunnerBase): ) self.share_inputs["input_ids"] = paddle.full( [max_num_seqs, self.parallel_config.max_model_len], - self.parallel_config.pad_token_id, + self.model_config.pad_token_id, dtype="int64", ) - self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64") + self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64") self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32") @@ -815,7 +813,9 @@ class XPUModelRunner(ModelRunnerBase): for i in range(batch_size): idx = i self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) - self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1) + self.share_inputs["eos_token_id"][:] = np.array( + [2] * self.model_config.eos_tokens_lens, dtype="int64" + ).reshape(-1, 1) self.seq_lens_this_time_buffer[idx : idx + 1] = input_length self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length