diff --git a/docs/online_serving/README.md b/docs/online_serving/README.md index 761e79720..6378652cd 100644 --- a/docs/online_serving/README.md +++ b/docs/online_serving/README.md @@ -45,8 +45,9 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "messages": [ - {"role": "user", "content": "Hello!"}, "logprobs": true, "top_logprobs": 5 - ] + {"role": "user", "content": "Hello!"} + ], + "logprobs": true, "top_logprobs": 0, }' ``` @@ -190,6 +191,12 @@ max_streaming_response_tokens: Optional[int] = None disable_chat_template: Optional[bool] = False # Whether to disable chat template rendering, using raw input directly (default False means template is enabled). + +temp_scaled_logprobs: Optional[bool] = False +# Whether to divide the logits by the temperature coefficient when calculating logprobs (default is False, meaning the logits are not divided by the temperature coefficient). + +top_p_normalized_logprobs: Optional[bool] = False +# Whether to perform top-p normalization when calculating logprobs (default is False, indicating that top-p normalization is not performed). ``` ### Differences in Return Fields diff --git a/docs/parameters.md b/docs/parameters.md index ab8361d2a..0d8d3dead 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -48,6 +48,7 @@ When using FastDeploy to deploy models (including offline inference and service | ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting | | ```tool_call_parser``` | `str` | Specify the function call parser to be used for extracting function call content from the model's output. | | ```tool_parser_plugin``` | `str` | Specify the file path of the tool parser to be registered, so as to register parsers that are not in the code repository. The code format within these parsers must adhere to the format used in the code repository. | +| ```lm_head_fp32``` | `bool` | Specify the dtype of the lm_head layer as FP32. | ## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```? diff --git a/docs/zh/online_serving/README.md b/docs/zh/online_serving/README.md index a68eedbdb..06c45efdc 100644 --- a/docs/zh/online_serving/README.md +++ b/docs/zh/online_serving/README.md @@ -45,8 +45,9 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "messages": [ - {"role": "user", "content": "Hello!"}, "logprobs": true, "top_logprobs": 5 - ] + {"role": "user", "content": "Hello!"} + ], + "logprobs": true, "top_logprobs": 0, }' ``` @@ -189,6 +190,12 @@ max_streaming_response_tokens: Optional[int] = None disable_chat_template: Optional[bool] = False # 是否禁用聊天模板渲染,直接使用原始输入(默认 False 表示启用模板)。 + +temp_scaled_logprobs: Optional[bool] = False +# 计算logprob时是否对logits除以温度系数(默认 False 表示不除以温度系数)。 + +top_p_normalized_logprobs: Optional[bool] = False +# 计算logprob时是否进行 top_p 归一化(默认 False 表示不进行top_p归一化)。 ``` ### 返回字段差异 diff --git a/docs/zh/parameters.md b/docs/zh/parameters.md index 7ed2ea4d5..a83c8a7ef 100644 --- a/docs/zh/parameters.md +++ b/docs/zh/parameters.md @@ -46,6 +46,7 @@ | ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob,则在启动时可以省略此参数。 | | ```tool_call_parser``` | `str` | 指定要使用的function call解析器,以便从模型输出中抽取 function call内容| | ```tool_parser_plugin``` | `str` | 指定要注册的tool parser文件路径,以便注册不在代码库中的parser,parser中代码格式需遵循代码库中格式| +| ```lm_head_fp32``` | `bool` | 指定lm_head层的类型为 FP32 | ## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系? diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 5e97c07b4..d8a1f288f 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -119,6 +119,7 @@ class ModelConfig: self.redundant_experts_num = 0 self.quantization = None self.think_end_id = None + self.lm_head_fp32 = False for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 054077c13..9f47b7a05 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -344,6 +344,11 @@ class EngineArgs: - "new_loader": new loader. """ + lm_head_fp32: bool = False + """ + Flag to specify the dtype of lm_head as FP32. Default is False (Using model default dtype). + """ + def __post_init__(self): """ Post-initialization processing to set default tokenizer if not provided. @@ -519,6 +524,12 @@ class EngineArgs: default=EngineArgs.early_stop_config, help="the config for early stop.", ) + model_group.add_argument( + "--lm_head-fp32", + action="store_true", + default=EngineArgs.lm_head_fp32, + help="Specify the dtype of lm_head weight as float32.", + ) # Parallel processing parameters group parallel_group = parser.add_argument_group("Parallel Configuration") diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 7196bdc0b..7648f9bfc 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -1139,6 +1139,7 @@ class LLMEngine: "enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce, "enable_logprob": self.cfg.enable_logprob, "enable_mm": self.cfg.enable_mm, + "lm_head_fp32": self.cfg.model_config.lm_head_fp32, } for worker_flag, value in worker_append_flag.items(): if value: diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 1cd77d2b1..f95f09bd5 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -98,6 +98,9 @@ class SamplingParams: reasoning_max_tokens: Optional[int] = None min_tokens: int = 1 logprobs: Optional[int] = None + # For logits and logprobs post processing + temp_scaled_logprobs: bool = False + top_p_normalized_logprobs: bool = False bad_words: Optional[List[str]] = None _bad_words_token_ids: Optional[List[int]] = None diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 1e4444124..ed71f5676 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -371,6 +371,9 @@ class CompletionRequest(BaseModel): echo: Optional[bool] = False frequency_penalty: Optional[float] = None logprobs: Optional[int] = None + # For logits and logprobs post processing + temp_scaled_logprobs: bool = False + top_p_normalized_logprobs: bool = False max_tokens: Optional[int] = None n: int = 1 presence_penalty: Optional[float] = None @@ -502,6 +505,11 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = None logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 + + # For logits and logprobs post processing + temp_scaled_logprobs: bool = False + top_p_normalized_logprobs: bool = False + # remove max_tokens when field is removed from OpenAI API max_tokens: Optional[int] = Field( default=None, @@ -558,6 +566,8 @@ class ChatCompletionRequest(BaseModel): req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens req_dict["logprobs"] = self.top_logprobs if self.logprobs else None + req_dict["temp_scaled_logprobs"] = self.temp_scaled_logprobs + req_dict["top_p_normalized_logprobs"] = self.top_p_normalized_logprobs # parse request model into dict, priority: request params > metadata params if self.metadata is not None: diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index 5c1fd3c15..32f617160 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -22,7 +22,7 @@ from paddle import nn from paddle.distributed import fleet from fastdeploy.config import FDConfig -from fastdeploy.model_executor.models.utils import set_weight_attrs +from fastdeploy.model_executor.models.utils import set_weight_attrs, temporary_dtype from .utils import get_tensor @@ -39,6 +39,7 @@ class ParallelLMHead(nn.Layer): embedding_dim: int, prefix: str = "", with_bias: bool = False, + dtype: str = None, ) -> None: """ Parallelized LMhead. @@ -51,6 +52,7 @@ class ParallelLMHead(nn.Layer): embedding_dim (int): size of hidden state. prefix (str): The name of current layer. Defaults to "". with_bias (bool): whether to have bias. Default: False. + dtype (str): The dtype of weight. Defalut: None. """ super(ParallelLMHead, self).__init__() self.weight_key: str = prefix + ".weight" @@ -63,39 +65,40 @@ class ParallelLMHead(nn.Layer): ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear + self.dtype = "float32" if fd_config.model_config.lm_head_fp32 else dtype self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings - - if self.use_ep: - self.weight = self.create_parameter( - shape=[embedding_dim, num_embeddings], - dtype=paddle.get_default_dtype(), - is_bias=False, - ) - else: - if self.column_cut: - need_gather = True - self.linear = ColumnParallelLinear( - embedding_dim, - num_embeddings, - mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), - weight_attr=None, - has_bias=True if self.bias_key is not None else False, - gather_output=need_gather, - fuse_matmul_bias=False, # False diff更小 + with temporary_dtype(self.dtype): + if self.use_ep: + self.weight = self.create_parameter( + shape=[embedding_dim, num_embeddings], + dtype=paddle.get_default_dtype(), + is_bias=False, ) - set_weight_attrs(self.linear.weight, {"output_dim": True}) else: - self.linear = RowParallelLinear( - embedding_dim, - num_embeddings, - mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), - weight_attr=None, - has_bias=True if self.bias_key is not None else False, - input_is_parallel=False, - fuse_matmul_bias=False, # False diff更小 - ) - set_weight_attrs(self.linear.weight, {"output_dim": False}) + if self.column_cut: + need_gather = True + self.linear = ColumnParallelLinear( + embedding_dim, + num_embeddings, + mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), + weight_attr=None, + has_bias=True if self.bias_key is not None else False, + gather_output=need_gather, + fuse_matmul_bias=False, # False diff更小 + ) + set_weight_attrs(self.linear.weight, {"output_dim": True}) + else: + self.linear = RowParallelLinear( + embedding_dim, + num_embeddings, + mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), + weight_attr=None, + has_bias=True if self.bias_key is not None else False, + input_is_parallel=False, + fuse_matmul_bias=False, # False diff更小 + ) + set_weight_attrs(self.linear.weight, {"output_dim": False}) def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]): """ @@ -106,20 +109,20 @@ class ParallelLMHead(nn.Layer): """ if self.use_ep: - self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())) + self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(self.weight.dtype)) else: if self.tie_word_embeddings: self.linear.weight.set_value( - get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0]) + get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype).transpose([1, 0]) ) else: - weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()) + weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype) if self.linear.weight.shape != weight_tensor.shape: weight_tensor = weight_tensor.transpose([1, 0]) self.linear.weight.set_value(weight_tensor) if self.bias_key is not None: - bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()) + bias = get_tensor(state_dict.pop(self.bias_key)).astype(self.linear.weight.dtype) self.linear.bias.set_value(bias) def forward(self, input: paddle.Tensor) -> paddle.Tensor: @@ -134,7 +137,8 @@ class ParallelLMHead(nn.Layer): """ logits = input if self.use_ep: - logits = paddle.matmul(logits, self.weight) + logits = paddle.matmul(logits.astype(self.weight.dtype), self.weight) else: - logits = self.linear(logits) + logits = self.linear(logits.astype(self.linear.weight.dtype)) + print(self.linear.weight.dtype) return logits diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 9cca5af27..9d9d52efa 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -15,7 +15,7 @@ """ from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional import paddle @@ -48,3 +48,6 @@ class SamplingMetadata: stop_flags: Optional[paddle.Tensor] = None prompt_ids: Optional[paddle.Tensor] = None prompt_lens: Optional[paddle.Tensor] = None + temp_scaled_logprobs: Optional[paddle.Tensor] = None + top_p_normalized_logprobs: Optional[paddle.Tensor] = None + share_inputs: Optional[Dict[str, paddle.Tensor]] = None diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 412a7eda7..d67339aef 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -40,6 +40,18 @@ from fastdeploy.platforms import current_platform from fastdeploy.worker.output import LogprobsTensors, SamplerOutput +def top_p_normalize_probs_paddle( + probs: paddle.Tensor, + top_ps: paddle.Tensor, +): + probs_idx = probs.argsort(axis=-1, descending=True) + probs_sort = paddle.take_along_axis(probs, probs_idx, axis=-1) + probs_sum = paddle.cumsum(probs_sort, axis=-1) + probs_sort = paddle.where((probs_sum - probs_sort) > top_ps, paddle.zeros_like(probs_sort), probs_sort) + probs_sort.divide_(probs_sort.sum(axis=-1, keepdim=True)) + return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1) + + class SamplerProcessor: """ SamplingProcessor for guided decoding. @@ -206,9 +218,45 @@ class Sampler(nn.Layer): """pre process before running""" self.processor.pre_process(skip_idx_list) - def compute_logprobs(self, logits: paddle.Tensor) -> paddle.Tensor: + def compute_logprobs( + self, + logits: paddle.Tensor, + sampling_metadata: SamplingMetadata, + ) -> paddle.Tensor: """ """ - return F.log_softmax(logits, axis=-1) + last_logits = logits + real_bsz = last_logits.shape[0] + temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs + top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs + share_inputs = sampling_metadata.share_inputs + if temp_scaled_logprobs is not None: + real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] + temperature = sampling_metadata.temperature[:real_bsz] + temp_temperature = paddle.where(real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)) + last_logits = last_logits / temp_temperature + + last_logprobs = F.log_softmax(last_logits, axis=-1) + top_p_logprob = None + top_p_req_mask = None + + if top_p_normalized_logprobs is not None and share_inputs is not None: + seq_lens_this_time = share_inputs["seq_lens_this_time"].reshape([-1, 1])[:real_bsz] + seq_lens_encoder = share_inputs["seq_lens_encoder"].reshape([-1, 1])[:real_bsz] + seq_lens_decoder = share_inputs["seq_lens_decoder"].reshape([-1, 1])[:real_bsz] + seq_lens_time_sum = seq_lens_this_time + seq_lens_encoder + seq_lens_decoder + real_req_mask = seq_lens_time_sum > 0 + top_p_req_mask = paddle.logical_and(top_p_normalized_logprobs[:real_bsz], real_req_mask) + real_req_top_p = sampling_metadata.top_p[:real_bsz] + # Normalize logprobs if top_p normalization is enabled + # NOTE: only normalize logprobs when top_p is set and not equal to 1.0 + top_p_req_mask = paddle.logical_and(top_p_req_mask, real_req_top_p != 1.0) + if top_p_req_mask.any(): + probs = F.softmax(last_logits, axis=-1) + probs = top_p_normalize_probs_paddle(probs, real_req_top_p) + top_p_logprob = paddle.log(probs) + if top_p_logprob is not None: + last_logprobs = paddle.where(top_p_req_mask, top_p_logprob, last_logprobs) + return last_logprobs def gather_logprobs( self, @@ -233,6 +281,7 @@ class Sampler(nn.Layer): Sampled token rank tensor, (num tokens) """ assert token_ids.dtype == paddle.int64 + logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) # Get with the logprob of the prompt or sampled token. token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) @@ -259,7 +308,7 @@ class Sampler(nn.Layer): """ """ num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - raw_logprobs = self.compute_logprobs(logits) + raw_logprobs = self.compute_logprobs(logits, sampling_metadata) logits = self.processor.apply_token_mask(logits, skip_idx_list) diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 8cbd4a0bd..d695d5a22 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -621,7 +621,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): def compute_logits(self, hidden_states: paddle.Tensor): """ """ logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 460170b7d..3d9695a18 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -412,13 +412,13 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM): """ self.ernie.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight}) else: self.lm_head.load_state_dict(state_dict) def compute_logits(self, hidden_states: paddle.Tensor): logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/ernie4_5_mtp.py b/fastdeploy/model_executor/models/ernie4_5_mtp.py index b52d8ed71..6e8e83603 100644 --- a/fastdeploy/model_executor/models/ernie4_5_mtp.py +++ b/fastdeploy/model_executor/models/ernie4_5_mtp.py @@ -363,7 +363,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM): compute logits """ logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 2dd562135..fe303ee2d 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -570,13 +570,13 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM): self.vision_model.load_state_dict(state_dict) self.resampler_model.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight}) else: self.lm_head.load_state_dict(state_dict) def compute_logits(self, hidden_states: paddle.Tensor): logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index af2af00b1..acf74e8e1 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -326,7 +326,7 @@ class Qwen2ForCausalLM(ModelForCasualLM): def compute_logits(self, hidden_states: paddle.Tensor): """ """ logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 5aa00bfa9..4c2314950 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -285,6 +285,8 @@ class Qwen3ForCausalLM(ModelForCasualLM): param = params_dict[loaded_weight_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight) + if self.tie_word_embeddings: + self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight}) @paddle.no_grad() def set_state_dict(self, state_dict): @@ -298,14 +300,14 @@ class Qwen3ForCausalLM(ModelForCasualLM): """ self.model.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) + self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight}) else: self.lm_head.load_state_dict(state_dict) def compute_logits(self, hidden_states: paddle.Tensor): """ """ logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 7064ceafc..b63b0ad7e 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -295,7 +295,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM): def compute_logits(self, hidden_states: paddle.Tensor): """ """ logits = self.lm_head(hidden_states) - logits = paddle.cast(logits, paddle.float32) + logits = logits.astype(paddle.float32) logits[:, self.ori_vocab_size :] = -float("inf") return logits diff --git a/fastdeploy/model_executor/models/utils.py b/fastdeploy/model_executor/models/utils.py index 48da4736f..243a7cda4 100644 --- a/fastdeploy/model_executor/models/utils.py +++ b/fastdeploy/model_executor/models/utils.py @@ -23,6 +23,7 @@ import os import random import re import struct +from contextlib import contextmanager from functools import partial from typing import Any, NamedTuple, Optional, Union @@ -533,3 +534,15 @@ def parser_quant_type(quant_type): quant_type_list.append(default_type) return quant_type_list[0], quant_type_list[1], quant_type_list[2] + + +@contextmanager +def temporary_dtype(dtype: str): + """Temporarily set Paddle default dtype""" + orig_dtype = paddle.get_default_dtype() + try: + if dtype is not None and dtype == "float32": + paddle.set_default_dtype(dtype) + yield + finally: + paddle.set_default_dtype(orig_dtype) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 05f3e83dd..7cf5434ad 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -315,6 +315,10 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0) + self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = request.get("temp_scaled_logprobs", False) + self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = request.get( + "top_p_normalized_logprobs", False + ) self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( @@ -493,6 +497,12 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request( request, "presence_penalty", 0.0 ) + self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request( + request, "temp_scaled_logprobs", False + ) + self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = get_attr_from_request( + request, "top_p_normalized_logprobs", False + ) self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( @@ -622,6 +632,8 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["presence_score"] = paddle.full( [max_num_seqs, 1], self.model_config.presence_score, dtype="float32" ) + self.share_inputs["temp_scaled_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool") + self.share_inputs["top_p_normalized_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool") self.share_inputs["min_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") self.share_inputs["max_dec_len"] = paddle.full( @@ -841,6 +853,9 @@ class GPUModelRunner(ModelRunnerBase): max_num_logprobs=20 if self.enable_logprob else None, enable_early_stop=self.enable_early_stop, stop_flags=self.share_inputs["stop_flags"], + temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"], + top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"], + share_inputs=self.share_inputs, ) def load_model(self) -> None: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 828cdfd14..527316349 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -624,6 +624,12 @@ def parse_args(): help="The format of the model weights to load. default/new_loader.", ) + parser.add_argument( + "--lm_head_fp32", + action="store_true", + help="Flag to specify dtype of lm_head as FP32", + ) + args = parser.parse_args() return args