Rename top_p_sampling to top_k_top_p_sampling (#2791)

This commit is contained in:
Sunny-bot1
2025-07-10 15:09:25 +08:00
committed by GitHub
parent e45050cae3
commit 1e2319cbef
5 changed files with 23 additions and 16 deletions

View File

@@ -4,6 +4,7 @@
FastDeploy supports offline inference by loading models locally and processing user data. Usage examples:
### Chat Interface (LLM.chat)
```python
from fastdeploy import LLM, SamplingParams
@@ -77,10 +78,12 @@ for output in outputs:
prompt = output.prompt
generated_text = output.outputs.text
```
> Note: Text completion interface, suitable for scenarios where users have predefined the context input and expect the model to output only the continuation content. No additional `prompt` concatenation will be added during the inference process.
> For the `chat` model, it is recommended to use the Chat Interface (`LLM.chat`).
For multimodal models, such as `baidu/ERNIE-4.5-VL-28B-A3B-Paddle`, when calling the `generate interface`, you need to provide a prompt that includes images. The usage is as follows:
```python
import io
import os
@@ -141,6 +144,7 @@ for output in outputs:
reasoning_text = output.outputs.reasoning_content
```
>Note: The `generate interface` does not currently support passing parameters to control the thinking function (on/off). It always uses the model's default parameters.
## 2. API Documentation
@@ -176,6 +180,7 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md).
* repetition_penalty(float): Direct penalty for repeated tokens (>1 penalizes, <1 encourages)
* temperature(float): Controls randomness (higher = more random)
* top_p(float): Probability threshold for token selection
* top_k(int): Number of tokens considered for sampling
* max_tokens(int): Maximum generated tokens (input + output)
* min_tokens(int): Minimum forced generation length

View File

@@ -78,10 +78,12 @@ for output in outputs:
prompt = output.prompt
generated_text = output.outputs.text
```
> 注: 续写接口, 适应于用户自定义好上下文输入, 并希望模型仅输出续写内容的场景; 推理过程不会增加其他 `prompt`拼接。
> 对于 `chat`模型, 建议使用对话接口(LLM.chat)。
对于多模模型, 例如`baidu/ERNIE-4.5-VL-28B-A3B-Paddle`, 在调用`generate接口`时, 需要提供包含图片的prompt, 使用方式如下:
```python
import io
import os
@@ -142,6 +144,7 @@ for output in outputs:
reasoning_text = output.outputs.reasoning_content
```
> 注: `generate` 接口, 暂时不支持思考开关参数控制, 均使用模型默认思考能力。
## 2. 接口说明
@@ -155,7 +158,6 @@ for output in outputs:
> 2. 模型服务启动后会在日志文件log/fastdeploy.log中打印如 `Doing profile, the total_block_num:640` 的日志其中640即表示自动计算得到的KV Cache block数量将它乘以block_size(默认值64)即可得到部署后总共可以在KV Cache中缓存的Token数。
> 3. `max_num_seqs` 用于配置decode阶段最大并发处理请求数该参数可以基于第1点中缓存的Token数来计算一个较优值例如线上统计输入平均token数800, 输出平均token数500本次计>算得到KV Cache block为640 block_size为64。那么我们可以配置 `kv_cache_ratio = 800 / (800 + 500) = 0.6` , 配置 `max_seq_len = 640 * 64 / (800 + 500) = 31`。
### 2.2 fastdeploy.LLM.chat
* messages(list[dict],list[list[dict]]): 输入的message, 支持batch message 输入
@@ -178,7 +180,7 @@ for output in outputs:
* repetition_penalty(float): 直接对重复生成的token进行惩罚的系数>1时惩罚重复<1时鼓励重复
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
* top_p(float): 概率累积分布截断阈值仅考虑累计概率达到此阈值的最可能token集合
* top_k(int): 采样概率最高的token数量考虑概率最高的k个token进行采样
* top_k(int): 采样概率最高的token数量考虑概率最高的k个token进行采样
* max_tokens(int): 限制模型生成的最大token数量包括输入和输出
* min_tokens(int): 强制模型生成的最少token数量避免过早结束

View File

@@ -16,10 +16,10 @@
from .apply_penalty_multi_scores import (
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores)
from .top_p_sampling import top_p_sampling
from .top_k_top_p_sampling import top_k_top_p_sampling
__all__ = [
"apply_penalty_multi_scores",
"apply_speculative_penalty_multi_scores",
"top_p_sampling",
"top_k_top_p_sampling",
]

View File

@@ -25,7 +25,7 @@ if current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import \
top_p_sampling as gcu_top_p_sampling
def top_p_sampling(
def top_k_top_p_sampling(
x: paddle.Tensor,
top_p: paddle.Tensor,
top_k: Optional[paddle.Tensor] = None,

View File

@@ -27,7 +27,7 @@ from fastdeploy.model_executor.guided_decoding.base_guided_decoding import \
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores,
top_p_sampling)
top_k_top_p_sampling)
from fastdeploy.platforms import current_platform
@@ -214,7 +214,7 @@ class Sampler(nn.Layer):
probs = F.softmax(logits)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
self.processor.update_output_tokens(next_tokens, skip_idx_list)
return next_tokens
@@ -367,5 +367,5 @@ class MTPSampler(nn.Layer):
)
probs = F.softmax(logits)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
return next_tokens