mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-30 14:22:27 +08:00
Compare commits
109 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
1d46420c49 | ||
![]() |
fb0f284e67 | ||
![]() |
5d1788c7b5 | ||
![]() |
abd238fc12 | ||
![]() |
e5804b1d98 | ||
![]() |
8c43bc8176 | ||
![]() |
b0f1e0eef4 | ||
![]() |
69be77c8c0 | ||
![]() |
535a15ab8f | ||
![]() |
580460046f | ||
![]() |
4dbc483713 | ||
![]() |
4ead15822c | ||
![]() |
f941124402 | ||
![]() |
b89f083004 | ||
![]() |
4d05ed596c | ||
![]() |
bc1866af58 | ||
![]() |
fe237fe92b | ||
![]() |
3a480abcbb | ||
![]() |
335609efb6 | ||
![]() |
3464f75f98 | ||
![]() |
09d0073fdc | ||
![]() |
63d6e7ce06 | ||
![]() |
aa76085d1f | ||
![]() |
42b80182e0 | ||
![]() |
dda4a9f848 | ||
![]() |
a83a3eea5f | ||
![]() |
0d0340392f | ||
![]() |
0253381fb9 | ||
![]() |
2d1184aefe | ||
![]() |
17314ee126 | ||
![]() |
101ad33332 | ||
![]() |
0fad10b35a | ||
![]() |
61b3997b85 | ||
![]() |
e7bcbbab52 | ||
![]() |
5fc659b900 | ||
![]() |
33db137d0b | ||
![]() |
9d6a42b334 | ||
![]() |
1b712bba82 | ||
![]() |
fd91da7b41 | ||
![]() |
15c8c240b5 | ||
![]() |
7cdd8d290d | ||
![]() |
4c7b8bc458 | ||
![]() |
2e81792d64 | ||
![]() |
b7858c22d9 | ||
![]() |
09bbac6de0 | ||
![]() |
7f64d408a9 | ||
![]() |
ece88596ed | ||
![]() |
bad53c6b6e | ||
![]() |
16940822a7 | ||
![]() |
d48c03413f | ||
![]() |
e9e8443ea8 | ||
![]() |
749b2e9c89 | ||
![]() |
f6ad26fc08 | ||
![]() |
c08561c13a | ||
![]() |
2c3607407f | ||
![]() |
b5e4288704 | ||
![]() |
abbbd0cddc | ||
![]() |
e98937cbba | ||
![]() |
240d6236bc | ||
![]() |
59071268b6 | ||
![]() |
8c660a0dfb | ||
![]() |
ce5adec877 | ||
![]() |
36571fd2d9 | ||
![]() |
830de5a925 | ||
![]() |
d33105baeb | ||
![]() |
24f934f1f9 | ||
![]() |
1e2319cbef | ||
![]() |
e45050cae3 | ||
![]() |
b0f525955c | ||
![]() |
2ea267f624 | ||
![]() |
1d8af7ab73 | ||
![]() |
54affdc44b | ||
![]() |
a4fdb3970b | ||
![]() |
2a86928657 | ||
![]() |
b1c53fa779 | ||
![]() |
da20cf681e | ||
![]() |
4ccd1696ab | ||
![]() |
888780ffde | ||
![]() |
e3768c5a83 | ||
![]() |
1f28bdf994 | ||
![]() |
03a74995b8 | ||
![]() |
b89180f1cd | ||
![]() |
be21ef5047 | ||
![]() |
771e71a24d | ||
![]() |
0350831c2b | ||
![]() |
fee544e808 | ||
![]() |
c4718fd693 | ||
![]() |
f7cad30a38 | ||
![]() |
6b10c19482 | ||
![]() |
f4f1d8de44 | ||
![]() |
6610aa29d0 | ||
![]() |
f72c4de539 | ||
![]() |
f6ffbc3cbd | ||
![]() |
e8bbe7244b | ||
![]() |
57b086dc6b | ||
![]() |
525be243e7 | ||
![]() |
d0f4d6ba3a | ||
![]() |
26d5d737dd | ||
![]() |
fefbd65cf8 | ||
![]() |
1eb8ea7328 | ||
![]() |
ef6649a577 | ||
![]() |
1b54a2831e | ||
![]() |
2579e8fea8 | ||
![]() |
91528f1af9 | ||
![]() |
4e293e50fa | ||
![]() |
66b321d9ec | ||
![]() |
68b4755587 | ||
![]() |
04a8e1ef2b | ||
![]() |
a6e9161045 |
8
.github/workflows/ci.yml
vendored
8
.github/workflows/ci.yml
vendored
@@ -2,7 +2,9 @@ name: CI
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
@@ -27,9 +29,11 @@ jobs:
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
@@ -38,7 +42,7 @@ jobs:
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME}
|
||||
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
|
10
.github/workflows/ci_xpu.yml
vendored
10
.github/workflows/ci_xpu.yml
vendored
@@ -2,7 +2,9 @@ name: CI_XPU
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ develop ]
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
@@ -10,7 +12,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
CI_XPU:
|
||||
runs-on: [self-hosted, XPU-P800-8Card]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
@@ -27,9 +29,11 @@ jobs:
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
@@ -38,7 +42,7 @@ jobs:
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME}
|
||||
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@@ -162,3 +162,5 @@ custom_ops/tmp*
|
||||
build
|
||||
|
||||
.ccls-cache
|
||||
|
||||
third_party
|
||||
|
@@ -8,14 +8,17 @@
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
|
||||
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/installation/nvidia_gpu/"><b> Installation </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/quick_start"><b> Quick Start </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
|
||||
|
||||
</p>
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
@@ -105,3 +105,30 @@ python benchmark_serving.py \
|
||||
--save-result > infer_log.txt 2>&1 &
|
||||
```
|
||||
|
||||
### 投机解码性能测试工具
|
||||
|
||||
#### 使用方式:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_mtp.py \
|
||||
--host 127.0.0.1 --port 8000 \
|
||||
--max-concurrency 16 32 64 96 --num-prompts 256 \
|
||||
--acceptance-rate 0.8 --draft-token-steps 1 2 3 \
|
||||
--s_itl-base-model 15.88 22.84 16.47 16.93 \
|
||||
--dataset-name EBChat \
|
||||
--dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json
|
||||
```
|
||||
|
||||
#### 参数说明
|
||||
|
||||
```bash
|
||||
--host:服务ip地址,用于组url
|
||||
--port:服务HTTP端口,用于组url
|
||||
--max-concurrency:测试并发数
|
||||
--num-prompts:总计发送多少条请求
|
||||
--acceptance-rate:投机解码的模拟接受率
|
||||
--draft-token-steps:投机解码的步数
|
||||
--s_itl-base-model:主模型的解码延迟,可由上述的性能压测工具获得,与batch-size一一对应
|
||||
--dataset-name:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集
|
||||
--dataset-path:测试数据集路径
|
||||
```
|
@@ -36,6 +36,7 @@ AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
@dataclass
|
||||
class RequestFuncInput:
|
||||
"""Input for requesting LLMs via API"""
|
||||
no: int
|
||||
prompt: str
|
||||
history_QA: Optional[dict]
|
||||
hyper_parameters: dict
|
||||
@@ -54,6 +55,7 @@ class RequestFuncInput:
|
||||
@dataclass
|
||||
class RequestFuncOutput:
|
||||
"""Output for requesting LLMs via API"""
|
||||
no: int = 0
|
||||
generated_text: str = ""
|
||||
reasoning_content: str = ""
|
||||
success: bool = False
|
||||
@@ -84,7 +86,7 @@ async def async_request_eb_openai_chat_completions(
|
||||
if request_func_input.multi_modal_content:
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
payload = {
|
||||
"model": "default",
|
||||
"model": request_func_input.model,
|
||||
"messages": request_func_input.history_QA,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
@@ -97,6 +99,9 @@ async def async_request_eb_openai_chat_completions(
|
||||
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
|
||||
print("payload:{}".format(json.dumps(payload, ensure_ascii=False)))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
@@ -104,6 +109,7 @@ async def async_request_eb_openai_chat_completions(
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = 0
|
||||
output.no = request_func_input.no
|
||||
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
@@ -132,7 +138,8 @@ async def async_request_eb_openai_chat_completions(
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
# cached_tokens
|
||||
output.prompt_len = data["usage"]["prompt_tokens_details"]["cached_tokens"]
|
||||
output.prompt_len = data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
|
||||
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
@@ -141,12 +148,12 @@ async def async_request_eb_openai_chat_completions(
|
||||
|
||||
output.generated_text += content or ""
|
||||
output.reasoning_content += reason_content or ""
|
||||
output.arrival_time.append(choices[0].get("arrival_time"))
|
||||
elif usage := data.get("usage"):
|
||||
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
|
||||
elif usage := data.get("usage", {}):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
"completion_tokens", 0)
|
||||
output.prompt_tokens = usage.get(
|
||||
"prompt_tokens")
|
||||
"prompt_tokens", 0)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
@@ -173,6 +180,7 @@ async def async_request_eb_openai_chat_completions(
|
||||
f.write(str(output) + "\n")
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
print("#####final_output:", output)
|
||||
return output
|
||||
|
||||
|
||||
@@ -189,7 +197,7 @@ async def async_request_eb_openai_completions(
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": "default",
|
||||
"model": request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
@@ -202,14 +210,20 @@ async def async_request_eb_openai_completions(
|
||||
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
|
||||
print("payload:", json.dumps(payload, ensure_ascii=False))
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
output.no = request_func_input.no
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
@@ -226,6 +240,7 @@ async def async_request_eb_openai_completions(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
# print("####chunk:", chunk, chunk.usage)
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
@@ -235,21 +250,22 @@ async def async_request_eb_openai_completions(
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
ttft = time.perf_counter() - st
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
generated_text += text or ""
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
output.arrival_time.append(choices[0].get("arrival_time"))
|
||||
generated_text += text or ""
|
||||
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
|
||||
elif usage := data.get("usage"):
|
||||
output.prompt_tokens = usage.get(
|
||||
"prompt_tokens")
|
||||
@@ -262,8 +278,15 @@ async def async_request_eb_openai_completions(
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
|
||||
if output.generated_text == "":
|
||||
output.success = False
|
||||
output.error = "No generated text found!"
|
||||
else:
|
||||
output.success = True
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@@ -271,6 +294,8 @@ async def async_request_eb_openai_completions(
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
print("final_output:{}".format(output))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
|
@@ -38,7 +38,7 @@ class SampleRequest:
|
||||
"""
|
||||
Represents a single inference request for benchmarking.
|
||||
"""
|
||||
|
||||
no: int
|
||||
prompt: Union[str, Any]
|
||||
history_QA: Union[str, Any]
|
||||
json_data: Optional[dict]
|
||||
@@ -229,6 +229,7 @@ class EBDataset(BenchmarkDataset):
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
cnt = 1
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
@@ -246,16 +247,17 @@ class EBDataset(BenchmarkDataset):
|
||||
prompt, None)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
no=cnt,
|
||||
prompt=prompt,
|
||||
prompt_len=self.prompt_len,
|
||||
history_QA=[],
|
||||
expected_output_len=new_output_len,
|
||||
))
|
||||
cnt += 1
|
||||
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
|
||||
class EBChatDataset(BenchmarkDataset):
|
||||
"""
|
||||
Implements the ShareGPT dataset. Loads data from a JSON file and generates
|
||||
@@ -284,6 +286,7 @@ class EBChatDataset(BenchmarkDataset):
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
cnt = 1
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
@@ -297,12 +300,14 @@ class EBChatDataset(BenchmarkDataset):
|
||||
prompt, None)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
no=cnt,
|
||||
json_data=json_data,
|
||||
prompt=prompt,
|
||||
prompt_len=0,
|
||||
history_QA=history_QA,
|
||||
expected_output_len=new_output_len,
|
||||
))
|
||||
cnt += 1
|
||||
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
191
benchmarks/benchmark_mtp.py
Normal file
191
benchmarks/benchmark_mtp.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
import openai
|
||||
import yaml
|
||||
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
|
||||
from benchmark_serving import benchmark
|
||||
|
||||
|
||||
def prepare_input_requests(
|
||||
num_prompts: int, dataset_name: str, dataset_path: str
|
||||
) -> Union[EBDataset, EBChatDataset]:
|
||||
dataset_mapping = {
|
||||
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(
|
||||
num_requests=num_prompts
|
||||
),
|
||||
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
|
||||
num_requests=num_prompts
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
input_requests = dataset_mapping[dataset_name]()
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unknown dataset: {dataset_name}") from err
|
||||
|
||||
return input_requests
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def encode(self, text: str, add_special_tokens: bool = False):
|
||||
return []
|
||||
|
||||
|
||||
def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm):
|
||||
selected_percentile_metrics = ["s_itl"]
|
||||
selected_percentiles = []
|
||||
# Run benchmark
|
||||
results = asyncio.run(
|
||||
benchmark(
|
||||
backend="openai-chat",
|
||||
api_url=f"{base_url}/v1/chat/completions",
|
||||
base_url=base_url,
|
||||
model_id="default",
|
||||
model_name="default",
|
||||
input_requests=input_requests,
|
||||
hyper_parameters={},
|
||||
logprobs=None,
|
||||
request_rate=float("inf"),
|
||||
burstiness=1.0,
|
||||
disable_tqdm=disable_tqdm,
|
||||
profile=False,
|
||||
selected_percentile_metrics=selected_percentile_metrics,
|
||||
selected_percentiles=selected_percentiles,
|
||||
ignore_eos=False,
|
||||
goodput_config_dict=None,
|
||||
max_concurrency=max_concurrency,
|
||||
lora_modules=None,
|
||||
extra_body=None,
|
||||
)
|
||||
)
|
||||
|
||||
record = {
|
||||
"mean_s_itl_ms": results["mean_s_itl_ms"],
|
||||
}
|
||||
|
||||
return record
|
||||
|
||||
|
||||
def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
|
||||
|
||||
tmp = 0.0
|
||||
for i in range(draft_token_step):
|
||||
tmp += pow(acceptance_rate, i + 1)
|
||||
|
||||
r_ac = tmp / (1 + tmp)
|
||||
|
||||
return t_ori / ((1 - r_ac) * t_mtp)
|
||||
|
||||
|
||||
def main(args):
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
input_requests = prepare_input_requests(
|
||||
args.num_prompts, args.dataset_name, args.dataset_path
|
||||
)
|
||||
|
||||
if len(args.max_concurrency) != len(args.s_itl_base_model):
|
||||
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")
|
||||
|
||||
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
||||
# Wramup
|
||||
print("Starting warmup...")
|
||||
with open(os.devnull, "w") as f:
|
||||
with contextlib.redirect_stdout(f):
|
||||
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)
|
||||
|
||||
# Benchmark
|
||||
record = send_one_batch(base_url, max_concurrency, input_requests, False)
|
||||
|
||||
metric_header = f"Speed up"
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
|
||||
for draft_token_step in args.draft_token_steps:
|
||||
speedup = calculate_speedup(
|
||||
args.acceptance_rate,
|
||||
draft_token_step,
|
||||
s_itl,
|
||||
record["mean_s_itl_ms"],
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Speed up on {draft_token_step} steps draft", speedup
|
||||
)
|
||||
)
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=str,
|
||||
default="8000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2, 4, 8, 16, 32),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--acceptance-rate",
|
||||
type=float,
|
||||
default=0.8,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--draft-token-steps",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=(1, 2),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--s_itl-base-model",
|
||||
type=float,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="EBChat",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
@@ -182,6 +182,7 @@ def calculate_metrics(
|
||||
# len(outputs[i].itl) since multiple output tokens may be
|
||||
# bundled together
|
||||
# Note : this may inflate the output token count slightly
|
||||
continue
|
||||
|
||||
actual_output_lens.append(output_len)
|
||||
input_lens.append(outputs[i].prompt_len)
|
||||
@@ -209,6 +210,8 @@ def calculate_metrics(
|
||||
if len(outputs[i].arrival_time) > 2:
|
||||
s_decodes.append((outputs[i].output_tokens - 1) /
|
||||
(outputs[i].arrival_time[-1] - outputs[i].arrival_time[1]))
|
||||
else:
|
||||
print("len(outputs[i].arrival_time) <= 2")
|
||||
completed += 1
|
||||
else:
|
||||
actual_output_lens.append(0)
|
||||
@@ -341,15 +344,16 @@ async def benchmark(
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_output_len = \
|
||||
test_prompt, test_output_len, test_no = \
|
||||
input_requests[0].prompt, \
|
||||
input_requests[0].expected_output_len
|
||||
input_requests[0].expected_output_len, input_requests[0].no
|
||||
test_history_QA = input_requests[0].history_QA
|
||||
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=test_prompt,
|
||||
no=test_no,
|
||||
prompt_len=0,
|
||||
history_QA=test_history_QA,
|
||||
hyper_parameters=hyper_parameters,
|
||||
@@ -384,6 +388,7 @@ async def benchmark(
|
||||
profile_input = RequestFuncInput(model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=test_prompt,
|
||||
no=test_no,
|
||||
api_url=base_url + "/start_profile",
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
@@ -422,7 +427,7 @@ async def benchmark(
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
prompt, output_len = request.prompt, request.expected_output_len
|
||||
prompt, output_len, no = request.prompt, request.expected_output_len, request.no
|
||||
history_QA = request.history_QA
|
||||
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
@@ -433,6 +438,7 @@ async def benchmark(
|
||||
request_func_input = RequestFuncInput(model=req_model_id,
|
||||
model_name=req_model_name,
|
||||
prompt=prompt,
|
||||
no=no,
|
||||
prompt_len=0,
|
||||
history_QA=history_QA,
|
||||
hyper_parameters=hyper_parameters,
|
||||
@@ -452,6 +458,7 @@ async def benchmark(
|
||||
profile_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
no=test_no,
|
||||
api_url=base_url + "/stop_profile",
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
@@ -464,6 +471,8 @@ async def benchmark(
|
||||
pbar.close()
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
print("benchmark_duration:", benchmark_duration)
|
||||
|
||||
|
||||
metrics, actual_output_lens = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
@@ -594,6 +603,155 @@ async def benchmark(
|
||||
return result
|
||||
|
||||
|
||||
def benchmark_metrics(
|
||||
benchmark_duration: float,
|
||||
result_file: str,
|
||||
selected_percentiles: list[float],
|
||||
selected_percentile_metrics: list[str],
|
||||
goodput_config_dict: dict[str, float],
|
||||
):
|
||||
"""Benchmark metrics statistics,generate benchmark result"""
|
||||
outputs = []
|
||||
case_no_list = []
|
||||
with open(result_file) as f:
|
||||
for line in f.readlines():
|
||||
if "RequestFuncOutput" in line:
|
||||
start = line.find("RequestFuncOutput")
|
||||
end = line.rfind(")")
|
||||
para_str = line[start:end + 1]
|
||||
|
||||
output = eval(para_str)
|
||||
outputs.append(output)
|
||||
|
||||
input_requests = [[]] * len(outputs)
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
metrics, actual_output_lens = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
selected_percentiles=selected_percentiles,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
)
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
|
||||
benchmark_duration))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:",
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
if goodput_config_dict:
|
||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||
metrics.request_goodput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
metrics.output_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||
metrics.total_token_throughput))
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput:":
|
||||
metrics.request_goodput if goodput_config_dict else None,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": actual_output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"input_texts": ["" for input in input_requests],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
|
||||
def process_one_metric(
|
||||
# E.g., "ttft"
|
||||
metric_attribute_name: str,
|
||||
# E.g., "TTFT"
|
||||
metric_name: str,
|
||||
# E.g., "Time to First Token"
|
||||
metric_header: str,
|
||||
):
|
||||
# This function prints and adds statistics of the specified
|
||||
# metric.
|
||||
if metric_attribute_name not in selected_percentile_metrics:
|
||||
return
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name} (ms):",
|
||||
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name} (ms):",
|
||||
getattr(metrics, f"median_{metric_attribute_name}_ms")))
|
||||
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"mean_{metric_attribute_name}_ms")
|
||||
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"median_{metric_attribute_name}_ms")
|
||||
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"std_{metric_attribute_name}_ms")
|
||||
for p, value in getattr(metrics,
|
||||
f"percentiles_{metric_attribute_name}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
|
||||
value))
|
||||
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||
|
||||
def process_one_length(
|
||||
# E.g., "ttft"
|
||||
metric_attribute_name: str,
|
||||
# E.g., "TTFT"
|
||||
metric_name: str,
|
||||
# E.g., "Time to First Token"
|
||||
metric_header: str,
|
||||
):
|
||||
# This function prints and adds statistics of the specified
|
||||
# metric.
|
||||
if metric_attribute_name not in selected_percentile_metrics:
|
||||
return
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name}:",
|
||||
getattr(metrics, f"mean_{metric_attribute_name}")))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name}:",
|
||||
getattr(metrics, f"median_{metric_attribute_name}")))
|
||||
result[f"mean_{metric_attribute_name}"] = getattr(
|
||||
metrics, f"mean_{metric_attribute_name}")
|
||||
result[f"median_{metric_attribute_name}"] = getattr(
|
||||
metrics, f"median_{metric_attribute_name}")
|
||||
result[f"std_{metric_attribute_name}"] = getattr(
|
||||
metrics, f"std_{metric_attribute_name}")
|
||||
for p, value in getattr(metrics,
|
||||
f"percentiles_{metric_attribute_name}"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:",
|
||||
value))
|
||||
result[f"p{p_word}_{metric_attribute_name}"] = value
|
||||
|
||||
process_one_length("s_decode", "Decode", "解码速度(tok/s)")
|
||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
|
||||
process_one_metric("tpot", "TPOT",
|
||||
"Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
|
||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||
process_one_metric("s_e2el", "S_E2EL", "Infer End-to-end Latency")
|
||||
process_one_length("input_len", "Input Length", "Input Length")
|
||||
process_one_length("s_input_len", "Input Length", "Infer Input Length")
|
||||
process_one_length("output_len", "Output Length", "Output Length")
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def check_goodput_args(args):
|
||||
"""Check whether the given argument has valid goodput configuration or not"""
|
||||
# Check and parse goodput arguments
|
||||
@@ -759,6 +917,16 @@ def main(args: argparse.Namespace):
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
))
|
||||
|
||||
# benchmark_result = benchmark_metrics(
|
||||
# benchmark_duration=3600,
|
||||
# result_file="your result file",
|
||||
# selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||
# selected_percentiles=[
|
||||
# float(p) for p in args.metric_percentiles.split(",")
|
||||
# ],
|
||||
# goodput_config_dict=goodput_config_dict,
|
||||
# )
|
||||
|
||||
# Save config and results to json
|
||||
if args.save_result:
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint4
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 96
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 4
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wfp8afp8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint4
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
11
benchmarks/yaml/request_yaml/vLLM_default.yaml
Normal file
11
benchmarks/yaml/request_yaml/vLLM_default.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
top_p: 1.0
|
||||
temperature: 1.0
|
||||
metadata:
|
||||
min_tokens: 1
|
||||
max_tokens: 30721
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
||||
skip_special_tokens: false
|
||||
chat_template_kwargs:
|
||||
enable_thinking: true
|
40
build.sh
40
build.sh
@@ -18,6 +18,9 @@ BUILD_WHEEL=${1:-1}
|
||||
PYTHON_VERSION=${2:-"python"}
|
||||
export python=$PYTHON_VERSION
|
||||
FD_CPU_USE_BF16=${3:-"false"}
|
||||
# FD_BUILDING_ARCS: Specify target CUDA architectures for custom ops, e.g., "[80, 90, 100]".
|
||||
# For SM90 (Hopper), use 90. For SM100 (Blackwell), use 100.
|
||||
# These will be translated to 90a / 100a in setup_ops.py for specific features.
|
||||
FD_BUILDING_ARCS=${4:-""}
|
||||
|
||||
|
||||
@@ -74,8 +77,10 @@ function copy_ops(){
|
||||
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
|
||||
if [ "$is_rocm" = "True" ]; then
|
||||
DEVICE_TYPE="rocm"
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
||||
echo -e "ROCM ops have been copy to fastdeploy"
|
||||
echo -e "BASE and ROCM ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
@@ -104,6 +109,23 @@ function copy_ops(){
|
||||
return
|
||||
fi
|
||||
|
||||
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
|
||||
if [ "$if_corex" = "True" ]; then
|
||||
DEVICE_TYPE="iluvatar-gpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
|
||||
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
is_gcu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('gcu'))"`
|
||||
if [ "$is_gcu" = "True" ]; then
|
||||
DEVICE_TYPE="gcu"
|
||||
cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gcu
|
||||
echo -e "gcu ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
DEVICE_TYPE="cpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cd ../../../../
|
||||
@@ -163,17 +185,6 @@ function build_and_install() {
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${BLUE}[build]${NONE} ${GREEN}build fastdeploy wheel success${NONE}\n"
|
||||
|
||||
echo -e "${BLUE}[install]${NONE} installing fastdeploy..."
|
||||
cd $DIST_DIR
|
||||
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install --force-reinstall --no-cache-dir
|
||||
if [ $? -ne 0 ]; then
|
||||
cd ..
|
||||
echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed"
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${BLUE}[install]${NONE} ${GREEN}fastdeploy install success${NONE}\n"
|
||||
cd ..
|
||||
}
|
||||
|
||||
function version_info() {
|
||||
@@ -181,7 +192,10 @@ function version_info() {
|
||||
fastdeploy_git_commit_id=$(git rev-parse HEAD)
|
||||
paddle_version=$(${python} -c "import paddle; print(paddle.__version__)")
|
||||
paddle_git_commit_id=$(${python} -c "import paddle; print(paddle.__git_commit__)")
|
||||
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
|
||||
cuda_version="nvcc-not-installed"
|
||||
if command -v nvcc &> /dev/null; then
|
||||
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
|
||||
fi
|
||||
cxx_version=$(g++ --version | head -n 1 | grep -Po "(?<=\) )[\d.]+")
|
||||
|
||||
echo "fastdeploy GIT COMMIT ID: $fastdeploy_git_commit_id" > $output_file
|
||||
|
@@ -47,7 +47,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
@@ -166,7 +166,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
lambda_batch_ids,
|
||||
lambda_tile_ids_per_batch,
|
||||
@@ -203,7 +203,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
@@ -275,7 +275,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -298,7 +298,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -323,7 +323,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -347,7 +347,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
@@ -404,7 +404,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
@@ -462,7 +462,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
|
||||
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
@@ -474,7 +474,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
@@ -551,7 +551,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& padding_offsets_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
@@ -611,7 +611,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& padding_offsets_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
@@ -689,7 +689,7 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"cu_seqlens_q",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
|
@@ -41,7 +41,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -114,8 +114,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block =
|
||||
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
@@ -405,7 +404,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -477,8 +476,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
@@ -776,7 +774,7 @@ void MultiQueryAppendAttention(
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -882,7 +880,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -939,7 +937,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -974,7 +972,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1103,7 +1101,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1171,7 +1169,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1207,7 +1205,7 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1290,7 +1288,7 @@ void CascadeAppendAttentionC16Kernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -1353,7 +1351,7 @@ void CascadeAppendAttentionC16Kernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -144,8 +144,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
|
||||
const uint32_t kv_b_stride = HEAD_DIM / 2;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block =
|
||||
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
@@ -504,7 +503,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -601,8 +600,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
|
||||
const uint32_t kv_b_stride = HEAD_DIM / 2;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
@@ -963,7 +961,7 @@ void MultiQueryAppendC4Attention(
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1088,7 +1086,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1151,7 +1149,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1186,7 +1184,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1333,7 +1331,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1409,7 +1407,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1444,7 +1442,7 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1527,7 +1525,7 @@ void CascadeAppendAttentionC4Kernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -1594,7 +1592,7 @@ void CascadeAppendAttentionC4Kernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -151,8 +151,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block =
|
||||
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
@@ -473,7 +472,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -575,8 +574,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t kv_b_stride = HEAD_DIM;
|
||||
const uint32_t kv_d_stride = BLOCK_SIZE;
|
||||
const uint32_t q_start_seq_id =
|
||||
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
|
||||
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
|
||||
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
|
||||
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
|
||||
q_head_idx * HEAD_DIM +
|
||||
@@ -900,7 +898,7 @@ void MultiQueryAppendC8Attention(
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1054,7 +1052,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1111,7 +1109,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1146,7 +1144,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1317,7 +1315,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1387,7 +1385,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_kv.data<int>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -1417,7 +1415,7 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1500,7 +1498,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -1565,7 +1563,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -2111,7 +2111,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_kv,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT *__restrict__ out,
|
||||
@@ -2127,7 +2127,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const int bid = blockIdx.x, hid = blockIdx.y;
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int seq_len_q = seq_lens_q[bid];
|
||||
if (seq_len_q == 0) return;
|
||||
int seq_len_kv = seq_lens_kv[bid];
|
||||
|
@@ -41,7 +41,7 @@ void CascadeAppendAttentionC16Kernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -86,7 +86,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -131,7 +131,7 @@ void CascadeAppendAttentionC4Kernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -176,7 +176,7 @@ void CascadeAppendAttentionKernel(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -212,7 +212,7 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
@@ -247,7 +247,7 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
@@ -282,7 +282,7 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
@@ -317,7 +317,7 @@ void CascadeAppendAttentionKernel(
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
|
@@ -35,7 +35,7 @@ __global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi
|
||||
const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cum_offsets,
|
||||
const int * __restrict__ cu_seqlens_q,
|
||||
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
OutT * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
@@ -59,7 +59,7 @@ __global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ T md_smem[bdy * 2];
|
||||
|
||||
const int start_token_ids = qid * max_seq_len - __ldg(&cum_offsets[qid]);
|
||||
const int start_token_ids = cu_seqlens_q[qid];
|
||||
using LoadT = AlignedVector<T, vec_size>;
|
||||
LoadT load_vec;
|
||||
LoadT res_vec;
|
||||
@@ -134,7 +134,7 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke
|
||||
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int * __restrict__ seq_lens_q,
|
||||
const int * __restrict__ seq_lens_kv,
|
||||
const int * __restrict__ cum_offsets,
|
||||
const int * __restrict__ cu_seqlens_q,
|
||||
const int * __restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -171,8 +171,8 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke
|
||||
}
|
||||
kv_len += q_len;
|
||||
const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size);
|
||||
const uint32_t q_start_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const uint32_t q_write_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const uint32_t q_start_idx = cu_seqlens_q[bid];
|
||||
const uint32_t q_write_idx = cu_seqlens_q[bid];
|
||||
if (chunk_id >= num_chunk_this_seq) {
|
||||
return;
|
||||
}
|
||||
@@ -318,7 +318,7 @@ void MultiQueryDecoderAttention(
|
||||
const paddle::Tensor &seq_lens_q,
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
@@ -393,7 +393,7 @@ void MultiQueryDecoderAttention(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -430,7 +430,7 @@ void MultiQueryDecoderAttention(
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
@@ -456,7 +456,7 @@ void MultiQueryDecoderAttention(
|
||||
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
|
||||
seq_lens_q.data<int>(),
|
||||
seq_lens_kv.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
|
||||
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>())),
|
||||
@@ -484,7 +484,7 @@ void DecodeMLAAttentionKernel(
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
@@ -513,7 +513,7 @@ void DecodeMLAAttentionKernel(
|
||||
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
|
||||
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
|
||||
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
|
||||
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cum_offsets,
|
||||
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cu_seqlens_q,
|
||||
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
|
||||
}
|
||||
|
||||
@@ -528,7 +528,7 @@ template void DecodeMLAAttentionKernel<paddle::bfloat16>(
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
@@ -549,7 +549,7 @@ template void DecodeMLAAttentionKernel<paddle::float16>(
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
|
@@ -29,7 +29,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -65,7 +65,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -135,7 +135,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -177,7 +177,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -255,7 +255,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -293,7 +293,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -367,7 +367,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -409,7 +409,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -499,7 +499,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -523,7 +523,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -746,7 +746,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -775,7 +775,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -1048,7 +1048,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1073,7 +1073,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -1347,7 +1347,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1377,7 +1377,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -1740,7 +1740,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1766,7 +1766,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -2035,7 +2035,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -2066,7 +2066,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -2363,7 +2363,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -2389,7 +2389,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
@@ -2733,7 +2733,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -2764,7 +2764,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int half_block_size = block_size / 2;
|
||||
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
|
@@ -22,7 +22,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -58,7 +58,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -80,7 +80,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -103,7 +103,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -126,7 +126,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -150,7 +150,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -183,7 +183,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -208,7 +208,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -233,7 +233,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -258,7 +258,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -283,7 +283,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -318,7 +318,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -345,7 +345,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -372,7 +372,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -399,7 +399,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -425,7 +425,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -472,7 +472,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -504,7 +504,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -537,7 +537,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -571,7 +571,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -604,7 +604,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -651,7 +651,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -678,7 +678,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -704,7 +704,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -730,7 +730,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
|
@@ -24,7 +24,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
|
@@ -879,7 +879,7 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ padding_offsets,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -911,8 +911,7 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx =
|
||||
batch_id * max_seq_len - cum_offsets[batch_id];
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
@@ -1119,7 +1118,7 @@ __global__ void append_write_cache_kv_c4_qkv(
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ padding_offsets,
|
||||
const int *__restrict__ cum_offsets,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -1148,8 +1147,7 @@ __global__ void append_write_cache_kv_c4_qkv(
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx =
|
||||
batch_id * max_seq_len - cum_offsets[batch_id];
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
@@ -1750,7 +1748,7 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1815,7 +1813,7 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
@@ -1838,7 +1836,7 @@ void CascadeAppendWriteCacheKVC4QKV(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
const paddle::Tensor &batch_ids,
|
||||
const paddle::Tensor &tile_ids_per_batch,
|
||||
@@ -1885,7 +1883,7 @@ void CascadeAppendWriteCacheKVC4QKV(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
|
@@ -26,7 +26,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
@@ -143,7 +143,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
batch_ids,
|
||||
tile_ids,
|
||||
@@ -170,7 +170,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
batch_ids,
|
||||
tile_ids,
|
||||
|
@@ -422,7 +422,6 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids,
|
||||
@@ -450,7 +449,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const int token_num = qkv_dims[0];
|
||||
const int max_blocks_per_seq = block_tables.dims()[1];
|
||||
const int block_size = key_cache.dims()[2];
|
||||
const int batch_size = cum_offsets.dims()[0];
|
||||
const int batch_size = seq_lens_this_time.dims()[0];
|
||||
const int kv_num_heads = key_cache_dims[1];
|
||||
const int head_dim = key_cache_dims[3];
|
||||
const int num_heads = qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads;
|
||||
@@ -463,7 +462,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
meta_data.q_num_heads = num_heads;
|
||||
meta_data.max_blocks_per_seq = max_blocks_per_seq;
|
||||
meta_data.block_size = block_size;
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(qkv.place()));
|
||||
|
||||
@@ -528,7 +527,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids,
|
||||
@@ -595,7 +594,6 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
|
@@ -23,7 +23,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
@@ -54,7 +54,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
max_seq_len,
|
||||
@@ -74,7 +74,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len) {
|
||||
@@ -91,7 +91,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
meta_data.batch_size = cu_seqlens_q.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
@@ -100,7 +100,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
@@ -113,7 +113,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
@@ -131,7 +131,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder,
|
||||
@@ -165,7 +165,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
@@ -185,7 +185,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
max_seq_len,
|
||||
@@ -206,7 +206,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
@@ -224,7 +224,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
meta_data.batch_size = cu_seqlens_q.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
@@ -233,7 +233,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
@@ -247,7 +247,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
@@ -266,7 +266,7 @@ PD_BUILD_OP(prefill_mla_write_cache)
|
||||
"seq_lens",
|
||||
"seq_lens_decoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"cu_seqlens_q",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
@@ -281,7 +281,7 @@ PD_BUILD_OP(decode_mla_write_cache)
|
||||
"seq_lens",
|
||||
"seq_lens_encoder",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"cu_seqlens_q",
|
||||
"block_tables"})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
|
@@ -24,7 +24,7 @@ __global__ void decode_absorb_cache_kernel(
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
@@ -50,7 +50,7 @@ __global__ void decode_absorb_cache_kernel(
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / hidden_size;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
|
||||
@@ -96,7 +96,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
@@ -124,7 +124,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
|
||||
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -143,7 +143,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
|
||||
ori_bi,
|
||||
seq_lens[ori_bi],
|
||||
token_id,
|
||||
cum_offsets[ori_bi]);
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
@@ -179,7 +179,7 @@ __global__ void prefill_absorb_cache_kernel(
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
|
@@ -27,7 +27,7 @@ void DecodeMLAAttentionKernel(
|
||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||
const paddle::Tensor &seq_lens_kv,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_table,
|
||||
int max_seq_len,
|
||||
int max_dec_len,
|
||||
|
@@ -27,7 +27,7 @@ __global__ void append_clear_cache_int8_block(
|
||||
const int* __restrict__ seq_lens,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -44,7 +44,7 @@ __global__ void append_clear_cache_int8_block(
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
@@ -101,7 +101,7 @@ __global__ void append_clear_cache_int4_block(
|
||||
const int* __restrict__ seq_lens,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
@@ -118,7 +118,7 @@ __global__ void append_clear_cache_int4_block(
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
@@ -179,7 +179,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
T* __restrict__ q_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
@@ -219,7 +219,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -235,7 +235,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cum_offsets[ori_bi]);
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
@@ -312,7 +312,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
@@ -352,7 +352,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -368,7 +368,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cum_offsets[ori_bi]);
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
@@ -459,7 +459,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -487,7 +487,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
@@ -691,7 +691,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -719,7 +719,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
|
||||
@@ -1069,7 +1069,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1100,7 +1100,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
@@ -1375,7 +1375,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets, // [num_tokens]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
@@ -1406,7 +1406,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const int ori_token_id = token_id + padding_offsets[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
|
@@ -23,7 +23,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -60,7 +60,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
@@ -83,7 +83,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
@@ -107,7 +107,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -137,7 +137,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
seq_lens,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
@@ -152,7 +152,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -176,7 +176,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -202,7 +202,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* padding_offsets,
|
||||
const int* cum_offsets,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
@@ -234,7 +234,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
seq_lens,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
@@ -249,7 +249,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -275,7 +275,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
@@ -302,7 +302,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -350,7 +350,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -377,7 +377,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -410,7 +410,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -443,7 +443,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
padding_offsets.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
@@ -489,7 +489,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -515,7 +515,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -540,7 +540,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
@@ -567,7 +567,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
|
@@ -24,7 +24,7 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
|
@@ -38,7 +38,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -38,7 +38,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -39,7 +39,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -86,7 +86,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -81,7 +81,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -83,7 +83,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -83,7 +83,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const paddle::Tensor& seq_lens_kv,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -23,7 +23,7 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids,
|
||||
|
@@ -54,7 +54,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &padding_offsets, const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &padding_offsets, const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
|
||||
const paddle::Tensor &encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &encoder_num_blocks,
|
||||
@@ -94,7 +94,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &padding_offsets, const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &padding_offsets,
|
||||
const paddle::Tensor &block_tables, const paddle::Tensor &kv_batch_ids,
|
||||
const paddle::Tensor &kv_tile_ids, const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &cache_batch_ids, const paddle::Tensor &cache_tile_ids,
|
||||
@@ -116,11 +116,11 @@ PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
|
||||
|
||||
paddle::Tensor FusedExpertMoeFunc(
|
||||
const paddle::Tensor &input, const paddle::Tensor &gate_weight,
|
||||
const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor> &ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor> &ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_scale,
|
||||
const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor> &up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_scale,
|
||||
const std::string &quant_method, const int moe_topk,
|
||||
const bool norm_topk_prob, const bool group_moe);
|
||||
|
||||
@@ -149,7 +149,7 @@ MoERedundantTopKSelectKernel(const paddle::Tensor &gating_logits,
|
||||
std::vector<paddle::Tensor>
|
||||
EPMoeExpertDispatch(const paddle::Tensor &input, const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &topk_weights,
|
||||
const paddle::optional<paddle::Tensor> &ffn1_in_scale,
|
||||
const paddle::optional<paddle::Tensor> &up_gate_proj_in_scale,
|
||||
const std::vector<int> &token_nums_per_expert,
|
||||
const int token_nums_this_rank,
|
||||
const std::string &moe_quant_type);
|
||||
@@ -158,7 +158,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
const paddle::Tensor &input, const paddle::Tensor &scale,
|
||||
const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights,
|
||||
const paddle::Tensor &token_nums_per_expert,
|
||||
const paddle::Tensor &token_nums_per_expert_padded);
|
||||
const paddle::Tensor &token_nums_per_expert_padded,
|
||||
const bool use_in_ep, const int token_nums_this_rank_padded);
|
||||
|
||||
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
|
||||
const int block_size);
|
||||
@@ -172,7 +173,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
|
||||
const paddle::Tensor &ffn_out, const paddle::Tensor &expert_scales_float,
|
||||
const paddle::Tensor &permute_indices_per_token,
|
||||
const paddle::Tensor &top_k_indices,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const bool norm_topk_prob, const float routed_scaling_factor);
|
||||
|
||||
std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
|
||||
@@ -181,35 +182,35 @@ std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
|
||||
paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight, const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
|
||||
const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency);
|
||||
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
|
||||
const bool used_in_ep_low_latency);
|
||||
|
||||
paddle::Tensor MoeExpertReduceFunc(
|
||||
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
|
||||
const paddle::Tensor &permute_indices_per_token,
|
||||
const paddle::Tensor &top_k_indices,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const bool norm_topk_prob, const float routed_scaling_factor);
|
||||
|
||||
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
|
||||
@@ -330,7 +331,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
@@ -343,7 +344,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len);
|
||||
@@ -369,7 +370,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
@@ -468,6 +468,262 @@ std::vector<paddle::Tensor> NoauxTc(
|
||||
int topk,
|
||||
float routed_scaling_factor);
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& y,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
bool trans_x,
|
||||
bool trans_y,
|
||||
float scale, // only support per-tensor quantization
|
||||
std::string output_dtype,
|
||||
std::string activation_type);
|
||||
|
||||
paddle::Tensor MoeFusedHadamardQuantFp8Func(
|
||||
const paddle::Tensor &input,
|
||||
const paddle::Tensor &scale,
|
||||
const paddle::Tensor &topk_ids,
|
||||
const int top_k,
|
||||
const int intermediate_size,
|
||||
const bool tiled);
|
||||
|
||||
paddle::Tensor FusedHadamardQuantFp8Func(
|
||||
const paddle::Tensor &input,
|
||||
const float scale);
|
||||
#endif
|
||||
|
||||
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
|
||||
|
||||
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
|
||||
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||
|
||||
void dispose(int64_t _fa);
|
||||
|
||||
int64_t meta_size();
|
||||
|
||||
void register_buffer(int64_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
|
||||
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(int64_t _fa);
|
||||
|
||||
void register_graph_buffers(int64_t _fa,
|
||||
const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
|
||||
std::tuple<int64_t, paddle::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size);
|
||||
|
||||
int64_t open_mem_handle(paddle::Tensor& mem_handle);
|
||||
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
// speculative decoding Kernel
|
||||
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& token_num,
|
||||
const paddle::Tensor& seq_len,
|
||||
const paddle::Tensor& seq_lens_encoder);
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
|
||||
const paddle::Tensor& output_cum_offsets_tmp,
|
||||
const paddle::Tensor& out_token_num,
|
||||
const paddle::Tensor& seq_lens_output,
|
||||
const int max_seq_len);
|
||||
|
||||
|
||||
void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &logits,
|
||||
const paddle::Tensor &penalty_scores,
|
||||
const paddle::Tensor &frequency_scores,
|
||||
const paddle::Tensor &presence_scores,
|
||||
const paddle::Tensor &temperatures,
|
||||
const paddle::Tensor &bad_tokens,
|
||||
const paddle::Tensor &cur_len,
|
||||
const paddle::Tensor &min_len,
|
||||
const paddle::Tensor &eos_token_id,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &output_padding_offset,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const int max_seq_len);
|
||||
|
||||
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens,
|
||||
const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len,
|
||||
const paddle::Tensor &end_ids);
|
||||
|
||||
|
||||
void SpeculateVerify(
|
||||
const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &verify_tokens, const paddle::Tensor &verify_scores,
|
||||
const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const paddle::Tensor &actual_candidate_len,
|
||||
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
||||
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
|
||||
|
||||
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &actual_draft_token_nums,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &stop_nums);
|
||||
|
||||
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_idx);
|
||||
|
||||
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank);
|
||||
|
||||
|
||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &draft_token_num,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const int max_ngram_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
|
||||
// MTP
|
||||
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_stop_flags);
|
||||
|
||||
|
||||
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
const paddle::Tensor& base_model_stop_flags,
|
||||
const paddle::Tensor& base_model_is_block_step,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill);
|
||||
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& output_cum_offsets,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& max_dec_len,
|
||||
const paddle::Tensor& end_ids,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_seq_len,
|
||||
const int substep);
|
||||
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> EagleGetHiddenStates(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& accept_nums,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const int actual_draft_token_num);
|
||||
|
||||
void MTPStepPaddle(
|
||||
const paddle::Tensor &base_model_stop_flags,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &batch_drop,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor &encoder_block_lens,
|
||||
const paddle::Tensor &used_list_len,
|
||||
const paddle::Tensor &free_list,
|
||||
const paddle::Tensor &free_list_len,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void SpeculateStepPaddle(
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &ori_seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
|
||||
const paddle::Tensor &encoder_block_lens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor &step_block_list,
|
||||
const paddle::Tensor &step_lens,
|
||||
const paddle::Tensor &recover_block_list,
|
||||
const paddle::Tensor &recover_lens,
|
||||
const paddle::Tensor &need_block_list,
|
||||
const paddle::Tensor &need_block_len,
|
||||
const paddle::Tensor &used_list_len,
|
||||
const paddle::Tensor &free_list,
|
||||
const paddle::Tensor &free_list_len,
|
||||
const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &first_token_ids,
|
||||
const paddle::Tensor &accept_num,
|
||||
const int block_size,
|
||||
const int encoder_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
||||
@@ -559,7 +815,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
* ep_moe_dispatch
|
||||
*/
|
||||
m.def("ep_moe_expert_dispatch", &EPMoeExpertDispatch, py::arg("input"),
|
||||
py::arg("topk_ids"), py::arg("topk_weights"), py::arg("ffn1_in_scale"),
|
||||
py::arg("topk_ids"), py::arg("topk_weights"), py::arg("up_gate_proj_in_scale"),
|
||||
py::arg("token_nums_per_expert"), py::arg("token_nums_this_rank"),
|
||||
py::arg("moe_quant_type"), "ep moe export dispatch function");
|
||||
|
||||
@@ -567,7 +823,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("ep_moe_expert_combine", &EPMoeExpertCombine, py::arg("ffn_out"),
|
||||
py::arg("expert_scales_float"), py::arg("permute_indices_per_token"),
|
||||
py::arg("top_k_indices"), py::arg("ffn2_bias"),
|
||||
py::arg("top_k_indices"), py::arg("down_proj_bias"),
|
||||
py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"),
|
||||
"ep moe export combine function");
|
||||
|
||||
@@ -609,7 +865,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
*/
|
||||
m.def("moe_expert_reduce", &MoeExpertReduceFunc, py::arg("ffn_out"),
|
||||
py::arg("top_k_weight"), py::arg("permute_indices_per_token"),
|
||||
py::arg("top_k_indices"), py::arg("ffn2_bias"),
|
||||
py::arg("top_k_indices"), py::arg("down_proj_bias"),
|
||||
py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"),
|
||||
"moe export reduce function");
|
||||
|
||||
@@ -637,9 +893,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
* append_attn/get_block_shape_and_split_kv_block.cu
|
||||
* get_block_shape_and_split_kv_block
|
||||
*/
|
||||
// m.def("f_get_block_shape_and_split_kv_block",
|
||||
// &GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block
|
||||
// function");
|
||||
m.def("get_block_shape_and_split_kv_block",
|
||||
&GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block function");
|
||||
|
||||
/**
|
||||
* get_padding_offset.cu
|
||||
@@ -700,35 +955,17 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
|
||||
|
||||
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
|
||||
py::arg("a"),
|
||||
py::arg("c_or_none"),
|
||||
py::arg("b_q_weight"),
|
||||
py::arg("b_scales"),
|
||||
py::arg("global_scale_or_none"),
|
||||
py::arg("b_zeros_or_none"),
|
||||
py::arg("g_idx_or_none"),
|
||||
py::arg("perm_or_none"),
|
||||
py::arg("workspace"),
|
||||
py::arg("sorted_token_ids"),
|
||||
py::arg("expert_ids"),
|
||||
py::arg("num_tokens_post_padded"),
|
||||
py::arg("topk_weights"),
|
||||
py::arg("moe_block_size"),
|
||||
py::arg("top_k"),
|
||||
py::arg("mul_topk_weights"),
|
||||
py::arg("is_ep"),
|
||||
py::arg("b_q_type_str"),
|
||||
py::arg("size_m"),
|
||||
py::arg("size_n"),
|
||||
py::arg("size_k"),
|
||||
py::arg("is_k_full"),
|
||||
py::arg("use_atomic_add"),
|
||||
py::arg("use_fp32_reduce"),
|
||||
py::arg("is_zp_float"));
|
||||
py::arg("a"), py::arg("c_or_none"), py::arg("b_q_weight"),
|
||||
py::arg("b_scales"), py::arg("global_scale_or_none"), py::arg("b_zeros_or_none"),
|
||||
py::arg("g_idx_or_none"), py::arg("perm_or_none"), py::arg("workspace"), py::arg("sorted_token_ids"),
|
||||
py::arg("expert_ids"), py::arg("num_tokens_post_padded"), py::arg("topk_weights"), py::arg("moe_block_size"),
|
||||
py::arg("top_k"), py::arg("mul_topk_weights"), py::arg("is_ep"), py::arg("b_q_type_str"),
|
||||
py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"), py::arg("use_atomic_add"),
|
||||
py::arg("use_fp32_reduce"), py::arg("is_zp_float"));
|
||||
|
||||
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
|
||||
"get_position_ids_and_mask_encoder_batch function");
|
||||
|
||||
|
||||
/**
|
||||
* cutlass_scaled_mm.cu
|
||||
* cutlass_scaled_mm
|
||||
@@ -762,4 +999,71 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
|
||||
|
||||
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
|
||||
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
|
||||
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
|
||||
py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function");
|
||||
m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func,
|
||||
py::arg("input"), py::arg("scale"), py::arg("topk_ids"),
|
||||
py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function");
|
||||
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
|
||||
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
|
||||
#endif
|
||||
|
||||
m.def("init_custom_all_reduce", &init_custom_all_reduce, "init all reduce class function");
|
||||
|
||||
m.def("all_reduce", &all_reduce, "all reduce function");
|
||||
|
||||
m.def("dispose", &dispose, "del function for python");
|
||||
|
||||
m.def("meta_size", &meta_size, "meta_size function for Signal struct");
|
||||
|
||||
m.def("register_buffer", ®ister_buffer, "register ipc buffer");
|
||||
|
||||
m.def("register_graph_buffers", ®ister_graph_buffers, "register_graph_buffers");
|
||||
|
||||
m.def("allocate_shared_buffer_and_handle", &allocate_shared_buffer_and_handle, "allocate_shared_buffer_and_handle");
|
||||
|
||||
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
|
||||
|
||||
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
|
||||
|
||||
// speculative decoding Kernel
|
||||
m.def("speculate_get_padding_offset", &SpeculateGetPaddingOffset, "speculate_get_padding_offset function");
|
||||
|
||||
m.def("speculate_get_seq_lens_output", &SpeculateGetSeqLensOutput, "speculate_get_seq_lens_output function");
|
||||
|
||||
m.def("speculate_get_output_padding_offset",&SpeculateGetOutputPaddingOffset, "speculate_get_output_padding_offset function");
|
||||
|
||||
m.def("speculate_get_token_penalty_multi_scores",&SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function");
|
||||
|
||||
m.def("speculate_set_stop_value_multi_seqs",&SpecGetStopFlagsMultiSeqs, "speculate_set_stop_value_multi_seqs function");
|
||||
|
||||
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
|
||||
|
||||
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
|
||||
|
||||
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
|
||||
|
||||
m.def("speculate_save_output", &SpeculateSaveWithOutputMsgStatic, "speculate_save_output function");
|
||||
|
||||
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
|
||||
|
||||
m.def("ngram_match", &NgramMatch, "ngram_match function");
|
||||
|
||||
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
|
||||
|
||||
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");
|
||||
|
||||
m.def("draft_model_update",&DraftModelUpdate, "draft_model_update function");
|
||||
|
||||
m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function");
|
||||
|
||||
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
|
||||
|
||||
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
|
||||
}
|
||||
|
165
custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu
Normal file
165
custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu
Normal file
@@ -0,0 +1,165 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
|
||||
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "all_reduce.cuh"
|
||||
|
||||
// Fake pointer type, must match fptr_t type in ops.h.
|
||||
// We use this type alias to indicate when pointers are passed in as int64_t.
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
paddle::Tensor& rank_data, int64_t rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = fake_ipc_ptrs.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
paddle::Signal* ipc_ptrs[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<paddle::Signal*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, rank_data.data(),
|
||||
rank_data.numel(), rank, world_size,
|
||||
full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs an out-of-place allreduce and stores result in out.
|
||||
*
|
||||
* If _reg_buffer is null, assumes inp.data() is already IPC-registered.
|
||||
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
|
||||
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
auto stream = inp.stream();
|
||||
|
||||
auto input_size = inp.numel() * 2;
|
||||
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
|
||||
if (reg_buffer) {
|
||||
cudaMemcpyAsync(reg_buffer, inp.data(), input_size,
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
} else {
|
||||
reg_buffer = inp.data();
|
||||
}
|
||||
switch (out.dtype()) {
|
||||
case phi::DataType::FLOAT32: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
|
||||
reinterpret_cast<float*>(out.data()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case phi::DataType::FLOAT16: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
|
||||
reinterpret_cast<half*>(out.data()), out.numel());
|
||||
break;
|
||||
}
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
|
||||
case phi::DataType::BFLOAT16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
delete reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
}
|
||||
|
||||
int64_t meta_size() { return sizeof(paddle::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
void* ipc_ptrs[8];
|
||||
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
fa->register_buffer(ipc_ptrs);
|
||||
}
|
||||
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
std::vector<int64_t> bytes(handle.begin(), handle.end());
|
||||
return std::make_tuple(bytes, offsets);
|
||||
}
|
||||
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
void register_graph_buffers(fptr_t _fa,
|
||||
const std::vector<std::vector<int64_t>>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
std::vector<std::string> bytes;
|
||||
bytes.reserve(handles.size());
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
|
||||
auto device_index = phi::backends::gpu::GetCurrentDeviceId();
|
||||
void* buffer;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
auto stream = paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream();
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Allocate buffer
|
||||
CUDACHECK(cudaMalloc((void**)&buffer, size));
|
||||
CUDACHECK(cudaMemsetAsync(buffer, 0, size, stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Create IPC memhandle for the allocated buffer.
|
||||
// Will use it in open_mem_handle.
|
||||
auto handle =
|
||||
paddle::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, paddle::DataType::UINT8, paddle::GPUPlace(device_index));
|
||||
CUDACHECK(
|
||||
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer));
|
||||
|
||||
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
|
||||
}
|
||||
|
||||
|
||||
fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
|
||||
void* ipc_ptr;
|
||||
CUDACHECK(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data()),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
return reinterpret_cast<fptr_t>(ipc_ptr);
|
||||
}
|
||||
|
||||
void free_shared_buffer(fptr_t buffer) {
|
||||
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||
}
|
526
custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh
Normal file
526
custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh
Normal file
@@ -0,0 +1,526 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <array>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace paddle {
|
||||
|
||||
constexpr int kMaxBlocks = 36;
|
||||
// Counter may overflow, but it's fine since unsigned int overflow is
|
||||
// well-defined behavior.
|
||||
using FlagType = uint32_t;
|
||||
struct Signal {
|
||||
alignas(128) FlagType self_counter[kMaxBlocks][8];
|
||||
// Two sets of peer counters are needed for two syncs. The reason is that
|
||||
// it's possible for peer GPU block to arrive at the second sync point while
|
||||
// the current GPU block haven't passed the first sync point. Thus, peer GPU
|
||||
// may write counter+1 while current GPU is busy waiting for counter. We use
|
||||
// alternating counter array to avoid this possibility.
|
||||
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankSignals {
|
||||
Signal* signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) { return __half2float(val); }
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Paddle, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half& assign_add(half& a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float& assign_add(float& a, float b) { return a += b; }
|
||||
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
|
||||
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
||||
template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
|
||||
"l"(flag_addr));
|
||||
#else
|
||||
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
|
||||
"l"(flag_addr));
|
||||
#endif
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
#else
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
#endif
|
||||
return flag;
|
||||
}
|
||||
|
||||
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
|
||||
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1];"
|
||||
: "=r"(flag)
|
||||
: "l"(flag_addr));
|
||||
return flag;
|
||||
}
|
||||
|
||||
// is_start: whether this is the very first synchronization barrier.
|
||||
// need_fence: whether a memory fence is needed. If true, a release-acquire
|
||||
// semantic is used to enforce memory access order before and after this
|
||||
// barrier.
|
||||
template <int ngpus, bool is_start, bool need_fence = false>
|
||||
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
|
||||
int rank) {
|
||||
if constexpr (!is_start) __syncthreads();
|
||||
static_assert(
|
||||
!(is_start && need_fence)); // Start barrier shouldn't need fence.
|
||||
if (threadIdx.x < ngpus) {
|
||||
// Increment the counter. Technically we only need one counter, but we use
|
||||
// multiple per block to eliminate the need to share the counter via smem.
|
||||
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
|
||||
// Write the expected counter value to peer and wait for correct value from
|
||||
// peer.
|
||||
auto peer_counter_ptr =
|
||||
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
|
||||
auto self_counter_ptr =
|
||||
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
|
||||
if constexpr (need_fence) {
|
||||
st_flag_release(peer_counter_ptr, val);
|
||||
while (ld_flag_acquire(self_counter_ptr) != val);
|
||||
} else {
|
||||
st_flag_volatile(peer_counter_ptr, val);
|
||||
while (ld_flag_volatile(self_counter_ptr) != val);
|
||||
}
|
||||
}
|
||||
if constexpr (is_start || need_fence) __syncthreads();
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
|
||||
T* __restrict__ result, int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
DINLINE P* get_tmp_buf(Signal* sg) {
|
||||
return (P*)(((Signal*)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
|
||||
T* __restrict__ result, int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P* ptrs[ngpus];
|
||||
P* tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P*)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P*)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
|
||||
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointters from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
|
||||
// For cuda graph to work, all kernel arguments must be fixed during graph
|
||||
// capture time. However, the peer pointers are not known during graph capture
|
||||
// time. Therefore, during capture, we increment the rank data pointer and use
|
||||
// that as the argument to the kernel. The kernel arguments are stored in
|
||||
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
|
||||
// memory pointed to by the pointers in graph_unreg_buffers_ when
|
||||
// the IPC handles are exchanged between ranks.
|
||||
//
|
||||
// The overall process looks like this:
|
||||
// 1. Graph capture.
|
||||
// 2. Each rank obtains the IPC handles for each addresses used during cuda
|
||||
// graph capture using get_graph_buffer_ipc_meta.
|
||||
// 3. (In Python) all gather the IPC handles.
|
||||
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
|
||||
// the rank data array at corresponding positions.
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void*> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
|
||||
/**
|
||||
* Signals are an array of ipc-enabled buffers from all ranks.
|
||||
* For each of the buffer, the layout is as follows:
|
||||
* | -- sizeof(Signal) -- | ------ a few MB ----- |
|
||||
* The first section is for allreduce synchronization, and the second section
|
||||
* is for storing the intermediate results required by some allreduce algos.
|
||||
*
|
||||
* Note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor.
|
||||
*/
|
||||
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
|
||||
int rank, int world_size, bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(world_size),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(signals[rank]),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
sg_.signals[i] = signals[i];
|
||||
}
|
||||
}
|
||||
|
||||
char* open_ipc_handle(const void* ipc_handle) {
|
||||
auto [it, new_handle] =
|
||||
ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
|
||||
*((const cudaIpcMemHandle_t*)ipc_handle),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr,
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(cudaIpcGetMemHandle(
|
||||
(cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " +
|
||||
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
/**
|
||||
* Register already-shared IPC pointers.
|
||||
*/
|
||||
void register_buffer(void** ptrs) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
data.ptrs[i] = ptrs[i];
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CUDACHECK(
|
||||
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||
buffers_[ptrs[rank_]] = d_data;
|
||||
}
|
||||
|
||||
// Note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void register_graph_buffers(
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto& rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char* handle =
|
||||
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
|
||||
sizeof(RankData) * num_buffers,
|
||||
cudaMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs allreduce, assuming input has already been registered.
|
||||
*
|
||||
* Block and grid default configs are results after careful grid search. Using
|
||||
* 36 blocks give the best or close to the best runtime on the devices I
|
||||
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
|
||||
* take a small amount of SMs. Not quite sure the underlying reason, but my
|
||||
* guess is that too many SMs will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, int size,
|
||||
int threads = 512, int block_limit = 36) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error("max supported block limit is " +
|
||||
std::to_string(kMaxBlocks) + ". Got " +
|
||||
std::to_string(block_limit));
|
||||
|
||||
RankData* ptrs;
|
||||
cudaStreamCaptureStatus status;
|
||||
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
||||
if (status == cudaStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " +
|
||||
std::to_string(reinterpret_cast<uint64_t>(input)) +
|
||||
" is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
||||
rank_, size);
|
||||
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace paddle
|
@@ -76,6 +76,34 @@ enum class SplitKStyle
|
||||
// SPLIT_K_PARALLEL // Not supported yet
|
||||
};
|
||||
|
||||
// New enum for SM100 (Blackwell) Tile Configs
|
||||
// Placeholder values - actual optimal values need research
|
||||
enum class CutlassTileConfigSM100
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
Undefined,
|
||||
|
||||
// Signals that we should run heuristics do choose a config
|
||||
ChooseWithHeuristic,
|
||||
|
||||
// Actual SM100 tile configs based on user input (K-tile is 128B)
|
||||
CtaShape64x64x128B,
|
||||
CtaShape64x128x128B,
|
||||
CtaShape64x256x128B,
|
||||
CtaShape128x64x128B,
|
||||
CtaShape128x128x128B,
|
||||
CtaShape128x256x128B,
|
||||
CtaShape256x64x128B,
|
||||
CtaShape256x128x128B,
|
||||
CtaShape256x256x128B
|
||||
// Note: The user-provided list for get_candidate_tiles_sm100 also includes
|
||||
// CtaShape128x64x128B and CtaShape256x64x128B for specific FP4 grouped gemm cases.
|
||||
// These are already covered by the list above if general suffices.
|
||||
// If they need distinct enum values, they should be added.
|
||||
// For now, keeping the enum concise with unique shapes mentioned for general use.
|
||||
};
|
||||
|
||||
|
||||
enum class CutlassTileConfigSM90
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
@@ -132,9 +160,11 @@ struct CutlassGemmConfig
|
||||
WEIGHT_ONLY = 1u << 0,
|
||||
SIMT_ONLY = 1u << 1,
|
||||
INT8_ONLY = 1u << 2,
|
||||
HOPPER = 1u << 3,
|
||||
HOPPER = 1u << 3, // SM90
|
||||
GROUPED_GEMM = 1u << 4,
|
||||
FP8_ONLY = 1u << 5,
|
||||
BLACKWELL = 1u << 6, // SM100
|
||||
FP4_ONLY = 1u << 7, // For Blackwell FP4/MXFP4 paths
|
||||
};
|
||||
|
||||
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
|
||||
@@ -149,7 +179,17 @@ struct CutlassGemmConfig
|
||||
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
|
||||
bool is_sm90 = false;
|
||||
|
||||
CutlassGemmConfig() {}
|
||||
// config options for sm100 (Blackwell)
|
||||
// Assuming SM100 might use similar schedule/cluster types as SM90 for now.
|
||||
// These might need to become SM100-specific if Blackwell introduces new concepts.
|
||||
CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic;
|
||||
// MainloopScheduleType mainloop_schedule_sm100 = MainloopScheduleType::AUTO; // Example if SM100 has different types
|
||||
// EpilogueScheduleType epilogue_schedule_sm100 = EpilogueScheduleType::AUTO; // Example
|
||||
// ClusterShape cluster_shape_sm100 = ClusterShape::ClusterShape_1x1x1; // Example
|
||||
bool is_sm100 = false;
|
||||
|
||||
|
||||
CutlassGemmConfig() : is_sm90(false), is_sm100(false) {}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
|
||||
: tile_config(tile_config)
|
||||
@@ -157,37 +197,64 @@ struct CutlassGemmConfig
|
||||
, split_k_factor(split_k_factor)
|
||||
, stages(stages)
|
||||
, is_sm90(false)
|
||||
, is_sm100(false)
|
||||
{
|
||||
}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
|
||||
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
|
||||
: tile_config_sm90(tile_config_sm90)
|
||||
, mainloop_schedule(mainloop_schedule)
|
||||
, epilogue_schedule(epilogue_schedule)
|
||||
, cluster_shape(cluster_shape)
|
||||
// Constructor for SM90
|
||||
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90_in, MainloopScheduleType mainloop_schedule_in,
|
||||
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
|
||||
: tile_config_sm90(tile_config_sm90_in)
|
||||
, mainloop_schedule(mainloop_schedule_in)
|
||||
, epilogue_schedule(epilogue_schedule_in)
|
||||
, cluster_shape(cluster_shape_in)
|
||||
, is_sm90(true)
|
||||
, is_sm100(false)
|
||||
{
|
||||
}
|
||||
|
||||
// Constructor for SM100 (Blackwell)
|
||||
// Using existing MainloopScheduleType, EpilogueScheduleType, ClusterShape for now.
|
||||
// These might need to be new SM100-specific types if Blackwell's TMA differs significantly.
|
||||
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100_in, MainloopScheduleType mainloop_schedule_in,
|
||||
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
|
||||
: tile_config_sm100(tile_config_sm100_in)
|
||||
, mainloop_schedule(mainloop_schedule_in) // Potentially use mainloop_schedule_sm100 if types diverge
|
||||
, epilogue_schedule(epilogue_schedule_in) // Potentially use epilogue_schedule_sm100
|
||||
, cluster_shape(cluster_shape_in) // Potentially use cluster_shape_sm100
|
||||
, is_sm90(false) // Explicitly false
|
||||
, is_sm100(true)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
std::string toString() const
|
||||
{
|
||||
std::stringstream tactic;
|
||||
tactic << "Cutlass GEMM Tactic";
|
||||
if (tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
|
||||
if (is_sm100 && tile_config_sm100 != cutlass_extensions::CutlassTileConfigSM100::ChooseWithHeuristic)
|
||||
{
|
||||
assert(is_sm90 && "Invalid cutlass GEMM config");
|
||||
tactic << "\n\tstyle=TMA"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm90
|
||||
assert(is_sm100 && !is_sm90 && "Invalid cutlass GEMM config: SM100");
|
||||
tactic << "\n\tstyle=TMA_SM100" // Indicate SM100 specific TMA if applicable
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm100
|
||||
<< "\n\tcluster shape ID: " << (int) cluster_shape
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule
|
||||
<< "\n\tepi sched: " << (int) epilogue_schedule;
|
||||
}
|
||||
else if (is_sm90 && tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
|
||||
{
|
||||
assert(is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: SM90");
|
||||
tactic << "\n\tstyle=TMA_SM90"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm90
|
||||
<< "\n\tcluster shape ID: " << (int) cluster_shape
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule
|
||||
<< "\n\tepi sched: " << (int) epilogue_schedule;
|
||||
}
|
||||
else if (tile_config != cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
|
||||
{
|
||||
assert(!is_sm90 && "Invalid cutlass GEMM config");
|
||||
assert(!is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: Compatible");
|
||||
tactic << "\n\tstyle=compatible"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config
|
||||
<< "\n\ttile shape ID: " << (int) tile_config
|
||||
<< "\n\tstages: " << (int) stages
|
||||
<< "\n\tsplit_k_style: " << (int) split_k_style
|
||||
<< "\n\tsplit k: " << (int) split_k_factor;
|
||||
@@ -204,9 +271,24 @@ struct CutlassGemmConfig
|
||||
std::istringstream stream(str);
|
||||
std::string line;
|
||||
|
||||
is_sm90 = false; // Reset flags
|
||||
is_sm100 = false;
|
||||
|
||||
while (std::getline(stream, line)) {
|
||||
if (line.find("style=TMA") != std::string::npos) {
|
||||
if (line.find("style=TMA_SM100") != std::string::npos) {
|
||||
is_sm100 = true;
|
||||
is_sm90 = false;
|
||||
std::getline(stream, line);
|
||||
tile_config_sm100 = static_cast<cutlass_extensions::CutlassTileConfigSM100>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
cluster_shape = static_cast<cutlass_extensions::ClusterShape>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
mainloop_schedule = static_cast<cutlass_extensions::MainloopScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
} else if (line.find("style=TMA_SM90") != std::string::npos) { // Check for SM90 specific first
|
||||
is_sm90 = true;
|
||||
is_sm100 = false;
|
||||
std::getline(stream, line);
|
||||
tile_config_sm90 = static_cast<cutlass_extensions::CutlassTileConfigSM90>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
@@ -217,6 +299,7 @@ struct CutlassGemmConfig
|
||||
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
} else if (line.find("style=compatible") != std::string::npos) {
|
||||
is_sm90 = false;
|
||||
is_sm100 = false;
|
||||
std::getline(stream, line);
|
||||
tile_config = static_cast<cutlass_extensions::CutlassTileConfig>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
@@ -233,7 +316,14 @@ struct CutlassGemmConfig
|
||||
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config)
|
||||
{
|
||||
// clang-format off
|
||||
if (config.is_sm90)
|
||||
if (config.is_sm100)
|
||||
{
|
||||
out << "tile_config_sm100_enum: " << int(config.tile_config_sm100)
|
||||
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule) // Assuming same schedule types for now
|
||||
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule) // Assuming same schedule types for now
|
||||
<< ", cluster_shape_enum: " << int(config.cluster_shape); // Assuming same cluster types for now
|
||||
}
|
||||
else if (config.is_sm90)
|
||||
{
|
||||
out << "tile_config_sm90_enum: " << int(config.tile_config_sm90)
|
||||
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
|
||||
|
@@ -245,6 +245,88 @@ bool supports_mcast_along_n(CutlassTileConfigSM90 const tile)
|
||||
#endif
|
||||
}
|
||||
|
||||
// SM100 (Blackwell) candidate tile configurations
|
||||
std::vector<CutlassTileConfigSM100> get_candidate_tiles_sm100(
|
||||
int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return {CutlassTileConfigSM100::CtaShape128x128x128B};
|
||||
#else
|
||||
/* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) */
|
||||
if (config & CutlassGemmConfig::GROUPED_GEMM)
|
||||
{
|
||||
if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4
|
||||
{
|
||||
return {
|
||||
/* 1 SM (M=128) */
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
/* 2 SM (M=256) */
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x256x128B,
|
||||
/* slim tiles for very tall matrices */
|
||||
CutlassTileConfigSM100::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x64x128B};
|
||||
}
|
||||
|
||||
/* Fp8 / Fp16 grouped-GEMM */
|
||||
return {
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x256x128B};
|
||||
}
|
||||
|
||||
/* Non-grouped path (plain GEMM or weight-only) */
|
||||
return {
|
||||
/* 1 SM tiles */
|
||||
CutlassTileConfigSM100::CtaShape64x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape64x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
/* 2 SM tiles */
|
||||
CutlassTileConfigSM100::CtaShape256x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x256x128B};
|
||||
#endif
|
||||
}
|
||||
|
||||
// M-multicast support for SM100.
|
||||
bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return false;
|
||||
#else
|
||||
std::set<CutlassTileConfigSM100> m_tiles{
|
||||
CutlassTileConfigSM100::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x64x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x256x128B};
|
||||
return m_tiles.count(tile) == 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
// N-multicast support for SM100.
|
||||
bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return false;
|
||||
#else
|
||||
std::set<CutlassTileConfigSM100> n_tiles{
|
||||
CutlassTileConfigSM100::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape64x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM100::CtaShape128x256x128B,
|
||||
CutlassTileConfigSM100::CtaShape256x128x128B};
|
||||
return n_tiles.count(tile) == 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
std::vector<CutlassGemmConfig> get_candidate_configs(
|
||||
int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
|
||||
{
|
||||
@@ -284,9 +366,50 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
|
||||
}
|
||||
return candidate_configs;
|
||||
}
|
||||
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
|
||||
else if (sm == 100 && (config_type_param & CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell
|
||||
{
|
||||
std::vector<CutlassTileConfigSM100> tiles = get_candidate_tiles_sm100(sm, config_type_param);
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
for (auto const& tile_config_sm100 : tiles)
|
||||
{
|
||||
// SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO similar to SM90.
|
||||
// Cluster shapes are also handled similarly.
|
||||
CutlassGemmConfig config(
|
||||
tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
|
||||
candidate_configs.push_back(config);
|
||||
|
||||
bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100);
|
||||
bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100);
|
||||
|
||||
if (has_m_mcast)
|
||||
{
|
||||
CutlassGemmConfig mcast_m_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_2x1x1);
|
||||
candidate_configs.push_back(mcast_m_config);
|
||||
}
|
||||
|
||||
if (has_n_mcast)
|
||||
{
|
||||
CutlassGemmConfig mcast_n_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_1x2x1);
|
||||
candidate_configs.push_back(mcast_n_config);
|
||||
}
|
||||
|
||||
if (has_m_mcast && has_n_mcast)
|
||||
{
|
||||
CutlassGemmConfig mcast_mn_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_2x2x1);
|
||||
candidate_configs.push_back(mcast_mn_config);
|
||||
}
|
||||
}
|
||||
return candidate_configs;
|
||||
}
|
||||
|
||||
// Fallback to older architecture configurations
|
||||
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
|
||||
std::vector<CutlassGemmConfig> candidate_configs; //Already declared above for SM90 path, ensure scope is correct or redeclare if necessary.
|
||||
// It's fine here as it's within an else if / else block.
|
||||
bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY;
|
||||
int const min_stages = int8_configs_only ? 3 : 2;
|
||||
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
|
||||
|
@@ -12,21 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
ffn1_n=7168
|
||||
ffn1_k=8192
|
||||
up_gate_proj_n=7168
|
||||
up_gate_proj_k=8192
|
||||
|
||||
ffn2_n=8192
|
||||
ffn2_k=3584
|
||||
rm -rf ffn1_7168_8192.log
|
||||
rm -rf ffn2_8192_3584.log
|
||||
down_proj_n=8192
|
||||
down_proj_k=3584
|
||||
rm -rf up_gate_proj_7168_8192.log
|
||||
rm -rf down_proj_8192_3584.log
|
||||
num_experts=8
|
||||
|
||||
for tokens_per_expert in 12
|
||||
|
||||
do
|
||||
wait
|
||||
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 1 0 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 &
|
||||
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 1 0 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 &
|
||||
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 &
|
||||
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 &
|
||||
done
|
||||
wait
|
||||
echo "#### finish ####"
|
||||
|
@@ -19,7 +19,7 @@
|
||||
#include "fp8_fp8_half_cuda_core_gemm.h"
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
|
||||
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& y,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
@@ -142,7 +142,7 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
|
||||
{
|
||||
if(output_dtype == "bfloat16") {
|
||||
cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params);
|
||||
|
||||
|
||||
} else {
|
||||
cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params);
|
||||
}
|
||||
@@ -174,7 +174,21 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
|
||||
fuse_gemm_config};
|
||||
fp8_fp8_gemm_scale_bias_act(params);
|
||||
}
|
||||
return {out};
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::Tensor& y,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
bool trans_x,
|
||||
bool trans_y,
|
||||
float scale, // only support per-tensor quantization
|
||||
std::string output_dtype,
|
||||
std::string activation_type) {
|
||||
return {cutlass_fp8_fp8_half_gemm_func(
|
||||
x, y, bias, trans_x, trans_y, scale,
|
||||
output_dtype, activation_type)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfGemmFusedInferShape(
|
||||
|
198
custom_ops/gpu_ops/fused_hadamard_quant_fp8.cu
Normal file
198
custom_ops/gpu_ops/fused_hadamard_quant_fp8.cu
Normal file
@@ -0,0 +1,198 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <algorithm>
|
||||
#include "helper.h"
|
||||
|
||||
__device__ __forceinline__ void hadamard32_warp(__nv_bfloat16& x) {
|
||||
int lane_id = threadIdx.x % 32;
|
||||
#pragma unroll
|
||||
for (int step = 0; step < 5; ++step) {
|
||||
const int lane_mask = 1 << step;
|
||||
const __nv_bfloat16 sign = (lane_id & lane_mask) ? -1.f : 1.f;
|
||||
__nv_bfloat16 x_val_other = __shfl_xor_sync(0xffffffff, x, lane_mask);
|
||||
x = sign * x + x_val_other;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void MoeFusedHadamardQuantFp8Kernel(
|
||||
const __nv_bfloat16* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
const int64_t* __restrict__ topk_ids,
|
||||
__nv_fp8_e4m3* out,
|
||||
const int top_k,
|
||||
const int intermediate_size,
|
||||
const int64_t numel
|
||||
) {
|
||||
int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (out_idx >= numel) return;
|
||||
|
||||
int64_t token_idx = out_idx / (top_k * intermediate_size);
|
||||
int64_t topk_idx = (out_idx / intermediate_size) % top_k;
|
||||
int64_t inter_idx = out_idx % intermediate_size;
|
||||
|
||||
int64_t input_idx = token_idx * intermediate_size + inter_idx;
|
||||
if (input_idx >= numel / top_k) return;
|
||||
|
||||
int64_t expert_id = topk_ids[token_idx * top_k + topk_idx];
|
||||
float scale_value = scale[expert_id];
|
||||
|
||||
__nv_bfloat16 x = input[input_idx];
|
||||
hadamard32_warp(x);
|
||||
|
||||
float x_fp32 = __bfloat162float(x);
|
||||
float quantized = x_fp32 / scale_value;
|
||||
out[out_idx] = static_cast<__nv_fp8_e4m3>(quantized);
|
||||
}
|
||||
|
||||
__global__ void MoeFusedHadamardQuantFp8TiledKernel(
|
||||
const __nv_bfloat16* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
const int64_t* __restrict__ topk_ids,
|
||||
__nv_fp8_e4m3* out,
|
||||
const int top_k,
|
||||
const int intermediate_size,
|
||||
const int64_t numel
|
||||
) {
|
||||
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= numel) return;
|
||||
|
||||
int64_t token_idx = idx / intermediate_size;
|
||||
int64_t expert_id = topk_ids[token_idx];
|
||||
float scale_value = scale[expert_id];
|
||||
|
||||
__nv_bfloat16 x = input[idx];
|
||||
hadamard32_warp(x);
|
||||
|
||||
float x_fp32 = __bfloat162float(x);
|
||||
float quantized = x_fp32 / scale_value;
|
||||
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeFusedHadamardQuantFp8(
|
||||
const paddle::Tensor &input,
|
||||
const paddle::Tensor &scale,
|
||||
const paddle::Tensor &topk_ids,
|
||||
const int top_k,
|
||||
const int intermediate_size,
|
||||
const bool tiled) {
|
||||
int64_t numel = input.numel();
|
||||
if (!tiled) numel *= top_k;
|
||||
paddle::Tensor out = GetEmptyTensor(
|
||||
{numel / intermediate_size, intermediate_size},
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
input.place());
|
||||
constexpr int64_t thread_per_block = 256;
|
||||
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
|
||||
auto stream = input.stream();
|
||||
if (tiled) {
|
||||
MoeFusedHadamardQuantFp8TiledKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
|
||||
top_k,
|
||||
intermediate_size,
|
||||
numel
|
||||
);
|
||||
} else {
|
||||
MoeFusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(input.data<phi::dtype::bfloat16>()),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
|
||||
top_k,
|
||||
intermediate_size,
|
||||
numel
|
||||
);
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(moe_fused_hadamard_quant_fp8)
|
||||
.Inputs({"input", "scale", "topk_ids"})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"top_k: int",
|
||||
"intermediate_size: int",
|
||||
"tiled: bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeFusedHadamardQuantFp8));
|
||||
|
||||
|
||||
paddle::Tensor MoeFusedHadamardQuantFp8Func(
|
||||
const paddle::Tensor &input,
|
||||
const paddle::Tensor &scale,
|
||||
const paddle::Tensor &topk_ids,
|
||||
const int top_k,
|
||||
const int intermediate_size,
|
||||
const bool tiled) {
|
||||
return MoeFusedHadamardQuantFp8(input, scale, topk_ids, top_k, intermediate_size, tiled)[0];
|
||||
}
|
||||
|
||||
|
||||
__global__ void FusedHadamardQuantFp8Kernel(
|
||||
const __nv_bfloat16* __restrict__ input,
|
||||
__nv_fp8_e4m3* out,
|
||||
const float scale,
|
||||
const int64_t numel) {
|
||||
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= numel) return;
|
||||
|
||||
__nv_bfloat16 x = input[idx];
|
||||
hadamard32_warp(x);
|
||||
|
||||
float x_fp32 = __bfloat162float(x);
|
||||
float quantized = x_fp32 / scale;
|
||||
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> FusedHadamardQuantFp8(
|
||||
const paddle::Tensor &input,
|
||||
const float scale) {
|
||||
int64_t numel = input.numel();
|
||||
paddle::Tensor out = GetEmptyTensor(
|
||||
input.dims(),
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
input.place());
|
||||
constexpr int64_t thread_per_block = 256;
|
||||
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
|
||||
auto stream = input.stream();
|
||||
FusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
|
||||
scale,
|
||||
numel
|
||||
);
|
||||
return {out};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(fused_hadamard_quant_fp8)
|
||||
.Inputs({"input"})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"scale: float"})
|
||||
.SetKernelFn(PD_KERNEL(FusedHadamardQuantFp8));
|
||||
|
||||
|
||||
paddle::Tensor FusedHadamardQuantFp8Func(
|
||||
const paddle::Tensor &input,
|
||||
const float scale) {
|
||||
return FusedHadamardQuantFp8(input, scale)[0];
|
||||
}
|
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "helper.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
@@ -59,7 +60,12 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &token_num,
|
||||
const paddle::Tensor &seq_len) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = input_ids.stream();
|
||||
#endif
|
||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int seq_length = input_ids_shape[1];
|
||||
@@ -75,7 +81,11 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
int blockSize = std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
|
||||
#else
|
||||
int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
|
||||
#endif
|
||||
GetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
|
||||
padding_offset.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
|
@@ -14,7 +14,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
#include "glog/logging.h"
|
||||
#endif
|
||||
#include <fcntl.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
@@ -35,22 +37,35 @@ namespace cub = hipcub;
|
||||
#else
|
||||
#include <cub/cub.cuh>
|
||||
#endif
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
#include "nlohmann/json.hpp"
|
||||
#endif
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
|
||||
#include "env.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/allocator.h"
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
#include "paddle/phi/backends/custom/custom_context.h"
|
||||
#else
|
||||
#include "paddle/phi/core/cuda_stream.h"
|
||||
#endif
|
||||
#include "paddle/phi/core/dense_tensor.h"
|
||||
#include "paddle/phi/backends/gpu/gpu_info.h"
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
using json = nlohmann::json;
|
||||
#endif
|
||||
|
||||
#define CUDA_CHECK(call) \
|
||||
do { \
|
||||
@@ -199,11 +214,19 @@ HOSTDEVICE inline void Store(const AlignedVector<T, Size> &vec, T *addr) {
|
||||
*addr_vec = vec;
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_HIP
|
||||
template <int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<hip_bfloat16, Size> &vec,
|
||||
int8_t *addr) {
|
||||
printf("Error: Store hip_bfloat16 to int8_t is not supported!");
|
||||
}
|
||||
#else
|
||||
template <int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<__nv_bfloat16, Size> &vec,
|
||||
int8_t *addr) {
|
||||
printf("Error: Store __nv_bfloat16 to int8_t is not supported!");
|
||||
}
|
||||
#endif
|
||||
|
||||
template <int Size>
|
||||
HOSTDEVICE inline void Store(const AlignedVector<half, Size> &vec,
|
||||
@@ -237,6 +260,7 @@ inline int GetBlockSize(int vocab_size) {
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
inline json readJsonFromFile(const std::string &filePath) {
|
||||
std::ifstream file(filePath);
|
||||
if (!file.is_open()) {
|
||||
@@ -247,6 +271,7 @@ inline json readJsonFromFile(const std::string &filePath) {
|
||||
file >> j;
|
||||
return j;
|
||||
}
|
||||
#endif
|
||||
|
||||
#define cudaCheckError() \
|
||||
{ \
|
||||
@@ -418,6 +443,7 @@ inline std::string base64_decode(const std::string &encoded_string) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
template <typename T>
|
||||
inline T get_relative_best(nlohmann::json *json_data,
|
||||
const std::string &target_key,
|
||||
@@ -430,6 +456,7 @@ inline T get_relative_best(nlohmann::json *json_data,
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids,
|
||||
int length) {
|
||||
@@ -459,7 +486,12 @@ template <typename T>
|
||||
static void PrintMatrix3(const T *mat_d, int num, std::string name) {
|
||||
|
||||
std::vector<T> tmp(num);
|
||||
#ifdef PADDLE_WITH_HIP
|
||||
hipMemcpy(tmp.data(), mat_d, sizeof(T) * num, hipMemcpyDeviceToHost);
|
||||
#else
|
||||
cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost);
|
||||
#endif
|
||||
|
||||
|
||||
std::ofstream outfile;
|
||||
outfile.open(name + ".txt", std::ios::out);
|
||||
@@ -476,6 +508,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
#ifndef PADDLE_WITH_HIP
|
||||
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
|
||||
int mode = 0) {
|
||||
uint32_t flag;
|
||||
@@ -515,6 +548,7 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
#endif
|
||||
|
||||
inline int GetSMVersion() {
|
||||
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
|
||||
|
@@ -73,7 +73,6 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -181,7 +180,6 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
@@ -216,7 +214,6 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -50,7 +50,6 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& batch_ids,
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
|
@@ -161,7 +161,7 @@ __global__ void combine_prmt_back_kernel(
|
||||
expanded_permuted_rows + expanded_permuted_row * cols; // prmt后的位置对应的值
|
||||
Load<T, VEC_SIZE>(expanded_permuted_rows_row_ptr + tid * VEC_SIZE, &load_vec);
|
||||
const int expert_idx = expert_for_source_row[k_offset]; // 当前位置对应的专家
|
||||
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的ffn2的bias
|
||||
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的down_proj的bias
|
||||
if (bias_ptr) {
|
||||
Load<T, VEC_SIZE>(bias_ptr + tid * VEC_SIZE, &bias_vec);
|
||||
#pragma unroll
|
||||
@@ -188,7 +188,7 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out,
|
||||
const paddle::Tensor& expert_scales_float,
|
||||
const paddle::Tensor& permute_indices_per_token,
|
||||
const paddle::Tensor& top_k_indices,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const int num_rows,
|
||||
@@ -206,7 +206,7 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out,
|
||||
combine_prmt_back_kernel<<<gridx, threads, 0, stream>>>(
|
||||
ffn_out.data<data_t>(),
|
||||
output->data<data_t>(),
|
||||
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
|
||||
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
|
||||
expert_scales_float.data<float>(),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
top_k_indices.data<int>(),
|
||||
@@ -223,7 +223,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
|
||||
const paddle::Tensor& expert_scales_float, // dst_weights
|
||||
const paddle::Tensor& permute_indices_per_token, // permute_indices_per_token
|
||||
const paddle::Tensor& top_k_indices, // dst_indices
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor) {
|
||||
|
||||
@@ -242,7 +242,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
|
||||
expert_scales_float,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
ffn2_bias,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
@@ -255,7 +255,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
|
||||
expert_scales_float,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
ffn2_bias,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
@@ -274,7 +274,7 @@ __global__ void permute_x_kernel(const T *src_x,
|
||||
const int64_t *topk_idx,
|
||||
const float *topk_weights,
|
||||
const int *token_nums_per_expert,
|
||||
const float *ffn1_in_scale,
|
||||
const float *up_gate_proj_in_scale,
|
||||
const int moe_topk,
|
||||
const int num_rows,
|
||||
const int token_nums_this_rank,
|
||||
@@ -327,9 +327,9 @@ __global__ void permute_x_kernel(const T *src_x,
|
||||
// cp x
|
||||
for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) {
|
||||
Load<T, vec_size>(&src_x[s_token_idx * hidden_size + v_id * vec_size], &src_vec);
|
||||
if (ffn1_in_scale) {
|
||||
if (up_gate_proj_in_scale) {
|
||||
for (int i = 0; i < vec_size; i++) {
|
||||
float quant_value = max_bound * ffn1_in_scale[expert_now] * static_cast<float>(src_vec[i]);
|
||||
float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast<float>(src_vec[i]);
|
||||
if (RoundType == 0) {
|
||||
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
|
||||
} else {
|
||||
@@ -353,7 +353,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const paddle::Tensor& token_nums_per_expert,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_in_scale,
|
||||
const std::string& moe_quant_type,
|
||||
const int moe_topk,
|
||||
const int num_rows,
|
||||
@@ -383,7 +383,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
@@ -404,7 +404,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
@@ -427,7 +427,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
@@ -448,7 +448,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
@@ -472,7 +472,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_in_scale,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const int token_nums_this_rank,
|
||||
const std::string& moe_quant_type) {
|
||||
@@ -516,7 +516,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
num_experts_per_rank_tensor,
|
||||
ffn1_in_scale,
|
||||
up_gate_proj_in_scale,
|
||||
moe_quant_type,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
@@ -536,7 +536,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
num_experts_per_rank_tensor,
|
||||
ffn1_in_scale,
|
||||
up_gate_proj_in_scale,
|
||||
moe_quant_type,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
@@ -568,7 +568,7 @@ std::vector<std::vector<int64_t>> EPMoeExpertDispatchInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& topk_ids_shape,
|
||||
const std::vector<int64_t>& topk_weights_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_in_scale_dtype,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_in_scale_dtype,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const int token_nums_this_rank) {
|
||||
int token_rows = -1;
|
||||
@@ -610,7 +610,7 @@ std::vector<paddle::DataType> EPMoeExpertDispatchInferDtype(
|
||||
|
||||
PD_BUILD_STATIC_OP(ep_moe_expert_dispatch)
|
||||
.Inputs({"input", "topk_ids", "topk_weights",
|
||||
paddle::Optional("ffn1_in_scale")})
|
||||
paddle::Optional("up_gate_proj_in_scale")})
|
||||
.Outputs({"permute_input",
|
||||
"permute_indices_per_token",
|
||||
"token_nums_per_expert_cumsum",
|
||||
@@ -870,7 +870,9 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const paddle::Tensor& num_experts_per_rank_tensor,
|
||||
const paddle::Tensor& num_experts_per_rank_padded_tensor) {
|
||||
const paddle::Tensor& num_experts_per_rank_padded_tensor,
|
||||
const bool use_in_ep,
|
||||
const int token_nums_this_rank_padded) {
|
||||
const auto input_type = input.dtype();
|
||||
const int moe_topk = topk_ids.dims()[1];
|
||||
auto place = input.place();
|
||||
@@ -886,22 +888,21 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
const int hidden_size = input.dims()[input_dims.size() - 1];
|
||||
const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0];
|
||||
|
||||
int32_t token_nums_this_rank_padded = token_rows * moe_topk + num_experts_per_rank * (128-1);
|
||||
// token_nums_this_rank_padded = token_nums_this_rank_padded_useless;
|
||||
int32_t token_nums_feed_to_ffn = use_in_ep ? token_nums_this_rank_padded : token_rows * moe_topk + num_experts_per_rank * (128-1);
|
||||
|
||||
auto permute_input = GetEmptyTensor(
|
||||
{token_nums_this_rank_padded, hidden_size},
|
||||
{token_nums_feed_to_ffn, hidden_size},
|
||||
input_type,
|
||||
place);
|
||||
auto permute_scale = GetEmptyTensor(
|
||||
{token_nums_this_rank_padded, hidden_size / 128},
|
||||
{token_nums_feed_to_ffn, hidden_size / 128},
|
||||
paddle::DataType::FLOAT32,
|
||||
place);
|
||||
|
||||
auto m_indices = paddle::full({token_nums_this_rank_padded}, -1, paddle::DataType::INT32, place);
|
||||
auto m_indices = paddle::full({token_nums_feed_to_ffn}, -1, paddle::DataType::INT32, place);
|
||||
auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
|
||||
auto token_nums_per_expert_padded_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
|
||||
auto dst_weights = GetEmptyTensor({token_nums_this_rank_padded}, paddle::DataType::FLOAT32, place);
|
||||
auto dst_weights = GetEmptyTensor({token_nums_feed_to_ffn}, paddle::DataType::FLOAT32, place);
|
||||
auto dst_indices = GetEmptyTensor({num_rows, num_experts_per_rank}, paddle::DataType::INT32, place);
|
||||
auto permute_indices_per_token = paddle::full({num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place);
|
||||
auto cumsum_idx_gpu = paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place);
|
||||
@@ -949,4 +950,5 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8)
|
||||
"dst_indices",
|
||||
"cumsum_idx_gpu",
|
||||
"m_indices"})
|
||||
.Attrs({"use_in_ep:bool", "token_nums_this_rank_padded:int"})
|
||||
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8));
|
||||
|
@@ -54,12 +54,12 @@ void compute_total_rows_before_expert(int* sorted_indices,
|
||||
template <paddle::DataType T>
|
||||
void FusedMoeKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const std::string& quant_method,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
@@ -84,12 +84,12 @@ void FusedMoeKernel(const paddle::Tensor& input,
|
||||
|
||||
moe_compute.ComputeFFN(&input,
|
||||
&gate_weight,
|
||||
&ffn1_weight,
|
||||
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
|
||||
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
|
||||
&ffn2_weight,
|
||||
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
|
||||
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
|
||||
&up_gate_proj_weight,
|
||||
up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr,
|
||||
up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr,
|
||||
&down_proj_weight,
|
||||
down_proj_scale ? down_proj_scale.get_ptr() : nullptr,
|
||||
down_proj_bias ? down_proj_bias.get_ptr() : nullptr,
|
||||
nullptr,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
@@ -102,12 +102,12 @@ void FusedMoeKernel(const paddle::Tensor& input,
|
||||
paddle::Tensor FusedExpertMoeFunc(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const std::string& quant_method,
|
||||
const int moe_topk,
|
||||
const bool norm_topk_prob,
|
||||
@@ -119,12 +119,12 @@ paddle::Tensor FusedExpertMoeFunc(
|
||||
case paddle::DataType::BFLOAT16:
|
||||
FusedMoeKernel<paddle::DataType::BFLOAT16>(input,
|
||||
gate_weight,
|
||||
ffn1_weight,
|
||||
ffn1_scale,
|
||||
ffn1_bias,
|
||||
ffn2_weight,
|
||||
ffn2_scale,
|
||||
ffn2_bias,
|
||||
up_gate_proj_weight,
|
||||
up_gate_proj_scale,
|
||||
up_gate_proj_bias,
|
||||
down_proj_weight,
|
||||
down_proj_scale,
|
||||
down_proj_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
@@ -134,12 +134,12 @@ paddle::Tensor FusedExpertMoeFunc(
|
||||
case paddle::DataType::FLOAT16:
|
||||
FusedMoeKernel<paddle::DataType::FLOAT16>(input,
|
||||
gate_weight,
|
||||
ffn1_weight,
|
||||
ffn1_scale,
|
||||
ffn1_bias,
|
||||
ffn2_weight,
|
||||
ffn2_scale,
|
||||
ffn2_bias,
|
||||
up_gate_proj_weight,
|
||||
up_gate_proj_scale,
|
||||
up_gate_proj_bias,
|
||||
down_proj_weight,
|
||||
down_proj_scale,
|
||||
down_proj_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
@@ -155,24 +155,24 @@ paddle::Tensor FusedExpertMoeFunc(
|
||||
std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const std::string& quant_method,
|
||||
const int moe_topk,
|
||||
const bool norm_topk_prob,
|
||||
const bool group_moe) {
|
||||
return {FusedExpertMoeFunc(input,
|
||||
gate_weight,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_bias,
|
||||
ffn2_scale,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_bias,
|
||||
down_proj_scale,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
norm_topk_prob,
|
||||
@@ -182,30 +182,30 @@ std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& gate_weight_shape,
|
||||
const std::vector<int64_t>& ffn1_weight_shape,
|
||||
const std::vector<int64_t>& ffn2_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
|
||||
const std::vector<int64_t>& up_gate_proj_weight_shape,
|
||||
const std::vector<int64_t>& down_proj_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape) {
|
||||
return {input_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& gate_weight_dtype,
|
||||
const paddle::DataType& ffn1_weight_dtype,
|
||||
const paddle::DataType& ffn2_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
|
||||
const paddle::DataType& up_gate_proj_weight_dtype,
|
||||
const paddle::DataType& down_proj_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& down_proj_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& down_proj_scale_dtype) {
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Fused Mixture-of-Experts (MoE) Operator
|
||||
*
|
||||
*
|
||||
* This operator combines three key MoE operations into a single optimized kernel:
|
||||
* 1. moe_dispatch - Routes tokens to top-k experts using gating network
|
||||
* 2. moe_ffn - Processes tokens through parallel expert FFNs
|
||||
@@ -230,12 +230,12 @@ std::vector<paddle::DataType> FusedExpertMoeInferDtype(
|
||||
PD_BUILD_STATIC_OP(fused_expert_moe)
|
||||
.Inputs({"input",
|
||||
"gate_weight",
|
||||
"ffn1_weight",
|
||||
"ffn2_weight",
|
||||
paddle::Optional("ffn1_bias"),
|
||||
paddle::Optional("ffn1_scale"),
|
||||
paddle::Optional("ffn2_bias"),
|
||||
paddle::Optional("ffn2_scale")})
|
||||
"up_gate_proj_weight",
|
||||
"down_proj_weight",
|
||||
paddle::Optional("up_gate_proj_bias"),
|
||||
paddle::Optional("up_gate_proj_scale"),
|
||||
paddle::Optional("down_proj_bias"),
|
||||
paddle::Optional("down_proj_scale")})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"quant_method:std::string",
|
||||
"moe_topk:int",
|
||||
|
@@ -117,18 +117,18 @@ public:
|
||||
|
||||
void
|
||||
ComputeFFN(const paddle::Tensor *input, const paddle::Tensor *gate_weight,
|
||||
const paddle::Tensor *ffn1_weight,
|
||||
const paddle::Tensor *ffn1_scale, const paddle::Tensor *ffn1_bias,
|
||||
const paddle::Tensor *ffn2_weight,
|
||||
const paddle::Tensor *ffn2_scale, const paddle::Tensor *ffn2_bias,
|
||||
const paddle::Tensor *up_gate_proj_weight,
|
||||
const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_bias,
|
||||
const paddle::Tensor *down_proj_weight,
|
||||
const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_bias,
|
||||
const paddle::Tensor *moe_token_type_ids, const int moe_topk,
|
||||
const bool group_moe, const bool norm_topk_prob,
|
||||
const float routed_scaling_factor, const std::string moe_type,
|
||||
paddle::Tensor *output) {
|
||||
auto *input_activations = input->data<T>();
|
||||
auto *gating_weights = gate_weight->data<float>();
|
||||
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
|
||||
const T *fc1_expert_biases = up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases = down_proj_bias ? down_proj_bias->data<T>() : nullptr;
|
||||
|
||||
auto *output_ = output->data<T>();
|
||||
auto stream = input->stream();
|
||||
@@ -136,7 +136,7 @@ public:
|
||||
auto input_type = input->dtype();
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto ffn1_dims = ffn1_weight->dims();
|
||||
auto up_gate_proj_dims = up_gate_proj_weight->dims();
|
||||
int64_t token_num = 0;
|
||||
if (input_dims.size() == 3) {
|
||||
token_num = input_dims[0] * input_dims[1];
|
||||
@@ -145,12 +145,12 @@ public:
|
||||
}
|
||||
const int64_t num_rows = token_num;
|
||||
|
||||
const int64_t hidden_size = ffn1_dims[1];
|
||||
const int64_t hidden_size = up_gate_proj_dims[1];
|
||||
int64_t inter_dim = 0;
|
||||
if (moe_type == "qkv") {
|
||||
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
|
||||
inter_dim = up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
|
||||
} else {
|
||||
inter_dim = ffn1_dims[2];
|
||||
inter_dim = up_gate_proj_dims[2];
|
||||
}
|
||||
|
||||
if (gemm_method_ == "weight_only_int4") {
|
||||
@@ -158,7 +158,7 @@ public:
|
||||
}
|
||||
|
||||
const int64_t inter_size = inter_dim;
|
||||
const int64_t num_experts = ffn1_dims[0];
|
||||
const int64_t num_experts = up_gate_proj_dims[0];
|
||||
const int64_t k = moe_topk;
|
||||
|
||||
int64_t bytes =
|
||||
@@ -260,38 +260,38 @@ public:
|
||||
total_rows_before_expert_, stream);
|
||||
|
||||
if (gemm_method_ == "weight_only_int8") {
|
||||
typename Int8Traits::Arguments ffn1_quant_args;
|
||||
typename Int8Traits::Arguments up_gate_proj_quant_args;
|
||||
int8_moe_gemm_runner_->moe_gemm_bias_act(
|
||||
reinterpret_cast<NvType *>(permuted_data_),
|
||||
reinterpret_cast<const uint8_t *>(ffn1_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(ffn1_scale->data<T>()),
|
||||
reinterpret_cast<const uint8_t *>(up_gate_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
|
||||
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
||||
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||
ffn1_quant_args, "none", stream);
|
||||
up_gate_proj_quant_args, "none", stream);
|
||||
} else if (gemm_method_ == "weight_only_int4") {
|
||||
typename Int4Traits::Arguments ffn1_quant_args;
|
||||
typename Int4Traits::Arguments up_gate_proj_quant_args;
|
||||
int4_moe_gemm_runner_->moe_gemm_bias_act(
|
||||
reinterpret_cast<NvType *>(permuted_data_),
|
||||
reinterpret_cast<const cutlass::uint4b_t *>(
|
||||
ffn1_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(ffn1_scale->data<T>()),
|
||||
up_gate_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
|
||||
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
||||
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||
ffn1_quant_args, "none", stream);
|
||||
up_gate_proj_quant_args, "none", stream);
|
||||
} else {
|
||||
typename Fp16Traits::Arguments ffn1_quant_args;
|
||||
typename Fp16Traits::Arguments up_gate_proj_quant_args;
|
||||
fp16_moe_gemm_runner_->moe_gemm_bias_act(
|
||||
reinterpret_cast<NvType *>(permuted_data_),
|
||||
reinterpret_cast<const NvType *>(ffn1_weight->data<T>()), nullptr,
|
||||
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()), nullptr,
|
||||
reinterpret_cast<const NvType *>(fc1_expert_biases),
|
||||
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
|
||||
ffn1_quant_args, "none", stream);
|
||||
up_gate_proj_quant_args, "none", stream);
|
||||
}
|
||||
|
||||
if (moe_type == "ffn") {
|
||||
@@ -304,35 +304,35 @@ public:
|
||||
T *fc2_result = fc2_output_tensor.data<T>();
|
||||
|
||||
if (gemm_method_ == "weight_only_int8") {
|
||||
typename Int8Traits::Arguments ffn2_quant_args;
|
||||
typename Int8Traits::Arguments down_proj_quant_args;
|
||||
int8_moe_gemm_runner_->moe_gemm(
|
||||
reinterpret_cast<NvType *>(act_out),
|
||||
reinterpret_cast<const uint8_t *>(ffn2_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(ffn2_scale->data<T>()),
|
||||
reinterpret_cast<const uint8_t *>(down_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
|
||||
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||
num_experts, ffn2_quant_args, stream);
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
} else if (gemm_method_ == "weight_only_int4") {
|
||||
typename Int4Traits::Arguments ffn2_quant_args;
|
||||
typename Int4Traits::Arguments down_proj_quant_args;
|
||||
int4_moe_gemm_runner_->moe_gemm(
|
||||
reinterpret_cast<NvType *>(act_out),
|
||||
reinterpret_cast<const cutlass::uint4b_t *>(
|
||||
ffn2_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(ffn2_scale->data<T>()),
|
||||
down_proj_weight->data<int8_t>()),
|
||||
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
|
||||
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||
num_experts, ffn2_quant_args, stream);
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
} else {
|
||||
typename Fp16Traits::Arguments ffn2_quant_args;
|
||||
typename Fp16Traits::Arguments down_proj_quant_args;
|
||||
fp16_moe_gemm_runner_->moe_gemm(
|
||||
reinterpret_cast<NvType *>(act_out),
|
||||
reinterpret_cast<const NvType *>(ffn2_weight->data<T>()), nullptr,
|
||||
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()), nullptr,
|
||||
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
|
||||
-1, // useless
|
||||
expanded_active_expert_rows, hidden_size, inter_size / 2,
|
||||
num_experts, ffn2_quant_args, stream);
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
}
|
||||
|
||||
finalize_moe_routing_kernelLauncher<T>::run(
|
||||
|
@@ -24,12 +24,12 @@
|
||||
template <paddle::DataType T>
|
||||
void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out,
|
||||
@@ -51,11 +51,11 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
|
||||
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
|
||||
|
||||
const int num_experts = ffn1_weight.dims()[0];
|
||||
const int num_experts = up_gate_proj_weight.dims()[0];
|
||||
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
|
||||
|
||||
assert(ffn1_weight.dims().size() == 3);
|
||||
int inter_dim = ffn1_weight.dims()[1] * ffn1_weight.dims()[2] / hidden_size;
|
||||
assert(up_gate_proj_weight.dims().size() == 3);
|
||||
int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size;
|
||||
|
||||
constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k
|
||||
Allocator* allocator = paddle::GetAllocator(place);
|
||||
@@ -96,8 +96,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
using NvType = typename traits_::DataType;
|
||||
|
||||
auto fc1_expert_biases =
|
||||
ffn1_bias
|
||||
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
|
||||
up_gate_proj_bias
|
||||
? const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr())->data<data_t>()
|
||||
: nullptr;
|
||||
|
||||
// This is a trick.
|
||||
@@ -112,9 +112,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
|
||||
int8_moe_gemm_runner.moe_gemm_bias_act(
|
||||
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
|
||||
reinterpret_cast<const uint8_t*>(ffn1_weight.data<int8_t>()),
|
||||
reinterpret_cast<const uint8_t*>(up_gate_proj_weight.data<int8_t>()),
|
||||
reinterpret_cast<const NvType*>(
|
||||
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())
|
||||
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
|
||||
->data<data_t>()),
|
||||
reinterpret_cast<const NvType*>(fc1_expert_biases),
|
||||
reinterpret_cast<NvType*>(fc1_out),
|
||||
@@ -132,9 +132,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
int4_moe_gemm_runner.moe_gemm_bias_act(
|
||||
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
|
||||
reinterpret_cast<const cutlass::uint4b_t*>(
|
||||
ffn1_weight.data<int8_t>()),
|
||||
up_gate_proj_weight.data<int8_t>()),
|
||||
reinterpret_cast<const NvType*>(
|
||||
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())
|
||||
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
|
||||
->data<data_t>()),
|
||||
reinterpret_cast<const NvType*>(fc1_expert_biases),
|
||||
reinterpret_cast<NvType*>(fc1_out),
|
||||
@@ -151,12 +151,12 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
w4a8_moe_gemm_runner.moe_gemm(
|
||||
reinterpret_cast<const int8_t *>(permute_input.data<int8_t>()),
|
||||
reinterpret_cast<const cutlass::uint4b_t *>(
|
||||
ffn1_weight.data<int8_t>()),
|
||||
up_gate_proj_weight.data<int8_t>()),
|
||||
quant_mode,
|
||||
reinterpret_cast<const NvType*>(
|
||||
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())
|
||||
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
|
||||
->data<data_t>()),
|
||||
nullptr, // ffn1_scale_dyquant
|
||||
nullptr, // up_gate_proj_scale_dyquant
|
||||
nullptr, // nf4_look_up_table
|
||||
reinterpret_cast<NvType *>(fc1_out),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
@@ -172,7 +172,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
|
||||
fp16_moe_gemm_runner.moe_gemm_bias_act(
|
||||
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
|
||||
reinterpret_cast<const NvType*>(ffn1_weight.data<data_t>()),
|
||||
reinterpret_cast<const NvType*>(up_gate_proj_weight.data<data_t>()),
|
||||
nullptr,
|
||||
reinterpret_cast<const NvType*>(fc1_expert_biases),
|
||||
reinterpret_cast<NvType*>(fc1_out),
|
||||
@@ -199,9 +199,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
|
||||
int8_moe_gemm_runner.moe_gemm(
|
||||
reinterpret_cast<const NvType*>(act_out),
|
||||
reinterpret_cast<const uint8_t*>(ffn2_weight.data<int8_t>()),
|
||||
reinterpret_cast<const uint8_t*>(down_proj_weight.data<int8_t>()),
|
||||
reinterpret_cast<const NvType*>(
|
||||
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())
|
||||
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
|
||||
->data<data_t>()),
|
||||
reinterpret_cast<NvType*>(ffn_out_data),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
@@ -218,9 +218,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
int4_moe_gemm_runner.moe_gemm(
|
||||
reinterpret_cast<const NvType*>(act_out),
|
||||
reinterpret_cast<const cutlass::uint4b_t*>(
|
||||
ffn2_weight.data<int8_t>()),
|
||||
down_proj_weight.data<int8_t>()),
|
||||
reinterpret_cast<const NvType*>(
|
||||
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())
|
||||
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
|
||||
->data<data_t>()),
|
||||
reinterpret_cast<NvType*>(ffn_out_data),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
@@ -232,17 +232,17 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
quant_args,
|
||||
stream);
|
||||
} else if (quant_method == "w4a8") {
|
||||
data_t *ffn2_shift = nullptr;
|
||||
data_t *ffn2_smooth = nullptr;
|
||||
data_t *down_proj_shift = nullptr;
|
||||
data_t *down_proj_smooth = nullptr;
|
||||
Allocator::AllocationPtr int8_act_out;
|
||||
int8_act_out = allocator->Allocate(
|
||||
SizeOf(paddle::DataType::INT8) * act_out_tensor.numel());
|
||||
MoeFastHardamardWrapper<data_t, int8_t>(
|
||||
act_out_tensor.data<data_t>(),
|
||||
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
|
||||
ffn2_shift, // ffn2_shift->data<T>(),
|
||||
ffn2_smooth, // ffn2_smooth->data<T>(),
|
||||
ffn2_in_scale ? const_cast<paddle::Tensor*>(ffn2_in_scale.get_ptr())->data<float>() : nullptr,
|
||||
down_proj_shift, // down_proj_shift->data<T>(),
|
||||
down_proj_smooth, // down_proj_smooth->data<T>(),
|
||||
down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr,
|
||||
1,
|
||||
127.0,
|
||||
-127.0,
|
||||
@@ -254,12 +254,12 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
w4a8_moe_gemm_runner.moe_gemm(
|
||||
reinterpret_cast<int8_t *>(int8_act_out->ptr()),
|
||||
reinterpret_cast<const cutlass::uint4b_t *>(
|
||||
ffn2_weight.data<int8_t>()),
|
||||
down_proj_weight.data<int8_t>()),
|
||||
quant_mode,
|
||||
reinterpret_cast<const NvType*>(
|
||||
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())
|
||||
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
|
||||
->data<data_t>()),
|
||||
nullptr, // ffn2_scale_dyquant
|
||||
nullptr, // down_proj_scale_dyquant
|
||||
nullptr, // reinterpret_cast<const int32_t*>(d_nf4_look_up_table), // nf4_look_up_table
|
||||
reinterpret_cast<NvType *>(ffn_out_data),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
@@ -275,7 +275,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
|
||||
fp16_moe_gemm_runner.moe_gemm(
|
||||
reinterpret_cast<const NvType*>(act_out),
|
||||
reinterpret_cast<const NvType*>(ffn2_weight.data<data_t>()),
|
||||
reinterpret_cast<const NvType*>(down_proj_weight.data<data_t>()),
|
||||
nullptr,
|
||||
reinterpret_cast<NvType*>(ffn_out_data),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
@@ -292,29 +292,29 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency) {
|
||||
|
||||
cudaCheckError();
|
||||
const auto t_type = quant_method == "w4a8" ? ffn1_scale.get().dtype() : permute_input.dtype();
|
||||
const auto t_type = quant_method == "w4a8" ? up_gate_proj_scale.get().dtype() : permute_input.dtype();
|
||||
auto ffn_out = paddle::empty_like(permute_input, t_type);
|
||||
|
||||
switch (t_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
ffn2_in_scale,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_scale,
|
||||
down_proj_in_scale,
|
||||
expert_idx_per_token,
|
||||
quant_method,
|
||||
ffn_out, used_in_ep_low_latency);
|
||||
@@ -322,12 +322,12 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
ffn2_in_scale,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_scale,
|
||||
down_proj_in_scale,
|
||||
expert_idx_per_token,
|
||||
quant_method,
|
||||
ffn_out, used_in_ep_low_latency);
|
||||
@@ -341,22 +341,22 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency) {
|
||||
return {MoeExpertFFNFunc(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
ffn2_in_scale,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_scale,
|
||||
down_proj_in_scale,
|
||||
expert_idx_per_token,
|
||||
quant_method, used_in_ep_low_latency)};
|
||||
}
|
||||
@@ -364,12 +364,12 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
const std::vector<int64_t>& permute_input_shape,
|
||||
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
|
||||
const std::vector<int64_t>& ffn1_weight_shape,
|
||||
const std::vector<int64_t>& ffn2_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_in_scale_shape,
|
||||
const std::vector<int64_t>& up_gate_proj_weight_shape,
|
||||
const std::vector<int64_t>& down_proj_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_in_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
|
||||
const std::string& quant_method,
|
||||
const bool used_in_ep_low_latency) {
|
||||
@@ -379,15 +379,15 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
const paddle::DataType &permute_input_dtype,
|
||||
const paddle::DataType &tokens_expert_prefix_sum_dtype,
|
||||
const paddle::DataType &ffn1_weight_dtype,
|
||||
const paddle::DataType &ffn2_weight_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn1_bias_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn1_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn2_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn2_in_scale_dtype,
|
||||
const paddle::DataType &up_gate_proj_weight_dtype,
|
||||
const paddle::DataType &down_proj_weight_dtype,
|
||||
const paddle::optional<paddle::DataType> &up_gate_proj_bias_dtype,
|
||||
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
|
||||
const std::string &quant_method, const bool used_in_ep_low_latency) {
|
||||
if (quant_method == "w4a8") {
|
||||
return {ffn1_scale_dtype.get()};
|
||||
return {up_gate_proj_scale_dtype.get()};
|
||||
} else {
|
||||
return {permute_input_dtype};
|
||||
}
|
||||
@@ -397,9 +397,9 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
* @brief Mixture of Experts (MoE) Feed-Forward Network Operator
|
||||
*
|
||||
* This operator performs the expert computation in MoE architecture, including:
|
||||
* 1. First linear transformation (FFN1) with optional quantization
|
||||
* 1. First linear transformation (up_gate_proj) with optional quantization
|
||||
* 2. SwiGLU activation function
|
||||
* 3. Second linear transformation (FFN2) with optional quantization
|
||||
* 3. Second linear transformation (down_proj) with optional quantization
|
||||
*
|
||||
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
|
||||
*
|
||||
@@ -410,22 +410,22 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
|
||||
* Shape: [num_experts]
|
||||
* dtype: int64
|
||||
* - ffn1_weight: First FFN layer weights
|
||||
* - up_gate_proj_weight: First FFN layer weights
|
||||
* Shape: [num_experts, inter_size * 2, hidden_size]
|
||||
* dtype: Same as input (unquantized) or int8 (quantized)
|
||||
* - ffn2_weight: Second FFN layer weights
|
||||
* - down_proj_weight: Second FFN layer weights
|
||||
* Shape: [num_experts, hidden_size, inter_size]
|
||||
* dtype: Same as input (unquantized) or int8 (quantized)
|
||||
* - ffn1_bias: Optional bias for first FFN layer
|
||||
* - up_gate_proj_bias: Optional bias for first FFN layer
|
||||
* Shape: [num_experts, inter_size * 2]
|
||||
* dtype: Same as input
|
||||
* - ffn1_scale: Quantization scales for first FFN layer
|
||||
* - up_gate_proj_scale: Quantization scales for first FFN layer
|
||||
* Shape: [num_experts, inter_size * 2]
|
||||
* dtype: Same as input
|
||||
* - ffn2_scale: Quantization scales for second FFN layer
|
||||
* - down_proj_scale: Quantization scales for second FFN layer
|
||||
* Shape: [num_experts, hidden_size]
|
||||
* dtype: Same as input
|
||||
* - ffn2_in_scale: Optional input scales for second FFN layer (w4a8 only)
|
||||
* - down_proj_in_scale: Optional input scales for second FFN layer (w4a8 only)
|
||||
* dtype: float32
|
||||
* - expert_idx_per_token: Optional expert indices per token (w4a8 only)
|
||||
* Shape: [total_tokens]
|
||||
@@ -434,7 +434,7 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
* Outputs:
|
||||
* - output_tensor: Output tensor after MoE FFN computation
|
||||
* Shape: Same as permute_input
|
||||
* dtype: Same as input (or ffn1_scale dtype for w4a8)
|
||||
* dtype: Same as input (or up_gate_proj_scale dtype for w4a8)
|
||||
*
|
||||
* Attributes:
|
||||
* - quant_method: Quantization method to use
|
||||
@@ -449,12 +449,12 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
PD_BUILD_STATIC_OP(moe_expert_ffn)
|
||||
.Inputs({"permute_input",
|
||||
"tokens_expert_prefix_sum",
|
||||
"ffn1_weight",
|
||||
"ffn2_weight",
|
||||
paddle::Optional("ffn1_bias"),
|
||||
paddle::Optional("ffn1_scale"),
|
||||
paddle::Optional("ffn2_scale"),
|
||||
paddle::Optional("ffn2_in_scale"),
|
||||
"up_gate_proj_weight",
|
||||
"down_proj_weight",
|
||||
paddle::Optional("up_gate_proj_bias"),
|
||||
paddle::Optional("up_gate_proj_scale"),
|
||||
paddle::Optional("down_proj_scale"),
|
||||
paddle::Optional("down_proj_in_scale"),
|
||||
paddle::Optional("expert_idx_per_token")})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool"})
|
||||
|
@@ -23,17 +23,17 @@
|
||||
template <typename DataT, typename NvType, typename WeightSavedT, cutlass::WintQuantMethod QuantMethod>
|
||||
void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::Tensor* ffn1_bias,
|
||||
const paddle::Tensor* ffn1_super_scale,
|
||||
const paddle::Tensor* ffn2_super_scale,
|
||||
const paddle::Tensor* ffn1_local_scale,
|
||||
const paddle::Tensor* ffn1_code_scale,
|
||||
const paddle::Tensor* ffn1_code_zp,
|
||||
const paddle::Tensor* ffn2_local_scale,
|
||||
const paddle::Tensor* ffn2_code_scale,
|
||||
const paddle::Tensor* ffn2_code_zp,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::Tensor* up_gate_proj_bias,
|
||||
const paddle::Tensor* up_gate_proj_super_scale,
|
||||
const paddle::Tensor* down_proj_super_scale,
|
||||
const paddle::Tensor* up_gate_proj_local_scale,
|
||||
const paddle::Tensor* up_gate_proj_code_scale,
|
||||
const paddle::Tensor* up_gate_proj_code_zp,
|
||||
const paddle::Tensor* down_proj_local_scale,
|
||||
const paddle::Tensor* down_proj_code_scale,
|
||||
const paddle::Tensor* down_proj_code_zp,
|
||||
paddle::Tensor fc1_out,
|
||||
paddle::Tensor ffn_out,
|
||||
const int64_t total_rows_in_ll_else_minus1,
|
||||
@@ -46,15 +46,15 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
using WeightOnlyTraits = cutlass::WintQuantTraits<NvType, QuantMethod>;
|
||||
using WeightType = typename WeightOnlyTraits::WeightType;
|
||||
|
||||
typename WeightOnlyTraits::Arguments ffn1_quant_args;
|
||||
typename WeightOnlyTraits::Arguments ffn2_quant_args;
|
||||
typename WeightOnlyTraits::Arguments up_gate_proj_quant_args;
|
||||
typename WeightOnlyTraits::Arguments down_proj_quant_args;
|
||||
if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
|
||||
ffn1_quant_args.local_scale_ptr = ffn1_local_scale->data<uint8_t>();
|
||||
ffn1_quant_args.code_scale_ptr = ffn1_code_scale->data<float>();
|
||||
ffn1_quant_args.code_zp_ptr = ffn1_code_zp->data<float>();
|
||||
ffn2_quant_args.local_scale_ptr = ffn2_local_scale->data<uint8_t>();
|
||||
ffn2_quant_args.code_scale_ptr = ffn2_code_scale->data<float>();
|
||||
ffn2_quant_args.code_zp_ptr = ffn2_code_zp->data<float>();
|
||||
up_gate_proj_quant_args.local_scale_ptr = up_gate_proj_local_scale->data<uint8_t>();
|
||||
up_gate_proj_quant_args.code_scale_ptr = up_gate_proj_code_scale->data<float>();
|
||||
up_gate_proj_quant_args.code_zp_ptr = up_gate_proj_code_zp->data<float>();
|
||||
down_proj_quant_args.local_scale_ptr = down_proj_local_scale->data<uint8_t>();
|
||||
down_proj_quant_args.code_scale_ptr = down_proj_code_scale->data<float>();
|
||||
down_proj_quant_args.code_zp_ptr = down_proj_code_zp->data<float>();
|
||||
}
|
||||
|
||||
auto moe_gemm_runner = MoeGemmRunner<NvType, WeightOnlyTraits>();
|
||||
@@ -62,9 +62,9 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
|
||||
moe_gemm_runner.moe_gemm_bias_act(
|
||||
reinterpret_cast<const NvType*>(permute_input.data<DataT>()),
|
||||
reinterpret_cast<const WeightType*>(ffn1_weight.data<WeightSavedT>()),
|
||||
reinterpret_cast<const NvType*>(ffn1_super_scale ? ffn1_super_scale->data<DataT>() : nullptr),
|
||||
reinterpret_cast<const NvType*>(ffn1_bias ? ffn1_bias->data<DataT>() : nullptr),
|
||||
reinterpret_cast<const WeightType*>(up_gate_proj_weight.data<WeightSavedT>()),
|
||||
reinterpret_cast<const NvType*>(up_gate_proj_super_scale ? up_gate_proj_super_scale->data<DataT>() : nullptr),
|
||||
reinterpret_cast<const NvType*>(up_gate_proj_bias ? up_gate_proj_bias->data<DataT>() : nullptr),
|
||||
reinterpret_cast<NvType*>(fc1_out.data<DataT>()),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
total_rows_in_ll_else_minus1,
|
||||
@@ -72,7 +72,7 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
ffn1_quant_args,
|
||||
up_gate_proj_quant_args,
|
||||
"none",
|
||||
stream);
|
||||
|
||||
@@ -85,8 +85,8 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
|
||||
moe_gemm_runner.moe_gemm(
|
||||
reinterpret_cast<const NvType*>(act_out.data<DataT>()),
|
||||
reinterpret_cast<const WeightType*>(ffn2_weight.data<WeightSavedT>()),
|
||||
reinterpret_cast<const NvType*>(ffn2_super_scale ? ffn2_super_scale->data<DataT>() : nullptr),
|
||||
reinterpret_cast<const WeightType*>(down_proj_weight.data<WeightSavedT>()),
|
||||
reinterpret_cast<const NvType*>(down_proj_super_scale ? down_proj_super_scale->data<DataT>() : nullptr),
|
||||
reinterpret_cast<NvType*>(ffn_out.data<DataT>()),
|
||||
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
|
||||
total_rows_in_ll_else_minus1,
|
||||
@@ -94,24 +94,24 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
num_experts,
|
||||
ffn2_quant_args,
|
||||
down_proj_quant_args,
|
||||
stream);
|
||||
}
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
|
||||
paddle::Tensor ffn_out,
|
||||
bool used_in_ep_low_latency) {
|
||||
using namespace phi;
|
||||
@@ -121,12 +121,12 @@ void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
|
||||
auto place = permute_input.place();
|
||||
|
||||
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
|
||||
assert(ffn1_weight.dims().size() == 3);
|
||||
assert(up_gate_proj_weight.dims().size() == 3);
|
||||
|
||||
const int num_experts = ffn1_weight.dims()[0];
|
||||
const int num_experts = up_gate_proj_weight.dims()[0];
|
||||
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
|
||||
|
||||
int inter_dim = ffn1_weight.dims()[1] * ffn1_weight.dims()[2] / hidden_size;
|
||||
int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size;
|
||||
|
||||
const int64_t inter_size = inter_dim * 4;
|
||||
|
||||
@@ -160,17 +160,17 @@ void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
|
||||
WeightOnlyMoeFFNKernel<data_t, NvType, uint8_t, cutlass::WintQuantMethod::kWeightOnlyInt2>(
|
||||
permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
const_cast<paddle::Tensor*>(ffn1_bias.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn1_local_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn1_code_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn1_code_zp.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn2_local_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn2_code_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(ffn2_code_zp.get_ptr()),
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(up_gate_proj_local_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(up_gate_proj_code_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(up_gate_proj_code_zp.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(down_proj_local_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(down_proj_code_scale.get_ptr()),
|
||||
const_cast<paddle::Tensor*>(down_proj_code_zp.get_ptr()),
|
||||
fc1_out_tensor,
|
||||
ffn_out,
|
||||
total_rows_in_ll_else_minus1,
|
||||
@@ -184,17 +184,17 @@ void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
|
||||
const bool used_in_ep_low_latency) {
|
||||
|
||||
const auto dtype = permute_input.dtype();
|
||||
@@ -204,34 +204,34 @@ paddle::Tensor MoeExpertFFNWint2Func(
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeFFNWint2Kernel<paddle::DataType::BFLOAT16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
ffn1_local_scale,
|
||||
ffn1_code_scale,
|
||||
ffn1_code_zp,
|
||||
ffn2_local_scale,
|
||||
ffn2_code_scale,
|
||||
ffn2_code_zp,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_scale,
|
||||
up_gate_proj_local_scale,
|
||||
up_gate_proj_code_scale,
|
||||
up_gate_proj_code_zp,
|
||||
down_proj_local_scale,
|
||||
down_proj_code_scale,
|
||||
down_proj_code_zp,
|
||||
ffn_out,
|
||||
used_in_ep_low_latency);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeFFNWint2Kernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
ffn1_local_scale,
|
||||
ffn1_code_scale,
|
||||
ffn1_code_zp,
|
||||
ffn2_local_scale,
|
||||
ffn2_code_scale,
|
||||
ffn2_code_zp,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_scale,
|
||||
up_gate_proj_local_scale,
|
||||
up_gate_proj_code_scale,
|
||||
up_gate_proj_code_zp,
|
||||
down_proj_local_scale,
|
||||
down_proj_code_scale,
|
||||
down_proj_code_zp,
|
||||
ffn_out,
|
||||
used_in_ep_low_latency);
|
||||
break;
|
||||
@@ -244,49 +244,49 @@ paddle::Tensor MoeExpertFFNWint2Func(
|
||||
std::vector<paddle::Tensor> MoeExpertFFNWint2(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
|
||||
const bool used_in_ep_low_latency) {
|
||||
|
||||
return {MoeExpertFFNWint2Func(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
ffn1_local_scale,
|
||||
ffn1_code_scale,
|
||||
ffn1_code_zp,
|
||||
ffn2_local_scale,
|
||||
ffn2_code_scale,
|
||||
ffn2_code_zp,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_scale,
|
||||
up_gate_proj_local_scale,
|
||||
up_gate_proj_code_scale,
|
||||
up_gate_proj_code_zp,
|
||||
down_proj_local_scale,
|
||||
down_proj_code_scale,
|
||||
down_proj_code_zp,
|
||||
used_in_ep_low_latency)};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNWint2InferShape(
|
||||
const std::vector<int64_t>& permute_input_shape,
|
||||
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
|
||||
const std::vector<int64_t>& ffn1_weight_shape,
|
||||
const std::vector<int64_t>& ffn2_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_local_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_code_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_code_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_local_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_code_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_code_zp_shape,
|
||||
const std::vector<int64_t>& up_gate_proj_weight_shape,
|
||||
const std::vector<int64_t>& down_proj_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_local_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_code_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_code_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_local_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_code_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_code_zp_shape,
|
||||
const bool used_in_ep_low_latency) {
|
||||
|
||||
return {permute_input_shape};
|
||||
@@ -295,17 +295,17 @@ std::vector<std::vector<int64_t>> MoeExpertFFNWint2InferShape(
|
||||
std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
|
||||
const paddle::DataType &permute_input_dtype,
|
||||
const paddle::DataType &tokens_expert_prefix_sum_dtype,
|
||||
const paddle::DataType &ffn1_weight_dtype,
|
||||
const paddle::DataType &ffn2_weight_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn1_bias_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn1_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn2_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn1_local_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn1_code_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn1_code_zp_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn2_local_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn2_code_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn2_code_zp_dtype,
|
||||
const paddle::DataType &up_gate_proj_weight_dtype,
|
||||
const paddle::DataType &down_proj_weight_dtype,
|
||||
const paddle::optional<paddle::DataType> &up_gate_proj_bias_dtype,
|
||||
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &up_gate_proj_local_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &up_gate_proj_code_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &up_gate_proj_code_zp_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_local_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_code_scale_dtype,
|
||||
const paddle::optional<paddle::DataType> &down_proj_code_zp_dtype,
|
||||
const bool used_in_ep_low_latency) {
|
||||
|
||||
return {permute_input_dtype};
|
||||
@@ -315,9 +315,9 @@ std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
|
||||
* @brief Weight-Only Quantized Mixture of Experts (MoE) Feed-Forward Network Operator
|
||||
*
|
||||
* This operator performs the expert computation in MoE architecture, including:
|
||||
* 1. First linear transformation (FFN1) with optional quantization
|
||||
* 1. First linear transformation (up_gate_proj) with optional quantization
|
||||
* 2. SwiGLU activation function
|
||||
* 3. Second linear transformation (FFN2) with optional quantization
|
||||
* 3. Second linear transformation (down_proj) with optional quantization
|
||||
*
|
||||
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
|
||||
*
|
||||
@@ -328,26 +328,26 @@ std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
|
||||
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
|
||||
* Shape: [num_experts]
|
||||
* dtype: int64
|
||||
* - ffn1_weight: First FFN layer weights
|
||||
* - up_gate_proj_weight: First FFN layer weights
|
||||
* Shape: [num_experts, inter_size * 2, hidden_size]
|
||||
* dtype: Same as input (unquantized) or int8 (quantized)
|
||||
* - ffn2_weight: Second FFN layer weights
|
||||
* - down_proj_weight: Second FFN layer weights
|
||||
* Shape: [num_experts, hidden_size, inter_size]
|
||||
* dtype: Same as input (unquantized) or int8 (quantized)
|
||||
* - ffn1_bias: Optional bias for first FFN layer
|
||||
* - up_gate_proj_bias: Optional bias for first FFN layer
|
||||
* Shape: [num_experts, inter_size * 2]
|
||||
* dtype: Same as input
|
||||
* - ffn1_scale: Quantization scales for first FFN layer
|
||||
* - up_gate_proj_scale: Quantization scales for first FFN layer
|
||||
* Shape: [num_experts, inter_size * 2]
|
||||
* dtype: Same as input
|
||||
* - ffn2_scale: Quantization scales for second FFN layer
|
||||
* - down_proj_scale: Quantization scales for second FFN layer
|
||||
* Shape: [num_experts, hidden_size]
|
||||
* dtype: Same as input
|
||||
*
|
||||
* Outputs:
|
||||
* - output_tensor: Output tensor after MoE FFN computation
|
||||
* Shape: Same as permute_input
|
||||
* dtype: Same as input (or ffn1_scale dtype for w4a8)
|
||||
* dtype: Same as input (or up_gate_proj_scale dtype for w4a8)
|
||||
*
|
||||
* Attributes:
|
||||
* - used_in_ep_low_latency: Whether running in low latency mode
|
||||
@@ -359,17 +359,17 @@ std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
|
||||
PD_BUILD_STATIC_OP(moe_expert_ffn_wint2)
|
||||
.Inputs({"permute_input",
|
||||
"tokens_expert_prefix_sum",
|
||||
"ffn1_weight",
|
||||
"ffn2_weight",
|
||||
paddle::Optional("ffn1_bias"),
|
||||
paddle::Optional("ffn1_scale"),
|
||||
paddle::Optional("ffn2_scale"),
|
||||
paddle::Optional("ffn1_local_scale"),
|
||||
paddle::Optional("ffn1_code_scale"),
|
||||
paddle::Optional("ffn1_code_zp"),
|
||||
paddle::Optional("ffn2_local_scale"),
|
||||
paddle::Optional("ffn2_code_scale"),
|
||||
paddle::Optional("ffn2_code_zp")})
|
||||
"up_gate_proj_weight",
|
||||
"down_proj_weight",
|
||||
paddle::Optional("up_gate_proj_bias"),
|
||||
paddle::Optional("up_gate_proj_scale"),
|
||||
paddle::Optional("down_proj_scale"),
|
||||
paddle::Optional("up_gate_proj_local_scale"),
|
||||
paddle::Optional("up_gate_proj_code_scale"),
|
||||
paddle::Optional("up_gate_proj_code_zp"),
|
||||
paddle::Optional("down_proj_local_scale"),
|
||||
paddle::Optional("down_proj_code_scale"),
|
||||
paddle::Optional("down_proj_code_zp")})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"used_in_ep_low_latency:bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertFFNWint2))
|
||||
|
@@ -25,7 +25,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
|
||||
const paddle::Tensor &top_k_weight,
|
||||
const paddle::Tensor &permute_indices_per_token,
|
||||
const paddle::Tensor &top_k_indices,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor, const int num_rows,
|
||||
const int hidden_size, const int topk,
|
||||
@@ -38,7 +38,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
|
||||
|
||||
finalize_moe_routing_kernelLauncher<data_t>::run(
|
||||
ffn_out.data<data_t>(), output->data<data_t>(),
|
||||
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
|
||||
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
|
||||
top_k_weight.data<float>(), permute_indices_per_token.data<int32_t>(),
|
||||
top_k_indices.data<int>(), num_rows, hidden_size, topk,
|
||||
static_cast<int>(1), norm_topk_prob, routed_scaling_factor, stream);
|
||||
@@ -48,7 +48,7 @@ paddle::Tensor MoeExpertReduceFunc(
|
||||
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
|
||||
const paddle::Tensor &permute_indices_per_token,
|
||||
const paddle::Tensor &top_k_indices,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const bool norm_topk_prob, const float routed_scaling_factor) {
|
||||
const auto input_type = ffn_out.dtype();
|
||||
auto place = ffn_out.place();
|
||||
@@ -63,13 +63,13 @@ paddle::Tensor MoeExpertReduceFunc(
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||
ffn_out, top_k_weight, permute_indices_per_token, top_k_indices,
|
||||
ffn2_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
|
||||
down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
|
||||
topk, &output);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||
ffn_out, top_k_weight, permute_indices_per_token, top_k_indices,
|
||||
ffn2_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
|
||||
down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
|
||||
topk, &output);
|
||||
break;
|
||||
default:
|
||||
@@ -83,10 +83,10 @@ MoeExpertReduce(const paddle::Tensor &ffn_out,
|
||||
const paddle::Tensor &top_k_weight,
|
||||
const paddle::Tensor &permute_indices_per_token,
|
||||
const paddle::Tensor &top_k_indices,
|
||||
const paddle::optional<paddle::Tensor> &ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor> &down_proj_bias,
|
||||
const bool norm_topk_prob, const float routed_scaling_factor) {
|
||||
return {MoeExpertReduceFunc(ffn_out, top_k_weight, permute_indices_per_token,
|
||||
top_k_indices, ffn2_bias, norm_topk_prob,
|
||||
top_k_indices, down_proj_bias, norm_topk_prob,
|
||||
routed_scaling_factor)};
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
|
||||
const std::vector<int64_t> &top_k_weight_shape,
|
||||
const std::vector<int64_t> &permute_indices_per_token_shape,
|
||||
const std::vector<int64_t> &top_k_indices_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &ffn2_bias_shape) {
|
||||
const paddle::optional<std::vector<int64_t>> &down_proj_bias_shape) {
|
||||
const int moe_topk = top_k_indices_shape[1];
|
||||
auto out_shape = ffn_out_shape;
|
||||
if (out_shape[0] != -1) out_shape[0] /= moe_topk;
|
||||
@@ -107,19 +107,19 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
const paddle::DataType &top_k_weight_dtype,
|
||||
const paddle::DataType &permute_indices_per_token_dtype,
|
||||
const paddle::DataType &top_k_indices_dtype,
|
||||
const paddle::optional<paddle::DataType> &ffn2_bias_dtype) {
|
||||
const paddle::optional<paddle::DataType> &down_proj_bias_dtype) {
|
||||
return {ffn_out_dtype};
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Mixture of Experts (MoE) Expert Reduce Operator
|
||||
*
|
||||
*
|
||||
* This operator performs the following key functions:
|
||||
* 1. Combines outputs from multiple experts based on routing weights
|
||||
* 2. Applies optional bias and scaling to the combined output
|
||||
* 3. Restores the original token order from permuted expert outputs
|
||||
*
|
||||
*
|
||||
* Inputs:
|
||||
* - ffn_out: Outputs from all expert networks (permuted)
|
||||
* Shape: [total_tokens * moe_topk, hidden_size]
|
||||
@@ -133,19 +133,19 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
* - top_k_indices: Indices of selected top-k experts for each token
|
||||
* Shape: [total_tokens, moe_topk]
|
||||
* dtype: int32
|
||||
* - ffn2_bias: Optional bias term for expert outputs (hidden_size)
|
||||
*
|
||||
* - down_proj_bias: Optional bias term for expert outputs (hidden_size)
|
||||
*
|
||||
* Outputs:
|
||||
* - output: Combined expert outputs in original token order
|
||||
* Shape: [total_tokens, hidden_size]
|
||||
* dtype: Same as ffn_out
|
||||
*
|
||||
*
|
||||
* Attributes:
|
||||
* - norm_topk_prob: Whether to normalize top-k probabilities
|
||||
* (true: weights sum to 1 for each token,
|
||||
* false: use raw weights)
|
||||
* - routed_scaling_factor: Scaling factor applied to top-k probabilities
|
||||
*
|
||||
*
|
||||
* Note:
|
||||
* - The operator expects permuted expert outputs from moe_expert_dispatch
|
||||
* - When norm_topk_prob is true, weights are normalized per token
|
||||
@@ -154,7 +154,7 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
*/
|
||||
PD_BUILD_STATIC_OP(moe_expert_reduce)
|
||||
.Inputs({"ffn_out", "top_k_weight", "permute_indices_per_token",
|
||||
"top_k_indices", paddle::Optional("ffn2_bias")})
|
||||
"top_k_indices", paddle::Optional("down_proj_bias")})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertReduce))
|
||||
|
@@ -26,7 +26,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
@@ -99,7 +98,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
@@ -128,7 +126,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
seq_lens_this_time, // q_seq_len is 1
|
||||
seq_lens_decoder,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_input_length,
|
||||
max_len_kv_data,
|
||||
@@ -151,7 +149,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& padding_offsets,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
@@ -201,7 +198,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = cum_offsets.dims()[0];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
switch (query.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
@@ -215,7 +212,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
@@ -262,7 +258,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
padding_offsets,
|
||||
cum_offsets,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
@@ -316,7 +311,6 @@ std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& padding_offsets_shape,
|
||||
const std::vector<int64_t>& cum_offsets_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
@@ -371,7 +365,6 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& padding_offsets_dtype,
|
||||
const paddle::DataType& cum_offsets_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
@@ -426,7 +419,6 @@ PD_BUILD_OP(multi_head_latent_attention)
|
||||
"seq_lens_this_time",
|
||||
"cu_seqlens_q",
|
||||
"padding_offsets",
|
||||
"cum_offsets",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
|
@@ -18,7 +18,6 @@
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
|
||||
#include "helper.h"
|
||||
#include "noauxtc_kernel.h"
|
||||
|
||||
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
|
@@ -17,11 +17,11 @@
|
||||
#pragma once
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include "helper.h"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
|
@@ -91,7 +91,12 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(tmp_out.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = tmp_out.stream();
|
||||
#endif
|
||||
std::vector<int64_t> tmp_out_shape = tmp_out.shape();
|
||||
const int token_num = tmp_out_shape[0];
|
||||
const int dim_embed = tmp_out_shape[1];
|
||||
@@ -125,7 +130,7 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
|
||||
if (output_padding_offset) {
|
||||
RebuildAppendPaddingKernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, tmp_out.stream()>>>(
|
||||
<<<grid_size, blocksize, 0, cu_stream>>>(
|
||||
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
||||
reinterpret_cast<const DataType_ *>(tmp_out.data<data_t>()),
|
||||
cum_offsets.data<int>(),
|
||||
@@ -138,7 +143,7 @@ std::vector<paddle::Tensor> rebuild_padding(
|
||||
elem_nums);
|
||||
} else {
|
||||
RebuildPaddingKernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, tmp_out.stream()>>>(
|
||||
<<<grid_size, blocksize, 0, cu_stream>>>(
|
||||
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
||||
reinterpret_cast<DataType_ *>(
|
||||
const_cast<data_t *>(tmp_out.data<data_t>())),
|
||||
|
@@ -376,7 +376,6 @@ __global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
|
||||
}
|
||||
|
||||
// scan/find
|
||||
constexpr int WARP_SIZE = 32;
|
||||
constexpr int WARP_COUNT = NumBuckets / WARP_SIZE;
|
||||
namespace cg = cooperative_groups;
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
|
@@ -289,7 +289,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
curand_init(philox_seed, bx, philox_offset, &state);
|
||||
const uint32_t row_idx = bx;
|
||||
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||
const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx];
|
||||
const float p = top_p_arr[row_idx];
|
||||
|
||||
extern __shared__ __align__(
|
||||
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
|
@@ -91,7 +91,12 @@ void set_data_ipc(const paddle::Tensor& tmp_input,
|
||||
memset((void *)shm, 0, sizeof(*shm));
|
||||
|
||||
void *data_ptr_now = reinterpret_cast<void*>(const_cast<data_t*>(tmp_input.data<data_t>()));
|
||||
#ifdef PADDLE_WITH_HIP
|
||||
checkCudaErrors(hipIpcGetMemHandle((hipIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
|
||||
#else
|
||||
checkCudaErrors(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
|
||||
#endif
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "helper.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
@@ -51,13 +52,18 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_flags) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(stop_flags.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = stop_flags.stream();
|
||||
#endif
|
||||
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
|
||||
|
||||
int bs = seq_lens_this_time.shape()[0];
|
||||
int length = pre_ids_all_shape[1];
|
||||
int length_input_ids = input_ids.shape()[1];
|
||||
int block_size = (bs + 32 - 1) / 32 * 32;
|
||||
int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||
set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(
|
||||
stop_flags.data<bool>(),
|
||||
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
||||
|
@@ -37,10 +37,18 @@ std::vector<paddle::Tensor> ShareExternalData(paddle::Tensor& input,
|
||||
}
|
||||
shm = (volatile shmStruct *)info.addr;
|
||||
void *ptr = nullptr;
|
||||
#ifdef PADDLE_WITH_HIP
|
||||
checkCudaErrors(
|
||||
hipIpcOpenMemHandle(&ptr,
|
||||
*(hipIpcMemHandle_t *)&shm->memHandle, // NOLINT
|
||||
hipIpcMemLazyEnablePeerAccess));
|
||||
#else
|
||||
checkCudaErrors(
|
||||
cudaIpcOpenMemHandle(&ptr,
|
||||
*(cudaIpcMemHandle_t *)&shm->memHandle, // NOLINT
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
#endif
|
||||
|
||||
paddle::Tensor tmp_tensor = paddle::from_blob(
|
||||
ptr,
|
||||
shape,
|
||||
|
@@ -246,7 +246,7 @@ void token_penalty_multi_scores_kernel(
|
||||
max_seq_len);
|
||||
}
|
||||
|
||||
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &logits,
|
||||
const paddle::Tensor &penalty_scores,
|
||||
const paddle::Tensor &frequency_scores,
|
||||
@@ -338,4 +338,4 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
|
||||
.Outputs({"logits_out"})
|
||||
.Attrs({"max_seq_len: int"})
|
||||
.SetInplaceMap({{"logits", "logits_out"}})
|
||||
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));
|
||||
.SetKernelFn(PD_KERNEL(SpecTokenPenaltyMultiScores));
|
||||
|
@@ -73,7 +73,7 @@ __global__ void speculate_verify(
|
||||
const int *output_cum_offsets, const int *actual_candidate_len,
|
||||
const int real_bsz, const int max_draft_tokens, const int end_length,
|
||||
const int max_seq_len, const int max_candidate_len, const int verify_window,
|
||||
const bool prefill_one_step_stop) {
|
||||
const bool prefill_one_step_stop, const bool benchmark_mode) {
|
||||
const int bid = threadIdx.x;
|
||||
// verify and set stop flags
|
||||
int accept_num_now = 1;
|
||||
@@ -95,6 +95,9 @@ __global__ void speculate_verify(
|
||||
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
|
||||
// seq_lens_this_time[bid]-1);
|
||||
for (; i < seq_lens_this_time[bid] - 1; i++) {
|
||||
if (benchmark_mode) {
|
||||
break;
|
||||
}
|
||||
if (seq_lens_encoder[bid] != 0) {
|
||||
break;
|
||||
}
|
||||
@@ -246,7 +249,7 @@ void SpeculateVerify(
|
||||
const paddle::Tensor &output_cum_offsets,
|
||||
const paddle::Tensor &actual_candidate_len,
|
||||
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
||||
int max_seq_len, int verify_window, bool enable_topp) {
|
||||
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) {
|
||||
// printf("Enter speculate update\n");
|
||||
auto bsz = accept_tokens.shape()[0];
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
@@ -263,18 +266,6 @@ void SpeculateVerify(
|
||||
seed++;
|
||||
offset++;
|
||||
|
||||
auto err = cudaDeviceSynchronize();
|
||||
if (err != 0) {
|
||||
printf("err %d\n", err);
|
||||
}
|
||||
|
||||
err = cudaGetLastError();
|
||||
|
||||
if (err != 0) {
|
||||
printf("err %d\n", err);
|
||||
}
|
||||
|
||||
// printf("inited curand\n");
|
||||
bool use_topk = false;
|
||||
char *env_var = getenv("SPECULATE_VERIFY_USE_TOPK");
|
||||
if (env_var) {
|
||||
@@ -301,7 +292,7 @@ void SpeculateVerify(
|
||||
is_block_step.data<bool>(), output_cum_offsets.data<int>(),
|
||||
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
|
||||
end_length, max_seq_len, max_candidate_len, verify_window,
|
||||
prefill_one_step_stop);
|
||||
prefill_one_step_stop, benchmark_mode);
|
||||
} else {
|
||||
speculate_verify<false, true>
|
||||
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
@@ -317,7 +308,7 @@ void SpeculateVerify(
|
||||
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
||||
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
||||
max_candidate_len, verify_window, prefill_one_step_stop);
|
||||
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
|
||||
}
|
||||
} else {
|
||||
if (enable_topp) {
|
||||
@@ -335,7 +326,7 @@ void SpeculateVerify(
|
||||
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
||||
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
||||
max_candidate_len, verify_window, prefill_one_step_stop);
|
||||
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
|
||||
} else {
|
||||
speculate_verify<false, false>
|
||||
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||
@@ -351,7 +342,7 @@ void SpeculateVerify(
|
||||
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
||||
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
||||
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
||||
max_candidate_len, verify_window, prefill_one_step_stop);
|
||||
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -366,7 +357,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
|
||||
"actual_candidate_len", "actual_draft_token_nums", "topp"})
|
||||
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
|
||||
"stop_flags_out"})
|
||||
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"})
|
||||
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"})
|
||||
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"step_idx", "step_idx_out"},
|
||||
|
@@ -189,7 +189,7 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
|
||||
? tmp_used_len + 1
|
||||
: max_decoder_block_num_this_seq;
|
||||
#ifdef DEBUG_STEP
|
||||
printf("#### ori_step_len:%d, ori_free_list_len:%d, used_len:%d \n",
|
||||
printf("#### ori_step_len:%d, ori_free_list_len:%d, used_len:%d \n",
|
||||
ori_step_len, ori_free_list_len, used_len);
|
||||
#endif
|
||||
while (ori_step_len > 0 && ori_free_list_len >= used_len) {
|
||||
@@ -323,7 +323,12 @@ void StepPaddle(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &first_token_ids,
|
||||
const int block_size,
|
||||
const int encoder_decoder_block_num) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = seq_lens_this_time.stream();
|
||||
#endif
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
const int length = input_ids.shape()[1];
|
||||
|
@@ -74,11 +74,16 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = topk_ids.stream();
|
||||
#endif
|
||||
std::vector<int64_t> shape = topk_ids.shape();
|
||||
int64_t bs_now = shape[0];
|
||||
int64_t end_length = end_ids.shape()[0];
|
||||
int block_size = (bs_now + 32 - 1) / 32 * 32;
|
||||
int block_size = (bs_now + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||
set_value_by_flags<<<1, block_size, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
|
@@ -21,6 +21,7 @@
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "helper.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
@@ -88,7 +89,12 @@ void GetStopFlagsMultiSeqs(const paddle::Tensor &topk_ids,
|
||||
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
|
||||
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
|
||||
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
#else
|
||||
auto cu_stream = topk_ids.stream();
|
||||
#endif
|
||||
std::vector<int64_t> shape = topk_ids.shape();
|
||||
std::vector<int64_t> stop_seqs_shape = stop_seqs.shape();
|
||||
int bs_now = shape[0];
|
||||
@@ -96,7 +102,7 @@ void GetStopFlagsMultiSeqs(const paddle::Tensor &topk_ids,
|
||||
int stop_seqs_max_len = stop_seqs_shape[1];
|
||||
int pre_ids_len = pre_ids.shape()[1];
|
||||
|
||||
int block_size = (stop_seqs_bs + 31) / 32 * 32;
|
||||
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
|
||||
set_value_by_stop_seqs<<<bs_now, block_size, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user