mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-01 06:42:23 +08:00
Compare commits
137 Commits
v2.0.0
...
release/2.
Author | SHA1 | Date | |
---|---|---|---|
![]() |
5224f6c434 | ||
![]() |
bfef09dd73 | ||
![]() |
1d46420c49 | ||
![]() |
fb0f284e67 | ||
![]() |
5d1788c7b5 | ||
![]() |
abd238fc12 | ||
![]() |
e5804b1d98 | ||
![]() |
8c43bc8176 | ||
![]() |
b0f1e0eef4 | ||
![]() |
69be77c8c0 | ||
![]() |
535a15ab8f | ||
![]() |
580460046f | ||
![]() |
4dbc483713 | ||
![]() |
4ead15822c | ||
![]() |
f941124402 | ||
![]() |
b89f083004 | ||
![]() |
4d05ed596c | ||
![]() |
bc1866af58 | ||
![]() |
fe237fe92b | ||
![]() |
3a480abcbb | ||
![]() |
335609efb6 | ||
![]() |
3464f75f98 | ||
![]() |
09d0073fdc | ||
![]() |
63d6e7ce06 | ||
![]() |
aa76085d1f | ||
![]() |
42b80182e0 | ||
![]() |
dda4a9f848 | ||
![]() |
a83a3eea5f | ||
![]() |
0d0340392f | ||
![]() |
0253381fb9 | ||
![]() |
2d1184aefe | ||
![]() |
17314ee126 | ||
![]() |
101ad33332 | ||
![]() |
0fad10b35a | ||
![]() |
61b3997b85 | ||
![]() |
e7bcbbab52 | ||
![]() |
5fc659b900 | ||
![]() |
33db137d0b | ||
![]() |
9d6a42b334 | ||
![]() |
1b712bba82 | ||
![]() |
fd91da7b41 | ||
![]() |
15c8c240b5 | ||
![]() |
7cdd8d290d | ||
![]() |
4c7b8bc458 | ||
![]() |
2e81792d64 | ||
![]() |
b7858c22d9 | ||
![]() |
09bbac6de0 | ||
![]() |
7f64d408a9 | ||
![]() |
ece88596ed | ||
![]() |
bad53c6b6e | ||
![]() |
16940822a7 | ||
![]() |
d48c03413f | ||
![]() |
e9e8443ea8 | ||
![]() |
749b2e9c89 | ||
![]() |
f6ad26fc08 | ||
![]() |
c08561c13a | ||
![]() |
2c3607407f | ||
![]() |
b5e4288704 | ||
![]() |
abbbd0cddc | ||
![]() |
e98937cbba | ||
![]() |
240d6236bc | ||
![]() |
59071268b6 | ||
![]() |
8c660a0dfb | ||
![]() |
ce5adec877 | ||
![]() |
36571fd2d9 | ||
![]() |
830de5a925 | ||
![]() |
d33105baeb | ||
![]() |
24f934f1f9 | ||
![]() |
1e2319cbef | ||
![]() |
e45050cae3 | ||
![]() |
b0f525955c | ||
![]() |
2ea267f624 | ||
![]() |
1d8af7ab73 | ||
![]() |
54affdc44b | ||
![]() |
a4fdb3970b | ||
![]() |
2a86928657 | ||
![]() |
b1c53fa779 | ||
![]() |
da20cf681e | ||
![]() |
4ccd1696ab | ||
![]() |
888780ffde | ||
![]() |
e3768c5a83 | ||
![]() |
1f28bdf994 | ||
![]() |
03a74995b8 | ||
![]() |
b89180f1cd | ||
![]() |
be21ef5047 | ||
![]() |
771e71a24d | ||
![]() |
0350831c2b | ||
![]() |
fee544e808 | ||
![]() |
c4718fd693 | ||
![]() |
f7cad30a38 | ||
![]() |
6b10c19482 | ||
![]() |
f4f1d8de44 | ||
![]() |
6610aa29d0 | ||
![]() |
f72c4de539 | ||
![]() |
f6ffbc3cbd | ||
![]() |
e8bbe7244b | ||
![]() |
57b086dc6b | ||
![]() |
525be243e7 | ||
![]() |
d0f4d6ba3a | ||
![]() |
26d5d737dd | ||
![]() |
fefbd65cf8 | ||
![]() |
1eb8ea7328 | ||
![]() |
ef6649a577 | ||
![]() |
1b54a2831e | ||
![]() |
2579e8fea8 | ||
![]() |
91528f1af9 | ||
![]() |
4e293e50fa | ||
![]() |
66b321d9ec | ||
![]() |
68b4755587 | ||
![]() |
04a8e1ef2b | ||
![]() |
a6e9161045 | ||
![]() |
90ef28d982 | ||
![]() |
b37585e693 | ||
![]() |
9cb08e71e8 | ||
![]() |
dacc46f04c | ||
![]() |
09ded7715f | ||
![]() |
11cfdf5d89 | ||
![]() |
e7fa57ebae | ||
![]() |
a5ae88ded9 | ||
![]() |
87e638498c | ||
![]() |
667547be59 | ||
![]() |
b38823bc66 | ||
![]() |
050d9658a5 | ||
![]() |
be5cabaf80 | ||
![]() |
240bdac2a4 | ||
![]() |
00863c43fd | ||
![]() |
3d3bccdf79 | ||
![]() |
9fd74f75bd | ||
![]() |
05c670e593 | ||
![]() |
d222248d00 | ||
![]() |
e5b94d4117 | ||
![]() |
87e2e58a22 | ||
![]() |
de20e5a992 | ||
![]() |
2f9c0618f0 | ||
![]() |
9a14ab6572 | ||
![]() |
d1cb3ed571 | ||
![]() |
b8a8a19689 |
8
.github/workflows/ci.yml
vendored
8
.github/workflows/ci.yml
vendored
@@ -2,7 +2,9 @@ name: CI
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
@@ -27,9 +29,11 @@ jobs:
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
@@ -38,7 +42,7 @@ jobs:
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME}
|
||||
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
|
87
.github/workflows/ci_xpu.yml
vendored
Normal file
87
.github/workflows/ci_xpu.yml
vendored
Normal file
@@ -0,0 +1,87 @@
|
||||
name: CI_XPU
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}-xpu-ci
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
CI_XPU:
|
||||
runs-on: [self-hosted, XPU-P800-8Card]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
# Because the system version is lower than 2.23, the checkout cannot be used.
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}
|
||||
fi
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
git merge pr/${{ github.event.pull_request.number }}
|
||||
git log -n 3 --oneline
|
||||
else
|
||||
git checkout ${{ github.sha }}
|
||||
git log -n 3 --oneline
|
||||
fi
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
|
||||
if [[ "$last_char" =~ [0-3] ]]; then
|
||||
gpu_id="$last_char"
|
||||
else
|
||||
gpu_id="0"
|
||||
fi
|
||||
FD_API_PORT=$((9180 + gpu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
|
||||
FD_METRICS_PORT=$((9170 + gpu_id * 100))
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
docker run --rm --net=host --cap-add=SYS_PTRACE --privileged --shm-size=64G \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "/ssd3:/ssd3" \
|
||||
-e "MODEL_PATH=/ssd3/model" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci_xpu.sh
|
||||
"
|
6
.github/workflows/gh-pages.yml
vendored
6
.github/workflows/gh-pages.yml
vendored
@@ -3,8 +3,6 @@ name: Deploy GitHub Pages
|
||||
on:
|
||||
push:
|
||||
branches: [ develop ]
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -21,4 +19,6 @@ jobs:
|
||||
- name: Deploy to GitHub Pages
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: mkdocs gh-deploy --force --remote-name origin
|
||||
run: |
|
||||
git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}.git
|
||||
mkdocs gh-deploy --force --remote-name origin
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@@ -162,3 +162,5 @@ custom_ops/tmp*
|
||||
build
|
||||
|
||||
.ccls-cache
|
||||
|
||||
third_party
|
||||
|
@@ -5,12 +5,6 @@ default_stages:
|
||||
- pre-commit # Run locally
|
||||
# - manual # Run in CI
|
||||
repos:
|
||||
# 格式化
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.43.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
# 代码检查
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.7
|
||||
@@ -29,15 +23,6 @@ repos:
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
# # 格式化
|
||||
# - repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
# rev: v20.1.3
|
||||
# hooks:
|
||||
# - id: clang-format
|
||||
# # exclude: '.*'
|
||||
# types_or: [c++, cuda]
|
||||
# args: [--style=file, --verbose]
|
||||
|
||||
# markdown
|
||||
- repo: https://github.com/jackdewinter/pymarkdown
|
||||
rev: v0.9.29
|
||||
|
@@ -8,14 +8,17 @@
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
|
||||
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/installation/nvidia_gpu/"><b> Installation </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/quick_start"><b> Quick Start </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
|
||||
|
||||
</p>
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
@@ -105,3 +105,30 @@ python benchmark_serving.py \
|
||||
--save-result > infer_log.txt 2>&1 &
|
||||
```
|
||||
|
||||
### 投机解码性能测试工具
|
||||
|
||||
#### 使用方式:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_mtp.py \
|
||||
--host 127.0.0.1 --port 8000 \
|
||||
--max-concurrency 16 32 64 96 --num-prompts 256 \
|
||||
--acceptance-rate 0.8 --draft-token-steps 1 2 3 \
|
||||
--s_itl-base-model 15.88 22.84 16.47 16.93 \
|
||||
--dataset-name EBChat \
|
||||
--dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json
|
||||
```
|
||||
|
||||
#### 参数说明
|
||||
|
||||
```bash
|
||||
--host:服务ip地址,用于组url
|
||||
--port:服务HTTP端口,用于组url
|
||||
--max-concurrency:测试并发数
|
||||
--num-prompts:总计发送多少条请求
|
||||
--acceptance-rate:投机解码的模拟接受率
|
||||
--draft-token-steps:投机解码的步数
|
||||
--s_itl-base-model:主模型的解码延迟,可由上述的性能压测工具获得,与batch-size一一对应
|
||||
--dataset-name:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集
|
||||
--dataset-path:测试数据集路径
|
||||
```
|
@@ -36,6 +36,7 @@ AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
@dataclass
|
||||
class RequestFuncInput:
|
||||
"""Input for requesting LLMs via API"""
|
||||
no: int
|
||||
prompt: str
|
||||
history_QA: Optional[dict]
|
||||
hyper_parameters: dict
|
||||
@@ -54,6 +55,7 @@ class RequestFuncInput:
|
||||
@dataclass
|
||||
class RequestFuncOutput:
|
||||
"""Output for requesting LLMs via API"""
|
||||
no: int = 0
|
||||
generated_text: str = ""
|
||||
reasoning_content: str = ""
|
||||
success: bool = False
|
||||
@@ -84,7 +86,7 @@ async def async_request_eb_openai_chat_completions(
|
||||
if request_func_input.multi_modal_content:
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
payload = {
|
||||
"model": "default",
|
||||
"model": request_func_input.model,
|
||||
"messages": request_func_input.history_QA,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
@@ -97,6 +99,9 @@ async def async_request_eb_openai_chat_completions(
|
||||
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
|
||||
print("payload:{}".format(json.dumps(payload, ensure_ascii=False)))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
@@ -104,6 +109,7 @@ async def async_request_eb_openai_chat_completions(
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = 0
|
||||
output.no = request_func_input.no
|
||||
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
@@ -132,7 +138,8 @@ async def async_request_eb_openai_chat_completions(
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
# cached_tokens
|
||||
output.prompt_len = data["usage"]["prompt_tokens_details"]["cached_tokens"]
|
||||
output.prompt_len = data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
|
||||
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
@@ -141,12 +148,12 @@ async def async_request_eb_openai_chat_completions(
|
||||
|
||||
output.generated_text += content or ""
|
||||
output.reasoning_content += reason_content or ""
|
||||
output.arrival_time.append(choices[0].get("arrival_time"))
|
||||
elif usage := data.get("usage"):
|
||||
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
|
||||
elif usage := data.get("usage", {}):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
"completion_tokens", 0)
|
||||
output.prompt_tokens = usage.get(
|
||||
"prompt_tokens")
|
||||
"prompt_tokens", 0)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
@@ -173,6 +180,7 @@ async def async_request_eb_openai_chat_completions(
|
||||
f.write(str(output) + "\n")
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
print("#####final_output:", output)
|
||||
return output
|
||||
|
||||
|
||||
@@ -189,7 +197,7 @@ async def async_request_eb_openai_completions(
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": "default",
|
||||
"model": request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
@@ -202,14 +210,20 @@ async def async_request_eb_openai_completions(
|
||||
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
|
||||
print("payload:", json.dumps(payload, ensure_ascii=False))
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
output.no = request_func_input.no
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
@@ -226,6 +240,7 @@ async def async_request_eb_openai_completions(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
# print("####chunk:", chunk, chunk.usage)
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
@@ -235,21 +250,22 @@ async def async_request_eb_openai_completions(
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
ttft = time.perf_counter() - st
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
generated_text += text or ""
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
output.arrival_time.append(choices[0].get("arrival_time"))
|
||||
generated_text += text or ""
|
||||
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
|
||||
elif usage := data.get("usage"):
|
||||
output.prompt_tokens = usage.get(
|
||||
"prompt_tokens")
|
||||
@@ -262,8 +278,15 @@ async def async_request_eb_openai_completions(
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
|
||||
if output.generated_text == "":
|
||||
output.success = False
|
||||
output.error = "No generated text found!"
|
||||
else:
|
||||
output.success = True
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@@ -271,6 +294,8 @@ async def async_request_eb_openai_completions(
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
print("final_output:{}".format(output))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
|
@@ -38,7 +38,7 @@ class SampleRequest:
|
||||
"""
|
||||
Represents a single inference request for benchmarking.
|
||||
"""
|
||||
|
||||
no: int
|
||||
prompt: Union[str, Any]
|
||||
history_QA: Union[str, Any]
|
||||
json_data: Optional[dict]
|
||||
@@ -229,6 +229,7 @@ class EBDataset(BenchmarkDataset):
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
cnt = 1
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
@@ -246,16 +247,17 @@ class EBDataset(BenchmarkDataset):
|
||||
prompt, None)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
no=cnt,
|
||||
prompt=prompt,
|
||||
prompt_len=self.prompt_len,
|
||||
history_QA=[],
|
||||
expected_output_len=new_output_len,
|
||||
))
|
||||
cnt += 1
|
||||
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
|
||||
class EBChatDataset(BenchmarkDataset):
|
||||
"""
|
||||
Implements the ShareGPT dataset. Loads data from a JSON file and generates
|
||||
@@ -284,6 +286,7 @@ class EBChatDataset(BenchmarkDataset):
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
cnt = 1
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
@@ -297,12 +300,14 @@ class EBChatDataset(BenchmarkDataset):
|
||||
prompt, None)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
no=cnt,
|
||||
json_data=json_data,
|
||||
prompt=prompt,
|
||||
prompt_len=0,
|
||||
history_QA=history_QA,
|
||||
expected_output_len=new_output_len,
|
||||
))
|
||||
cnt += 1
|
||||
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
191
benchmarks/benchmark_mtp.py
Normal file
191
benchmarks/benchmark_mtp.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
import openai
|
||||
import yaml
|
||||
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
|
||||
from benchmark_serving import benchmark
|
||||
|
||||
|
||||
def prepare_input_requests(
|
||||
num_prompts: int, dataset_name: str, dataset_path: str
|
||||
) -> Union[EBDataset, EBChatDataset]:
|
||||
dataset_mapping = {
|
||||
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(
|
||||
num_requests=num_prompts
|
||||
),
|
||||
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
|
||||
num_requests=num_prompts
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
input_requests = dataset_mapping[dataset_name]()
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unknown dataset: {dataset_name}") from err
|
||||
|
||||
return input_requests
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def encode(self, text: str, add_special_tokens: bool = False):
|
||||
return []
|
||||
|
||||
|
||||
def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm):
|
||||
selected_percentile_metrics = ["s_itl"]
|
||||
selected_percentiles = []
|
||||
# Run benchmark
|
||||
results = asyncio.run(
|
||||
benchmark(
|
||||
backend="openai-chat",
|
||||
api_url=f"{base_url}/v1/chat/completions",
|
||||
base_url=base_url,
|
||||
model_id="default",
|
||||
model_name="default",
|
||||
input_requests=input_requests,
|
||||
hyper_parameters={},
|
||||
logprobs=None,
|
||||
request_rate=float("inf"),
|
||||
burstiness=1.0,
|
||||
disable_tqdm=disable_tqdm,
|
||||
profile=False,
|
||||
selected_percentile_metrics=selected_percentile_metrics,
|
||||
selected_percentiles=selected_percentiles,
|
||||
ignore_eos=False,
|
||||
goodput_config_dict=None,
|
||||
max_concurrency=max_concurrency,
|
||||
lora_modules=None,
|
||||
extra_body=None,
|
||||
)
|
||||
)
|
||||
|
||||
record = {
|
||||
"mean_s_itl_ms": results["mean_s_itl_ms"],
|
||||
}
|
||||
|
||||
return record
|
||||
|
||||
|
||||
def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
|
||||
|
||||
tmp = 0.0
|
||||
for i in range(draft_token_step):
|
||||
tmp += pow(acceptance_rate, i + 1)
|
||||
|
||||
r_ac = tmp / (1 + tmp)
|
||||
|
||||
return t_ori / ((1 - r_ac) * t_mtp)
|
||||
|
||||
|
||||
def main(args):
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
input_requests = prepare_input_requests(
|
||||
args.num_prompts, args.dataset_name, args.dataset_path
|
||||
)
|
||||
|
||||
if len(args.max_concurrency) != len(args.s_itl_base_model):
|
||||
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")
|
||||
|
||||
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
||||
# Wramup
|
||||
print("Starting warmup...")
|
||||
with open(os.devnull, "w") as f:
|
||||
with contextlib.redirect_stdout(f):
|
||||
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)
|
||||
|
||||
# Benchmark
|
||||
record = send_one_batch(base_url, max_concurrency, input_requests, False)
|
||||
|
||||
metric_header = f"Speed up"
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||
for draft_token_step in args.draft_token_steps:
|
||||
speedup = calculate_speedup(
|
||||
args.acceptance_rate,
|
||||
draft_token_step,
|
||||
s_itl,
|
||||
record["mean_s_itl_ms"],
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Speed up on {draft_token_step} steps draft", speedup
|
||||
)
|
||||
)
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=str,
|
||||
default="8000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2, 4, 8, 16, 32),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--acceptance-rate",
|
||||
type=float,
|
||||
default=0.8,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--draft-token-steps",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--s_itl-base-model",
|
||||
type=float,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="EBChat",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
@@ -182,6 +182,7 @@ def calculate_metrics(
|
||||
# len(outputs[i].itl) since multiple output tokens may be
|
||||
# bundled together
|
||||
# Note : this may inflate the output token count slightly
|
||||
continue
|
||||
|
||||
actual_output_lens.append(output_len)
|
||||
input_lens.append(outputs[i].prompt_len)
|
||||
@@ -209,6 +210,8 @@ def calculate_metrics(
|
||||
if len(outputs[i].arrival_time) > 2:
|
||||
s_decodes.append((outputs[i].output_tokens - 1) /
|
||||
(outputs[i].arrival_time[-1] - outputs[i].arrival_time[1]))
|
||||
else:
|
||||
print("len(outputs[i].arrival_time) <= 2")
|
||||
completed += 1
|
||||
else:
|
||||
actual_output_lens.append(0)
|
||||
@@ -341,15 +344,16 @@ async def benchmark(
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_output_len = \
|
||||
test_prompt, test_output_len, test_no = \
|
||||
input_requests[0].prompt, \
|
||||
input_requests[0].expected_output_len
|
||||
input_requests[0].expected_output_len, input_requests[0].no
|
||||
test_history_QA = input_requests[0].history_QA
|
||||
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=test_prompt,
|
||||
no=test_no,
|
||||
prompt_len=0,
|
||||
history_QA=test_history_QA,
|
||||
hyper_parameters=hyper_parameters,
|
||||
@@ -384,6 +388,7 @@ async def benchmark(
|
||||
profile_input = RequestFuncInput(model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=test_prompt,
|
||||
no=test_no,
|
||||
api_url=base_url + "/start_profile",
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
@@ -422,7 +427,7 @@ async def benchmark(
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
prompt, output_len = request.prompt, request.expected_output_len
|
||||
prompt, output_len, no = request.prompt, request.expected_output_len, request.no
|
||||
history_QA = request.history_QA
|
||||
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
@@ -433,6 +438,7 @@ async def benchmark(
|
||||
request_func_input = RequestFuncInput(model=req_model_id,
|
||||
model_name=req_model_name,
|
||||
prompt=prompt,
|
||||
no=no,
|
||||
prompt_len=0,
|
||||
history_QA=history_QA,
|
||||
hyper_parameters=hyper_parameters,
|
||||
@@ -452,6 +458,7 @@ async def benchmark(
|
||||
profile_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
no=test_no,
|
||||
api_url=base_url + "/stop_profile",
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
@@ -464,6 +471,8 @@ async def benchmark(
|
||||
pbar.close()
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
print("benchmark_duration:", benchmark_duration)
|
||||
|
||||
|
||||
metrics, actual_output_lens = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
@@ -594,6 +603,155 @@ async def benchmark(
|
||||
return result
|
||||
|
||||
|
||||
def benchmark_metrics(
|
||||
benchmark_duration: float,
|
||||
result_file: str,
|
||||
selected_percentiles: list[float],
|
||||
selected_percentile_metrics: list[str],
|
||||
goodput_config_dict: dict[str, float],
|
||||
):
|
||||
"""Benchmark metrics statistics,generate benchmark result"""
|
||||
outputs = []
|
||||
case_no_list = []
|
||||
with open(result_file) as f:
|
||||
for line in f.readlines():
|
||||
if "RequestFuncOutput" in line:
|
||||
start = line.find("RequestFuncOutput")
|
||||
end = line.rfind(")")
|
||||
para_str = line[start:end + 1]
|
||||
|
||||
output = eval(para_str)
|
||||
outputs.append(output)
|
||||
|
||||
input_requests = [[]] * len(outputs)
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
metrics, actual_output_lens = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
selected_percentiles=selected_percentiles,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
)
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
|
||||
benchmark_duration))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:",
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
if goodput_config_dict:
|
||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||
metrics.request_goodput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
metrics.output_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||
metrics.total_token_throughput))
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput:":
|
||||
metrics.request_goodput if goodput_config_dict else None,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": actual_output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"input_texts": ["" for input in input_requests],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
|
||||
def process_one_metric(
|
||||
# E.g., "ttft"
|
||||
metric_attribute_name: str,
|
||||
# E.g., "TTFT"
|
||||
metric_name: str,
|
||||
# E.g., "Time to First Token"
|
||||
metric_header: str,
|
||||
):
|
||||
# This function prints and adds statistics of the specified
|
||||
# metric.
|
||||
if metric_attribute_name not in selected_percentile_metrics:
|
||||
return
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name} (ms):",
|
||||
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name} (ms):",
|
||||
getattr(metrics, f"median_{metric_attribute_name}_ms")))
|
||||
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"mean_{metric_attribute_name}_ms")
|
||||
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"median_{metric_attribute_name}_ms")
|
||||
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"std_{metric_attribute_name}_ms")
|
||||
for p, value in getattr(metrics,
|
||||
f"percentiles_{metric_attribute_name}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
|
||||
value))
|
||||
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||
|
||||
def process_one_length(
|
||||
# E.g., "ttft"
|
||||
metric_attribute_name: str,
|
||||
# E.g., "TTFT"
|
||||
metric_name: str,
|
||||
# E.g., "Time to First Token"
|
||||
metric_header: str,
|
||||
):
|
||||
# This function prints and adds statistics of the specified
|
||||
# metric.
|
||||
if metric_attribute_name not in selected_percentile_metrics:
|
||||
return
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name}:",
|
||||
getattr(metrics, f"mean_{metric_attribute_name}")))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name}:",
|
||||
getattr(metrics, f"median_{metric_attribute_name}")))
|
||||
result[f"mean_{metric_attribute_name}"] = getattr(
|
||||
metrics, f"mean_{metric_attribute_name}")
|
||||
result[f"median_{metric_attribute_name}"] = getattr(
|
||||
metrics, f"median_{metric_attribute_name}")
|
||||
result[f"std_{metric_attribute_name}"] = getattr(
|
||||
metrics, f"std_{metric_attribute_name}")
|
||||
for p, value in getattr(metrics,
|
||||
f"percentiles_{metric_attribute_name}"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:",
|
||||
value))
|
||||
result[f"p{p_word}_{metric_attribute_name}"] = value
|
||||
|
||||
process_one_length("s_decode", "Decode", "解码速度(tok/s)")
|
||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
|
||||
process_one_metric("tpot", "TPOT",
|
||||
"Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
|
||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||
process_one_metric("s_e2el", "S_E2EL", "Infer End-to-end Latency")
|
||||
process_one_length("input_len", "Input Length", "Input Length")
|
||||
process_one_length("s_input_len", "Input Length", "Infer Input Length")
|
||||
process_one_length("output_len", "Output Length", "Output Length")
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def check_goodput_args(args):
|
||||
"""Check whether the given argument has valid goodput configuration or not"""
|
||||
# Check and parse goodput arguments
|
||||
@@ -759,6 +917,16 @@ def main(args: argparse.Namespace):
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
))
|
||||
|
||||
# benchmark_result = benchmark_metrics(
|
||||
# benchmark_duration=3600,
|
||||
# result_file="your result file",
|
||||
# selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||
# selected_percentiles=[
|
||||
# float(p) for p in args.metric_percentiles.split(",")
|
||||
# ],
|
||||
# goodput_config_dict=goodput_config_dict,
|
||||
# )
|
||||
|
||||
# Save config and results to json
|
||||
if args.save_result:
|
||||
|
1180
benchmarks/quick_benchmark.py
Normal file
1180
benchmarks/quick_benchmark.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,3 +3,4 @@ tqdm
|
||||
numpy
|
||||
Pillow
|
||||
pyyaml
|
||||
requests
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint4
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 96
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 4
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wfp8afp8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint4
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
3
benchmarks/yaml/request_yaml/quick_benchmark.yaml
Normal file
3
benchmarks/yaml/request_yaml/quick_benchmark.yaml
Normal file
@@ -0,0 +1,3 @@
|
||||
metadata:
|
||||
min_tokens: 32
|
||||
max_tokens: 33
|
11
benchmarks/yaml/request_yaml/vLLM_default.yaml
Normal file
11
benchmarks/yaml/request_yaml/vLLM_default.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
top_p: 1.0
|
||||
temperature: 1.0
|
||||
metadata:
|
||||
min_tokens: 1
|
||||
max_tokens: 30721
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
||||
skip_special_tokens: false
|
||||
chat_template_kwargs:
|
||||
enable_thinking: true
|
51
build.sh
51
build.sh
@@ -18,6 +18,9 @@ BUILD_WHEEL=${1:-1}
|
||||
PYTHON_VERSION=${2:-"python"}
|
||||
export python=$PYTHON_VERSION
|
||||
FD_CPU_USE_BF16=${3:-"false"}
|
||||
# FD_BUILDING_ARCS: Specify target CUDA architectures for custom ops, e.g., "[80, 90, 100]".
|
||||
# For SM90 (Hopper), use 90. For SM100 (Blackwell), use 100.
|
||||
# These will be translated to 90a / 100a in setup_ops.py for specific features.
|
||||
FD_BUILDING_ARCS=${4:-""}
|
||||
|
||||
|
||||
@@ -74,8 +77,10 @@ function copy_ops(){
|
||||
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
|
||||
if [ "$is_rocm" = "True" ]; then
|
||||
DEVICE_TYPE="rocm"
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
||||
echo -e "ROCM ops have been copy to fastdeploy"
|
||||
echo -e "BASE and ROCM ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
@@ -104,6 +109,23 @@ function copy_ops(){
|
||||
return
|
||||
fi
|
||||
|
||||
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
|
||||
if [ "$if_corex" = "True" ]; then
|
||||
DEVICE_TYPE="iluvatar-gpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
|
||||
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
is_gcu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('gcu'))"`
|
||||
if [ "$is_gcu" = "True" ]; then
|
||||
DEVICE_TYPE="gcu"
|
||||
cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gcu
|
||||
echo -e "gcu ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
DEVICE_TYPE="cpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cd ../../../../
|
||||
@@ -163,17 +185,24 @@ function build_and_install() {
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${BLUE}[build]${NONE} ${GREEN}build fastdeploy wheel success${NONE}\n"
|
||||
}
|
||||
|
||||
echo -e "${BLUE}[install]${NONE} installing fastdeploy..."
|
||||
cd $DIST_DIR
|
||||
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install
|
||||
if [ $? -ne 0 ]; then
|
||||
cd ..
|
||||
echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed"
|
||||
exit 1
|
||||
function version_info() {
|
||||
output_file="fastdeploy/version.txt"
|
||||
fastdeploy_git_commit_id=$(git rev-parse HEAD)
|
||||
paddle_version=$(${python} -c "import paddle; print(paddle.__version__)")
|
||||
paddle_git_commit_id=$(${python} -c "import paddle; print(paddle.__git_commit__)")
|
||||
cuda_version="nvcc-not-installed"
|
||||
if command -v nvcc &> /dev/null; then
|
||||
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
|
||||
fi
|
||||
echo -e "${BLUE}[install]${NONE} ${GREEN}fastdeploy install success${NONE}\n"
|
||||
cd ..
|
||||
cxx_version=$(g++ --version | head -n 1 | grep -Po "(?<=\) )[\d.]+")
|
||||
|
||||
echo "fastdeploy GIT COMMIT ID: $fastdeploy_git_commit_id" > $output_file
|
||||
echo "Paddle version: $paddle_version" >> $output_file
|
||||
echo "Paddle GIT COMMIT ID: $paddle_git_commit_id" >> $output_file
|
||||
echo "CUDA version: $cuda_version" >> $output_file
|
||||
echo "CXX compiler version: $cxx_version" >> $output_file
|
||||
}
|
||||
|
||||
function cleanup() {
|
||||
@@ -207,6 +236,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
|
||||
set -e
|
||||
|
||||
init
|
||||
version_info
|
||||
build_and_install_ops
|
||||
build_and_install
|
||||
cleanup
|
||||
@@ -237,6 +267,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
|
||||
else
|
||||
init
|
||||
build_and_install_ops
|
||||
version_info
|
||||
rm -rf $BUILD_DIR $EGG_DIR $DIST_DIR
|
||||
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
|
||||
fi
|
||||
|
@@ -26,7 +26,7 @@ index 15b22ca..63e7fb7 100644
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
|
||||
from . import jit
|
||||
from .jit_kernels import (
|
||||
diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh
|
||||
@@ -53,7 +53,7 @@ index c17d466..6fdc52f 100644
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
from . import interleave_ffma
|
||||
diff --git a/deep_gemm/jit/interleave_ffma.py b/deep_gemm/jit/interleave_ffma.py
|
||||
index fcb377e..db9d6f3 100644
|
||||
@@ -65,8 +65,8 @@ index fcb377e..db9d6f3 100644
|
||||
import subprocess
|
||||
-from torch.utils.cpp_extension import CUDA_HOME
|
||||
+from ..paddle_utils import CUDA_HOME
|
||||
|
||||
|
||||
|
||||
|
||||
def run_cuobjdump(file_path):
|
||||
diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py
|
||||
index 66c370a..4761426 100644
|
||||
@@ -78,7 +78,7 @@ index 66c370a..4761426 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from .template import map_ctype
|
||||
@@ -35,7 +35,7 @@ class Runtime:
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
@@ -100,8 +100,8 @@ index ead37f5..51b02c1 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict, Iterable, Tuple
|
||||
|
||||
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
@@ -116,15 +116,15 @@ index ead37f5..51b02c1 100644
|
||||
+ paddle.float8_e4m3fn: 'paddle.float8_e4m3fn',
|
||||
+ paddle.device.cuda.Stream: "paddle.device.cuda.Stream",
|
||||
}
|
||||
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
- **{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
+ **{t: ctypes.c_void_p for t in (paddle.int32, paddle.float32, paddle.bfloat16, paddle.float8_e4m3fn, paddle.device.cuda.Stream)},
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -27,25 +27,25 @@ genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
@@ -140,8 +140,8 @@ index ead37f5..51b02c1 100644
|
||||
+ paddle.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
+ paddle.device.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
if hasattr(value, 'data_ptr'):
|
||||
- if value.dtype == torch.int:
|
||||
@@ -171,11 +171,11 @@ index cb438b7..44aa0ed 100644
|
||||
+import paddle
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@@ -166,20 +166,20 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||
|
||||
|
||||
|
||||
|
||||
-def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor) -> None:
|
||||
@@ -189,7 +189,7 @@ index cb438b7..44aa0ed 100644
|
||||
The LHS scaling tensor requires TMA-aligned transposed format, if your input does not match the requirement,
|
||||
- this function will do a transposing with a set of slow PyTorch operations.
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m, k]`,
|
||||
@@ -202,10 +202,10 @@ index cb438b7..44aa0ed 100644
|
||||
@@ -189,22 +189,22 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
n, k_ = rhs.shape
|
||||
m_, n_ = out.shape
|
||||
|
||||
|
||||
- assert n % 64 == 0 and k % 128 == 0
|
||||
+ # assert n % 64 == 0 and k % 128 == 0
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
- assert n > 0 and k > 0
|
||||
@@ -223,13 +223,13 @@ index cb438b7..44aa0ed 100644
|
||||
+ # assert rhs.dtype == paddle.float8_e4m3fn and rhs_scales.dtype == paddle.float32
|
||||
+ # assert out.dtype == paddle.bfloat16
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -214,7 +214,7 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -264,12 +264,12 @@ index 3b518c9..ba776bd 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
|
||||
@@ -37,25 +37,25 @@ gemm_t::run(out, rhs_scales, grouped_layout,
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, m_indices: torch.Tensor) -> None:
|
||||
@@ -285,7 +285,7 @@ index 3b518c9..ba776bd 100644
|
||||
+ this function will do a transposing with a set of slow Pypaddle operations.
|
||||
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
|
||||
`get_m_alignment_for_contiguous_layout()` (128).
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[m_sum, k]`,
|
||||
@@ -301,7 +301,7 @@ index 3b518c9..ba776bd 100644
|
||||
Values of `m_indices` in every-m-alignment-block must also be the same.
|
||||
@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
m__ = m_indices.numel()
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert m == m_ == m__ and k == k_ and n == n_
|
||||
- assert lhs_scales.shape == (m, (k + 127) // 128)
|
||||
@@ -321,12 +321,12 @@ index 3b518c9..ba776bd 100644
|
||||
+ # assert m_indices.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and m_indices.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Do nothing if `m` is zero
|
||||
if m == 0:
|
||||
@@ -92,7 +92,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
@@ -357,8 +357,8 @@ index 3b518c9..ba776bd 100644
|
||||
)
|
||||
@@ -118,22 +118,22 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
runtime(*args)
|
||||
|
||||
|
||||
|
||||
|
||||
-def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
- out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
|
||||
@@ -374,7 +374,7 @@ index 3b518c9..ba776bd 100644
|
||||
+ this function will do a transposing with a set of slow paddle operations.
|
||||
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
|
||||
should be separately transposed.
|
||||
|
||||
|
||||
Arguments:
|
||||
- lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
+ lhs: the first element is an FP8 tensor (typed `paddle.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
|
||||
@@ -386,7 +386,7 @@ index 3b518c9..ba776bd 100644
|
||||
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
|
||||
@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
num_groups___ = masked_m.numel()
|
||||
|
||||
|
||||
# Type and shape checks
|
||||
- assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
- assert m == m_ and n == n_ and k == k_
|
||||
@@ -410,16 +410,16 @@ index 3b518c9..ba776bd 100644
|
||||
+ # assert masked_m.dtype == paddle.int32
|
||||
+ # assert lhs.is_contiguous() and rhs.is_contiguous()
|
||||
+ # assert out.is_contiguous() and masked_m.is_contiguous()
|
||||
|
||||
|
||||
# LHS scales must be transposed for TMA load, but not for RHS scales
|
||||
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
|
||||
- assert rhs_scales.is_contiguous()
|
||||
+ # assert rhs_scales.is_contiguous()
|
||||
|
||||
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
@@ -176,7 +176,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||
masked_m, m,
|
||||
- torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||
@@ -454,11 +454,11 @@ index 6ed6749..9e1d70f 100644
|
||||
-import torch
|
||||
+import paddle
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
from ..jit import build, cpp_format, generate, Runtime
|
||||
@@ -51,10 +51,10 @@ class JITTuner:
|
||||
continue
|
||||
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -478,9 +478,9 @@ index c6da56b..a17b1b1 100644
|
||||
@@ -1,4 +1,4 @@
|
||||
-import torch
|
||||
+import paddle
|
||||
|
||||
|
||||
_num_sms = None
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ def set_num_sms(num_sms: int) -> None:
|
||||
num_sms: the desired maximum SM count for all GEMM kernels to use.
|
||||
"""
|
||||
@@ -488,8 +488,8 @@ index c6da56b..a17b1b1 100644
|
||||
- assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ assert 0 < num_sms <= paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
_num_sms = num_sms
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def get_num_sms() -> int:
|
||||
"""
|
||||
global _num_sms
|
||||
@@ -497,12 +497,12 @@ index c6da56b..a17b1b1 100644
|
||||
- _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
|
||||
+ _num_sms = paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
return _num_sms
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -74,9 +74,9 @@ def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
return ceil_div(x, alignment) * alignment
|
||||
|
||||
|
||||
|
||||
|
||||
-def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
+def get_col_major_tma_aligned_tensor(x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -510,7 +510,7 @@ index c6da56b..a17b1b1 100644
|
||||
+ Returns TMA-aligned transposed format of the input tensor. `paddle.transpose` will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along the M axis
|
||||
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
|
||||
|
||||
|
||||
@@ -92,18 +92,20 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
m, n = x.shape[-2], x.shape[-1]
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
@@ -519,14 +519,14 @@ index c6da56b..a17b1b1 100644
|
||||
+ if x.strides[0] == 1 and x.strides[1] == aligned_m:
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
- if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
|
||||
+ if x.strides[0] == aligned_m * n and x.strides[1] == 1 and x.strides[2] == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
|
||||
# Normal layout requires transposing
|
||||
- aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
+ aligned_x = paddle.transpose(
|
||||
@@ -574,20 +574,20 @@ index d5cdd01..5237f09 100644
|
||||
-import torch.distributed as dist
|
||||
+import paddle
|
||||
+import paddle.distributed as dist
|
||||
|
||||
|
||||
|
||||
|
||||
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
high_precision: bool = False):
|
||||
# Flush L2 cache with 256 MB data
|
||||
- torch.cuda.synchronize()
|
||||
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||
+ paddle.device.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
+ cache = paddle.empty((int(256e6 // 4)), dtype=paddle.int32)
|
||||
cache.zero_()
|
||||
|
||||
|
||||
# Warmup
|
||||
@@ -18,18 +18,18 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
|
||||
|
||||
|
||||
# Add a large kernel to eliminate the CPU launch overhead
|
||||
if high_precision:
|
||||
- x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
||||
@@ -595,7 +595,7 @@ index d5cdd01..5237f09 100644
|
||||
+ x = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
+ y = paddle.randn((8192, 8192), dtype=paddle.float32)
|
||||
x @ y
|
||||
|
||||
|
||||
# Testing
|
||||
- start_event = torch.cuda.Event(enable_timing=True)
|
||||
- end_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -607,9 +607,9 @@ index d5cdd01..5237f09 100644
|
||||
end_event.record()
|
||||
- torch.cuda.synchronize()
|
||||
+ paddle.device.synchronize()
|
||||
|
||||
|
||||
return start_event.elapsed_time(end_event) / num_tests
|
||||
|
||||
|
||||
@@ -106,21 +106,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
||||
@@ -636,8 +636,7 @@ index d5cdd01..5237f09 100644
|
||||
- torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
|
||||
+ paddle.empty(flush_l2_size, dtype=paddle.int32).zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
--
|
||||
2.43.0
|
||||
|
||||
if not using_nsys:
|
||||
--
|
||||
2.43.0
|
||||
|
@@ -47,7 +47,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
@@ -166,7 +166,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
lambda_batch_ids,
|
||||
lambda_tile_ids_per_batch,
|
||||
@@ -203,7 +203,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
@@ -275,7 +275,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -298,7 +298,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -323,7 +323,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -347,7 +347,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -404,7 +404,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
@@ -462,7 +462,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
|
||||
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
@@ -474,7 +474,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
@@ -551,7 +551,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& padding_offsets_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
@@ -611,7 +611,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& padding_offsets_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
@@ -689,7 +689,7 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"cu_seqlens_q",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
|
@@ -41,7 +41,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -114,8 +114,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block =
|
||||
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
@@ -405,7 +404,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -477,8 +476,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
@@ -776,7 +774,7 @@ void MultiQueryAppendAttention(
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -882,7 +880,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -939,7 +937,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -974,7 +972,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1103,7 +1101,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1171,7 +1169,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1207,7 +1205,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1290,7 +1288,7 @@ void CascadeAppendAttentionC16Kernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -1353,7 +1351,7 @@ void CascadeAppendAttentionC16Kernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -144,8 +144,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
|
||||
const uint32_t kv_b_stride = HEAD_DIM / 2;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block =
|
||||
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
@@ -504,7 +503,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -601,8 +600,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
|
||||
const uint32_t kv_b_stride = HEAD_DIM / 2;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
@@ -963,7 +961,7 @@ void MultiQueryAppendC4Attention(
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1088,7 +1086,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1151,7 +1149,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1186,7 +1184,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1333,7 +1331,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1409,7 +1407,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1444,7 +1442,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1527,7 +1525,7 @@ void CascadeAppendAttentionC4Kernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -1594,7 +1592,7 @@ void CascadeAppendAttentionC4Kernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -151,8 +151,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block =
|
||||
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
@@ -473,7 +472,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -575,8 +574,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
@@ -900,7 +898,7 @@ void MultiQueryAppendC8Attention(
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1054,7 +1052,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1111,7 +1109,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1146,7 +1144,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1317,7 +1315,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1387,7 +1385,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1417,7 +1415,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1500,7 +1498,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -1565,7 +1563,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -2111,7 +2111,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT *__restrict__ out,
|
||||
@@ -2127,7 +2127,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const int bid = blockIdx.x, hid = blockIdx.y;
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int seq_len_q = seq_lens_q[bid];
|
||||
if (seq_len_q == 0) return;
|
||||
int seq_len_kv = seq_lens_kv[bid];
|
||||
|
@@ -41,7 +41,7 @@ void CascadeAppendAttentionC16Kernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -86,7 +86,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -131,7 +131,7 @@ void CascadeAppendAttentionC4Kernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -176,7 +176,7 @@ void CascadeAppendAttentionKernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -212,7 +212,7 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
@@ -247,7 +247,7 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
@@ -282,7 +282,7 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
@@ -317,7 +317,7 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
236
custom_ops/gpu_ops/append_attn/decode_attention_func.cuh
Normal file
236
custom_ops/gpu_ops/append_attn/decode_attention_func.cuh
Normal file
@@ -0,0 +1,236 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "multi_head_latent_attention_kernel.h"
|
||||
|
||||
template <size_t vec_size, typename T>
|
||||
struct softmax_state_t {
|
||||
AlignedVector<T, vec_size> o;
|
||||
T m;
|
||||
T d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = __float2half(-5e4f);
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = __float2bfloat16(-3.38953e38f);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ softmax_state_t() {
|
||||
init();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
|
||||
T other_m,
|
||||
T other_d) {
|
||||
// using kType = typename cascade_attn_nv_type2_traits<T>::type;
|
||||
T m_prev = m, d_prev = d;
|
||||
m = m_prev > other_m ? m_prev : other_m;
|
||||
T scale1 = hexp(m_prev - m), scale2 = hexp(other_m - m);
|
||||
|
||||
d = d_prev * scale1 + other_d * scale2;
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] = o[i] * scale1 + other_o[i] * scale2;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize() {
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <size_t vec_size, typename T, uint32_t num_tiles = 0>
|
||||
struct softmax_state_ts {
|
||||
uint32_t num_tiles_ = num_tiles;
|
||||
AlignedVector<T, vec_size> o[num_tiles];
|
||||
float m;
|
||||
float d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
#pragma unroll
|
||||
for (uint32_t tile_id = 0; tile_id < num_tiles_; ++tile_id) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o[tile_id]) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o[tile_id]) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ softmax_state_ts() {
|
||||
init();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize(const uint32_t tile_id) {
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; i++) {
|
||||
o[tile_id][i] /= d;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <SharedMemFillMode fill_mode, uint32_t HEAD_DIM_QK, uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t BLOCK_SIZE, uint32_t CACHE_VEC_SIZE, typename CacheT>
|
||||
__device__ __forceinline__ void produce_kv(CacheT *smem,
|
||||
CacheT *kv_base_gptr,
|
||||
const int * block_table_smem,
|
||||
const uint32_t seq_offset_gmem,
|
||||
const uint32_t seq_offset_smem,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t tidx,
|
||||
const uint32_t chunk_start,
|
||||
const uint32_t chunk_end) {
|
||||
int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]);
|
||||
if (block_id < 0) {
|
||||
block_id = 0;
|
||||
}
|
||||
const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE;
|
||||
// 8/16 T/int8 each time
|
||||
const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK;
|
||||
const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK;
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>(
|
||||
smem + smem_offset_base + vid * CACHE_VEC_SIZE,
|
||||
kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE,
|
||||
seq_offset_gmem < chunk_end
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t bdy, uint32_t HEAD_DIM, uint32_t DEAL_EACH_TIME, uint32_t num_tile_v, typename T, typename CacheT>
|
||||
__device__ __forceinline__ void compute_qk(const T* cu_q_smem,
|
||||
const CacheT* k_smem,
|
||||
const uint32_t kv_idx_base,
|
||||
const uint32_t stage_idx,
|
||||
const uint32_t iter_base,
|
||||
const uint32_t iter_bound,
|
||||
const uint32_t tidx,
|
||||
const uint32_t gid,
|
||||
const float scale,
|
||||
float *s,
|
||||
softmax_state_ts<vec_size, T, num_tile_v>& st) {
|
||||
const CacheT* smem;
|
||||
AlignedVector<T, vec_size> q_vec;
|
||||
AlignedVector<T, vec_size> k_vec;
|
||||
float m_prev = st.m;
|
||||
// smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM;
|
||||
smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM;
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
|
||||
if (iter_base + j < iter_bound) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s[j] = 0.f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
s[j] = 0.f;
|
||||
}
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
Load<T, vec_size>(cu_q_smem + vid * vec_size, &q_vec);
|
||||
Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
s[j] += static_cast<float>(q_vec[i] * k_vec[i]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
|
||||
s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s[j] = -5e4f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
s[j] = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
st.m = st.m > s[j] ? st.m : s[j];
|
||||
}
|
||||
|
||||
// T o_scale = hexp(m_prev - st.m);
|
||||
float o_scale = __expf(m_prev - st.m);
|
||||
st.d *= o_scale;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
|
||||
// s[j] = hexp(s[j] - st.m);
|
||||
s[j] = __expf(s[j] - st.m);
|
||||
st.d += s[j];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t tile_id = 0; tile_id < num_tile_v; ++tile_id) {
|
||||
for (uint32_t i = 0; i < vec_size; ++i) {
|
||||
st.o[tile_id][i] *= o_scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t DEAL_EACH_TIME, uint32_t HEAD_DIM_QK, uint32_t num_tile, typename T, typename CacheT>
|
||||
__device__ __forceinline__ void compute_sv(const float *s,
|
||||
const CacheT *base_v_smem,
|
||||
const uint32_t stage_idx,
|
||||
const uint32_t iter_base,
|
||||
const uint32_t iter_bound,
|
||||
const uint32_t tidx,
|
||||
softmax_state_ts<vec_size, T, num_tile>& st) {
|
||||
const CacheT* v_smem;
|
||||
AlignedVector<T, vec_size> v_vec;
|
||||
#pragma unroll
|
||||
for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) {
|
||||
v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK;
|
||||
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
|
||||
Load<T, vec_size>(v_smem + vid * vec_size, &v_vec);
|
||||
uint32_t tile_id = vid / bdx;
|
||||
#pragma unroll
|
||||
for (int reg_id = 0; reg_id < vec_size; ++reg_id) {
|
||||
st.o[tile_id][reg_id] += static_cast<T>(s[j]) * v_vec[reg_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
560
custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu
Normal file
560
custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu
Normal file
@@ -0,0 +1,560 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decode_attention_func.cuh"
|
||||
|
||||
#define CHECK(call) \
|
||||
do \
|
||||
{ \
|
||||
const cudaError_t error_code = call; \
|
||||
if (error_code != cudaSuccess) \
|
||||
{ \
|
||||
printf("CUDA Error:\n"); \
|
||||
printf(" File: %s\n", __FILE__); \
|
||||
printf(" Line %d:\n", __LINE__); \
|
||||
printf(" Error code:%d\n", error_code); \
|
||||
printf(" Error text:%s\n", cudaGetErrorString(error_code)); \
|
||||
exit(1); \
|
||||
} \
|
||||
}while(0)
|
||||
|
||||
template <typename T, typename OutT, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim]
|
||||
const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads]
|
||||
const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cu_seqlens_q,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const float in_scale,
|
||||
const int num_chunks,
|
||||
const int chunk_size,
|
||||
const int max_seq_len,
|
||||
const int num_heads,
|
||||
const int head_dim) {
|
||||
const int vid = threadIdx.x, ty = threadIdx.y;
|
||||
const int qid = blockIdx.x, hid = blockIdx.y;
|
||||
const int seq_len_q = seq_lens_q[qid];
|
||||
if (seq_len_q == 0) return;
|
||||
int seq_len_kv = seq_lens_kv[qid];
|
||||
if (seq_len_kv == 0) return;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) {
|
||||
return;
|
||||
}
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ T md_smem[bdy * 2];
|
||||
|
||||
const int start_token_ids = cu_seqlens_q[qid];
|
||||
using LoadT = AlignedVector<T, vec_size>;
|
||||
LoadT load_vec;
|
||||
LoadT res_vec;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&res_vec) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
T m;
|
||||
T d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = __float2half(-5e4f);
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = __float2bfloat16(-3.38953e38f);
|
||||
}
|
||||
// merge per ty
|
||||
#pragma unroll 2
|
||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
|
||||
T m_prev = m;
|
||||
T d_prev = d;
|
||||
const T m_now = multi_m[offset];
|
||||
const T d_now = multi_d[offset];
|
||||
m = m_prev > m_now ? m_prev : m_now;
|
||||
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size;
|
||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||
const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m);
|
||||
d = d * scale1 + d_now * scale2;
|
||||
#pragma once
|
||||
for (int j = 0; j < vec_size; j++) {
|
||||
res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2;
|
||||
}
|
||||
}
|
||||
// store ty res
|
||||
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
|
||||
md_smem[2 * ty] = m;
|
||||
md_smem[2 * ty + 1] = d;
|
||||
__syncthreads();
|
||||
|
||||
// merge bdy
|
||||
softmax_state_t<vec_size, T> st{};
|
||||
const uint32_t iter_num = min(num_chunks_this_seq, bdy);
|
||||
#pragma once
|
||||
for (int i = 0; i < iter_num; i++) {
|
||||
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
|
||||
const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
|
||||
AlignedVector<OutT, vec_size> out_vec;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size; ++i) {
|
||||
out_vec[i] = static_cast<OutT>(st.o[i]);
|
||||
}
|
||||
Store<OutT, vec_size>(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
|
||||
template <bool partition_kv, typename T, typename OutT, typename CacheT, uint32_t NUM_STAGES, uint32_t DEAL_EACH_TIME, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V,
|
||||
uint32_t BLOCK_SIZE, uint32_t VEC_SIZE, uint32_t CACHE_VEC_SIZE, uint32_t bdx, uint32_t bdy>
|
||||
__global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim]
|
||||
CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||
CacheT * __restrict__ cache_v,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cu_seqlens_q,
|
||||
const int * __restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
const float scale,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim]
|
||||
T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads]
|
||||
T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads]
|
||||
OutT * __restrict__ out) {
|
||||
const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const uint32_t bid = bidx, gid = threadIdx.y;
|
||||
const uint32_t tidx = threadIdx.x;
|
||||
constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE;
|
||||
constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE;
|
||||
constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx;
|
||||
|
||||
const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid;
|
||||
const uint32_t kv_num_heads = gridDim.z;
|
||||
const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE;
|
||||
|
||||
const int *block_table_now = block_table + bid * max_block_num_per_seq;
|
||||
|
||||
const uint32_t num_chunks = gridDim.y;
|
||||
const uint32_t chunk_id = blockIdx.y;
|
||||
const uint32_t q_len = seq_lens_q[bid];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!!
|
||||
if (kv_len <= 0) {
|
||||
return;
|
||||
}
|
||||
kv_len += q_len;
|
||||
const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size);
|
||||
const uint32_t q_start_idx = cu_seqlens_q[bid];
|
||||
const uint32_t q_write_idx = cu_seqlens_q[bid];
|
||||
if (chunk_id >= num_chunk_this_seq) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0;
|
||||
const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len;
|
||||
const uint32_t chunk_len = chunk_end - chunk_start;
|
||||
|
||||
extern __shared__ uint8_t smem[];
|
||||
const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK;
|
||||
T *q_smem = reinterpret_cast<T*>(smem); // [HEAD_DIM_QK * sizeof(T)]
|
||||
T *cu_q_smem = q_smem + gid * HEAD_DIM_QK;
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
|
||||
((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0];
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
using VecT = AlignedVector<T, VEC_SIZE>;
|
||||
VecT q_vec;
|
||||
#pragma unroll
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) {
|
||||
Load<T, VEC_SIZE>(cu_q_smem + vid * VEC_SIZE, &q_vec);
|
||||
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
|
||||
q_vec[i] *= scale;
|
||||
}
|
||||
Store<T, VEC_SIZE>(q_vec, cu_q_smem + vid * VEC_SIZE);
|
||||
}
|
||||
|
||||
|
||||
CacheT *kv_smem = reinterpret_cast<CacheT*>(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT));
|
||||
uint32_t stage_idx = 0;
|
||||
constexpr int loop_times = DEAL_EACH_TIME / bdy;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_STAGES; ++i) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < loop_times; ++j) {
|
||||
const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid;
|
||||
const uint32_t k_seq_id = chunk_start + k_seq_offset;
|
||||
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
|
||||
kv_smem,
|
||||
cache_k,
|
||||
block_table_now,
|
||||
k_seq_id,
|
||||
k_seq_offset,
|
||||
kv_head_idx,
|
||||
kv_num_heads,
|
||||
tidx,
|
||||
chunk_start,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
stage_idx = (stage_idx + 1) % NUM_STAGES;
|
||||
}
|
||||
|
||||
|
||||
softmax_state_ts<VEC_SIZE, T, num_tile_v> st;
|
||||
float s[DEAL_EACH_TIME];
|
||||
|
||||
const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME);
|
||||
for (int iter = 0; iter < num_iters; ++iter) {
|
||||
wait_group<NUM_STAGES - 1>();
|
||||
__syncthreads();
|
||||
// compute qk
|
||||
compute_qk<VEC_SIZE, num_vec_per_head_qk, bdx, bdy, HEAD_DIM_QK, DEAL_EACH_TIME, num_tile_v>(
|
||||
cu_q_smem,
|
||||
kv_smem,
|
||||
chunk_start + iter * DEAL_EACH_TIME,
|
||||
stage_idx,
|
||||
iter * DEAL_EACH_TIME,
|
||||
chunk_len,
|
||||
tidx,
|
||||
gid,
|
||||
scale,
|
||||
s,
|
||||
st
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
// compute sv
|
||||
compute_sv<VEC_SIZE, num_vec_per_head_v, bdx, DEAL_EACH_TIME, HEAD_DIM_QK, num_tile_v>(
|
||||
s,
|
||||
kv_smem,
|
||||
stage_idx,
|
||||
iter * DEAL_EACH_TIME,
|
||||
chunk_len,
|
||||
tidx,
|
||||
st
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < loop_times; ++j) {
|
||||
const uint32_t k_seq_offset = j * bdy + gid;
|
||||
produce_kv<SharedMemFillMode::kNoFill, HEAD_DIM_QK, VEC_SIZE, num_vec_per_head_qk, bdx, BLOCK_SIZE, CACHE_VEC_SIZE>(
|
||||
kv_smem,
|
||||
cache_k,
|
||||
block_table_now,
|
||||
chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME,
|
||||
stage_idx * DEAL_EACH_TIME + k_seq_offset,
|
||||
kv_head_idx,
|
||||
kv_num_heads,
|
||||
tidx,
|
||||
chunk_start,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
stage_idx = (stage_idx + 1) % NUM_STAGES;
|
||||
}
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// normize if not partition_kv
|
||||
for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) {
|
||||
const uint32_t tile_id = vid / bdx;
|
||||
if (!partition_kv || num_chunk_this_seq == 1) {
|
||||
st.normalize(tile_id);
|
||||
}
|
||||
if (partition_kv && num_chunk_this_seq > 1) {
|
||||
const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx;
|
||||
Store<T, VEC_SIZE>(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE);
|
||||
tmp_m[head_idx] = st.m;
|
||||
tmp_d[head_idx] = st.d;
|
||||
} else {
|
||||
Store<OutT, VEC_SIZE>(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V, uint32_t BLOCK_SIZE, bool CAUSAL, uint32_t NUM_STAGE, uint32_t cache_bytes, uint32_t DEAL_EACH_TIME>
|
||||
void MultiQueryDecoderAttention(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
cudaStream_t &stream,
|
||||
const paddle::Tensor &q,
|
||||
const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim]
|
||||
const paddle::Tensor &cache_v, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float rope_scale,
|
||||
const float rope_theta,
|
||||
const float softmax_scale,
|
||||
const float in_scale,
|
||||
paddle::Tensor *out) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto bsz = meta_data.batch_size;
|
||||
auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
constexpr int num_stages = NUM_STAGE;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T); // 8 16 32
|
||||
constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32
|
||||
constexpr int blockxc = HEAD_DIM_QK / cache_vec_size;
|
||||
constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size;
|
||||
constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32;
|
||||
|
||||
constexpr int blocky = GROUP_SIZE;
|
||||
const int gridx = bsz;
|
||||
|
||||
constexpr int num_threads = blockx * blocky;
|
||||
|
||||
auto splitkv_kernel = multi_query_decode_attention_kernel<true, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V,
|
||||
BLOCK_SIZE, vec_size, cache_vec_size, blockx, blocky>;
|
||||
uint32_t cache_smem_bytes = 0;
|
||||
|
||||
const T *shift_bias_ptr = shift_bias ? shift_bias.get().data<T>() : nullptr;
|
||||
const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data<T>() : nullptr;
|
||||
cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T);
|
||||
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T);
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
const int dev_id = 0;
|
||||
int sm_count;
|
||||
int act_blocks_per_sm;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, splitkv_kernel, num_threads, smem_size);
|
||||
assert(act_blocks_per_sm > 1);
|
||||
|
||||
const int num_blocks_per_wave = sm_count * act_blocks_per_sm;
|
||||
const int num_blocks_need = gridx * num_chunks * kv_num_heads;
|
||||
const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need);
|
||||
const float ratio = static_cast<float>(num_blocks_need) / static_cast<float>(num_blocks_per_wave);
|
||||
|
||||
dim3 grids(gridx, num_chunks, kv_num_heads);
|
||||
dim3 blocks(blockx, blocky);
|
||||
if (num_chunks <= 1) {
|
||||
auto no_splitkv_kernel = multi_query_decode_attention_kernel<false, NV_TYPE, NV_TYPE, NV_TYPE, num_stages, DEAL_EACH_TIME, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, vec_size,
|
||||
cache_vec_size, blockx, blocky>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
no_splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
softmax_scale,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
|
||||
);
|
||||
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
} else {
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
tmp_workspace = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM_V));
|
||||
tmp_m = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
tmp_d = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||
|
||||
splitkv_kernel<<<grids, blocks, smem_size, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(q.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(cache_v.data<T>())),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
softmax_scale,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>()))
|
||||
);
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
|
||||
constexpr int mblockx = HEAD_DIM_V / vec_size;
|
||||
constexpr int bdy = 256 / mblockx;
|
||||
dim3 grids_merge(bsz, num_heads);
|
||||
dim3 blocks_merge(mblockx, bdy);
|
||||
merge_varlen_multi_chunks_v2_kernel<NV_TYPE, NV_TYPE, vec_size, bdy, HEAD_DIM_V><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(tmp_workspace->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_m->ptr()),
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>())),
|
||||
in_scale,
|
||||
num_chunks,
|
||||
chunk_size,
|
||||
max_seq_len,
|
||||
num_heads,
|
||||
HEAD_DIM_V
|
||||
);
|
||||
}
|
||||
// CHECK(cudaGetLastError());
|
||||
// CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DecodeMLAAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
const auto num_heads = meta_data.q_num_heads;
|
||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||
const auto head_dim_qk = meta_data.head_dims;
|
||||
const auto head_dim_v = meta_data.head_dims_v;
|
||||
const float rope_scale = 0.0;
|
||||
const float rope_theta = 0.0;
|
||||
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
|
||||
const uint32_t num_stage = get_cascade_attention_num_stages();
|
||||
const uint32_t num_threads = get_cascade_attention_num_threads();
|
||||
|
||||
DISPATCH_CAUSAL(causal, CAUSAL,
|
||||
{DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE,
|
||||
{DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
|
||||
{DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V,
|
||||
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
|
||||
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
|
||||
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
|
||||
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cu_seqlens_q,
|
||||
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
|
||||
}
|
||||
|
||||
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
||||
|
||||
template void DecodeMLAAttentionKernel<paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
@@ -29,7 +29,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -65,7 +65,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -135,7 +135,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -177,7 +177,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -255,7 +255,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -293,7 +293,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -367,7 +367,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -409,7 +409,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -499,7 +499,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -523,7 +523,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -746,7 +746,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -775,7 +775,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -1048,7 +1048,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1073,7 +1073,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -1347,7 +1347,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1377,7 +1377,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -1740,7 +1740,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1766,7 +1766,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -2035,7 +2035,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -2066,7 +2066,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -2363,7 +2363,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -2389,7 +2389,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -2733,7 +2733,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -2764,7 +2764,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
|
@@ -22,7 +22,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -58,7 +58,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -80,7 +80,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -103,7 +103,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -126,7 +126,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -150,7 +150,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -183,7 +183,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -208,7 +208,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -233,7 +233,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -258,7 +258,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -283,7 +283,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -318,7 +318,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -345,7 +345,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -372,7 +372,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -399,7 +399,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -425,7 +425,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -472,7 +472,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -504,7 +504,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -537,7 +537,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -571,7 +571,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -604,7 +604,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -651,7 +651,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -678,7 +678,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -704,7 +704,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -730,7 +730,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
|
@@ -24,7 +24,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
|
@@ -879,7 +879,7 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ padding_offsets,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -911,8 +911,7 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx =
|
||||
batch_id * max_seq_len - cum_offsets[batch_id];
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
@@ -1119,7 +1118,7 @@ __global__ void append_write_cache_kv_c4_qkv(
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ padding_offsets,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -1148,8 +1147,7 @@ __global__ void append_write_cache_kv_c4_qkv(
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx =
|
||||
batch_id * max_seq_len - cum_offsets[batch_id];
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
@@ -1750,7 +1748,7 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1815,7 +1813,7 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
@@ -1838,7 +1836,7 @@ void CascadeAppendWriteCacheKVC4QKV(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1885,7 +1883,7 @@ void CascadeAppendWriteCacheKVC4QKV(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
|
@@ -26,7 +26,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
@@ -143,7 +143,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
batch_ids,
|
||||
tile_ids,
|
||||
@@ -170,7 +170,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
batch_ids,
|
||||
tile_ids,
|
||||
|
@@ -422,7 +422,6 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids,
|
||||
@@ -450,7 +449,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const int token_num = qkv_dims[0];
|
||||
const int max_blocks_per_seq = block_tables.dims()[1];
|
||||
const int block_size = key_cache.dims()[2];
|
||||
const int batch_size = cum_offsets.dims()[0];
|
||||
const int batch_size = seq_lens_this_time.dims()[0];
|
||||
const int kv_num_heads = key_cache_dims[1];
|
||||
const int head_dim = key_cache_dims[3];
|
||||
const int num_heads = qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads;
|
||||
@@ -463,7 +462,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
meta_data.q_num_heads = num_heads;
|
||||
meta_data.max_blocks_per_seq = max_blocks_per_seq;
|
||||
meta_data.block_size = block_size;
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(qkv.place()));
|
||||
|
||||
@@ -528,7 +527,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids,
|
||||
@@ -595,7 +594,6 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
|
291
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Normal file
291
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu
Normal file
@@ -0,0 +1,291 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "mla_cache_kernel.cuh"
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto num_tokens = meta_data.token_nums;
|
||||
auto block_size = meta_data.block_size;
|
||||
auto nope_size = meta_data.head_dims_v;
|
||||
auto all_size = meta_data.head_dims;
|
||||
int pe_size = all_size - nope_size;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
const uint32_t elem_nums = num_tokens * kv_num_heads * all_size;
|
||||
|
||||
constexpr int PackSize = 16 / sizeof(DataType_);
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
|
||||
prefill_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = cu_seqlens_q.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto bsz = meta_data.batch_size;
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto block_size = meta_data.block_size;
|
||||
auto nope_size = meta_data.head_dims_v;
|
||||
auto all_size = meta_data.head_dims;
|
||||
int pe_size = all_size - nope_size;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
constexpr int PackSize = 16 / sizeof(DataType_);
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
|
||||
|
||||
if (speculate_decoder) {
|
||||
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
} else {
|
||||
const uint32_t elem_nums = bsz * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
nope_size,
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
const auto& kv_nope_dims = kv_nope.dims();
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = cu_seqlens_q.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(prefill_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"seq_lens",
|
||||
"seq_lens_decoder",
|
||||
"padding_offsets",
|
||||
"cu_seqlens_q",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int"})
|
||||
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
|
||||
|
||||
PD_BUILD_OP(decode_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"seq_lens",
|
||||
"seq_lens_encoder",
|
||||
"padding_offsets",
|
||||
"cu_seqlens_q",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int",
|
||||
"speculate_decoder: bool"})
|
||||
.SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel));
|
242
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh
Normal file
242
custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh
Normal file
@@ -0,0 +1,242 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "mem_util.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / hidden_size;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
start_token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void speculate_decode_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
if (block_idx < 0) {
|
||||
printf(
|
||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||
"%d %d %d %d\n",
|
||||
block_idx,
|
||||
write_seq_id,
|
||||
ori_bi,
|
||||
seq_lens[ori_bi],
|
||||
token_id,
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_id * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void prefill_absorb_cache_kernel(
|
||||
const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512
|
||||
const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int kv_num_heads,
|
||||
const int nope_size,
|
||||
const int pe_size,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t nope_hidden_size = kv_num_heads * nope_size;
|
||||
const uint32_t pe_hidden_size = kv_num_heads * pe_size;
|
||||
const uint32_t all_size = nope_size + pe_size;
|
||||
const int64_t hidden_size = nope_hidden_size + pe_hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const uint32_t token_idx = linear_index / hidden_size;
|
||||
const uint32_t bias = linear_index % hidden_size;
|
||||
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
|
||||
const uint32_t ori_bi = ori_token_idx / max_seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const uint32_t ori_seq_id =
|
||||
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const uint32_t block_idx = block_table_now[ori_seq_id / block_size];
|
||||
const uint32_t block_offset = ori_seq_id % block_size;
|
||||
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
const uint32_t hi = inner_bias / nope_size;
|
||||
const uint32_t h_bias = inner_bias % nope_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * nope_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_nope[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
} else {
|
||||
const uint32_t inner_bias = bias - nope_hidden_size;
|
||||
const uint32_t hi = inner_bias / pe_size;
|
||||
const uint32_t h_bias = inner_bias % pe_size;
|
||||
const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size +
|
||||
hi * block_size * all_size +
|
||||
block_offset * all_size + nope_size + h_bias;
|
||||
const uint32_t ori_idx =
|
||||
token_idx * pe_hidden_size + inner_bias;
|
||||
Load<T, VecSize>(&kv_pe[ori_idx], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &kv_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "helper.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T>
|
||||
void DecodeMLAAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||
const paddle::Tensor &cache_k,
|
||||
const paddle::Tensor &cache_v,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
float softmax_scale,
|
||||
float in_scale,
|
||||
bool causal,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *out);
|
@@ -27,7 +27,7 @@ __global__ void append_clear_cache_int8_block(
|
||||
const int* __restrict__ seq_lens,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -44,7 +44,7 @@ __global__ void append_clear_cache_int8_block(
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
@@ -101,7 +101,7 @@ __global__ void append_clear_cache_int4_block(
|
||||
const int* __restrict__ seq_lens,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -118,7 +118,7 @@ __global__ void append_clear_cache_int4_block(
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
@@ -179,7 +179,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
T* __restrict__ q_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
@@ -219,7 +219,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -235,7 +235,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cum_offsets[ori_bi]);
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
@@ -312,7 +312,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
@@ -352,7 +352,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -368,7 +368,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cum_offsets[ori_bi]);
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
@@ -459,7 +459,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -487,7 +487,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
@@ -691,7 +691,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -719,7 +719,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
|
||||
@@ -1069,7 +1069,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1100,7 +1100,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
@@ -1375,7 +1375,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1406,7 +1406,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
|
@@ -23,7 +23,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -60,7 +60,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
@@ -83,7 +83,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
@@ -107,7 +107,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -137,7 +137,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
seq_lens,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
@@ -152,7 +152,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -176,7 +176,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -202,7 +202,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -234,7 +234,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
seq_lens,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
@@ -249,7 +249,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -275,7 +275,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -302,7 +302,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -350,7 +350,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -377,7 +377,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -410,7 +410,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -443,7 +443,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -489,7 +489,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -515,7 +515,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -540,7 +540,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -567,7 +567,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
|
@@ -24,7 +24,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
|
@@ -38,7 +38,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -38,7 +38,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -39,7 +39,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -86,7 +86,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -81,7 +81,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -83,7 +83,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -83,7 +83,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -23,7 +23,7 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
@@ -25,6 +25,7 @@ struct AppendAttnMetaData {
|
||||
int kv_num_heads;
|
||||
int token_nums;
|
||||
int head_dims;
|
||||
int head_dims_v;
|
||||
int max_blocks_per_seq;
|
||||
};
|
||||
|
||||
@@ -309,10 +310,56 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
||||
if (num_stage == 2) { \
|
||||
constexpr size_t NUM_STAGE = 2; \
|
||||
__VA_ARGS__ \
|
||||
#define DISPATCH_GQA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 192: { \
|
||||
constexpr size_t HEAD_DIM = 192; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("not support the head_dim: ", head_dim); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_HEAD_DIM(head_dim, HEAD_DIM, ...) \
|
||||
switch (head_dim) { \
|
||||
case 128: { \
|
||||
constexpr size_t HEAD_DIM = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 192: { \
|
||||
constexpr size_t HEAD_DIM = 192; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 512: { \
|
||||
constexpr size_t HEAD_DIM = 512; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 576: { \
|
||||
constexpr size_t HEAD_DIM = 576; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("not support the head_dim: ", head_dim); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define DISPATCH_NUM_STAGE(num_stage, NUM_STAGE, ...) \
|
||||
if (num_stage == 2) { \
|
||||
constexpr size_t NUM_STAGE = 2; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the num_stage: ", num_stage); \
|
||||
}
|
||||
|
||||
#define DISPATCH_CACHE_TYPE(cache_type, cache_type_now, cache_bytes, ...) \
|
||||
@@ -328,10 +375,13 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
constexpr CacheType cache_type_now = CacheType::CacheInt4CwZp; \
|
||||
constexpr size_t cache_bytes = 4; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the cache_type: ", cache_type); \
|
||||
}
|
||||
|
||||
|
||||
#define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \
|
||||
if (deal_each_time == 32) { \
|
||||
if (deal_each_time == 32) { \
|
||||
constexpr size_t DEAL_EACH_TIME = 32; \
|
||||
__VA_ARGS__ \
|
||||
} else if (deal_each_time == 64) { \
|
||||
@@ -387,6 +437,20 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 128) { \
|
||||
constexpr size_t GROUP_SIZE = 128; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size: ", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \
|
||||
if (block_shape_q <= 16) { \
|
||||
constexpr size_t BLOCK_SHAPE_Q = 16; \
|
||||
|
@@ -54,7 +54,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &padding_offsets, const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &padding_offsets, const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
|
||||
const paddle::Tensor &encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &encoder_num_blocks,
|
||||
@@ -94,7 +94,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &padding_offsets, const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &block_tables, const paddle::Tensor &kv_batch_ids,
|
||||
const paddle::Tensor &kv_tile_ids, const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &cache_batch_ids, const paddle::Tensor &cache_tile_ids,
|
||||
@@ -116,11 +116,11 @@ PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
|
||||
|
||||
paddle::Tensor FusedExpertMoeFunc(
|
||||
const paddle::Tensor &input, const paddle::Tensor &gate_weight,
|
||||
const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor> &ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor> &ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_scale,
|
||||
const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor> &up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_scale,
|
||||
const std::string &quant_method, const int moe_topk,
|
||||
const bool norm_topk_prob, const bool group_moe);
|
||||
|
||||
@@ -149,7 +149,7 @@ MoERedundantTopKSelectKernel(const paddle::Tensor &gating_logits,
|
||||
std::vector<paddle::Tensor>
|
||||
EPMoeExpertDispatch(const paddle::Tensor &input, const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &topk_weights,
|
||||
const paddle::optional<paddle::Tensor> &ffn1_in_scale,
|
||||
const paddle::optional<paddle::Tensor> &up_gate_proj_in_scale,
|
||||
const std::vector<int> &token_nums_per_expert,
|
||||
const int token_nums_this_rank,
|
||||
const std::string &moe_quant_type);
|
||||
@@ -158,7 +158,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
const paddle::Tensor &input, const paddle::Tensor &scale,
|
||||
const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights,
|
||||
const paddle::Tensor &token_nums_per_expert,
|
||||
const paddle::Tensor &token_nums_per_expert_padded);
|
||||
const paddle::Tensor &token_nums_per_expert_padded,
|
||||
const bool use_in_ep, const int token_nums_this_rank_padded);
|
||||
|
||||
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
|
||||
const int block_size);
|
||||
@@ -172,7 +173,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
|
||||
const paddle::Tensor &ffn_out, const paddle::Tensor &expert_scales_float,
|
||||
const paddle::Tensor &permute_indices_per_token,
|
||||
const paddle::Tensor &top_k_indices,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const bool norm_topk_prob, const float routed_scaling_factor);
|
||||
|
||||
std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
|
||||
@@ -181,35 +182,35 @@ std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
|
||||
paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight, const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
|
||||
const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency);
|
||||
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
|
||||
const bool used_in_ep_low_latency);
|
||||
|
||||
paddle::Tensor MoeExpertReduceFunc(
|
||||
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
|
||||
const paddle::Tensor &permute_indices_per_token,
|
||||
const paddle::Tensor &top_k_indices,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const bool norm_topk_prob, const float routed_scaling_factor);
|
||||
|
||||
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
|
||||
@@ -316,6 +317,95 @@ void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||
|
||||
paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
|
||||
int64_t num_experts);
|
||||
void GetPositionIdsAndMaskEncoderBatch(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& position_ids,
|
||||
const paddle::Tensor& mask_encoder_batch);
|
||||
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder);
|
||||
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len);
|
||||
|
||||
|
||||
void FusedRotaryPositionEncoding(
|
||||
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
paddle::Tensor& key,
|
||||
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
|
||||
// head_size]
|
||||
const paddle::Tensor& position_ids, // [num_tokens]
|
||||
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
int head_size,
|
||||
bool is_neox);
|
||||
|
||||
std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& query_bias,
|
||||
const paddle::optional<paddle::Tensor>& query_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int nope_size,
|
||||
const int max_input_length,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder);
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M);
|
||||
@@ -370,6 +460,270 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out,
|
||||
paddle::Tensor const &input,
|
||||
paddle::Tensor &scales, float scale_ub);
|
||||
|
||||
std::vector<paddle::Tensor> NoauxTc(
|
||||
paddle::Tensor& scores,
|
||||
paddle::Tensor& scores_with_bias,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
float routed_scaling_factor);
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& y,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
bool trans_x,
|
||||
bool trans_y,
|
||||
float scale, // only support per-tensor quantization
|
||||
std::string output_dtype,
|
||||
std::string activation_type);
|
||||
|
||||
paddle::Tensor MoeFusedHadamardQuantFp8Func(
|
||||
const paddle::Tensor &input,
|
||||
const paddle::Tensor &scale,
|
||||
const paddle::Tensor &topk_ids,
|
||||
const int top_k,
|
||||
const int intermediate_size,
|
||||
const bool tiled);
|
||||
|
||||
paddle::Tensor FusedHadamardQuantFp8Func(
|
||||
const paddle::Tensor &input,
|
||||
const float scale);
|
||||
#endif
|
||||
|
||||
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
|
||||
|
||||
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
|
||||
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||
|
||||
void dispose(int64_t _fa);
|
||||
|
||||
int64_t meta_size();
|
||||
|
||||
void register_buffer(int64_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
|
||||
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(int64_t _fa);
|
||||
|
||||
void register_graph_buffers(int64_t _fa,
|
||||
const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
|
||||
std::tuple<int64_t, paddle::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size);
|
||||
|
||||
int64_t open_mem_handle(paddle::Tensor& mem_handle);
|
||||
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
// speculative decoding Kernel
|
||||
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& token_num,
|
||||
const paddle::Tensor& seq_len,
|
||||
const paddle::Tensor& seq_lens_encoder);
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
|
||||
const paddle::Tensor& output_cum_offsets_tmp,
|
||||
const paddle::Tensor& out_token_num,
|
||||
const paddle::Tensor& seq_lens_output,
|
||||
const int max_seq_len);
|
||||
|
||||
|
||||
void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &logits,
|
||||
const paddle::Tensor &penalty_scores,
|
||||
const paddle::Tensor &frequency_scores,
|
||||
const paddle::Tensor &presence_scores,
|
||||
const paddle::Tensor &temperatures,
|
||||
const paddle::Tensor &bad_tokens,
|
||||
const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &output_padding_offset,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const int max_seq_len);
|
||||
|
||||
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens,
|
||||
const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len,
|
||||
const paddle::Tensor &end_ids);
|
||||
|
||||
|
||||
void SpeculateVerify(
|
||||
const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &verify_tokens, const paddle::Tensor &verify_scores,
|
||||
const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const paddle::Tensor &actual_candidate_len,
|
||||
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
||||
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
|
||||
|
||||
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &actual_draft_token_nums,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &stop_nums);
|
||||
|
||||
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_idx);
|
||||
|
||||
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank);
|
||||
|
||||
|
||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &draft_token_num,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const int max_ngram_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
|
||||
// MTP
|
||||
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_stop_flags);
|
||||
|
||||
|
||||
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& base_model_is_block_step,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill);
|
||||
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
const paddle::Tensor& end_ids,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_seq_len,
|
||||
const int substep);
|
||||
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> EagleGetHiddenStates(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& accept_nums,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const int actual_draft_token_num);
|
||||
|
||||
void MTPStepPaddle(
|
||||
const paddle::Tensor &base_model_stop_flags,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &batch_drop,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor &encoder_block_lens,
|
||||
const paddle::Tensor &used_list_len,
|
||||
const paddle::Tensor &free_list,
|
||||
const paddle::Tensor &free_list_len,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void SpeculateStepPaddle(
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &ori_seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor &encoder_block_lens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &step_block_list,
|
||||
const paddle::Tensor &step_lens,
|
||||
const paddle::Tensor &recover_block_list,
|
||||
const paddle::Tensor &recover_lens,
|
||||
const paddle::Tensor &need_block_list,
|
||||
const paddle::Tensor &need_block_len,
|
||||
const paddle::Tensor &used_list_len,
|
||||
const paddle::Tensor &free_list,
|
||||
const paddle::Tensor &free_list_len,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &first_token_ids,
|
||||
const paddle::Tensor &accept_num,
|
||||
const int block_size,
|
||||
const int encoder_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
||||
@@ -461,7 +815,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
* ep_moe_dispatch
|
||||
*/
|
||||
m.def("ep_moe_expert_dispatch", &EPMoeExpertDispatch, py::arg("input"),
|
||||
py::arg("topk_ids"), py::arg("topk_weights"), py::arg("ffn1_in_scale"),
|
||||
py::arg("topk_ids"), py::arg("topk_weights"), py::arg("up_gate_proj_in_scale"),
|
||||
py::arg("token_nums_per_expert"), py::arg("token_nums_this_rank"),
|
||||
py::arg("moe_quant_type"), "ep moe export dispatch function");
|
||||
|
||||
@@ -469,7 +823,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("ep_moe_expert_combine", &EPMoeExpertCombine, py::arg("ffn_out"),
|
||||
py::arg("expert_scales_float"), py::arg("permute_indices_per_token"),
|
||||
py::arg("top_k_indices"), py::arg("ffn2_bias"),
|
||||
py::arg("top_k_indices"), py::arg("down_proj_bias"),
|
||||
py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"),
|
||||
"ep moe export combine function");
|
||||
|
||||
@@ -511,7 +865,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
*/
|
||||
m.def("moe_expert_reduce", &MoeExpertReduceFunc, py::arg("ffn_out"),
|
||||
py::arg("top_k_weight"), py::arg("permute_indices_per_token"),
|
||||
py::arg("top_k_indices"), py::arg("ffn2_bias"),
|
||||
py::arg("top_k_indices"), py::arg("down_proj_bias"),
|
||||
py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"),
|
||||
"moe export reduce function");
|
||||
|
||||
@@ -539,9 +893,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
* append_attn/get_block_shape_and_split_kv_block.cu
|
||||
* get_block_shape_and_split_kv_block
|
||||
*/
|
||||
// m.def("f_get_block_shape_and_split_kv_block",
|
||||
// &GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block
|
||||
// function");
|
||||
m.def("get_block_shape_and_split_kv_block",
|
||||
&GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block function");
|
||||
|
||||
/**
|
||||
* get_padding_offset.cu
|
||||
@@ -602,32 +955,16 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
|
||||
|
||||
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
|
||||
py::arg("a"),
|
||||
py::arg("c_or_none"),
|
||||
py::arg("b_q_weight"),
|
||||
py::arg("b_scales"),
|
||||
py::arg("global_scale_or_none"),
|
||||
py::arg("b_zeros_or_none"),
|
||||
py::arg("g_idx_or_none"),
|
||||
py::arg("perm_or_none"),
|
||||
py::arg("workspace"),
|
||||
py::arg("sorted_token_ids"),
|
||||
py::arg("expert_ids"),
|
||||
py::arg("num_tokens_post_padded"),
|
||||
py::arg("topk_weights"),
|
||||
py::arg("moe_block_size"),
|
||||
py::arg("top_k"),
|
||||
py::arg("mul_topk_weights"),
|
||||
py::arg("is_ep"),
|
||||
py::arg("b_q_type_str"),
|
||||
py::arg("size_m"),
|
||||
py::arg("size_n"),
|
||||
py::arg("size_k"),
|
||||
py::arg("is_k_full"),
|
||||
py::arg("use_atomic_add"),
|
||||
py::arg("use_fp32_reduce"),
|
||||
py::arg("is_zp_float"));
|
||||
py::arg("a"), py::arg("c_or_none"), py::arg("b_q_weight"),
|
||||
py::arg("b_scales"), py::arg("global_scale_or_none"), py::arg("b_zeros_or_none"),
|
||||
py::arg("g_idx_or_none"), py::arg("perm_or_none"), py::arg("workspace"), py::arg("sorted_token_ids"),
|
||||
py::arg("expert_ids"), py::arg("num_tokens_post_padded"), py::arg("topk_weights"), py::arg("moe_block_size"),
|
||||
py::arg("top_k"), py::arg("mul_topk_weights"), py::arg("is_ep"), py::arg("b_q_type_str"),
|
||||
py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"), py::arg("use_atomic_add"),
|
||||
py::arg("use_fp32_reduce"), py::arg("is_zp_float"));
|
||||
|
||||
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
|
||||
"get_position_ids_and_mask_encoder_batch function");
|
||||
|
||||
/**
|
||||
* cutlass_scaled_mm.cu
|
||||
@@ -653,4 +990,80 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
|
||||
"dynamic_per_token_scaled_fp8_quant function",
|
||||
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
|
||||
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
|
||||
|
||||
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
|
||||
|
||||
m.def("fused_rotary_position_encoding", &FusedRotaryPositionEncoding, "fused_rotary_position_encoding function");
|
||||
|
||||
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
|
||||
|
||||
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
|
||||
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
|
||||
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
|
||||
py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function");
|
||||
m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func,
|
||||
py::arg("input"), py::arg("scale"), py::arg("topk_ids"),
|
||||
py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function");
|
||||
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
|
||||
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
|
||||
#endif
|
||||
|
||||
m.def("init_custom_all_reduce", &init_custom_all_reduce, "init all reduce class function");
|
||||
|
||||
m.def("all_reduce", &all_reduce, "all reduce function");
|
||||
|
||||
m.def("dispose", &dispose, "del function for python");
|
||||
|
||||
m.def("meta_size", &meta_size, "meta_size function for Signal struct");
|
||||
|
||||
m.def("register_buffer", ®ister_buffer, "register ipc buffer");
|
||||
|
||||
m.def("register_graph_buffers", ®ister_graph_buffers, "register_graph_buffers");
|
||||
|
||||
m.def("allocate_shared_buffer_and_handle", &allocate_shared_buffer_and_handle, "allocate_shared_buffer_and_handle");
|
||||
|
||||
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
|
||||
|
||||
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
|
||||
|
||||
// speculative decoding Kernel
|
||||
m.def("speculate_get_padding_offset", &SpeculateGetPaddingOffset, "speculate_get_padding_offset function");
|
||||
|
||||
m.def("speculate_get_seq_lens_output", &SpeculateGetSeqLensOutput, "speculate_get_seq_lens_output function");
|
||||
|
||||
m.def("speculate_get_output_padding_offset",&SpeculateGetOutputPaddingOffset, "speculate_get_output_padding_offset function");
|
||||
|
||||
m.def("speculate_get_token_penalty_multi_scores",&SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function");
|
||||
|
||||
m.def("speculate_set_stop_value_multi_seqs",&SpecGetStopFlagsMultiSeqs, "speculate_set_stop_value_multi_seqs function");
|
||||
|
||||
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
|
||||
|
||||
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
|
||||
|
||||
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
|
||||
|
||||
m.def("speculate_save_output", &SpeculateSaveWithOutputMsgStatic, "speculate_save_output function");
|
||||
|
||||
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
|
||||
|
||||
m.def("ngram_match", &NgramMatch, "ngram_match function");
|
||||
|
||||
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
|
||||
|
||||
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");
|
||||
|
||||
m.def("draft_model_update",&DraftModelUpdate, "draft_model_update function");
|
||||
|
||||
m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function");
|
||||
|
||||
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
|
||||
|
||||
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
|
||||
}
|
||||
|
165
custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu
Normal file
165
custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu
Normal file
@@ -0,0 +1,165 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
|
||||
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "all_reduce.cuh"
|
||||
|
||||
// Fake pointer type, must match fptr_t type in ops.h.
|
||||
// We use this type alias to indicate when pointers are passed in as int64_t.
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
paddle::Tensor& rank_data, int64_t rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = fake_ipc_ptrs.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
paddle::Signal* ipc_ptrs[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<paddle::Signal*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, rank_data.data(),
|
||||
rank_data.numel(), rank, world_size,
|
||||
full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs an out-of-place allreduce and stores result in out.
|
||||
*
|
||||
* If _reg_buffer is null, assumes inp.data() is already IPC-registered.
|
||||
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
|
||||
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
auto stream = inp.stream();
|
||||
|
||||
auto input_size = inp.numel() * 2;
|
||||
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
|
||||
if (reg_buffer) {
|
||||
cudaMemcpyAsync(reg_buffer, inp.data(), input_size,
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
} else {
|
||||
reg_buffer = inp.data();
|
||||
}
|
||||
switch (out.dtype()) {
|
||||
case phi::DataType::FLOAT32: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
|
||||
reinterpret_cast<float*>(out.data()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case phi::DataType::FLOAT16: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
|
||||
reinterpret_cast<half*>(out.data()), out.numel());
|
||||
break;
|
||||
}
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
|
||||
case phi::DataType::BFLOAT16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
delete reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
}
|
||||
|
||||
int64_t meta_size() { return sizeof(paddle::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
void* ipc_ptrs[8];
|
||||
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
fa->register_buffer(ipc_ptrs);
|
||||
}
|
||||
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
std::vector<int64_t> bytes(handle.begin(), handle.end());
|
||||
return std::make_tuple(bytes, offsets);
|
||||
}
|
||||
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
void register_graph_buffers(fptr_t _fa,
|
||||
const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
std::vector<std::string> bytes;
|
||||
bytes.reserve(handles.size());
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
|
||||
auto device_index = phi::backends::gpu::GetCurrentDeviceId();
|
||||
void* buffer;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
auto stream = paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream();
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Allocate buffer
|
||||
CUDACHECK(cudaMalloc((void**)&buffer, size));
|
||||
CUDACHECK(cudaMemsetAsync(buffer, 0, size, stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Create IPC memhandle for the allocated buffer.
|
||||
// Will use it in open_mem_handle.
|
||||
auto handle =
|
||||
paddle::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, paddle::DataType::UINT8, paddle::GPUPlace(device_index));
|
||||
CUDACHECK(
|
||||
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer));
|
||||
|
||||
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
|
||||
}
|
||||
|
||||
|
||||
fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
|
||||
void* ipc_ptr;
|
||||
CUDACHECK(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data()),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
return reinterpret_cast<fptr_t>(ipc_ptr);
|
||||
}
|
||||
|
||||
void free_shared_buffer(fptr_t buffer) {
|
||||
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||
}
|
526
custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh
Normal file
526
custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh
Normal file
@@ -0,0 +1,526 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <array>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace paddle {
|
||||
|
||||
constexpr int kMaxBlocks = 36;
|
||||
// Counter may overflow, but it's fine since unsigned int overflow is
|
||||
// well-defined behavior.
|
||||
using FlagType = uint32_t;
|
||||
struct Signal {
|
||||
alignas(128) FlagType self_counter[kMaxBlocks][8];
|
||||
// Two sets of peer counters are needed for two syncs. The reason is that
|
||||
// it's possible for peer GPU block to arrive at the second sync point while
|
||||
// the current GPU block haven't passed the first sync point. Thus, peer GPU
|
||||
// may write counter+1 while current GPU is busy waiting for counter. We use
|
||||
// alternating counter array to avoid this possibility.
|
||||
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankSignals {
|
||||
Signal* signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) { return __half2float(val); }
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Paddle, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half& assign_add(half& a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float& assign_add(float& a, float b) { return a += b; }
|
||||
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
|
||||
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
||||
template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
|
||||
"l"(flag_addr));
|
||||
#else
|
||||
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
|
||||
"l"(flag_addr));
|
||||
#endif
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
#else
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
#endif
|
||||
return flag;
|
||||
}
|
||||
|
||||
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
|
||||
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1];"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
return flag;
|
||||
}
|
||||
|
||||
// is_start: whether this is the very first synchronization barrier.
|
||||
// need_fence: whether a memory fence is needed. If true, a release-acquire
|
||||
// semantic is used to enforce memory access order before and after this
|
||||
// barrier.
|
||||
template <int ngpus, bool is_start, bool need_fence = false>
|
||||
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
|
||||
int rank) {
|
||||
if constexpr (!is_start) __syncthreads();
|
||||
static_assert(
|
||||
!(is_start && need_fence)); // Start barrier shouldn't need fence.
|
||||
if (threadIdx.x < ngpus) {
|
||||
// Increment the counter. Technically we only need one counter, but we use
|
||||
// multiple per block to eliminate the need to share the counter via smem.
|
||||
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
|
||||
// Write the expected counter value to peer and wait for correct value from
|
||||
// peer.
|
||||
auto peer_counter_ptr =
|
||||
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
|
||||
auto self_counter_ptr =
|
||||
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
|
||||
if constexpr (need_fence) {
|
||||
st_flag_release(peer_counter_ptr, val);
|
||||
while (ld_flag_acquire(self_counter_ptr) != val);
|
||||
} else {
|
||||
st_flag_volatile(peer_counter_ptr, val);
|
||||
while (ld_flag_volatile(self_counter_ptr) != val);
|
||||
}
|
||||
}
|
||||
if constexpr (is_start || need_fence) __syncthreads();
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
|
||||
T* __restrict__ result, int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
DINLINE P* get_tmp_buf(Signal* sg) {
|
||||
return (P*)(((Signal*)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
|
||||
T* __restrict__ result, int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P* ptrs[ngpus];
|
||||
P* tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P*)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P*)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
|
||||
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointters from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
|
||||
// For cuda graph to work, all kernel arguments must be fixed during graph
|
||||
// capture time. However, the peer pointers are not known during graph capture
|
||||
// time. Therefore, during capture, we increment the rank data pointer and use
|
||||
// that as the argument to the kernel. The kernel arguments are stored in
|
||||
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
|
||||
// memory pointed to by the pointers in graph_unreg_buffers_ when
|
||||
// the IPC handles are exchanged between ranks.
|
||||
//
|
||||
// The overall process looks like this:
|
||||
// 1. Graph capture.
|
||||
// 2. Each rank obtains the IPC handles for each addresses used during cuda
|
||||
// graph capture using get_graph_buffer_ipc_meta.
|
||||
// 3. (In Python) all gather the IPC handles.
|
||||
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
|
||||
// the rank data array at corresponding positions.
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void*> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
|
||||
/**
|
||||
* Signals are an array of ipc-enabled buffers from all ranks.
|
||||
* For each of the buffer, the layout is as follows:
|
||||
* | -- sizeof(Signal) -- | ------ a few MB ----- |
|
||||
* The first section is for allreduce synchronization, and the second section
|
||||
* is for storing the intermediate results required by some allreduce algos.
|
||||
*
|
||||
* Note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor.
|
||||
*/
|
||||
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
|
||||
int rank, int world_size, bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(world_size),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(signals[rank]),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
sg_.signals[i] = signals[i];
|
||||
}
|
||||
}
|
||||
|
||||
char* open_ipc_handle(const void* ipc_handle) {
|
||||
auto [it, new_handle] =
|
||||
ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
|
||||
*((const cudaIpcMemHandle_t*)ipc_handle),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr,
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(cudaIpcGetMemHandle(
|
||||
(cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " +
|
||||
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
/**
|
||||
* Register already-shared IPC pointers.
|
||||
*/
|
||||
void register_buffer(void** ptrs) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
data.ptrs[i] = ptrs[i];
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CUDACHECK(
|
||||
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||
buffers_[ptrs[rank_]] = d_data;
|
||||
}
|
||||
|
||||
// Note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void register_graph_buffers(
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto& rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char* handle =
|
||||
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
|
||||
sizeof(RankData) * num_buffers,
|
||||
cudaMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs allreduce, assuming input has already been registered.
|
||||
*
|
||||
* Block and grid default configs are results after careful grid search. Using
|
||||
* 36 blocks give the best or close to the best runtime on the devices I
|
||||
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
|
||||
* take a small amount of SMs. Not quite sure the underlying reason, but my
|
||||
* guess is that too many SMs will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, int size,
|
||||
int threads = 512, int block_limit = 36) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error("max supported block limit is " +
|
||||
std::to_string(kMaxBlocks) + ". Got " +
|
||||
std::to_string(block_limit));
|
||||
|
||||
RankData* ptrs;
|
||||
cudaStreamCaptureStatus status;
|
||||
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
||||
if (status == cudaStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " +
|
||||
std::to_string(reinterpret_cast<uint64_t>(input)) +
|
||||
" is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
||||
rank_, size);
|
||||
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace paddle
|
@@ -76,6 +76,34 @@ enum class SplitKStyle
|
||||
// SPLIT_K_PARALLEL // Not supported yet
|
||||
};
|
||||
|
||||
// New enum for SM100 (Blackwell) Tile Configs
|
||||
// Placeholder values - actual optimal values need research
|
||||
enum class CutlassTileConfigSM100
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
Undefined,
|
||||
|
||||
// Signals that we should run heuristics do choose a config
|
||||
ChooseWithHeuristic,
|
||||
|
||||
// Actual SM100 tile configs based on user input (K-tile is 128B)
|
||||
CtaShape64x64x128B,
|
||||
CtaShape64x128x128B,
|
||||
CtaShape64x256x128B,
|
||||
CtaShape128x64x128B,
|
||||
CtaShape128x128x128B,
|
||||
CtaShape128x256x128B,
|
||||
CtaShape256x64x128B,
|
||||
CtaShape256x128x128B,
|
||||
CtaShape256x256x128B
|
||||
// Note: The user-provided list for get_candidate_tiles_sm100 also includes
|
||||
// CtaShape128x64x128B and CtaShape256x64x128B for specific FP4 grouped gemm cases.
|
||||
// These are already covered by the list above if general suffices.
|
||||
// If they need distinct enum values, they should be added.
|
||||
// For now, keeping the enum concise with unique shapes mentioned for general use.
|
||||
};
|
||||
|
||||
|
||||
enum class CutlassTileConfigSM90
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
@@ -132,9 +160,11 @@ struct CutlassGemmConfig
|
||||
WEIGHT_ONLY = 1u << 0,
|
||||
SIMT_ONLY = 1u << 1,
|
||||
INT8_ONLY = 1u << 2,
|
||||
HOPPER = 1u << 3,
|
||||
HOPPER = 1u << 3, // SM90
|
||||
GROUPED_GEMM = 1u << 4,
|
||||
FP8_ONLY = 1u << 5,
|
||||
BLACKWELL = 1u << 6, // SM100
|
||||
FP4_ONLY = 1u << 7, // For Blackwell FP4/MXFP4 paths
|
||||
};
|
||||
|
||||
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
|
||||
@@ -149,7 +179,17 @@ struct CutlassGemmConfig
|
||||
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
|
||||
bool is_sm90 = false;
|
||||
|
||||
CutlassGemmConfig() {}
|
||||
// config options for sm100 (Blackwell)
|
||||
// Assuming SM100 might use similar schedule/cluster types as SM90 for now.
|
||||
// These might need to become SM100-specific if Blackwell introduces new concepts.
|
||||
CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic;
|
||||
// MainloopScheduleType mainloop_schedule_sm100 = MainloopScheduleType::AUTO; // Example if SM100 has different types
|
||||
// EpilogueScheduleType epilogue_schedule_sm100 = EpilogueScheduleType::AUTO; // Example
|
||||
// ClusterShape cluster_shape_sm100 = ClusterShape::ClusterShape_1x1x1; // Example
|
||||
bool is_sm100 = false;
|
||||
|
||||
|
||||
CutlassGemmConfig() : is_sm90(false), is_sm100(false) {}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
|
||||
: tile_config(tile_config)
|
||||
@@ -157,37 +197,64 @@ struct CutlassGemmConfig
|
||||
, split_k_factor(split_k_factor)
|
||||
, stages(stages)
|
||||
, is_sm90(false)
|
||||
, is_sm100(false)
|
||||
{
|
||||
}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
|
||||
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
|
||||
: tile_config_sm90(tile_config_sm90)
|
||||
, mainloop_schedule(mainloop_schedule)
|
||||
, epilogue_schedule(epilogue_schedule)
|
||||
, cluster_shape(cluster_shape)
|
||||
// Constructor for SM90
|
||||
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90_in, MainloopScheduleType mainloop_schedule_in,
|
||||
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
|
||||
: tile_config_sm90(tile_config_sm90_in)
|
||||
, mainloop_schedule(mainloop_schedule_in)
|
||||
, epilogue_schedule(epilogue_schedule_in)
|
||||
, cluster_shape(cluster_shape_in)
|
||||
, is_sm90(true)
|
||||
, is_sm100(false)
|
||||
{
|
||||
}
|
||||
|
||||
// Constructor for SM100 (Blackwell)
|
||||
// Using existing MainloopScheduleType, EpilogueScheduleType, ClusterShape for now.
|
||||
// These might need to be new SM100-specific types if Blackwell's TMA differs significantly.
|
||||
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100_in, MainloopScheduleType mainloop_schedule_in,
|
||||
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
|
||||
: tile_config_sm100(tile_config_sm100_in)
|
||||
, mainloop_schedule(mainloop_schedule_in) // Potentially use mainloop_schedule_sm100 if types diverge
|
||||
, epilogue_schedule(epilogue_schedule_in) // Potentially use epilogue_schedule_sm100
|
||||
, cluster_shape(cluster_shape_in) // Potentially use cluster_shape_sm100
|
||||
, is_sm90(false) // Explicitly false
|
||||
, is_sm100(true)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
std::string toString() const
|
||||
{
|
||||
std::stringstream tactic;
|
||||
tactic << "Cutlass GEMM Tactic";
|
||||
if (tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
|
||||
if (is_sm100 && tile_config_sm100 != cutlass_extensions::CutlassTileConfigSM100::ChooseWithHeuristic)
|
||||
{
|
||||
assert(is_sm90 && "Invalid cutlass GEMM config");
|
||||
tactic << "\n\tstyle=TMA"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm90
|
||||
assert(is_sm100 && !is_sm90 && "Invalid cutlass GEMM config: SM100");
|
||||
tactic << "\n\tstyle=TMA_SM100" // Indicate SM100 specific TMA if applicable
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm100
|
||||
<< "\n\tcluster shape ID: " << (int) cluster_shape
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule
|
||||
<< "\n\tepi sched: " << (int) epilogue_schedule;
|
||||
}
|
||||
else if (is_sm90 && tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
|
||||
{
|
||||
assert(is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: SM90");
|
||||
tactic << "\n\tstyle=TMA_SM90"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm90
|
||||
<< "\n\tcluster shape ID: " << (int) cluster_shape
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule
|
||||
<< "\n\tepi sched: " << (int) epilogue_schedule;
|
||||
}
|
||||
else if (tile_config != cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
|
||||
{
|
||||
assert(!is_sm90 && "Invalid cutlass GEMM config");
|
||||
assert(!is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: Compatible");
|
||||
tactic << "\n\tstyle=compatible"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config
|
||||
<< "\n\ttile shape ID: " << (int) tile_config
|
||||
<< "\n\tstages: " << (int) stages
|
||||
<< "\n\tsplit_k_style: " << (int) split_k_style
|
||||
<< "\n\tsplit k: " << (int) split_k_factor;
|
||||
@@ -204,9 +271,24 @@ struct CutlassGemmConfig
|
||||
std::istringstream stream(str);
|
||||
std::string line;
|
||||
|
||||
is_sm90 = false; // Reset flags
|
||||
is_sm100 = false;
|
||||
|
||||
while (std::getline(stream, line)) {
|
||||
if (line.find("style=TMA") != std::string::npos) {
|
||||
if (line.find("style=TMA_SM100") != std::string::npos) {
|
||||
is_sm100 = true;
|
||||
is_sm90 = false;
|
||||
std::getline(stream, line);
|
||||
tile_config_sm100 = static_cast<cutlass_extensions::CutlassTileConfigSM100>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
cluster_shape = static_cast<cutlass_extensions::ClusterShape>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
mainloop_schedule = static_cast<cutlass_extensions::MainloopScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
} else if (line.find("style=TMA_SM90") != std::string::npos) { // Check for SM90 specific first
|
||||
is_sm90 = true;
|
||||
is_sm100 = false;
|
||||
std::getline(stream, line);
|
||||
tile_config_sm90 = static_cast<cutlass_extensions::CutlassTileConfigSM90>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
@@ -217,6 +299,7 @@ struct CutlassGemmConfig
|
||||
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
} else if (line.find("style=compatible") != std::string::npos) {
|
||||
is_sm90 = false;
|
||||
is_sm100 = false;
|
||||
std::getline(stream, line);
|
||||
tile_config = static_cast<cutlass_extensions::CutlassTileConfig>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
@@ -233,7 +316,14 @@ struct CutlassGemmConfig
|
||||
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config)
|
||||
{
|
||||
// clang-format off
|
||||
if (config.is_sm90)
|
||||
if (config.is_sm100)
|
||||
{
|
||||
out << "tile_config_sm100_enum: " << int(config.tile_config_sm100)
|
||||
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule) // Assuming same schedule types for now
|
||||
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule) // Assuming same schedule types for now
|
||||
<< ", cluster_shape_enum: " << int(config.cluster_shape); // Assuming same cluster types for now
|
||||
}
|
||||
else if (config.is_sm90)
|
||||
{
|
||||
out << "tile_config_sm90_enum: " << int(config.tile_config_sm90)
|
||||
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
|
||||
|
@@ -245,6 +245,88 @@ bool supports_mcast_along_n(CutlassTileConfigSM90 const tile)
|
||||
#endif
|
||||
}
|
||||
|
||||
// SM100 (Blackwell) candidate tile configurations
|
||||
std::vector<CutlassTileConfigSM100> get_candidate_tiles_sm100(
|
||||
int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return {CutlassTileConfigSM100::CtaShape128x128x128B};
|
||||
#else
|
||||
/* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) */
|
||||
if (config & CutlassGemmConfig::GROUPED_GEMM)
|
||||
{
|
||||
if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4
|
||||
{
|
||||
return {
|
||||
/* 1 SM (M=128) */
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
/* 2 SM (M=256) */
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x256x128B,
|
||||
/* slim tiles for very tall matrices */
|
||||
CutlassTileConfigSM100::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x64x128B};
|
||||
}
|
||||
|
||||
/* Fp8 / Fp16 grouped-GEMM */
|
||||
return {
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x256x128B};
|
||||
}
|
||||
|
||||
/* Non-grouped path (plain GEMM or weight-only) */
|
||||
return {
|
||||
/* 1 SM tiles */
|
||||
CutlassTileConfigSM100::CtaShape64x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape64x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
/* 2 SM tiles */
|
||||
CutlassTileConfigSM100::CtaShape256x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x256x128B};
|
||||
#endif
|
||||
}
|
||||
|
||||
// M-multicast support for SM100.
|
||||
bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return false;
|
||||
#else
|
||||
std::set<CutlassTileConfigSM100> m_tiles{
|
||||
CutlassTileConfigSM100::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x256x128B};
|
||||
return m_tiles.count(tile) == 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
// N-multicast support for SM100.
|
||||
bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return false;
|
||||
#else
|
||||
std::set<CutlassTileConfigSM100> n_tiles{
|
||||
CutlassTileConfigSM100::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape64x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B};
|
||||
return n_tiles.count(tile) == 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
std::vector<CutlassGemmConfig> get_candidate_configs(
|
||||
int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
|
||||
{
|
||||
@@ -284,9 +366,50 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
|
||||
}
|
||||
return candidate_configs;
|
||||
}
|
||||
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
|
||||
else if (sm == 100 && (config_type_param & CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell
|
||||
{
|
||||
std::vector<CutlassTileConfigSM100> tiles = get_candidate_tiles_sm100(sm, config_type_param);
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
for (auto const& tile_config_sm100 : tiles)
|
||||
{
|
||||
// SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO similar to SM90.
|
||||
// Cluster shapes are also handled similarly.
|
||||
CutlassGemmConfig config(
|
||||
tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
|
||||
candidate_configs.push_back(config);
|
||||
|
||||
bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100);
|
||||
bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100);
|
||||
|
||||
if (has_m_mcast)
|
||||
{
|
||||
CutlassGemmConfig mcast_m_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_2x1x1);
|
||||
candidate_configs.push_back(mcast_m_config);
|
||||
}
|
||||
|
||||
if (has_n_mcast)
|
||||
{
|
||||
CutlassGemmConfig mcast_n_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_1x2x1);
|
||||
candidate_configs.push_back(mcast_n_config);
|
||||
}
|
||||
|
||||
if (has_m_mcast && has_n_mcast)
|
||||
{
|
||||
CutlassGemmConfig mcast_mn_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_2x2x1);
|
||||
candidate_configs.push_back(mcast_mn_config);
|
||||
}
|
||||
}
|
||||
return candidate_configs;
|
||||
}
|
||||
|
||||
// Fallback to older architecture configurations
|
||||
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
|
||||
std::vector<CutlassGemmConfig> candidate_configs; //Already declared above for SM90 path, ensure scope is correct or redeclare if necessary.
|
||||
// It's fine here as it's within an else if / else block.
|
||||
bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY;
|
||||
int const min_stages = int8_configs_only ? 3 : 2;
|
||||
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
|
||||
|
@@ -12,21 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
ffn1_n=7168
|
||||
ffn1_k=8192
|
||||
up_gate_proj_n=7168
|
||||
up_gate_proj_k=8192
|
||||
|
||||
ffn2_n=8192
|
||||
ffn2_k=3584
|
||||
rm -rf ffn1_7168_8192.log
|
||||
rm -rf ffn2_8192_3584.log
|
||||
down_proj_n=8192
|
||||
down_proj_k=3584
|
||||
rm -rf up_gate_proj_7168_8192.log
|
||||
rm -rf down_proj_8192_3584.log
|
||||
num_experts=8
|
||||
|
||||
for tokens_per_expert in 12
|
||||
|
||||
do
|
||||
wait
|
||||
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 1 0 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 &
|
||||
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 1 0 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 &
|
||||
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 &
|
||||
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 &
|
||||
done
|
||||
wait
|
||||
echo "#### finish ####"
|
||||
|
64
custom_ops/gpu_ops/env.h
Normal file
64
custom_ops/gpu_ops/env.h
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
inline uint32_t get_decoder_block_shape_q() {
|
||||
static const char* decoder_block_shape_q_env = std::getenv("FLAGS_dec_block_shape_q");
|
||||
static const uint32_t decoder_block_shape_q =
|
||||
decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env));
|
||||
return decoder_block_shape_q;
|
||||
}
|
||||
|
||||
inline uint32_t get_encoder_block_shape_q() {
|
||||
static const char* encoder_block_shape_q_env = std::getenv("FLAGS_enc_block_shape_q");
|
||||
static const uint32_t encoder_block_shape_q =
|
||||
encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env));
|
||||
return encoder_block_shape_q;
|
||||
}
|
||||
|
||||
inline uint32_t get_max_partition_size(int bsz) {
|
||||
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
|
||||
static const uint32_t max_partition_size =
|
||||
max_partition_size_env == nullptr ? 32768 : std::stoul(std::string(max_partition_size_env));
|
||||
return max_partition_size;
|
||||
}
|
||||
|
||||
inline uint32_t get_cascade_attention_deal_each_time() {
|
||||
static const char* cascade_attention_deal_each_time_env = std::getenv("FLAGS_cascade_attention_deal_each_time");
|
||||
static const uint32_t cascade_attention_deal_each_time =
|
||||
cascade_attention_deal_each_time_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_deal_each_time_env));
|
||||
return (cascade_attention_deal_each_time != 0 ? cascade_attention_deal_each_time : 32);
|
||||
}
|
||||
|
||||
inline uint32_t get_cascade_attention_num_stages() {
|
||||
static const char* cascade_attention_num_stages_env = std::getenv("FLAGS_cascade_attention_num_stages");
|
||||
static const uint32_t cascade_attention_num_stages =
|
||||
cascade_attention_num_stages_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_stages_env));
|
||||
return cascade_attention_num_stages != 0 ? cascade_attention_num_stages : 2;
|
||||
}
|
||||
|
||||
inline uint32_t get_cascade_attention_num_threads() {
|
||||
static const char* cascade_attention_num_threads_env = std::getenv("FLAGS_cascade_attention_num_threads");
|
||||
static const uint32_t cascade_attention_num_threads =
|
||||
cascade_attention_num_threads_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_threads_env));
|
||||
return cascade_attention_num_threads != 0 ? cascade_attention_num_threads : 128;
|
||||
}
|
||||
|
||||
inline bool get_mla_use_tensorcore() {
|
||||
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
|
||||
static const uint32_t mla_use_tensorcore =
|
||||
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
|
||||
return mla_use_tensorcore != 0 ? true : false;
|
||||
}
|
@@ -19,7 +19,7 @@
|
||||
#include "fp8_fp8_half_cuda_core_gemm.h"
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
|
||||
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& y,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
@@ -142,7 +142,7 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
|
||||
{
|
||||
if(output_dtype == "bfloat16") {
|
||||
cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params);
|
||||
|
||||
|
||||
} else {
|
||||
cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params);
|
||||
}
|
||||
@@ -174,7 +174,21 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
|
||||
fuse_gemm_config};
|
||||
fp8_fp8_gemm_scale_bias_act(params);
|
||||
}
|
||||
return {out};
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& y,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
bool trans_x,
|
||||
bool trans_y,
|
||||
float scale, // only support per-tensor quantization
|
||||
std::string output_dtype,
|
||||
std::string activation_type) {
|
||||
return {cutlass_fp8_fp8_half_gemm_func(
|
||||
x, y, bias, trans_x, trans_y, scale,
|
||||
output_dtype, activation_type)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfGemmFusedInferShape(
|
||||
|
198
custom_ops/gpu_ops/fused_hadamard_quant_fp8.cu
Normal file
198
custom_ops/gpu_ops/fused_hadamard_quant_fp8.cu
Normal file
@@ -0,0 +1,198 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <algorithm>
|
||||
#include "helper.h"
|
||||
|
||||
__device__ __forceinline__ void hadamard32_warp(__nv_bfloat16& x) {
|
||||
int lane_id = threadIdx.x % 32;
|
||||
#pragma unroll
|
||||
for (int step = 0; step < 5; ++step) {
|
||||
const int lane_mask = 1 << step;
|
||||
const __nv_bfloat16 sign = (lane_id & lane_mask) ? -1.f : 1.f;
|
||||
__nv_bfloat16 x_val_other = __shfl_xor_sync(0xffffffff, x, lane_mask);
|
||||
x = sign * x + x_val_other;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void MoeFusedHadamardQuantFp8Kernel(
|
||||
const __nv_bfloat16* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
const int64_t* __restrict__ topk_ids,
|
||||
__nv_fp8_e4m3* out,
|
||||
const int top_k,
|
||||
const int intermediate_size,
|
||||
const int64_t numel
|
||||
) {
|
||||
int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (out_idx >= numel) return;
|
||||
|
||||
int64_t token_idx = out_idx / (top_k * intermediate_size);
|
||||
int64_t topk_idx = (out_idx / intermediate_size) % top_k;
|
||||
int64_t inter_idx = out_idx % intermediate_size;
|
||||
|
||||
int64_t input_idx = token_idx * intermediate_size + inter_idx;
|
||||
if (input_idx >= numel / top_k) return;
|
||||
|
||||
int64_t expert_id = topk_ids[token_idx * top_k + topk_idx];
|
||||
float scale_value = scale[expert_id];
|
||||
|
||||
__nv_bfloat16 x = input[input_idx];
|
||||
hadamard32_warp(x);
|
||||
|
||||
float x_fp32 = __bfloat162float(x);
|
||||
float quantized = x_fp32 / scale_value;
|
||||
out[out_idx] = static_cast<__nv_fp8_e4m3>(quantized);
|
||||
}
|
||||
|
||||
__global__ void MoeFusedHadamardQuantFp8TiledKernel(
|
||||
const __nv_bfloat16* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
const int64_t* __restrict__ topk_ids,
|
||||
__nv_fp8_e4m3* out,
|
||||
const int top_k,
|
||||
const int intermediate_size,
|
||||
const int64_t numel
|
||||
) {
|
||||
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= numel) return;
|
||||
|
||||
int64_t token_idx = idx / intermediate_size;
|
||||
int64_t expert_id = topk_ids[token_idx];
|
||||
float scale_value = scale[expert_id];
|
||||
|
||||
__nv_bfloat16 x = input[idx];
|
||||
hadamard32_warp(x);
|
||||
|
||||
float x_fp32 = __bfloat162float(x);
|
||||
float quantized = x_fp32 / scale_value;
|
||||
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeFusedHadamardQuantFp8(
|
||||
const paddle::Tensor &input,
|
||||
const paddle::Tensor &scale,
|
||||
const paddle::Tensor &topk_ids,
|
||||
const int top_k,
|
||||
const int intermediate_size,
|
||||
const bool tiled) {
|
||||
int64_t numel = input.numel();
|
||||
if (!tiled) numel *= top_k;
|
||||
paddle::Tensor out = GetEmptyTensor(
|
||||
{numel / intermediate_size, intermediate_size},
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
input.place());
|
||||
constexpr int64_t thread_per_block = 256;
|
||||
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
|
||||
auto stream = input.stream();
|
||||
if (tiled) {
|
||||
MoeFusedHadamardQuantFp8TiledKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
|
||||
top_k,
|
||||
intermediate_size,
|
||||
numel
|
||||
);
|
||||
} else {
|
||||
MoeFusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(input.data<phi::dtype::bfloat16>()),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
|
||||
top_k,
|
||||
intermediate_size,
|
||||
numel
|
||||
);
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(moe_fused_hadamard_quant_fp8)
|
||||
.Inputs({"input", "scale", "topk_ids"})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"top_k: int",
|
||||
"intermediate_size: int",
|
||||
"tiled: bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeFusedHadamardQuantFp8));
|
||||
|
||||
|
||||
paddle::Tensor MoeFusedHadamardQuantFp8Func(
|
||||
const paddle::Tensor &input,
|
||||
const paddle::Tensor &scale,
|
||||
const paddle::Tensor &topk_ids,
|
||||
const int top_k,
|
||||
const int intermediate_size,
|
||||
const bool tiled) {
|
||||
return MoeFusedHadamardQuantFp8(input, scale, topk_ids, top_k, intermediate_size, tiled)[0];
|
||||
}
|
||||
|
||||
|
||||
__global__ void FusedHadamardQuantFp8Kernel(
|
||||
const __nv_bfloat16* __restrict__ input,
|
||||
__nv_fp8_e4m3* out,
|
||||
const float scale,
|
||||
const int64_t numel) {
|
||||
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= numel) return;
|
||||
|
||||
__nv_bfloat16 x = input[idx];
|
||||
hadamard32_warp(x);
|
||||
|
||||
float x_fp32 = __bfloat162float(x);
|
||||
float quantized = x_fp32 / scale;
|
||||
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> FusedHadamardQuantFp8(
|
||||
const paddle::Tensor &input,
|
||||
const float scale) {
|
||||
int64_t numel = input.numel();
|
||||
paddle::Tensor out = GetEmptyTensor(
|
||||
input.dims(),
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
input.place());
|
||||
constexpr int64_t thread_per_block = 256;
|
||||
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
|
||||
auto stream = input.stream();
|
||||
FusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
|
||||
scale,
|
||||
numel
|
||||
);
|
||||
return {out};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(fused_hadamard_quant_fp8)
|
||||
.Inputs({"input"})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"scale: float"})
|
||||
.SetKernelFn(PD_KERNEL(FusedHadamardQuantFp8));
|
||||
|
||||
|
||||
paddle::Tensor FusedHadamardQuantFp8Func(
|
||||
const paddle::Tensor &input,
|
||||
const float scale) {
|
||||
return FusedHadamardQuantFp8(input, scale)[0];
|
||||
}
|
146
custom_ops/gpu_ops/fused_rotary_position_encoding.cu
Normal file
146
custom_ops/gpu_ops/fused_rotary_position_encoding.cu
Normal file
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename T, bool IS_NEOX>
|
||||
inline __device__ void apply_token_rotary_embedding_kernel(
|
||||
T* __restrict__ arr,
|
||||
const T* __restrict__ cos_ptr,
|
||||
const T* __restrict__ sin_ptr,
|
||||
int rot_offset,
|
||||
int embed_dim) {
|
||||
int x_index, y_index;
|
||||
T cos, sin;
|
||||
if (IS_NEOX) {
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = cos_ptr[x_index];
|
||||
sin = sin_ptr[x_index];
|
||||
} else {
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = cos_ptr[x_index / 2];
|
||||
sin = sin_ptr[x_index / 2];
|
||||
}
|
||||
|
||||
const T x = arr[x_index];
|
||||
const T y = arr[y_index];
|
||||
arr[x_index] = x * cos - y * sin;
|
||||
arr[y_index] = y * cos + x * sin;
|
||||
}
|
||||
|
||||
|
||||
template <typename T, bool IS_NEOX>
|
||||
__global__ void apply_rotary_embedding_kernel(
|
||||
T* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
T* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||
const int* __restrict__ position_ids, // [num_tokens]
|
||||
const T* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int pos = position_ids[token_idx];
|
||||
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const T* cos_ptr = cache_ptr;
|
||||
const T* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
|
||||
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
|
||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void FusedRotaryPositionEncoding(
|
||||
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
|
||||
// [num_tokens, num_heads * head_size]
|
||||
paddle::Tensor& key,
|
||||
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
|
||||
// head_size]
|
||||
const paddle::Tensor& position_ids, // [num_tokens]
|
||||
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
int head_size,
|
||||
bool is_neox) {
|
||||
int64_t num_tokens = query.dims()[0];
|
||||
int num_heads = query.numel() / num_tokens / head_size;
|
||||
int num_kv_heads = key.numel() / num_tokens / head_size;
|
||||
int rot_dim = cos_sin_cache.dims()[1];
|
||||
int64_t query_stride = num_heads * head_size;
|
||||
int64_t key_stride = num_kv_heads * head_size;
|
||||
|
||||
if (num_tokens > 65535) {
|
||||
PD_THROW(
|
||||
"apply_rotary_embedding_kernel launch failed when num_tokens > 65535.");
|
||||
}
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
|
||||
query.dtype(), "apply_rotary_embedding_kernel", [&] {
|
||||
if (is_neox) {
|
||||
apply_rotary_embedding_kernel<data_t, true>
|
||||
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
|
||||
key.data<data_t>(),
|
||||
position_ids.data<int>(),
|
||||
cos_sin_cache.data<data_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
} else {
|
||||
apply_rotary_embedding_kernel<data_t, false>
|
||||
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
|
||||
key.data<data_t>(),
|
||||
position_ids.data<int>(),
|
||||
cos_sin_cache.data<data_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
PD_BUILD_OP(fused_rotary_position_encoding)
|
||||
.Inputs({"query", "key", "position_ids", "cos_sin_cache"})
|
||||
.Outputs({"query_out", "key_out"})
|
||||
.Attrs({"head_size: int", "is_neox: bool"})
|
||||
.SetInplaceMap({{"query", "query_out"}, {"key", "key_out"}})
|
||||
.SetKernelFn(PD_KERNEL(FusedRotaryPositionEncoding));
|
@@ -24,16 +24,18 @@
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 512
|
||||
#define K 10
|
||||
#define K 20
|
||||
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
|
||||
float mtext_f[MAX_BSZ * (K + 1)]; // score
|
||||
int mtext_ranks[MAX_BSZ]; // ranks
|
||||
};
|
||||
|
||||
void GetOutputTopK(const paddle::Tensor& x,
|
||||
const paddle::Tensor& scores,
|
||||
const paddle::Tensor& ranks,
|
||||
int k,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
@@ -66,17 +68,18 @@ void GetOutputTopK(const paddle::Tensor& x,
|
||||
|
||||
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
|
||||
float* scores_data = const_cast<float*>(scores.data<float>());
|
||||
int64_t* ranks_data = const_cast<int64_t*>(ranks.data<int64_t>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(msgid,
|
||||
&msg_rcv,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
|
||||
0,
|
||||
IPC_NOWAIT);
|
||||
} else {
|
||||
ret = msgrcv(msgid,
|
||||
&msg_rcv,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
@@ -97,13 +100,14 @@ void GetOutputTopK(const paddle::Tensor& x,
|
||||
out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2];
|
||||
scores_data[offset] = msg_rcv.mtext_f[offset];
|
||||
}
|
||||
ranks_data[i] = (int64_t)msg_rcv.mtext_ranks[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_output_topk)
|
||||
.Inputs({"x", "scores"})
|
||||
.Inputs({"x", "scores", "ranks"})
|
||||
.Attrs({"k: int", "rank_id: int64_t", "wait_flag: bool"})
|
||||
.Outputs({"x_out", "scores_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}})
|
||||
.Outputs({"x_out", "scores_out", "ranks_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}, {"ranks", "ranks_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetOutputTopK));
|
||||
|
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "helper.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
@@ -59,7 +60,12 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &token_num,
|
||||
const paddle::Tensor &seq_len) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = input_ids.stream();
|
||||
#endif
|
||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int seq_length = input_ids_shape[1];
|
||||
@@ -75,7 +81,11 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
int blockSize = std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
|
||||
#else
|
||||
int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
|
||||
#endif
|
||||
GetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
|
||||
padding_offset.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
|
@@ -0,0 +1,86 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
__global__ void GetPositionIdsAndMaskEncoderBatchKernel(
|
||||
const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度
|
||||
const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度
|
||||
const int* seq_lens_this_time,
|
||||
int* position_ids, // 输出的一维 position_ids
|
||||
int* mask_encoder_batch,
|
||||
const int bsz) { // 批次大小
|
||||
// 当前线程索引(每个线程对应一个批次)
|
||||
int tid = threadIdx.x;
|
||||
if (tid >= bsz) return;
|
||||
|
||||
// 动态计算当前批次的偏移量
|
||||
int offset = 0;
|
||||
for (int i = 0; i < tid; i++) {
|
||||
offset += seq_lens_encoder[i];
|
||||
if (seq_lens_decoder[i] > 0) {
|
||||
offset += seq_lens_this_time[i];
|
||||
}
|
||||
}
|
||||
|
||||
// 当前批次的 encoder 和 decoder 长度
|
||||
int encoder_len = seq_lens_encoder[tid];
|
||||
int decoder_len = seq_lens_decoder[tid];
|
||||
int seq_len_this_time = seq_lens_this_time[tid];
|
||||
|
||||
// 写入 encoder 的 position_ids
|
||||
for (int i = 0; i < encoder_len; i++) {
|
||||
position_ids[offset + i] = i;
|
||||
mask_encoder_batch[offset + i] = 1;
|
||||
}
|
||||
offset += encoder_len;
|
||||
|
||||
// 写入 decoder 的 position_ids
|
||||
if (decoder_len > 0) {
|
||||
for (int i = 0; i < seq_len_this_time; i++) {
|
||||
position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身
|
||||
mask_encoder_batch[offset + i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void GetPositionIdsAndMaskEncoderBatch(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& position_ids,
|
||||
const paddle::Tensor& mask_encoder_batch) {
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>(
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
const_cast<int*>(position_ids.data<int>()),
|
||||
const_cast<int*>(mask_encoder_batch.data<int>()),
|
||||
bsz);
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_position_ids_and_mask_encoder_batch)
|
||||
.Inputs({"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"position_ids",
|
||||
"mask_encoder_batch"})
|
||||
.Outputs({"position_ids_out", "mask_encoder_batch_out"})
|
||||
.SetInplaceMap({{"position_ids", "position_ids_out"},
|
||||
{"mask_encoder_batch", "mask_encoder_batch_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch));
|
@@ -14,7 +14,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
#include "glog/logging.h"
|
||||
#endif
|
||||
#include <fcntl.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
@@ -35,20 +37,35 @@ namespace cub = hipcub;
|
||||
#else
|
||||
#include <cub/cub.cuh>
|
||||
#endif
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
#include "nlohmann/json.hpp"
|
||||
#endif
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "env.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/allocator.h"
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
#include "paddle/phi/backends/custom/custom_context.h"
|
||||
#else
|
||||
#include "paddle/phi/core/cuda_stream.h"
|
||||
#endif
|
||||
#include "paddle/phi/core/dense_tensor.h"
|
||||
#include "paddle/phi/backends/gpu/gpu_info.h"
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
using json = nlohmann::json;
|
||||
#endif
|
||||
|
||||
#define CUDA_CHECK(call) \
|
||||
do { \
|
||||
@@ -197,11 +214,19 @@ HOSTDEVICE inline void Store(const AlignedVector<T, Size> &vec, T *addr) {
|
||||
*addr_vec = vec;
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_HIP
|
||||
template <int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<hip_bfloat16, Size> &vec,
|
||||
int8_t *addr) {
|
||||
printf("Error: Store hip_bfloat16 to int8_t is not supported!");
|
||||
}
|
||||
#else
|
||||
template <int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<__nv_bfloat16, Size> &vec,
|
||||
int8_t *addr) {
|
||||
printf("Error: Store __nv_bfloat16 to int8_t is not supported!");
|
||||
}
|
||||
#endif
|
||||
|
||||
template <int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<half, Size> &vec,
|
||||
@@ -235,6 +260,7 @@ inline int GetBlockSize(int vocab_size) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
inline json readJsonFromFile(const std::string &filePath) {
|
||||
std::ifstream file(filePath);
|
||||
if (!file.is_open()) {
|
||||
@@ -245,6 +271,7 @@ inline json readJsonFromFile(const std::string &filePath) {
|
||||
file >> j;
|
||||
return j;
|
||||
}
|
||||
#endif
|
||||
|
||||
#define cudaCheckError() \
|
||||
{ \
|
||||
@@ -416,6 +443,7 @@ inline std::string base64_decode(const std::string &encoded_string) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
template <typename T>
|
||||
inline T get_relative_best(nlohmann::json *json_data,
|
||||
const std::string &target_key,
|
||||
@@ -428,6 +456,7 @@ inline T get_relative_best(nlohmann::json *json_data,
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids,
|
||||
int length) {
|
||||
@@ -457,7 +486,12 @@ template <typename T>
|
||||
static void PrintMatrix3(const T *mat_d, int num, std::string name) {
|
||||
|
||||
std::vector<T> tmp(num);
|
||||
#ifdef PADDLE_WITH_HIP
|
||||
hipMemcpy(tmp.data(), mat_d, sizeof(T) * num, hipMemcpyDeviceToHost);
|
||||
#else
|
||||
cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost);
|
||||
#endif
|
||||
|
||||
|
||||
std::ofstream outfile;
|
||||
outfile.open(name + ".txt", std::ios::out);
|
||||
@@ -474,6 +508,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
#ifndef PADDLE_WITH_HIP
|
||||
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
|
||||
int mode = 0) {
|
||||
uint32_t flag;
|
||||
@@ -513,3 +548,11 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
#endif
|
||||
|
||||
inline int GetSMVersion() {
|
||||
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
|
||||
phi::backends::gpu::GetCurrentDeviceId());
|
||||
return sm_version;
|
||||
|
||||
}
|
||||
|
255
custom_ops/gpu_ops/mla_attn/attention_updater.cuh
Normal file
255
custom_ops/gpu_ops/mla_attn/attention_updater.cuh
Normal file
@@ -0,0 +1,255 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cutlass/detail/helper_macros.hpp>
|
||||
|
||||
#include "utils.cuh"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const& x, float const& y) { return max(x, y); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; }
|
||||
};
|
||||
|
||||
template <int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template <typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator& op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Allreduce<2> {
|
||||
template <typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator& op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& summary, Operator& op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0>& dst,
|
||||
Tensor<Engine1, Layout1>& src, Operator& op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++) {
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& summary, Operator& op) {
|
||||
thread_reduce_<init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template <bool init, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& max) {
|
||||
MaxOp<float> max_op;
|
||||
reduce_<init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template <bool init, bool warp_reduce = true, typename Engine0, typename Layout0, typename Engine1,
|
||||
typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor,
|
||||
Tensor<Engine1, Layout1>& sum) {
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<init>(tensor, sum, sum_op);
|
||||
if constexpr (warp_reduce) {
|
||||
quad_allreduce_(sum, sum, sum_op);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void apply_exp2(Tensor<Engine0, Layout0>& tensor,
|
||||
Tensor<Engine1, Layout1> const& max) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
auto row_max = max(mi);
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = __expf(tensor(mi, ni) - row_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0>& tensor,
|
||||
Tensor<Engine1, Layout1> const& max,
|
||||
const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
auto row_max = max(mi);
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// row_max * scale is a constant for each row, so we can use fma here
|
||||
tensor(mi, ni) = __expf(tensor(mi, ni) * scale - row_max * scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int NUM_ROWS_PER_THREAD, bool WITH_SCALE>
|
||||
struct OnlineSoftmax {
|
||||
constexpr static float fill_value = -5e4;
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<NUM_ROWS_PER_THREAD>>{}));
|
||||
TensorT row_max, row_sum, scores_scale;
|
||||
float sm_scale_log2;
|
||||
|
||||
CUTLASS_DEVICE OnlineSoftmax(float sm_scale_log2) : sm_scale_log2(sm_scale_log2) {
|
||||
clear(scores_scale);
|
||||
};
|
||||
|
||||
__forceinline__ __device__ TensorT get_lse() const { return row_sum; }
|
||||
|
||||
template <bool init, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT update(Tensor0& acc_s) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
|
||||
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
|
||||
if constexpr (init) {
|
||||
reduce_max</*init=*/true>(scores, row_max);
|
||||
if constexpr (WITH_SCALE) {
|
||||
scale_apply_exp2(scores, row_max, sm_scale_log2);
|
||||
} else {
|
||||
apply_exp2(scores, row_max);
|
||||
}
|
||||
reduce_sum</*init=*/true, /*warp_reduce=*/false>(scores, row_sum);
|
||||
} else {
|
||||
// update row_max
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
reduce_max</*init=*/false>(scores, row_max);
|
||||
// update scores_scale and scale row_sum
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = row_max(mi);
|
||||
if constexpr (WITH_SCALE) {
|
||||
scores_scale(mi) = __expf((scores_max_prev(mi) - scores_max_cur) * sm_scale_log2);
|
||||
} else {
|
||||
scores_scale(mi) = __expf(scores_max_prev(mi) - scores_max_cur);
|
||||
}
|
||||
row_sum(mi) *= scores_scale(mi);
|
||||
}
|
||||
// perform exp2 on scores
|
||||
if constexpr (WITH_SCALE) {
|
||||
scale_apply_exp2(scores, row_max, sm_scale_log2);
|
||||
} else {
|
||||
apply_exp2(scores, row_max);
|
||||
}
|
||||
// update row_sum
|
||||
reduce_sum</*init=*/false, /*warp_reduce=*/false>(scores, row_sum);
|
||||
return scores_scale;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tensor0>
|
||||
__forceinline__ __device__ TensorT finalize(Tensor0& acc_s) {
|
||||
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == NUM_ROWS_PER_THREAD);
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = 1.f / sum;
|
||||
scores_scale(mi) = inv_sum;
|
||||
row_max(mi) *= sm_scale_log2;
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template <typename Tensor1>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1& acc_o) {
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tensor1, typename Tensor2>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1& acc_o, Tensor2& scores_scale_input) {
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == NUM_ROWS_PER_THREAD);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale_input(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
232
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu
Normal file
232
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu
Normal file
@@ -0,0 +1,232 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_device_runtime_api.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include "cute/tensor.hpp"
|
||||
#include "mla_hopper.cuh"
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "batch_mla_with_paged_kv_cache.h"
|
||||
#include "env.h"
|
||||
|
||||
using namespace cute;
|
||||
using namespace mla_attn;
|
||||
using namespace std;
|
||||
|
||||
template <typename T>
|
||||
struct cascade_type_traits {
|
||||
using type = T;
|
||||
using cutlass_type = T;
|
||||
};
|
||||
template <>
|
||||
struct cascade_type_traits<phi::dtype::bfloat16> {
|
||||
using type = __nv_bfloat16;
|
||||
using cutlass_type = cutlass::bfloat16_t;;
|
||||
};
|
||||
template <>
|
||||
struct cascade_type_traits<phi::dtype::float16> {
|
||||
using type = half;
|
||||
using cutlass_type = cutlass::half_t;
|
||||
};
|
||||
template <>
|
||||
struct cascade_type_traits<phi::dtype::float8_e4m3fn> {
|
||||
using type = __nv_fp8_e4m3;
|
||||
using cutlass_type = cutlass::float_e4m3_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void BatchMLAWithPagedKVCacheKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
using NV_TYPE = typename cascade_type_traits<T>::type;
|
||||
using CUTLASS_TYPE = typename cascade_type_traits<T>::cutlass_type;
|
||||
const auto token_num = meta_data.token_nums;
|
||||
const auto block_size = meta_data.block_size;
|
||||
const auto bsz = meta_data.batch_size;
|
||||
const auto q_head_num = meta_data.q_num_heads;
|
||||
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
const auto max_block_num = bsz * max_block_num_per_seq;
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
|
||||
|
||||
int q_head_dim = meta_data.head_dims;
|
||||
int k_head_dim = meta_data.head_dims;
|
||||
int v_head_dim = meta_data.head_dims_v;
|
||||
// int num_chunks = max_dec_len / chunk_size;
|
||||
int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
|
||||
O_tmp = allocator->Allocate(
|
||||
phi::SizeOf(q.dtype()) *
|
||||
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num * v_head_dim));
|
||||
m_tmp = allocator->Allocate(
|
||||
sizeof(float) *
|
||||
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
|
||||
d_tmp = allocator->Allocate(
|
||||
sizeof(float) *
|
||||
static_cast<size_t>(num_chunks * bsz * draft_token_num * q_head_num));
|
||||
|
||||
Params<CUTLASS_TYPE, CUTLASS_TYPE, CUTLASS_TYPE, int> params = {};
|
||||
params.Q = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(q.data<T>()));
|
||||
params.KV = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(latent_cache.data<T>()));
|
||||
params.O = reinterpret_cast<CUTLASS_TYPE*>(const_cast<T*>(out->data<T>()));
|
||||
params.O_tmp = reinterpret_cast<CUTLASS_TYPE*>(O_tmp->ptr());
|
||||
params.m = reinterpret_cast<float*>(m_tmp->ptr());
|
||||
params.d = reinterpret_cast<float*>(d_tmp->ptr());
|
||||
params.block_tables = const_cast<int*>(block_tables.data<int>());
|
||||
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
|
||||
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
|
||||
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
|
||||
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
|
||||
params.padding_offsets = const_cast<int*>(padding_offsets.data<int>());
|
||||
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
|
||||
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
|
||||
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
|
||||
params.num_blocks_x_int = num_blocks_x;
|
||||
params.q_stride_bsz = q_head_num * q_head_dim;
|
||||
params.q_stride_head_num = q_head_dim;
|
||||
params.kv_stride_block_num = block_size * k_head_dim;
|
||||
params.kv_stride_block_size = k_head_dim;
|
||||
params.o_stride_bsz = q_head_num * v_head_dim;
|
||||
params.o_stride_head_num = v_head_dim;
|
||||
params.bsz = bsz;
|
||||
params.token_num = token_num;
|
||||
params.max_seq_len = max_seq_len;
|
||||
params.max_block_num = max_block_num;
|
||||
params.max_block_num_per_seq = max_block_num_per_seq;
|
||||
params.q_num_head = q_head_num;
|
||||
params.qk_head_dim = q_head_dim;
|
||||
params.vo_head_dim = v_head_dim;
|
||||
params.block_size = block_size;
|
||||
params.max_draft_token_num = draft_token_num;
|
||||
params.sm_scale = softmax_scale;
|
||||
params.chunk_size = chunk_size;
|
||||
params.chunk_num = num_chunks;
|
||||
|
||||
if (q_head_dim == 576) {
|
||||
BatchMLAWithPagedKVCacheDispatched<576, 512, NV_TYPE>(
|
||||
params, stream
|
||||
);
|
||||
} else {
|
||||
PD_THROW("error!!! q_head_dim must be 576 !!!\n");
|
||||
}
|
||||
}
|
||||
|
||||
template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
|
||||
template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
68
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h
Normal file
68
custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2023 by FlashInfer team.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/dense_tensor.h"
|
||||
#include "paddle/phi/core/allocator.h"
|
||||
#include "append_attn/utils.cuh"
|
||||
|
||||
template <typename T>
|
||||
void BatchMLAWithPagedKVCacheKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& q, // [token_num, q_head_num, head_dim]
|
||||
const paddle::Tensor& latent_cache, // [max_block_num, q_head_num, block_size, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& shift_bias, // [num_kv_heads, head_dim]
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const int draft_token_num,
|
||||
const bool causal,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
175
custom_ops/gpu_ops/mla_attn/epilogue.cuh
Normal file
175
custom_ops/gpu_ops/mla_attn/epilogue.cuh
Normal file
@@ -0,0 +1,175 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
|
||||
#ifndef ATTENTION_HOPPER_EPILOGUE_CUH_
|
||||
#define ATTENTION_HOPPER_EPILOGUE_CUH_
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "named_barrier.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
#undef DEBUG_MLA
|
||||
#endif
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Ktraits>
|
||||
struct CollectiveEpilogue {
|
||||
using DTypeO = typename Ktraits::DTypeO;
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
|
||||
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
|
||||
|
||||
static constexpr int NUM_WARPS = Ktraits::NUM_WARPS;
|
||||
static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp;
|
||||
|
||||
static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
|
||||
decltype(cute::get<1>(TileShape_PDV{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
|
||||
|
||||
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeO>;
|
||||
using SharedStorage = cute::array_aligned<DTypeO, cute::cosize_v<SmemLayoutO>>;
|
||||
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
using ShapeTmpT = cute::Shape<int32_t, int32_t, int32_t, int32_t>;
|
||||
using StrideTmpT = cute::Shape<int32_t, _1, int32_t, int32_t>;
|
||||
using LayoutTmpT = cute::Layout<ShapeTmpT, StrideTmpT>;
|
||||
|
||||
using ShapeNTMAT = cute::Shape<int32_t, int32_t>;
|
||||
using StrideNTMAT = cute::Shape<int32_t, _1>;
|
||||
using LayoutNTMAT = cute::Layout<ShapeNTMAT, StrideNTMAT>;
|
||||
|
||||
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
|
||||
using TMA_O = decltype(make_tma_copy(
|
||||
GmemTiledCopyOTMA{},
|
||||
make_tensor(make_gmem_ptr(static_cast<DTypeO*>(nullptr)), ShapeT{}, StrideT{}), SmemLayoutO{},
|
||||
select<0, 1>(TileShape_PDV{}), _1{})); // no mcast for O
|
||||
|
||||
static constexpr int VEC_SIZE = cute::ceil_div(128, sizeof_bits_v<DTypeO>); // 8
|
||||
static_assert(HEAD_DIM_VO % VEC_SIZE == 0);
|
||||
static constexpr int NUM_THREADS_PER_ROW = HEAD_DIM_VO / VEC_SIZE; // 64
|
||||
static_assert(NUM_MMA_THREADS % NUM_THREADS_PER_ROW == 0);
|
||||
static constexpr int NUM_ROWS = NUM_MMA_THREADS / NUM_THREADS_PER_ROW;
|
||||
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, DTypeO>;
|
||||
using TiledCopyOThrLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(Int<NUM_ROWS>{}, Int<NUM_THREADS_PER_ROW>{}), LayoutRight{}));
|
||||
using TiledCopyOValLayout =
|
||||
decltype(cute::make_layout(cute::make_shape(_1{}, Int<VEC_SIZE>{}), LayoutRight{}));
|
||||
using TiledCopyO =
|
||||
decltype(make_tiled_copy(TiledCopyOAtom{}, TiledCopyOThrLayout{}, // Thr layout
|
||||
TiledCopyOValLayout{} // Val layout
|
||||
));
|
||||
struct Arguments {
|
||||
DTypeO* O_ptr;
|
||||
LayoutNTMAT const layout_O;
|
||||
DTypeO* O_ptr_tmp;
|
||||
LayoutNTMAT const layout_O_tmp;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
DTypeO* O_ptr;
|
||||
LayoutNTMAT const layout_O;
|
||||
DTypeO* O_ptr_tmp;
|
||||
LayoutNTMAT const layout_O_tmp;
|
||||
};
|
||||
|
||||
static Params to_underlying_arguments_ntma(Arguments const& args) {
|
||||
return {args.O_ptr, args.layout_O, args.O_ptr_tmp, args.layout_O_tmp};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& epilogue_params) {}
|
||||
|
||||
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE,
|
||||
typename TiledMma>
|
||||
CUTLASS_DEVICE void store(Params const& epilogue_params,
|
||||
FrgTensorO const& tOrO,
|
||||
FrgTensorLSE const& lse,
|
||||
SharedStorage& shared_storage,
|
||||
TiledMma tiled_mma,
|
||||
const int thread_idx,
|
||||
const int bid,
|
||||
const int bsz,
|
||||
const int seq_len_now,
|
||||
const int start_token_idx,
|
||||
const int tile_idx,
|
||||
const int kv_len,
|
||||
const int chunk_size,
|
||||
const int max_draft_token_num,
|
||||
const int o_stride_bsz) {
|
||||
const int num_chunks = cute::ceil_div(kv_len, chunk_size);
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tOrO_out = convert_type<DTypeO>(tOrO);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
// make sure gemm done
|
||||
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
|
||||
// r2s
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
// make sure r2s done
|
||||
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kValueEmpty));
|
||||
TiledCopyO gmem_tiled_copy_O;
|
||||
auto O_ptr = num_chunks == 1 ? epilogue_params.O_ptr + start_token_idx * o_stride_bsz : epilogue_params.O_ptr_tmp + (tile_idx * bsz + bid) * max_draft_token_num * o_stride_bsz;
|
||||
Tensor mO = make_tensor(make_gmem_ptr(O_ptr), epilogue_params.layout_O);
|
||||
Tensor gO = local_tile(mO, select<0, 1>(TileShape_PDV{}), make_coord(_, _0{}))(_, _, _0{});
|
||||
Tensor cO = make_identity_tensor(gO.shape()); // (O, D) -> (o_idx, d_idx)
|
||||
ThrCopy thr_copy_O = gmem_tiled_copy_O.get_slice(thread_idx);
|
||||
Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY, CPY_O, CPY_D)
|
||||
Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY, CPY_O, CPY_D)
|
||||
Tensor tOcO = thr_copy_O.partition_D(cO); // (CPY, CPY_O, CPY_D)
|
||||
Tensor tOgOGroup = flatten_1(tOgO); // (CPY, (CPY_O, CPY_D))
|
||||
Tensor tOsOGroup = flatten_1(tOsO); // (CPY, (CPY_O, CPY_D))
|
||||
Tensor tOcOGroup = flatten_1(tOcO); // (CPY, (CPY_O, CPY_D))
|
||||
|
||||
// copy if not out of bound
|
||||
auto predicate_fn = [&](auto coords) {
|
||||
auto s_coords = tOcOGroup(_0{}, coords);
|
||||
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, seq_len_now);
|
||||
};
|
||||
copy_if(gmem_tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_EPILOGUE_CUH_
|
163
custom_ops/gpu_ops/mla_attn/kernel_traits.cuh
Normal file
163
custom_ops/gpu_ops/mla_attn/kernel_traits.cuh
Normal file
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#ifndef ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
|
||||
#define ATTENTION_HOPPER_KERNEL_TRAITS_CUH_
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename MainloopPipeline, typename MainloopPipelineQ, class DTypeQ, class DTypeKV, class DTypeQKAccum, class DTypeOut, class IdType,
|
||||
int BLOCK_SHAPE_KV, class SmemLayoutQ, class SmemLayoutK, class SmemLayoutP, class SmemLayoutRow, class SmemLayoutO>
|
||||
struct alignas(16) SharedStorageQKVO {
|
||||
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
alignas(16) cute::array_aligned<DTypeQ, cute::cosize_v<SmemLayoutP>> smem_p;
|
||||
alignas(16) cute::array_aligned<DTypeQKAccum, cute::cosize_v<SmemLayoutRow>> smem_scale;
|
||||
union {
|
||||
alignas(16) cute::array_aligned<DTypeKV, cute::cosize_v<SmemLayoutK>> smem_kv;
|
||||
alignas(16) cute::array_aligned<DTypeOut, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
struct {
|
||||
alignas(16) typename MainloopPipelineQ::SharedStorage pipeline_q;
|
||||
alignas(16) typename MainloopPipeline::SharedStorage pipeline_kv;
|
||||
};
|
||||
};
|
||||
|
||||
template <bool USE_TMA_LOAD_KV_, int HEAD_DIM_QK_, int HEAD_DIM_VO_, int GROUP_SIZE_, int BLOCK_SHAPE_Q_, int BLOCK_SHAPE_KV_,
|
||||
int NUM_STAGES_, typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_, typename NV_TYPE_>
|
||||
struct AttentionKernelTraits {
|
||||
|
||||
using DTypeQ = DTypeQ_;
|
||||
using DTypeKV = DTypeKV_;
|
||||
using DTypeO = DTypeO_;
|
||||
using IdType = IdType_;
|
||||
using DTypeQKAccum = float;
|
||||
using DTypePVAccum = float;
|
||||
using NV_TYPE = NV_TYPE_;
|
||||
|
||||
|
||||
static constexpr bool USE_TMA_LOAD_KV = USE_TMA_LOAD_KV_;
|
||||
static constexpr int GROUP_SIZE = GROUP_SIZE_;
|
||||
static constexpr int BLOCK_SHAPE_Q = BLOCK_SHAPE_Q_;
|
||||
static_assert(BLOCK_SHAPE_Q % 64 == 0);
|
||||
static constexpr int BLOCK_SHAPE_KV = BLOCK_SHAPE_KV_;
|
||||
static constexpr int HEAD_DIM_QK = HEAD_DIM_QK_;
|
||||
static constexpr int HEAD_DIM_VO = HEAD_DIM_VO_;
|
||||
static constexpr int NUM_PER_STAGE = BLOCK_SHAPE_KV * HEAD_DIM_QK;
|
||||
static_assert(HEAD_DIM_QK % 32 == 0);
|
||||
static_assert(HEAD_DIM_VO % 32 == 0);
|
||||
|
||||
static constexpr int NUM_WARPS = 12;
|
||||
static constexpr int NUM_THREADS = 384;
|
||||
static constexpr int NUM_PRODUCER_THREADS = 128;
|
||||
|
||||
using TileShape_QKD = Shape<Int<BLOCK_SHAPE_Q>, Int<BLOCK_SHAPE_KV>, Int<HEAD_DIM_QK>>;
|
||||
using TileShape_PDV = Shape<Int<BLOCK_SHAPE_Q>, Int<HEAD_DIM_VO>, Int<BLOCK_SHAPE_KV>>;
|
||||
|
||||
static constexpr int NUM_STAGES = NUM_STAGES_;
|
||||
|
||||
using AtomLayoutQKD = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _1, _1>>;
|
||||
using AtomLayoutPV = Layout<Shape<Int<BLOCK_SHAPE_Q / 64>, _2, _1>>;
|
||||
using TiledMmaQK = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<DTypeQ, DTypeKV, DTypeQKAccum, TileShape_QKD>(), AtomLayoutQKD{}));
|
||||
using TiledMmaPV = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
AtomLayoutPV{}));
|
||||
using TiledMmaPVSS = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<DTypeKV, DTypeKV, /*ElementAccum=*/DTypePVAccum, TileShape_PDV,
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
AtomLayoutPV{}));
|
||||
|
||||
static constexpr int NUM_MMA_THREADS = size(TiledMmaPV{});
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
|
||||
decltype(cute::get<2>(TileShape_QKD{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_QKD{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})),
|
||||
decltype(cute::get<2>(TileShape_QKD{}))>());
|
||||
using SmemLayoutK = decltype(tile_to_shape(
|
||||
SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int<NUM_STAGES>{})));
|
||||
using SmemLayoutVt = decltype(composition(
|
||||
SmemLayoutK{}, make_ordered_layout(make_shape(get<2>(TileShape_QKD{}),
|
||||
get<1>(TileShape_QKD{}), Int<NUM_STAGES>{}),
|
||||
Step<_2, _1, _3>{})));
|
||||
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeKV, decltype(cute::get<2>(TileShape_PDV{})),
|
||||
decltype(cute::get<1>(TileShape_PDV{}))>());
|
||||
using SmemLayoutV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomV{},
|
||||
make_shape(get<2>(TileShape_PDV{}), get<1>(TileShape_PDV{}), Int<1>{})));
|
||||
|
||||
// Note this is the transpose in terms of the view, not in terms of memory.
|
||||
using SmemLayoutVtOneStage = decltype(composition(
|
||||
SmemLayoutV{}, make_ordered_layout(make_shape(get<1>(TileShape_PDV{}),
|
||||
get<2>(TileShape_PDV{}), Int<1>{}),
|
||||
Step<_2, _1, _3>{})));
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})),
|
||||
decltype(cute::get<1>(TileShape_PDV{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{})));
|
||||
|
||||
using SmemCopyAtom = Copy_Atom<cute::SM90_U32x4_STSM_N, DTypeQ>;
|
||||
|
||||
static constexpr bool IS_CTA_32 = (BLOCK_SHAPE_KV == 32);
|
||||
using SmemLayoutRowOneStage = Layout<Shape<_2, Int<128>>, Stride<_1, _2>>;
|
||||
using SmemLayoutRowTwoStage = Layout<Shape<_2, Int<128>, _2>, Stride<_1, _2, _256>>;
|
||||
using SmemLayoutRow = std::conditional_t<IS_CTA_32, SmemLayoutRowTwoStage, SmemLayoutRowOneStage>;
|
||||
|
||||
using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})),
|
||||
decltype(cute::get<1>(TileShape_QKD{}))>());
|
||||
using SmemLayoutPSSOneStage = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_QKD{})));
|
||||
using SmemLayoutPSSTwoStage = decltype(tile_to_shape(SmemLayoutAtomP{}, make_shape(Int<BLOCK_SHAPE_Q>{}, Int<BLOCK_SHAPE_KV>{}, Int<2>{})));
|
||||
using SmemLayoutP = std::conditional_t<IS_CTA_32, SmemLayoutPSSTwoStage, SmemLayoutPSSOneStage>;
|
||||
|
||||
using MainloopPipelineQ = typename cutlass::PipelineAsync<1>;
|
||||
using PipelineStateQ = typename cutlass::PipelineState<1>;
|
||||
using MainloopPipeline =
|
||||
std::conditional_t<USE_TMA_LOAD_KV, typename cutlass::PipelineTmaAsync<NUM_STAGES>,
|
||||
typename cutlass::PipelineAsync<NUM_STAGES>>;
|
||||
using PipelineState = typename cutlass::PipelineState<NUM_STAGES>;
|
||||
|
||||
using SharedStorage = SharedStorageQKVO<MainloopPipeline, MainloopPipelineQ, DTypeQ, DTypeKV, DTypeQKAccum, DTypeO, IdType, BLOCK_SHAPE_KV,
|
||||
SmemLayoutQ, SmemLayoutK, SmemLayoutP, SmemLayoutRow, SmemLayoutO>;
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif
|
348
custom_ops/gpu_ops/mla_attn/mainloop_load.cuh
Normal file
348
custom_ops/gpu_ops/mla_attn/mainloop_load.cuh
Normal file
@@ -0,0 +1,348 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
|
||||
#define ATTENTION_HOPPER_MAINLOOP_LOAD_CUH_
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "named_barrier.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
#undef DEBUG_MLA
|
||||
#endif
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Ktraits, bool CAUSAL>
|
||||
struct CollectiveMainloop {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeMD = float;
|
||||
using IdType = typename Ktraits::IdType;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
using TileShape_PDV = typename Ktraits::TileShape_PDV;
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
|
||||
static constexpr int NUM_STAGES = Ktraits::NUM_STAGES;
|
||||
static constexpr int HEAD_DIM_QK = Ktraits::HEAD_DIM_QK;
|
||||
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
|
||||
|
||||
using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(DTypeQ); // 8
|
||||
static_assert(HEAD_DIM_QK % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // 576 512
|
||||
static constexpr int kGmemThreadsPerRow = 64 / kGmemElemsPerLoad; // 8
|
||||
using AlignmentTypeQ = cute::uint_byte_t<static_cast<int>(sizeof(DTypeQ)) * kGmemElemsPerLoad>;
|
||||
using GmemCopyAtomQ = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<AlignmentTypeQ>, DTypeQ>;
|
||||
static constexpr int kNThreadsLoad = Ktraits::NUM_PRODUCER_THREADS;
|
||||
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<
|
||||
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopy = decltype(make_tiled_copy(
|
||||
GmemCopyAtomQ{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
|
||||
using GmemLayoutAtomQ = Layout<
|
||||
Shape<Int<Ktraits::NUM_PRODUCER_THREADS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, // 32, 8
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
using GmemTiledCopyQ = decltype(make_tiled_copy(
|
||||
GmemCopyAtomQ{},
|
||||
GmemLayoutAtomQ{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
|
||||
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
|
||||
using SmemLayoutAtomQ = typename Ktraits::SmemLayoutAtomQ;
|
||||
|
||||
using SmemLayoutK = typename Ktraits::SmemLayoutK;
|
||||
using SmemLayoutV = typename Ktraits::SmemLayoutV;
|
||||
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
|
||||
|
||||
using ShapeQT = cute::Shape<int32_t, int32_t>;
|
||||
using StrideQT = cute::Shape<int32_t, _1>;
|
||||
using LayoutQT = cute::Layout<ShapeQT, StrideQT>;
|
||||
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
using ShapeMDT = cute::Shape<int32_t, int32_t>;
|
||||
using StrideMDT = cute::Shape<int32_t, _1>;
|
||||
using LayoutMDT = cute::Layout<ShapeMDT, StrideMDT>;
|
||||
|
||||
using TMA_KV = decltype(make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<DTypeKV const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)), StrideT{}
|
||||
),
|
||||
take<0, 2>(SmemLayoutK{}),
|
||||
select<1, 2>(TileShape_QKD{}),
|
||||
_1{})); // no mcast for KV
|
||||
|
||||
static constexpr bool USE_TMA_LOAD_KV = Ktraits::USE_TMA_LOAD_KV;
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
using MainloopPipelineQ = typename Ktraits::MainloopPipelineQ;
|
||||
using PipelineParamsQ = typename MainloopPipelineQ::Params;
|
||||
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
|
||||
|
||||
static constexpr uint32_t TmaTransactionBytesQ =
|
||||
static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<DTypeQ> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesKV =
|
||||
static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<DTypeKV> / 8);
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
LayoutQT layout_Q;
|
||||
LayoutT layout_KV;
|
||||
LayoutMDT layout_MD;
|
||||
DTypeQ const* Q_ptr;
|
||||
DTypeKV const* KV_ptr;
|
||||
DTypeMD const* m_ptr;
|
||||
DTypeMD const* d_ptr;
|
||||
IdType const* kv_block_tables;
|
||||
IdType const* seq_lens_this_time;
|
||||
IdType const* seq_lens_encoder;
|
||||
IdType const* seq_lens_decoder;
|
||||
IdType const* cumsum_q_seqlens;
|
||||
IdType const* batch_ids;
|
||||
IdType const* tile_ids_per_batch;
|
||||
IdType const* num_blocks_x;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_stride_bsz;
|
||||
int q_stride_head_num;
|
||||
int kv_stride_block_num;
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
LayoutQT layout_Q;
|
||||
LayoutT layout_KV;
|
||||
LayoutMDT layout_MD;
|
||||
DTypeQ *Q_ptr;
|
||||
DTypeKV* KV_ptr;
|
||||
DTypeMD* m_ptr;
|
||||
DTypeMD* d_ptr;
|
||||
IdType* kv_block_tables;
|
||||
IdType* seq_lens_this_time;
|
||||
IdType* seq_lens_encoder;
|
||||
IdType* seq_lens_decoder;
|
||||
IdType* cumsum_q_seqlens;
|
||||
IdType* batch_ids;
|
||||
IdType* tile_ids_per_batch;
|
||||
IdType* num_blocks_x;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_stride_bsz;
|
||||
int q_stride_head_num;
|
||||
int kv_stride_block_num;
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
TMA_KV tma_load_KV;
|
||||
};
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args) {
|
||||
TMA_KV tma_load_KV;
|
||||
if constexpr (USE_TMA_LOAD_KV) {
|
||||
Tensor mKV = make_tensor(make_gmem_ptr(args.KV_ptr), args.layout_KV);
|
||||
tma_load_KV =
|
||||
make_tma_copy(GmemTiledCopyKV{}, mKV, SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_QKD{}), _1{});
|
||||
}
|
||||
return {args.layout_Q,
|
||||
args.layout_KV,
|
||||
args.layout_MD,
|
||||
const_cast<DTypeQ*>(args.Q_ptr),
|
||||
const_cast<DTypeKV*>(args.KV_ptr),
|
||||
const_cast<DTypeMD*>(args.m_ptr),
|
||||
const_cast<DTypeMD*>(args.d_ptr),
|
||||
const_cast<IdType*>(args.kv_block_tables),
|
||||
const_cast<IdType*>(args.seq_lens_this_time),
|
||||
const_cast<IdType*>(args.seq_lens_encoder),
|
||||
const_cast<IdType*>(args.seq_lens_decoder),
|
||||
const_cast<IdType*>(args.cumsum_q_seqlens),
|
||||
const_cast<IdType*>(args.batch_ids),
|
||||
const_cast<IdType*>(args.tile_ids_per_batch),
|
||||
const_cast<IdType*>(args.num_blocks_x),
|
||||
args.sm_scale,
|
||||
args.bsz,
|
||||
args.max_block_num,
|
||||
args.max_block_num_per_seq,
|
||||
args.q_stride_bsz,
|
||||
args.q_stride_head_num,
|
||||
args.kv_stride_block_num,
|
||||
args.kv_stride_block_size,
|
||||
args.o_stride_bsz,
|
||||
args.o_stride_head_num,
|
||||
args.chunk_size,
|
||||
args.chunk_num,
|
||||
args.max_draft_token_num,
|
||||
tma_load_KV
|
||||
};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
if constexpr (USE_TMA_LOAD_KV) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_KV.get_tma_descriptor());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load_q(Params const& mainloop_params,
|
||||
MainloopPipelineQ pipeline_q,
|
||||
PipelineStateQ& smem_pipe_write_q,
|
||||
SharedStorage& shared_storage,
|
||||
const int thread_idx,
|
||||
const int bid) {
|
||||
int start_q_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
int offset_Q = mainloop_params.q_stride_bsz * start_q_token_idx;
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(mainloop_params.Q_ptr + offset_Q), mainloop_params.layout_Q);
|
||||
Tensor gQ =
|
||||
local_tile(mQ, select<0, 2>(TileShape_QKD{}), make_coord(_, _0{}))(_, _, _0{});
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor cQ = cute::make_identity_tensor(gQ.shape());
|
||||
|
||||
GmemTiledCopyQ gmem_tiled_copy_q;
|
||||
auto gmem_thr_copy_q = gmem_tiled_copy_q.get_slice(thread_idx);
|
||||
Tensor tQgQ = gmem_thr_copy_q.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_q.partition_D(sQ);
|
||||
Tensor tQcQ = gmem_thr_copy_q.partition_D(cQ);
|
||||
Tensor tQcQGroup = flatten_1(tQcQ);
|
||||
|
||||
int valid_q_size = mainloop_params.seq_lens_this_time[bid];
|
||||
auto q_predicate_fn = [&](auto coords) {
|
||||
auto s_coords = tQcQGroup(_0{}, coords);
|
||||
return elem_less(get<0>(s_coords) / Ktraits::GROUP_SIZE, valid_q_size);
|
||||
};
|
||||
Tensor tQgQiGroup = flatten_1(tQgQ);
|
||||
Tensor tQsQiGroup = flatten_1(tQsQ);
|
||||
|
||||
pipeline_q.producer_acquire(smem_pipe_write_q);
|
||||
copy_if(gmem_tiled_copy_q, q_predicate_fn, tQgQiGroup, tQsQiGroup);
|
||||
pipeline_q.producer_commit(smem_pipe_write_q, cutlass::arch::cpasync_barrier_arrive);
|
||||
++smem_pipe_write_q;
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load_kv(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_write_kv,
|
||||
SharedStorage& shared_storage,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int tile_idx) {
|
||||
int thread_idx = threadIdx.x;
|
||||
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0);
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
Tensor mKV = make_tensor(make_gmem_ptr(mainloop_params.KV_ptr), mainloop_params.layout_KV);
|
||||
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
|
||||
GmemTiledCopy gmem_tiled_copy_kv;
|
||||
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
Tensor tKgK = gmem_thr_copy_kv.partition_S(gKV);
|
||||
Tensor tKsK = gmem_thr_copy_kv.partition_S(sK);
|
||||
|
||||
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
const int block_idx = kv_block_tables(bid, kv_tile_idx);
|
||||
pipeline_kv.producer_acquire(smem_pipe_write_kv);
|
||||
Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, block_idx));
|
||||
Tensor tKsKiGroup =
|
||||
flatten_1(tKsK(_, _, _, smem_pipe_write_kv.index()));
|
||||
copy(gmem_tiled_copy_kv, tKgKiGroup, tKsKiGroup);
|
||||
pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
|
||||
++smem_pipe_write_kv;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load_kv_tma(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_write_kv,
|
||||
SharedStorage& shared_storage,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int tile_idx) {
|
||||
int thread_idx = threadIdx.x;
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
|
||||
Tensor mKV = mainloop_params.tma_load_KV.get_tma_tensor(mainloop_params.layout_KV.shape());
|
||||
|
||||
// Prepare the TMA loads
|
||||
Tensor gKV = local_tile(mKV, make_shape(get<1>(TileShape_QKD{}), get<2>(TileShape_QKD{})), make_coord(_, _))(_, _, _0{}, _0{}, _);
|
||||
auto [tKgK, tKsK] =
|
||||
tma_partition(mainloop_params.tma_load_KV, _0{}, Layout<_1>{},
|
||||
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
#pragma unroll 2
|
||||
for (int kv_tile_idx = end_tile_idx; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
const int block_idx = kv_block_tables(bid, kv_tile_idx);
|
||||
pipeline_kv.producer_acquire(smem_pipe_write_kv);
|
||||
copy(mainloop_params.tma_load_KV.with(*pipeline_kv.producer_get_barrier(smem_pipe_write_kv), /*mcast_mask=*/0),
|
||||
tKgK(_, block_idx), tKsK(_, smem_pipe_write_kv.index()));
|
||||
++smem_pipe_write_kv;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_
|
500
custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh
Normal file
500
custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh
Normal file
@@ -0,0 +1,500 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
|
||||
#define ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include "named_barrier.cuh"
|
||||
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
|
||||
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
|
||||
CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
|
||||
MainloopPipelineQ pipeline_q,
|
||||
PipelineStateQ& smem_pipe_read_q,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_read_kv,
|
||||
FrgTensorO& tOrO,
|
||||
AttentionUpdater& attention_updater,
|
||||
const int thread_idx,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int qo_len,
|
||||
const int tile_idx,
|
||||
SharedStorage& shared_storage) {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeMD = typename Ktraits::DTypeO;
|
||||
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
||||
using IdType = typename Ktraits::IdType;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
|
||||
using SmemLayoutK = typename Ktraits::SmemLayoutK;
|
||||
using SmemLayoutV = typename Ktraits::SmemLayoutV;
|
||||
using SmemLayoutP = typename Ktraits::SmemLayoutP;
|
||||
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
|
||||
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
|
||||
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
|
||||
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
|
||||
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _); // (bsz * draft_token_num * num_head)
|
||||
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
|
||||
|
||||
typename Ktraits::TiledMmaQK tiled_mma_qk;
|
||||
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
|
||||
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
|
||||
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
|
||||
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup);
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
|
||||
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
|
||||
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
|
||||
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
if (warp_group_idx == 1) {
|
||||
// consumer 0, compute qk
|
||||
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
|
||||
|
||||
constexpr int n_masking_steps = !CAUSAL ? 1 : cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) + 1;
|
||||
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
|
||||
bool is_first_step = true;
|
||||
// wait q
|
||||
consumer_wait(pipeline_q, smem_pipe_read_q);
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
|
||||
#pragma unroll 1
|
||||
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
|
||||
// wait kv
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
// gemm qk
|
||||
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
|
||||
tSrS);
|
||||
// mask
|
||||
if (masking_step > 0) {
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
|
||||
Tensor tScS = threadMmaQK.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
|
||||
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
|
||||
if constexpr (!CAUSAL) { // Just masking based on col
|
||||
if (kv_idx >= kv_len) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
} else {
|
||||
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update s (exp(s - m))
|
||||
Tensor scale_o = is_first_step ? attention_updater.update</*init=*/true>(tSrS) : attention_updater.update</*init=*/false>(tSrS);
|
||||
is_first_step = false;
|
||||
|
||||
Tensor convert_tSrS = convert_type<DTypeKV>(tSrS);
|
||||
Tensor tPrP = smem_thr_copy_P.retile_S(convert_tSrS);
|
||||
|
||||
// gather qk gemm res
|
||||
cute::copy(smem_tiled_copy_P, tPrP, tPsP);
|
||||
cute::copy(scale_o, tScalesScale);
|
||||
// r2s fence wgmma
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
// make sure r2s all done
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
|
||||
// pv gemm
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
// sync WG1 WG2
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
|
||||
}
|
||||
// release q
|
||||
pipeline_q.consumer_release(smem_pipe_read_q);
|
||||
++smem_pipe_read_q;
|
||||
|
||||
// normalize
|
||||
Tensor scale_o = attention_updater.finalize(tSrS); // warp reduce row sum
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cute::copy(scale_o, tScalesScale);
|
||||
|
||||
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
}
|
||||
|
||||
// WG1 write m,d back to gmem
|
||||
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8,t4->row1 row9
|
||||
const int warp_idx = thread_idx / 32;
|
||||
#pragma unroll
|
||||
for (int w_i = 0; w_i < 2; ++w_i) {
|
||||
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
|
||||
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
|
||||
|
||||
if (token_idx < qo_len) {
|
||||
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
|
||||
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
|
||||
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
|
||||
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
|
||||
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (warp_group_idx == 2) {
|
||||
// consumer 1, compute pv
|
||||
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
|
||||
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
// wait kv
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
|
||||
// A: tPsP
|
||||
cute::copy(tScalesScale, scale_o);
|
||||
|
||||
// rescale
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2,
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
// sync WG1 WG2
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2Sync));
|
||||
}
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG2));
|
||||
cute::copy(tScalesScale, scale_o);
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename Ktraits, bool CAUSAL, typename Params, typename MainloopPipeline, typename MainloopPipelineQ,
|
||||
typename PipelineState, typename PipelineStateQ, typename SharedStorage, typename FrgTensorO, typename AttentionUpdater>
|
||||
CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
|
||||
MainloopPipelineQ pipeline_q,
|
||||
PipelineStateQ& smem_pipe_read_q,
|
||||
MainloopPipeline pipeline_kv,
|
||||
PipelineState& smem_pipe_read_kv,
|
||||
FrgTensorO& tOrO,
|
||||
AttentionUpdater& attention_updater,
|
||||
const int thread_idx,
|
||||
const int bid,
|
||||
const int kv_len,
|
||||
const int qo_len,
|
||||
const int tile_idx,
|
||||
SharedStorage& shared_storage) {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeMD = typename Ktraits::DTypeO; // !!! bf16
|
||||
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
||||
using IdType = typename Ktraits::IdType;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
|
||||
using SmemLayoutK = typename Ktraits::SmemLayoutK;
|
||||
using SmemLayoutV = typename Ktraits::SmemLayoutV;
|
||||
using SmemLayoutP = typename Ktraits::SmemLayoutP;
|
||||
using SmemLayoutRow = typename Ktraits::SmemLayoutRow;
|
||||
using SmemCopyAtom = typename Ktraits::SmemCopyAtom;
|
||||
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutK{});
|
||||
Tensor sVt_s1 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data()), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s2 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s3 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 2 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sVt_s4 = make_tensor(make_smem_ptr(shared_storage.smem_kv.data() + 3 * Ktraits::NUM_PER_STAGE), SmemLayoutVtOneStage{});
|
||||
Tensor sPSS = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), SmemLayoutP{});
|
||||
Tensor mM = make_tensor(make_gmem_ptr(mainloop_params.m_ptr), mainloop_params.layout_MD)(tile_idx, _);
|
||||
Tensor mD = make_tensor(make_gmem_ptr(mainloop_params.d_ptr), mainloop_params.layout_MD)(tile_idx, _);
|
||||
|
||||
Tensor s_scale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutRow{});
|
||||
|
||||
typename Ktraits::TiledMmaQK tiled_mma_qk;
|
||||
auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx);
|
||||
auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtom{}, tiled_mma_qk);
|
||||
auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);
|
||||
Tensor tPsP = smem_thr_copy_P.partition_D(sPSS);
|
||||
Tensor tScalesScale = s_scale(_, thread_idx % cutlass::NumThreadsPerWarpGroup, _);
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv_ss;
|
||||
auto threadMmaPVSS = tiled_mma_pv_ss.get_thread_slice(thread_idx);
|
||||
Tensor tOrV1 = threadMmaPVSS.partition_fragment_B(sVt_s1);
|
||||
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
|
||||
Tensor tOrV3 = threadMmaPVSS.partition_fragment_B(sVt_s3);
|
||||
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
if (warp_group_idx == 1) {
|
||||
// consumer 0, compute qk
|
||||
Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMmaQK.partition_fragment_B(sK);
|
||||
auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; };
|
||||
// wait q
|
||||
consumer_wait(pipeline_q, smem_pipe_read_q);
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
|
||||
// wait k
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
// first qk gemm
|
||||
gemm</*init=*/true, /*wg_wait=*/0>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
|
||||
tSrS);
|
||||
// mask
|
||||
{
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
|
||||
Tensor tScS = threadMmaQK.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
|
||||
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
|
||||
if constexpr (!CAUSAL) { // Just masking based on col
|
||||
if (kv_idx >= kv_len) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
} else {
|
||||
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor scale_o = attention_updater.update</*init=*/true>(tSrS);
|
||||
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
|
||||
// gather qk gemm res
|
||||
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
|
||||
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
|
||||
// r2s fence wgmma
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
|
||||
constexpr int n_masking_steps = CAUSAL ? cute::ceil_div(BLOCK_SHAPE_Q, BLOCK_SHAPE_KV) : 0;
|
||||
--kv_tile_idx;
|
||||
for (int masking_step = n_masking_steps; kv_tile_idx >= start_tile_idx; --masking_step, --kv_tile_idx) {
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{}));
|
||||
PipelineState smem_pipe_read_kv_cur = smem_pipe_read_kv;
|
||||
++smem_pipe_read_kv;
|
||||
// wait next kv
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
|
||||
// gemm next qk
|
||||
gemm</*init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_kv.index()),
|
||||
tSrS);
|
||||
attention_updater.rescale_o(tOrO);
|
||||
// last pv gemm
|
||||
if (smem_pipe_read_kv_cur.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv_cur.index() == 1) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv_cur.index() == 2) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV3(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv_cur.index() % 2),
|
||||
tOrV4(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
// wait cur qk gemm
|
||||
warpgroup_wait<1>();
|
||||
// mask p
|
||||
if (masking_step > 0) {
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{}));
|
||||
Tensor tScS = threadMmaQK.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
int qo_idx = get<0>(tScS(i)) / Ktraits::GROUP_SIZE;
|
||||
int kv_idx = get<1>(tScS(i)) + kv_tile_idx * BLOCK_SHAPE_KV;
|
||||
if constexpr (!CAUSAL) { // Just masking based on col
|
||||
if (kv_idx >= kv_len) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
} else {
|
||||
if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) {
|
||||
tSrS(i) = AttentionUpdater::fill_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// update s (exp(s - m))
|
||||
Tensor scale_o = attention_updater.update</*init=*/false>(tSrS);
|
||||
Tensor tPrP = smem_thr_copy_P.retile_S(convert_type<DTypeKV>(tSrS));
|
||||
|
||||
// gather qk gemm res
|
||||
cute::copy(smem_tiled_copy_P, tPrP, tPsP(_, _, _, smem_pipe_read_kv.index() % 2));
|
||||
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
|
||||
// r2s fence wgmma
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
// make sure tSrS r2s done
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
// wait last pv gemm
|
||||
warpgroup_wait<0>();
|
||||
// release last kv
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv_cur);
|
||||
}
|
||||
// release q
|
||||
pipeline_q.consumer_release(smem_pipe_read_q);
|
||||
++smem_pipe_read_q;
|
||||
// compute last pv
|
||||
attention_updater.rescale_o(tOrO);
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 1) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 2) {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV3(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/-1>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV4(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
scale_o = attention_updater.finalize(tSrS);
|
||||
warpgroup_wait<0>();
|
||||
// release last kv
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cute::copy(scale_o, tScalesScale(_, smem_pipe_read_kv.index() % 2));
|
||||
|
||||
cutlass::arch::NamedBarrier::arrive(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
|
||||
attention_updater.rescale_o(tOrO);
|
||||
}
|
||||
// WG1 write m,d back to gmem
|
||||
if (chunk_num_this_seq > 1 && thread_idx % 4 == 0) { // 16 rows per warp, eg. t0->row0 row8,t4->row1 row9
|
||||
const int warp_idx = thread_idx / 32;
|
||||
#pragma unroll
|
||||
for (int w_i = 0; w_i < 2; ++w_i) {
|
||||
const int token_group_idx = warp_idx * 16 + (thread_idx % 32) / 4 + 8 * w_i;
|
||||
const int token_idx = token_group_idx / Ktraits::GROUP_SIZE;
|
||||
|
||||
if (token_idx < qo_len) {
|
||||
const int head_idx = token_group_idx % Ktraits::GROUP_SIZE;
|
||||
const int bid_offset = mainloop_params.max_draft_token_num * Ktraits::GROUP_SIZE;
|
||||
const int write_idx = bid * bid_offset + token_idx * Ktraits::GROUP_SIZE + head_idx;
|
||||
mM(write_idx) = static_cast<DTypeMD>(attention_updater.row_max(w_i));
|
||||
mD(write_idx) = static_cast<DTypeMD>(attention_updater.row_sum(w_i));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (warp_group_idx == 2) {
|
||||
// consumer 1, compute pv
|
||||
Tensor scale_o = make_tensor<DTypeQKAccum>(Shape<_2>{});
|
||||
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
|
||||
consumer_wait(pipeline_kv, smem_pipe_read_kv);
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWarpSchedulerWG1));
|
||||
// A: tPsP
|
||||
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
|
||||
// rescale
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
if (smem_pipe_read_kv.index() == 0) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV1(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 1) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV2(_, _, _, _0{}), tOrO);
|
||||
} else if (smem_pipe_read_kv.index() == 2) {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV3(_, _, _, _0{}), tOrO);
|
||||
} else {
|
||||
gemm</*init=*/false, /*wg_wait=*/0>(tiled_mma_pv_ss, tOrP_CS2(_, _, _, smem_pipe_read_kv.index() % 2),
|
||||
tOrV4(_, _, _, _0{}), tOrO);
|
||||
}
|
||||
pipeline_kv.consumer_release(smem_pipe_read_kv);
|
||||
++smem_pipe_read_kv;
|
||||
}
|
||||
if (chunk_num_this_seq == 1) {
|
||||
// norm
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_MMA_THREADS, static_cast<int>(NamedBarriers::kWG1WG2LastSync));
|
||||
cute::copy(tScalesScale(_, smem_pipe_read_kv.index() % 2), scale_o);
|
||||
attention_updater.rescale_o(tOrO, scale_o);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_MAINLOOP_MMA_CUH_
|
575
custom_ops/gpu_ops/mla_attn/mla_hopper.cuh
Normal file
575
custom_ops/gpu_ops/mla_attn/mla_hopper.cuh
Normal file
@@ -0,0 +1,575 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#ifndef ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
||||
#define ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_device_runtime_api.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "attention_updater.cuh"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "epilogue.cuh"
|
||||
#include "helper.h"
|
||||
#include "kernel_traits.cuh"
|
||||
#include "mainloop_mma.cuh"
|
||||
#include "mainloop_load.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
#undef DEBUG_MLA
|
||||
#endif
|
||||
// #define DEBUG_MLA
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
|
||||
struct Params {
|
||||
using DTypeQ = DTypeQ_;
|
||||
using DTypeKV = DTypeKV_;
|
||||
using DTypeO = DTypeO_;
|
||||
using IdType = IdType_;
|
||||
|
||||
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
||||
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
|
||||
alignas(16) IdType *block_tables;
|
||||
alignas(16) IdType *seq_lens_this_time;
|
||||
alignas(16) IdType *seq_lens_encoder;
|
||||
alignas(16) IdType *seq_lens_decoder;
|
||||
alignas(16) IdType *cumsum_q_seqlens;
|
||||
alignas(16) IdType *padding_offsets;
|
||||
|
||||
alignas(16) IdType *batch_ids;
|
||||
alignas(16) IdType *tile_ids_per_batch;
|
||||
alignas(16) IdType *num_blocks_x;
|
||||
|
||||
|
||||
uint32_t q_stride_bsz;
|
||||
uint32_t q_stride_head_num;
|
||||
|
||||
uint32_t kv_stride_block_num;
|
||||
uint32_t kv_stride_block_size;
|
||||
|
||||
uint32_t o_stride_bsz;
|
||||
uint32_t o_stride_head_num;
|
||||
|
||||
int bsz;
|
||||
int token_num;
|
||||
int max_seq_len;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_num_head;
|
||||
int qk_head_dim;
|
||||
int vo_head_dim;
|
||||
int block_size;
|
||||
int max_draft_token_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int num_blocks_x_int;
|
||||
|
||||
float sm_scale;
|
||||
};
|
||||
|
||||
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 64) { \
|
||||
constexpr size_t GROUP_SIZE = 64; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size: ", group_size); \
|
||||
return cudaErrorNotSupported; \
|
||||
}
|
||||
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
|
||||
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
typename CollectiveMainloop::Params const mainloop_params,
|
||||
CUTE_GRID_CONSTANT
|
||||
typename CollectiveEpilogue::Params const epilogue_params) {
|
||||
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeO = typename Ktraits::DTypeO;
|
||||
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
|
||||
using TileShape_QKD = typename Ktraits::TileShape_QKD;
|
||||
using TileShape_PDV = typename Ktraits::TileShape_PDV;
|
||||
|
||||
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
|
||||
static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS;
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
const int num_blocks_x = mainloop_params.num_blocks_x[0];
|
||||
|
||||
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
|
||||
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ;
|
||||
using PipelineParamsQ = typename MainloopPipelineQ::Params;
|
||||
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||
|
||||
int const lane_predicate = cute::elect_one_sync();
|
||||
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
|
||||
}
|
||||
|
||||
// Obtain warp index
|
||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
if constexpr (use_tma_load_kv) {
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NUM_MMA_THREADS;
|
||||
} else {
|
||||
pipeline_params.producer_arv_count = NUM_COPY_THREADS;
|
||||
pipeline_params.consumer_arv_count = NUM_MMA_THREADS;
|
||||
}
|
||||
|
||||
PipelineParamsQ pipeline_params_q;
|
||||
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer
|
||||
: MainloopPipelineQ::ThreadCategory::Consumer;
|
||||
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
|
||||
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk
|
||||
|
||||
|
||||
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
|
||||
MainloopPipeline pipeline_kv = [&] {
|
||||
if constexpr (use_tma_load_kv) {
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV;
|
||||
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params,
|
||||
/*cluster_shape=*/Shape<_1, _1, _1>{});
|
||||
} else {
|
||||
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
|
||||
}
|
||||
}();
|
||||
__syncthreads();
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue;
|
||||
|
||||
if (warp_group_idx == 0) {
|
||||
// producer
|
||||
if constexpr(USE_REG_EALLOC) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<72>();
|
||||
}
|
||||
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0);
|
||||
|
||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// consumer
|
||||
if constexpr(USE_REG_EALLOC) {
|
||||
cutlass::arch::warpgroup_reg_alloc<216>();
|
||||
}
|
||||
PipelineStateQ smem_pipe_read_q;
|
||||
PipelineState smem_pipe_read_kv;
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
||||
|
||||
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaStream_t stream) {
|
||||
using DTypeQ = typename KernelTraits::DTypeQ;
|
||||
using DTypeKV = typename KernelTraits::DTypeKV;
|
||||
using DTypeO = typename KernelTraits::DTypeO;
|
||||
using IdType = typename KernelTraits::IdType;
|
||||
using NV_TYPE = typename KernelTraits::NV_TYPE;
|
||||
|
||||
using CollectiveMainloop =
|
||||
CollectiveMainloop<KernelTraits, CAUSAL>;
|
||||
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
|
||||
|
||||
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q
|
||||
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)),
|
||||
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})),
|
||||
params.Q,
|
||||
params.KV,
|
||||
params.m,
|
||||
params.d,
|
||||
params.block_tables,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_encoder,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_ids,
|
||||
params.tile_ids_per_batch,
|
||||
params.num_blocks_x,
|
||||
params.sm_scale,
|
||||
params.bsz,
|
||||
params.max_block_num,
|
||||
params.max_block_num_per_seq,
|
||||
params.q_stride_bsz,
|
||||
params.q_stride_head_num,
|
||||
params.kv_stride_block_num,
|
||||
params.kv_stride_block_size,
|
||||
params.o_stride_bsz,
|
||||
params.o_stride_head_num,
|
||||
params.chunk_size,
|
||||
params.chunk_num,
|
||||
params.max_draft_token_num
|
||||
});
|
||||
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({
|
||||
params.O,
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O
|
||||
params.O_tmp,
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp
|
||||
});
|
||||
|
||||
// Get the ptr to kernel function.
|
||||
auto kernel =
|
||||
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>;
|
||||
int smem_size = sizeof(typename KernelTraits::SharedStorage);
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int multiprocessor_count;
|
||||
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
|
||||
int act_blocks_per_sm;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
||||
|
||||
int gridx;
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
gridx = multiprocessor_count;
|
||||
} else {
|
||||
gridx = params.num_blocks_x_int;
|
||||
}
|
||||
dim3 grid_dims = {gridx, 1, 1};
|
||||
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
||||
dim3 block_dims(ctaSize, 1, 1);
|
||||
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
|
||||
mainloop_params, epilogue_params
|
||||
);
|
||||
if (params.chunk_num > 1) {
|
||||
constexpr int vec_size = 16 / sizeof(DTypeO);
|
||||
constexpr int merge_block_size = 256;
|
||||
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
||||
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.seq_lens_encoder,
|
||||
params.padding_offsets,
|
||||
reinterpret_cast<NV_TYPE*>(params.O),
|
||||
params.max_seq_len,
|
||||
params.chunk_num,
|
||||
params.q_num_head,
|
||||
params.chunk_size,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num
|
||||
);
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
|
||||
constexpr bool CAUSAL = true;
|
||||
if constexpr (HEAD_DIM_QK == 576) {
|
||||
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
|
||||
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
|
||||
HEAD_DIM_QK,
|
||||
HEAD_DIM_VO,
|
||||
GROUP_SIZE,
|
||||
/*BLOCK_SHAPE_Q_=*/64,
|
||||
/*BLOCK_SHAPE_KV_=*/64,
|
||||
/*NUM_STAGES_=*/2,
|
||||
typename Params::DTypeQ,
|
||||
typename Params::DTypeKV,
|
||||
typename Params::DTypeO,
|
||||
typename Params::IdType,
|
||||
NV_TYPE>,
|
||||
CAUSAL,
|
||||
Params,
|
||||
USE_REG_EALLOC,
|
||||
USE_FIXED_BLOCK>(params, stream);)
|
||||
} else {
|
||||
return cudaErrorNotSupported;
|
||||
}
|
||||
return cudaSuccess;
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_PREFILL_SM90_CUH_
|
47
custom_ops/gpu_ops/mla_attn/named_barrier.cuh
Normal file
47
custom_ops/gpu_ops/mla_attn/named_barrier.cuh
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
|
||||
#ifndef ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
|
||||
#define ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "cutlass/arch/barrier.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
enum class NamedBarriers {
|
||||
kQueryEmpty = 0,
|
||||
kValueEmpty = 1,
|
||||
kWarpSchedulerWG1 = 2,
|
||||
kWarpSchedulerWG2 = 3,
|
||||
kWarpSchedulerWG3 = 4,
|
||||
kPrefetchIndices = 5,
|
||||
kOdone = 6,
|
||||
kWG1WG2Sync = 7,
|
||||
kWG0WG1WG2Sync = 8,
|
||||
kWG1WG2LastSync = 9,
|
||||
};
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_NAMED_BARRIERS_CUH_
|
351
custom_ops/gpu_ops/mla_attn/utils.cuh
Normal file
351
custom_ops/gpu_ops/mla_attn/utils.cuh
Normal file
@@ -0,0 +1,351 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef ATTENTION_HOPPER_UTILS_CUH_
|
||||
#define ATTENTION_HOPPER_UTILS_CUH_
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/device/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include <assert.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/tensor.hpp>
|
||||
#include "cutlass/fast_math.h"
|
||||
|
||||
namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename TensorT>
|
||||
CUTLASS_HOST_DEVICE auto flatten_1(TensorT tensor) {
|
||||
Tensor tensor_flatten = cute::flatten(tensor);
|
||||
return cute::group_modes<1, rank(tensor_flatten)>(tensor_flatten);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE auto get_gmem_layout(int nnz, int num_heads, int head_dim, int64_t n_stride,
|
||||
int64_t h_stride) {
|
||||
return make_layout(make_shape(nnz, head_dim, num_heads),
|
||||
make_stride(n_stride, cute::_1{}, h_stride));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(int nnz, int num_heads) {
|
||||
return make_layout(make_shape(num_heads, nnz), make_stride(cute::_1{}, int64_t(num_heads)));
|
||||
}
|
||||
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
|
||||
int head_idx, int offset, int seq_len) {
|
||||
auto g_offset = local_tile(m_tensor(_, _, head_idx), cute::make_shape(1, get<1>(tile_shape)),
|
||||
make_coord(offset, _0{}));
|
||||
auto g_sequence =
|
||||
make_tensor(g_offset.data(),
|
||||
make_layout(cute::make_shape(seq_len, get<1>(tile_shape)), g_offset.stride()));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
|
||||
return g_tensor;
|
||||
}
|
||||
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_lse_local_tile_tensor(const MTensor& m_tensor, const Shape& tile_shape,
|
||||
int head_idx, int offset, int seq_len) {
|
||||
auto g_offset = local_tile(m_tensor(head_idx, _), cute::make_shape(_1{}), make_coord(offset));
|
||||
|
||||
auto g_sequence = make_tensor(g_offset.data(), make_layout(cute::make_shape(seq_len),
|
||||
cute::make_shape(shape<0>(m_tensor))));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
|
||||
return g_tensor;
|
||||
}
|
||||
|
||||
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V,
|
||||
// MMA_N))
|
||||
template <typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = acc_layout;
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
|
||||
make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
||||
};
|
||||
|
||||
// For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16,
|
||||
// MMA_N))
|
||||
template <typename MMA_traits, typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
||||
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
|
||||
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout),
|
||||
make_layout(get<2, 1>(l), get<2>(acc_layout)));
|
||||
};
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const& tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel>*>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template <bool init = false, int wg_wait = 0, typename TensorA, typename TensorB, typename TensorC,
|
||||
typename TiledMma>
|
||||
__forceinline__ __device__ void gemm(TiledMma& tiled_mma, TensorA const& tCrA, TensorB const& tCrB,
|
||||
TensorC& tCrC) {
|
||||
constexpr bool Is_RS =
|
||||
!cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
|
||||
if constexpr (Is_RS) {
|
||||
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
|
||||
}
|
||||
warpgroup_fence_operand(tCrC);
|
||||
warpgroup_arrive();
|
||||
if constexpr (init) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
} else {
|
||||
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
if constexpr (wg_wait >= 0) {
|
||||
warpgroup_wait<wg_wait>();
|
||||
}
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) {
|
||||
warpgroup_fence_operand(const_cast<TensorA&>(tCrA));
|
||||
}
|
||||
}
|
||||
|
||||
#define HOSTDEVICE __host__ __device__
|
||||
|
||||
template <typename T, int Size>
|
||||
struct alignas(sizeof(T) * Size) AlignedVector {
|
||||
T val[Size];
|
||||
|
||||
HOSTDEVICE inline const T& operator[](int i) const { return val[i]; }
|
||||
HOSTDEVICE inline T& operator[](int i) { return val[i]; }
|
||||
};
|
||||
|
||||
template <typename T, int Size>
|
||||
HOSTDEVICE inline void Load(const T* addr, AlignedVector<T, Size>* vec) {
|
||||
const AlignedVector<T, Size>* addr_vec =
|
||||
reinterpret_cast<const AlignedVector<T, Size>*>(addr);
|
||||
*vec = *addr_vec;
|
||||
}
|
||||
|
||||
template <typename T, int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) {
|
||||
AlignedVector<T, Size>* addr_vec =
|
||||
reinterpret_cast<AlignedVector<T, Size>*>(addr);
|
||||
*addr_vec = vec;
|
||||
}
|
||||
|
||||
template <size_t vec_size, typename T>
|
||||
struct prefill_softmax_state_t {
|
||||
AlignedVector<T, vec_size> o;
|
||||
float m;
|
||||
float d;
|
||||
|
||||
__device__ __forceinline__ void init() {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&o) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
m = -3.38953e38f;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
|
||||
const float other_m,
|
||||
const float other_d) {
|
||||
float m_prev = m, d_prev = d;
|
||||
m = max(m_prev, other_m);
|
||||
const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m);
|
||||
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
|
||||
d = d_prev * scale1 + other_d * scale2;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] = o[i] * scale1_T + other_o[i] * scale2_T;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void normalize() {
|
||||
const T d_t = static_cast<T>(d);
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
o[i] /= d_t;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const int * __restrict__ seq_lens_this_time,
|
||||
const int * __restrict__ seq_lens_decoder,
|
||||
const int * __restrict__ seq_lens_encoder,
|
||||
const int * __restrict__ padding_offsets,
|
||||
T * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const int max_seq_len,
|
||||
const int num_chunks,
|
||||
const int num_heads,
|
||||
const int chunk_size,
|
||||
const int head_dim,
|
||||
const int token_num,
|
||||
const int bsz,
|
||||
const int max_draft_token_num) {
|
||||
const int vid = threadIdx.x, ty = threadIdx.y;
|
||||
const int hid = blockIdx.y;
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t ori_token_id = qid + padding_offsets[qid];
|
||||
const uint32_t bid = ori_token_id / max_seq_len;
|
||||
const int seq_len_q = seq_lens_this_time[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
const uint32_t local_seq_id = ori_token_id % max_seq_len;
|
||||
int seq_len_kv = seq_lens_decoder[bid];
|
||||
if (seq_len_kv == 0) continue;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
// not need merge
|
||||
continue;
|
||||
}
|
||||
|
||||
using LoadT = AlignedVector<T, vec_size>;
|
||||
LoadT load_vec;
|
||||
LoadT res_vec;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((half2*)(&res_vec) + i) = make_half2(0, 0);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vec_size / 2; ++i) {
|
||||
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||
}
|
||||
}
|
||||
float m;
|
||||
float d = 1.f;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
m = -5e4f;
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
m = -3.0e+30f;
|
||||
}
|
||||
|
||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||
uint32_t offset;
|
||||
offset = ((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid;
|
||||
float m_prev = m;
|
||||
float d_prev = d;
|
||||
const float m_now = multi_m[offset];
|
||||
const float d_now = multi_d[offset];
|
||||
m = max(m_prev, m_now);
|
||||
offset = (((i * bsz + bid) * max_draft_token_num + local_seq_id) * num_heads + hid) * head_dim + vid * vec_size;
|
||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
|
||||
const T scale1_T = static_cast<T>(scale1), scale2_T = static_cast<T>(scale2);
|
||||
d = d * scale1 + d_now * scale2;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vec_size; j++) {
|
||||
res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T;
|
||||
}
|
||||
}
|
||||
// store ty res
|
||||
Store<T, vec_size>(res_vec, &smem[ty * head_dim + vid * vec_size]);
|
||||
md_smem[2 * ty] = m;
|
||||
md_smem[2 * ty + 1] = d;
|
||||
__syncthreads();
|
||||
if (ty == 0) {
|
||||
// merge bdy
|
||||
prefill_softmax_state_t<vec_size, T> st;
|
||||
st.init();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < bdy; i++) {
|
||||
Load<T, vec_size>(&smem[i * head_dim + vid * vec_size], &load_vec);
|
||||
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||
st.merge(load_vec, m_tmp, d_tmp);
|
||||
}
|
||||
st.normalize();
|
||||
Store<T, vec_size>(st.o, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mla_attn
|
||||
|
||||
#endif // ATTENTION_HOPPER_UTILS_CUH_
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user