[CP] CP Lm head fp32 and temp_logprob to release/2.1 (#3766)

* [Feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing (#3552)

* [feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing

* infer engine support temp_scaled_logprobs and top_p_normalized_logprobs

* delete some code

* code check

* code check and add doc

* fix tokenizer.decoder(-1), return 'Invalid Token'

* add ci for temp_scaled and top_p logprobs

* check test

* check seq len time shape

* logprob clip inf

---------

Co-authored-by: sunlei1024 <sunlei5788@gmail.com>

* [Precision] Support lm_head layer running in float32 (#3597)

* support lm_head fp32 bf16 fp16

* support lm_head fp32 bf16 fp16

* add doc and check code

* lm_head_fp32 specify lm_head as fp32

* code check

* check doc

* code check

---------

Co-authored-by: sunlei1024 <sunlei5788@gmail.com>
This commit is contained in:
chen
2025-09-01 19:56:54 +08:00
committed by GitHub
parent 4da603daec
commit 1e19833ba5
22 changed files with 188 additions and 54 deletions

View File

@@ -45,8 +45,9 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "Hello!"}, "logprobs": true, "top_logprobs": 5
]
{"role": "user", "content": "Hello!"}
],
"logprobs": true, "top_logprobs": 0,
}'
```
@@ -190,6 +191,12 @@ max_streaming_response_tokens: Optional[int] = None
disable_chat_template: Optional[bool] = False
# Whether to disable chat template rendering, using raw input directly (default False means template is enabled).
temp_scaled_logprobs: Optional[bool] = False
# Whether to divide the logits by the temperature coefficient when calculating logprobs (default is False, meaning the logits are not divided by the temperature coefficient).
top_p_normalized_logprobs: Optional[bool] = False
# Whether to perform top-p normalization when calculating logprobs (default is False, indicating that top-p normalization is not performed).
```
### Differences in Return Fields

View File

@@ -48,6 +48,7 @@ When using FastDeploy to deploy models (including offline inference and service
| ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting |
| ```tool_call_parser``` | `str` | Specify the function call parser to be used for extracting function call content from the model's output. |
| ```tool_parser_plugin``` | `str` | Specify the file path of the tool parser to be registered, so as to register parsers that are not in the code repository. The code format within these parsers must adhere to the format used in the code repository. |
| ```lm_head_fp32``` | `bool` | Specify the dtype of the lm_head layer as FP32. |
## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```?

View File

@@ -45,8 +45,9 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "Hello!"}, "logprobs": true, "top_logprobs": 5
]
{"role": "user", "content": "Hello!"}
],
"logprobs": true, "top_logprobs": 0,
}'
```
@@ -189,6 +190,12 @@ max_streaming_response_tokens: Optional[int] = None
disable_chat_template: Optional[bool] = False
# 是否禁用聊天模板渲染,直接使用原始输入(默认 False 表示启用模板)。
temp_scaled_logprobs: Optional[bool] = False
# 计算logprob时是否对logits除以温度系数默认 False 表示不除以温度系数)。
top_p_normalized_logprobs: Optional[bool] = False
# 计算logprob时是否进行 top_p 归一化(默认 False 表示不进行top_p归一化
```
### 返回字段差异

View File

@@ -46,6 +46,7 @@
| ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob则在启动时可以省略此参数。 |
| ```tool_call_parser``` | `str` | 指定要使用的function call解析器以便从模型输出中抽取 function call内容|
| ```tool_parser_plugin``` | `str` | 指定要注册的tool parser文件路径以便注册不在代码库中的parserparser中代码格式需遵循代码库中格式|
| ```lm_head_fp32``` | `bool` | 指定lm_head层的类型为 FP32 |
## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系?

View File

@@ -119,6 +119,7 @@ class ModelConfig:
self.redundant_experts_num = 0
self.quantization = None
self.think_end_id = None
self.lm_head_fp32 = False
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)

View File

@@ -344,6 +344,11 @@ class EngineArgs:
- "new_loader": new loader.
"""
lm_head_fp32: bool = False
"""
Flag to specify the dtype of lm_head as FP32. Default is False (Using model default dtype).
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -519,6 +524,12 @@ class EngineArgs:
default=EngineArgs.early_stop_config,
help="the config for early stop.",
)
model_group.add_argument(
"--lm_head-fp32",
action="store_true",
default=EngineArgs.lm_head_fp32,
help="Specify the dtype of lm_head weight as float32.",
)
# Parallel processing parameters group
parallel_group = parser.add_argument_group("Parallel Configuration")

View File

@@ -1139,6 +1139,7 @@ class LLMEngine:
"enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce,
"enable_logprob": self.cfg.enable_logprob,
"enable_mm": self.cfg.enable_mm,
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
}
for worker_flag, value in worker_append_flag.items():
if value:

View File

@@ -98,6 +98,9 @@ class SamplingParams:
reasoning_max_tokens: Optional[int] = None
min_tokens: int = 1
logprobs: Optional[int] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
bad_words: Optional[List[str]] = None
_bad_words_token_ids: Optional[List[int]] = None

