diff --git a/.clang-format b/.clang-format
new file mode 100644
index 000000000..3bb927623
--- /dev/null
+++ b/.clang-format
@@ -0,0 +1,29 @@
+# This file is used by clang-format to autoformat paddle source code
+#
+# The clang-format is part of llvm toolchain.
+# It need to install llvm and clang to format source code style.
+#
+# The basic usage is,
+# clang-format -i -style=file PATH/TO/SOURCE/CODE
+#
+# The -style=file implicit use ".clang-format" file located in one of
+# parent directory.
+# The -i means inplace change.
+#
+# The document of clang-format is
+# http://clang.llvm.org/docs/ClangFormat.html
+# http://clang.llvm.org/docs/ClangFormatStyleOptions.html
+---
+Language: Cpp
+BasedOnStyle: Google
+IndentWidth: 4
+TabWidth: 2
+ContinuationIndentWidth: 4
+AccessModifierOffset: -1 # The private/protected/public has no indent in class
+Standard: Cpp11
+AllowAllParametersOfDeclarationOnNextLine: true
+BinPackParameters: false
+BinPackArguments: false
+IncludeBlocks: Preserve
+IncludeIsMainSourceRegex: (\.cu)$
+...
diff --git a/.gitignore b/.gitignore
index 35c771cf5..f94e8f7cc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -121,7 +121,7 @@ dmypy.json
FETCH_HEAD
#log
-log/
+log*/
checkpoints/
checkpoints_origin/
@@ -158,3 +158,7 @@ custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute
# buff
custom_ops/tmp*
+
+build
+
+.ccls-cache
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 4b08b23db..faa05efbf 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -16,7 +16,7 @@ repos:
rev: v0.11.7
hooks:
- id: ruff
- args: [--output-format, github, --fix]
+ args: [--output-format, github, --fix, --line-length=120]
# # 拼写检查
# - repo: https://github.com/codespell-project/codespell
# rev: v2.4.1
@@ -29,14 +29,15 @@ repos:
rev: 6.0.1
hooks:
- id: isort
-# 格式化
-- repo: https://github.com/pre-commit/mirrors-clang-format
- rev: v20.1.3
- hooks:
- - id: clang-format
- # exclude: '.*'
- types_or: [c++, cuda]
- args: [--style=file, --verbose]
+# # 格式化
+# - repo: https://github.com/pre-commit/mirrors-clang-format
+# rev: v20.1.3
+# hooks:
+# - id: clang-format
+# # exclude: '.*'
+# types_or: [c++, cuda]
+# args: [--style=file, --verbose]
+
# markdown
- repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.29
diff --git a/README.md b/README.md
index 86ebda86d..55963d04d 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,8 @@
-# FastDeploy 2.0: 大模型推理部署
-
-
-
-
+
+
+
+
@@ -11,105 +10,78 @@
-FastDeploy升级2.0版本支持多种大模型推理(当前仅支持Qwen2,更多模型即将更新支持),其推理部署功能涵盖:
+
+ Installation
+ |
+ Quick Start
+ |
+ Supported Models
+
-- 一行命令即可快速实现模型的服务化部署,并支持流式生成
-- 利用张量并行技术加速模型推理
-- 支持 PagedAttention 与 continuous batching(动态批处理)
-- 兼容 OpenAI 的 HTTP 协议
-- 提供 Weight only int8/int4 无损压缩方案
-- 支持 Prometheus Metrics 指标
+--------------------------------------------------------------------------------
+# FastDeploy 2.0: Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
-> 注意: 如果你还在使用FastDeploy部署小模型(如PaddleClas/PaddleOCR等CV套件模型),请checkout [release/1.1.0分支](https://github.com/PaddlePaddle/FastDeploy/tree/release/1.1.0)。
+## News
-## 环境依赖
-- A800/H800/H100
-- Python>=3.10
-- CUDA>=12.3
-- CUDNN>=9.5
-- Linux X64
+**[2025-06] 🔥 Released FastDeploy v2.0:** Supports inference and deployment for ERNIE 4.5. Furthermore, we open-source an industrial-grade PD disaggregation with context caching, dynamic role switching for effective resource utilization to further enhance inference performance for MoE models.
-## 安装
+## About
-### Docker安装(推荐)
-```
-docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy:2.0.0.0-alpha
-```
+**FastDeploy** is an inference and deployment toolkit for large language models and visual language models based on PaddlePaddle. It delivers **production-ready, out-of-the-box deployment solutions** with core acceleration technologies:
-### 源码安装
-#### 安装PaddlePaddle
-> 注意安装nightly build版本,代码版本需新于2025.05.30,详见[PaddlePaddle安装](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html),指定安装CUDA 12.6 develop(Nightly build)版本。
-```
-python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
-```
+- 🚀 **Load-Balanced PD Disaggregation**: Industrial-grade solution featuring context caching and dynamic instance role switching. Optimizes resource utilization while balancing SLO compliance and throughput.
+- 🔄 **Unified KV Cache Transmission**: Lightweight high-performance transport library with intelligent NVLink/RDMA selection.
+- 🤝 **OpenAI API Server and vLLM Compatible**: One-command deployment with [vLLM](https://github.com/vllm-project/vllm/) interface compatibility.
+- 🧮 **Comprehensive Quantization Format Support**: W8A16, W8A8, W4A16, W4A8, W2A16, FP8, and more.
+- ⏩ **Advanced Acceleration Techniques**: Speculative decoding, Multi-Token Prediction (MTP) and Chunked Prefill.
+- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU etc.
-#### 编译安装FastDeploy
+## Requirements
-```
-# 编译
-cd FastDeploy
-bash build.sh
-# 安装
-pip install dist/fastdeploy-2.0.0a0-py3-none-any.whl
-```
+- OS: Linux
+- Python: 3.10 ~ 3.12
-## 快速使用
+## Installation
-在安装后,执行如下命令快速部署Qwen2模型, 更多参数的配置与含义参考[参数说明](docs/serving.md).
+FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, and other hardware. For detailed installation instructions:
-``` shell
-# 下载与解压Qwen模型
-wget https://fastdeploy.bj.bcebos.com/llm/models/Qwen2-7B-Instruct.tar.gz && tar xvf Qwen2-7B-Instruct.tar.gz
-# 指定单卡部署
-python -m fastdeploy.entrypoints.openai.api_server --model ./Qwen2-7B-Instruct --port 8188 --tensor-parallel-size 1
-```
+- [NVIDIA GPU](./docs/installation/nvidia_cuda.md)
+- [Kunlunxin XPU](./docs/en/get_started/installation/kunlunxin_xpu.md)
+- [Iluvatar GPU](./docs/en/get_started/installation/iluvatar_gpu.md)
+- [Enflame GCU](./docs/en/get_started/installation/Enflame_gcu.md)
-使用如下命令请求模型服务
-``` shell
-curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
--H "Content-Type: application/json" \
--d '{
- "messages": [
- {"role": "user", "content": "你好,你的名字是什么?"}
- ]
-}'
-```
-响应结果如下所示
-``` json
-{
- "id": "chatcmpl-db662f47-7c8c-4945-9a7a-db563b2ddd8d",
- "object": "chat.completion",
- "created": 1749451045,
- "model": "default",
- "choices": [
- {
- "index": 0,
- "message": {
- "role": "assistant",
- "content": "你好!我叫通义千问。",
- "reasoning_content": null
- },
- "finish_reason": "stop"
- }
- ],
- "usage": {
- "prompt_tokens": 25,
- "total_tokens": 35,
- "completion_tokens": 10,
- "prompt_tokens_details": null
- }
-}
-```
-FastDeploy提供与OpenAI完全兼容的服务API(字段`model`与`api_key`目前不支持,设定会被忽略),用户也可基于openai python api请求服务。
+**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU, Hygon DCU, and MetaX GPU are currently under development and testing. Stay tuned for updates!
-## 部署文档
-- [本地部署](docs/offline_inference.md)
-- [服务部署](docs/serving.md)
-- [服务metrics](docs/metrics.md)
+## Get Started
-# 代码说明
-- [代码目录说明](docs/code_guide.md)
-- FastDeploy的使用中存在任何建议和问题,欢迎通过issue反馈。
+Learn how to use FastDeploy through our documentation:
+- [10-Minutes Quick Deployment](./docs/get_started/quick_start.md)
+- [ERNIE-4.5 Large Language Model Deployment](./docs/get_started/ernie-4.5.md)
+- [ERNIE-4.5-VL Multimodal Model Deployment](./docs/get_started/ernie-4.5-vl.md)
+- [Offline Inference Development](./docs/offline_inference.md)
+- [Online Service Deployment](./docs/serving/README.md)
+- [Full Supported Models List](./docs/supported_models.md)
-# 开源说明
-FastDeploy遵循[Apache-2.0开源协议](./LICENSE)。 在本项目的开发中,为了对齐[vLLM](https://github.com/vllm-project/vllm)使用接口,参考和直接使用了部分vLLM代码,在此表示感谢。
+## Supported Models
+
+| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
+|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
+|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅(WINT4/W4A8C8/Expert Parallelism)| ✅ | ✅|✅(WINT4)| WIP |128K |
+|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅(WINT4/Expert Parallelism)| ✅ | ✅|✅(WINT4)| ❌ | 128K |
+|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
+|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
+|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
+|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
+|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅| 128K |
+
+## Advanced Usage
+
+- [Quantization](./docs/quantization/README.md)
+- [PD Disaggregation Deployment](./docs/features/pd_disaggregation.md)
+- [Speculative Decoding](./docs/features/speculative_decoding.md)
+- [Prefix Caching](./docs/features/prefix_caching.md)
+- [Chunked Prefill](./docs/features/chunked_prefill.md)
+
+## Acknowledgement
+
+FastDeploy is licensed under the [Apache-2.0 open-source license](./LICENSE). During development, portions of [vLLM](https://github.com/vllm-project/vllm) code were referenced and incorporated to maintain interface compatibility, for which we express our gratitude.
diff --git a/benchmarks/README.md b/benchmarks/README.md
new file mode 100644
index 000000000..d7a7e5007
--- /dev/null
+++ b/benchmarks/README.md
@@ -0,0 +1,106 @@
+### FastDeploy服务化性能压测工具
+
+#### 数据集:
+
+wget下载到本地用于性能测试
+
+
+
+
+ Dataset |
+ Data Path |
+
+
+
+
+ 开源数据集 2k条 |
+ https://fastdeploy.bj.bcebos.com/eb_query/filtered_sharedgpt_2000_input_1136_output_200_fd.json |
+
+
+
+#### 使用方式:
+
+```
+# 安装依赖
+python -m pip install -r requirements.txt
+```
+
+##### 参数说明
+
+```bash
+--backend openai-chat:压测使用的后端接口,指定为"openai-chat"使用chat/completion接口
+--model EB45T:模型名,任意取名,影响最后保存的结果文件名 EB45T \
+--endpoint /v1/chat/completions:endpoint,用于组url
+--host 0.0.0.0:服务ip地址,用于组url
+--port 9812:服务HTTP端口,用于组url
+--dataset-name EBChat:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集
+--dataset-path ./eb45t_spv4_dataserver_1w_waigua_fd:压测数据集路径
+--hyperparameter-path EB45T.yaml:(可选)超参文件,请求时会更新进payload中,默认不带任何超参
+--percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len:性能结果中展示的指标集合
+--metric-percentiles 80,95,99,99.9,99.95,99.99:性能结果中展示的性能指标分位值
+--num-prompts 1:总计发送多少条请求
+--max-concurrency 1:压测并发数
+--save-result:开启结果保存,结果文件会存入json
+```
+
+##### /v1/chat/completions接口压测单条数据调试
+
+```
+python benchmark_serving.py \
+ --backend openai-chat \
+ --model EB45T \
+ --endpoint /v1/chat/completions \
+ --host 0.0.0.0 \
+ --port 9812 \
+ --dataset-name EBChat \
+ --dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json \
+ --hyperparameter-path yaml/request_yaml/eb45t-32k.yaml \
+ --percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \
+ --metric-percentiles 80,95,99,99.9,99.95,99.99 \
+ --num-prompts 1 \
+ --max-concurrency 1 \
+ --save-result
+```
+
+##### /v1/chat/completions接口完整100并发 2000条压测
+
+```
+# 保存infer_log.txt
+python benchmark_serving.py \
+ --backend openai-chat \
+ --model EB45T \
+ --endpoint /v1/chat/completions \
+ --host 0.0.0.0 \
+ --port 9812 \
+ --dataset-name EBChat \
+ --dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json \
+ --hyperparameter-path yaml/request_yaml/eb45t-32k.yaml \
+ --percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \
+ --metric-percentiles 80,95,99,99.9,99.95,99.99 \
+ --num-prompts 2000 \
+ --max-concurrency 100 \
+ --save-result > infer_log.txt 2>&1 &
+```
+
+##### /v1/completions接口压测
+
+修改endpoint为/v1/completions,backend为openai,会对/v1/completions接口进行压测
+
+```
+# 保存infer_log.txt
+python benchmark_serving.py \
+ --backend openai \
+ --model EB45T \
+ --endpoint /v1/completions \
+ --host 0.0.0.0 \
+ --port 9812 \
+ --dataset-name EBChat \
+ --dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json \
+ --hyperparameter-path yaml/request_yaml/eb45t-32k.yaml \
+ --percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \
+ --metric-percentiles 80,95,99,99.9,99.95,99.99 \
+ --num-prompts 2000 \
+ --max-concurrency 100 \
+ --save-result > infer_log.txt 2>&1 &
+```
+
diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py
new file mode 100644
index 000000000..84b11d7a9
--- /dev/null
+++ b/benchmarks/backend_request_func.py
@@ -0,0 +1,700 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+# This file is modified from https://github.com/vllm-project/vllm/blob/main/benchmarks/backend_request_func.py
+
+
+import io
+import json
+import os
+import sys
+import time
+import traceback
+from dataclasses import dataclass, field
+from typing import Optional
+
+import aiohttp
+from tqdm.asyncio import tqdm
+
+
+AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
+
+
+@dataclass
+class RequestFuncInput:
+ """Input for requesting LLMs via API"""
+ prompt: str
+ history_QA: Optional[dict]
+ hyper_parameters: dict
+ api_url: str
+ prompt_len: int
+ output_len: int
+ model: str
+ model_name: Optional[str] = None
+ logprobs: Optional[int] = None
+ extra_body: Optional[dict] = None
+ multi_modal_content: Optional[dict] = None
+ ignore_eos: bool = False
+ language: Optional[str] = None
+
+
+@dataclass
+class RequestFuncOutput:
+ """Output for requesting LLMs via API"""
+ generated_text: str = ""
+ reasoning_content: str = ""
+ success: bool = False
+ latency: float = 0.0
+ output_tokens: int = 0
+ ttft: float = 0.0 # Time to first token
+ arrival_time: list = field(default_factory=list) # arrival_time
+ itl: list = field(default_factory=list) # list of inter-token latencies
+ tpot: float = 0.0 # avg next-token latencies
+ prompt_len: int = 0
+ prompt_tokens: int = 0 # 推理侧返回输入token数
+ error: str = ""
+
+
+async def async_request_eb_openai_chat_completions(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+) -> RequestFuncOutput:
+ """Request an LLM using EB OpenAI"""
+ api_url = request_func_input.api_url
+ assert api_url.endswith(
+ ("completions", "profile")
+ ), "OpenAI Chat Completions API URL must end with 'completions'."
+
+ async with aiohttp.ClientSession(trust_env=True,
+ timeout=AIOHTTP_TIMEOUT) as session:
+ content = [{"type": "text", "text": request_func_input.prompt}]
+ if request_func_input.multi_modal_content:
+ content.append(request_func_input.multi_modal_content)
+ payload = {
+ "model": "default",
+ "messages": request_func_input.history_QA,
+ "stream": True,
+ "stream_options": {
+ "include_usage": True,
+ "continuous_usage_stats": True
+ },
+ }
+ # 超参由yaml传入
+ payload.update(request_func_input.hyper_parameters)
+
+ if request_func_input.ignore_eos:
+ payload["ignore_eos"] = request_func_input.ignore_eos
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
+ }
+
+ output = RequestFuncOutput()
+ output.prompt_len = 0
+
+ ttft = 0.0
+ st = time.perf_counter()
+ most_recent_timestamp = st
+ try:
+ async with session.post(url=api_url, json=payload,
+ headers=headers) as response:
+ if response.status == 200:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ chunk = chunk_bytes.decode("utf-8").removeprefix(
+ "data: ")
+ if chunk != "[DONE]":
+ # print("####chunk:", chunk, type(chunk))
+ timestamp = time.perf_counter()
+ data = json.loads(chunk)
+
+ if choices := data.get("choices"):
+ content = choices[0]["delta"].get("content")
+ reason_content = choices[0]["delta"].get("reasoning_content")
+ # First token
+ if ttft == 0.0:
+ ttft = timestamp - st
+ output.ttft = ttft
+ # cached_tokens
+ output.prompt_len = data["usage"]["prompt_tokens_details"]["cached_tokens"]
+
+ # Decoding phase
+ else:
+ output.itl.append(timestamp -
+ most_recent_timestamp)
+
+ output.generated_text += content or ""
+ output.reasoning_content += reason_content or ""
+ output.arrival_time.append(choices[0].get("arrival_time"))
+ elif usage := data.get("usage"):
+ output.output_tokens = usage.get(
+ "completion_tokens")
+ output.prompt_tokens = usage.get(
+ "prompt_tokens")
+
+ most_recent_timestamp = timestamp
+
+ # output.generated_text = generated_text
+ if output.generated_text.strip() == "":
+ output.success = False
+ output.error = "No generated text found!"
+ else:
+ output.success = True
+ output.latency = most_recent_timestamp - st
+ else:
+ error_text = await response.text()
+ print("####error response:", error_text, "####payload:", payload)
+ output.error = error_text or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ # 保存失败请求结果
+ if not output.success:
+ with open("error_output.txt", "a") as f:
+ f.write(str(output) + "\n")
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def async_request_eb_openai_completions(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+) -> RequestFuncOutput:
+ """Request an LLM using EB OpenAI"""
+ api_url = request_func_input.api_url
+ assert api_url.endswith(
+ ("completions", "profile")
+ ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
+
+ async with aiohttp.ClientSession(trust_env=True,
+ timeout=AIOHTTP_TIMEOUT) as session:
+ payload = {
+ "model": "default",
+ "prompt": request_func_input.prompt,
+ "stream": True,
+ "stream_options": {
+ "include_usage": True,
+ "continuous_usage_stats": True
+ },
+ }
+ # 超参由yaml传入
+ payload.update(request_func_input.hyper_parameters)
+
+ if request_func_input.ignore_eos:
+ payload["ignore_eos"] = request_func_input.ignore_eos
+ headers = {
+ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
+ }
+
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+
+ generated_text = ""
+ st = time.perf_counter()
+ most_recent_timestamp = st
+ try:
+ async with session.post(url=api_url, json=payload,
+ headers=headers) as response:
+ if response.status == 200:
+ first_chunk_received = False
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ chunk = chunk_bytes.decode("utf-8").removeprefix(
+ "data: ")
+ if chunk != "[DONE]":
+ # print("####chunk:", chunk, chunk.usage)
+ data = json.loads(chunk)
+
+ # NOTE: Some completion API might have a last
+ # usage summary response without a token so we
+ # want to check a token was generated
+ if choices := data.get("choices"):
+ # Note that text could be empty here
+ # e.g. for special tokens
+ text = choices[0].get("text")
+ timestamp = time.perf_counter()
+ # First token
+ if not first_chunk_received:
+ first_chunk_received = True
+ ttft = time.perf_counter() - st
+ output.ttft = ttft
+
+ # Decoding phase
+ else:
+ output.itl.append(timestamp -
+ most_recent_timestamp)
+
+ most_recent_timestamp = timestamp
+ output.arrival_time.append(choices[0].get("arrival_time"))
+ generated_text += text or ""
+ elif usage := data.get("usage"):
+ output.prompt_tokens = usage.get(
+ "prompt_tokens")
+ output.output_tokens = usage.get(
+ "completion_tokens")
+ if first_chunk_received:
+ output.success = True
+ else:
+ output.success = False
+ output.error = (
+ "Never received a valid chunk to calculate TTFT."
+ "This response will be marked as failed!")
+ output.generated_text = generated_text
+ output.latency = most_recent_timestamp - st
+ else:
+ output.error = response.reason or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def async_request_tgi(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+) -> RequestFuncOutput:
+ """Request an LLM using the TGI API"""
+ api_url = request_func_input.api_url
+ assert api_url.endswith("generate_stream")
+
+ async with aiohttp.ClientSession(trust_env=True,
+ timeout=AIOHTTP_TIMEOUT) as session:
+ params = {
+ "max_new_tokens": request_func_input.output_len,
+ "do_sample": True,
+ "temperature": 0.01, # TGI does not accept 0.0 temperature.
+ "top_p": 0.99, # TGI does not accept 1.0 top_p.
+ "truncate": request_func_input.prompt_len,
+ "ignore_eos_token": request_func_input.ignore_eos,
+ }
+ payload = {
+ "inputs": request_func_input.prompt,
+ "parameters": params,
+ }
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+ if request_func_input.ignore_eos:
+ output.output_tokens = request_func_input.output_len
+ else:
+ output.output_tokens = None
+
+ ttft = 0.0
+ st = time.perf_counter()
+ most_recent_timestamp = st
+ try:
+ async with session.post(url=api_url, json=payload) as response:
+ if response.status == 200:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+ chunk_bytes = chunk_bytes.decode("utf-8")
+
+ # NOTE: Sometimes TGI returns a ping response without
+ # any data, we should skip it.
+ if chunk_bytes.startswith(":"):
+ continue
+ chunk = chunk_bytes.removeprefix("data:")
+
+ data = json.loads(chunk)
+ timestamp = time.perf_counter()
+ # First token
+ if ttft == 0.0:
+ ttft = time.perf_counter() - st
+ output.ttft = ttft
+
+ # Decoding phase
+ else:
+ output.itl.append(timestamp -
+ most_recent_timestamp)
+
+ most_recent_timestamp = timestamp
+ output.arrival_time.append(data["arrival_time"])
+
+ output.latency = most_recent_timestamp - st
+ output.success = True
+ output.generated_text = data["generated_text"]
+ else:
+ output.error = response.reason or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def async_request_trt_llm(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+) -> RequestFuncOutput:
+ """Request an LLM using TRT's llm_server"""
+ api_url = request_func_input.api_url
+ assert api_url.endswith("generate_stream")
+
+ async with aiohttp.ClientSession(trust_env=True,
+ timeout=AIOHTTP_TIMEOUT) as session:
+ payload = {
+ "accumulate_tokens": True,
+ "text_input": request_func_input.prompt,
+ "temperature": 0.0,
+ "top_p": 1.0,
+ "max_tokens": request_func_input.output_len,
+ "stream": True,
+ }
+ if request_func_input.ignore_eos:
+ payload["min_length"] = request_func_input.output_len
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+
+ ttft = 0.0
+ st = time.perf_counter()
+ most_recent_timestamp = st
+ try:
+ async with session.post(url=api_url, json=payload) as response:
+ if response.status == 200:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ chunk = chunk_bytes.decode("utf-8").removeprefix(
+ "data:")
+
+ data = json.loads(chunk)
+ output.generated_text += data["text_output"]
+ timestamp = time.perf_counter()
+ # First token
+ if ttft == 0.0:
+ ttft = timestamp - st
+ output.ttft = ttft
+
+ # Decoding phase
+ else:
+ output.itl.append(timestamp -
+ most_recent_timestamp)
+
+ most_recent_timestamp = timestamp
+
+ output.latency = most_recent_timestamp - st
+ output.success = True
+
+ else:
+ output.error = response.reason or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def async_request_deepspeed_mii(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+) -> RequestFuncOutput:
+ """Request an LLM using Deepspeed MII"""
+ async with aiohttp.ClientSession(trust_env=True,
+ timeout=AIOHTTP_TIMEOUT) as session:
+
+ payload = {
+ "prompt": request_func_input.prompt,
+ "max_tokens": request_func_input.output_len,
+ "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
+ "top_p": 1.0,
+ }
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+
+ # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
+ # will use 0 as placeholder.
+ # See https://github.com/microsoft/DeepSpeed-MII/pull/311
+ output.ttft = 0
+
+ st = time.perf_counter()
+ try:
+ async with session.post(url=request_func_input.api_url,
+ json=payload) as response:
+ if response.status == 200:
+ parsed_resp = await response.json()
+ output.latency = time.perf_counter() - st
+ if "choices" in parsed_resp:
+ output.generated_text = parsed_resp["choices"][0][
+ "text"]
+ elif "text" in parsed_resp:
+ output.generated_text = parsed_resp["text"][0]
+ else:
+ output.error = ("Unexpected response format: "
+ "neither 'choices' nor 'text' found")
+ output.success = False
+ output.success = True
+ else:
+ output.error = response.reason or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def async_request_openai_completions(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+) -> RequestFuncOutput:
+ """Request an LLM using OpenAI"""
+ api_url = request_func_input.api_url
+ assert api_url.endswith(
+ ("completions", "profile")
+ ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
+
+ async with aiohttp.ClientSession(trust_env=True,
+ timeout=AIOHTTP_TIMEOUT) as session:
+ payload = {
+ "model": request_func_input.model_name \
+ if request_func_input.model_name else request_func_input.model,
+ "prompt": request_func_input.prompt,
+ # "temperature": 0.0,
+ "max_tokens": request_func_input.output_len,
+ "logprobs": request_func_input.logprobs,
+ "stream": True,
+ #"stream_options": {
+ # "include_usage": True,
+ #},
+ }
+ if request_func_input.ignore_eos:
+ payload["ignore_eos"] = request_func_input.ignore_eos
+
+ headers = {
+ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
+ }
+
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+
+ generated_text = ""
+ st = time.perf_counter()
+ most_recent_timestamp = st
+ try:
+ async with session.post(url=api_url, json=payload,
+ headers=headers) as response:
+ if response.status == 200:
+ first_chunk_received = False
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ chunk = chunk_bytes.decode("utf-8").removeprefix(
+ "data: ")
+ if chunk != "[DONE]":
+ # print("####chunk:", chunk, type(chunk))
+ data = json.loads(chunk)
+
+ # NOTE: Some completion API might have a last
+ # usage summary response without a token so we
+ # want to check a token was generated
+ if choices := data.get("choices"):
+ # Note that text could be empty here
+ # e.g. for special tokens
+ text = choices[0].get("text")
+ timestamp = time.perf_counter()
+ # First token
+ if not first_chunk_received:
+ first_chunk_received = True
+ ttft = time.perf_counter() - st
+ output.ttft = ttft
+
+ # Decoding phase
+ else:
+ output.itl.append(timestamp -
+ most_recent_timestamp)
+
+ most_recent_timestamp = timestamp
+ generated_text += text or ""
+ elif usage := data.get("usage"):
+ output.output_tokens = usage.get(
+ "completion_tokens")
+ if first_chunk_received:
+ output.success = True
+ else:
+ output.success = False
+ output.error = (
+ "Never received a valid chunk to calculate TTFT."
+ "This response will be marked as failed!")
+ output.generated_text = generated_text
+ output.latency = most_recent_timestamp - st
+ else:
+ output.error = response.reason or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def async_request_openai_audio(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+) -> RequestFuncOutput:
+ """Request an LLM using OpenAI"""
+ # Lazy import without PlaceholderModule to avoid vllm dep.
+ import soundfile
+ api_url = request_func_input.api_url
+ assert api_url.endswith(
+ ("transcriptions", "translations"
+ )), "OpenAI Chat Completions API URL must end with 'transcriptions' "
+ "or `translations`."
+
+ async with aiohttp.ClientSession(trust_env=True,
+ timeout=AIOHTTP_TIMEOUT) as session:
+ content = [{"type": "text", "text": request_func_input.prompt}]
+ payload = {
+ "model": request_func_input.model_name \
+ if request_func_input.model_name else request_func_input.model,
+ "temperature": 0.0,
+ "max_completion_tokens": request_func_input.output_len,
+ "stream": True,
+ "language": "en",
+ # Flattened due to multipart/form-data
+ "stream_include_usage": True,
+ "stream_continuous_usage_stats": True
+ }
+ if request_func_input.extra_body:
+ payload.update(request_func_input.extra_body)
+ headers = {
+ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
+ }
+
+ # Send audio file
+ def to_bytes(y, sr):
+ buffer = io.BytesIO()
+ soundfile.write(buffer, y, sr, format="WAV")
+ buffer.seek(0)
+ return buffer
+
+ with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
+ form = aiohttp.FormData()
+ form.add_field('file', f, content_type='audio/wav')
+ for key, value in payload.items():
+ form.add_field(key, str(value))
+
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+
+ generated_text = ""
+ ttft = 0.0
+ st = time.perf_counter()
+ most_recent_timestamp = st
+ try:
+ async with session.post(url=api_url,
+ data=form,
+ headers=headers) as response:
+ if response.status == 200:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ chunk = chunk_bytes.decode("utf-8").removeprefix(
+ "data: ")
+ if chunk != "[DONE]":
+ timestamp = time.perf_counter()
+ data = json.loads(chunk)
+
+ if choices := data.get("choices"):
+ content = choices[0]["delta"].get(
+ "content")
+ # First token
+ if ttft == 0.0:
+ ttft = timestamp - st
+ output.ttft = ttft
+
+ # Decoding phase
+ else:
+ output.itl.append(
+ timestamp - most_recent_timestamp)
+
+ generated_text += content or ""
+ elif usage := data.get("usage"):
+ output.output_tokens = usage.get(
+ "completion_tokens")
+
+ most_recent_timestamp = timestamp
+
+ output.generated_text = generated_text
+ output.success = True
+ output.latency = most_recent_timestamp - st
+ else:
+ output.error = response.reason or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+ASYNC_REQUEST_FUNCS = {
+ "tgi": async_request_tgi,
+ "vllm": async_request_openai_completions,
+ "lmdeploy": async_request_openai_completions,
+ "deepspeed-mii": async_request_deepspeed_mii,
+ "openai": async_request_eb_openai_completions,
+ "openai-chat": async_request_eb_openai_chat_completions,
+ "openai-audio": async_request_openai_audio,
+ "tensorrt-llm": async_request_trt_llm,
+ "scalellm": async_request_openai_completions,
+ "sglang": async_request_openai_completions,
+}
+
+OPENAI_COMPATIBLE_BACKENDS = [
+ k for k, v in ASYNC_REQUEST_FUNCS.items()
+ if v in (async_request_openai_completions,
+ async_request_eb_openai_chat_completions)
+]
+
diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py
new file mode 100644
index 000000000..2d8bcca34
--- /dev/null
+++ b/benchmarks/benchmark_dataset.py
@@ -0,0 +1,309 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+# This file is modified from https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_dataset.py
+
+
+import base64
+import io
+import json
+import logging
+import random
+from abc import ABC, abstractmethod
+from collections.abc import Mapping
+from dataclasses import dataclass
+from io import BytesIO
+from typing import Any, Callable, Optional, Union
+from PIL import Image
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class SampleRequest:
+ """
+ Represents a single inference request for benchmarking.
+ """
+
+ prompt: Union[str, Any]
+ history_QA: Union[str, Any]
+ json_data: Optional[dict]
+ prompt_len: int
+ expected_output_len: int
+
+
+class BenchmarkDataset(ABC):
+ """BenchmarkDataset"""
+ DEFAULT_SEED = 0
+ IS_MULTIMODAL = False
+
+ def __init__(
+ self,
+ dataset_path: Optional[str] = None,
+ random_seed: int = DEFAULT_SEED,
+ hyperparameter_path: Optional[str] = None,
+ ) -> None:
+ """
+ Initialize the BenchmarkDataset with an optional dataset path and random
+ seed. Args:
+ dataset_path (Optional[str]): Path to the dataset. If None, it
+ indicates that a default or random dataset might be used.
+ random_seed (int): Seed value for reproducible shuffling or
+ sampling. Defaults to DEFAULT_SEED.
+ """
+ self.dataset_path = dataset_path
+ # Set the random seed, ensuring that a None value is replaced with the
+ # default seed.
+ self.random_seed = (random_seed
+ if random_seed is not None else self.DEFAULT_SEED)
+ self.data = None
+ self.hyperparameter_path = hyperparameter_path
+ self.hyperparameters = {}
+
+ def load_data(self) -> None:
+ """
+ Load data from the dataset path into self.data.
+
+ This method must be overridden by subclasses since the method to load
+ data will vary depending on the dataset format and source.
+
+ Raises:
+ NotImplementedError: If a subclass does not implement this method.
+ """
+ # TODO (jenniferzhao): add support for downloading data
+ raise NotImplementedError(
+ "load_data must be implemented in subclasses.")
+
+ @abstractmethod
+ def sample(self, num_requests: int) -> list[SampleRequest]:
+ """
+ Abstract method to generate sample requests from the dataset.
+
+ Subclasses must override this method to implement dataset-specific logic
+ for generating a list of SampleRequest objects.
+
+ Args:
+ num_requests (int): The number of sample requests to generate.
+
+ Returns:
+ list[SampleRequest]: A list of sample requests generated from the
+ dataset.
+ """
+ raise NotImplementedError("sample must be implemented in subclasses.")
+
+ def maybe_oversample_requests(self, requests: list[SampleRequest],
+ num_requests: int) -> None:
+ """
+ Oversamples the list of requests if its size is less than the desired
+ number.
+
+ Args:
+ requests (List[SampleRequest]): The current list of sampled
+ requests. num_requests (int): The target number of requests.
+ """
+ if len(requests) < num_requests:
+ random.seed(self.random_seed)
+ additional = random.choices(requests,
+ k=num_requests - len(requests))
+ requests.extend(additional)
+ logger.info("Oversampled requests to reach %d total samples.",
+ num_requests)
+
+
+def is_valid_sequence(
+ prompt_len: int,
+ output_len: int,
+ min_len: int = 4,
+ max_prompt_len: int = 1024,
+ max_total_len: int = 2048,
+ skip_min_output_len_check: bool = False,
+) -> bool:
+ """
+ Validate a sequence based on prompt and output lengths.
+
+ Default pruning criteria are copied from the original `sample_hf_requests`
+ and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as
+ from `sample_requests` in benchmark_throughput.py.
+ """
+ # Check for invalid conditions
+ prompt_too_short = prompt_len < min_len
+ output_too_short = (not skip_min_output_len_check) and (output_len
+ < min_len)
+ prompt_too_long = prompt_len > max_prompt_len
+ combined_too_long = (prompt_len + output_len) > max_total_len
+
+ # Return True if none of the invalid conditions are met
+ return not (prompt_too_short or output_too_short or prompt_too_long
+ or combined_too_long)
+
+
+def process_image(image: Any) -> Mapping[str, Any]:
+ """
+ Process a single image input and return a multimedia content dictionary.
+
+ Supports three input types:
+
+ 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
+ containing raw image data. - Loads the bytes as a PIL.Image.Image.
+
+ 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as
+ a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns
+ a dictionary with the image as a base64 data URL.
+
+ 3. String input: - Treats the string as a URL or local file path. -
+ Prepends "file://" if the string doesn't start with "http://" or
+ "file://". - Returns a dictionary with the image URL.
+
+ Raises:
+ ValueError: If the input is not a supported type.
+ """
+ if isinstance(image, dict) and 'bytes' in image:
+ image = Image.open(BytesIO(image['bytes']))
+ if isinstance(image, Image.Image):
+ image = image.convert("RGB")
+ with io.BytesIO() as image_data:
+ image.save(image_data, format="JPEG")
+ image_base64 = base64.b64encode(
+ image_data.getvalue()).decode("utf-8")
+ return {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/jpeg;base64,{image_base64}"
+ },
+ }
+
+ if isinstance(image, str):
+ image_url = (image if image.startswith(
+ ("http://", "file://")) else f"file://{image}")
+ return {"type": "image_url", "image_url": {"url": image_url}}
+
+ raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
+ " or str or dictionary with raw image bytes.")
+
+
+class EBDataset(BenchmarkDataset):
+ """
+ Implements the ShareGPT dataset. Loads data from a JSON file and generates
+ sample requests based on conversation turns.
+ """
+
+ temperature: float
+ repetition_penalty: float
+ frequency_penalty: float
+ presence_penalty: float
+ top_p: float
+ prompt_len: int
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.load_data()
+
+ def load_data(self) -> None:
+ if self.dataset_path is None:
+ raise ValueError("dataset_path must be provided for loading data.")
+
+ with open(self.dataset_path, encoding="utf-8") as f:
+ self.data = [json.loads(i.strip()) for i in f.readlines()]
+
+ def sample(
+ self,
+ num_requests: int,
+ lora_path: Optional[str] = None,
+ max_loras: Optional[int] = None,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ **kwargs,
+ ) -> list:
+ samples: list = []
+ for entry in self.data:
+ if len(samples) >= num_requests:
+ break
+ prompt = entry["text"]
+ self.temperature = float(entry["temperature"])
+ self.repetition_penalty = float(entry["penalty_score"])
+ self.frequency_penalty = float(entry["frequency_score"])
+ self.presence_penalty = float(entry["presence_score"])
+ self.top_p = float(entry["topp"])
+ self.prompt_len = int(entry["input_token_num"])
+ new_output_len = int(entry["max_dec_len"])
+
+ if enable_multimodal_chat:
+ prompt = self.apply_multimodal_chat_transformation(
+ prompt, None)
+ samples.append(
+ SampleRequest(
+ prompt=prompt,
+ prompt_len=self.prompt_len,
+ history_QA=[],
+ expected_output_len=new_output_len,
+ ))
+
+ self.maybe_oversample_requests(samples, num_requests)
+ return samples
+
+
+class EBChatDataset(BenchmarkDataset):
+ """
+ Implements the ShareGPT dataset. Loads data from a JSON file and generates
+ sample requests based on conversation turns.
+ """
+ prompt_len: int
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.load_data()
+
+ def load_data(self) -> None:
+ if self.dataset_path is None:
+ raise ValueError("dataset_path must be provided for loading data.")
+
+ with open(self.dataset_path, encoding="utf-8") as f:
+ self.data = [json.loads(i.strip()) for i in f.readlines()]
+
+ def sample(
+ self,
+ num_requests: int,
+ lora_path: Optional[str] = None,
+ max_loras: Optional[int] = None,
+ output_len: Optional[int] = None,
+ enable_multimodal_chat: bool = False,
+ **kwargs,
+ ) -> list:
+ samples: list = []
+ for entry in self.data:
+ if len(samples) >= num_requests:
+ break
+ json_data = entry
+ prompt = entry["messages"][-1].get("content", "")
+ history_QA = entry.get("messages", [])
+ new_output_len = int(entry.get("max_tokens", 12288))
+
+ if enable_multimodal_chat:
+ prompt = self.apply_multimodal_chat_transformation(
+ prompt, None)
+ samples.append(
+ SampleRequest(
+ json_data=json_data,
+ prompt=prompt,
+ prompt_len=0,
+ history_QA=history_QA,
+ expected_output_len=new_output_len,
+ ))
+
+ self.maybe_oversample_requests(samples, num_requests)
+ return samples
+
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
new file mode 100644
index 000000000..924f96ad4
--- /dev/null
+++ b/benchmarks/benchmark_serving.py
@@ -0,0 +1,1141 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+# This file is modified from https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_serving.py
+
+
+import argparse
+import asyncio
+import gc
+import json
+import os
+import random
+import time
+import warnings
+import yaml
+from collections.abc import AsyncGenerator, Iterable
+from dataclasses import dataclass
+from datetime import datetime
+from typing import Any, Optional
+
+import numpy as np
+from backend_request_func import (ASYNC_REQUEST_FUNCS,
+ OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
+ RequestFuncOutput)
+from tqdm.asyncio import tqdm
+
+from argparse import ArgumentParser as FlexibleArgumentParser
+
+from benchmark_dataset import (SampleRequest, EBDataset, EBChatDataset)
+from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
+
+MILLISECONDS_TO_SECONDS_CONVERSION = 1000
+
+
+@dataclass
+class BenchmarkMetrics:
+ """Class containing all metrics that are used in this script"""
+ completed: int
+ total_input: int
+ total_output: int
+ request_throughput: float
+ request_goodput: float
+ output_throughput: float
+ total_token_throughput: float
+ mean_s_decode: float
+ median_s_decode: float
+ std_s_decode: float
+ percentiles_s_decode: list[tuple[float, float]]
+ mean_ttft_ms: float
+ median_ttft_ms: float
+ std_ttft_ms: float
+ percentiles_ttft_ms: list[tuple[float, float]]
+ mean_s_ttft_ms: float
+ median_s_ttft_ms: float
+ std_s_ttft_ms: float
+ percentiles_s_ttft_ms: list[tuple[float, float]]
+ mean_tpot_ms: float
+ median_tpot_ms: float
+ std_tpot_ms: float
+ percentiles_tpot_ms: list[tuple[float, float]]
+ mean_itl_ms: float
+ median_itl_ms: float
+ std_itl_ms: float
+ percentiles_itl_ms: list[tuple[float, float]]
+ mean_s_itl_ms: float
+ median_s_itl_ms: float
+ std_s_itl_ms: float
+ percentiles_s_itl_ms: list[tuple[float, float]]
+ # E2EL stands for end-to-end latency per request.
+ # It is the time taken on the client side from sending
+ # a request to receiving a complete response.
+ mean_e2el_ms: float
+ median_e2el_ms: float
+ std_e2el_ms: float
+ percentiles_e2el_ms: list[tuple[float, float]]
+ mean_s_e2el_ms: float
+ median_s_e2el_ms: float
+ std_s_e2el_ms: float
+ percentiles_s_e2el_ms: list[tuple[float, float]]
+ mean_input_len: float
+ median_input_len: float
+ std_input_len: float
+ percentiles_input_len: list[tuple[float, float]]
+ mean_s_input_len: float
+ median_s_input_len: float
+ std_s_input_len: float
+ percentiles_s_input_len: list[tuple[float, float]]
+ mean_output_len: float
+ median_output_len: float
+ std_output_len: float
+ percentiles_output_len: list[tuple[float, float]]
+
+
+async def get_request(
+ input_requests: list[SampleRequest],
+ request_rate: float,
+ burstiness: float = 1.0,
+) -> AsyncGenerator[SampleRequest, None]:
+ """
+ Asynchronously generates requests at a specified rate
+ with OPTIONAL burstiness.
+
+ Args:
+ input_requests:
+ A list of input requests, each represented as a SampleRequest.
+ request_rate:
+ The rate at which requests are generated (requests/s).
+ burstiness (optional):
+ The burstiness factor of the request generation.
+ Only takes effect when request_rate is not inf.
+ Default value is 1, which follows a Poisson process.
+ Otherwise, the request intervals follow a gamma distribution.
+ A lower burstiness value (0 < burstiness < 1) results
+ in more bursty requests, while a higher burstiness value
+ (burstiness > 1) results in a more uniform arrival of requests.
+ """
+ input_requests: Iterable[SampleRequest] = iter(input_requests)
+
+ # Calculate scale parameter theta to maintain the desired request_rate.
+ assert burstiness > 0, (
+ f"A positive burstiness factor is expected, but given {burstiness}.")
+ theta = 1.0 / (request_rate * burstiness)
+
+ for request in input_requests:
+ yield request
+
+ if request_rate == float("inf"):
+ # If the request rate is infinity, then we don't need to wait.
+ continue
+
+ # Sample the request interval from the gamma distribution.
+ # If burstiness is 1, it follows exponential distribution.
+ interval = np.random.gamma(shape=burstiness, scale=theta)
+ # The next request will be sent after the interval.
+ await asyncio.sleep(interval)
+
+
+def calculate_metrics(
+ input_requests: list[SampleRequest],
+ outputs: list[RequestFuncOutput],
+ dur_s: float,
+ selected_percentiles: list[float],
+ goodput_config_dict: dict[str, float],
+) -> tuple[BenchmarkMetrics, list[int]]:
+ """Calculates various performance metrics based on the inputs and outputs."""
+ input_lens: list[int] = []
+ infer_input_lens: list[int] = [] # 推理侧输入token数
+ actual_output_lens: list[int] = []
+ total_input = 0
+ completed = 0
+ good_completed = 0
+ itls: list[float] = []
+ s_itls: list[float] = []
+ tpots: list[float] = []
+ all_tpots: list[float] = []
+ ttfts: list[float] = []
+ s_ttfts: list[float] = []
+ e2els: list[float] = []
+ s_e2els: list[float] = []
+ s_decodes: list[float] = []
+ for i in range(len(outputs)):
+ if outputs[i].success:
+ output_len = outputs[i].output_tokens
+
+ if not output_len:
+ print("no output_len")
+ # We use the tokenizer to count the number of output tokens
+ # for some serving backends instead of looking at
+ # len(outputs[i].itl) since multiple output tokens may be
+ # bundled together
+ # Note : this may inflate the output token count slightly
+
+ actual_output_lens.append(output_len)
+ input_lens.append(outputs[i].prompt_len)
+ infer_input_lens.append(outputs[i].prompt_tokens)
+ total_input += outputs[i].prompt_tokens
+ tpot = 0
+ if output_len > 1:
+ latency_minus_ttft = outputs[i].latency - outputs[i].ttft
+ tpot = latency_minus_ttft / (output_len - 1)
+ tpots.append(tpot)
+ # Note: if output_len <= 1, we regard tpot as 0 for goodput
+ all_tpots.append(tpot)
+ itls += outputs[i].itl
+ # 推理侧ITL
+ s_a = outputs[i].arrival_time[1:]
+ for j in range(len(s_a) - 2):
+ s_itls.append(s_a[j + 1] - s_a[j])
+ ttfts.append(outputs[i].ttft)
+ # 推理侧TTFT
+ s_ttfts.append(outputs[i].arrival_time[1])
+ e2els.append(outputs[i].latency)
+ # 推理侧整句时延
+ s_e2els.append(outputs[i].arrival_time[-1])
+ # 解码速度去掉首token
+ if len(outputs[i].arrival_time) > 2:
+ s_decodes.append((outputs[i].output_tokens - 1) /
+ (outputs[i].arrival_time[-1] - outputs[i].arrival_time[1]))
+ completed += 1
+ else:
+ actual_output_lens.append(0)
+ input_lens.append(0)
+ infer_input_lens.append(0)
+
+ if goodput_config_dict:
+ valid_metrics = []
+ slo_values = []
+
+ if "ttft" in goodput_config_dict:
+ valid_metrics.append(ttfts)
+ slo_values.append(goodput_config_dict["ttft"] /
+ MILLISECONDS_TO_SECONDS_CONVERSION)
+ if "tpot" in goodput_config_dict:
+ valid_metrics.append(all_tpots)
+ slo_values.append(goodput_config_dict["tpot"] /
+ MILLISECONDS_TO_SECONDS_CONVERSION)
+ if "e2el" in goodput_config_dict:
+ valid_metrics.append(e2els)
+ slo_values.append(goodput_config_dict["e2el"] /
+ MILLISECONDS_TO_SECONDS_CONVERSION)
+
+ for req_metric in zip(*valid_metrics):
+ is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
+ if is_good_req:
+ good_completed += 1
+
+ if completed == 0:
+ warnings.warn(
+ "All requests failed. This is likely due to a misconfiguration "
+ "on the benchmark arguments.",
+ stacklevel=2)
+ metrics = BenchmarkMetrics(
+ completed=completed,
+ total_input=total_input,
+ total_output=sum(actual_output_lens),
+ request_throughput=completed / dur_s,
+ request_goodput=good_completed / dur_s,
+ output_throughput=sum(actual_output_lens) / dur_s,
+ total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
+ mean_s_decode=np.mean(s_decodes or 0) *
+ 1, # ttfts is empty if streaming is not supported by backend
+ std_s_decode=np.std(s_decodes or 0) * 1,
+ median_s_decode=np.median(s_decodes or 0) * 1,
+ percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1)
+ for p in selected_percentiles],
+ mean_ttft_ms=np.mean(ttfts or 0) *
+ 1000, # ttfts is empty if streaming is not supported by backend
+ std_ttft_ms=np.std(ttfts or 0) * 1000,
+ median_ttft_ms=np.median(ttfts or 0) * 1000,
+ percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_s_ttft_ms=np.mean(s_ttfts or 0) *
+ 1000, # ttfts is empty if streaming is not supported by backend
+ std_s_ttft_ms=np.std(s_ttfts or 0) * 1000,
+ median_s_ttft_ms=np.median(s_ttfts or 0) * 1000,
+ percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_tpot_ms=np.mean(tpots or 0) * 1000,
+ std_tpot_ms=np.std(tpots or 0) * 1000,
+ median_tpot_ms=np.median(tpots or 0) * 1000,
+ percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_itl_ms=np.mean(itls or 0) * 1000,
+ std_itl_ms=np.std(itls or 0) * 1000,
+ median_itl_ms=np.median(itls or 0) * 1000,
+ percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_s_itl_ms=np.mean(s_itls or 0) * 1000,
+ std_s_itl_ms=np.std(s_itls or 0) * 1000,
+ median_s_itl_ms=np.median(s_itls or 0) * 1000,
+ percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_e2el_ms=np.mean(e2els or 0) * 1000,
+ std_e2el_ms=np.std(e2els or 0) * 1000,
+ median_e2el_ms=np.median(e2els or 0) * 1000,
+ percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000,
+ std_s_e2el_ms=np.std(s_e2els or 0) * 1000,
+ median_s_e2el_ms=np.median(s_e2els or 0) * 1000,
+ percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_input_len=np.mean(input_lens or 0) * 1,
+ std_input_len=np.std(input_lens or 0) * 1,
+ median_input_len=np.median(input_lens or 0) * 1,
+ percentiles_input_len=[(p, np.percentile(input_lens or 0, p))
+ for p in selected_percentiles],
+ mean_s_input_len=np.mean(infer_input_lens or 0) * 1,
+ std_s_input_len=np.std(infer_input_lens or 0) * 1,
+ median_s_input_len=np.median(infer_input_lens or 0) * 1,
+ percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p))
+ for p in selected_percentiles],
+ mean_output_len=np.mean(actual_output_lens or 0) * 1,
+ std_output_len=np.std(actual_output_lens or 0) * 1,
+ median_output_len=np.median(actual_output_lens or 0) * 1,
+ percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p))
+ for p in selected_percentiles],
+ )
+
+ return metrics, actual_output_lens
+
+
+async def benchmark(
+ backend: str,
+ api_url: str,
+ base_url: str,
+ model_id: str,
+ model_name: str,
+ input_requests: list[SampleRequest],
+ hyper_parameters: dict,
+ logprobs: Optional[int],
+ request_rate: float,
+ burstiness: float,
+ disable_tqdm: bool,
+ profile: bool,
+ selected_percentile_metrics: list[str],
+ selected_percentiles: list[float],
+ ignore_eos: bool,
+ goodput_config_dict: dict[str, float],
+ max_concurrency: Optional[int],
+ lora_modules: Optional[Iterable[str]],
+ extra_body: Optional[dict],
+):
+ """Benchmarks an API endpoint using a given set of sample inputs and returns"""
+ if backend in ASYNC_REQUEST_FUNCS:
+ request_func = ASYNC_REQUEST_FUNCS[backend]
+ else:
+ raise ValueError(f"Unknown backend: {backend}")
+
+ print("Starting initial single prompt test run...")
+ test_prompt, test_output_len = \
+ input_requests[0].prompt, \
+ input_requests[0].expected_output_len
+ test_history_QA = input_requests[0].history_QA
+
+ test_input = RequestFuncInput(
+ model=model_id,
+ model_name=model_name,
+ prompt=test_prompt,
+ prompt_len=0,
+ history_QA=test_history_QA,
+ hyper_parameters=hyper_parameters,
+ api_url=api_url,
+ output_len=test_output_len,
+ logprobs=logprobs,
+ ignore_eos=ignore_eos,
+ extra_body=extra_body,
+ )
+
+ print("test_input:", test_input)
+
+ test_output = await request_func(request_func_input=test_input)
+
+ print("test_output:", test_output)
+
+ if not test_output.success:
+ raise ValueError(
+ "Initial test run failed - Please make sure benchmark arguments "
+ f"are correctly specified. Error: {test_output.error}")
+ else:
+ print("Initial test run completed. Starting main benchmark run...")
+
+ if lora_modules:
+ # For each input request, choose a LoRA module at random.
+ lora_modules = iter(
+ [random.choice(lora_modules) \
+ for _ in range(len(input_requests))])
+
+ if profile:
+ print("Starting profiler...")
+ profile_input = RequestFuncInput(model=model_id,
+ model_name=model_name,
+ prompt=test_prompt,
+ api_url=base_url + "/start_profile",
+ output_len=test_output_len,
+ logprobs=logprobs,
+ ignore_eos=ignore_eos,
+ extra_body=extra_body)
+ profile_output = await request_func(request_func_input=profile_input)
+ if profile_output.success:
+ print("Profiler started")
+
+ if burstiness == 1.0:
+ distribution = "Poisson process"
+ else:
+ distribution = "Gamma distribution"
+
+ print(f"Traffic request rate: {request_rate}")
+ print(f"Burstiness factor: {burstiness} ({distribution})")
+ print(f"Maximum request concurrency: {max_concurrency}")
+
+ pbar = None if disable_tqdm else tqdm(total=len(input_requests))
+
+ # This can be used once the minimum Python version is 3.10 or higher,
+ # and it will simplify the code in limited_request_func.
+ # semaphore = (asyncio.Semaphore(max_concurrency)
+ # if max_concurrency else contextlib.nullcontext())
+ semaphore = (asyncio.Semaphore(max_concurrency)
+ if max_concurrency else None)
+
+ async def limited_request_func(request_func_input, pbar):
+ if semaphore is None:
+ return await request_func(request_func_input=request_func_input,
+ pbar=pbar)
+ async with semaphore:
+ return await request_func(request_func_input=request_func_input,
+ pbar=pbar)
+
+ benchmark_start_time = time.perf_counter()
+ tasks: list[asyncio.Task] = []
+ async for request in get_request(input_requests, request_rate, burstiness):
+ prompt, output_len = request.prompt, request.expected_output_len
+ history_QA = request.history_QA
+
+ req_model_id, req_model_name = model_id, model_name
+ if lora_modules:
+ req_lora_module = next(lora_modules)
+ req_model_id, req_model_name = req_lora_module, req_lora_module
+
+ request_func_input = RequestFuncInput(model=req_model_id,
+ model_name=req_model_name,
+ prompt=prompt,
+ prompt_len=0,
+ history_QA=history_QA,
+ hyper_parameters=hyper_parameters,
+ api_url=api_url,
+ output_len=output_len,
+ logprobs=logprobs,
+ ignore_eos=ignore_eos,
+ extra_body=extra_body)
+ tasks.append(
+ asyncio.create_task(
+ limited_request_func(request_func_input=request_func_input,
+ pbar=pbar)))
+ outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
+
+ if profile:
+ print("Stopping profiler...")
+ profile_input = RequestFuncInput(
+ model=model_id,
+ prompt=test_prompt,
+ api_url=base_url + "/stop_profile",
+ output_len=test_output_len,
+ logprobs=logprobs,
+ )
+ profile_output = await request_func(request_func_input=profile_input)
+ if profile_output.success:
+ print("Profiler stopped")
+
+ if pbar is not None:
+ pbar.close()
+
+ benchmark_duration = time.perf_counter() - benchmark_start_time
+
+ metrics, actual_output_lens = calculate_metrics(
+ input_requests=input_requests,
+ outputs=outputs,
+ dur_s=benchmark_duration,
+ # tokenizer=tokenizer,
+ selected_percentiles=selected_percentiles,
+ goodput_config_dict=goodput_config_dict,
+ )
+
+ print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
+ print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
+ benchmark_duration))
+ print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
+ print("{:<40} {:<10}".format("Total generated tokens:",
+ metrics.total_output))
+ print("{:<40} {:<10.3f}".format("Request throughput (req/s):",
+ metrics.request_throughput))
+ if goodput_config_dict:
+ print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
+ metrics.request_goodput))
+ print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
+ metrics.output_throughput))
+ print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
+ metrics.total_token_throughput))
+
+ result = {
+ "duration": benchmark_duration,
+ "completed": metrics.completed,
+ "total_input_tokens": metrics.total_input,
+ "total_output_tokens": metrics.total_output,
+ "request_throughput": metrics.request_throughput,
+ "request_goodput:":
+ metrics.request_goodput if goodput_config_dict else None,
+ "output_throughput": metrics.output_throughput,
+ "total_token_throughput": metrics.total_token_throughput,
+ "input_lens": [output.prompt_len for output in outputs],
+ "infer_input_lens": [output.prompt_tokens for output in outputs],
+ "output_lens": actual_output_lens,
+ "ttfts": [output.ttft for output in outputs],
+ "itls": [output.itl for output in outputs],
+ "input_texts": [input.prompt for input in input_requests],
+ "generated_texts": [output.generated_text for output in outputs],
+ "reasoning_contents": [output.reasoning_content for output in outputs],
+ "errors": [output.error for output in outputs],
+ }
+
+ def process_one_metric(
+ # E.g., "ttft"
+ metric_attribute_name: str,
+ # E.g., "TTFT"
+ metric_name: str,
+ # E.g., "Time to First Token"
+ metric_header: str,
+ ):
+ # This function prints and adds statistics of the specified
+ # metric.
+ if metric_attribute_name not in selected_percentile_metrics:
+ return
+ print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
+ print("{:<40} {:<10.2f}".format(
+ f"Mean {metric_name} (ms):",
+ getattr(metrics, f"mean_{metric_attribute_name}_ms")))
+ print("{:<40} {:<10.2f}".format(
+ f"Median {metric_name} (ms):",
+ getattr(metrics, f"median_{metric_attribute_name}_ms")))
+ result[f"mean_{metric_attribute_name}_ms"] = getattr(
+ metrics, f"mean_{metric_attribute_name}_ms")
+ result[f"median_{metric_attribute_name}_ms"] = getattr(
+ metrics, f"median_{metric_attribute_name}_ms")
+ result[f"std_{metric_attribute_name}_ms"] = getattr(
+ metrics, f"std_{metric_attribute_name}_ms")
+ for p, value in getattr(metrics,
+ f"percentiles_{metric_attribute_name}_ms"):
+ p_word = str(int(p)) if int(p) == p else str(p)
+ print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
+ value))
+ result[f"p{p_word}_{metric_attribute_name}_ms"] = value
+
+ def process_one_length(
+ # E.g., "ttft"
+ metric_attribute_name: str,
+ # E.g., "TTFT"
+ metric_name: str,
+ # E.g., "Time to First Token"
+ metric_header: str,
+ ):
+ # This function prints and adds statistics of the specified
+ # metric.
+ if metric_attribute_name not in selected_percentile_metrics:
+ return
+ print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
+ print("{:<40} {:<10.2f}".format(
+ f"Mean {metric_name}:",
+ getattr(metrics, f"mean_{metric_attribute_name}")))
+ print("{:<40} {:<10.2f}".format(
+ f"Median {metric_name}:",
+ getattr(metrics, f"median_{metric_attribute_name}")))
+ result[f"mean_{metric_attribute_name}"] = getattr(
+ metrics, f"mean_{metric_attribute_name}")
+ result[f"median_{metric_attribute_name}"] = getattr(
+ metrics, f"median_{metric_attribute_name}")
+ result[f"std_{metric_attribute_name}"] = getattr(
+ metrics, f"std_{metric_attribute_name}")
+ for p, value in getattr(metrics,
+ f"percentiles_{metric_attribute_name}"):
+ p_word = str(int(p)) if int(p) == p else str(p)
+ print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:",
+ value))
+ result[f"p{p_word}_{metric_attribute_name}"] = value
+
+ process_one_length("s_decode", "Decode", "解码速度(tok/s)")
+ process_one_metric("ttft", "TTFT", "Time to First Token")
+ process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
+ process_one_metric("tpot", "TPOT",
+ "Time per Output Token (excl. 1st token)")
+ process_one_metric("itl", "ITL", "Inter-token Latency")
+ process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
+ process_one_metric("e2el", "E2EL", "End-to-end Latency")
+ process_one_metric("s_e2el", "S_E2EL", "Infer End-to-end Latency")
+ process_one_length("input_len", "Cached Tokens", "Cached Tokens")
+ process_one_length("s_input_len", "Input Length", "Infer Input Length")
+ process_one_length("output_len", "Output Length", "Output Length")
+
+ print("=" * 50)
+
+ return result
+
+
+def check_goodput_args(args):
+ """Check whether the given argument has valid goodput configuration or not"""
+ # Check and parse goodput arguments
+ goodput_config_dict = {}
+ VALID_NAMES = ["ttft", "tpot", "e2el"]
+ if args.goodput:
+ goodput_config_dict = parse_goodput(args.goodput)
+ for slo_name, slo_val in goodput_config_dict.items():
+ if slo_name not in VALID_NAMES:
+ raise ValueError(
+ f"Invalid metric name found, {slo_name}: {slo_val}. "
+ "The service level objective name should be one of "
+ f"{str(VALID_NAMES)}. ")
+ if slo_val < 0:
+ raise ValueError(
+ f"Invalid value found, {slo_name}: {slo_val}. "
+ "The service level objective value should be "
+ "non-negative.")
+ return goodput_config_dict
+
+
+def parse_goodput(slo_pairs):
+ """Parse the string into a dictionary with keys being names of SLOS and values being their corresponding values"""
+ goodput_config_dict = {}
+ try:
+ for slo_pair in slo_pairs:
+ slo_name, slo_val = slo_pair.split(":")
+ goodput_config_dict[slo_name] = float(slo_val)
+ except ValueError as err:
+ raise argparse.ArgumentTypeError(
+ "Invalid format found for service level objectives. "
+ "Specify service level objectives for goodput as \"KEY:VALUE\" "
+ "pairs, where the key is a metric name, and the value is a "
+ "number in milliseconds.") from err
+ return goodput_config_dict
+
+
+def save_to_pytorch_benchmark_format(args: argparse.Namespace,
+ results: dict[str, Any],
+ file_name: str) -> None:
+ """Save the benchmarking results to PyTorch Benchmark Format JSON file"""
+ metrics = [
+ "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
+ "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
+ "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
+ ]
+ # These raw data might be useful, but they are rather big. They can be added
+ # later if needed
+ ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
+ pt_records = convert_to_pytorch_benchmark_format(
+ args=args,
+ metrics={k: [results[k]]
+ for k in metrics},
+ extra_info={
+ k: results[k]
+ for k in results if k not in metrics and k not in ignored_metrics
+ })
+ if pt_records:
+ # Don't use json suffix here as we don't want CI to pick it up
+ pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
+ write_to_json(pt_file, pt_records)
+
+
+def main(args: argparse.Namespace):
+ """Main entry point"""
+ print(args)
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ backend = args.backend
+ model_id = args.model
+ model_name = args.served_model_name
+ tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
+ tokenizer_mode = args.tokenizer_mode
+
+ if args.base_url is not None:
+ api_url = f"{args.base_url}{args.endpoint}"
+ base_url = f"{args.base_url}"
+ else:
+ api_url = f"http://{args.host}:{args.port}{args.endpoint}"
+ base_url = f"http://{args.host}:{args.port}"
+
+ if args.dataset_name is None:
+ raise ValueError(
+ "Please specify '--dataset-name' and the corresponding "
+ "'--dataset-path' if required.")
+
+ # For datasets that follow a similar structure, use a mapping.
+ dataset_mapping = {
+ "EB":
+ lambda: EBDataset(random_seed=args.seed,
+ dataset_path=args.dataset_path).sample(
+ num_requests=args.num_prompts,
+ output_len=args.sharegpt_output_len,
+ ),
+ "EBChat":
+ lambda: EBChatDataset(random_seed=args.seed,
+ dataset_path=args.dataset_path).sample(
+ num_requests=args.num_prompts,
+ output_len=args.sharegpt_output_len,
+ ),
+ }
+
+ try:
+ input_requests = dataset_mapping[args.dataset_name]()
+ except KeyError as err:
+ raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
+
+ goodput_config_dict = check_goodput_args(args)
+
+ # Collect the sampling parameters.
+ sampling_params = {
+ k: v
+ for k, v in {
+ "top_p": args.top_p,
+ "top_k": args.top_k,
+ "min_p": args.min_p,
+ "temperature": args.temperature
+ }.items() if v is not None
+ }
+
+ # Sampling parameters are only supported by openai-compatible backend.
+ if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
+ raise ValueError(
+ "Sampling parameters are only supported by openai-compatible "
+ "backends.")
+
+ if "temperature" not in sampling_params:
+ sampling_params["temperature"] = 0.0 # Default to greedy decoding.
+
+ # Avoid GC processing "static" data - reduce pause times.
+ gc.collect()
+ gc.freeze()
+
+ # 超参由yaml传入
+ if args.hyperparameter_path:
+ with open(args.hyperparameter_path, "r") as f:
+ hyper_parameters = yaml.safe_load(f)
+ else:
+ hyper_parameters = {}
+
+ benchmark_result = asyncio.run(
+ benchmark(
+ backend=backend,
+ api_url=api_url,
+ base_url=base_url,
+ model_id=model_id,
+ model_name=model_name,
+ input_requests=input_requests,
+ hyper_parameters=hyper_parameters,
+ logprobs=args.logprobs,
+ request_rate=args.request_rate,
+ burstiness=args.burstiness,
+ disable_tqdm=args.disable_tqdm,
+ profile=args.profile,
+ selected_percentile_metrics=args.percentile_metrics.split(","),
+ selected_percentiles=[
+ float(p) for p in args.metric_percentiles.split(",")
+ ],
+ ignore_eos=args.ignore_eos,
+ goodput_config_dict=goodput_config_dict,
+ max_concurrency=args.max_concurrency,
+ lora_modules=args.lora_modules,
+ extra_body=sampling_params,
+ ))
+
+ # Save config and results to json
+ if args.save_result:
+ result_json: dict[str, Any] = {}
+
+ # Setup
+ current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
+ result_json["date"] = current_dt
+ result_json["backend"] = backend
+ result_json["model_id"] = model_id
+ result_json["tokenizer_id"] = tokenizer_id
+ result_json["num_prompts"] = args.num_prompts
+
+ # Metadata
+ if args.metadata:
+ for item in args.metadata:
+ if "=" in item:
+ kvstring = item.split("=")
+ result_json[kvstring[0].strip()] = kvstring[1].strip()
+ else:
+ raise ValueError(
+ "Invalid metadata format. Please use KEY=VALUE format."
+ )
+
+ if not args.save_detailed:
+ # Remove fields with too many data points
+ for field in [
+ "input_lens", "output_lens", "ttfts", "itls",
+ "generated_texts", "errors"
+ ]:
+ if field in result_json:
+ del result_json[field]
+
+ # Traffic
+ result_json["request_rate"] = (args.request_rate if args.request_rate
+ < float("inf") else "inf")
+ result_json["burstiness"] = args.burstiness
+ result_json["max_concurrency"] = args.max_concurrency
+
+ # Merge with benchmark result
+ result_json = {**result_json, **benchmark_result}
+
+ # Save to file
+ base_model_id = model_id.split("/")[-1]
+ max_concurrency_str = (f"-concurrency{args.max_concurrency}"
+ if args.max_concurrency is not None else "")
+ file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
+ if args.result_filename:
+ file_name = args.result_filename
+ if args.result_dir:
+ file_name = os.path.join(args.result_dir, file_name)
+ with open(file_name, "w", encoding='utf-8') as outfile:
+ json.dump(result_json, outfile)
+ save_to_pytorch_benchmark_format(args, result_json, file_name)
+
+
+if __name__ == "__main__":
+ parser = FlexibleArgumentParser(
+ description="Benchmark the online serving throughput.")
+ parser.add_argument(
+ "--backend",
+ type=str,
+ default="vllm",
+ choices=list(ASYNC_REQUEST_FUNCS.keys()),
+ )
+ parser.add_argument(
+ "--base-url",
+ type=str,
+ default=None,
+ help="Server or API base url if not using http host and port.",
+ )
+ # Use 127.0.0.1 here instead of localhost to force the use of ipv4
+ parser.add_argument("--host", type=str, default="127.0.0.1")
+ parser.add_argument("--port", type=int, default=8000)
+ parser.add_argument(
+ "--endpoint",
+ type=str,
+ default="/v1/completions",
+ help="API endpoint.",
+ )
+ parser.add_argument(
+ "--dataset-name",
+ type=str,
+ default="sharegpt",
+ choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "EB", "EBChat"],
+ help="Name of the dataset to benchmark on.",
+ )
+ parser.add_argument("--dataset-path",
+ type=str,
+ default=None,
+ help="Path to the sharegpt/sonnet dataset. "
+ "Or the huggingface dataset ID if using HF dataset.")
+ parser.add_argument("--hyperparameter-path",
+ type=str,
+ default=None,
+ help="Path to the hyperparameter. ")
+ parser.add_argument(
+ "--max-concurrency",
+ type=int,
+ default=None,
+ help="Maximum number of concurrent requests. This can be used "
+ "to help simulate an environment where a higher level component "
+ "is enforcing a maximum number of concurrent requests. While the "
+ "--request-rate argument controls the rate at which requests are "
+ "initiated, this argument will control how many are actually allowed "
+ "to execute at a time. This means that when used in combination, the "
+ "actual request rate may be lower than specified with --request-rate, "
+ "if the server is not processing requests fast enough to keep up.")
+
+ parser.add_argument(
+ "--model",
+ type=str,
+ required=True,
+ help="Name of the model.",
+ )
+ parser.add_argument(
+ "--tokenizer",
+ type=str,
+ help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
+ )
+ parser.add_argument("--use-beam-search", action="store_true")
+ parser.add_argument(
+ "--num-prompts",
+ type=int,
+ default=1000,
+ help="Number of prompts to process.",
+ )
+ parser.add_argument(
+ "--logprobs",
+ type=int,
+ default=None,
+ help=("Number of logprobs-per-token to compute & return as part of "
+ "the request. If unspecified, then either (1) if beam search "
+ "is disabled, no logprobs are computed & a single dummy "
+ "logprob is returned for each token; or (2) if beam search "
+ "is enabled 1 logprob per token is computed"),
+ )
+ parser.add_argument(
+ "--request-rate",
+ type=float,
+ default=float("inf"),
+ help="Number of requests per second. If this is inf, "
+ "then all the requests are sent at time 0. "
+ "Otherwise, we use Poisson process or gamma distribution "
+ "to synthesize the request arrival times.",
+ )
+ parser.add_argument(
+ "--burstiness",
+ type=float,
+ default=1.0,
+ help="Burstiness factor of the request generation. "
+ "Only take effect when request_rate is not inf. "
+ "Default value is 1, which follows Poisson process. "
+ "Otherwise, the request intervals follow a gamma distribution. "
+ "A lower burstiness value (0 < burstiness < 1) results in more "
+ "bursty requests. A higher burstiness value (burstiness > 1) "
+ "results in a more uniform arrival of requests.",
+ )
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument(
+ "--trust-remote-code",
+ action="store_true",
+ help="Trust remote code from huggingface",
+ )
+ parser.add_argument(
+ "--disable-tqdm",
+ action="store_true",
+ help="Specify to disable tqdm progress bar.",
+ )
+ parser.add_argument(
+ "--profile",
+ action="store_true",
+ help="Use Torch Profiler. The endpoint must be launched with "
+ "VLLM_TORCH_PROFILER_DIR to enable profiler.",
+ )
+ parser.add_argument(
+ "--save-result",
+ action="store_true",
+ help="Specify to save benchmark results to a json file",
+ )
+ parser.add_argument(
+ "--save-detailed",
+ action="store_true",
+ help="When saving the results, whether to include per request "
+ "information such as response, error, ttfs, tpots, etc.",
+ )
+ parser.add_argument(
+ "--metadata",
+ metavar="KEY=VALUE",
+ nargs="*",
+ help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
+ "for metadata of this run to be saved in the result JSON file "
+ "for record keeping purposes.",
+ )
+ parser.add_argument(
+ "--result-dir",
+ type=str,
+ default=None,
+ help="Specify directory to save benchmark json results."
+ "If not specified, results are saved in the current directory.",
+ )
+ parser.add_argument(
+ "--result-filename",
+ type=str,
+ default=None,
+ help="Specify the filename to save benchmark json results."
+ "If not specified, results will be saved in "
+ "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
+ " format.",
+ )
+ parser.add_argument(
+ "--ignore-eos",
+ action="store_true",
+ help="Set ignore_eos flag when sending the benchmark request."
+ "Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
+ parser.add_argument(
+ "--percentile-metrics",
+ type=str,
+ default="ttft,tpot,itl",
+ help="Comma-separated list of selected metrics to report percentils. "
+ "This argument specifies the metrics to report percentiles. "
+ "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
+ "Default value is \"ttft,tpot,itl\".")
+ parser.add_argument(
+ "--metric-percentiles",
+ type=str,
+ default="99",
+ help="Comma-separated list of percentiles for selected metrics. "
+ "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
+ "Default value is \"99\". "
+ "Use \"--percentile-metrics\" to select metrics.",
+ )
+ parser.add_argument(
+ "--goodput",
+ nargs="+",
+ required=False,
+ help="Specify service level objectives for goodput as \"KEY:VALUE\" "
+ "pairs, where the key is a metric name, and the value is in "
+ "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
+ "separated by spaces. Allowed request level metric names are "
+ "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
+ "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
+ "and the blog: https://hao-ai-lab.github.io/blogs/distserve")
+
+ # group for dataset specific arguments
+ sonnet_group = parser.add_argument_group("sonnet dataset options")
+ sonnet_group.add_argument(
+ "--sonnet-input-len",
+ type=int,
+ default=550,
+ help="Number of input tokens per request, used only for sonnet dataset.",
+ )
+ sonnet_group.add_argument(
+ "--sonnet-output-len",
+ type=int,
+ default=150,
+ help="Number of output tokens per request, used only for sonnet dataset.",
+ )
+ sonnet_group.add_argument(
+ "--sonnet-prefix-len",
+ type=int,
+ default=200,
+ help="Number of prefix tokens per request, used only for sonnet dataset.",
+ )
+
+ sharegpt_group = parser.add_argument_group("sharegpt dataset options")
+ sharegpt_group.add_argument(
+ "--sharegpt-output-len",
+ type=int,
+ default=None,
+ help="Output length for each request. Overrides the output length "
+ "from the ShareGPT dataset.")
+
+ random_group = parser.add_argument_group("random dataset options")
+ random_group.add_argument(
+ "--random-input-len",
+ type=int,
+ default=1024,
+ help="Number of input tokens per request, used only for random sampling.",
+ )
+ random_group.add_argument(
+ "--random-output-len",
+ type=int,
+ default=128,
+ help="Number of output tokens per request, used only for random sampling.",
+ )
+ random_group.add_argument(
+ "--random-range-ratio",
+ type=float,
+ default=0.0,
+ help="Range ratio for sampling input/output length, "
+ "used only for random sampling. Must be in the range [0, 1) to define "
+ "a symmetric sampling range"
+ "[length * (1 - range_ratio), length * (1 + range_ratio)].",
+ )
+ random_group.add_argument(
+ "--random-prefix-len",
+ type=int,
+ default=0,
+ help=("Number of fixed prefix tokens before the random context "
+ "in a request. "
+ "The total input length is the sum of `random-prefix-len` and "
+ "a random "
+ "context length sampled from [input_len * (1 - range_ratio), "
+ "input_len * (1 + range_ratio)]."),
+ )
+
+ hf_group = parser.add_argument_group("hf dataset options")
+ hf_group.add_argument("--hf-subset",
+ type=str,
+ default=None,
+ help="Subset of the HF dataset.")
+ hf_group.add_argument("--hf-split",
+ type=str,
+ default=None,
+ help="Split of the HF dataset.")
+ hf_group.add_argument(
+ "--hf-output-len",
+ type=int,
+ default=None,
+ help="Output length for each request. Overrides the output lengths "
+ "from the sampled HF dataset.",
+ )
+
+ sampling_group = parser.add_argument_group("sampling parameters")
+ sampling_group.add_argument(
+ "--top-p",
+ type=float,
+ default=None,
+ help="Top-p sampling parameter. Only has effect on openai-compatible "
+ "backends.")
+ sampling_group.add_argument(
+ "--top-k",
+ type=int,
+ default=None,
+ help="Top-k sampling parameter. Only has effect on openai-compatible "
+ "backends.")
+ sampling_group.add_argument(
+ "--min-p",
+ type=float,
+ default=None,
+ help="Min-p sampling parameter. Only has effect on openai-compatible "
+ "backends.")
+ sampling_group.add_argument(
+ "--temperature",
+ type=float,
+ default=None,
+ help="Temperature sampling parameter. Only has effect on "
+ "openai-compatible backends. If not specified, default to greedy "
+ "decoding (i.e. temperature==0.0).")
+
+ parser.add_argument(
+ '--tokenizer-mode',
+ type=str,
+ default="auto",
+ choices=['auto', 'slow', 'mistral', 'custom'],
+ help='The tokenizer mode.\n\n* "auto" will use the '
+ 'fast tokenizer if available.\n* "slow" will '
+ 'always use the slow tokenizer. \n* '
+ '"mistral" will always use the `mistral_common` tokenizer. \n*'
+ '"custom" will use --tokenizer to select the preregistered tokenizer.')
+
+ parser.add_argument("--served-model-name",
+ type=str,
+ default=None,
+ help="The model name used in the API. "
+ "If not specified, the model name will be the "
+ "same as the ``--model`` argument. ")
+
+ parser.add_argument("--lora-modules",
+ nargs='+',
+ default=None,
+ help="A subset of LoRA module names passed in when "
+ "launching the server. For each request, the "
+ "script chooses a LoRA module at random.")
+
+ args = parser.parse_args()
+
+ main(args)
+
diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py
new file mode 100644
index 000000000..6c149bf5f
--- /dev/null
+++ b/benchmarks/benchmark_utils.py
@@ -0,0 +1,90 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+# This file is modified from https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_utils.py
+
+
+import argparse
+import json
+import math
+import os
+from typing import Any
+
+
+def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
+ metrics: dict[str, list],
+ extra_info: dict[str, Any]) -> list:
+ """
+ Save the benchmark results in the format used by PyTorch OSS benchmark with
+ on metric per record
+ https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
+ """
+ records = []
+ if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
+ return records
+
+ for name, benchmark_values in metrics.items():
+ record = {
+ "benchmark": {
+ "name": "vLLM benchmark",
+ "extra_info": {
+ "args": vars(args),
+ },
+ },
+ "model": {
+ "name": args.model,
+ },
+ "metric": {
+ "name": name,
+ "benchmark_values": benchmark_values,
+ "extra_info": extra_info,
+ },
+ }
+
+ tp = record["benchmark"]["extra_info"]["args"].get(
+ "tensor_parallel_size")
+ # Save tensor_parallel_size parameter if it's part of the metadata
+ if not tp and "tensor_parallel_size" in extra_info:
+ record["benchmark"]["extra_info"]["args"][
+ "tensor_parallel_size"] = extra_info["tensor_parallel_size"]
+
+ records.append(record)
+
+ return records
+
+
+class InfEncoder(json.JSONEncoder):
+ """InfEncoder"""
+ def clear_inf(self, o: Any):
+ """clear_inf"""
+ if isinstance(o, dict):
+ return {k: self.clear_inf(v) for k, v in o.items()}
+ elif isinstance(o, list):
+ return [self.clear_inf(v) for v in o]
+ elif isinstance(o, float) and math.isinf(o):
+ return "inf"
+ return o
+
+ def iterencode(self, o: Any, *args, **kwargs) -> Any:
+ """iterencode"""
+ return super().iterencode(self.clear_inf(o), *args, **kwargs)
+
+
+def write_to_json(filename: str, records: list) -> None:
+ """write_to_json"""
+ with open(filename, "w") as f:
+ json.dump(records, f, cls=InfEncoder)
+
diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt
new file mode 100644
index 000000000..1ad085b79
--- /dev/null
+++ b/benchmarks/requirements.txt
@@ -0,0 +1,5 @@
+aiohttp
+tqdm
+numpy
+Pillow
+pyyaml
diff --git a/benchmarks/yaml/eb45-128k-wint4-a800-tp8.yaml b/benchmarks/yaml/eb45-128k-wint4-a800-tp8.yaml
new file mode 100644
index 000000000..280f8e336
--- /dev/null
+++ b/benchmarks/yaml/eb45-128k-wint4-a800-tp8.yaml
@@ -0,0 +1,8 @@
+enable_chunked_prefill: True
+max_model_len: 131072
+max_num_seqs: 16
+kv_cache_ratio: 0.75
+tensor_parallel_size: 8
+max_num_batched_tokens: 4096
+max_num_partial_prefills: 3
+max_long_partial_prefills: 3
diff --git a/benchmarks/yaml/eb45-128k-wint4-p800-tp8.yaml b/benchmarks/yaml/eb45-128k-wint4-p800-tp8.yaml
new file mode 100644
index 000000000..d3aaa9243
--- /dev/null
+++ b/benchmarks/yaml/eb45-128k-wint4-p800-tp8.yaml
@@ -0,0 +1,5 @@
+max_model_len: 131072
+max_num_seqs: 40
+gpu_memory_utilization: 0.9
+tensor_parallel_size: 8
+quantization: wint4
diff --git a/benchmarks/yaml/eb45-128k-wint8-a800-tp8.yaml b/benchmarks/yaml/eb45-128k-wint8-a800-tp8.yaml
new file mode 100644
index 000000000..280f8e336
--- /dev/null
+++ b/benchmarks/yaml/eb45-128k-wint8-a800-tp8.yaml
@@ -0,0 +1,8 @@
+enable_chunked_prefill: True
+max_model_len: 131072
+max_num_seqs: 16
+kv_cache_ratio: 0.75
+tensor_parallel_size: 8
+max_num_batched_tokens: 4096
+max_num_partial_prefills: 3
+max_long_partial_prefills: 3
diff --git a/benchmarks/yaml/eb45-21B-vl-128k-wint4-h800-tp1.yaml b/benchmarks/yaml/eb45-21B-vl-128k-wint4-h800-tp1.yaml
new file mode 100644
index 000000000..db8a20b86
--- /dev/null
+++ b/benchmarks/yaml/eb45-21B-vl-128k-wint4-h800-tp1.yaml
@@ -0,0 +1,10 @@
+enable_mm: True
+max_model_len: 32768
+max_num_seqs: 128
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.71
+tensor_parallel_size: 1
+enable_chunked_prefill: True
+max_num_batched_tokens: 384
+quantization: wint4
+reasoning_parser: ernie-45-vl
\ No newline at end of file
diff --git a/benchmarks/yaml/eb45-21b-a3b-32k-bf16.yaml b/benchmarks/yaml/eb45-21b-a3b-32k-bf16.yaml
new file mode 100644
index 000000000..f57706607
--- /dev/null
+++ b/benchmarks/yaml/eb45-21b-a3b-32k-bf16.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+max_num_batched_tokens: 32768
diff --git a/benchmarks/yaml/eb45-21b-a3b-32k-wint4-a10.yaml b/benchmarks/yaml/eb45-21b-a3b-32k-wint4-a10.yaml
new file mode 100644
index 000000000..783a42c6b
--- /dev/null
+++ b/benchmarks/yaml/eb45-21b-a3b-32k-wint4-a10.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 32
+kv_cache_ratio: 0.5
+tensor_parallel_size: 1
+quantization: wint4
diff --git a/benchmarks/yaml/eb45-21b-a3b-32k-wint4.yaml b/benchmarks/yaml/eb45-21b-a3b-32k-wint4.yaml
new file mode 100644
index 000000000..366b4952e
--- /dev/null
+++ b/benchmarks/yaml/eb45-21b-a3b-32k-wint4.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+max_num_batched_tokens: 32768
+quantization: wint4
diff --git a/benchmarks/yaml/eb45-21b-a3b-32k-wint8.yaml b/benchmarks/yaml/eb45-21b-a3b-32k-wint8.yaml
new file mode 100644
index 000000000..b5add626e
--- /dev/null
+++ b/benchmarks/yaml/eb45-21b-a3b-32k-wint8.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+max_num_batched_tokens: 32768
+quantization: wint8
diff --git a/benchmarks/yaml/eb45-32k-bf16-a30-tp1.yaml b/benchmarks/yaml/eb45-32k-bf16-a30-tp1.yaml
new file mode 100644
index 000000000..f57706607
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-bf16-a30-tp1.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+max_num_batched_tokens: 32768
diff --git a/benchmarks/yaml/eb45-32k-blockwise-fp8-h800-tp8.yaml b/benchmarks/yaml/eb45-32k-blockwise-fp8-h800-tp8.yaml
new file mode 100644
index 000000000..b2f9a7457
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-blockwise-fp8-h800-tp8.yaml
@@ -0,0 +1,12 @@
+max_model_len: 32768
+max_num_seqs: 256
+tensor_parallel_size: 8
+quantization: block_wise_fp8
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.8
+enable_chunked_prefill: True
+max_num_batched_tokens: 1024
+max_num_partial_prefills: 3
+max_long_partial_prefills: 3
+enable_prefix_caching: True
+swap_space: 200
diff --git a/benchmarks/yaml/eb45-32k-tensorwise-fp8-h800-tp8.yaml b/benchmarks/yaml/eb45-32k-tensorwise-fp8-h800-tp8.yaml
new file mode 100644
index 000000000..47d1bfbcd
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-tensorwise-fp8-h800-tp8.yaml
@@ -0,0 +1,11 @@
+max_model_len: 32768
+max_num_seqs: 256
+tensor_parallel_size: 8
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.8
+enable_chunked_prefill: True
+max_num_batched_tokens: 1024
+max_num_partial_prefills: 3
+max_long_partial_prefills: 3
+enable_prefix_caching: True
+swap_space: 200
diff --git a/benchmarks/yaml/eb45-32k-w4a8c8-a800-tp4.yaml b/benchmarks/yaml/eb45-32k-w4a8c8-a800-tp4.yaml
new file mode 100644
index 000000000..6ac9a2188
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-w4a8c8-a800-tp4.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 96
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.71
+tensor_parallel_size: 4
diff --git a/benchmarks/yaml/eb45-32k-w4a8c8-tp4_decode.yaml b/benchmarks/yaml/eb45-32k-w4a8c8-tp4_decode.yaml
new file mode 100644
index 000000000..957f59d2a
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-w4a8c8-tp4_decode.yaml
@@ -0,0 +1,15 @@
+max_model_len: 32768
+max_num_seqs: 256
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.8
+tensor_parallel_size: 4
+cache_queue_port: 55663
+enable_chunked_prefill: True
+splitwise_role: decode
+engine_worker_queue_port: 6678
+cache_transfer_protocol: "rdma,ipc"
+rdma_comm_ports: "7671,7672,7673,7674"
+pd_comm_port: "2334"
+max_num_batched_tokens: 384
+max_num_partial_prefills: 3
+max_long_partial_prefills: 3
\ No newline at end of file
diff --git a/benchmarks/yaml/eb45-32k-w4a8c8-tp4_prefill.yaml b/benchmarks/yaml/eb45-32k-w4a8c8-tp4_prefill.yaml
new file mode 100644
index 000000000..c1466160d
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-w4a8c8-tp4_prefill.yaml
@@ -0,0 +1,12 @@
+max_model_len: 32768
+max_num_seqs: 16
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.9
+tensor_parallel_size: 4
+splitwise_role: prefill
+enable_prefix_caching: True
+cache_queue_port: 55664
+engine_worker_queue_port: 6677
+cache_transfer_protocol: "rdma,ipc"
+rdma_comm_ports: "7675,7676,7677,7678"
+pd_comm_port: "2333"
\ No newline at end of file
diff --git a/benchmarks/yaml/eb45-32k-wint2-h20-tp1.yaml b/benchmarks/yaml/eb45-32k-wint2-h20-tp1.yaml
new file mode 100644
index 000000000..af8d49e80
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint2-h20-tp1.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+enable_prefix_caching: true
+enable_chunked_prefill: true
diff --git a/benchmarks/yaml/eb45-32k-wint4-a800-tp4.yaml b/benchmarks/yaml/eb45-32k-wint4-a800-tp4.yaml
new file mode 100644
index 000000000..6ac9a2188
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint4-a800-tp4.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 96
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.71
+tensor_parallel_size: 4
diff --git a/benchmarks/yaml/eb45-32k-wint4-h800-dp8_decode.yaml b/benchmarks/yaml/eb45-32k-wint4-h800-dp8_decode.yaml
new file mode 100644
index 000000000..2e00aad6d
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint4-h800-dp8_decode.yaml
@@ -0,0 +1,13 @@
+max_model_len: 32768
+max_num_seqs: 256
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.8
+tensor_parallel_size: 1
+data_parallel_size: 8
+num_gpu_blocks_override: 1024
+cache_queue_port: 55663
+splitwise_role: decode
+engine_worker_queue_port: 6678
+cache_transfer_protocol: "rdma"
+rdma_comm_ports: "7671,7672,7673,7674,7675,7676,7677,7678"
+pd_comm_port: "2334"
diff --git a/benchmarks/yaml/eb45-32k-wint4-h800-dp8_prefill.yaml b/benchmarks/yaml/eb45-32k-wint4-h800-dp8_prefill.yaml
new file mode 100644
index 000000000..e6d0fa6e0
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint4-h800-dp8_prefill.yaml
@@ -0,0 +1,13 @@
+max_model_len: 32768
+max_num_seqs: 16
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.9
+tensor_parallel_size: 1
+data_parallel_size: 8
+splitwise_role: prefill
+cache_queue_port: 55664
+engine_worker_queue_port: 6677
+num_gpu_blocks_override: 1024
+cache_transfer_protocol: "rdma"
+rdma_comm_ports: "7671,7672,7673,7674,7675,7676,7677,7678"
+pd_comm_port: "2334"
\ No newline at end of file
diff --git a/benchmarks/yaml/eb45-32k-wint4-mtp-h800-tp4.yaml b/benchmarks/yaml/eb45-32k-wint4-mtp-h800-tp4.yaml
new file mode 100644
index 000000000..c609fba49
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint4-mtp-h800-tp4.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 96
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.71
+tensor_parallel_size: 4
+quantization: wint4
diff --git a/benchmarks/yaml/eb45-32k-wint4-mtp-tp4-decode.yaml b/benchmarks/yaml/eb45-32k-wint4-mtp-tp4-decode.yaml
new file mode 100644
index 000000000..e239cea89
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint4-mtp-tp4-decode.yaml
@@ -0,0 +1,13 @@
+max_model_len: 32768
+max_num_seqs: 128
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.7
+tensor_parallel_size: 4
+cache_queue_port: 55663
+enable_chunked_prefill: False
+enable_prefix_caching: False
+splitwise_role: decode
+engine_worker_queue_port: 6678
+cache_transfer_protocol: "rdma,ipc"
+rdma_comm_ports: "7671,7672,7673,7674"
+pd_comm_port: "2334"
\ No newline at end of file
diff --git a/benchmarks/yaml/eb45-32k-wint4-mtp-tp4-prefill.yaml b/benchmarks/yaml/eb45-32k-wint4-mtp-tp4-prefill.yaml
new file mode 100644
index 000000000..6d759c843
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint4-mtp-tp4-prefill.yaml
@@ -0,0 +1,12 @@
+max_model_len: 32768
+max_num_seqs: 16
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.9
+tensor_parallel_size: 4
+splitwise_role: prefill
+enable_prefix_caching: False
+cache_queue_port: 55664
+engine_worker_queue_port: 6677
+cache_transfer_protocol: "rdma,ipc"
+rdma_comm_ports: "7675,7676,7677,7678"
+pd_comm_port: "2333"
\ No newline at end of file
diff --git a/benchmarks/yaml/eb45-32k-wint4-p800-tp4.yaml b/benchmarks/yaml/eb45-32k-wint4-p800-tp4.yaml
new file mode 100644
index 000000000..14f025dc0
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint4-p800-tp4.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 40
+tensor_parallel_size: 4
+quantization: wint4
+gpu_memory_utilization: 0.9
diff --git a/benchmarks/yaml/eb45-32k-wint4-p800-tp8.yaml b/benchmarks/yaml/eb45-32k-wint4-p800-tp8.yaml
new file mode 100644
index 000000000..b5059f185
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint4-p800-tp8.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 160
+tensor_parallel_size: 8
+quantization: wint4
+gpu_memory_utilization: 0.9
diff --git a/benchmarks/yaml/eb45-32k-wint4-prefixcache-a800-tp4.yaml b/benchmarks/yaml/eb45-32k-wint4-prefixcache-a800-tp4.yaml
new file mode 100644
index 000000000..5a5de2aba
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint4-prefixcache-a800-tp4.yaml
@@ -0,0 +1,8 @@
+enable_prefix_caching: True
+max_model_len: 32768
+max_num_seqs: 128
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.71
+tensor_parallel_size: 4
+swap_space: 200
+cache_queue_port: 55664
diff --git a/benchmarks/yaml/eb45-32k-wint4-tp4_decode.yaml b/benchmarks/yaml/eb45-32k-wint4-tp4_decode.yaml
new file mode 100644
index 000000000..957f59d2a
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint4-tp4_decode.yaml
@@ -0,0 +1,15 @@
+max_model_len: 32768
+max_num_seqs: 256
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.8
+tensor_parallel_size: 4
+cache_queue_port: 55663
+enable_chunked_prefill: True
+splitwise_role: decode
+engine_worker_queue_port: 6678
+cache_transfer_protocol: "rdma,ipc"
+rdma_comm_ports: "7671,7672,7673,7674"
+pd_comm_port: "2334"
+max_num_batched_tokens: 384
+max_num_partial_prefills: 3
+max_long_partial_prefills: 3
\ No newline at end of file
diff --git a/benchmarks/yaml/eb45-32k-wint4-tp4_prefill.yaml b/benchmarks/yaml/eb45-32k-wint4-tp4_prefill.yaml
new file mode 100644
index 000000000..c1466160d
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint4-tp4_prefill.yaml
@@ -0,0 +1,12 @@
+max_model_len: 32768
+max_num_seqs: 16
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.9
+tensor_parallel_size: 4
+splitwise_role: prefill
+enable_prefix_caching: True
+cache_queue_port: 55664
+engine_worker_queue_port: 6677
+cache_transfer_protocol: "rdma,ipc"
+rdma_comm_ports: "7675,7676,7677,7678"
+pd_comm_port: "2333"
\ No newline at end of file
diff --git a/benchmarks/yaml/eb45-32k-wint8-a800-tp8.yaml b/benchmarks/yaml/eb45-32k-wint8-a800-tp8.yaml
new file mode 100644
index 000000000..a8a51c086
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint8-a800-tp8.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 96
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.71
+tensor_parallel_size: 8
diff --git a/benchmarks/yaml/eb45-32k-wint8-p800-tp8.yaml b/benchmarks/yaml/eb45-32k-wint8-p800-tp8.yaml
new file mode 100644
index 000000000..f1fde433f
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint8-p800-tp8.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 80
+tensor_parallel_size: 8
+quantization: wint8
+gpu_memory_utilization: 0.9
diff --git a/benchmarks/yaml/eb45-32k-wint8-prefixcache-a800-tp8.yaml b/benchmarks/yaml/eb45-32k-wint8-prefixcache-a800-tp8.yaml
new file mode 100644
index 000000000..e597f5bb7
--- /dev/null
+++ b/benchmarks/yaml/eb45-32k-wint8-prefixcache-a800-tp8.yaml
@@ -0,0 +1,9 @@
+enable_prefix_caching: True
+max_model_len: 32768
+max_num_batched_tokens: 68304
+max_num_seqs: 128
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.71
+tensor_parallel_size: 8
+swap_space: 100
+cache_queue_port: 55664
diff --git a/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8.yaml b/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8.yaml
new file mode 100644
index 000000000..1a53f9b9a
--- /dev/null
+++ b/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8.yaml
@@ -0,0 +1,9 @@
+enable_mm: True
+max_model_len: 32768
+max_num_seqs: 56
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.8
+tensor_parallel_size: 8
+quantization: wint4
+limit_mm_per_prompt: '{"image": 100, "video": 100}'
+reasoning_parser: ernie-45-vl
diff --git a/benchmarks/yaml/eb45-vl-32k-wint4-h800-tp8.yaml b/benchmarks/yaml/eb45-vl-32k-wint4-h800-tp8.yaml
new file mode 100644
index 000000000..31d3f5a14
--- /dev/null
+++ b/benchmarks/yaml/eb45-vl-32k-wint4-h800-tp8.yaml
@@ -0,0 +1,11 @@
+enable_mm: True
+max_model_len: 32768
+max_num_seqs: 56
+gpu_memory_utilization: 0.8
+kv_cache_ratio: 0.8
+tensor_parallel_size: 8
+quantization: wint4
+limit_mm_per_prompt: '{"image": 100, "video": 100}'
+enable_chunked_prefill: True
+max_num_batched_tokens: 384
+reasoning_parser: ernie-45-vl
diff --git a/benchmarks/yaml/eb45-vl-32k-wint4-tp4.yaml b/benchmarks/yaml/eb45-vl-32k-wint4-tp4.yaml
new file mode 100644
index 000000000..9646a4c61
--- /dev/null
+++ b/benchmarks/yaml/eb45-vl-32k-wint4-tp4.yaml
@@ -0,0 +1,9 @@
+enable_mm: True
+max_model_len: 32768
+max_num_seqs: 36
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.8
+tensor_parallel_size: 4
+quantization: wint4
+limit_mm_per_prompt: '{"image": 100, "video": 100}'
+reasoning_parser: ernie-45-vl
diff --git a/benchmarks/yaml/eb45-vl-32k-wint8-a800-tp8.yaml b/benchmarks/yaml/eb45-vl-32k-wint8-a800-tp8.yaml
new file mode 100644
index 000000000..3c803e662
--- /dev/null
+++ b/benchmarks/yaml/eb45-vl-32k-wint8-a800-tp8.yaml
@@ -0,0 +1,9 @@
+enable_mm: True
+max_model_len: 32768
+max_num_seqs: 36
+gpu_memory_utilization: 0.95
+kv_cache_ratio: 0.8
+tensor_parallel_size: 8
+quantization: wint8
+limit_mm_per_prompt: '{"image": 100, "video": 100}'
+reasoning_parser: ernie-45-vl
diff --git a/benchmarks/yaml/eb45-vl-32k-wint8-h800-tp8.yaml b/benchmarks/yaml/eb45-vl-32k-wint8-h800-tp8.yaml
new file mode 100644
index 000000000..ff9611f5d
--- /dev/null
+++ b/benchmarks/yaml/eb45-vl-32k-wint8-h800-tp8.yaml
@@ -0,0 +1,11 @@
+enable_mm: True
+max_model_len: 32768
+max_num_seqs: 36
+gpu_memory_utilization: 0.8
+kv_cache_ratio: 0.8
+tensor_parallel_size: 8
+quantization: wint8
+limit_mm_per_prompt: '{"image": 100, "video": 100}'
+enable_chunked_prefill: True
+max_num_batched_tokens: 384
+reasoning_parser: ernie-45-vl
diff --git a/benchmarks/yaml/eb45-vl-32k-wint8-tp4.yaml b/benchmarks/yaml/eb45-vl-32k-wint8-tp4.yaml
new file mode 100644
index 000000000..e01db1566
--- /dev/null
+++ b/benchmarks/yaml/eb45-vl-32k-wint8-tp4.yaml
@@ -0,0 +1,9 @@
+enable_mm: True
+max_model_len: 32768
+max_num_seqs: 36
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.8
+tensor_parallel_size: 4
+quantization: wint8
+limit_mm_per_prompt: '{"image": 100, "video": 100}'
+reasoning_parser: ernie-45-vl
diff --git a/benchmarks/yaml/eb45t_0dot3b-32k-bf16-a30-tp1-static.yaml b/benchmarks/yaml/eb45t_0dot3b-32k-bf16-a30-tp1-static.yaml
new file mode 100644
index 000000000..55a37e029
--- /dev/null
+++ b/benchmarks/yaml/eb45t_0dot3b-32k-bf16-a30-tp1-static.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/eb45t_0dot3b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/eb45t_0dot3b-32k-bf16-h800-tp1-static.yaml
new file mode 100644
index 000000000..55a37e029
--- /dev/null
+++ b/benchmarks/yaml/eb45t_0dot3b-32k-bf16-h800-tp1-static.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/eb45t_0dot3b-32k-wint8-a30-tp1-static.yaml b/benchmarks/yaml/eb45t_0dot3b-32k-wint8-a30-tp1-static.yaml
new file mode 100644
index 000000000..14024b565
--- /dev/null
+++ b/benchmarks/yaml/eb45t_0dot3b-32k-wint8-a30-tp1-static.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+quantization: wint8
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/eb45t_0dot3b-32k-wint8-h800-tp1-static.yaml b/benchmarks/yaml/eb45t_0dot3b-32k-wint8-h800-tp1-static.yaml
new file mode 100644
index 000000000..14024b565
--- /dev/null
+++ b/benchmarks/yaml/eb45t_0dot3b-32k-wint8-h800-tp1-static.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+quantization: wint8
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/eb45t_21b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/eb45t_21b-32k-bf16-h800-tp1-static.yaml
new file mode 100644
index 000000000..55a37e029
--- /dev/null
+++ b/benchmarks/yaml/eb45t_21b-32k-bf16-h800-tp1-static.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/eb45t_21b-32k-wint4-h800-tp1-static.yaml b/benchmarks/yaml/eb45t_21b-32k-wint4-h800-tp1-static.yaml
new file mode 100644
index 000000000..010dd3bc3
--- /dev/null
+++ b/benchmarks/yaml/eb45t_21b-32k-wint4-h800-tp1-static.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+quantization: wint4
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/eb45t_300b-32k-wint4-h800-tp4-static.yaml b/benchmarks/yaml/eb45t_300b-32k-wint4-h800-tp4-static.yaml
new file mode 100644
index 000000000..eec95559d
--- /dev/null
+++ b/benchmarks/yaml/eb45t_300b-32k-wint4-h800-tp4-static.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 96
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.71
+tensor_parallel_size: 4
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/qwen2_7b-32k-bf16-a30-tp1-static.yaml b/benchmarks/yaml/qwen2_7b-32k-bf16-a30-tp1-static.yaml
new file mode 100644
index 000000000..55a37e029
--- /dev/null
+++ b/benchmarks/yaml/qwen2_7b-32k-bf16-a30-tp1-static.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1-static.yaml
new file mode 100644
index 000000000..55a37e029
--- /dev/null
+++ b/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1-static.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1.yaml b/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1.yaml
new file mode 100644
index 000000000..c88178259
--- /dev/null
+++ b/benchmarks/yaml/qwen2_7b-32k-bf16-h800-tp1.yaml
@@ -0,0 +1,4 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
diff --git a/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1-static.yaml b/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1-static.yaml
new file mode 100644
index 000000000..8cdc10498
--- /dev/null
+++ b/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1-static.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+quantization: wfp8afp8
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1.yaml b/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1.yaml
new file mode 100644
index 000000000..d766c9f53
--- /dev/null
+++ b/benchmarks/yaml/qwen2_7b-32k-fp8-h800-tp1.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+quantization: wfp8afp8
diff --git a/benchmarks/yaml/qwen2_7b-32k-wint8-h800-tp1.yaml b/benchmarks/yaml/qwen2_7b-32k-wint8-h800-tp1.yaml
new file mode 100644
index 000000000..90af4a558
--- /dev/null
+++ b/benchmarks/yaml/qwen2_7b-32k-wint8-h800-tp1.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+quantization: wint8
diff --git a/benchmarks/yaml/qwen3_0dot6b-32k-bf16-a30-tp1-static.yaml b/benchmarks/yaml/qwen3_0dot6b-32k-bf16-a30-tp1-static.yaml
new file mode 100644
index 000000000..55a37e029
--- /dev/null
+++ b/benchmarks/yaml/qwen3_0dot6b-32k-bf16-a30-tp1-static.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/qwen3_0dot6b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/qwen3_0dot6b-32k-bf16-h800-tp1-static.yaml
new file mode 100644
index 000000000..55a37e029
--- /dev/null
+++ b/benchmarks/yaml/qwen3_0dot6b-32k-bf16-h800-tp1-static.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/qwen3_0dot6b-32k-wint8-a30-tp1-static.yaml b/benchmarks/yaml/qwen3_0dot6b-32k-wint8-a30-tp1-static.yaml
new file mode 100644
index 000000000..14024b565
--- /dev/null
+++ b/benchmarks/yaml/qwen3_0dot6b-32k-wint8-a30-tp1-static.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+quantization: wint8
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/qwen3_0dot6b-32k-wint8-h800-tp1-static.yaml b/benchmarks/yaml/qwen3_0dot6b-32k-wint8-h800-tp1-static.yaml
new file mode 100644
index 000000000..14024b565
--- /dev/null
+++ b/benchmarks/yaml/qwen3_0dot6b-32k-wint8-h800-tp1-static.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+quantization: wint8
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/qwen3_30b-32k-bf16-h800-tp1-static.yaml b/benchmarks/yaml/qwen3_30b-32k-bf16-h800-tp1-static.yaml
new file mode 100644
index 000000000..55a37e029
--- /dev/null
+++ b/benchmarks/yaml/qwen3_30b-32k-bf16-h800-tp1-static.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/qwen3_30b-32k-wint4-h800-tp1-static.yaml b/benchmarks/yaml/qwen3_30b-32k-wint4-h800-tp1-static.yaml
new file mode 100644
index 000000000..010dd3bc3
--- /dev/null
+++ b/benchmarks/yaml/qwen3_30b-32k-wint4-h800-tp1-static.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 128
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
+quantization: wint4
+enable_static_graph_inference: True
diff --git a/benchmarks/yaml/qwen3dot6b-32k-bf16-a30-tp1.yaml b/benchmarks/yaml/qwen3dot6b-32k-bf16-a30-tp1.yaml
new file mode 100644
index 000000000..45ee7d14e
--- /dev/null
+++ b/benchmarks/yaml/qwen3dot6b-32k-bf16-a30-tp1.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 256
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
diff --git a/benchmarks/yaml/qwen3dot6b-32k-bf16-a800-tp1.yaml b/benchmarks/yaml/qwen3dot6b-32k-bf16-a800-tp1.yaml
new file mode 100644
index 000000000..45ee7d14e
--- /dev/null
+++ b/benchmarks/yaml/qwen3dot6b-32k-bf16-a800-tp1.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 256
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
diff --git a/benchmarks/yaml/qwen3dot6b-32k-bf16-h800-tp1.yaml b/benchmarks/yaml/qwen3dot6b-32k-bf16-h800-tp1.yaml
new file mode 100644
index 000000000..45ee7d14e
--- /dev/null
+++ b/benchmarks/yaml/qwen3dot6b-32k-bf16-h800-tp1.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 256
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
diff --git a/benchmarks/yaml/qwen3dot6b-32k-wint8-a30-tp1.yaml b/benchmarks/yaml/qwen3dot6b-32k-wint8-a30-tp1.yaml
new file mode 100644
index 000000000..60a6dbeef
--- /dev/null
+++ b/benchmarks/yaml/qwen3dot6b-32k-wint8-a30-tp1.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 256
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.75
+quantization: wint8
+tensor_parallel_size: 1
diff --git a/benchmarks/yaml/qwen3dot6b-32k-wint8-a800-tp1.yaml b/benchmarks/yaml/qwen3dot6b-32k-wint8-a800-tp1.yaml
new file mode 100644
index 000000000..60a6dbeef
--- /dev/null
+++ b/benchmarks/yaml/qwen3dot6b-32k-wint8-a800-tp1.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 256
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.75
+quantization: wint8
+tensor_parallel_size: 1
diff --git a/benchmarks/yaml/qwen3dot6b-32k-wint8-h800-tp1.yaml b/benchmarks/yaml/qwen3dot6b-32k-wint8-h800-tp1.yaml
new file mode 100644
index 000000000..60a6dbeef
--- /dev/null
+++ b/benchmarks/yaml/qwen3dot6b-32k-wint8-h800-tp1.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 256
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.75
+quantization: wint8
+tensor_parallel_size: 1
diff --git a/benchmarks/yaml/qwen3moe235b-32k-wint4-h800-tp4.yaml b/benchmarks/yaml/qwen3moe235b-32k-wint4-h800-tp4.yaml
new file mode 100644
index 000000000..7a127995e
--- /dev/null
+++ b/benchmarks/yaml/qwen3moe235b-32k-wint4-h800-tp4.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 75
+gpu_memory_utilization: 0.85
+kv_cache_ratio: 0.75
+quantization: wint4
+tensor_parallel_size: 4
\ No newline at end of file
diff --git a/benchmarks/yaml/qwen3moe235b-32k-wint8-h800-tp4.yaml b/benchmarks/yaml/qwen3moe235b-32k-wint8-h800-tp4.yaml
new file mode 100644
index 000000000..4d6cff601
--- /dev/null
+++ b/benchmarks/yaml/qwen3moe235b-32k-wint8-h800-tp4.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 25
+gpu_memory_utilization: 0.9
+kv_cache_ratio: 0.75
+quantization: wint8
+tensor_parallel_size: 4
\ No newline at end of file
diff --git a/benchmarks/yaml/qwen3moe30b-32k-bf16-a800-tp1.yaml b/benchmarks/yaml/qwen3moe30b-32k-bf16-a800-tp1.yaml
new file mode 100644
index 000000000..00fa7bef0
--- /dev/null
+++ b/benchmarks/yaml/qwen3moe30b-32k-bf16-a800-tp1.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 50
+gpu_memory_utilization: 0.85
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
diff --git a/benchmarks/yaml/qwen3moe30b-32k-bf16-h800-tp1.yaml b/benchmarks/yaml/qwen3moe30b-32k-bf16-h800-tp1.yaml
new file mode 100644
index 000000000..00fa7bef0
--- /dev/null
+++ b/benchmarks/yaml/qwen3moe30b-32k-bf16-h800-tp1.yaml
@@ -0,0 +1,5 @@
+max_model_len: 32768
+max_num_seqs: 50
+gpu_memory_utilization: 0.85
+kv_cache_ratio: 0.75
+tensor_parallel_size: 1
diff --git a/benchmarks/yaml/qwen3moe30b-32k-wint4-a800-tp1.yaml b/benchmarks/yaml/qwen3moe30b-32k-wint4-a800-tp1.yaml
new file mode 100644
index 000000000..8ed7b40b3
--- /dev/null
+++ b/benchmarks/yaml/qwen3moe30b-32k-wint4-a800-tp1.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 50
+gpu_memory_utilization: 0.8
+kv_cache_ratio: 0.75
+quantization: wint4
+tensor_parallel_size: 1
diff --git a/benchmarks/yaml/qwen3moe30b-32k-wint4-h800-tp1.yaml b/benchmarks/yaml/qwen3moe30b-32k-wint4-h800-tp1.yaml
new file mode 100644
index 000000000..8ed7b40b3
--- /dev/null
+++ b/benchmarks/yaml/qwen3moe30b-32k-wint4-h800-tp1.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 50
+gpu_memory_utilization: 0.8
+kv_cache_ratio: 0.75
+quantization: wint4
+tensor_parallel_size: 1
diff --git a/benchmarks/yaml/request_yaml/eb45-128k.yaml b/benchmarks/yaml/request_yaml/eb45-128k.yaml
new file mode 100644
index 000000000..052d20997
--- /dev/null
+++ b/benchmarks/yaml/request_yaml/eb45-128k.yaml
@@ -0,0 +1,8 @@
+top_p: 0.8
+temperature: 0.8
+metadata:
+ min_tokens: 1
+max_tokens: 131071
+repetition_penalty: 1.0
+frequency_penalty: 0
+presence_penalty: 0
diff --git a/benchmarks/yaml/request_yaml/eb45-32k.yaml b/benchmarks/yaml/request_yaml/eb45-32k.yaml
new file mode 100644
index 000000000..07753d410
--- /dev/null
+++ b/benchmarks/yaml/request_yaml/eb45-32k.yaml
@@ -0,0 +1,8 @@
+top_p: 0.8
+temperature: 0.8
+metadata:
+ min_tokens: 1
+max_tokens: 12288
+repetition_penalty: 1.0
+frequency_penalty: 0
+presence_penalty: 0
diff --git a/benchmarks/yaml/request_yaml/qwen2-32k.yaml b/benchmarks/yaml/request_yaml/qwen2-32k.yaml
new file mode 100644
index 000000000..464277942
--- /dev/null
+++ b/benchmarks/yaml/request_yaml/qwen2-32k.yaml
@@ -0,0 +1,8 @@
+top_p: 0.8
+temperature: 0.7
+metadata:
+ min_tokens: 1
+max_tokens: 12288
+repetition_penalty: 1.05
+frequency_penalty: 0
+presence_penalty: 0
\ No newline at end of file
diff --git a/benchmarks/yaml/request_yaml/qwen3-32k.yaml b/benchmarks/yaml/request_yaml/qwen3-32k.yaml
new file mode 100644
index 000000000..8f1fc1fd7
--- /dev/null
+++ b/benchmarks/yaml/request_yaml/qwen3-32k.yaml
@@ -0,0 +1,8 @@
+top_p: 0.8
+temperature: 0.7
+metadata:
+ min_tokens: 1
+max_tokens: 12288
+repetition_penalty: 1.0
+frequency_penalty: 0
+presence_penalty: 1.5
\ No newline at end of file
diff --git a/benchmarks/yaml/request_yaml/x1-32k.yaml b/benchmarks/yaml/request_yaml/x1-32k.yaml
new file mode 100644
index 000000000..7cec615c4
--- /dev/null
+++ b/benchmarks/yaml/request_yaml/x1-32k.yaml
@@ -0,0 +1,8 @@
+top_p: 0.95
+temperature: 0.6
+metadata:
+ min_tokens: 1
+max_tokens: 32767
+repetition_penalty: 1.0
+frequency_penalty: 0
+presence_penalty: 0
diff --git a/benchmarks/yaml/x1-32k-wint4-h800-tp8.yaml b/benchmarks/yaml/x1-32k-wint4-h800-tp8.yaml
new file mode 100644
index 000000000..b2cbce4a6
--- /dev/null
+++ b/benchmarks/yaml/x1-32k-wint4-h800-tp8.yaml
@@ -0,0 +1,6 @@
+tensor_parallel_size: 8
+max_model_len: 32768
+max_num_seqs: 32
+num_gpu_blocks_override: 4096
+kv_cache_ratio: 0.5
+reasoning_parser: ernie-x1
diff --git a/benchmarks/yaml/x1-32k-wint4-p800-tp4.yaml b/benchmarks/yaml/x1-32k-wint4-p800-tp4.yaml
new file mode 100644
index 000000000..f6b593889
--- /dev/null
+++ b/benchmarks/yaml/x1-32k-wint4-p800-tp4.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 32
+gpu_memory_utilization: 0.9
+tensor_parallel_size: 4
+quantization: wint4
+reasoning_parser: ernie-x1
diff --git a/benchmarks/yaml/x1-32k-wint4-p800-tp8.yaml b/benchmarks/yaml/x1-32k-wint4-p800-tp8.yaml
new file mode 100644
index 000000000..25a2e89a2
--- /dev/null
+++ b/benchmarks/yaml/x1-32k-wint4-p800-tp8.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 128
+gpu_memory_utilization: 0.9
+tensor_parallel_size: 8
+quantization: wint4
+reasoning_parser: ernie-x1
diff --git a/benchmarks/yaml/x1-32k-wint4-prefixcache-h800-tp8.yaml b/benchmarks/yaml/x1-32k-wint4-prefixcache-h800-tp8.yaml
new file mode 100644
index 000000000..a6f522578
--- /dev/null
+++ b/benchmarks/yaml/x1-32k-wint4-prefixcache-h800-tp8.yaml
@@ -0,0 +1,10 @@
+enable_prefix_caching: True
+num_gpu_blocks_override: 8000
+max_model_len: 32768
+max_num_seqs: 64
+gpu_memory_utilization: 0.85
+kv_cache_ratio: 0.5
+tensor_parallel_size: 8
+swap_space: 200
+cache_queue_port: 55664
+reasoning_parser: ernie-x1
diff --git a/benchmarks/yaml/x1-32k-wint8-h800-tp8.yaml b/benchmarks/yaml/x1-32k-wint8-h800-tp8.yaml
new file mode 100644
index 000000000..b2cbce4a6
--- /dev/null
+++ b/benchmarks/yaml/x1-32k-wint8-h800-tp8.yaml
@@ -0,0 +1,6 @@
+tensor_parallel_size: 8
+max_model_len: 32768
+max_num_seqs: 32
+num_gpu_blocks_override: 4096
+kv_cache_ratio: 0.5
+reasoning_parser: ernie-x1
diff --git a/benchmarks/yaml/x1-32k-wint8-p800-tp4.yaml b/benchmarks/yaml/x1-32k-wint8-p800-tp4.yaml
new file mode 100644
index 000000000..df01844d1
--- /dev/null
+++ b/benchmarks/yaml/x1-32k-wint8-p800-tp4.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 8
+gpu_memory_utilization: 0.9
+tensor_parallel_size: 4
+quantization: wint8
+reasoning_parser: ernie-x1
diff --git a/benchmarks/yaml/x1-32k-wint8-p800-tp8.yaml b/benchmarks/yaml/x1-32k-wint8-p800-tp8.yaml
new file mode 100644
index 000000000..376177602
--- /dev/null
+++ b/benchmarks/yaml/x1-32k-wint8-p800-tp8.yaml
@@ -0,0 +1,6 @@
+max_model_len: 32768
+max_num_seqs: 64
+gpu_memory_utilization: 0.9
+tensor_parallel_size: 8
+quantization: wint8
+reasoning_parser: ernie-x1
\ No newline at end of file
diff --git a/benchmarks/yaml/x1-32k-wint8-prefixcache-h800-tp8.yaml b/benchmarks/yaml/x1-32k-wint8-prefixcache-h800-tp8.yaml
new file mode 100644
index 000000000..a6f522578
--- /dev/null
+++ b/benchmarks/yaml/x1-32k-wint8-prefixcache-h800-tp8.yaml
@@ -0,0 +1,10 @@
+enable_prefix_caching: True
+num_gpu_blocks_override: 8000
+max_model_len: 32768
+max_num_seqs: 64
+gpu_memory_utilization: 0.85
+kv_cache_ratio: 0.5
+tensor_parallel_size: 8
+swap_space: 200
+cache_queue_port: 55664
+reasoning_parser: ernie-x1
diff --git a/build.sh b/build.sh
index 8591a52f2..4e4098559 100644
--- a/build.sh
+++ b/build.sh
@@ -17,8 +17,9 @@
BUILD_WHEEL=${1:-1}
PYTHON_VERSION=${2:-"python"}
export python=$PYTHON_VERSION
-CPU_USE_BF16=${3:-"false"}
-BUILDING_ARCS=${4:-""}
+FD_CPU_USE_BF16=${3:-"false"}
+FD_BUILDING_ARCS=${4:-""}
+
# paddle distributed use to set archs
unset PADDLE_CUDA_ARCH_LIST
@@ -30,13 +31,9 @@ EGG_DIR="fastdeploy.egg-info"
# custom_ops directory config
OPS_SRC_DIR="custom_ops"
-OPS_BUILD_DIR="build"
-OPS_EGG_DIR="efficitentllm_ops.egg-info"
OPS_TMP_DIR_BASE="tmp_base"
OPS_TMP_DIR="tmp"
-TEST_DIR="tests"
-
# command line log config
RED='\033[0;31m'
BLUE='\033[0;34m'
@@ -44,13 +41,14 @@ GREEN='\033[1;32m'
BOLD='\033[1m'
NONE='\033[0m'
+DEVICE_TYPE="gpu"
function python_version_check() {
PY_MAIN_VERSION=`${python} -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $1}'`
PY_SUB_VERSION=`${python} -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $2}'`
echo -e "find python version ${PY_MAIN_VERSION}.${PY_SUB_VERSION}"
- if [ $PY_MAIN_VERSION -ne "3" -o $PY_SUB_VERSION -lt "8" ]; then
- echo -e "${RED}FAIL:${NONE} please use Python >= 3.8"
+ if [ $PY_MAIN_VERSION -ne "3" -o $PY_SUB_VERSION -lt "9" ]; then
+ echo -e "${RED}FAIL:${NONE} please use Python >= 3.9"
exit 1
fi
}
@@ -75,6 +73,7 @@ function copy_ops(){
WHEEL_CPU_NAME="fastdeploy_cpu_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
if [ "$is_rocm" = "True" ]; then
+ DEVICE_TYPE="rocm"
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "ROCM ops have been copy to fastdeploy"
return
@@ -82,6 +81,7 @@ function copy_ops(){
mkdir -p ../fastdeploy/model_executor/ops/base
is_cuda=`$python -c "import paddle; print(paddle.is_compiled_with_cuda())"`
if [ "$is_cuda" = "True" ]; then
+ DEVICE_TYPE="gpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "BASE and CUDA ops have been copy to fastdeploy"
@@ -90,6 +90,7 @@ function copy_ops(){
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
if [ "$is_xpu" = "True" ]; then
+ DEVICE_TYPE="xpu"
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/xpu
echo -e "xpu ops have been copy to fastdeploy"
return
@@ -97,20 +98,14 @@ function copy_ops(){
is_npu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('npu'))"`
if [ "$is_npu" = "True" ]; then
+ DEVICE_TYPE="npu"
cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/npu
echo -e "npu ops have been copy to fastdeploy"
return
fi
+ DEVICE_TYPE="cpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
- cd ${OPS_TMP_DIR}/${WHEEL_CPU_NAME}/xFasterTransformer/build/
- for file in *_pd_.so; do
- mv "$file" "${file/_pd_/}"
- done
- cd ../../x86-simd-sort/builddir/
- for file in *_pd_.so; do
- mv "$file" "${file/_pd_/}"
- done
cd ../../../../
cp -r ${OPS_TMP_DIR}/${WHEEL_CPU_NAME}/* ../fastdeploy/model_executor/ops/cpu
echo -e "BASE and CPU ops have been copy to fastdeploy"
@@ -122,15 +117,30 @@ function build_and_install_ops() {
export no_proxy=bcebos.com,paddlepaddle.org.cn,${no_proxy}
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_base_ops..."
${python} setup_ops_base.py install --install-lib ${OPS_TMP_DIR_BASE}
+ find ${OPS_TMP_DIR_BASE} -type f -name "*.o" -exec rm -f {} \;
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_ops..."
- if [ "$CPU_USE_BF16" == "true" ]; then
- CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
- :
- elif [ "$CPU_USE_BF16" == "false" ]; then
+ TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
+ is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
+ if [ "$is_xpu" = "True" ]; then
+ cd xpu_ops/src
+ bash build.sh ${TMP_DIR_REAL_PATH}
+ cd ../..
+ elif [ "$FD_CPU_USE_BF16" == "true" ]; then
+ if [ "$FD_BUILDING_ARCS" == "" ]; then
+ FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
+ else
+ FD_BUILDING_ARCS=${FD_BUILDING_ARCS} FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
+ fi
+ find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
+ elif [ "$FD_CPU_USE_BF16" == "false" ]; then
+ if [ "$FD_BUILDING_ARCS" == "" ]; then
${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
- :
+ else
+ FD_BUILDING_ARCS=${FD_BUILDING_ARCS} ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
+ fi
+ find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
else
- echo "Error: Invalid parameter '$CPU_USE_BF16'. Please use true or false."
+ echo "Error: Invalid parameter '$FD_CPU_USE_BF16'. Please use true or false."
exit 1
fi
if [ $? -ne 0 ]; then
@@ -146,11 +156,7 @@ function build_and_install_ops() {
function build_and_install() {
echo -e "${BLUE}[build]${NONE} building fastdeploy wheel..."
- if [ "$BUILDING_ARCS" == "" ]; then
- ${python} setup.py bdist_wheel --python-tag py3
- else
- BUILDING_ARCS=${BUILDING_ARCS} ${python} setup.py bdist_wheel --python-tag py3
- fi
+ ${python} setup.py bdist_wheel --python-tag=py3
if [ $? -ne 0 ]; then
echo -e "${RED}[FAIL]${NONE} build fastdeploy wheel failed"
@@ -174,10 +180,12 @@ function cleanup() {
rm -rf $BUILD_DIR $EGG_DIR
if [ `${python} -m pip list | grep fastdeploy | wc -l` -gt 0 ]; then
echo -e "${BLUE}[init]${NONE} uninstalling fastdeploy..."
- ${python} -m pip uninstall -y fastdeploy
+ ${python} -m pip uninstall -y fastdeploy-${DEVICE_TYPE}
fi
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
+ rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR_BASE
+ rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR
}
function abort() {
@@ -187,7 +195,7 @@ function abort() {
cur_dir=`basename "$pwd"`
rm -rf $BUILD_DIR $EGG_DIR $DIST_DIR
- ${python} -m pip uninstall -y fastdeploy
+ ${python} -m pip uninstall -y fastdeploy-${DEVICE_TYPE}
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
}
diff --git a/custom_ops/0001-DeepGEMM-95e81b3.patch b/custom_ops/0001-DeepGEMM-95e81b3.patch
new file mode 100644
index 000000000..e62972cec
--- /dev/null
+++ b/custom_ops/0001-DeepGEMM-95e81b3.patch
@@ -0,0 +1,643 @@
+From 5112002c155dceecc5e5983cdb67157e4f5400e2 Mon Sep 17 00:00:00 2001
+From: minghaipeng
+Date: Wed, 25 Jun 2025 15:05:24 +0800
+Subject: [PATCH] DeepGEMM 95e81b3
+
+---
+ deep_gemm/__init__.py | 2 +-
+ deep_gemm/include/deep_gemm/scheduler.cuh | 2 +-
+ deep_gemm/jit/compiler.py | 2 +-
+ deep_gemm/jit/interleave_ffma.py | 2 +-
+ deep_gemm/jit/runtime.py | 4 +-
+ deep_gemm/jit/template.py | 34 ++++----
+ deep_gemm/jit_kernels/gemm.py | 44 +++++------
+ deep_gemm/jit_kernels/m_grouped_gemm.py | 96 +++++++++++------------
+ deep_gemm/jit_kernels/tuner.py | 10 +--
+ deep_gemm/jit_kernels/utils.py | 18 +++--
+ deep_gemm/paddle_utils.py | 20 +++++
+ deep_gemm/utils.py | 30 +++----
+ 12 files changed, 143 insertions(+), 121 deletions(-)
+ create mode 100644 deep_gemm/paddle_utils.py
+
+diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py
+index 15b22ca..63e7fb7 100644
+--- a/deep_gemm/__init__.py
++++ b/deep_gemm/__init__.py
+@@ -1,4 +1,4 @@
+-import torch
++import paddle
+
+ from . import jit
+ from .jit_kernels import (
+diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
+index 9743871..6c97152 100644
+--- a/deep_gemm/include/deep_gemm/scheduler.cuh
++++ b/deep_gemm/include/deep_gemm/scheduler.cuh
+@@ -102,7 +102,7 @@ struct Scheduler {
+ if constexpr (kGemmType == GemmType::Normal) {
+ return block_idx * block_size;
+ } else if constexpr (kGemmType == GemmType::GroupedContiguous) {
+- auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
++ auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M));
+ return offset * shape_dim + block_idx * block_size;
+ } else if constexpr (kGemmType == GemmType::GroupedMasked) {
+ return curr_group_idx * shape_dim + block_idx * block_size;
+diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py
+index c17d466..6fdc52f 100644
+--- a/deep_gemm/jit/compiler.py
++++ b/deep_gemm/jit/compiler.py
+@@ -4,7 +4,7 @@ import os
+ import re
+ import subprocess
+ import uuid
+-from torch.utils.cpp_extension import CUDA_HOME
++from ..paddle_utils import CUDA_HOME
+ from typing import Tuple
+
+ from . import interleave_ffma
+diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
+index fcb377e..db9d6f3 100644
+--- a/deep_gemm/jit/interleave_ffma.py
++++ b/deep_gemm/jit/interleave_ffma.py
+@@ -3,7 +3,7 @@ import mmap
+ import os
+ import re
+ import subprocess
+-from torch.utils.cpp_extension import CUDA_HOME
++from ..paddle_utils import CUDA_HOME
+
+
+ def run_cuobjdump(file_path):
+diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
+index 66c370a..4761426 100644
+--- a/deep_gemm/jit/runtime.py
++++ b/deep_gemm/jit/runtime.py
+@@ -1,6 +1,6 @@
+ import ctypes
+ import os
+-import torch
++import paddle
+ from typing import Optional
+
+ from .template import map_ctype
+@@ -35,7 +35,7 @@ class Runtime:
+ assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
+ cargs = []
+ for arg, (name, dtype) in zip(args, self.args):
+- if isinstance(arg, torch.Tensor):
++ if isinstance(arg, paddle.Tensor):
+ assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
+ else:
+ assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
+diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py
+index ead37f5..51b02c1 100644
+--- a/deep_gemm/jit/template.py
++++ b/deep_gemm/jit/template.py
+@@ -1,24 +1,24 @@
+ import copy
+ import ctypes
+ import os
+-import torch
++import paddle
+ from typing import Any, Dict, Iterable, Tuple
+
+
+ # Name map for Python `eval`
+ typename_map: Dict[Any, str] = {
+ **{t: t.__name__ for t in (bool, int, float)},
+- torch.int: 'torch.int',
+- torch.float: 'torch.float',
+- torch.bfloat16: 'torch.bfloat16',
+- torch.float8_e4m3fn: 'torch.float8_e4m3fn',
+- torch.cuda.Stream: 'torch.cuda.Stream',
++ paddle.int32: 'paddle.int32',
++ paddle.float32: 'paddle.float32',
++ paddle.bfloat16: 'paddle.bfloat16',
++ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
++ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
+ }
+
+ # `ctype` map for Python casting
+ ctype_map: Dict[Any, Any] = {
+ **{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
+- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
++ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
+ }
+
+
+@@ -27,25 +27,25 @@ genc_map = {
+ bool: ('bool', 'bool'),
+ int: ('int', 'int'),
+ float: ('float', 'float'),
+- torch.int: ('void*', 'int*'),
+- torch.float: ('void*', 'float*'),
+- torch.bfloat16: ('void*', '__nv_bfloat16*'),
+- torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
+- torch.cuda.Stream: ('void*', 'cudaStream_t'),
++ paddle.int32: ('void*', 'int*'),
++ paddle.float32: ('void*', 'float*'),
++ paddle.bfloat16: ('void*', '__nv_bfloat16*'),
++ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
++ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
+ }
+
+
+ def map_ctype(value: Any) -> Any:
+ if hasattr(value, 'data_ptr'):
+- if value.dtype == torch.int:
++ if value.dtype == paddle.int32:
+ return ctypes.c_void_p(value.data_ptr())
+- elif value.dtype == torch.float:
++ elif value.dtype == paddle.float32:
+ return ctypes.c_void_p(value.data_ptr())
+- elif value.dtype == torch.bfloat16:
++ elif value.dtype == paddle.bfloat16:
+ return ctypes.c_void_p(value.data_ptr())
+- elif value.dtype == torch.float16:
++ elif value.dtype == paddle.float16:
+ return ctypes.c_void_p(value.data_ptr())
+- elif value.dtype == torch.float8_e4m3fn:
++ elif value.dtype == paddle.float8_e4m3fn:
+ return ctypes.c_void_p(value.data_ptr())
+ else:
+ return ctypes.c_void_p(value.data_ptr())
+diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py
+index cb438b7..44aa0ed 100644
+--- a/deep_gemm/jit_kernels/gemm.py
++++ b/deep_gemm/jit_kernels/gemm.py
+@@ -1,5 +1,5 @@
+ import math
+-import torch
++import paddle
+ from functools import lru_cache
+ from typing import Tuple
+
+@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
+ return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
+
+
+-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
+- rhs: Tuple[torch.Tensor, torch.Tensor],
+- out: torch.Tensor) -> None:
++def gemm_fp8_fp8_bf16_nt(lhs: Tuple[paddle.Tensor, paddle.Tensor],
++ rhs: Tuple[paddle.Tensor, paddle.Tensor],
++ out: paddle.Tensor) -> None:
+ """
+ Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
+ LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
+ RHS and RHS scaling factors are required to be transposed.
+ The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
+- this function will do a transposing with a set of slow PyTorch operations.
++ this function will do a transposing with a set of slow paddle operations.
+
+ Arguments:
+- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
++ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
+ the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
+- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
++ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[n, k]`.
+ the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
+ out: the BF16 output tensor of shape `[m, n]`, representing the result.
+ """
+@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
+ n, k_ = rhs.shape
+ m_, n_ = out.shape
+
+- assert n % 64 == 0 and k % 128 == 0
++ # assert n % 64 == 0 and k % 128 == 0
+
+ # Type and shape checks
+- assert m == m_ and n == n_ and k == k_
+- assert n > 0 and k > 0
+- assert lhs_scales.shape == (m, (k + 127) // 128)
+- assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
+- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
+- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
+- assert out.dtype == torch.bfloat16
+- assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
++ # assert m == m_ and n == n_ and k == k_
++ # assert n > 0 and k > 0
++ # assert lhs_scales.shape == (m, (k + 127) // 128)
++ # assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
++ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
++ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
++ # assert out.dtype == paddle.bfloat16
++ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
+
+ # LHS scales must be transposed for TMA load, but not for RHS scales
+ # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
+ lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
+- assert rhs_scales.is_contiguous()
++ # assert rhs_scales.is_contiguous()
+
+ # Do nothing if `m` is zero
+ if m == 0:
+@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
+ global includes, template
+ num_sms = get_num_sms()
+ num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
+- args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0])
++ args = (lhs, lhs_scales, rhs, rhs_scales, out, m, paddle.device.cuda.current_stream(), num_sms, smem_config[0])
+ runtime = jit_tuner.compile_and_tune(
+ name='gemm_fp8_fp8_bf16_nt',
+ keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
+@@ -225,10 +225,10 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
+ 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
+ space=(),
+ includes=includes,
+- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
+- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
+- ('out', torch.bfloat16), ('m', int),
+- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
++ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
++ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
++ ('out', paddle.bfloat16), ('m', int),
++ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
+ template=template,
+ args=args
+ )
+diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py
+index 3b518c9..ba776bd 100644
+--- a/deep_gemm/jit_kernels/m_grouped_gemm.py
++++ b/deep_gemm/jit_kernels/m_grouped_gemm.py
+@@ -1,4 +1,4 @@
+-import torch
++import paddle
+ from typing import Tuple
+
+ from .gemm import get_best_configs, get_block_n_padding_for_smem_d
+@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
+ """
+
+
+-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
+- rhs: Tuple[torch.Tensor, torch.Tensor],
+- out: torch.Tensor, m_indices: torch.Tensor) -> None:
++def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[paddle.Tensor, paddle.Tensor],
++ rhs: Tuple[paddle.Tensor, paddle.Tensor],
++ out: paddle.Tensor, m_indices: paddle.Tensor) -> None:
+ """
+ Do a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
+ LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
+ RHS and RHS scaling factors are required to be transposed.
+ The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
+- this function will do a transposing with a set of slow PyTorch operations.
++ this function will do a transposing with a set of slow Pypaddle operations.
+ On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
+ `get_m_alignment_for_contiguous_layout()` (128).
+
+ Arguments:
+- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
++ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
+ the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
+- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
++ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
+ the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
+ out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
+- m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
++ m_indices: a tensor of shape `[m_sum]` with type `paddle.int`.
+ `m_indices[i]` records the group which the i-th row of the LHS belong to,
+ which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
+ Values of `m_indices` in every-m-alignment-block must also be the same.
+@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
+ m__ = m_indices.numel()
+
+ # Type and shape checks
+- assert m == m_ == m__ and k == k_ and n == n_
+- assert lhs_scales.shape == (m, (k + 127) // 128)
+- assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
+- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
+- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
+- assert out.dtype == torch.bfloat16
+- assert m_indices.dtype == torch.int32
+- assert lhs.is_contiguous() and rhs.is_contiguous()
+- assert out.is_contiguous() and m_indices.is_contiguous()
++ # assert m == m_ == m__ and k == k_ and n == n_
++ # assert lhs_scales.shape == (m, (k + 127) // 128)
++ # assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
++ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
++ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
++ # assert out.dtype == paddle.bfloat16
++ # assert m_indices.dtype == paddle.int32
++ # assert lhs.is_contiguous() and rhs.is_contiguous()
++ # assert out.is_contiguous() and m_indices.is_contiguous()
+
+ # LHS scales must be transposed for TMA load, but not for RHS scales
+ lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
+- assert rhs_scales.is_contiguous()
++ # assert rhs_scales.is_contiguous()
+
+ # Do nothing if `m` is zero
+ if m == 0:
+@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
+ num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
+ args = (lhs, lhs_scales, rhs, rhs_scales, out,
+ m_indices, m, num_groups,
+- torch.cuda.current_stream(), num_sms, smem_config[0])
++ paddle.device.cuda.current_stream(), num_sms, smem_config[0])
+ runtime = jit_tuner.compile_and_tune(
+ name='m_grouped_gemm_fp8_fp8_bf16_nt',
+ keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
+@@ -105,11 +105,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
+ 'GEMM_TYPE': 'GroupedContiguous'},
+ space=(),
+ includes=includes,
+- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
+- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
+- ('out', torch.bfloat16),
+- ('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
+- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
++ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
++ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
++ ('out', paddle.bfloat16),
++ ('grouped_layout', paddle.int32), ('m', int), ('num_groups', int),
++ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
+ template=template,
+ args=args
+ )
+@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
+ runtime(*args)
+
+
+-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
+- rhs: Tuple[torch.Tensor, torch.Tensor],
+- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
++def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[paddle.Tensor, paddle.Tensor],
++ rhs: Tuple[paddle.Tensor, paddle.Tensor],
++ out: paddle.Tensor, masked_m: paddle.Tensor, expected_m: int) -> None:
+ """
+ Do a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
+ LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
+ RHS and RHS scaling factors are required to be transposed.
+ The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
+- this function will do a transposing with a set of slow PyTorch operations.
++ this function will do a transposing with a set of slow paddle operations.
+ Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
+ should be separately transposed.
+
+ Arguments:
+- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
++ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
+ the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
+- rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
++ rhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, n, k]`.
+ the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
+ out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
+ masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
+@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
+ num_groups___ = masked_m.numel()
+
+ # Type and shape checks
+- assert num_groups == num_groups_ == num_groups__ == num_groups___
+- assert m == m_ and n == n_ and k == k_
+- assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
+- assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
+- assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
+- assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
+- assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
+- assert out.dtype == torch.bfloat16
+- assert masked_m.dtype == torch.int32
+- assert lhs.is_contiguous() and rhs.is_contiguous()
+- assert out.is_contiguous() and masked_m.is_contiguous()
++ # assert num_groups == num_groups_ == num_groups__ == num_groups___
++ # assert m == m_ and n == n_ and k == k_
++ # assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
++ # assert lhs_scales.shape == (num_groups, m, (k + 127) // 128)
++ # assert rhs_scales.shape == (num_groups, (n + 127) // 128, (k + 127) // 128)
++ # assert lhs.dtype == paddle.float8_e4m3fn and lhs_scales.dtype == paddle.float32
++ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
++ # assert out.dtype == paddle.bfloat16
++ # assert masked_m.dtype == paddle.int32
++ # assert lhs.is_contiguous() and rhs.is_contiguous()
++ # assert out.is_contiguous() and masked_m.is_contiguous()
+
+ # LHS scales must be transposed for TMA load, but not for RHS scales
+ lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
+- assert rhs_scales.is_contiguous()
++ # assert rhs_scales.is_contiguous()
+
+ # Auto-tuning with compilation
+ global includes, template
+@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
+
+ args = (lhs, lhs_scales, rhs, rhs_scales, out,
+ masked_m, m,
+- torch.cuda.current_stream(), num_sms, smem_config[0])
++ paddle.device.cuda.current_stream(), num_sms, smem_config[0])
+ runtime = jit_tuner.compile_and_tune(
+ name='m_grouped_gemm_fp8_fp8_bf16_nt',
+ keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
+@@ -189,11 +189,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
+ 'GEMM_TYPE': 'GroupedMasked'},
+ space=(),
+ includes=includes,
+- arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
+- ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
+- ('out', torch.bfloat16),
+- ('grouped_layout', torch.int32), ('m', int),
+- ('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
++ arg_defs=(('lhs', paddle.float8_e4m3fn), ('lhs_scales', paddle.float32),
++ ('rhs', paddle.float8_e4m3fn), ('rhs_scales', paddle.float32),
++ ('out', paddle.bfloat16),
++ ('grouped_layout', paddle.int32), ('m', int),
++ ('stream', paddle.device.cuda.Stream), ('num_sms', int), ('smem_size', int)),
+ template=template,
+ args=args
+ )
+diff --git a/deep_gemm/jit_kernels/tuner.py b/deep_gemm/jit_kernels/tuner.py
+index 6ed6749..9e1d70f 100644
+--- a/deep_gemm/jit_kernels/tuner.py
++++ b/deep_gemm/jit_kernels/tuner.py
+@@ -1,6 +1,6 @@
+ import copy
+ import os
+-import torch
++import paddle
+ from typing import Any, Dict
+
+ from ..jit import build, cpp_format, generate, Runtime
+@@ -51,10 +51,10 @@ class JITTuner:
+ continue
+
+ # Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
+- start_event = torch.cuda.Event(enable_timing=True)
+- end_event = torch.cuda.Event(enable_timing=True)
+- torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
+- torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
++ start_event = paddle.device.cuda.Event(enable_timing=True)
++ end_event = paddle.device.cuda.Event(enable_timing=True)
++ paddle.empty((int(256e6 // 4)), dtype=paddle.int32).zero_()
++ paddle.randn((8192, 8192), dtype=paddle.float32) @ paddle.randn((8192, 8192), dtype=paddle.float32)
+ start_event.record()
+ for i in range(20):
+ assert runtime(*args) == 0
+diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py
+index c6da56b..a17b1b1 100644
+--- a/deep_gemm/jit_kernels/utils.py
++++ b/deep_gemm/jit_kernels/utils.py
+@@ -1,4 +1,4 @@
+-import torch
++import paddle
+
+ _num_sms = None
+
+@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
+ num_sms: the desired maximum SM count for all GEMM kernels to use.
+ """
+ global _num_sms
+- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
++ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
+ _num_sms = num_sms
+
+
+@@ -25,7 +25,7 @@ def get_num_sms() -> int:
+ """
+ global _num_sms
+ if _num_sms is None:
+- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
++ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
+ return _num_sms
+
+
+@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
+ return ceil_div(x, alignment) * alignment
+
+
+-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
++def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
+ """
+- Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
++ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
+ If the input tensor is already column-major layout and 16-byte aligned along the M axis
+ (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
+
+@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
+ m, n = x.shape[-2], x.shape[-1]
+ aligned_m = get_tma_aligned_size(m, x.element_size())
+ if x.dim() == 2:
+- if x.stride(0) == 1 and x.stride(1) == aligned_m:
++ if x.strides[0] == 1 and x.strides[1] == aligned_m:
+ return x
+ x, remove_dim = x.unsqueeze(0), True
+
+ b = x.shape[0]
+
+ # The last kernel gives a column-major TMA aligned layout
+- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
++ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
+ return x.squeeze(0) if remove_dim else x
+
+ # Normal layout requires transposing
+- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
++ aligned_x = paddle.transpose(
++ paddle.empty((b, n, aligned_m), dtype=x.dtype), perm=[0, 2, 1]
++ )
+ aligned_x[:, :m, :] = x
+ aligned_x = aligned_x[:, :m, :]
+ return aligned_x.squeeze(0) if remove_dim else aligned_x
+diff --git a/deep_gemm/paddle_utils.py b/deep_gemm/paddle_utils.py
+new file mode 100644
+index 0000000..2326807
+--- /dev/null
++++ b/deep_gemm/paddle_utils.py
+@@ -0,0 +1,20 @@
++import os
++
++def get_cuda_home():
++ """Get Cuda home directory"""
++ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
++ if cuda_home:
++ return cuda_home
++
++ try:
++ which_cmd = "which nvcc"
++
++ nvcc_path = os.popen(which_cmd).read().strip()
++ if nvcc_path:
++ return os.path.dirname(os.path.dirname(nvcc_path))
++ except Exception:
++ pass
++
++ return None
++
++CUDA_HOME = get_cuda_home()
+\ No newline at end of file
+diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py
+index d5cdd01..5237f09 100644
+--- a/deep_gemm/utils.py
++++ b/deep_gemm/utils.py
+@@ -1,15 +1,15 @@
+ import os
+ import sys
+ import time
+-import torch
+-import torch.distributed as dist
++import paddle
++import paddle.distributed as dist
+
+
+ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
+ high_precision: bool = False):
+ # Flush L2 cache with 256 MB data
+- torch.cuda.synchronize()
+- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
++ paddle.device.cuda.synchronize()
++ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
+ cache.zero_()
+
+ # Warmup
+@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
+
+ # Add a large kernel to eliminate the CPU launch overhead
+ if high_precision:
+- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
+- y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
++ x = paddle.randn((8192, 8192), dtype=paddle.float32)
++ y = paddle.randn((8192, 8192), dtype=paddle.float32)
+ x @ y
+
+ # Testing
+- start_event = torch.cuda.Event(enable_timing=True)
+- end_event = torch.cuda.Event(enable_timing=True)
++ start_event = paddle.device.cuda.Event(enable_timing=True)
++ end_event = paddle.device.cuda.Event(enable_timing=True)
+ start_event.record()
+ for i in range(num_tests):
+ fn()
+ end_event.record()
+- torch.cuda.synchronize()
++ paddle.device.synchronize()
+
+ return start_event.elapsed_time(end_event) / num_tests
+
+@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
+ # Profile
+ suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
+ with suppress():
+- schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
+- profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
++ scheduler = paddle.profiler.make_scheduler(closed=0, ready=1, record=1, repeat=1) if not using_nsys else None
++ profiler = paddle.profiler.Profiler(targets=[paddle.profiler.ProfilerTarget.CPU, paddle.profiler.ProfilerTarget.GPU], scheduler=scheduler) if not using_nsys else empty_suppress()
+ with profiler:
+ for i in range(2):
+ # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
+ if barrier_comm_profiling:
+- lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
+- rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
++ lhs = paddle.randn((8192, 8192), dtype=paddle.float32)
++ rhs = paddle.randn((8192, 8192), dtype=paddle.float32)
+ lhs @ rhs
+- dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
++ dist.all_reduce(paddle.ones(1, dtype=paddle.float32))
+ for _ in range(num_tests):
+ if sleep_between_tests > 0.0:
+ time.sleep(sleep_between_tests)
+ if flush_l2:
+- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
++ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
+ fn()
+
+ if not using_nsys:
+--
+2.43.0
+
diff --git a/custom_ops/cpu_ops/avx_weight_only.cc b/custom_ops/cpu_ops/avx_weight_only.cc
deleted file mode 100644
index 1d410156e..000000000
--- a/custom_ops/cpu_ops/avx_weight_only.cc
+++ /dev/null
@@ -1,188 +0,0 @@
-// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-#include "dtype.h"
-#include "matmul_helper.h"
-#include "my_types.h"
-#include "paddle/extension.h"
-#include "paddle/phi/core/kernel_registry.h"
-template
-void AvxCompute(const paddle::Tensor &x,
- const paddle::Tensor &weight,
- const paddle::Tensor &w_bias,
- bool trans,
- const std::string alog,
- paddle::Tensor &out,
- xft::Matrix &quantizedWeight,
- xft::Vector &WeightScale,
- xft::Vector &WeightZero,
- xft::Vector &WeightSum,
- MMHelper *mmHelper) {
- auto out_data = out.data();
- const float *x_data = reinterpret_cast(x.data());
- const float *bias_data = nullptr;
- if (w_bias.initialized()) {
- bias_data = reinterpret_cast(w_bias.data());
- }
- int m = 1;
- for (int i = 0; i < x.shape().size() - 1; i++) {
- m = m * x.shape()[i];
- }
- int k = x.shape()[x.shape().size() - 1];
- int l = weight.shape()[1];
- int n = weight.shape()[1];
- if (w_bias.initialized()) {
- mmHelper->compute_bias(false,
- m,
- n,
- k,
- 1.0f,
- x_data,
- k,
- quantizedWeight.Data(),
- WeightScale.Data(),
- WeightZero.Data(),
- WeightSum.Data(),
- 0.0f,
- out_data,
- l,
- bias_data);
- } else {
- mmHelper->compute(false,
- m,
- n,
- k,
- 1.0f,
- x_data,
- k,
- quantizedWeight.Data(),
- WeightScale.Data(),
- WeightZero.Data(),
- WeightSum.Data(),
- 0.0,
- out_data,
- l);
- }
-};
-template
-void AvxWeightOnly(const paddle::Tensor &x,
- const paddle::Tensor &weight,
- const paddle::Tensor &w_bias,
- bool trans,
- const std::string alog,
- paddle::Tensor &out) {
- static std::unordered_map *,
- xft::Vector *,
- xft::Vector *,
- xft::Vector *>>
- weight_only_hub;
- std::stringstream weights_addr;
- weights_addr << weight.data() << alog;
- std::string weight_only_key = weights_addr.str();
- auto it_created = weight_only_hub.find(weight_only_key);
- static MMHelper *mmHelper;
- int rows = weight.shape()[0], cols = weight.shape()[1];
- xft::Vector *WeightScale =
- new xft::Vector(); // if weight is int8
- xft::Vector *WeightZero =
- new xft::Vector(); // if weight is int8
- xft::Vector *WeightSum =
- new xft::Vector(); // if weight is int8
- xft::Matrix *quantizedWeight = new xft::Matrix();
- if (it_created == weight_only_hub.end()) {
- auto weight_ptr = reinterpret_cast(weight.data());
- xft::Matrix convertedWeight;
- mmHelper = new MMHelper(xft::DeviceKind::iCPU, 0);
- mmHelper->convertWeight(trans,
- rows,
- cols,
- weight_ptr,
- nullptr,
- nullptr,
- convertedWeight,
- *WeightScale,
- *WeightZero,
- *WeightSum);
- quantizedWeight->Resize(rows, cols);
- mmHelper->packWeight(trans, convertedWeight, *quantizedWeight);
- weight_only_hub[weight_only_key] = std::make_tuple(
- quantizedWeight, WeightScale, WeightZero, WeightSum);
- AvxCompute(x,
- weight,
- w_bias,
- trans,
- alog,
- out,
- *quantizedWeight,
- *WeightScale,
- *WeightZero,
- *WeightSum,
- mmHelper);
- } else {
- AvxCompute(x,
- weight,
- w_bias,
- trans,
- alog,
- out,
- *(std::get<0>(it_created->second)),
- *(std::get<1>(it_created->second)),
- *(std::get<2>(it_created->second)),
- *(std::get<3>(it_created->second)),
- mmHelper);
- }
-}
-std::vector InvokeAvxWeightOnly(const paddle::Tensor &x,
- const paddle::Tensor &weight,
- const paddle::Tensor &w_bias,
- const std::string &alog,
- bool trans) {
- auto out_shape = x.shape();
- out_shape[out_shape.size() - 1] = weight.shape()[1];
- auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace());
- if (alog == "int8") {
- AvxWeightOnly(x, weight, w_bias, trans, alog, out);
- } else if (alog == "fp16") {
- AvxWeightOnly(x, weight, w_bias, trans, alog, out);
- } else {
- AvxWeightOnly(x, weight, w_bias, trans, alog, out);
- }
- return {out};
-}
-
-std::vector> AvxWeightOnlyInferShape(
- std::vector x_shape,
- std::vector weigh_shape,
- std::vector weigh_bias_shape) {
- int m = 1;
- for (int i = 0; i < x_shape.size() - 1; i++) {
- m = m * x_shape[i];
- }
- return {std::vector{m, weigh_shape[1]}};
-}
-
-std::vector AvxWeightOnlyInferDtype(
- paddle::DataType x_dtype,
- paddle::DataType weight_dtype,
- paddle::DataType weight_bias_dtype) {
- return {x_dtype};
-}
-
-PD_BUILD_STATIC_OP(avx_weight_only)
- .Inputs({"x", "weight", "w_bias"})
- .Outputs({"out"})
- .Attrs({"alog: std::string", "trans:bool"})
- .SetKernelFn(PD_KERNEL(InvokeAvxWeightOnly))
- .SetInferShapeFn(PD_INFER_SHAPE(AvxWeightOnlyInferShape))
- .SetInferDtypeFn(PD_INFER_DTYPE(AvxWeightOnlyInferDtype));
diff --git a/custom_ops/cpu_ops/rebuild_padding.cc b/custom_ops/cpu_ops/rebuild_padding.cc
new file mode 100644
index 000000000..8ce533d04
--- /dev/null
+++ b/custom_ops/cpu_ops/rebuild_padding.cc
@@ -0,0 +1,268 @@
+// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include "paddle/extension.h"
+
+#ifndef PD_BUILD_STATIC_OP
+#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
+#endif
+
+template
+void RebuildPaddingCPUImpl(T *output_data,
+ const T *input_data,
+ const int *cum_offsets_data,
+ const int *seq_len_this_time_data,
+ const int *seq_lens_decoder_data,
+ const int *seq_lens_encoder_data,
+ int max_input_length,
+ int dim_embed,
+ const int elem_nums) {
+ for (int i = 0; i < elem_nums; ++i) {
+ const int bi = i / dim_embed;
+ const int bias_idx = i % dim_embed;
+ int seq_id = 0;
+
+ if (seq_len_this_time_data[bi] == 0) {
+ continue;
+ }
+ if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
+ continue;
+ }
+ if (seq_lens_encoder_data[bi] > 0) {
+ seq_id = seq_lens_encoder_data[bi] - 1;
+ }
+ const int ori_token_idx =
+ bi * max_input_length - cum_offsets_data[bi] + seq_id;
+ const int src_offset = ori_token_idx * dim_embed + bias_idx;
+
+ output_data[i] = input_data[src_offset];
+ }
+}
+
+template
+void RebuildAppendPaddingCPUImpl(T *output_data,
+ const T *input_data,
+ const int *cum_offsets_data,
+ const int *seq_len_this_time_data,
+ const int *seq_lens_decoder_data,
+ const int *seq_lens_encoder_data,
+ const int *output_padding_offset_data,
+ const int max_input_length,
+ const int dim_embed,
+ const int64_t output_elem_nums) {
+ for (int i = 0; i < output_elem_nums; ++i) {
+ int out_token_id = i / dim_embed;
+ int ori_token_id =
+ out_token_id + output_padding_offset_data[out_token_id];
+ int bi = ori_token_id / max_input_length;
+ if (seq_len_this_time_data[bi] == 0 ||
+ (seq_lens_decoder_data[bi] == 0 &&
+ seq_lens_encoder_data[bi] == 0)) {
+ continue;
+ }
+ int seq_id = 0;
+ if (seq_lens_encoder_data[bi] > 0) {
+ seq_id = seq_lens_encoder_data[bi] - 1;
+ }
+ int input_token_id = ori_token_id - cum_offsets_data[bi] + seq_id;
+ int bias_idx = i % dim_embed;
+ int src_offset = input_token_id * dim_embed + bias_idx;
+ output_data[i] = input_data[src_offset];
+ }
+}
+
+std::vector RebuildPaddingCPU(
+ const paddle::Tensor &tmp_out,
+ const paddle::Tensor &cum_offsets,
+ const paddle::Tensor &seq_len_this_time,
+ const paddle::Tensor &seq_lens_decoder,
+ const paddle::Tensor &seq_lens_encoder,
+ const paddle::optional &output_padding_offset,
+ int max_input_length) {
+ auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
+ auto cum_offsets_cpu = cum_offsets.copy_to(paddle::CPUPlace(), true);
+ auto seq_len_this_time_cpu =
+ seq_len_this_time.copy_to(paddle::CPUPlace(), true);
+ auto seq_lens_decoder_cpu =
+ seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
+ auto seq_lens_encoder_cpu =
+ seq_lens_encoder.copy_to(paddle::CPUPlace(), true);
+ paddle::optional output_padding_offset_cpu;
+ if (output_padding_offset) {
+ output_padding_offset_cpu =
+ output_padding_offset->copy_to(paddle::CPUPlace(), true);
+ }
+
+ int token_num = tmp_out_cpu.shape()[0];
+ int dim_embed = tmp_out_cpu.shape()[1];
+ int bsz = cum_offsets_cpu.shape()[0];
+
+ paddle::Tensor out;
+ if (output_padding_offset_cpu) {
+ int need_delete_token_num = 0;
+ for (int i = 0; i < bsz; ++i) {
+ if (seq_lens_encoder_cpu.data()[i] > 0) {
+ need_delete_token_num +=
+ seq_lens_encoder_cpu.data()[i] - 1;
+ }
+ }
+ int output_token_num = token_num - need_delete_token_num;
+ out = paddle::full({output_token_num, dim_embed},
+ 0,
+ tmp_out_cpu.dtype(),
+ paddle::CPUPlace());
+ } else {
+ out = paddle::full(
+ {bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
+ }
+
+ const int *cum_offsets_data = cum_offsets_cpu.data();
+ const int *seq_len_this_time_data = seq_len_this_time_cpu.data();
+ const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data();
+ const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data();
+ int elem_nums = out.numel();
+
+ if (output_padding_offset_cpu) {
+ const int *output_padding_offset_data =
+ output_padding_offset_cpu->data();
+ switch (tmp_out_cpu.dtype()) {
+ case paddle::DataType::FLOAT32:
+ RebuildAppendPaddingCPUImpl(out.data(),
+ tmp_out_cpu.data(),
+ cum_offsets_data,
+ seq_len_this_time_data,
+ seq_lens_decoder_data,
+ seq_lens_encoder_data,
+ output_padding_offset_data,
+ max_input_length,
+ dim_embed,
+ elem_nums);
+ break;
+ case paddle::DataType::FLOAT16:
+ RebuildAppendPaddingCPUImpl(
+ out.data(),
+ tmp_out_cpu.data(),
+ cum_offsets_data,
+ seq_len_this_time_data,
+ seq_lens_decoder_data,
+ seq_lens_encoder_data,
+ output_padding_offset_data,
+ max_input_length,
+ dim_embed,
+ elem_nums);
+ break;
+ case paddle::DataType::BFLOAT16:
+ RebuildAppendPaddingCPUImpl(
+ out.data(),
+ tmp_out_cpu.data(),
+ cum_offsets_data,
+ seq_len_this_time_data,
+ seq_lens_decoder_data,
+ seq_lens_encoder_data,
+ output_padding_offset_data,
+ max_input_length,
+ dim_embed,
+ elem_nums);
+ break;
+ default:
+ PD_THROW(
+ "Unsupported data type for rebuild_padding_cpu. "
+ "Only float32, float16, and bfloat16 are supported.");
+ }
+ } else {
+ switch (tmp_out_cpu.dtype()) {
+ case paddle::DataType::FLOAT32:
+ RebuildPaddingCPUImpl(out.data(),
+ tmp_out_cpu.data(),
+ cum_offsets_data,
+ seq_len_this_time_data,
+ seq_lens_decoder_data,
+ seq_lens_encoder_data,
+ max_input_length,
+ dim_embed,
+ elem_nums);
+ break;
+ case paddle::DataType::FLOAT16:
+ RebuildPaddingCPUImpl(
+ out.data(),
+ tmp_out_cpu.data(),
+ cum_offsets_data,
+ seq_len_this_time_data,
+ seq_lens_decoder_data,
+ seq_lens_encoder_data,
+ max_input_length,
+ dim_embed,
+ elem_nums);
+ break;
+ case paddle::DataType::BFLOAT16:
+
+ RebuildPaddingCPUImpl(
+ out.data(),
+ tmp_out_cpu.data(),
+ cum_offsets_data,
+ seq_len_this_time_data,
+ seq_lens_decoder_data,
+ seq_lens_encoder_data,
+ max_input_length,
+ dim_embed,
+ elem_nums);
+ break;
+ default:
+ PD_THROW(
+ "Unsupported data type for rebuild_padding_cpu. "
+ "Only float32, float16, and bfloat16 are supported.");
+ }
+ }
+ return {out};
+}
+
+std::vector> RebuildPaddingInferShape(
+ const std::vector &tmp_out_shape,
+ const std::vector &cum_offsets_shape,
+ const std::vector &seq_len_this_time_shape,
+ const std::vector &seq_lens_decoder_shape,
+ const std::vector &seq_lens_encoder_shape,
+ const paddle::optional> &output_padding_offset_shape) {
+ int64_t dim_embed = tmp_out_shape[1];
+ if (output_padding_offset_shape) {
+ return {{-1, dim_embed}};
+ } else {
+ int64_t bsz = cum_offsets_shape[0];
+ return {{bsz, dim_embed}};
+ }
+}
+
+std::vector RebuildPaddingInferDtype(
+ const paddle::DataType &tmp_out_dtype,
+ const paddle::DataType &cum_offsets_dtype,
+ const paddle::DataType &seq_len_this_time_dtype,
+ const paddle::DataType &seq_lens_decoder_dtype,
+ const paddle::DataType &seq_lens_encoder_dtype,
+ const paddle::optional &output_padding_offset_dtype) {
+ return {tmp_out_dtype};
+}
+
+PD_BUILD_STATIC_OP(rebuild_padding_cpu)
+ .Inputs({"tmp_out",
+ "cum_offsets",
+ "seq_len_this_time",
+ "seq_lens_decoder",
+ "seq_lens_encoder",
+ paddle::Optional("output_padding_offset")})
+ .Outputs({"out"})
+ .Attrs({"max_input_length: int"})
+ .SetKernelFn(PD_KERNEL(RebuildPaddingCPU))
+ .SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape))
+ .SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype));
diff --git a/custom_ops/cpu_ops/xft_all_layer.cc b/custom_ops/cpu_ops/xft_all_layer.cc
deleted file mode 100644
index 7b24e0b8e..000000000
--- a/custom_ops/cpu_ops/xft_all_layer.cc
+++ /dev/null
@@ -1,201 +0,0 @@
-// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "layers_decoder.h"
-#include "paddle/extension.h"
-#include "paddle/phi/core/kernel_registry.h"
-
-std::vector InvokeAllLLaMALayer(
- const paddle::Tensor &input,
- const std::vector &ln1Gamma,
- const std::vector &ln1Beta,
- const std::vector &qkvWeight,
- const std::vector &qkvBiasWeight,
- const std::vector &attnOutWeight,
- const std::vector &attnOutBias,
- const std::vector &ln2Gamma,
- const std::vector &ln2Beta,
- const std::vector &gateWeight,
- const std::vector &gateBias,
- const std::vector &upWeight,
- const std::vector &upBias,
- const std::vector &downWeight,
- const std::vector &downBias,
- const paddle::Tensor &pastSeqLen,
- const paddle::Tensor ¤tSeqLen,
- const paddle::Tensor &step,
- int hiddensize,
- int totalLayer,
- const std::string &computeType,
- const std::string &activation,
- const std::string &normType,
- int attHeadDim,
- int attHeadNum,
- int kvHeadNum,
- int maxPositions,
- int maxPosEmbed,
- int intermediateSize) {
- auto out = paddle::empty_like(input);
- auto batchSize = input.shape()[0];
- auto inputSeqLen = input.shape()[1];
- auto past_seq_len = pastSeqLen.data()[0];
- auto cur_seq_len = static_cast(currentSeqLen.data()[0]);
- auto step_id = step.data()[0];
- auto output_ptr = reinterpret_cast(out.data());
- auto xft_data_type = xft::DataType::fp16;
- if (computeType == "bf16") {
- xft_data_type = xft::DataType::bf16;
- } else if (computeType == "bf16_int8") {
- xft_data_type = xft::DataType::bf16_int8;
- }
- auto xft_act_type = xft::ActivationType::SILU;
- if (activation == "relu") {
- xft_act_type = xft::ActivationType::RELU;
- } else if (activation == "gelu") {
- xft_act_type = xft::ActivationType::GELU;
- } else if (activation == "swiglu") {
- xft_act_type = xft::ActivationType::SWIGLU;
- }
- auto xft_norm_type = xft::NormType::RMS;
- if (normType == "layernorm") {
- xft_norm_type = xft::NormType::LN;
- }
- auto input_ptr = reinterpret_cast(input.data());
- for (int i = 0; i < totalLayer; ++i) {
- auto ln1Gamma_ptr =
- reinterpret_cast(ln1Gamma[i].data());
- auto ln1Beta_ptr =
- reinterpret_cast(ln1Beta[i].data());
- auto qkvWeight_ptr =
- reinterpret_cast(qkvWeight[i].data());
- auto qkvBiasWeight_ptr =
- reinterpret_cast(qkvBiasWeight[i].data());
- auto attnOutWeight_ptr =
- reinterpret_cast(attnOutWeight[i].data());
- auto ln2Gamma_ptr =
- reinterpret_cast(ln2Gamma[i].data());
- auto ln2Beta_ptr =
- reinterpret_cast(ln2Beta[i].data());
- auto gate_weight_ptr =
- reinterpret_cast(gateWeight[i].data());
- auto up_weight_ptr =
- reinterpret_cast(upWeight[i].data());
- auto down_weight_ptr =
- reinterpret_cast(downWeight[i].data());
- auto gate_bias_ptr =
- reinterpret_cast(gateBias[i].data());
- auto up_bias_ptr =
- reinterpret_cast(upBias[i].data());
- auto down_bias_ptr =
- reinterpret_cast(downBias[i].data());
- auto attnOutBias_ptr =
- reinterpret_cast(attnOutBias[i].data());
- invokeLayerLLaMA(
- xft_data_type, // dt
- xft_act_type, // at
- xft_norm_type, // nt
- i, // layerId
- totalLayer, // totalLayers
- batchSize, // batchSize
- inputSeqLen, // inputSeqLen
- attHeadDim, // attHeadDim
- attHeadNum, // attHeadNum
- kvHeadNum, // kvHeadNum
- maxPositions, // maxPositions
- maxPosEmbed, // maxPosEmbed
- past_seq_len, // pastSeqLen
- cur_seq_len, // currentSeqLen
- step_id, // step
- hiddensize, // hiddenSize
- intermediateSize, // intermediateSize
- reinterpret_cast(output_ptr), // output
- hiddensize, // outputStride
- input_ptr, // input
- hiddensize, // inputStride
- ln1Gamma_ptr, // ln1Gamma
- ln1Beta_ptr, // ln1Beta
- qkvWeight_ptr, // queryWeight
- qkvWeight_ptr + hiddensize, // keyWeight
- qkvWeight_ptr + hiddensize + kvHeadNum * attHeadDim, // valueWeight
- attnOutWeight_ptr, // attnOutWeight
- ln2Gamma_ptr, // ln2Gamma
- ln2Beta_ptr, // ln2Beta
- gate_weight_ptr,
- up_weight_ptr,
- down_weight_ptr,
- qkvBiasWeight_ptr, // queryBias
- qkvBiasWeight_ptr + hiddensize, // keyBias
- qkvBiasWeight_ptr + hiddensize +
- kvHeadNum * attHeadDim, // valueBias
- attnOutBias_ptr, // attnOutBias
- qkvWeight_ptr, // myqkvWeight
- gate_bias_ptr,
- up_bias_ptr,
- down_bias_ptr,
- qkvBiasWeight_ptr);
- if (i < totalLayer - 1) {
- memcpy(const_cast(input_ptr),
- output_ptr,
- batchSize * inputSeqLen * hiddensize * sizeof(float));
- }
- }
- return {out};
-}
-
-std::vector> AllLLaMALayerInferShape(
- std::vector x_shape) {
- return {x_shape};
-}
-
-std::vector AllLLaMALayerInferDtype(
- paddle::DataType x_dtype) {
- return {x_dtype};
-}
-
-PD_BUILD_STATIC_OP(xft_llama_all_layer)
- .Inputs({
- "x",
- paddle::Vec("ln1Gamma"),
- paddle::Vec("ln1Beta"),
- paddle::Vec("qkvWeight"),
- paddle::Vec("qkvBiasWeight"),
- paddle::Vec("attnOutWeight"),
- paddle::Vec("attnOutBias"),
- paddle::Vec("ln2Gamma"),
- paddle::Vec("ln2Beta"),
- paddle::Vec("gateWeight"),
- paddle::Vec("gateBias"),
- paddle::Vec("upWeight"),
- paddle::Vec("upBias"),
- paddle::Vec("downWeight"),
- paddle::Vec("downBias"),
- "pastSeqLen",
- "currentSeqLen",
- "step",
- })
- .Outputs({"out"})
- .Attrs({"hiddensize :int",
- "totalLayer :int",
- "computeType : std::string",
- "activation :std::string",
- "normType :std::string",
- "attHeadDim: int",
- "attHeadNum: int",
- "kvHeadNum: int",
- "maxPositions: int",
- "maxPosEmbed: int",
- "intermediateSize: int"})
- .SetKernelFn(PD_KERNEL(InvokeAllLLaMALayer))
- .SetInferShapeFn(PD_INFER_SHAPE(AllLLaMALayerInferShape))
- .SetInferDtypeFn(PD_INFER_DTYPE(AllLLaMALayerInferDtype));
diff --git a/custom_ops/cpu_ops/xft_greedy_search.cc b/custom_ops/cpu_ops/xft_greedy_search.cc
deleted file mode 100644
index 4ee78a768..000000000
--- a/custom_ops/cpu_ops/xft_greedy_search.cc
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-#include
-#include
-#include
-#include "paddle/extension.h"
-
-void greedy_search(const float *probs,
- int64_t *next_token_ids,
- int bsz,
- int vocab_size) {
- int numThreads = 0;
-#pragma omp parallel
- {
- int tid = omp_get_thread_num();
- if (tid == 0) {
- numThreads = omp_get_num_threads();
- }
- }
- float maxVals[bsz];
-
- // Small batch size (each sample can have at least 2 threads)
- if (numThreads / bsz >= 2) {
- int thrPerSample = numThreads / bsz;
- int sizePerThr = (vocab_size + thrPerSample - 1) / thrPerSample;
- int maxIndices[bsz * thrPerSample];
- float maxValues[bsz * thrPerSample];
-
- // TODO: if size is small, possible to cause out of boundary
-#pragma omp parallel for collapse(2)
- for (int b = 0; b < bsz; ++b) {
- for (int t = 0; t < thrPerSample; ++t) {
- int start = t * sizePerThr;
- int end = (start + sizePerThr) > vocab_size
- ? vocab_size
- : (start + sizePerThr);
- const float *p = probs + b * vocab_size;
- int maxIdx = start;
- float maxVal = p[start];
- for (int off = start + 1; off < end; ++off) {
- if (p[off] > maxVal) {
- maxVal = p[off];
- maxIdx = off;
- }
- }
-
- // False sharing happens, but since only one time, not avoided
- maxIndices[b * thrPerSample + t] = maxIdx;
- maxValues[b * thrPerSample + t] = maxVal;
- }
- }
-
- // Local reduction
- for (int i = 0; i < bsz; ++i) {
- int *pIndices = maxIndices + i * thrPerSample;
- float *pValues = maxValues + i * thrPerSample;
- int maxIdx = pIndices[0];
- float maxVal = pValues[0];
- for (int j = 1; j < thrPerSample; ++j) {
- if (pValues[j] > maxVal) {
- maxVal = pValues[j];
- maxIdx = pIndices[j];
- }
- }
- next_token_ids[i] = maxIdx;
- maxVals[i] = maxVal;
- }
- }
-
- // Each thread handle one sample (one row)
- else {
-#pragma omp parallel for
- for (int i = 0; i < bsz; ++i) {
- int maxId = 0;
- const float *p = probs + i * vocab_size;
- float maxVal = p[0];
- for (int j = 1; j < vocab_size; ++j) {
- if (p[j] > maxVal) {
- maxVal = p[j];
- maxId = j;
- }
- }
- next_token_ids[i] = maxId;
- maxVals[i] = maxVal;
- }
- }
- return;
-}
-std::vector XftGreedySearch(const paddle::Tensor &probs) {
- const int bsz = probs.shape()[0];
- const int vocab_size = probs.shape()[1];
- auto next_tokens =
- paddle::empty({bsz, 1}, paddle::DataType::INT64, probs.place());
-
- greedy_search(probs.data(),
- const_cast(next_tokens.data()),
- bsz,
- vocab_size);
- return {next_tokens};
-}
-std::vector> XftGreedySearchInferShape(
- const std::vector &probs_shape) {
- int64_t bsz = probs_shape[0];
- return {{bsz, 1}};
-}
-std::vector XftGreedySearchInferDtype(
- const paddle::DataType &probs_dtype) {
- return {paddle::DataType::INT64};
-}
-PD_BUILD_STATIC_OP(xft_greedy_search)
- .Inputs({"probs"})
- .Outputs({"next_tokens_ids"})
- .SetInferShapeFn(PD_INFER_SHAPE(XftGreedySearchInferShape))
- .SetInferDtypeFn(PD_INFER_DTYPE(XftGreedySearchInferDtype))
- .SetKernelFn(PD_KERNEL(XftGreedySearch));
diff --git a/custom_ops/gpu_ops/air_topp_sampling.cu b/custom_ops/gpu_ops/air_topp_sampling.cu
deleted file mode 100644
index 92318b38d..000000000
--- a/custom_ops/gpu_ops/air_topp_sampling.cu
+++ /dev/null
@@ -1,1612 +0,0 @@
-// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-/*
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include
-#include
-#include
-#include
-
-#include "helper.h"
-#include "paddle/phi/common/memory_utils.h"
-#include "paddle/phi/backends/context_pool.h"
-#include "paddle/phi/core/stream.h"
-
-#define CHECK_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
-
-#define FINAL_MASK 0xFFFFFFFF
-
-#define FIXED_BLOCK_DIM_BASE(dim, ...) \
- case (dim): { \
- constexpr auto kBlockDim = (dim); \
- __VA_ARGS__; \
- } break
-
-
-#define FIXED_BLOCK_DIM(...) \
- FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
- FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
- FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
- FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
- FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
- FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
-
-template
-struct alignas(128) Counter
-{
- T const* in;
- IdxT const* inIdx;
-
- IdxT oriLen;
-
- AccT sum;
- IdxT len;
- float p;
- IdxT previousLen;
- typename cub::Traits::UnsignedBits kthValueBits;
-
- alignas(128) IdxT filterCnt;
- alignas(128) uint32_t finishedBlockCnt;
-};
-
-template
-constexpr __host__ __device__ IntType ceilDiv(IntType a, IntType b)
-{
- return (a + b - 1) / b;
-}
-
-template
-constexpr __host__ __device__ IntType alignTo(IntType a, IntType b)
-{
- return ceilDiv(a, b) * b;
-}
-
-/**
- * This function calculate the bufLen, which is the size of buffer.
- * When the number of candidates for next pass exceeds the bufLen, we choose not to store the candidates. Otherwise, we
- * will load candidates from the original input data.
- */
-template
-__host__ __device__ IdxT calcBufLen(IdxT len)
-{
- IdxT constexpr ratio = 2 + sizeof(IdxT) * 2 / sizeof(T);
- IdxT bufLen = len / (ratio * 8);
- bufLen = alignTo(bufLen, 256);
- return bufLen;
-}
-
-template
-__host__ __device__ constexpr int calcNumPasses()
-{
- return ceilDiv(sizeof(T) * 8, BitsPerPass);
-}
-
-template
-__device__ typename cub::Traits::UnsignedBits twiddleIn(T key, bool selectMin)
-{
- auto bits = reinterpret_cast::UnsignedBits&>(key);
- bits = cub::Traits::TwiddleIn(bits);
- if (!selectMin)
- {
- bits = ~bits;
- }
- return bits;
-}
-
-template
-__device__ T twiddleOut(typename cub::Traits::UnsignedBits bits, bool selectMin)
-{
- if (!selectMin)
- {
- bits = ~bits;
- }
- bits = cub::Traits::TwiddleOut(bits);
- return reinterpret_cast(bits);
-}
-
-template
-__host__ __device__ constexpr int calcNumBuckets()
-{
- return 1 << BitsPerPass;
-}
-
-template
-__device__ constexpr int calcStartBit()
-{
- constexpr int tmpBit = sizeof(T) * 8 - (Pass + 1) * BitsPerPass;
-
- constexpr int startBit = tmpBit < 0 ? 0 : tmpBit;
- return startBit;
-}
-
-template
-__device__ constexpr uint32_t calcMask()
-{
- static_assert(BitsPerPass <= 31);
- constexpr int numBits = calcStartBit() - calcStartBit();
- return (1 << numBits) - 1;
-}
-
-/**
- * Find the bucket based on the radix
- */
-template
-__device__ int calcBucket(T x, int startBit, uint32_t mask, bool selectMin)
-{
- return (twiddleIn(x, selectMin) >> startBit) & mask;
-}
-
-/**
- * Replace histogram with its own prefix sum (step 2 in `airTopPSampling` description)
- */
-template
-__device__ void scan(IdxT volatile* histogram, IdxT* histogramOut)
-{
- int constexpr numBuckets = calcNumBuckets();
- if constexpr (numBuckets >= BlockSize)
- {
- static_assert(numBuckets % BlockSize == 0);
- int constexpr itemsPerThread = numBuckets / BlockSize;
- typedef cub::BlockLoad BlockLoad;
- typedef cub::BlockStore BlockStore;
- typedef cub::BlockScan BlockScan;
-
- __shared__ union
- {
- typename BlockLoad::TempStorage load;
- typename BlockScan::TempStorage scan;
- typename BlockStore::TempStorage store;
- } tempStorage;
-
- IdxT threadData[itemsPerThread];
-
- BlockLoad(tempStorage.load).Load(histogram, threadData);
- __syncthreads();
-
- BlockScan(tempStorage.scan).InclusiveSum(threadData, threadData);
- __syncthreads();
-
- BlockStore(tempStorage.store).Store(histogramOut, threadData);
- }
- else
- {
- typedef cub::BlockScan BlockScan;
- __shared__ typename BlockScan::TempStorage tempStorage;
-
- IdxT threadData = 0;
- if (threadIdx.x < numBuckets)
- {
- threadData = histogram[threadIdx.x];
- }
-
- BlockScan(tempStorage).InclusiveSum(threadData, threadData);
- __syncthreads();
-
- if (threadIdx.x < numBuckets)
- {
- histogramOut[threadIdx.x] = threadData;
- }
- }
-}
-
-template
-__device__ __forceinline__ void filterAndHistogram(const T *in_buffer,
- const int *in_idx_buffer,
- T *out_buffer,
- int *out_idx_buffer,
- T *out_scores,
- int64_t *out_ids,
- int previous_len,
- Counter *counter,
- T *histogram,
- int *count_histogram,
- T *histogram_shm,
- int *count_histogram_shm,
- const bool early_stop) {
- // scan and filter
- constexpr int start_bit = calcStartBit();
- const uint32_t mask = calcMask();
- constexpr int VecSize = 16 / sizeof(T);
- const int bid = blockIdx.y, tid = threadIdx.x;
- using VecT = uint4;
- union {
- VecT v;
- T array[VecSize];
- } vec;
- for (int i = (blockIdx.x * blockDim.x + threadIdx.x) ; i < ceilDiv(previous_len, VecSize); i += blockDim.x * gridDim.x) {
- vec.v = reinterpret_cast(in_buffer)[i];
- if constexpr (Pass == 0) {
-#pragma unroll
- for (int j = 0; j < VecSize; j++) {
- if (i * VecSize + j < previous_len) {
- int bucket = calcBucket(vec.array[j], start_bit, mask, false);
- atomicAdd(histogram_shm + bucket, vec.array[j]);
- atomicAdd(count_histogram_shm + bucket, 1);
- }
- }
- } else {
- int *filter_cnt = &counter->filterCnt;
- const auto kthValueBits = counter->kthValueBits;
- constexpr int previousStartBit = calcStartBit();
-#pragma unroll
- for (int j = 0; j < VecSize; j++) {
- const int idx = i * VecSize + j;
- if (idx < previous_len) {
- const auto previousBits = (twiddleIn(vec.array[j], false) >> previousStartBit) << previousStartBit;
- if (previousBits == kthValueBits) {
- if (early_stop) {
- const int pos = in_idx_buffer ? in_idx_buffer[idx] : idx;
- out_scores[bid] = vec.array[j];
- out_ids[bid] = pos;
- }
- if (out_buffer) {
- int pos = atomicAdd(filter_cnt, 1);
- out_buffer[pos] = vec.array[j];
- out_idx_buffer[pos] = in_idx_buffer ? in_idx_buffer[idx] : idx;
- }
- int bucket = calcBucket(vec.array[j], start_bit, mask, false);
- atomicAdd(histogram_shm + bucket, vec.array[j]);
- atomicAdd(count_histogram_shm + bucket, 1);
- }
- }
- }
- }
- }
- __syncthreads();
- if (early_stop) {
- return;
- }
- for (int i = tid; i < NumBuckets; i += blockDim.x) {
- if (count_histogram_shm[i] > 0) {
- atomicAdd(histogram + i, histogram_shm[i]);
- atomicAdd(count_histogram + i, count_histogram_shm[i]);
- }
- }
-}
-
-template
-__global__ void air_topp_sampling(Counter *counters,
- T *histograms,
- int *count_histograms,
- T *out,
- int64_t *ids,
- T *buf1,
- int *idx_buf1,
- T *buf2,
- int *idx_buf2,
- int* count_iter,
- int* count_iter_begin,
- const int buf_len) {
-
- /***
- * calc - filter - scan -find
- * TODO: calc - scan - find - filter
- ***/
- const int bid = blockIdx.y;
- if (count_iter_begin[bid] == count_iter[bid + 1]) {
- // topk
- return;
- }
-
- const int tid = threadIdx.x;
- auto counter = counters + bid;
-
- T current_sum;
- int previous_len, current_len;
- if constexpr (Pass == 0) {
- current_sum = 0;
- previous_len = counter->len;
- current_len = counter->len;
- } else {
- current_sum = counter->sum;
- previous_len = counter->previousLen;
- current_len = counter->len;
- }
- if (current_len == 0) {
- return;
- }
- const bool early_stop = (current_len == 1);
- const T *in_buf = nullptr;
- const int *in_idx_buf = nullptr;
- T *out_buf = nullptr;
- int *out_idx_buf = nullptr;
- const int buf_offset = bid * buf_len;
- if constexpr (Pass == 0) {
- in_buf = counter->in;
- in_idx_buf = nullptr;
- out_buf = nullptr;
- out_idx_buf = nullptr;
- } else if constexpr (Pass == 1) {
- in_buf = counter->in;
- in_idx_buf = nullptr;
- out_buf = buf1 + buf_offset;
- out_idx_buf = idx_buf1 + buf_offset;
- } else {
- in_buf = buf1 + buf_offset;
- in_idx_buf = idx_buf1 + buf_offset;
- out_buf = buf2 + buf_offset;
- out_idx_buf = idx_buf2 + buf_offset;
- }
-
- if (Pass == 0 || Pass == 1 || previous_len > buf_len) {
- previous_len = counter->oriLen;
- in_buf = counter->in;
- in_idx_buf = nullptr;
- }
- if (Pass == 0 || current_len > buf_len) {
- out_buf = nullptr;
- out_idx_buf = nullptr;
- }
-
- auto histogram = histograms + bid * NumBuckets;
- auto count_histogram = count_histograms + bid * NumBuckets;
- __shared__ T histogram_shm[NumBuckets];
- __shared__ int count_histogram_shm[NumBuckets];
- for (int i = tid; i < NumBuckets; i += blockDim.x) {
- histogram_shm[i] = 0;
- count_histogram_shm[i] = 0;
- }
- __syncthreads();
-
- filterAndHistogram(
- in_buf,
- in_idx_buf,
- out_buf,
- out_idx_buf,
- out,
- ids,
- previous_len,
- counter,
- histogram,
- count_histogram,
- histogram_shm,
- count_histogram_shm,
- early_stop
- );
- __syncthreads();
- __threadfence();
-
- // find last block
- bool isLastBlock = false;
- if (threadIdx.x == 0) {
- uint32_t finished = atomicInc(&counter->finishedBlockCnt, gridDim.x - 1);
- isLastBlock = (finished == (gridDim.x - 1));
- }
-
- if (__syncthreads_or(isLastBlock)) {
- if (early_stop) {
- if (threadIdx.x == 0) {
- counter->previousLen = 0;
- counter->len = 0;
- }
- return;
- }
-
- // scan/find
- constexpr int WARP_SIZE = 32;
- constexpr int WARP_COUNT = NumBuckets / WARP_SIZE;
- namespace cg = cooperative_groups;
- cg::thread_block block = cg::this_thread_block();
- cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
- __shared__ T warpSum[WARP_COUNT];
- __shared__ cuda::atomic blockSum;
- for (int i = tid; i < WARP_COUNT; i += BlockSize) {
- warpSum[i] = 0;
- }
- if (tid == 0) {
- blockSum = 0;
- }
- __syncthreads();
- // Acquire the summation of each 32 buckets
- for (int i = threadIdx.x; i < NumBuckets; i += BlockSize) {
- reduce_store_async(warp, warpSum + i / WARP_SIZE, histogram[i], cg::plus{});
- }
- __syncthreads();
- // Acquire the summation of all the 2048 buckets
- if (threadIdx.x < WARP_SIZE) {
- reduce_store_async(warp, blockSum, warpSum[threadIdx.x], cg::plus{});
- reduce_update_async(warp, blockSum, warpSum[threadIdx.x + WARP_SIZE], cg::plus{});
- }
- __syncthreads();
-
- if constexpr (Pass == 0) {
- current_sum = blockSum * counter->p;
- }
-
- if (tid == 0) {
- T prev = 0;
-
- // Add 32 elements each step
- int iStep = 0;
- int targetStep = 0;
- for (; iStep < WARP_COUNT; iStep++) {
- if (warpSum[iStep]) {
- targetStep = iStep;
- if ((prev + warpSum[iStep]) >= current_sum) {
- break;
- }
- prev += warpSum[iStep];
- }
- }
-
- int targetIdx = 0;
- for (int i = targetStep * WARP_SIZE; i < NumBuckets; i++) {
- if (count_histogram[i]) {
- targetIdx = i;
- if ((prev + histogram[i]) >= current_sum) {
- break;
- }
- prev += histogram[i];
- }
- }
- counter->sum = current_sum - prev; // how many values still are there to find
- counter->len = count_histogram[targetIdx]; // cur - prev; // number of values in next pass
- typename cub::Traits::UnsignedBits bucket = targetIdx;
- int startBit = calcStartBit();
- counter->kthValueBits |= bucket << startBit;
- }
- __syncthreads();
- constexpr int numPasses = calcNumPasses();
- if constexpr (Pass != numPasses - 1) {
- for (int i = tid; i < NumBuckets; i += BlockSize) {
- histogram[i] = 0;
- count_histogram[i] = 0;
- }
- }
- if (tid == 0) {
- // recover
- counter->previousLen = current_len;
- counter->filterCnt = 0;
- }
- if constexpr (Pass == numPasses - 1) {
- const auto kthValueBits = counter->kthValueBits;
- const auto equal_value = twiddleOut(kthValueBits, false);
-
- const T *last_data = out_buf ? out_buf : in_buf;
- const int *last_idx_data = out_idx_buf ? out_idx_buf : in_idx_buf;
- const int last_len = out_buf ? current_len : counter->oriLen;
- for (int i = tid; i < last_len; i += BlockSize) {
- if (last_data[i] == equal_value) {
- out[bid] = equal_value;
- ids[bid] = last_idx_data ? last_idx_data[i] : i;
- }
- }
- }
- }
-}
-
-template
-__global__ void air_topp_init(Counter *counters,
- T *histograms,
- int *count_histograms,
- const T *in,
- const T *ps,
- curandState_t* curandstate,
- const int bsz,
- const int vocab_size,
- const int buf_len,
- const int num_buckets) {
- const int bid = blockIdx.x;
- const int tid = threadIdx.x;
- Counter *counter_now = counters + bid;
- T *histogram_now = histograms + bid * num_buckets;
- int *count_histogram_now = count_histograms + bid * num_buckets;
- const int offset = bid * vocab_size;
- if (tid == 0) {
- counter_now->in = in + offset;
-
- counter_now->len = vocab_size;
- counter_now->oriLen = vocab_size;
- counter_now->previousLen = vocab_size;
-
- const T p = ps[bid];
- const T rand_p = curand_uniform(curandstate + bid) * p;
- counter_now->p = rand_p;
-
- counter_now->sum = 0;
-
- counter_now->kthValueBits = 0;
- counter_now->filterCnt = 0;
- counter_now->finishedBlockCnt = 0;
- }
- for (int i = tid; i < num_buckets; i += blockDim.x) {
- histogram_now[i] = 0;
- count_histogram_now[i] = 0;
- }
-}
-
-struct SegmentOffsetIter {
- explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
-
- __host__ __device__ __forceinline__ int operator()(int idx) const {
- return idx * num_cols_;
- }
-
- int num_cols_;
-};
-
-template
-struct Pair {
- __device__ __forceinline__ Pair() {}
- __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
-
- __device__ __forceinline__ void set(T value, int id) {
- this->v = value;
- this->id = id;
- }
-
- __device__ __forceinline__ void operator=(const Pair& in) {
- v = in.v;
- id = in.id;
- }
-
- __device__ __forceinline__ bool operator<(const T value) const {
- return (static_cast(v) < static_cast(value));
- }
-
- __device__ __forceinline__ bool operator>(const T value) const {
- return (static_cast(v) > static_cast(value));
- }
- __device__ __forceinline__ bool operator<(const Pair& in) const {
- return (static_cast(v) < static_cast(in.v)) ||
- ((static_cast(v) == static_cast(in.v)) &&
- (id > in.id));
- }
-
- __device__ __forceinline__ bool operator>(const Pair& in) const {
- return (static_cast(v) > static_cast(in.v)) ||
- ((static_cast(v) == static_cast(in.v)) &&
- (id < in.id));
- }
-
- T v;
- int id;
-};
-
-inline int div_up(int a, int n) { return (a + n - 1) / n; }
-
-template
-__device__ __forceinline__ void AddTo(Pair topk[],
- const Pair& p,
- int beam_size) {
- for (int k = beam_size - 2; k >= 0; k--) {
- if (topk[k] < p) {
- topk[k + 1] = topk[k];
- } else {
- topk[k + 1] = p;
- return;
- }
- }
- topk[0] = p;
-}
-
-template
-__device__ __forceinline__ void GetTopK(Pair topk[],
- const T* src,
- int idx,
- int dim,
- int beam_size) {
- while (idx < dim) {
- if (topk[beam_size - 1] < src[idx]) {
- Pair tmp(src[idx], idx);
- AddTo(topk, tmp, beam_size);
- }
- idx += BlockSize;
- }
-}
-
-template
-__device__ __forceinline__ void GetTopK(Pair topk[],
- const T* src,
- int idx,
- int dim,
- const Pair& max,
- int beam_size) {
- while (idx < dim) {
- if (topk[beam_size - 1] < src[idx]) {
- Pair tmp(src[idx], idx);
- if (tmp < max) {
- AddTo(topk, tmp, beam_size);
- }
- }
- idx += BlockSize;
- }
-}
-
-template
-__device__ __forceinline__ void ThreadGetTopK(Pair topk[],
- int* beam,
- int beam_size,
- const T* src,
- bool* firstStep,
- bool* is_empty,
- Pair* max,
- int dim,
- const int tid) {
- if (*beam > 0) {
- int length = (*beam) < beam_size ? *beam : beam_size;
- if (*firstStep) {
- *firstStep = false;
- GetTopK(topk, src, tid, dim, length);
- } else {
- for (int k = 0; k < MaxLength; k++) {
- if (k < MaxLength - (*beam)) {
- topk[k] = topk[k + *beam];
- } else {
- topk[k].set(std::numeric_limits::min(), -1);
- }
- }
- if (!(*is_empty)) {
- GetTopK(
- topk + MaxLength - *beam, src, tid, dim, *max, length);
- }
- }
-
- *max = topk[MaxLength - 1];
- if ((*max).id == -1) *is_empty = true;
- *beam = 0;
- }
-}
-
-template
-__forceinline__ __device__ T
-CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) {
- return __shfl_down_sync(mask, val, static_cast(delta), width);
-}
-
-template
-__forceinline__ __device__ Pair WarpReduce(Pair input) {
-#pragma unroll
- for (int offset = 16; offset > 0; offset >>= 1) {
- T tmp_val =
- CudaShuffleDownSync(FINAL_MASK, input.v, offset, 32);
- int tmp_id =
- CudaShuffleDownSync(FINAL_MASK, input.id, offset, 32);
- if (static_cast(input.v) < static_cast(tmp_val)) {
- input.v = tmp_val;
- input.id = tmp_id;
- }
- }
- return input;
-}
-
-template
-__device__ __forceinline__ void BlockReduce(Pair shared_max[],
- Pair topk[],
- Pair beam_max[],
- int* beam,
- int* k,
- int* count,
- const int tid,
- const int wid,
- const int lane) {
- while (true) {
- __syncthreads();
- Pair input_now = topk[0];
- input_now = WarpReduce(input_now);
-
- if (lane == 0) {
- shared_max[wid] = input_now;
- }
- __syncthreads();
- input_now = (tid < BlockSize / 32)
- ? shared_max[lane]
- : Pair(std::numeric_limits::min(), -1);
- if (wid == 0) {
- input_now = WarpReduce(input_now);
- if (lane == 0) shared_max[0] = input_now;
- }
- __syncthreads();
- if (tid == 0) {
- beam_max[*count] = shared_max[0];
- (*count)++;
- }
- int tid_max = shared_max[0].id % BlockSize;
- if (tid == tid_max) {
- (*beam)++;
- }
- if (--(*k) == 0) break;
- __syncthreads();
-
- if (tid == tid_max) {
- if (*beam < MaxLength) {
- topk[0] = topk[*beam];
- }
- }
-
- if (MaxLength < 5) {
- if (*beam >= MaxLength) break;
- } else {
- unsigned mask = 0u;
- mask = __ballot_sync(FINAL_MASK, true);
- if (tid_max / 32 == wid) {
- if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) == MaxLength)
- break;
- }
- }
- }
-}
-
-template
-__device__ inline T exponential_transform(T val, T lambda) {
-#if defined(__NVCC__) || defined(__HIPCC__)
- T log = -std::numeric_limits::epsilon() / 2;
- if (val < static_cast(1.) - std::numeric_limits::epsilon() / 2) {
- if (std::is_same::value) {
- log = logf(val);
- } else {
- log = __logf(val);
- }
- }
- return static_cast(-1.0) / lambda * log;
-#else
- return static_cast