diff --git a/README.md b/README.md
index 5d55c865f..f18ca3da9 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-English | [简体中文](README_CN.md)
+English | [简体中文](README_CN.md)
@@ -23,9 +23,10 @@ English | [简体中文](README_CN.md)
--------------------------------------------------------------------------------
-# FastDeploy 2.0: Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
+# FastDeploy 2.1: Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
## News
+**[2025-08] 🔥 Released FastDeploy v2.1:** A brand-new KV Cache scheduling strategy has been introduced, and expanded support for PD separation and CUDA Graph across more models. Enhanced hardware support has been added for platforms like Kunlun and Hygon, along with comprehensive optimizations to improve the performance of both the service and inference engine.
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务,即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
@@ -75,13 +76,13 @@ Learn how to use FastDeploy through our documentation:
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
-|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| WIP |128K |
-|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| WIP | 128K |
+|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| ✅ |128K |
+|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| ✅ | 128K |
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ✅ | ✅|128K |
-|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅|128K |
-|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅| 128K |
+|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅|128K |
+|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅| 128K |
## Advanced Usage
diff --git a/README_CN.md b/README_CN.md
index 8b3777508..7afca5fb8 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -1,5 +1,4 @@
[English](README.md) | 简体中文
-[English](README.md) | 简体中文
@@ -24,9 +23,10 @@
--------------------------------------------------------------------------------
-# FastDeploy 2.0:基于飞桨的大语言模型与视觉语言模型推理部署工具包
+# FastDeploy 2.1:基于飞桨的大语言模型与视觉语言模型推理部署工具包
## 最新活动
+**[2025-08] 🔥 FastDeploy v2.1 全新发布:** 全新的KV Cache调度策略,更多模型支持PD分离和CUDA Graph,昆仑、海光等更多硬件支持增强,全方面优化服务和推理引擎的性能。
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务,即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
@@ -41,7 +41,6 @@
- ⏩ **高级加速技术**:推测解码、多令牌预测(MTP)及分块预填充
- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等
-
## 要求
- 操作系统: Linux
@@ -73,13 +72,13 @@ FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
-|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| WIP |128K |
-|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| WIP | 128K |
+|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| ✅ |128K |
+|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| ✅ | 128K |
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ✅ | ✅|128K |
-|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅|128K |
-|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅| 128K |
+|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅|128K |
+|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅| 128K |
## 进阶用法
diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h
index 2bc07cf39..efe5b26bc 100644
--- a/custom_ops/gpu_ops/moe/fused_moe_op.h
+++ b/custom_ops/gpu_ops/moe/fused_moe_op.h
@@ -574,6 +574,7 @@ template
__launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
void topk_gating_softmax(const T* input,
+ const T* bias,
T* output,
const int64_t num_rows,
IdxT* indices,
@@ -716,7 +717,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
- row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
+ row_chunk[ii] = bias ? row_chunk[ii] * reciprocal_row_sum + bias[first_elt_read_by_thread + ii] : row_chunk[ii] * reciprocal_row_sum;
}
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find
@@ -765,6 +766,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
}
// Write the max for this k iteration to global memory.
+ T final_val = bias ? T(max_val) - bias[expert] : T(max_val);
if (thread_group_idx == 0) {
// The lead thread from each sub-group will write out the final results to
// global memory. (This will be a single) thread per row of the
@@ -772,11 +774,11 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
const int idx = k * thread_row + k_idx;
if constexpr (Norm_Weights) {
const int idx_in_cta = k * thread_row_in_cta + k_idx;
- row_output[idx_in_cta] = T(max_val);
- weight_sum += T(max_val);
+ row_output[idx_in_cta] = final_val;
+ weight_sum += final_val;
}
else {
- output[idx] = T(max_val);
+ output[idx] = final_val;
}
indices[idx] = should_process_row ? expert : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
@@ -831,6 +833,7 @@ struct TopkConstants {
template
void topk_gating_softmax_launcher_helper(const T* input,
+ const T* bias,
T* output,
IdxT* indices,
int* source_row,
@@ -851,7 +854,7 @@ void topk_gating_softmax_launcher_helper(const T* input,
static constexpr int ROWS_PER_CTA = WARPS_PER_TB * ROWS_PER_WARP;
topk_gating_softmax
<<>>(
- input, output, num_rows, indices, source_row, k);
+ input, bias, output, num_rows, indices, source_row, k);
}
template
@@ -882,7 +885,7 @@ static void run(const T* input,
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper( \
- input, output, indices, source_row, num_rows, num_experts, k, stream); \
+ input, gating_correction_bias, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
}
int64_t tem_num_experts = num_experts;
diff --git a/custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu b/custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu
index 8f780e00a..2d87c6bae 100644
--- a/custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu
+++ b/custom_ops/gpu_ops/moe/moe_redundant_topk_select.cu
@@ -51,7 +51,7 @@ void moe_redundant_topk_select_kernel(const T* input,
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper( \
- input, output, indices, source_row, num_rows, num_experts, k, stream); \
+ input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
}
int64_t tem_num_experts = num_experts;
diff --git a/custom_ops/gpu_ops/moe/moe_topk_select.cu b/custom_ops/gpu_ops/moe/moe_topk_select.cu
index 1798689c0..bbdaabdf2 100644
--- a/custom_ops/gpu_ops/moe/moe_topk_select.cu
+++ b/custom_ops/gpu_ops/moe/moe_topk_select.cu
@@ -47,17 +47,14 @@ void moe_topk_select_kernel(const T* input,
case N: { \
if (apply_norm_weight) { \
topk_gating_softmax_launcher_helper( \
- input, output, indices, source_row, num_rows, num_experts, k, stream); \
+ input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
} else { \
topk_gating_softmax_launcher_helper( \
- input, output, indices, source_row, num_rows, num_experts, k, stream); \
+ input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
} \
break; \
}
- int64_t tem_num_experts = num_experts;
- // when bias is not none, set tem_num_experts to 0 to follow the default branch
- if(bias != nullptr) tem_num_experts = 0;
- switch (tem_num_experts) {
+ switch (num_experts) {
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
diff --git a/dockerfiles/Dockerfile.gpu b/dockerfiles/Dockerfile.gpu
index 057f30228..9e1d97834 100644
--- a/dockerfiles/Dockerfile.gpu
+++ b/dockerfiles/Dockerfile.gpu
@@ -1,6 +1,6 @@
FROM ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.0.0
-ARG PADDLE_VERSION=3.1.0
-ARG FD_VERSION=2.0.0
+ARG PADDLE_VERSION=3.1.1
+ARG FD_VERSION=2.1.0
ENV DEBIAN_FRONTEND=noninteractive
diff --git a/docs/best_practices/ERNIE-4.5-0.3B-Paddle.md b/docs/best_practices/ERNIE-4.5-0.3B-Paddle.md
index 890822c29..333cd3843 100644
--- a/docs/best_practices/ERNIE-4.5-0.3B-Paddle.md
+++ b/docs/best_practices/ERNIE-4.5-0.3B-Paddle.md
@@ -25,12 +25,12 @@ The minimum number of GPUs required to deploy `ERNIE-4.5-0.3B` on the following
### 2.1 Basic: Launching the Service
Start the service by following command:
```bash
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-0.3B-Paddle \
--tensor-parallel-size 1 \
--quantization wint4 \
--max-model-len 32768 \
- --kv-cache-ratio 0.75 \
--max-num-seqs 128
```
- `--quantization`: indicates the quantization strategy used by the model. Different quantization strategies will result in different performance and accuracy of the model. It could be one of `wint8` / `wint4` / `block_wise_fp8`(Hopper is needed).
@@ -77,8 +77,8 @@ Add the following lines to the startup parameters
```
Notes:
1. Usually, no additional parameters need to be set, but CUDAGraph will generate some additional memory overhead, which may need to be adjusted in some scenarios with limited memory. For detailed parameter adjustments, please refer to [GraphOptimizationBackend](../parameters.md) for related configuration parameter descriptions
-2. When CUDAGraph is enabled, only single-card inference is supported, that is, `--tensor-parallel-size 1`
-3. When CUDAGraph is enabled, it is not supported to enable `Chunked Prefill` and `Prefix Caching` at the same time
+2. When CUDAGraph is enabled, if running with multi-GPUs TP>1, `--enable-custom-all-reduce` must be specified at the same time.
+3. When CUDAGraph is enabled, the scenario of `max-model-len > 32768` is not currently supported.
#### 2.2.6 Rejection Sampling
**Idea:**
diff --git a/docs/best_practices/ERNIE-4.5-21B-A3B-Paddle.md b/docs/best_practices/ERNIE-4.5-21B-A3B-Paddle.md
index 5754d6b0a..35dd27671 100644
--- a/docs/best_practices/ERNIE-4.5-21B-A3B-Paddle.md
+++ b/docs/best_practices/ERNIE-4.5-21B-A3B-Paddle.md
@@ -25,12 +25,12 @@ The minimum number of GPUs required to deploy `ERNIE-4.5-21B-A3B` on the followi
### 2.1 Basic: Launching the Service
Start the service by following command:
```bash
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-21B-A3B-Paddle \
--tensor-parallel-size 1 \
--quantization wint4 \
--max-model-len 32768 \
- --kv-cache-ratio 0.75 \
--max-num-seqs 128
```
- `--quantization`: indicates the quantization strategy used by the model. Different quantization strategies will result in different performance and accuracy of the model. It could be one of `wint8` / `wint4` / `block_wise_fp8`(Hopper is needed).
@@ -87,8 +87,8 @@ Add the following lines to the startup parameters
```
Notes:
1. Usually, no additional parameters need to be set, but CUDAGraph will generate some additional memory overhead, which may need to be adjusted in some scenarios with limited memory. For detailed parameter adjustments, please refer to [GraphOptimizationBackend](../parameters.md) for related configuration parameter descriptions
-2. When CUDAGraph is enabled, only single-card inference is supported, that is, `--tensor-parallel-size 1`
-3. When CUDAGraph is enabled, it is not supported to enable `Chunked Prefill` and `Prefix Caching` at the same time
+2. When CUDAGraph is enabled, if running with multi-GPUs TP>1, `--enable-custom-all-reduce` must be specified at the same time.
+3. When CUDAGraph is enabled, the scenario of `max-model-len > 32768` is not currently supported.
#### 2.2.6 Rejection Sampling
**Idea:**
@@ -111,6 +111,7 @@ export INFERENCE_MSG_QUEUE_ID=1315
export FLAGS_max_partition_size=2048
export FD_ATTENTION_BACKEND=FLASH_ATTN
export FD_LOG_DIR="prefill_log"
+export ENABLE_V1_KVCACHE_SCHEDULER=1
quant_type=block_wise_fp8
export FD_USE_DEEP_GEMM=0
@@ -120,7 +121,7 @@ python -m fastdeploy.entrypoints.openai.api_server --model baidu/ERNIE-4.5-21B-A
--max-num-seqs 20 \
--num-gpu-blocks-override 40000 \
--quantization ${quant_type} \
- --gpu-memory-utilization 0.9 --kv-cache-ratio 0.9 \
+ --gpu-memory-utilization 0.9 \
--port 7012 --engine-worker-queue-port 7013 --metrics-port 7014 --tensor-parallel-size 4 \
--cache-queue-port 7015 \
--splitwise-role "prefill" \
@@ -131,6 +132,7 @@ export CUDA_VISIBLE_DEVICES=4,5,6,7
export INFERENCE_MSG_QUEUE_ID=1215
export FLAGS_max_partition_size=2048
export FD_LOG_DIR="decode_log"
+export ENABLE_V1_KVCACHE_SCHEDULER=1
quant_type=block_wise_fp8
export FD_USE_DEEP_GEMM=0
@@ -139,7 +141,7 @@ python -m fastdeploy.entrypoints.openai.api_server --model baidu/ERNIE-4.5-21B-A
--max-model-len 131072 \
--max-num-seqs 20 \
--quantization ${quant_type} \
- --gpu-memory-utilization 0.85 --kv-cache-ratio 0.1 \
+ --gpu-memory-utilization 0.85 \
--port 9012 --engine-worker-queue-port 8013 --metrics-port 8014 --tensor-parallel-size 4 \
--cache-queue-port 8015 \
--innode-prefill-ports 7013 \
diff --git a/docs/best_practices/ERNIE-4.5-300B-A47B-Paddle.md b/docs/best_practices/ERNIE-4.5-300B-A47B-Paddle.md
index 285e2e044..0a09b6517 100644
--- a/docs/best_practices/ERNIE-4.5-300B-A47B-Paddle.md
+++ b/docs/best_practices/ERNIE-4.5-300B-A47B-Paddle.md
@@ -22,12 +22,12 @@ The minimum number of GPUs required to deploy `ERNIE-4.5-300B-A47B` on the follo
### 2.1 Basic: Launching the Service
Start the service by following command:
```bash
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
--tensor-parallel-size 8 \
--quantization wint4 \
--max-model-len 32768 \
- --kv-cache-ratio 0.75 \
--max-num-seqs 128
```
- `--quantization`: indicates the quantization strategy used by the model. Different quantization strategies will result in different performance and accuracy of the model. It could be one of `wint8` / `wint4` / `block_wise_fp8`(Hopper is needed).
@@ -99,6 +99,7 @@ export FD_SAMPLING_CLASS=rejection
**How to enable:** Take the deployment of a single machine with 8 GPUs and 1P1D (4 GPUs each) as an example. Compared with the default hybrid deployment method, `--splitwise-role` is required to specify the role of the node. And the GPUs and logs of the two nodes are isolated through the environment variables `FD_LOG_DIR` and `CUDA_VISIBLE_DEVICES`.
```
export FD_LOG_DIR="log_prefill"
+export ENABLE_V1_KVCACHE_SCHEDULER=1
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
@@ -111,6 +112,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
```
```
export FD_LOG_DIR="log_decode"
+export ENABLE_V1_KVCACHE_SCHEDULER=1
export CUDA_VISIBLE_DEVICES=4,5,6,7
# Note that innode-prefill-ports is specified as the Prefill serviceengine-worker-queue-port
python -m fastdeploy.entrypoints.openai.api_server \
@@ -124,5 +126,20 @@ python -m fastdeploy.entrypoints.openai.api_server \
--splitwise-role "decode"
```
+#### 2.2.8 CUDAGraph
+**Idea:**
+CUDAGraph is a GPU computing acceleration technology provided by NVIDIA. It achieves efficient execution and optimization of GPU tasks by capturing CUDA operation sequences into a graph structure. The core idea of CUDAGraph is to encapsulate a series of GPU computing and memory operations into a re-executable graph, thereby reducing CPU-GPU communication overhead, reducing kernel startup latency, and improving overall computing performance.
+
+**How to enable:**
+Add the following lines to the startup parameters
+```
+--use-cudagraph
+--enable-custom-all-reduce
+```
+Notes:
+1. Usually, no additional parameters need to be set, but CUDAGraph will generate some additional memory overhead, which may need to be adjusted in some scenarios with limited memory. For detailed parameter adjustments, please refer to [GraphOptimizationBackend](../parameters.md) for related configuration parameter descriptions
+2. When CUDAGraph is enabled, if running with multi-GPUs TP>1, `--enable-custom-all-reduce` must be specified at the same time.
+3. When CUDAGraph is enabled, the scenario of `max-model-len > 32768` is not currently supported.
+
## FAQ
If you encounter any problems during use, you can refer to [FAQ](./FAQ.md).
diff --git a/docs/best_practices/ERNIE-4.5-VL-28B-A3B-Paddle.md b/docs/best_practices/ERNIE-4.5-VL-28B-A3B-Paddle.md
index d839049d2..3fc933fb2 100644
--- a/docs/best_practices/ERNIE-4.5-VL-28B-A3B-Paddle.md
+++ b/docs/best_practices/ERNIE-4.5-VL-28B-A3B-Paddle.md
@@ -27,7 +27,6 @@ Installation process reference documentation [FastDeploy GPU Install](../get_sta
**Example 1:** Deploying a 32K Context Service on a Single RTX 4090 GPU
```shell
export ENABLE_V1_KVCACHE_SCHEDULER=1
-
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-VL-28B-A3B-Paddle \
--port 8180 \
@@ -47,7 +46,6 @@ python -m fastdeploy.entrypoints.openai.api_server \
**Example 2:** Deploying a 128K Context Service on Dual H800 GPUs
```shell
export ENABLE_V1_KVCACHE_SCHEDULER=1
-
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-VL-28B-A3B-Paddle \
--port 8180 \
@@ -64,6 +62,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
--quantization wint4 \
--enable-mm
```
+
+> ⚠️ For versions 2.1 and above, the new scheduler needs to be enabled via an environment variable `ENABLE_V1_KVCACHE_SCHEDULER=1`. Otherwise, some requests may be truncated before reaching the maximum length or return empty results.
+
An example is a set of configurations that can run stably while also delivering relatively good performance. If you have further requirements for precision or performance, please continue reading the content below.
### 2.2 Advanced: How to Achieve Better Performance
@@ -109,6 +110,15 @@ An example is a set of configurations that can run stably while also delivering
- If slightly higher precision is required, you may try WINT8.
- Only consider using BFLOAT16 if your application scenario demands extreme precision, as it requires significantly more GPU memory.
+#### 2.2.4 **Adjustable environment variables**
+> **Rejection sampling:**`FD_SAMPLING_CLASS=rejection`
+- **Description:** Rejection sampling involves generating samples from a proposal distribution that is easy to sample from, thereby avoiding explicit sorting and achieving an effect of improving sampling speed, which can enhance inference performance.
+- **Recommendation:** This is a relatively aggressive optimization strategy that affects the results, and we are still conducting comprehensive validation of its impact. If you have high performance requirements and can accept potential compromises in results, you may consider enabling this strategy.
+
+> **Attention Hyperparameter:**`FLAGS_max_partition_size=1024`
+- **Description:** The hyperparameters for the Append Attention (default) backend have been tested on commonly used datasets, and our results show that setting it to 1024 can significantly improve decoding speed, especially in long-text scenarios.
+- **Recommendation:** In the future, it will be modified to an automatic adjustment mechanism. If you have high performance requirements, you may consider enabling it.
+
## 3. FAQ
**Note:** Deploying multimodal services requires adding parameters to the configuration `--enable-mm`.
diff --git a/docs/best_practices/ERNIE-4.5-VL-424B-A47B-Paddle.md b/docs/best_practices/ERNIE-4.5-VL-424B-A47B-Paddle.md
index ea536ffb0..2741a417e 100644
--- a/docs/best_practices/ERNIE-4.5-VL-424B-A47B-Paddle.md
+++ b/docs/best_practices/ERNIE-4.5-VL-424B-A47B-Paddle.md
@@ -24,7 +24,6 @@ Installation process reference documentation [FastDeploy GPU Install](../get_sta
**Example 1:** Deploying a 128K context service on 8x H800 GPUs.
```shell
export ENABLE_V1_KVCACHE_SCHEDULER=1
-
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-VL-424B-A47B-Paddle \
--port 8180 \
@@ -42,6 +41,8 @@ python -m fastdeploy.entrypoints.openai.api_server \
--enable-mm
```
+> ⚠️ For versions 2.1 and above, the new scheduler needs to be enabled via an environment variable `ENABLE_V1_KVCACHE_SCHEDULER=1`. Otherwise, some requests may be truncated before reaching the maximum length or return empty results.
+
An example is a set of configurations that can run stably while also delivering relatively good performance. If you have further requirements for precision or performance, please continue reading the content below.
### 2.2 Advanced: How to Achieve Better Performance
@@ -87,6 +88,15 @@ An example is a set of configurations that can run stably while also delivering
- If slightly higher precision is required, you may try wint8.
- Only consider using bfloat16 if your application scenario demands extreme precision, as it requires significantly more GPU memory.
+#### 2.2.4 **Adjustable environment variables**
+> **Rejection sampling:**`FD_SAMPLING_CLASS=rejection`
+- **Description:** Rejection sampling involves generating samples from a proposal distribution that is easy to sample from, thereby avoiding explicit sorting and achieving an effect of improving sampling speed, which can enhance inference performance.
+- **Recommendation:** This is a relatively aggressive optimization strategy that affects the results, and we are still conducting comprehensive validation of its impact. If you have high performance requirements and can accept potential compromises in results, you may consider enabling this strategy.
+
+> **Attention Hyperparameter:**`FLAGS_max_partition_size=1024`
+- **Description:** The hyperparameters for the Append Attention (default) backend have been tested on commonly used datasets, and our results show that setting it to 1024 can significantly improve decoding speed, especially in long-text scenarios.
+- **Recommendation:** In the future, it will be modified to an automatic adjustment mechanism. If you have high performance requirements, you may consider enabling it.
+
## 3. FAQ
**Note:** Deploying multimodal services requires adding parameters to the configuration `--enable-mm`.
diff --git a/docs/best_practices/README.md b/docs/best_practices/README.md
index c3ab9875c..b7ff016a5 100644
--- a/docs/best_practices/README.md
+++ b/docs/best_practices/README.md
@@ -1,4 +1,7 @@
# Optimal Deployment
+- [ERNIE-4.5-0.3B-Paddle.md](ERNIE-4.5-0.3B-Paddle.md)
+- [ERNIE-4.5-21B-A3B-Paddle.md](ERNIE-4.5-21B-A3B-Paddle.md)
+- [ERNIE-4.5-300B-A47B-Paddle.md](ERNIE-4.5-300B-A47B-Paddle.md)
- [ERNIE-4.5-VL-28B-A3B-Paddle](ERNIE-4.5-VL-28B-A3B-Paddle.md)
- [ERNIE-4.5-VL-424B-A47B-Paddle](ERNIE-4.5-VL-424B-A47B-Paddle.md)
diff --git a/docs/get_started/ernie-4.5-vl.md b/docs/get_started/ernie-4.5-vl.md
index 1092ed19f..015fc6e5a 100644
--- a/docs/get_started/ernie-4.5-vl.md
+++ b/docs/get_started/ernie-4.5-vl.md
@@ -23,6 +23,7 @@ Execute the following command to start the service. For parameter configurations
>💡 **Note**: Since the model parameter size is 424B-A47B, on an 80G * 8 GPU machine, specify ```--quantization wint4``` (wint8 is also supported).
```shell
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-VL-424B-A47B-Paddle \
--port 8180 --engine-worker-queue-port 8181 \
diff --git a/docs/get_started/ernie-4.5.md b/docs/get_started/ernie-4.5.md
index 2d05c8c1a..ebfc4f514 100644
--- a/docs/get_started/ernie-4.5.md
+++ b/docs/get_started/ernie-4.5.md
@@ -21,6 +21,7 @@ Specify `--model baidu/ERNIE-4.5-300B-A47B-Paddle` during deployment to automati
Execute the following command to start the service. For configuration details, refer to the [Parameter Guide](../parameters.md):
```shell
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
--port 8180 --engine-worker-queue-port 8181 \
diff --git a/docs/get_started/installation/README.md b/docs/get_started/installation/README.md
index ba7042e26..77b037896 100644
--- a/docs/get_started/installation/README.md
+++ b/docs/get_started/installation/README.md
@@ -3,6 +3,7 @@
FastDeploy currently supports installation on the following hardware platforms:
- [NVIDIA GPU Installation](nvidia_gpu.md)
+- [Hygon DCU Installation](hygon_dcu.md)
- [Kunlun XPU Installation](kunlunxin_xpu.md)
- [Enflame S60 GCU Installation](Enflame_gcu.md)
- [Iluvatar GPU Installation](iluvatar_gpu.md)
diff --git a/docs/get_started/installation/kunlunxin_xpu.md b/docs/get_started/installation/kunlunxin_xpu.md
index 44aa350f7..aeaae3bac 100644
--- a/docs/get_started/installation/kunlunxin_xpu.md
+++ b/docs/get_started/installation/kunlunxin_xpu.md
@@ -25,9 +25,9 @@ Verified platform:
```bash
mkdir Work
cd Work
-docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.3
+docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
docker run --name fastdeploy-xpu --net=host -itd --privileged -v $PWD:/Work -w /Work \
- ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.3 \
+ ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0 \
/bin/bash
docker exec -it fastdeploy-xpu /bin/bash
```
@@ -37,7 +37,7 @@ docker exec -it fastdeploy-xpu /bin/bash
### Install PaddlePaddle
```bash
-python -m pip install paddlepaddle-xpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/
+python -m pip install paddlepaddle-xpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/
```
Alternatively, you can install the latest version of PaddlePaddle (Not recommended)
@@ -49,7 +49,7 @@ python -m pip install --pre paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/
### Install FastDeploy (**Do NOT install via PyPI source**)
```bash
-python -m pip install fastdeploy-xpu==2.0.3 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-xpu-p800/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
+python -m pip install fastdeploy-xpu==2.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-xpu-p800/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
```
Alternatively, you can install the latest version of FastDeploy (Not recommended)
@@ -63,7 +63,7 @@ python -m pip install --pre fastdeploy-xpu -i https://www.paddlepaddle.org.cn/pa
### Install PaddlePaddle
```bash
-python -m pip install paddlepaddle-xpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/
+python -m pip install paddlepaddle-xpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/
```
Alternatively, you can install the latest version of PaddlePaddle (Not recommended)
diff --git a/docs/get_started/installation/nvidia_gpu.md b/docs/get_started/installation/nvidia_gpu.md
index 97e3dc750..8cf69d5d9 100644
--- a/docs/get_started/installation/nvidia_gpu.md
+++ b/docs/get_started/installation/nvidia_gpu.md
@@ -20,7 +20,7 @@ docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12
First install paddlepaddle-gpu. For detailed instructions, refer to [PaddlePaddle Installation](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html)
```shell
-python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
+python -m pip install paddlepaddle-gpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
```
Then install fastdeploy. **Do not install from PyPI**. Use the following methods instead:
@@ -58,7 +58,7 @@ docker build -f dockerfiles/Dockerfile.gpu -t fastdeploy:gpu .
First install paddlepaddle-gpu. For detailed instructions, refer to [PaddlePaddle Installation](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html)
```shell
-python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
+python -m pip install paddlepaddle-gpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
```
Then clone the source code and build:
diff --git a/docs/get_started/quick_start.md b/docs/get_started/quick_start.md
index a9d2331ee..75dc0cc19 100644
--- a/docs/get_started/quick_start.md
+++ b/docs/get_started/quick_start.md
@@ -16,6 +16,7 @@ For more information about how to install FastDeploy, refer to the [installation
After installing FastDeploy, execute the following command in the terminal to start the service. For the configuration method of the startup command, refer to [Parameter Description](../parameters.md)
```
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-0.3B-Paddle \
--port 8180 \
diff --git a/docs/get_started/quick_start_vl.md b/docs/get_started/quick_start_vl.md
index 6e0b9a780..b9c50a1c2 100644
--- a/docs/get_started/quick_start_vl.md
+++ b/docs/get_started/quick_start_vl.md
@@ -19,6 +19,7 @@ For more information about how to install FastDeploy, refer to the [installation
After installing FastDeploy, execute the following command in the terminal to start the service. For the configuration method of the startup command, refer to [Parameter Description](../parameters.md)
```shell
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-VL-28B-A3B-Paddle \
--port 8180 \
diff --git a/docs/usage/kunlunxin_xpu_deployment.md b/docs/usage/kunlunxin_xpu_deployment.md
index 9aa19da29..51acb6868 100644
--- a/docs/usage/kunlunxin_xpu_deployment.md
+++ b/docs/usage/kunlunxin_xpu_deployment.md
@@ -5,8 +5,14 @@
|ERNIE-4.5-300B-A47B|32K|WINT4|4 (Recommended)|export XPU_VISIBLE_DEVICES="0,1,2,3" or "4,5,6,7"
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-300B-A47B-Paddle \
--port 8188 \
--tensor-parallel-size 4 \
--max-model-len 32768 \
--max-num-seqs 64 \
--quantization "wint4" \
--gpu-memory-utilization 0.9|>=2.0.0|
|ERNIE-4.5-300B-A47B|32K|WINT4|8|export XPU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-300B-A47B-Paddle \
--port 8188 \
--tensor-parallel-size 8 \
--max-model-len 32768 \
--max-num-seqs 64 \
--quantization "wint4" \
--gpu-memory-utilization 0.9|>=2.0.0|
|ERNIE-4.5-300B-A47B|128K|WINT4|8 (Recommended)|export XPU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-300B-A47B-Paddle \
--port 8188 \
--tensor-parallel-size 8 \
--max-model-len 131072 \
--max-num-seqs 64 \
--quantization "wint4" \
--gpu-memory-utilization 0.9|>=2.0.0|
+|ERNIE-4.5-21B-A3B|32K|BF16|1|export XPU_VISIBLE_DEVICES="0" # Specify any card
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-21B-A3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128 \
--gpu-memory-utilization 0.9|>=2.1.0|
+|ERNIE-4.5-21B-A3B|32K|WINT8|1|export XPU_VISIBLE_DEVICES="0" # Specify any card
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-21B-A3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128 \
--quantization "wint8" \
--gpu-memory-utilization 0.9|>=2.1.0|
+|ERNIE-4.5-21B-A3B|32K|WINT4|1|export XPU_VISIBLE_DEVICES="0" # Specify any card
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-21B-A3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128 \
--quantization "wint4" \
--gpu-memory-utilization 0.9|>=2.1.0|
+|ERNIE-4.5-21B-A3B|128K|BF16|1|export XPU_VISIBLE_DEVICES="0" # Specify any card
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-21B-A3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 131072 \
--max-num-seqs 128 \
--gpu-memory-utilization 0.9|>=2.1.0|
+|ERNIE-4.5-21B-A3B|128K|WINT8|1|export XPU_VISIBLE_DEVICES="0" # Specify any card
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-21B-A3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 131072 \
--max-num-seqs 128 \
--quantization "wint8" \
--gpu-memory-utilization 0.9|>=2.1.0|
+|ERNIE-4.5-21B-A3B|128K|WINT4|1|export XPU_VISIBLE_DEVICES="0" # Specify any card
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-21B-A3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 131072 \
--max-num-seqs 128 \
--quantization "wint4" \
--gpu-memory-utilization 0.9|>=2.1.0|
|ERNIE-4.5-0.3B|32K|BF16|1|export XPU_VISIBLE_DEVICES="0" # Specify any card
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-0.3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128 \
--gpu-memory-utilization 0.9|>=2.0.3|
-|ERNIE-4.5-0.3B|32K|WINT8|1|export XPU_VISIBLE_DEVICES="x" # Specify any card
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-0.3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128 \
--quantization "wint8" \
--gpu-memory-utilization 0.9|>=2.0.3|
+|ERNIE-4.5-0.3B|32K|WINT8|1|export XPU_VISIBLE_DEVICES="0" # Specify any card
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-0.3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128 \
--quantization "wint8" \
--gpu-memory-utilization 0.9|>=2.0.3|
|ERNIE-4.5-0.3B|128K|BF16|1|export XPU_VISIBLE_DEVICES="0" # Specify any card
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-0.3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 131072 \
--max-num-seqs 128 \
--gpu-memory-utilization 0.9|>=2.0.3|
|ERNIE-4.5-0.3B|128K|WINT8|1|export XPU_VISIBLE_DEVICES="0" # Specify any card
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-0.3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 131072 \
--max-num-seqs 128 \
--quantization "wint8" \
--gpu-memory-utilization 0.9|>=2.0.3|
diff --git a/docs/zh/best_practices/ERNIE-4.5-0.3B-Paddle.md b/docs/zh/best_practices/ERNIE-4.5-0.3B-Paddle.md
index bdfdbb275..4b9eb3343 100644
--- a/docs/zh/best_practices/ERNIE-4.5-0.3B-Paddle.md
+++ b/docs/zh/best_practices/ERNIE-4.5-0.3B-Paddle.md
@@ -25,12 +25,12 @@ ERNIE-4.5-0.3B 各量化精度,在下列硬件上部署所需要的最小卡
### 2.1 基础:启动服务
通过下列命令启动服务
```bash
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-0.3B-Paddle \
--tensor-parallel-size 1 \
--quantization wint4 \
--max-model-len 32768 \
- --kv-cache-ratio 0.75 \
--max-num-seqs 128
```
其中:
@@ -77,8 +77,8 @@ CUDAGraph 是 NVIDIA 提供的一项 GPU 计算加速技术,通过将 CUDA 操
```
注:
1. 通常情况下不需要额外设置其他参数,但CUDAGraph会产生一些额外的显存开销,在一些显存受限的场景下可能需要调整。详细的参数调整请参考[GraphOptimizationBackend](../parameters.md) 相关配置参数说明
-2. 开启CUDAGraph时,暂时只支持单卡推理,即`--tensor-parallel-size 1`
-3. 开启CUDAGraph时,暂时不支持同时开启`Chunked Prefill`和`Prefix Caching`
+2. 开启CUDAGraph时,如果是TP>1的多卡推理场景,需要同时指定 `--enable-custom-all-reduce`
+3. 开启CUDAGraph时,暂时不支持`max-model-len > 32768`的场景。
#### 2.2.5 拒绝采样
**原理:**
diff --git a/docs/zh/best_practices/ERNIE-4.5-21B-A3B-Paddle.md b/docs/zh/best_practices/ERNIE-4.5-21B-A3B-Paddle.md
index 8b494d890..e7be4601b 100644
--- a/docs/zh/best_practices/ERNIE-4.5-21B-A3B-Paddle.md
+++ b/docs/zh/best_practices/ERNIE-4.5-21B-A3B-Paddle.md
@@ -25,12 +25,12 @@ ERNIE-4.5-21B-A3B 各量化精度,在下列硬件上部署所需要的最小
### 2.1 基础:启动服务
通过下列命令启动服务
```bash
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-21B-A3B-Paddle \
--tensor-parallel-size 1 \
--quantization wint4 \
--max-model-len 32768 \
- --kv-cache-ratio 0.75 \
--max-num-seqs 128
```
其中:
@@ -87,8 +87,8 @@ CUDAGraph 是 NVIDIA 提供的一项 GPU 计算加速技术,通过将 CUDA 操
```
注:
1. 通常情况下不需要额外设置其他参数,但CUDAGraph会产生一些额外的显存开销,在一些显存受限的场景下可能需要调整。详细的参数调整请参考[GraphOptimizationBackend](../parameters.md) 相关配置参数说明
-2. 开启CUDAGraph时,暂时只支持单卡推理,即`--tensor-parallel-size 1`
-3. 开启CUDAGraph时,暂时不支持同时开启`Chunked Prefill`和`Prefix Caching`
+2. 开启CUDAGraph时,如果是TP>1的多卡推理场景,需要同时指定 `--enable-custom-all-reduce`
+3. 开启CUDAGraph时,暂时不支持`max-model-len > 32768`的场景。
#### 2.2.6 拒绝采样
**原理:**
@@ -111,6 +111,7 @@ export INFERENCE_MSG_QUEUE_ID=1315
export FLAGS_max_partition_size=2048
export FD_ATTENTION_BACKEND=FLASH_ATTN
export FD_LOG_DIR="prefill_log"
+export ENABLE_V1_KVCACHE_SCHEDULER=1
quant_type=block_wise_fp8
export FD_USE_DEEP_GEMM=0
@@ -120,7 +121,7 @@ python -m fastdeploy.entrypoints.openai.api_server --model baidu/ERNIE-4.5-21B-A
--max-num-seqs 20 \
--num-gpu-blocks-override 40000 \
--quantization ${quant_type} \
- --gpu-memory-utilization 0.9 --kv-cache-ratio 0.9 \
+ --gpu-memory-utilization 0.9 \
--port 7012 --engine-worker-queue-port 7013 --metrics-port 7014 --tensor-parallel-size 4 \
--cache-queue-port 7015 \
--splitwise-role "prefill" \
@@ -131,6 +132,7 @@ export CUDA_VISIBLE_DEVICES=4,5,6,7
export INFERENCE_MSG_QUEUE_ID=1215
export FLAGS_max_partition_size=2048
export FD_LOG_DIR="decode_log"
+export ENABLE_V1_KVCACHE_SCHEDULER=1
quant_type=block_wise_fp8
export FD_USE_DEEP_GEMM=0
@@ -139,7 +141,7 @@ python -m fastdeploy.entrypoints.openai.api_server --model baidu/ERNIE-4.5-21B-A
--max-model-len 131072 \
--max-num-seqs 20 \
--quantization ${quant_type} \
- --gpu-memory-utilization 0.85 --kv-cache-ratio 0.1 \
+ --gpu-memory-utilization 0.85 \
--port 9012 --engine-worker-queue-port 8013 --metrics-port 8014 --tensor-parallel-size 4 \
--cache-queue-port 8015 \
--innode-prefill-ports 7013 \
diff --git a/docs/zh/best_practices/ERNIE-4.5-300B-A47B-Paddle.md b/docs/zh/best_practices/ERNIE-4.5-300B-A47B-Paddle.md
index b265c75a1..108e879b4 100644
--- a/docs/zh/best_practices/ERNIE-4.5-300B-A47B-Paddle.md
+++ b/docs/zh/best_practices/ERNIE-4.5-300B-A47B-Paddle.md
@@ -22,12 +22,12 @@ ERNIE-4.5-300B-A47B各量化精度,在下列硬件上部署所需要的最小
### 2.1 基础:启动服务
通过下列命令启动服务
```bash
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
--tensor-parallel-size 8 \
--quantization wint4 \
--max-model-len 32768 \
- --kv-cache-ratio 0.75 \
--max-num-seqs 128
```
其中:
@@ -100,6 +100,7 @@ export FD_SAMPLING_CLASS=rejection
**启用方式:** 以单机8GPU,1P1D(各4GPU)部署为例,与默认的混合式部署方式相比, 需要`--splitwise-role`指定节点的角色。并通过环境变量`FD_LOG_DIR`和`CUDA_VISIBLE_DEVICES`将两个节点的GPU 和日志隔离开
```
export FD_LOG_DIR="log_prefill"
+export ENABLE_V1_KVCACHE_SCHEDULER=1
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
@@ -112,6 +113,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
```
```
export FD_LOG_DIR="log_decode"
+export ENABLE_V1_KVCACHE_SCHEDULER=1
export CUDA_VISIBLE_DEVICES=4,5,6,7
# 注意innode-prefill-ports指定为Prefill服务的engine-worker-queue-port
python -m fastdeploy.entrypoints.openai.api_server \
@@ -125,5 +127,20 @@ python -m fastdeploy.entrypoints.openai.api_server \
--splitwise-role "decode"
```
+#### 2.2.8 CUDAGraph
+**原理:**
+CUDAGraph 是 NVIDIA 提供的一项 GPU 计算加速技术,通过将 CUDA 操作序列捕获(capture)为图结构(graph),实现 GPU 任务的高效执行和优化。CUDAGraph 的核心思想是将一系列 GPU 计算和内存操作封装为一个可重复执行的图,从而减少 CPU-GPU 通信开销、降低内核启动延迟,并提升整体计算性能。
+
+**启用方式:**
+在启动命令中增加
+```
+--use-cudagraph
+--enable-custom-all-reduce
+```
+注:
+1. 通常情况下不需要额外设置其他参数,但CUDAGraph会产生一些额外的显存开销,在一些显存受限的场景下可能需要调整。详细的参数调整请参考[GraphOptimizationBackend](../parameters.md) 相关配置参数说明
+2. 开启CUDAGraph时,如果是TP>1的多卡推理场景,需要同时指定 `--enable-custom-all-reduce`
+3. 开启CUDAGraph时,暂时不支持`max-model-len > 32768`的场景。
+
## 三、常见问题FAQ
如果您在使用过程中遇到问题,可以在[FAQ](./FAQ.md)中查阅。
diff --git a/docs/zh/best_practices/ERNIE-4.5-VL-28B-A3B-Paddle.md b/docs/zh/best_practices/ERNIE-4.5-VL-28B-A3B-Paddle.md
index f5b18de53..12ebb2696 100644
--- a/docs/zh/best_practices/ERNIE-4.5-VL-28B-A3B-Paddle.md
+++ b/docs/zh/best_practices/ERNIE-4.5-VL-28B-A3B-Paddle.md
@@ -9,9 +9,9 @@
|:----------:|:----------:|:------:| :------:|
| A30 [24G] | 2 | 2 | 4 |
| L20 [48G] | 1 | 1 | 2 |
-| H20 [144G] | 1 | 1 | 1 |
-| A100 [80G] | 1 | 1 | 1 |
-| H800 [80G] | 1 | 1 | 1 |
+| H20 [144G] | 1 | 1 | 1 |
+| A100 [80G] | 1 | 1 | 1 |
+| H800 [80G] | 1 | 1 | 1 |
### 1.2 安装fastdeploy
@@ -26,7 +26,6 @@
**示例1:** 4090上单卡部署32K上下文的服务
```shell
export ENABLE_V1_KVCACHE_SCHEDULER=1
-
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-VL-28B-A3B-Paddle \
--port 8180 \
@@ -46,7 +45,6 @@ python -m fastdeploy.entrypoints.openai.api_server \
**示例2:** H800上双卡部署128K上下文的服务
```shell
export ENABLE_V1_KVCACHE_SCHEDULER=1
-
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-VL-28B-A3B-Paddle \
--port 8180 \
@@ -63,6 +61,8 @@ python -m fastdeploy.entrypoints.openai.api_server \
--quantization wint4 \
--enable-mm
```
+> ⚠️ 2.1及以上版本需要通过环境变量开启新调度器 `ENABLE_V1_KVCACHE_SCHEDULER=1`,否则可能会有部分请求最大长度前截断或返空。
+
示例是可以稳定运行的一组配置,同时也能得到比较好的性能。
如果对精度、性能有进一步的要求,请继续阅读下面的内容。
### 2.2 进阶:如何获取更优性能
@@ -110,6 +110,15 @@ python -m fastdeploy.entrypoints.openai.api_server \
- 若需要稍高的精度,可尝试WINT8。
- 仅当您的应用场景对精度有极致要求时候才尝试使用BFLOAT16,因为它需要更多显存。
+#### 2.2.4 **可调整的环境变量**
+> **拒绝采样:**`FD_SAMPLING_CLASS=rejection`
+- **描述**:拒绝采样即从一个易于采样的提议分布(proposal distribution)中生成样本,避免显式排序从而达到提升采样速度的效果,可以提升推理性能。
+- **推荐**:这是一种影响效果的较为激进的优化策略,我们还在全面验证影响。如果对性能有较高要求,也可以接受对效果的影响时可以尝试开启。
+
+> **Attention超参:**`FLAGS_max_partition_size=1024`
+- **描述**:Append Attntion(默认)后端的超参,我们在常用数据集上的测试结果表明,设置为1024后可以大幅提升解码速度,尤其是长文场景。
+- **推荐**:未来会修改为自动调整的机制。如果对性能有较高要求可以尝试开启。
+
## 三、常见问题FAQ
**注意:** 使用多模服务部署需要在配置中添加参数 `--enable-mm`。
diff --git a/docs/zh/best_practices/ERNIE-4.5-VL-424B-A47B-Paddle.md b/docs/zh/best_practices/ERNIE-4.5-VL-424B-A47B-Paddle.md
index fafaefa7d..bb83c02fe 100644
--- a/docs/zh/best_practices/ERNIE-4.5-VL-424B-A47B-Paddle.md
+++ b/docs/zh/best_practices/ERNIE-4.5-VL-424B-A47B-Paddle.md
@@ -23,8 +23,6 @@
### 2.1 基础:启动服务
**示例1:** H800上8卡部署128K上下文的服务
```shell
-export ENABLE_V1_KVCACHE_SCHEDULER=1
-
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-VL-424B-A47B-Paddle \
--port 8180 \
@@ -41,6 +39,8 @@ python -m fastdeploy.entrypoints.openai.api_server \
--quantization wint4 \
--enable-mm
```
+> ⚠️ 2.1及以上版本需要通过环境变量开启新调度器 `ENABLE_V1_KVCACHE_SCHEDULER=1`,否则可能会有部分请求最大长度前截断或返空。
+
示例是可以稳定运行的一组配置,同时也能得到比较好的性能。
如果对精度、性能有进一步的要求,请继续阅读下面的内容。
### 2.2 进阶:如何获取更优性能
@@ -87,6 +87,15 @@ python -m fastdeploy.entrypoints.openai.api_server \
- 若需要稍高的精度,可尝试WINT8。
- 仅当您的应用场景对精度有极致要求时候才尝试使用BFLOAT16,因为它需要更多显存。
+#### 2.2.4 **可调整的环境变量**
+> **拒绝采样:**`FD_SAMPLING_CLASS=rejection`
+- **描述**:拒绝采样即从一个易于采样的提议分布(proposal distribution)中生成样本,避免显式排序从而达到提升采样速度的效果,可以提升推理性能。
+- **推荐**:这是一种影响效果的较为激进的优化策略,我们还在全面验证影响。如果对性能有较高要求,也可以接受对效果的影响时可以尝试开启。
+
+> **Attention超参:**`FLAGS_max_partition_size=1024`
+- **描述**:Append Attntion(默认)后端的超参,我们在常用数据集上的测试结果表明,设置为1024后可以大幅提升解码速度,尤其是长文场景。
+- **推荐**:未来会修改为自动调整的机制。如果对性能有较高要求可以尝试开启。
+
## 三、常见问题FAQ
**注意:** 使用多模服务部署需要在配置中添加参数 `--enable-mm`。
diff --git a/docs/zh/best_practices/README.md b/docs/zh/best_practices/README.md
index b4e7401a0..daf9758b1 100644
--- a/docs/zh/best_practices/README.md
+++ b/docs/zh/best_practices/README.md
@@ -1,4 +1,7 @@
# 最佳实践
+- [ERNIE-4.5-0.3B-Paddle.md](ERNIE-4.5-0.3B-Paddle.md)
+- [ERNIE-4.5-21B-A3B-Paddle.md](ERNIE-4.5-21B-A3B-Paddle.md)
+- [ERNIE-4.5-300B-A47B-Paddle.md](ERNIE-4.5-300B-A47B-Paddle.md)
- [ERNIE-4.5-VL-28B-A3B-Paddle](ERNIE-4.5-VL-28B-A3B-Paddle.md)
- [ERNIE-4.5-VL-424B-A47B-Paddle](ERNIE-4.5-VL-424B-A47B-Paddle.md)
diff --git a/docs/zh/get_started/ernie-4.5-vl.md b/docs/zh/get_started/ernie-4.5-vl.md
index 3f12904c5..6fed957d4 100644
--- a/docs/zh/get_started/ernie-4.5-vl.md
+++ b/docs/zh/get_started/ernie-4.5-vl.md
@@ -23,6 +23,7 @@
**注意**: 由于模型参数量为424B-A47B,在80G * 8卡的机器上,需指定```--quantization wint4```(wint8也可部署)。
```shell
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-VL-424B-A47B-Paddle \
--port 8180 --engine-worker-queue-port 8181 \
diff --git a/docs/zh/get_started/ernie-4.5.md b/docs/zh/get_started/ernie-4.5.md
index 4c8bc6ea0..666b081e9 100644
--- a/docs/zh/get_started/ernie-4.5.md
+++ b/docs/zh/get_started/ernie-4.5.md
@@ -21,6 +21,7 @@
执行如下命令,启动服务,其中启动命令配置方式参考[参数说明](../parameters.md)。
```shell
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
--port 8180 --engine-worker-queue-port 8181 \
diff --git a/docs/zh/get_started/installation/README.md b/docs/zh/get_started/installation/README.md
index 80638604b..051259a30 100644
--- a/docs/zh/get_started/installation/README.md
+++ b/docs/zh/get_started/installation/README.md
@@ -1,8 +1,9 @@
-# FastDeploy Installation Guide
+# FastDeploy 安装
-FastDeploy currently supports installation on the following hardware platforms:
+FastDeploy支持如下硬件平台:
- [NVIDIA GPU Installation](nvidia_gpu.md)
+- [Hygon DCU Installation](hygon_dcu.md)
- [Kunlunxin XPU Installation](kunlunxin_xpu.md)
- [Enflame S60 GCU Installation](Enflame_gcu.md)
- [Iluvatar GPU Installation](iluvatar_gpu.md)
diff --git a/docs/zh/get_started/installation/kunlunxin_xpu.md b/docs/zh/get_started/installation/kunlunxin_xpu.md
index a5ee97dca..29fb801fc 100644
--- a/docs/zh/get_started/installation/kunlunxin_xpu.md
+++ b/docs/zh/get_started/installation/kunlunxin_xpu.md
@@ -25,9 +25,9 @@
```bash
mkdir Work
cd Work
-docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.3
+docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
docker run --name fastdeploy-xpu --net=host -itd --privileged -v $PWD:/Work -w /Work \
- ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.3 \
+ ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0 \
/bin/bash
docker exec -it fastdeploy-xpu /bin/bash
```
@@ -37,7 +37,7 @@ docker exec -it fastdeploy-xpu /bin/bash
### 安装 PaddlePaddle
```bash
-python -m pip install paddlepaddle-xpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/
+python -m pip install paddlepaddle-xpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/
```
或者您也可以安装最新版 PaddlePaddle(不推荐)
@@ -49,7 +49,7 @@ python -m pip install --pre paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/
### 安装 FastDeploy(**注意不要通过 pypi 源安装**)
```bash
-python -m pip install fastdeploy-xpu==2.0.3 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-xpu-p800/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
+python -m pip install fastdeploy-xpu==2.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-xpu-p800/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
```
或者你也可以安装最新版 FastDeploy(不推荐)
@@ -63,7 +63,7 @@ python -m pip install --pre fastdeploy-xpu -i https://www.paddlepaddle.org.cn/pa
### 安装 PaddlePaddle
```bash
-python -m pip install paddlepaddle-xpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/
+python -m pip install paddlepaddle-xpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/
```
或者您也可以安装最新版 PaddlePaddle(不推荐)
diff --git a/docs/zh/get_started/installation/nvidia_gpu.md b/docs/zh/get_started/installation/nvidia_gpu.md
index 94c111fe1..5744a1892 100644
--- a/docs/zh/get_started/installation/nvidia_gpu.md
+++ b/docs/zh/get_started/installation/nvidia_gpu.md
@@ -23,7 +23,7 @@ docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12
首先安装 paddlepaddle-gpu,详细安装方式参考 [PaddlePaddle安装](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html)
``` shell
-python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
+python -m pip install paddlepaddle-gpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
```
再安装 fastdeploy,**注意不要通过pypi源安装**,需要通过如下方式安装
@@ -64,7 +64,7 @@ docker build -f dockerfiles/Dockerfile.gpu -t fastdeploy:gpu .
首先安装 paddlepaddle-gpu,详细安装方式参考 [PaddlePaddle安装](https://www.paddlepaddle.org.cn/)
``` shell
-python -m pip install paddlepaddle-gpu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
+python -m pip install paddlepaddle-gpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
```
接着克隆源代码,编译安装
diff --git a/docs/zh/get_started/quick_start.md b/docs/zh/get_started/quick_start.md
index 46da9fa05..178c7ba02 100644
--- a/docs/zh/get_started/quick_start.md
+++ b/docs/zh/get_started/quick_start.md
@@ -17,6 +17,7 @@
安装FastDeploy后,在终端执行如下命令,启动服务,其中启动命令配置方式参考[参数说明](../parameters.md)
```shell
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-0.3B-Paddle \
--port 8180 \
diff --git a/docs/zh/get_started/quick_start_vl.md b/docs/zh/get_started/quick_start_vl.md
index b3b153817..b031378ac 100644
--- a/docs/zh/get_started/quick_start_vl.md
+++ b/docs/zh/get_started/quick_start_vl.md
@@ -19,6 +19,7 @@
安装FastDeploy后,在终端执行如下命令,启动服务,其中启动命令配置方式参考[参数说明](../parameters.md)
```shell
+export ENABLE_V1_KVCACHE_SCHEDULER=1
python -m fastdeploy.entrypoints.openai.api_server \
--model baidu/ERNIE-4.5-VL-28B-A3B-Paddle \
--port 8180 \
diff --git a/docs/zh/usage/kunlunxin_xpu_deployment.md b/docs/zh/usage/kunlunxin_xpu_deployment.md
index 743775f07..721c429fa 100644
--- a/docs/zh/usage/kunlunxin_xpu_deployment.md
+++ b/docs/zh/usage/kunlunxin_xpu_deployment.md
@@ -5,6 +5,12 @@
|ERNIE-4.5-300B-A47B|32K|WINT4|4 (推荐)|export XPU_VISIBLE_DEVICES="0,1,2,3" or "4,5,6,7"
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-300B-A47B-Paddle \
--port 8188 \
--tensor-parallel-size 4 \
--max-model-len 32768 \
--max-num-seqs 64 \
--quantization "wint4" \
--gpu-memory-utilization 0.9|>=2.0.0|
|ERNIE-4.5-300B-A47B|32K|WINT4|8|export XPU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-300B-A47B-Paddle \
--port 8188 \
--tensor-parallel-size 8 \
--max-model-len 32768 \
--max-num-seqs 64 \
--quantization "wint4" \
--gpu-memory-utilization 0.9|>=2.0.0|
|ERNIE-4.5-300B-A47B|128K|WINT4|8 (推荐)|export XPU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-300B-A47B-Paddle \
--port 8188 \
--tensor-parallel-size 8 \
--max-model-len 131072 \
--max-num-seqs 64 \
--quantization "wint4" \
--gpu-memory-utilization 0.9|>=2.0.0|
+|ERNIE-4.5-21B-A3B|32K|BF16|1|export XPU_VISIBLE_DEVICES="0" # 指定任意一张卡
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-21B-A3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128 \
--gpu-memory-utilization 0.9|>=2.1.0|
+|ERNIE-4.5-21B-A3B|32K|WINT8|1|export XPU_VISIBLE_DEVICES="0" # 指定任意一张卡
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-21B-A3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128 \
--quantization "wint8" \
--gpu-memory-utilization 0.9|>=2.1.0|
+|ERNIE-4.5-21B-A3B|32K|WINT4|1|export XPU_VISIBLE_DEVICES="0" # 指定任意一张卡
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-21B-A3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128 \
--quantization "wint4" \
--gpu-memory-utilization 0.9|>=2.1.0|
+|ERNIE-4.5-21B-A3B|128K|BF16|1|export XPU_VISIBLE_DEVICES="0" # 指定任意一张卡
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-21B-A3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 131072 \
--max-num-seqs 128 \
--gpu-memory-utilization 0.9|>=2.1.0|
+|ERNIE-4.5-21B-A3B|128K|WINT8|1|export XPU_VISIBLE_DEVICES="0" # 指定任意一张卡
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-21B-A3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 131072 \
--max-num-seqs 128 \
--quantization "wint8" \
--gpu-memory-utilization 0.9|>=2.1.0|
+|ERNIE-4.5-21B-A3B|128K|WINT4|1|export XPU_VISIBLE_DEVICES="0" # 指定任意一张卡
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-21B-A3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 131072 \
--max-num-seqs 128 \
--quantization "wint4" \
--gpu-memory-utilization 0.9|>=2.1.0|
|ERNIE-4.5-0.3B|32K|BF16|1|export XPU_VISIBLE_DEVICES="0" # 指定任意一张卡
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-0.3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128 \
--gpu-memory-utilization 0.9|>=2.0.3|
|ERNIE-4.5-0.3B|32K|WINT8|1|export XPU_VISIBLE_DEVICES="x" # 指定任意一张卡
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-0.3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128 \
--quantization "wint8" \
--gpu-memory-utilization 0.9|>=2.0.3|
|ERNIE-4.5-0.3B|128K|BF16|1|export XPU_VISIBLE_DEVICES="0" # 指定任意一张卡
python -m fastdeploy.entrypoints.openai.api_server \
--model PaddlePaddle/ERNIE-4.5-0.3B-Paddle \
--port 8188 \
--tensor-parallel-size 1 \
--max-model-len 131072 \
--max-num-seqs 128 \
--gpu-memory-utilization 0.9|>=2.0.3|