View File

@@ -371,6 +371,9 @@ class CompletionRequest(BaseModel):
echo: Optional[bool] = False
frequency_penalty: Optional[float] = None
logprobs: Optional[int] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
max_tokens: Optional[int] = None
n: int = 1
presence_penalty: Optional[float] = None
@@ -502,6 +505,11 @@ class ChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
# remove max_tokens when field is removed from OpenAI API
max_tokens: Optional[int] = Field(
default=None,
@@ -558,6 +566,8 @@ class ChatCompletionRequest(BaseModel):
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
req_dict["temp_scaled_logprobs"] = self.temp_scaled_logprobs
req_dict["top_p_normalized_logprobs"] = self.top_p_normalized_logprobs
# parse request model into dict, priority: request params > metadata params
if self.metadata is not None:

View File

@@ -22,7 +22,7 @@ from paddle import nn
from paddle.distributed import fleet
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.utils import set_weight_attrs
from fastdeploy.model_executor.models.utils import set_weight_attrs, temporary_dtype
from .utils import get_tensor
@@ -39,6 +39,7 @@ class ParallelLMHead(nn.Layer):
embedding_dim: int,
prefix: str = "",
with_bias: bool = False,
dtype: str = None,
) -> None:
"""
Parallelized LMhead.
@@ -51,6 +52,7 @@ class ParallelLMHead(nn.Layer):
embedding_dim (int): size of hidden state.
prefix (str): The name of current layer. Defaults to "".
with_bias (bool): whether to have bias. Default: False.
dtype (str): The dtype of weight. Defalut: None.
"""
super(ParallelLMHead, self).__init__()
self.weight_key: str = prefix + ".weight"
@@ -63,9 +65,10 @@ class ParallelLMHead(nn.Layer):
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
self.dtype = "float32" if fd_config.model_config.lm_head_fp32 else dtype
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
with temporary_dtype(self.dtype):
if self.use_ep:
self.weight = self.create_parameter(
shape=[embedding_dim, num_embeddings],
@@ -106,20 +109,20 @@ class ParallelLMHead(nn.Layer):
"""
if self.use_ep:
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(self.weight.dtype))
else:
if self.tie_word_embeddings:
self.linear.weight.set_value(
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype).transpose([1, 0])
)
else:
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(self.linear.weight.dtype)
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.linear.weight.set_value(weight_tensor)
if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
bias = get_tensor(state_dict.pop(self.bias_key)).astype(self.linear.weight.dtype)
self.linear.bias.set_value(bias)
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
@@ -134,7 +137,8 @@ class ParallelLMHead(nn.Layer):
"""
logits = input
if self.use_ep:
logits = paddle.matmul(logits, self.weight)
logits = paddle.matmul(logits.astype(self.weight.dtype), self.weight)
else:
logits = self.linear(logits)
logits = self.linear(logits.astype(self.linear.weight.dtype))
print(self.linear.weight.dtype)
return logits

View File

@@ -15,7 +15,7 @@
"""
from dataclasses import dataclass
from typing import Optional
from typing import Dict, Optional
import paddle
@@ -48,3 +48,6 @@ class SamplingMetadata:
stop_flags: Optional[paddle.Tensor] = None
prompt_ids: Optional[paddle.Tensor] = None
prompt_lens: Optional[paddle.Tensor] = None
temp_scaled_logprobs: Optional[paddle.Tensor] = None
top_p_normalized_logprobs: Optional[paddle.Tensor] = None
share_inputs: Optional[Dict[str, paddle.Tensor]] = None

View File

@@ -40,6 +40,18 @@ from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
def top_p_normalize_probs_paddle(
probs: paddle.Tensor,
top_ps: paddle.Tensor,
):
probs_idx = probs.argsort(axis=-1, descending=True)
probs_sort = paddle.take_along_axis(probs, probs_idx, axis=-1)
probs_sum = paddle.cumsum(probs_sort, axis=-1)
probs_sort = paddle.where((probs_sum - probs_sort) > top_ps, paddle.zeros_like(probs_sort), probs_sort)
probs_sort.divide_(probs_sort.sum(axis=-1, keepdim=True))
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
class SamplerProcessor:
"""
SamplingProcessor for guided decoding.
@@ -206,9 +218,45 @@ class Sampler(nn.Layer):
"""pre process before running"""
self.processor.pre_process(skip_idx_list)
def compute_logprobs(self, logits: paddle.Tensor) -> paddle.Tensor:
def compute_logprobs(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
) -> paddle.Tensor:
""" """
return F.log_softmax(logits, axis=-1)
last_logits = logits
real_bsz = last_logits.shape[0]
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
share_inputs = sampling_metadata.share_inputs
if temp_scaled_logprobs is not None:
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
temperature = sampling_metadata.temperature[:real_bsz]
temp_temperature = paddle.where(real_bsz_temp_scaled, temperature, paddle.ones_like(temperature))
last_logits = last_logits / temp_temperature
last_logprobs = F.log_softmax(last_logits, axis=-1)
top_p_logprob = None
top_p_req_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None:
seq_lens_this_time = share_inputs["seq_lens_this_time"].reshape([-1, 1])[:real_bsz]
seq_lens_encoder = share_inputs["seq_lens_encoder"].reshape([-1, 1])[:real_bsz]
seq_lens_decoder = share_inputs["seq_lens_decoder"].reshape([-1, 1])[:real_bsz]
seq_lens_time_sum = seq_lens_this_time + seq_lens_encoder + seq_lens_decoder
real_req_mask = seq_lens_time_sum > 0
top_p_req_mask = paddle.logical_and(top_p_normalized_logprobs[:real_bsz], real_req_mask)
real_req_top_p = sampling_metadata.top_p[:real_bsz]
# Normalize logprobs if top_p normalization is enabled
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
top_p_req_mask = paddle.logical_and(top_p_req_mask, real_req_top_p != 1.0)
if top_p_req_mask.any():
probs = F.softmax(last_logits, axis=-1)
probs = top_p_normalize_probs_paddle(probs, real_req_top_p)
top_p_logprob = paddle.log(probs)
if top_p_logprob is not None:
last_logprobs = paddle.where(top_p_req_mask, top_p_logprob, last_logprobs)
return last_logprobs
def gather_logprobs(
self,
@@ -233,6 +281,7 @@ class Sampler(nn.Layer):
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == paddle.int64
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
# Get with the logprob of the prompt or sampled token.
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
@@ -259,7 +308,7 @@ class Sampler(nn.Layer):
""" """
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits)
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
logits = self.processor.apply_token_mask(logits, skip_idx_list)

View File

@@ -621,7 +621,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
def compute_logits(self, hidden_states: paddle.Tensor):
""" """
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -412,13 +412,13 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
"""
self.ernie.load_state_dict(state_dict)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
else:
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -363,7 +363,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
compute logits
"""
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -570,13 +570,13 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
self.vision_model.load_state_dict(state_dict)
self.resampler_model.load_state_dict(state_dict)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.ernie.embed_tokens.embeddings.weight.transpose([1, 0]))
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
else:
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -326,7 +326,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
def compute_logits(self, hidden_states: paddle.Tensor):
""" """
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -285,6 +285,8 @@ class Qwen3ForCausalLM(ModelForCasualLM):
param = params_dict[loaded_weight_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
if self.tie_word_embeddings:
self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight})
@paddle.no_grad()
def set_state_dict(self, state_dict):
@@ -298,14 +300,14 @@ class Qwen3ForCausalLM(ModelForCasualLM):
"""
self.model.load_state_dict(state_dict)
if self.tie_word_embeddings:
self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0]))
self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight})
else:
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
""" """
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -295,7 +295,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
def compute_logits(self, hidden_states: paddle.Tensor):
""" """
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits

View File

@@ -23,6 +23,7 @@ import os
import random
import re
import struct
from contextlib import contextmanager
from functools import partial
from typing import Any, NamedTuple, Optional, Union
@@ -533,3 +534,15 @@ def parser_quant_type(quant_type):
quant_type_list.append(default_type)
return quant_type_list[0], quant_type_list[1], quant_type_list[2]
@contextmanager
def temporary_dtype(dtype: str):
"""Temporarily set Paddle default dtype"""
orig_dtype = paddle.get_default_dtype()
try:
if dtype is not None and dtype == "float32":
paddle.set_default_dtype(dtype)
yield
finally:
paddle.set_default_dtype(orig_dtype)

View File

@@ -315,6 +315,10 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = request.get("temp_scaled_logprobs", False)
self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = request.get(
"top_p_normalized_logprobs", False
)
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
@@ -493,6 +497,12 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request(
request, "presence_penalty", 0.0
)
self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request(
request, "temp_scaled_logprobs", False
)
self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = get_attr_from_request(
request, "top_p_normalized_logprobs", False
)
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
@@ -622,6 +632,8 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["presence_score"] = paddle.full(
[max_num_seqs, 1], self.model_config.presence_score, dtype="float32"
)
self.share_inputs["temp_scaled_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool")
self.share_inputs["top_p_normalized_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool")
self.share_inputs["min_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64")
self.share_inputs["max_dec_len"] = paddle.full(
@@ -841,6 +853,9 @@ class GPUModelRunner(ModelRunnerBase):
max_num_logprobs=20 if self.enable_logprob else None,
enable_early_stop=self.enable_early_stop,
stop_flags=self.share_inputs["stop_flags"],
temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"],
top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"],
share_inputs=self.share_inputs,
)
def load_model(self) -> None:

View File

@@ -624,6 +624,12 @@ def parse_args():
help="The format of the model weights to load. default/new_loader.",
)
parser.add_argument(
"--lm_head_fp32",
action="store_true",
help="Flag to specify dtype of lm_head as FP32",
)
args = parser.parse_args()
return args