Compare commits

..

5 Commits

Author SHA1 Message Date
Jiang-Jia-Jun
3ec126dc02 Update setup.py 2025-07-15 14:57:40 +08:00
gaoziyuan
337d76f094 [sync fix] (#2759)
* add rl qwen model support

* fix

* fix

* add_commit_config

* fix
2025-07-08 19:29:23 +08:00
gaoziyuan
ae2f78184d 【Sync develop】 add commit info (#2755)
* add rl qwen model support

* fix

* fix

* add_commit_config
2025-07-08 17:02:50 +08:00
gaoziyuan
6851489425 【Sync】Release/2.0.1 (#2745)
* add rl qwen model support

* fix

* fix
2025-07-08 14:38:18 +08:00
Jiang-Jia-Jun
ea787d8f62 fix bug. (#2718) (#2720)
Co-authored-by: Ting <wtmlon@foxmail.com>
2025-07-05 09:00:01 +08:00
298 changed files with 7879 additions and 19321 deletions

View File

@@ -2,9 +2,7 @@ name: CI
on:
pull_request:
branches:
- develop
- 'release/*'
branches: [ develop ]
workflow_dispatch:
concurrency:
@@ -29,11 +27,9 @@ jobs:
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
-e "BASE_BRANCH=${BASE_BRANCH}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
@@ -42,7 +38,7 @@ jobs:
'
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
git clone ${REPO} ${REPO_NAME}
cd FastDeploy
if [ "${{ github.event_name }}" = "pull_request" ]; then
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}

View File

@@ -2,9 +2,7 @@ name: CI_XPU
on:
pull_request:
branches:
- develop
- 'release/*'
branches: [ develop ]
workflow_dispatch:
concurrency:
@@ -12,7 +10,7 @@ concurrency:
cancel-in-progress: true
jobs:
CI_XPU:
build:
runs-on: [self-hosted, XPU-P800-8Card]
steps:
- name: Print current runner name
@@ -29,11 +27,9 @@ jobs:
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
-e "BASE_BRANCH=${BASE_BRANCH}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
@@ -42,7 +38,7 @@ jobs:
'
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
git clone ${REPO} ${REPO_NAME}
cd FastDeploy
if [ "${{ github.event_name }}" = "pull_request" ]; then
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}

2
.gitignore vendored
View File

@@ -162,5 +162,3 @@ custom_ops/tmp*
build
.ccls-cache
third_party

View File

@@ -8,17 +8,14 @@
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
</p>
<p align="center">
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/installation/nvidia_gpu/"><b> Installation </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/quick_start"><b> Quick Start </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
</p>
--------------------------------------------------------------------------------

View File

@@ -105,30 +105,3 @@ python benchmark_serving.py \
--save-result > infer_log.txt 2>&1 &
```
### 投机解码性能测试工具
#### 使用方式:
```bash
python benchmarks/benchmark_mtp.py \
--host 127.0.0.1 --port 8000 \
--max-concurrency 16 32 64 96 --num-prompts 256 \
--acceptance-rate 0.8 --draft-token-steps 1 2 3 \
--s_itl-base-model 15.88 22.84 16.47 16.93 \
--dataset-name EBChat \
--dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json
```
#### 参数说明
```bash
--host服务ip地址用于组url
--port服务HTTP端口用于组url
--max-concurrency测试并发数
--num-prompts总计发送多少条请求
--acceptance-rate投机解码的模拟接受率
--draft-token-steps投机解码的步数
--s_itl-base-model主模型的解码延迟可由上述的性能压测工具获得与batch-size一一对应
--dataset-name指定数据集类指定为"EBChat"可读取转存的FD格式数据集
--dataset-path测试数据集路径
```

View File

@@ -36,7 +36,6 @@ AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
@dataclass
class RequestFuncInput:
"""Input for requesting LLMs via API"""
no: int
prompt: str
history_QA: Optional[dict]
hyper_parameters: dict
@@ -55,7 +54,6 @@ class RequestFuncInput:
@dataclass
class RequestFuncOutput:
"""Output for requesting LLMs via API"""
no: int = 0
generated_text: str = ""
reasoning_content: str = ""
success: bool = False
@@ -86,7 +84,7 @@ async def async_request_eb_openai_chat_completions(
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
payload = {
"model": request_func_input.model,
"model": "default",
"messages": request_func_input.history_QA,
"stream": True,
"stream_options": {
@@ -99,9 +97,6 @@ async def async_request_eb_openai_chat_completions(
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
print("payload:{}".format(json.dumps(payload, ensure_ascii=False)))
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
@@ -109,7 +104,6 @@ async def async_request_eb_openai_chat_completions(
output = RequestFuncOutput()
output.prompt_len = 0
output.no = request_func_input.no
ttft = 0.0
st = time.perf_counter()
@@ -138,8 +132,7 @@ async def async_request_eb_openai_chat_completions(
ttft = timestamp - st
output.ttft = ttft
# cached_tokens
output.prompt_len = data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
output.prompt_len = data["usage"]["prompt_tokens_details"]["cached_tokens"]
# Decoding phase
else:
@@ -148,12 +141,12 @@ async def async_request_eb_openai_chat_completions(
output.generated_text += content or ""
output.reasoning_content += reason_content or ""
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage", {}):
output.arrival_time.append(choices[0].get("arrival_time"))
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens", 0)
"completion_tokens")
output.prompt_tokens = usage.get(
"prompt_tokens", 0)
"prompt_tokens")
most_recent_timestamp = timestamp
@@ -180,7 +173,6 @@ async def async_request_eb_openai_chat_completions(
f.write(str(output) + "\n")
if pbar:
pbar.update(1)
print("#####final_output:", output)
return output
@@ -197,7 +189,7 @@ async def async_request_eb_openai_completions(
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model,
"model": "default",
"prompt": request_func_input.prompt,
"stream": True,
"stream_options": {
@@ -210,20 +202,14 @@ async def async_request_eb_openai_completions(
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
print("payload:", json.dumps(payload, ensure_ascii=False))
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"Content-Type": "application/json"
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
output.no = request_func_input.no
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
@@ -240,7 +226,6 @@ async def async_request_eb_openai_completions(
"data: ")
if chunk != "[DONE]":
# print("####chunk:", chunk, chunk.usage)
timestamp = time.perf_counter()
data = json.loads(chunk)
# NOTE: Some completion API might have a last
@@ -250,22 +235,21 @@ async def async_request_eb_openai_completions(
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
first_chunk_received = True
ttft = timestamp - st
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
generated_text += text or ""
most_recent_timestamp = timestamp
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
output.arrival_time.append(choices[0].get("arrival_time"))
generated_text += text or ""
elif usage := data.get("usage"):
output.prompt_tokens = usage.get(
"prompt_tokens")
@@ -278,15 +262,8 @@ async def async_request_eb_openai_completions(
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
if output.generated_text == "":
output.success = False
output.error = "No generated text found!"
else:
output.success = True
else:
output.error = response.reason or ""
output.success = False
@@ -294,8 +271,6 @@ async def async_request_eb_openai_completions(
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
print("final_output:{}".format(output))
if pbar:
pbar.update(1)

View File

@@ -38,7 +38,7 @@ class SampleRequest:
"""
Represents a single inference request for benchmarking.
"""
no: int
prompt: Union[str, Any]
history_QA: Union[str, Any]
json_data: Optional[dict]
@@ -229,7 +229,6 @@ class EBDataset(BenchmarkDataset):
**kwargs,
) -> list:
samples: list = []
cnt = 1
for entry in self.data:
if len(samples) >= num_requests:
break
@@ -247,17 +246,16 @@ class EBDataset(BenchmarkDataset):
prompt, None)
samples.append(
SampleRequest(
no=cnt,
prompt=prompt,
prompt_len=self.prompt_len,
history_QA=[],
expected_output_len=new_output_len,
))
cnt += 1
self.maybe_oversample_requests(samples, num_requests)
return samples
class EBChatDataset(BenchmarkDataset):
"""
Implements the ShareGPT dataset. Loads data from a JSON file and generates
@@ -286,7 +284,6 @@ class EBChatDataset(BenchmarkDataset):
**kwargs,
) -> list:
samples: list = []
cnt = 1
for entry in self.data:
if len(samples) >= num_requests:
break
@@ -300,14 +297,12 @@ class EBChatDataset(BenchmarkDataset):
prompt, None)
samples.append(
SampleRequest(
no=cnt,
json_data=json_data,
prompt=prompt,
prompt_len=0,
history_QA=history_QA,
expected_output_len=new_output_len,
))
cnt += 1
self.maybe_oversample_requests(samples, num_requests)
return samples

View File

@@ -1,191 +0,0 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import argparse
import asyncio
import contextlib
import os
import signal
import socket
import subprocess
import time
from typing import Union
import openai
import yaml
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_serving import benchmark
def prepare_input_requests(
num_prompts: int, dataset_name: str, dataset_path: str
) -> Union[EBDataset, EBChatDataset]:
dataset_mapping = {
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(
num_requests=num_prompts
),
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
num_requests=num_prompts
),
}
try:
input_requests = dataset_mapping[dataset_name]()
except KeyError as err:
raise ValueError(f"Unknown dataset: {dataset_name}") from err
return input_requests
class FakeTokenizer:
def encode(self, text: str, add_special_tokens: bool = False):
return []
def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm):
selected_percentile_metrics = ["s_itl"]
selected_percentiles = []
# Run benchmark
results = asyncio.run(
benchmark(
backend="openai-chat",
api_url=f"{base_url}/v1/chat/completions",
base_url=base_url,
model_id="default",
model_name="default",
input_requests=input_requests,
hyper_parameters={},
logprobs=None,
request_rate=float("inf"),
burstiness=1.0,
disable_tqdm=disable_tqdm,
profile=False,
selected_percentile_metrics=selected_percentile_metrics,
selected_percentiles=selected_percentiles,
ignore_eos=False,
goodput_config_dict=None,
max_concurrency=max_concurrency,
lora_modules=None,
extra_body=None,
)
)
record = {
"mean_s_itl_ms": results["mean_s_itl_ms"],
}
return record
def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
tmp = 0.0
for i in range(draft_token_step):
tmp += pow(acceptance_rate, i + 1)
r_ac = tmp / (1 + tmp)
return t_ori / ((1 - r_ac) * t_mtp)
def main(args):
base_url = f"http://{args.host}:{args.port}"
input_requests = prepare_input_requests(
args.num_prompts, args.dataset_name, args.dataset_path
)
if len(args.max_concurrency) != len(args.s_itl_base_model):
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
# Wramup
print("Starting warmup...")
with open(os.devnull, "w") as f:
with contextlib.redirect_stdout(f):
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)
# Benchmark
record = send_one_batch(base_url, max_concurrency, input_requests, False)
metric_header = f"Speed up"
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
for draft_token_step in args.draft_token_steps:
speedup = calculate_speedup(
args.acceptance_rate,
draft_token_step,
s_itl,
record["mean_s_itl_ms"],
)
print(
"{:<40} {:<10.2f}".format(
f"Speed up on {draft_token_step} steps draft", speedup
)
)
print("=" * 50)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
)
parser.add_argument(
"--port",
type=str,
default="8000",
)
parser.add_argument(
"--max-concurrency",
type=int,
nargs="+",
default=(1, 2, 4, 8, 16, 32),
)
parser.add_argument(
"--num-prompts",
type=int,
default=128,
)
parser.add_argument(
"--acceptance-rate",
type=float,
default=0.8,
)
parser.add_argument(
"--draft-token-steps",
type=int,
nargs="+",
default=(1, 2),
)
parser.add_argument(
"--s_itl-base-model",
type=float,
nargs="+",
)
parser.add_argument(
"--dataset-name",
type=str,
default="EBChat",
)
parser.add_argument(
"--dataset-path",
type=str,
)
args = parser.parse_args()
main(args)

View File

@@ -182,7 +182,6 @@ def calculate_metrics(
# len(outputs[i].itl) since multiple output tokens may be
# bundled together
# Note : this may inflate the output token count slightly
continue
actual_output_lens.append(output_len)
input_lens.append(outputs[i].prompt_len)
@@ -210,8 +209,6 @@ def calculate_metrics(
if len(outputs[i].arrival_time) > 2:
s_decodes.append((outputs[i].output_tokens - 1) /
(outputs[i].arrival_time[-1] - outputs[i].arrival_time[1]))
else:
print("len(outputs[i].arrival_time) <= 2")
completed += 1
else:
actual_output_lens.append(0)
@@ -344,16 +341,15 @@ async def benchmark(
raise ValueError(f"Unknown backend: {backend}")
print("Starting initial single prompt test run...")
test_prompt, test_output_len, test_no = \
test_prompt, test_output_len = \
input_requests[0].prompt, \
input_requests[0].expected_output_len, input_requests[0].no
input_requests[0].expected_output_len
test_history_QA = input_requests[0].history_QA
test_input = RequestFuncInput(
model=model_id,
model_name=model_name,
prompt=test_prompt,
no=test_no,
prompt_len=0,
history_QA=test_history_QA,
hyper_parameters=hyper_parameters,
@@ -388,7 +384,6 @@ async def benchmark(
profile_input = RequestFuncInput(model=model_id,
model_name=model_name,
prompt=test_prompt,
no=test_no,
api_url=base_url + "/start_profile",
output_len=test_output_len,
logprobs=logprobs,
@@ -427,7 +422,7 @@ async def benchmark(
benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness):
prompt, output_len, no = request.prompt, request.expected_output_len, request.no
prompt, output_len = request.prompt, request.expected_output_len
history_QA = request.history_QA
req_model_id, req_model_name = model_id, model_name
@@ -438,7 +433,6 @@ async def benchmark(
request_func_input = RequestFuncInput(model=req_model_id,
model_name=req_model_name,
prompt=prompt,
no=no,
prompt_len=0,
history_QA=history_QA,
hyper_parameters=hyper_parameters,
@@ -458,7 +452,6 @@ async def benchmark(
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
no=test_no,
api_url=base_url + "/stop_profile",
output_len=test_output_len,
logprobs=logprobs,
@@ -471,8 +464,6 @@ async def benchmark(
pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time
print("benchmark_duration:", benchmark_duration)
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
@@ -603,155 +594,6 @@ async def benchmark(
return result
def benchmark_metrics(
benchmark_duration: float,
result_file: str,
selected_percentiles: list[float],
selected_percentile_metrics: list[str],
goodput_config_dict: dict[str, float],
):
"""Benchmark metrics statisticsgenerate benchmark result"""
outputs = []
case_no_list = []
with open(result_file) as f:
for line in f.readlines():
if "RequestFuncOutput" in line:
start = line.find("RequestFuncOutput")
end = line.rfind(")")
para_str = line[start:end + 1]
output = eval(para_str)
outputs.append(output)
input_requests = [[]] * len(outputs)
goodput_config_dict = check_goodput_args(args)
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
selected_percentiles=selected_percentiles,
goodput_config_dict=goodput_config_dict,
)
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput:":
metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"input_texts": ["" for input in input_requests],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
}
def process_one_metric(
# E.g., "ttft"
metric_attribute_name: str,
# E.g., "TTFT"
metric_name: str,
# E.g., "Time to First Token"
metric_header: str,
):
# This function prints and adds statistics of the specified
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
print("{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms")))
result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms")
result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
def process_one_length(
# E.g., "ttft"
metric_attribute_name: str,
# E.g., "TTFT"
metric_name: str,
# E.g., "Time to First Token"
metric_header: str,
):
# This function prints and adds statistics of the specified
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
print("{:<40} {:<10.2f}".format(
f"Mean {metric_name}:",
getattr(metrics, f"mean_{metric_attribute_name}")))
print("{:<40} {:<10.2f}".format(
f"Median {metric_name}:",
getattr(metrics, f"median_{metric_attribute_name}")))
result[f"mean_{metric_attribute_name}"] = getattr(
metrics, f"mean_{metric_attribute_name}")
result[f"median_{metric_attribute_name}"] = getattr(
metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr(
metrics, f"std_{metric_attribute_name}")
for p, value in getattr(metrics,
f"percentiles_{metric_attribute_name}"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:",
value))
result[f"p{p_word}_{metric_attribute_name}"] = value
process_one_length("s_decode", "Decode", "解码速度(tok/s)")
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_one_metric("tpot", "TPOT",
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
process_one_metric("s_e2el", "S_E2EL", "Infer End-to-end Latency")
process_one_length("input_len", "Input Length", "Input Length")
process_one_length("s_input_len", "Input Length", "Infer Input Length")
process_one_length("output_len", "Output Length", "Output Length")
print("=" * 50)
return result
def check_goodput_args(args):
"""Check whether the given argument has valid goodput configuration or not"""
# Check and parse goodput arguments
@@ -917,16 +759,6 @@ def main(args: argparse.Namespace):
lora_modules=args.lora_modules,
extra_body=sampling_params,
))
# benchmark_result = benchmark_metrics(
# benchmark_duration=3600,
# result_file="your result file",
# selected_percentile_metrics=args.percentile_metrics.split(","),
# selected_percentiles=[
# float(p) for p in args.metric_percentiles.split(",")
# ],
# goodput_config_dict=goodput_config_dict,
# )
# Save config and results to json
if args.save_result:

View File

@@ -2,5 +2,4 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -2,5 +2,4 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -3,5 +3,4 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint8
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -3,5 +3,4 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint8
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -2,5 +2,4 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -3,5 +3,4 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint4
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -3,5 +3,4 @@ max_num_seqs: 96
gpu_memory_utilization: 0.9
kv_cache_ratio: 0.71
tensor_parallel_size: 4
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -2,5 +2,4 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -2,5 +2,4 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -3,5 +3,4 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wfp8afp8
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -2,5 +2,4 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -2,5 +2,4 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -3,5 +3,4 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint8
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -3,5 +3,4 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint8
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -2,5 +2,4 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -3,5 +3,4 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint4
graph_optimization_config:
graph_opt_level: 1
enable_static_graph_inference: True

View File

@@ -1,11 +0,0 @@
top_p: 1.0
temperature: 1.0
metadata:
min_tokens: 1
max_tokens: 30721
repetition_penalty: 1.0
frequency_penalty: 0
presence_penalty: 0
skip_special_tokens: false
chat_template_kwargs:
enable_thinking: true

View File

@@ -18,9 +18,6 @@ BUILD_WHEEL=${1:-1}
PYTHON_VERSION=${2:-"python"}
export python=$PYTHON_VERSION
FD_CPU_USE_BF16=${3:-"false"}
# FD_BUILDING_ARCS: Specify target CUDA architectures for custom ops, e.g., "[80, 90, 100]".
# For SM90 (Hopper), use 90. For SM100 (Blackwell), use 100.
# These will be translated to 90a / 100a in setup_ops.py for specific features.
FD_BUILDING_ARCS=${4:-""}
@@ -77,10 +74,8 @@ function copy_ops(){
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
if [ "$is_rocm" = "True" ]; then
DEVICE_TYPE="rocm"
mkdir -p ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "BASE and ROCM ops have been copy to fastdeploy"
echo -e "ROCM ops have been copy to fastdeploy"
return
fi
mkdir -p ../fastdeploy/model_executor/ops/base
@@ -109,23 +104,6 @@ function copy_ops(){
return
fi
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
if [ "$if_corex" = "True" ]; then
DEVICE_TYPE="iluvatar-gpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
return
fi
is_gcu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('gcu'))"`
if [ "$is_gcu" = "True" ]; then
DEVICE_TYPE="gcu"
cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gcu
echo -e "gcu ops have been copy to fastdeploy"
return
fi
DEVICE_TYPE="cpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cd ../../../../
@@ -185,6 +163,17 @@ function build_and_install() {
exit 1
fi
echo -e "${BLUE}[build]${NONE} ${GREEN}build fastdeploy wheel success${NONE}\n"
echo -e "${BLUE}[install]${NONE} installing fastdeploy..."
cd $DIST_DIR
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install --force-reinstall --no-cache-dir
if [ $? -ne 0 ]; then
cd ..
echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed"
exit 1
fi
echo -e "${BLUE}[install]${NONE} ${GREEN}fastdeploy install success${NONE}\n"
cd ..
}
function version_info() {
@@ -192,10 +181,7 @@ function version_info() {
fastdeploy_git_commit_id=$(git rev-parse HEAD)
paddle_version=$(${python} -c "import paddle; print(paddle.__version__)")
paddle_git_commit_id=$(${python} -c "import paddle; print(paddle.__git_commit__)")
cuda_version="nvcc-not-installed"
if command -v nvcc &> /dev/null; then
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
fi
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
cxx_version=$(g++ --version | head -n 1 | grep -Po "(?<=\) )[\d.]+")
echo "fastdeploy GIT COMMIT ID: $fastdeploy_git_commit_id" > $output_file

View File

@@ -47,7 +47,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -166,7 +166,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
lambda_batch_ids,
lambda_tile_ids_per_batch,
@@ -203,7 +203,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_encoder,
seq_lens_decoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
kv_batch_ids,
kv_tile_ids_per_batch,
@@ -275,7 +275,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
rotary_embs,
qkv_out_scales,
@@ -298,7 +298,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
rotary_embs,
qkv_out_scales,
@@ -323,7 +323,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
rotary_embs,
qkv_out_scales,
@@ -347,7 +347,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
rotary_embs,
qkv_out_scales,
@@ -404,7 +404,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -462,7 +462,7 @@ std::vector<paddle::Tensor> AppendAttention(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = seq_lens_this_time.dims()[0];
meta_data.batch_size = cum_offsets.dims()[0];
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
@@ -474,7 +474,7 @@ std::vector<paddle::Tensor> AppendAttention(
seq_lens_decoder,
seq_lens_this_time,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
@@ -551,7 +551,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& padding_offsets_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
@@ -611,7 +611,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& padding_offsets_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
@@ -689,7 +689,7 @@ PD_BUILD_STATIC_OP(append_attention)
"seq_lens_decoder",
"seq_lens_this_time",
"padding_offsets",
"cu_seqlens_q",
"cum_offsets",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",

View File

@@ -41,7 +41,7 @@ __global__ void multi_query_append_attention_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -114,7 +114,8 @@ __global__ void multi_query_append_attention_kernel(
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_base_seq_id_this_block =
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
@@ -404,7 +405,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -476,7 +477,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
q_head_idx * HEAD_DIM +
@@ -774,7 +776,7 @@ void MultiQueryAppendAttention(
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -880,7 +882,7 @@ void MultiQueryAppendAttention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -937,7 +939,7 @@ void MultiQueryAppendAttention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -972,7 +974,7 @@ void MultiQueryAppendAttention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1101,7 +1103,7 @@ void MultiQueryAppendAttention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1169,7 +1171,7 @@ void MultiQueryAppendAttention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1205,7 +1207,7 @@ void MultiQueryAppendAttention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1288,7 +1290,7 @@ void CascadeAppendAttentionC16Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -1351,7 +1353,7 @@ void CascadeAppendAttentionC16Kernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_table,
batch_ids,
tile_ids_per_batch,

View File

@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -144,7 +144,8 @@ __global__ void multi_query_append_attention_c4_kernel(
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
const uint32_t kv_b_stride = HEAD_DIM / 2;
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_base_seq_id_this_block =
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
@@ -503,7 +504,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -600,7 +601,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
const uint32_t kv_b_stride = HEAD_DIM / 2;
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
q_head_idx * HEAD_DIM +
@@ -961,7 +963,7 @@ void MultiQueryAppendC4Attention(
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -1086,7 +1088,7 @@ void MultiQueryAppendC4Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1149,7 +1151,7 @@ void MultiQueryAppendC4Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1184,7 +1186,7 @@ void MultiQueryAppendC4Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1331,7 +1333,7 @@ void MultiQueryAppendC4Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1407,7 +1409,7 @@ void MultiQueryAppendC4Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1442,7 +1444,7 @@ void MultiQueryAppendC4Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1525,7 +1527,7 @@ void CascadeAppendAttentionC4Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -1592,7 +1594,7 @@ void CascadeAppendAttentionC4Kernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_table,
batch_ids,
tile_ids_per_batch,

View File

@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -151,7 +151,8 @@ __global__ void multi_query_append_attention_c8_kernel(
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t kv_d_stride = BLOCK_SIZE;
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_base_seq_id_this_block =
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
@@ -472,7 +473,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -574,7 +575,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t kv_d_stride = BLOCK_SIZE;
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
q_head_idx * HEAD_DIM +
@@ -898,7 +900,7 @@ void MultiQueryAppendC8Attention(
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -1052,7 +1054,7 @@ void MultiQueryAppendC8Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1109,7 +1111,7 @@ void MultiQueryAppendC8Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1144,7 +1146,7 @@ void MultiQueryAppendC8Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1315,7 +1317,7 @@ void MultiQueryAppendC8Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1385,7 +1387,7 @@ void MultiQueryAppendC8Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1415,7 +1417,7 @@ void MultiQueryAppendC8Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1498,7 +1500,7 @@ void CascadeAppendAttentionC8Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -1563,7 +1565,7 @@ void CascadeAppendAttentionC8Kernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_table,
batch_ids,
tile_ids_per_batch,

View File

@@ -2111,7 +2111,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ cum_offsets,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
OutT *__restrict__ out,
@@ -2127,7 +2127,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
const int bid = blockIdx.x, hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) return;
int seq_len_kv = seq_lens_kv[bid];

View File

@@ -41,7 +41,7 @@ void CascadeAppendAttentionC16Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -86,7 +86,7 @@ void CascadeAppendAttentionC8Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -131,7 +131,7 @@ void CascadeAppendAttentionC4Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -176,7 +176,7 @@ void CascadeAppendAttentionKernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -212,7 +212,7 @@ void CascadeAppendAttentionKernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_table,
batch_ids,
tile_ids_per_batch,
@@ -247,7 +247,7 @@ void CascadeAppendAttentionKernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_table,
batch_ids,
tile_ids_per_batch,
@@ -282,7 +282,7 @@ void CascadeAppendAttentionKernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_table,
batch_ids,
tile_ids_per_batch,
@@ -317,7 +317,7 @@ void CascadeAppendAttentionKernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_table,
batch_ids,
tile_ids_per_batch,

View File

@@ -35,7 +35,7 @@ __global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi
const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads]
const int * __restrict__ seq_lens_q,
const int * __restrict__ seq_lens_kv,
const int * __restrict__ cu_seqlens_q,
const int * __restrict__ cum_offsets,
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
OutT * __restrict__ out, // [token_num, num_heads, head_dim]
@@ -59,7 +59,7 @@ __global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi
__shared__ T smem[bdy * HEAD_DIM];
__shared__ T md_smem[bdy * 2];
const int start_token_ids = cu_seqlens_q[qid];
const int start_token_ids = qid * max_seq_len - __ldg(&cum_offsets[qid]);
using LoadT = AlignedVector<T, vec_size>;
LoadT load_vec;
LoadT res_vec;
@@ -134,7 +134,7 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int * __restrict__ seq_lens_q,
const int * __restrict__ seq_lens_kv,
const int * __restrict__ cu_seqlens_q,
const int * __restrict__ cum_offsets,
const int * __restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -171,8 +171,8 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke
}
kv_len += q_len;
const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size);
const uint32_t q_start_idx = cu_seqlens_q[bid];
const uint32_t q_write_idx = cu_seqlens_q[bid];
const uint32_t q_start_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const uint32_t q_write_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
if (chunk_id >= num_chunk_this_seq) {
return;
}
@@ -318,7 +318,7 @@ void MultiQueryDecoderAttention(
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
const int max_seq_len,
const int max_dec_len,
@@ -393,7 +393,7 @@ void MultiQueryDecoderAttention(
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -430,7 +430,7 @@ void MultiQueryDecoderAttention(
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -456,7 +456,7 @@ void MultiQueryDecoderAttention(
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>())),
@@ -484,7 +484,7 @@ void DecodeMLAAttentionKernel(
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
@@ -513,7 +513,7 @@ void DecodeMLAAttentionKernel(
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cu_seqlens_q,
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cum_offsets,
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
}
@@ -528,7 +528,7 @@ template void DecodeMLAAttentionKernel<paddle::bfloat16>(
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
@@ -549,7 +549,7 @@ template void DecodeMLAAttentionKernel<paddle::float16>(
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,

View File

@@ -29,7 +29,7 @@ __global__ void append_decode_cache_T_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -65,7 +65,7 @@ __global__ void append_decode_cache_T_rope_kernel(
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
@@ -135,7 +135,7 @@ __global__ void append_decode_cache_T_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -177,7 +177,7 @@ __global__ void append_decode_cache_T_rope_kernel(
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
@@ -255,7 +255,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -293,7 +293,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
@@ -367,7 +367,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -409,7 +409,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
@@ -499,7 +499,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -523,7 +523,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -746,7 +746,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -775,7 +775,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -1048,7 +1048,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1073,7 +1073,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -1347,7 +1347,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1377,7 +1377,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -1740,7 +1740,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1766,7 +1766,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int half_block_size = block_size / 2;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -2035,7 +2035,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -2066,7 +2066,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int half_block_size = block_size / 2;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -2363,7 +2363,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -2389,7 +2389,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int half_block_size = block_size / 2;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -2733,7 +2733,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -2764,7 +2764,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int half_block_size = block_size / 2;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;

View File

@@ -22,7 +22,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cu_seqlens_q,
const int* cum_offsets,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -58,7 +58,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -80,7 +80,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -103,7 +103,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -126,7 +126,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -150,7 +150,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cu_seqlens_q,
const int* cum_offsets,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -183,7 +183,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -208,7 +208,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -233,7 +233,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -258,7 +258,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -283,7 +283,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cu_seqlens_q,
const int* cum_offsets,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -318,7 +318,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -345,7 +345,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -372,7 +372,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -399,7 +399,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -425,7 +425,7 @@ void DecoderWriteCacheWithRoPEKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -472,7 +472,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -504,7 +504,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -537,7 +537,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -571,7 +571,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -604,7 +604,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -651,7 +651,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -678,7 +678,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -704,7 +704,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -730,7 +730,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,

View File

@@ -24,7 +24,7 @@ void DecoderWriteCacheWithRoPEKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,

View File

@@ -879,7 +879,7 @@ __global__ void append_write_cache_kv_c8_qkv(
const int *__restrict__ seq_lens_this_time,
const int *__restrict__ seq_lens_decoder,
const int *__restrict__ padding_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_tables,
const int max_seq_len,
const int max_blocks_per_seq,
@@ -911,7 +911,8 @@ __global__ void append_write_cache_kv_c8_qkv(
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
const uint32_t start_token_idx =
batch_id * max_seq_len - cum_offsets[batch_id];
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
const uint32_t kv_h_stride = HEAD_DIM;
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
@@ -1118,7 +1119,7 @@ __global__ void append_write_cache_kv_c4_qkv(
const int *__restrict__ seq_lens_this_time,
const int *__restrict__ seq_lens_decoder,
const int *__restrict__ padding_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ cum_offsets,
const int *__restrict__ block_tables,
const int max_seq_len,
const int max_blocks_per_seq,
@@ -1147,7 +1148,8 @@ __global__ void append_write_cache_kv_c4_qkv(
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
const uint32_t start_token_idx =
batch_id * max_seq_len - cum_offsets[batch_id];
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
const uint32_t kv_h_stride = HEAD_DIM;
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
@@ -1748,7 +1750,7 @@ void CascadeAppendWriteCacheKVC8QKV(
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -1813,7 +1815,7 @@ void CascadeAppendWriteCacheKVC8QKV(
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,
@@ -1836,7 +1838,7 @@ void CascadeAppendWriteCacheKVC4QKV(
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -1883,7 +1885,7 @@ void CascadeAppendWriteCacheKVC4QKV(
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,

View File

@@ -26,7 +26,7 @@ void EncoderWriteCacheWithRopeKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,
@@ -143,7 +143,7 @@ void EncoderWriteCacheWithRopeKernel(
seq_lens_this_time,
seq_lens_decoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
batch_ids,
tile_ids,
@@ -170,7 +170,7 @@ void EncoderWriteCacheWithRopeKernel(
seq_lens_this_time,
seq_lens_decoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
batch_ids,
tile_ids,

View File

@@ -422,6 +422,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids,
@@ -449,7 +450,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const int token_num = qkv_dims[0];
const int max_blocks_per_seq = block_tables.dims()[1];
const int block_size = key_cache.dims()[2];
const int batch_size = seq_lens_this_time.dims()[0];
const int batch_size = cum_offsets.dims()[0];
const int kv_num_heads = key_cache_dims[1];
const int head_dim = key_cache_dims[3];
const int num_heads = qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads;
@@ -462,7 +463,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
meta_data.q_num_heads = num_heads;
meta_data.max_blocks_per_seq = max_blocks_per_seq;
meta_data.block_size = block_size;
meta_data.batch_size = seq_lens_this_time.dims()[0];
meta_data.batch_size = cum_offsets.dims()[0];
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(qkv.place()));
@@ -527,7 +528,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
seq_lens_this_time,
seq_lens_decoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
kv_batch_ids,
kv_tile_ids,
@@ -594,6 +595,7 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
"seq_lens_encoder",
"seq_lens_decoder",
"padding_offsets",
"cum_offsets",
"block_tables",
"kv_batch_ids",
"kv_tile_ids_per_batch",

View File

@@ -23,7 +23,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const int max_seq_len,
cudaStream_t& stream,
@@ -54,7 +54,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_decoder.data<int>(),
max_seq_len,
@@ -74,7 +74,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len) {
@@ -91,7 +91,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cu_seqlens_q.dims()[0];
meta_data.batch_size = cum_offsets.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
@@ -100,7 +100,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
seq_lens,
seq_lens_decoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
max_seq_len,
stream,
@@ -113,7 +113,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
seq_lens,
seq_lens_decoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
max_seq_len,
stream,
@@ -131,7 +131,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
@@ -165,7 +165,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
@@ -185,7 +185,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
@@ -206,7 +206,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
@@ -224,7 +224,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cu_seqlens_q.dims()[0];
meta_data.batch_size = cum_offsets.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
@@ -233,7 +233,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
seq_lens,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
max_seq_len,
speculate_decoder,
@@ -247,7 +247,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
seq_lens,
seq_lens_encoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
max_seq_len,
speculate_decoder,
@@ -266,7 +266,7 @@ PD_BUILD_OP(prefill_mla_write_cache)
"seq_lens",
"seq_lens_decoder",
"padding_offsets",
"cu_seqlens_q",
"cum_offsets",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
@@ -281,7 +281,7 @@ PD_BUILD_OP(decode_mla_write_cache)
"seq_lens",
"seq_lens_encoder",
"padding_offsets",
"cu_seqlens_q",
"cum_offsets",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})

View File

@@ -24,7 +24,7 @@ __global__ void decode_absorb_cache_kernel(
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
@@ -50,7 +50,7 @@ __global__ void decode_absorb_cache_kernel(
linear_index += step) {
const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
@@ -96,7 +96,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
@@ -124,7 +124,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int write_seq_id =
seq_lens[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
@@ -143,7 +143,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
ori_bi,
seq_lens[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
cum_offsets[ori_bi]);
}
if (bias < nope_hidden_size) { // pe
const uint32_t inner_bias = bias;
@@ -179,7 +179,7 @@ __global__ void prefill_absorb_cache_kernel(
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_decoder, // [bsz]
const int max_seq_len,

View File

@@ -27,7 +27,7 @@ void DecodeMLAAttentionKernel(
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,

View File

@@ -27,7 +27,7 @@ __global__ void append_clear_cache_int8_block(
const int* __restrict__ seq_lens,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
@@ -44,7 +44,7 @@ __global__ void append_clear_cache_int8_block(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
if (seq_lens_encoder[bid] > 0) return;
@@ -101,7 +101,7 @@ __global__ void append_clear_cache_int4_block(
const int* __restrict__ seq_lens,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
@@ -118,7 +118,7 @@ __global__ void append_clear_cache_int4_block(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
if (seq_lens_encoder[bid] > 0) return;
@@ -179,7 +179,7 @@ __global__ void append_speculate_cache_rope_kernel(
T* __restrict__ q_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
@@ -219,7 +219,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int write_seq_id =
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
@@ -235,7 +235,7 @@ __global__ void append_speculate_cache_rope_kernel(
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
cum_offsets[ori_bi]);
}
const int block_offset = write_seq_id % block_size;
@@ -312,7 +312,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
@@ -352,7 +352,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int write_seq_id =
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
@@ -368,7 +368,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
cum_offsets[ori_bi]);
}
const int block_offset = write_seq_id % block_size;
@@ -459,7 +459,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -487,7 +487,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
@@ -691,7 +691,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -719,7 +719,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
int q_head_idx, k_head_idx, v_idx;
@@ -1069,7 +1069,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1100,7 +1100,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
@@ -1375,7 +1375,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ cum_offsets,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1406,7 +1406,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = cu_seqlens_q[bid];
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;

View File

@@ -23,7 +23,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cu_seqlens_q,
const int* cum_offsets,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -60,7 +60,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
cos_emb,
sin_emb,
@@ -83,7 +83,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
cos_emb,
sin_emb,
@@ -107,7 +107,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cu_seqlens_q,
const int* cum_offsets,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -137,7 +137,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
seq_lens,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens_encoder,
max_seq_len,
max_blocks_per_seq,
@@ -152,7 +152,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -176,7 +176,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -202,7 +202,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cu_seqlens_q,
const int* cum_offsets,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -234,7 +234,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
seq_lens,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens_encoder,
max_seq_len,
max_blocks_per_seq,
@@ -249,7 +249,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -275,7 +275,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cu_seqlens_q,
cum_offsets,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -302,7 +302,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -350,7 +350,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -377,7 +377,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -410,7 +410,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -443,7 +443,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
padding_offsets.data<int>(),
cu_seqlens_q.data<int>(),
cum_offsets.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -489,7 +489,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -515,7 +515,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -540,7 +540,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -567,7 +567,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,

View File

@@ -24,7 +24,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,

View File

@@ -38,7 +38,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -38,7 +38,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -39,7 +39,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -86,7 +86,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -81,7 +81,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -83,7 +83,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -83,7 +83,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -23,7 +23,7 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,

View File

@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,

View File

@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,

View File

@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,

View File

@@ -54,7 +54,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &padding_offsets, const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &padding_offsets, const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
const paddle::Tensor &encoder_tile_ids_per_batch,
const paddle::Tensor &encoder_num_blocks,
@@ -94,7 +94,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &padding_offsets, const paddle::Tensor &cum_offsets,
const paddle::Tensor &block_tables, const paddle::Tensor &kv_batch_ids,
const paddle::Tensor &kv_tile_ids, const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &cache_batch_ids, const paddle::Tensor &cache_tile_ids,
@@ -116,11 +116,11 @@ PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
paddle::Tensor FusedExpertMoeFunc(
const paddle::Tensor &input, const paddle::Tensor &gate_weight,
const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight,
const paddle::optional<paddle::Tensor> &up_gate_proj_bias,
const paddle::optional<paddle::Tensor> &up_gate_proj_scale,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &down_proj_scale,
const paddle::Tensor &ffn1_weight, const paddle::Tensor &ffn2_weight,
const paddle::optional<paddle::Tensor> &ffn1_bias,
const paddle::optional<paddle::Tensor> &ffn1_scale,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const paddle::optional<paddle::Tensor> &ffn2_scale,
const std::string &quant_method, const int moe_topk,
const bool norm_topk_prob, const bool group_moe);
@@ -149,7 +149,7 @@ MoERedundantTopKSelectKernel(const paddle::Tensor &gating_logits,
std::vector<paddle::Tensor>
EPMoeExpertDispatch(const paddle::Tensor &input, const paddle::Tensor &topk_ids,
const paddle::Tensor &topk_weights,
const paddle::optional<paddle::Tensor> &up_gate_proj_in_scale,
const paddle::optional<paddle::Tensor> &ffn1_in_scale,
const std::vector<int> &token_nums_per_expert,
const int token_nums_this_rank,
const std::string &moe_quant_type);
@@ -158,8 +158,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const paddle::Tensor &input, const paddle::Tensor &scale,
const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights,
const paddle::Tensor &token_nums_per_expert,
const paddle::Tensor &token_nums_per_expert_padded,
const bool use_in_ep, const int token_nums_this_rank_padded);
const paddle::Tensor &token_nums_per_expert_padded);
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
const int block_size);
@@ -173,7 +172,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
const paddle::Tensor &ffn_out, const paddle::Tensor &expert_scales_float,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const bool norm_topk_prob, const float routed_scaling_factor);
std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
@@ -182,35 +181,35 @@ std::vector<std::vector<int>> GetExpertTokenNum(const paddle::Tensor &topk_ids,
paddle::Tensor MoeExpertFFNFunc(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::Tensor& ffn1_weight, const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency);
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
const bool used_in_ep_low_latency);
paddle::Tensor MoeExpertReduceFunc(
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const bool norm_topk_prob, const float routed_scaling_factor);
void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
@@ -331,7 +330,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
@@ -344,7 +343,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len);
@@ -370,6 +369,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -468,262 +468,6 @@ std::vector<paddle::Tensor> NoauxTc(
int topk,
float routed_scaling_factor);
#ifdef ENABLE_FP8
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
const paddle::Tensor& x,
const paddle::Tensor& y,
const paddle::optional<paddle::Tensor>& bias,
bool trans_x,
bool trans_y,
float scale, // only support per-tensor quantization
std::string output_dtype,
std::string activation_type);
paddle::Tensor MoeFusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const paddle::Tensor &scale,
const paddle::Tensor &topk_ids,
const int top_k,
const int intermediate_size,
const bool tiled);
paddle::Tensor FusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const float scale);
#endif
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(int64_t _fa);
int64_t meta_size();
void register_buffer(int64_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(int64_t _fa);
void register_graph_buffers(int64_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
std::tuple<int64_t, paddle::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(paddle::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
// speculative decoding Kernel
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
const paddle::Tensor& input_ids,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len,
const paddle::Tensor& seq_lens_encoder);
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder);
std::vector<paddle::Tensor> SpeculateGetOutputPaddingOffset(
const paddle::Tensor& output_cum_offsets_tmp,
const paddle::Tensor& out_token_num,
const paddle::Tensor& seq_lens_output,
const int max_seq_len);
void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
const paddle::Tensor &presence_scores,
const paddle::Tensor &temperatures,
const paddle::Tensor &bad_tokens,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &output_padding_offset,
const paddle::Tensor &output_cum_offsets,
const int max_seq_len);
void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &end_ids);
void SpeculateVerify(
const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num,
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &verify_tokens, const paddle::Tensor &verify_scores,
const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &actual_draft_token_nums,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &is_block_step,
const paddle::Tensor &stop_nums);
void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &accept_num,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_idx);
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank);
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder);
void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int max_draft_tokens);
// MTP
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_stop_flags);
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& batch_drop,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
const paddle::Tensor& base_model_stop_flags,
const paddle::Tensor& base_model_is_block_step,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill);
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const paddle::Tensor& draft_tokens,
const paddle::Tensor& pre_ids,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& output_cum_offsets,
const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& max_dec_len,
const paddle::Tensor& end_ids,
const paddle::Tensor& base_model_draft_tokens,
const int max_seq_len,
const int substep);
std::vector<paddle::Tensor> EagleGetHiddenStates(
const paddle::Tensor& input,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& accept_nums,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const int actual_draft_token_num);
void MTPStepPaddle(
const paddle::Tensor &base_model_stop_flags,
const paddle::Tensor &stop_flags,
const paddle::Tensor &batch_drop,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const int block_size,
const int max_draft_tokens);
void SpeculateStepPaddle(
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &ori_seq_lens_encoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
const paddle::Tensor &encoder_block_lens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &step_block_list,
const paddle::Tensor &step_lens,
const paddle::Tensor &recover_block_list,
const paddle::Tensor &recover_lens,
const paddle::Tensor &need_block_list,
const paddle::Tensor &need_block_len,
const paddle::Tensor &used_list_len,
const paddle::Tensor &free_list,
const paddle::Tensor &free_list_len,
const paddle::Tensor &input_ids,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &next_tokens,
const paddle::Tensor &first_token_ids,
const paddle::Tensor &accept_num,
const int block_size,
const int encoder_decoder_block_num,
const int max_draft_tokens);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -815,7 +559,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
* ep_moe_dispatch
*/
m.def("ep_moe_expert_dispatch", &EPMoeExpertDispatch, py::arg("input"),
py::arg("topk_ids"), py::arg("topk_weights"), py::arg("up_gate_proj_in_scale"),
py::arg("topk_ids"), py::arg("topk_weights"), py::arg("ffn1_in_scale"),
py::arg("token_nums_per_expert"), py::arg("token_nums_this_rank"),
py::arg("moe_quant_type"), "ep moe export dispatch function");
@@ -823,7 +567,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("ep_moe_expert_combine", &EPMoeExpertCombine, py::arg("ffn_out"),
py::arg("expert_scales_float"), py::arg("permute_indices_per_token"),
py::arg("top_k_indices"), py::arg("down_proj_bias"),
py::arg("top_k_indices"), py::arg("ffn2_bias"),
py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"),
"ep moe export combine function");
@@ -865,7 +609,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
*/
m.def("moe_expert_reduce", &MoeExpertReduceFunc, py::arg("ffn_out"),
py::arg("top_k_weight"), py::arg("permute_indices_per_token"),
py::arg("top_k_indices"), py::arg("down_proj_bias"),
py::arg("top_k_indices"), py::arg("ffn2_bias"),
py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"),
"moe export reduce function");
@@ -893,8 +637,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
* append_attn/get_block_shape_and_split_kv_block.cu
* get_block_shape_and_split_kv_block
*/
m.def("get_block_shape_and_split_kv_block",
&GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block function");
// m.def("f_get_block_shape_and_split_kv_block",
// &GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block
// function");
/**
* get_padding_offset.cu
@@ -955,17 +700,35 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
py::arg("a"), py::arg("c_or_none"), py::arg("b_q_weight"),
py::arg("b_scales"), py::arg("global_scale_or_none"), py::arg("b_zeros_or_none"),
py::arg("g_idx_or_none"), py::arg("perm_or_none"), py::arg("workspace"), py::arg("sorted_token_ids"),
py::arg("expert_ids"), py::arg("num_tokens_post_padded"), py::arg("topk_weights"), py::arg("moe_block_size"),
py::arg("top_k"), py::arg("mul_topk_weights"), py::arg("is_ep"), py::arg("b_q_type_str"),
py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"), py::arg("use_atomic_add"),
py::arg("use_fp32_reduce"), py::arg("is_zp_float"));
py::arg("a"),
py::arg("c_or_none"),
py::arg("b_q_weight"),
py::arg("b_scales"),
py::arg("global_scale_or_none"),
py::arg("b_zeros_or_none"),
py::arg("g_idx_or_none"),
py::arg("perm_or_none"),
py::arg("workspace"),
py::arg("sorted_token_ids"),
py::arg("expert_ids"),
py::arg("num_tokens_post_padded"),
py::arg("topk_weights"),
py::arg("moe_block_size"),
py::arg("top_k"),
py::arg("mul_topk_weights"),
py::arg("is_ep"),
py::arg("b_q_type_str"),
py::arg("size_m"),
py::arg("size_n"),
py::arg("size_k"),
py::arg("is_k_full"),
py::arg("use_atomic_add"),
py::arg("use_fp32_reduce"),
py::arg("is_zp_float"));
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
"get_position_ids_and_mask_encoder_batch function");
/**
* cutlass_scaled_mm.cu
* cutlass_scaled_mm
@@ -999,71 +762,4 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
#ifdef ENABLE_FP8
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function");
m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func,
py::arg("input"), py::arg("scale"), py::arg("topk_ids"),
py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function");
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
#endif
m.def("init_custom_all_reduce", &init_custom_all_reduce, "init all reduce class function");
m.def("all_reduce", &all_reduce, "all reduce function");
m.def("dispose", &dispose, "del function for python");
m.def("meta_size", &meta_size, "meta_size function for Signal struct");
m.def("register_buffer", &register_buffer, "register ipc buffer");
m.def("register_graph_buffers", &register_graph_buffers, "register_graph_buffers");
m.def("allocate_shared_buffer_and_handle", &allocate_shared_buffer_and_handle, "allocate_shared_buffer_and_handle");
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
// speculative decoding Kernel
m.def("speculate_get_padding_offset", &SpeculateGetPaddingOffset, "speculate_get_padding_offset function");
m.def("speculate_get_seq_lens_output", &SpeculateGetSeqLensOutput, "speculate_get_seq_lens_output function");
m.def("speculate_get_output_padding_offset",&SpeculateGetOutputPaddingOffset, "speculate_get_output_padding_offset function");
m.def("speculate_get_token_penalty_multi_scores",&SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function");
m.def("speculate_set_stop_value_multi_seqs",&SpecGetStopFlagsMultiSeqs, "speculate_set_stop_value_multi_seqs function");
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
m.def("speculate_save_output", &SpeculateSaveWithOutputMsgStatic, "speculate_save_output function");
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
m.def("ngram_match", &NgramMatch, "ngram_match function");
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");
m.def("draft_model_update",&DraftModelUpdate, "draft_model_update function");
m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function");
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
}

View File

@@ -1,165 +0,0 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "all_reduce.cuh"
// Fake pointer type, must match fptr_t type in ops.h.
// We use this type alias to indicate when pointers are passed in as int64_t.
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
paddle::Tensor& rank_data, int64_t rank,
bool full_nvlink) {
int world_size = fake_ipc_ptrs.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
paddle::Signal* ipc_ptrs[8];
for (int i = 0; i < world_size; i++) {
ipc_ptrs[i] = reinterpret_cast<paddle::Signal*>(fake_ipc_ptrs[i]);
}
return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, rank_data.data(),
rank_data.numel(), rank, world_size,
full_nvlink);
}
/**
* Performs an out-of-place allreduce and stores result in out.
*
* If _reg_buffer is null, assumes inp.data() is already IPC-registered.
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
auto stream = inp.stream();
auto input_size = inp.numel() * 2;
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
cudaMemcpyAsync(reg_buffer, inp.data(), input_size,
cudaMemcpyDeviceToDevice, stream);
} else {
reg_buffer = inp.data();
}
switch (out.dtype()) {
case phi::DataType::FLOAT32: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
reinterpret_cast<float*>(out.data()),
out.numel());
break;
}
case phi::DataType::FLOAT16: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
reinterpret_cast<half*>(out.data()), out.numel());
break;
}
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
case phi::DataType::BFLOAT16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
void dispose(fptr_t _fa) {
delete reinterpret_cast<paddle::CustomAllreduce*>(_fa);
}
int64_t meta_size() { return sizeof(paddle::Signal); }
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
void* ipc_ptrs[8];
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
}
fa->register_buffer(ipc_ptrs);
}
// Use vector<int64_t> to represent byte data for python binding compatibility.
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
std::vector<int64_t> bytes(handle.begin(), handle.end());
return std::make_tuple(bytes, offsets);
}
// Use vector<int64_t> to represent byte data for python binding compatibility.
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
std::vector<std::string> bytes;
bytes.reserve(handles.size());
for (int i = 0; i < handles.size(); i++) {
bytes.emplace_back(handles[i].begin(), handles[i].end());
}
bytes.reserve(handles.size());
fa->register_graph_buffers(bytes, offsets);
}
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
int64_t size) {
auto device_index = phi::backends::gpu::GetCurrentDeviceId();
void* buffer;
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
auto stream = paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream();
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Allocate buffer
CUDACHECK(cudaMalloc((void**)&buffer, size));
CUDACHECK(cudaMemsetAsync(buffer, 0, size, stream));
CUDACHECK(cudaStreamSynchronize(stream));
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Create IPC memhandle for the allocated buffer.
// Will use it in open_mem_handle.
auto handle =
paddle::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, paddle::DataType::UINT8, paddle::GPUPlace(device_index));
CUDACHECK(
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer));
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
}
fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
void* ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data()),
cudaIpcMemLazyEnablePeerAccess));
return reinterpret_cast<fptr_t>(ipc_ptr);
}
void free_shared_buffer(fptr_t buffer) {
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
}

View File

@@ -1,526 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <array>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
namespace paddle {
constexpr int kMaxBlocks = 36;
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
struct Signal {
alignas(128) FlagType self_counter[kMaxBlocks][8];
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
};
struct __align__(16) RankData {
const void* __restrict__ ptrs[8];
};
struct __align__(16) RankSignals {
Signal* signals[8];
};
// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
T data[sz];
using type = T;
static constexpr int size = sz;
};
// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
// the (P)acked type for load/store
using P = array_t<T, 16 / sizeof(T)>;
// the (A)ccumulator type for reduction
using A = array_t<float, 16 / sizeof(T)>;
};
#define DINLINE __device__ __forceinline__
// scalar cast functions
DINLINE float upcast_s(half val) { return __half2float(val); }
template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
return __float2half(val);
}
// scalar add functions
// for some reason when compiling with Paddle, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b);
return a;
}
DINLINE float& assign_add(float& a, float b) { return a += b; }
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val);
}
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b);
return a;
}
#endif
template <typename T, int N>
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll
for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]);
}
return a;
}
template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
if constexpr (std::is_same<T, float>::value) {
return val;
} else {
array_t<float, N> out;
#pragma unroll
for (int i = 0; i < N; i++) {
out.data[i] = upcast_s(val.data[i]);
}
return out;
}
}
template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
if constexpr (std::is_same<typename O::type, float>::value) {
return val;
} else {
O out;
#pragma unroll
for (int i = 0; i < O::size; i++) {
out.data[i] = downcast_s<typename O::type>(val.data[i]);
}
return out;
}
}
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
#else
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
#endif
}
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
FlagType flag;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
: "=r"(flag)
: "l"(flag_addr));
#else
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
: "=r"(flag)
: "l"(flag_addr));
#endif
return flag;
}
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
FlagType flag;
asm volatile("ld.volatile.global.u32 %0, [%1];"
: "=r"(flag)
: "l"(flag_addr));
return flag;
}
// is_start: whether this is the very first synchronization barrier.
// need_fence: whether a memory fence is needed. If true, a release-acquire
// semantic is used to enforce memory access order before and after this
// barrier.
template <int ngpus, bool is_start, bool need_fence = false>
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
int rank) {
if constexpr (!is_start) __syncthreads();
static_assert(
!(is_start && need_fence)); // Start barrier shouldn't need fence.
if (threadIdx.x < ngpus) {
// Increment the counter. Technically we only need one counter, but we use
// multiple per block to eliminate the need to share the counter via smem.
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
// Write the expected counter value to peer and wait for correct value from
// peer.
auto peer_counter_ptr =
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
auto self_counter_ptr =
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
if constexpr (need_fence) {
st_flag_release(peer_counter_ptr, val);
while (ld_flag_acquire(self_counter_ptr) != val);
} else {
st_flag_volatile(peer_counter_ptr, val);
while (ld_flag_volatile(self_counter_ptr) != val);
}
}
if constexpr (is_start || need_fence) __syncthreads();
}
template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]);
#pragma unroll
for (int i = 1; i < ngpus; i++) {
packed_assign_add(tmp, upcast(ptrs[i][idx]));
}
return downcast<P>(tmp);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
}
template <typename P>
DINLINE P* get_tmp_buf(Signal* sg) {
return (P*)(((Signal*)sg) + 1);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P* ptrs[ngpus];
P* tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P*)result)[dst_idx] = tmps[i][idx];
}
}
}
}
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
class CustomAllreduce {
public:
int rank_;
int world_size_;
bool full_nvlink_;
RankSignals sg_;
// Stores an map from a pointer to its peer pointters from all ranks.
std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_;
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph capture
// time. Therefore, during capture, we increment the rank data pointer and use
// that as the argument to the kernel. The kernel arguments are stored in
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
// memory pointed to by the pointers in graph_unreg_buffers_ when
// the IPC handles are exchanged between ranks.
//
// The overall process looks like this:
// 1. Graph capture.
// 2. Each rank obtains the IPC handles for each addresses used during cuda
// graph capture using get_graph_buffer_ipc_meta.
// 3. (In Python) all gather the IPC handles.
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
// the rank data array at corresponding positions.
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char*> ipc_handles_;
/**
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* The first section is for allreduce synchronization, and the second section
* is for storing the intermediate results required by some allreduce algos.
*
* Note: this class does not own any device memory. Any required buffers
* are passed in from the constructor.
*/
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
int rank, int world_size, bool full_nvlink = true)
: rank_(rank),
world_size_(world_size),
full_nvlink_(full_nvlink),
self_sg_(signals[rank]),
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
sg_.signals[i] = signals[i];
}
}
char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] =
ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
*((const cudaIpcMemHandle_t*)ipc_handle),
cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
auto handle_sz = sizeof(cudaIpcMemHandle_t);
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i];
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
(CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
}
return std::make_pair(handles, offsets);
}
void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_)
throw std::runtime_error(
"Rank data buffer is overflowed by " +
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
/**
* Register already-shared IPC pointers.
*/
void register_buffer(void** ptrs) {
check_rank_data_capacity();
RankData data;
for (int i = 0; i < world_size_; i++) {
data.ptrs[i] = ptrs[i];
}
auto d_data = d_rank_data_base_++;
CUDACHECK(
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
buffers_[ptrs[rank_]] = d_data;
}
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void register_graph_buffers(
const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i];
auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) {
if (j != rank_) {
char* handle =
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
rd.ptrs[j] = self_ptr;
}
}
}
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
sizeof(RankData) * num_buffers,
cudaMemcpyHostToDevice));
d_rank_data_base_ += num_buffers;
graph_unreg_buffers_.clear();
}
/**
* Performs allreduce, assuming input has already been registered.
*
* Block and grid default configs are results after careful grid search. Using
* 36 blocks give the best or close to the best runtime on the devices I
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
* take a small amount of SMs. Not quite sure the underlying reason, but my
* guess is that too many SMs will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = 36) {
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " +
std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit));
RankData* ptrs;
cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error(
"buffer address " +
std::to_string(reinterpret_cast<uint64_t>(input)) +
" is not registered!");
ptrs = it->second;
}
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
(world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} \
break; \
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
~CustomAllreduce() {
for (auto [_, ptr] : ipc_handles_) {
CUDACHECK(cudaIpcCloseMemHandle(ptr));
}
}
};
} // namespace paddle

View File

@@ -76,34 +76,6 @@ enum class SplitKStyle
// SPLIT_K_PARALLEL // Not supported yet
};
// New enum for SM100 (Blackwell) Tile Configs
// Placeholder values - actual optimal values need research
enum class CutlassTileConfigSM100
{
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
// Actual SM100 tile configs based on user input (K-tile is 128B)
CtaShape64x64x128B,
CtaShape64x128x128B,
CtaShape64x256x128B,
CtaShape128x64x128B,
CtaShape128x128x128B,
CtaShape128x256x128B,
CtaShape256x64x128B,
CtaShape256x128x128B,
CtaShape256x256x128B
// Note: The user-provided list for get_candidate_tiles_sm100 also includes
// CtaShape128x64x128B and CtaShape256x64x128B for specific FP4 grouped gemm cases.
// These are already covered by the list above if general suffices.
// If they need distinct enum values, they should be added.
// For now, keeping the enum concise with unique shapes mentioned for general use.
};
enum class CutlassTileConfigSM90
{
// Signals that we should run heuristics do choose a config
@@ -160,11 +132,9 @@ struct CutlassGemmConfig
WEIGHT_ONLY = 1u << 0,
SIMT_ONLY = 1u << 1,
INT8_ONLY = 1u << 2,
HOPPER = 1u << 3, // SM90
HOPPER = 1u << 3,
GROUPED_GEMM = 1u << 4,
FP8_ONLY = 1u << 5,
BLACKWELL = 1u << 6, // SM100
FP4_ONLY = 1u << 7, // For Blackwell FP4/MXFP4 paths
};
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
@@ -179,17 +149,7 @@ struct CutlassGemmConfig
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
bool is_sm90 = false;
// config options for sm100 (Blackwell)
// Assuming SM100 might use similar schedule/cluster types as SM90 for now.
// These might need to become SM100-specific if Blackwell introduces new concepts.
CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic;
// MainloopScheduleType mainloop_schedule_sm100 = MainloopScheduleType::AUTO; // Example if SM100 has different types
// EpilogueScheduleType epilogue_schedule_sm100 = EpilogueScheduleType::AUTO; // Example
// ClusterShape cluster_shape_sm100 = ClusterShape::ClusterShape_1x1x1; // Example
bool is_sm100 = false;
CutlassGemmConfig() : is_sm90(false), is_sm100(false) {}
CutlassGemmConfig() {}
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
: tile_config(tile_config)
@@ -197,64 +157,37 @@ struct CutlassGemmConfig
, split_k_factor(split_k_factor)
, stages(stages)
, is_sm90(false)
, is_sm100(false)
{
}
// Constructor for SM90
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90_in, MainloopScheduleType mainloop_schedule_in,
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
: tile_config_sm90(tile_config_sm90_in)
, mainloop_schedule(mainloop_schedule_in)
, epilogue_schedule(epilogue_schedule_in)
, cluster_shape(cluster_shape_in)
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
: tile_config_sm90(tile_config_sm90)
, mainloop_schedule(mainloop_schedule)
, epilogue_schedule(epilogue_schedule)
, cluster_shape(cluster_shape)
, is_sm90(true)
, is_sm100(false)
{
}
// Constructor for SM100 (Blackwell)
// Using existing MainloopScheduleType, EpilogueScheduleType, ClusterShape for now.
// These might need to be new SM100-specific types if Blackwell's TMA differs significantly.
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100_in, MainloopScheduleType mainloop_schedule_in,
EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in)
: tile_config_sm100(tile_config_sm100_in)
, mainloop_schedule(mainloop_schedule_in) // Potentially use mainloop_schedule_sm100 if types diverge
, epilogue_schedule(epilogue_schedule_in) // Potentially use epilogue_schedule_sm100
, cluster_shape(cluster_shape_in) // Potentially use cluster_shape_sm100
, is_sm90(false) // Explicitly false
, is_sm100(true)
{
}
std::string toString() const
{
std::stringstream tactic;
tactic << "Cutlass GEMM Tactic";
if (is_sm100 && tile_config_sm100 != cutlass_extensions::CutlassTileConfigSM100::ChooseWithHeuristic)
if (tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
{
assert(is_sm100 && !is_sm90 && "Invalid cutlass GEMM config: SM100");
tactic << "\n\tstyle=TMA_SM100" // Indicate SM100 specific TMA if applicable
<< "\n\ttile shape ID: " << (int) tile_config_sm100
assert(is_sm90 && "Invalid cutlass GEMM config");
tactic << "\n\tstyle=TMA"
<< "\n\ttile shape ID: " << (int) tile_config_sm90
<< "\n\tcluster shape ID: " << (int) cluster_shape
<< "\n\tmainloop sched: " << (int) mainloop_schedule
<< "\n\tepi sched: " << (int) epilogue_schedule;
}
else if (is_sm90 && tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
{
assert(is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: SM90");
tactic << "\n\tstyle=TMA_SM90"
<< "\n\ttile shape ID: " << (int) tile_config_sm90
<< "\n\tcluster shape ID: " << (int) cluster_shape
<< "\n\tmainloop sched: " << (int) mainloop_schedule
<< "\n\tmainloop sched: " << (int) mainloop_schedule
<< "\n\tepi sched: " << (int) epilogue_schedule;
}
else if (tile_config != cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
{
assert(!is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: Compatible");
assert(!is_sm90 && "Invalid cutlass GEMM config");
tactic << "\n\tstyle=compatible"
<< "\n\ttile shape ID: " << (int) tile_config
<< "\n\ttile shape ID: " << (int) tile_config
<< "\n\tstages: " << (int) stages
<< "\n\tsplit_k_style: " << (int) split_k_style
<< "\n\tsplit k: " << (int) split_k_factor;
@@ -271,24 +204,9 @@ struct CutlassGemmConfig
std::istringstream stream(str);
std::string line;
is_sm90 = false; // Reset flags
is_sm100 = false;
while (std::getline(stream, line)) {
if (line.find("style=TMA_SM100") != std::string::npos) {
is_sm100 = true;
is_sm90 = false;
std::getline(stream, line);
tile_config_sm100 = static_cast<cutlass_extensions::CutlassTileConfigSM100>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
cluster_shape = static_cast<cutlass_extensions::ClusterShape>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
mainloop_schedule = static_cast<cutlass_extensions::MainloopScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
} else if (line.find("style=TMA_SM90") != std::string::npos) { // Check for SM90 specific first
if (line.find("style=TMA") != std::string::npos) {
is_sm90 = true;
is_sm100 = false;
std::getline(stream, line);
tile_config_sm90 = static_cast<cutlass_extensions::CutlassTileConfigSM90>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
@@ -299,7 +217,6 @@ struct CutlassGemmConfig
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
} else if (line.find("style=compatible") != std::string::npos) {
is_sm90 = false;
is_sm100 = false;
std::getline(stream, line);
tile_config = static_cast<cutlass_extensions::CutlassTileConfig>(std::stoi(line.substr(line.find(':') + 1)));
std::getline(stream, line);
@@ -316,14 +233,7 @@ struct CutlassGemmConfig
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config)
{
// clang-format off
if (config.is_sm100)
{
out << "tile_config_sm100_enum: " << int(config.tile_config_sm100)
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule) // Assuming same schedule types for now
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule) // Assuming same schedule types for now
<< ", cluster_shape_enum: " << int(config.cluster_shape); // Assuming same cluster types for now
}
else if (config.is_sm90)
if (config.is_sm90)
{
out << "tile_config_sm90_enum: " << int(config.tile_config_sm90)
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)

View File

@@ -245,88 +245,6 @@ bool supports_mcast_along_n(CutlassTileConfigSM90 const tile)
#endif
}
// SM100 (Blackwell) candidate tile configurations
std::vector<CutlassTileConfigSM100> get_candidate_tiles_sm100(
int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config)
{
#ifdef FAST_BUILD
return {CutlassTileConfigSM100::CtaShape128x128x128B};
#else
/* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) */
if (config & CutlassGemmConfig::GROUPED_GEMM)
{
if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4
{
return {
/* 1 SM (M=128) */
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
/* 2 SM (M=256) */
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B,
/* slim tiles for very tall matrices */
CutlassTileConfigSM100::CtaShape128x64x128B,
CutlassTileConfigSM100::CtaShape256x64x128B};
}
/* Fp8 / Fp16 grouped-GEMM */
return {
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B};
}
/* Non-grouped path (plain GEMM or weight-only) */
return {
/* 1 SM tiles */
CutlassTileConfigSM100::CtaShape64x64x128B,
CutlassTileConfigSM100::CtaShape64x128x128B,
CutlassTileConfigSM100::CtaShape64x256x128B,
CutlassTileConfigSM100::CtaShape128x64x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
/* 2 SM tiles */
CutlassTileConfigSM100::CtaShape256x64x128B,
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B};
#endif
}
// M-multicast support for SM100.
bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM100> m_tiles{
CutlassTileConfigSM100::CtaShape128x64x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
CutlassTileConfigSM100::CtaShape256x64x128B,
CutlassTileConfigSM100::CtaShape256x128x128B,
CutlassTileConfigSM100::CtaShape256x256x128B};
return m_tiles.count(tile) == 1;
#endif
}
// N-multicast support for SM100.
bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM100> n_tiles{
CutlassTileConfigSM100::CtaShape64x128x128B,
CutlassTileConfigSM100::CtaShape64x256x128B,
CutlassTileConfigSM100::CtaShape128x128x128B,
CutlassTileConfigSM100::CtaShape128x256x128B,
CutlassTileConfigSM100::CtaShape256x128x128B};
return n_tiles.count(tile) == 1;
#endif
}
std::vector<CutlassGemmConfig> get_candidate_configs(
int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param)
{
@@ -366,50 +284,9 @@ std::vector<CutlassGemmConfig> get_candidate_configs(
}
return candidate_configs;
}
else if (sm == 100 && (config_type_param & CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell
{
std::vector<CutlassTileConfigSM100> tiles = get_candidate_tiles_sm100(sm, config_type_param);
std::vector<CutlassGemmConfig> candidate_configs;
for (auto const& tile_config_sm100 : tiles)
{
// SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO similar to SM90.
// Cluster shapes are also handled similarly.
CutlassGemmConfig config(
tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(config);
bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100);
bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100);
if (has_m_mcast)
{
CutlassGemmConfig mcast_m_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(mcast_m_config);
}
if (has_n_mcast)
{
CutlassGemmConfig mcast_n_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(mcast_n_config);
}
if (has_m_mcast && has_n_mcast)
{
CutlassGemmConfig mcast_mn_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(mcast_mn_config);
}
}
return candidate_configs;
}
// Fallback to older architecture configurations
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, config_type_param);
std::vector<CutlassGemmConfig> candidate_configs; //Already declared above for SM90 path, ensure scope is correct or redeclare if necessary.
// It's fine here as it's within an else if / else block.
std::vector<CutlassGemmConfig> candidate_configs;
bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY;
int const min_stages = int8_configs_only ? 3 : 2;
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);

View File

@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
up_gate_proj_n=7168
up_gate_proj_k=8192
ffn1_n=7168
ffn1_k=8192
down_proj_n=8192
down_proj_k=3584
rm -rf up_gate_proj_7168_8192.log
rm -rf down_proj_8192_3584.log
ffn2_n=8192
ffn2_k=3584
rm -rf ffn1_7168_8192.log
rm -rf ffn2_8192_3584.log
num_experts=8
for tokens_per_expert in 12
do
wait
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 &
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 &
CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 1 0 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 &
# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 1 0 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 &
done
wait
echo "#### finish ####"

View File

@@ -19,7 +19,7 @@
#include "fp8_fp8_half_cuda_core_gemm.h"
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
const paddle::Tensor& x,
const paddle::Tensor& y,
const paddle::optional<paddle::Tensor>& bias,
@@ -142,7 +142,7 @@ paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
{
if(output_dtype == "bfloat16") {
cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params);
} else {
cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params);
}
@@ -174,21 +174,7 @@ paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
fuse_gemm_config};
fp8_fp8_gemm_scale_bias_act(params);
}
return out;
}
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
const paddle::Tensor& x,
const paddle::Tensor& y,
const paddle::optional<paddle::Tensor>& bias,
bool trans_x,
bool trans_y,
float scale, // only support per-tensor quantization
std::string output_dtype,
std::string activation_type) {
return {cutlass_fp8_fp8_half_gemm_func(
x, y, bias, trans_x, trans_y, scale,
output_dtype, activation_type)};
return {out};
}
std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfGemmFusedInferShape(

View File

@@ -1,198 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <algorithm>
#include "helper.h"
__device__ __forceinline__ void hadamard32_warp(__nv_bfloat16& x) {
int lane_id = threadIdx.x % 32;
#pragma unroll
for (int step = 0; step < 5; ++step) {
const int lane_mask = 1 << step;
const __nv_bfloat16 sign = (lane_id & lane_mask) ? -1.f : 1.f;
__nv_bfloat16 x_val_other = __shfl_xor_sync(0xffffffff, x, lane_mask);
x = sign * x + x_val_other;
}
}
__global__ void MoeFusedHadamardQuantFp8Kernel(
const __nv_bfloat16* __restrict__ input,
const float* __restrict__ scale,
const int64_t* __restrict__ topk_ids,
__nv_fp8_e4m3* out,
const int top_k,
const int intermediate_size,
const int64_t numel
) {
int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (out_idx >= numel) return;
int64_t token_idx = out_idx / (top_k * intermediate_size);
int64_t topk_idx = (out_idx / intermediate_size) % top_k;
int64_t inter_idx = out_idx % intermediate_size;
int64_t input_idx = token_idx * intermediate_size + inter_idx;
if (input_idx >= numel / top_k) return;
int64_t expert_id = topk_ids[token_idx * top_k + topk_idx];
float scale_value = scale[expert_id];
__nv_bfloat16 x = input[input_idx];
hadamard32_warp(x);
float x_fp32 = __bfloat162float(x);
float quantized = x_fp32 / scale_value;
out[out_idx] = static_cast<__nv_fp8_e4m3>(quantized);
}
__global__ void MoeFusedHadamardQuantFp8TiledKernel(
const __nv_bfloat16* __restrict__ input,
const float* __restrict__ scale,
const int64_t* __restrict__ topk_ids,
__nv_fp8_e4m3* out,
const int top_k,
const int intermediate_size,
const int64_t numel
) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= numel) return;
int64_t token_idx = idx / intermediate_size;
int64_t expert_id = topk_ids[token_idx];
float scale_value = scale[expert_id];
__nv_bfloat16 x = input[idx];
hadamard32_warp(x);
float x_fp32 = __bfloat162float(x);
float quantized = x_fp32 / scale_value;
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
}
std::vector<paddle::Tensor> MoeFusedHadamardQuantFp8(
const paddle::Tensor &input,
const paddle::Tensor &scale,
const paddle::Tensor &topk_ids,
const int top_k,
const int intermediate_size,
const bool tiled) {
int64_t numel = input.numel();
if (!tiled) numel *= top_k;
paddle::Tensor out = GetEmptyTensor(
{numel / intermediate_size, intermediate_size},
paddle::DataType::FLOAT8_E4M3FN,
input.place());
constexpr int64_t thread_per_block = 256;
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
auto stream = input.stream();
if (tiled) {
MoeFusedHadamardQuantFp8TiledKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
scale.data<float>(),
topk_ids.data<int64_t>(),
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
top_k,
intermediate_size,
numel
);
} else {
MoeFusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(input.data<phi::dtype::bfloat16>()),
scale.data<float>(),
topk_ids.data<int64_t>(),
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
top_k,
intermediate_size,
numel
);
}
return {out};
}
PD_BUILD_STATIC_OP(moe_fused_hadamard_quant_fp8)
.Inputs({"input", "scale", "topk_ids"})
.Outputs({"output"})
.Attrs({"top_k: int",
"intermediate_size: int",
"tiled: bool"})
.SetKernelFn(PD_KERNEL(MoeFusedHadamardQuantFp8));
paddle::Tensor MoeFusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const paddle::Tensor &scale,
const paddle::Tensor &topk_ids,
const int top_k,
const int intermediate_size,
const bool tiled) {
return MoeFusedHadamardQuantFp8(input, scale, topk_ids, top_k, intermediate_size, tiled)[0];
}
__global__ void FusedHadamardQuantFp8Kernel(
const __nv_bfloat16* __restrict__ input,
__nv_fp8_e4m3* out,
const float scale,
const int64_t numel) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= numel) return;
__nv_bfloat16 x = input[idx];
hadamard32_warp(x);
float x_fp32 = __bfloat162float(x);
float quantized = x_fp32 / scale;
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
}
std::vector<paddle::Tensor> FusedHadamardQuantFp8(
const paddle::Tensor &input,
const float scale) {
int64_t numel = input.numel();
paddle::Tensor out = GetEmptyTensor(
input.dims(),
paddle::DataType::FLOAT8_E4M3FN,
input.place());
constexpr int64_t thread_per_block = 256;
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
auto stream = input.stream();
FusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
scale,
numel
);
return {out};
}
PD_BUILD_STATIC_OP(fused_hadamard_quant_fp8)
.Inputs({"input"})
.Outputs({"output"})
.Attrs({"scale: float"})
.SetKernelFn(PD_KERNEL(FusedHadamardQuantFp8));
paddle::Tensor FusedHadamardQuantFp8Func(
const paddle::Tensor &input,
const float scale) {
return FusedHadamardQuantFp8(input, scale)[0];
}

View File

@@ -24,18 +24,16 @@
#endif
#define MAX_BSZ 512
#define K 20
#define K 10
struct msgdata {
long mtype;
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
float mtext_f[MAX_BSZ * (K + 1)]; // score
int mtext_ranks[MAX_BSZ]; // ranks
};
void GetOutputTopK(const paddle::Tensor& x,
const paddle::Tensor& scores,
const paddle::Tensor& ranks,
int k,
int64_t rank_id,
bool wait_flag) {
@@ -68,18 +66,17 @@ void GetOutputTopK(const paddle::Tensor& x,
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
float* scores_data = const_cast<float*>(scores.data<float>());
int64_t* ranks_data = const_cast<int64_t*>(ranks.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid,
&msg_rcv,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
0,
IPC_NOWAIT);
} else {
ret = msgrcv(msgid,
&msg_rcv,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
0,
0);
}
@@ -100,14 +97,13 @@ void GetOutputTopK(const paddle::Tensor& x,
out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2];
scores_data[offset] = msg_rcv.mtext_f[offset];
}
ranks_data[i] = (int64_t)msg_rcv.mtext_ranks[i];
}
return;
}
PD_BUILD_STATIC_OP(get_output_topk)
.Inputs({"x", "scores", "ranks"})
.Inputs({"x", "scores"})
.Attrs({"k: int", "rank_id: int64_t", "wait_flag: bool"})
.Outputs({"x_out", "scores_out", "ranks_out"})
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}, {"ranks", "ranks_out"}})
.Outputs({"x_out", "scores_out"})
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}})
.SetKernelFn(PD_KERNEL(GetOutputTopK));

View File

@@ -13,7 +13,6 @@
// limitations under the License.
#include "paddle/extension.h"
#include "helper.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
@@ -60,12 +59,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &token_num,
const paddle::Tensor &seq_len) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = input_ids.stream();
#endif
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
@@ -81,11 +75,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
#ifdef PADDLE_WITH_COREX
int blockSize = std::min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
#else
int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
#endif
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
GetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
padding_offset.data<int>(),
cum_offsets_out.data<int>(),

View File

@@ -14,9 +14,7 @@
#pragma once
#ifndef PADDLE_WITH_COREX
#include "glog/logging.h"
#endif
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
@@ -37,35 +35,22 @@ namespace cub = hipcub;
#else
#include <cub/cub.cuh>
#endif
#ifndef PADDLE_WITH_COREX
#include "nlohmann/json.hpp"
#endif
#include <fstream>
#include <iostream>
#include "env.h"
#include "paddle/extension.h"
#include "paddle/phi/core/allocator.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/custom/custom_context.h"
#else
#include "paddle/phi/core/cuda_stream.h"
#endif
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#ifdef PADDLE_WITH_COREX
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#ifndef PADDLE_WITH_COREX
using json = nlohmann::json;
#endif
#define CUDA_CHECK(call) \
do { \
@@ -214,19 +199,11 @@ HOSTDEVICE inline void Store(const AlignedVector<T, Size> &vec, T *addr) {
*addr_vec = vec;
}
#ifdef PADDLE_WITH_HIP
template <int Size>
HOSTDEVICE inline void Store(const AlignedVector<hip_bfloat16, Size> &vec,
int8_t *addr) {
printf("Error: Store hip_bfloat16 to int8_t is not supported!");
}
#else
template <int Size>
HOSTDEVICE inline void Store(const AlignedVector<__nv_bfloat16, Size> &vec,
int8_t *addr) {
printf("Error: Store __nv_bfloat16 to int8_t is not supported!");
}
#endif
template <int Size>
HOSTDEVICE inline void Store(const AlignedVector<half, Size> &vec,
@@ -260,7 +237,6 @@ inline int GetBlockSize(int vocab_size) {
}
}
#ifndef PADDLE_WITH_COREX
inline json readJsonFromFile(const std::string &filePath) {
std::ifstream file(filePath);
if (!file.is_open()) {
@@ -271,7 +247,6 @@ inline json readJsonFromFile(const std::string &filePath) {
file >> j;
return j;
}
#endif
#define cudaCheckError() \
{ \
@@ -443,7 +418,6 @@ inline std::string base64_decode(const std::string &encoded_string) {
return ret;
}
#ifndef PADDLE_WITH_COREX
template <typename T>
inline T get_relative_best(nlohmann::json *json_data,
const std::string &target_key,
@@ -456,7 +430,6 @@ inline T get_relative_best(nlohmann::json *json_data,
return default_value;
}
}
#endif
__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids,
int length) {
@@ -486,12 +459,7 @@ template <typename T>
static void PrintMatrix3(const T *mat_d, int num, std::string name) {
std::vector<T> tmp(num);
#ifdef PADDLE_WITH_HIP
hipMemcpy(tmp.data(), mat_d, sizeof(T) * num, hipMemcpyDeviceToHost);
#else
cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost);
#endif
std::ofstream outfile;
outfile.open(name + ".txt", std::ios::out);
@@ -508,7 +476,6 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
outfile.close();
}
#ifndef PADDLE_WITH_HIP
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
int mode = 0) {
uint32_t flag;
@@ -548,7 +515,6 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
return max_shared_mem_per_block_opt_in;
}
#endif
inline int GetSMVersion() {
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(

View File

@@ -73,6 +73,7 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -180,6 +181,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -214,6 +216,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -50,6 +50,7 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -161,7 +161,7 @@ __global__ void combine_prmt_back_kernel(
expanded_permuted_rows + expanded_permuted_row * cols; // prmt后的位置对应的值
Load<T, VEC_SIZE>(expanded_permuted_rows_row_ptr + tid * VEC_SIZE, &load_vec);
const int expert_idx = expert_for_source_row[k_offset]; // 当前位置对应的专家
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的down_proj的bias
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的ffn2的bias
if (bias_ptr) {
Load<T, VEC_SIZE>(bias_ptr + tid * VEC_SIZE, &bias_vec);
#pragma unroll
@@ -188,7 +188,7 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out,
const paddle::Tensor& expert_scales_float,
const paddle::Tensor& permute_indices_per_token,
const paddle::Tensor& top_k_indices,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const bool norm_topk_prob,
const float routed_scaling_factor,
const int num_rows,
@@ -206,7 +206,7 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out,
combine_prmt_back_kernel<<<gridx, threads, 0, stream>>>(
ffn_out.data<data_t>(),
output->data<data_t>(),
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
expert_scales_float.data<float>(),
permute_indices_per_token.data<int32_t>(),
top_k_indices.data<int>(),
@@ -223,7 +223,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
const paddle::Tensor& expert_scales_float, // dst_weights
const paddle::Tensor& permute_indices_per_token, // permute_indices_per_token
const paddle::Tensor& top_k_indices, // dst_indices
const paddle::optional<paddle::Tensor>& down_proj_bias,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const bool norm_topk_prob,
const float routed_scaling_factor) {
@@ -242,7 +242,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
expert_scales_float,
permute_indices_per_token,
top_k_indices,
down_proj_bias,
ffn2_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
@@ -255,7 +255,7 @@ std::vector<paddle::Tensor> EPMoeExpertCombine(
expert_scales_float,
permute_indices_per_token,
top_k_indices,
down_proj_bias,
ffn2_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
@@ -274,7 +274,7 @@ __global__ void permute_x_kernel(const T *src_x,
const int64_t *topk_idx,
const float *topk_weights,
const int *token_nums_per_expert,
const float *up_gate_proj_in_scale,
const float *ffn1_in_scale,
const int moe_topk,
const int num_rows,
const int token_nums_this_rank,
@@ -327,9 +327,9 @@ __global__ void permute_x_kernel(const T *src_x,
// cp x
for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) {
Load<T, vec_size>(&src_x[s_token_idx * hidden_size + v_id * vec_size], &src_vec);
if (up_gate_proj_in_scale) {
if (ffn1_in_scale) {
for (int i = 0; i < vec_size; i++) {
float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast<float>(src_vec[i]);
float quant_value = max_bound * ffn1_in_scale[expert_now] * static_cast<float>(src_vec[i]);
if (RoundType == 0) {
res_vec[i] = static_cast<OutT>(ClipFunc<float>(rint(quant_value), min_bound, max_bound));
} else {
@@ -353,7 +353,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_weights,
const paddle::Tensor& token_nums_per_expert,
const paddle::optional<paddle::Tensor>& up_gate_proj_in_scale,
const paddle::optional<paddle::Tensor>& ffn1_in_scale,
const std::string& moe_quant_type,
const int moe_topk,
const int num_rows,
@@ -383,7 +383,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
@@ -404,7 +404,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
@@ -427,7 +427,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
@@ -448,7 +448,7 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
ffn1_in_scale ? ffn1_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
@@ -472,7 +472,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
const paddle::Tensor& input,
const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_weights,
const paddle::optional<paddle::Tensor>& up_gate_proj_in_scale,
const paddle::optional<paddle::Tensor>& ffn1_in_scale,
const std::vector<int>& token_nums_per_expert,
const int token_nums_this_rank,
const std::string& moe_quant_type) {
@@ -516,7 +516,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
topk_ids,
topk_weights,
num_experts_per_rank_tensor,
up_gate_proj_in_scale,
ffn1_in_scale,
moe_quant_type,
moe_topk,
num_rows,
@@ -536,7 +536,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
topk_ids,
topk_weights,
num_experts_per_rank_tensor,
up_gate_proj_in_scale,
ffn1_in_scale,
moe_quant_type,
moe_topk,
num_rows,
@@ -568,7 +568,7 @@ std::vector<std::vector<int64_t>> EPMoeExpertDispatchInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& topk_ids_shape,
const std::vector<int64_t>& topk_weights_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_in_scale_dtype,
const paddle::optional<std::vector<int64_t>>& ffn1_in_scale_dtype,
const std::vector<int>& token_nums_per_expert,
const int token_nums_this_rank) {
int token_rows = -1;
@@ -610,7 +610,7 @@ std::vector<paddle::DataType> EPMoeExpertDispatchInferDtype(
PD_BUILD_STATIC_OP(ep_moe_expert_dispatch)
.Inputs({"input", "topk_ids", "topk_weights",
paddle::Optional("up_gate_proj_in_scale")})
paddle::Optional("ffn1_in_scale")})
.Outputs({"permute_input",
"permute_indices_per_token",
"token_nums_per_expert_cumsum",
@@ -870,9 +870,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_weights,
const paddle::Tensor& num_experts_per_rank_tensor,
const paddle::Tensor& num_experts_per_rank_padded_tensor,
const bool use_in_ep,
const int token_nums_this_rank_padded) {
const paddle::Tensor& num_experts_per_rank_padded_tensor) {
const auto input_type = input.dtype();
const int moe_topk = topk_ids.dims()[1];
auto place = input.place();
@@ -888,21 +886,22 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
const int hidden_size = input.dims()[input_dims.size() - 1];
const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0];
int32_t token_nums_feed_to_ffn = use_in_ep ? token_nums_this_rank_padded : token_rows * moe_topk + num_experts_per_rank * (128-1);
int32_t token_nums_this_rank_padded = token_rows * moe_topk + num_experts_per_rank * (128-1);
// token_nums_this_rank_padded = token_nums_this_rank_padded_useless;
auto permute_input = GetEmptyTensor(
{token_nums_feed_to_ffn, hidden_size},
{token_nums_this_rank_padded, hidden_size},
input_type,
place);
auto permute_scale = GetEmptyTensor(
{token_nums_feed_to_ffn, hidden_size / 128},
{token_nums_this_rank_padded, hidden_size / 128},
paddle::DataType::FLOAT32,
place);
auto m_indices = paddle::full({token_nums_feed_to_ffn}, -1, paddle::DataType::INT32, place);
auto m_indices = paddle::full({token_nums_this_rank_padded}, -1, paddle::DataType::INT32, place);
auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
auto token_nums_per_expert_padded_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
auto dst_weights = GetEmptyTensor({token_nums_feed_to_ffn}, paddle::DataType::FLOAT32, place);
auto dst_weights = GetEmptyTensor({token_nums_this_rank_padded}, paddle::DataType::FLOAT32, place);
auto dst_indices = GetEmptyTensor({num_rows, num_experts_per_rank}, paddle::DataType::INT32, place);
auto permute_indices_per_token = paddle::full({num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place);
auto cumsum_idx_gpu = paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place);
@@ -950,5 +949,4 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8)
"dst_indices",
"cumsum_idx_gpu",
"m_indices"})
.Attrs({"use_in_ep:bool", "token_nums_this_rank_padded:int"})
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8));

View File

@@ -54,12 +54,12 @@ void compute_total_rows_before_expert(int* sorted_indices,
template <paddle::DataType T>
void FusedMoeKernel(const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& up_gate_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const paddle::Tensor& ffn1_weight,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const std::string& quant_method,
const int moe_topk,
const bool group_moe,
@@ -84,12 +84,12 @@ void FusedMoeKernel(const paddle::Tensor& input,
moe_compute.ComputeFFN(&input,
&gate_weight,
&up_gate_proj_weight,
up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr,
up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr,
&down_proj_weight,
down_proj_scale ? down_proj_scale.get_ptr() : nullptr,
down_proj_bias ? down_proj_bias.get_ptr() : nullptr,
&ffn1_weight,
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
&ffn2_weight,
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
nullptr,
moe_topk,
group_moe,
@@ -102,12 +102,12 @@ void FusedMoeKernel(const paddle::Tensor& input,
paddle::Tensor FusedExpertMoeFunc(
const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method,
const int moe_topk,
const bool norm_topk_prob,
@@ -119,12 +119,12 @@ paddle::Tensor FusedExpertMoeFunc(
case paddle::DataType::BFLOAT16:
FusedMoeKernel<paddle::DataType::BFLOAT16>(input,
gate_weight,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_bias,
down_proj_weight,
down_proj_scale,
down_proj_bias,
ffn1_weight,
ffn1_scale,
ffn1_bias,
ffn2_weight,
ffn2_scale,
ffn2_bias,
quant_method,
moe_topk,
group_moe,
@@ -134,12 +134,12 @@ paddle::Tensor FusedExpertMoeFunc(
case paddle::DataType::FLOAT16:
FusedMoeKernel<paddle::DataType::FLOAT16>(input,
gate_weight,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_bias,
down_proj_weight,
down_proj_scale,
down_proj_bias,
ffn1_weight,
ffn1_scale,
ffn1_bias,
ffn2_weight,
ffn2_scale,
ffn2_bias,
quant_method,
moe_topk,
group_moe,
@@ -155,24 +155,24 @@ paddle::Tensor FusedExpertMoeFunc(
std::vector<paddle::Tensor> FusedExpertMoe(
const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method,
const int moe_topk,
const bool norm_topk_prob,
const bool group_moe) {
return {FusedExpertMoeFunc(input,
gate_weight,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_bias,
down_proj_scale,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_bias,
ffn2_scale,
quant_method,
moe_topk,
norm_topk_prob,
@@ -182,30 +182,30 @@ std::vector<paddle::Tensor> FusedExpertMoe(
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& gate_weight_shape,
const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape) {
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
return {input_shape};
}
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& gate_weight_dtype,
const paddle::DataType& up_gate_proj_weight_dtype,
const paddle::DataType& down_proj_weight_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_bias_dtype,
const paddle::optional<paddle::DataType>& down_proj_scale_dtype) {
const paddle::DataType& ffn1_weight_dtype,
const paddle::DataType& ffn2_weight_dtype,
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
return {input_dtype};
}
/**
* @brief Fused Mixture-of-Experts (MoE) Operator
*
*
* This operator combines three key MoE operations into a single optimized kernel:
* 1. moe_dispatch - Routes tokens to top-k experts using gating network
* 2. moe_ffn - Processes tokens through parallel expert FFNs
@@ -230,12 +230,12 @@ std::vector<paddle::DataType> FusedExpertMoeInferDtype(
PD_BUILD_STATIC_OP(fused_expert_moe)
.Inputs({"input",
"gate_weight",
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_bias"),
paddle::Optional("down_proj_scale")})
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_bias"),
paddle::Optional("ffn2_scale")})
.Outputs({"output"})
.Attrs({"quant_method:std::string",
"moe_topk:int",

View File

@@ -117,18 +117,18 @@ public:
void
ComputeFFN(const paddle::Tensor *input, const paddle::Tensor *gate_weight,
const paddle::Tensor *up_gate_proj_weight,
const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_bias,
const paddle::Tensor *down_proj_weight,
const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_bias,
const paddle::Tensor *ffn1_weight,
const paddle::Tensor *ffn1_scale, const paddle::Tensor *ffn1_bias,
const paddle::Tensor *ffn2_weight,
const paddle::Tensor *ffn2_scale, const paddle::Tensor *ffn2_bias,
const paddle::Tensor *moe_token_type_ids, const int moe_topk,
const bool group_moe, const bool norm_topk_prob,
const float routed_scaling_factor, const std::string moe_type,
paddle::Tensor *output) {
auto *input_activations = input->data<T>();
auto *gating_weights = gate_weight->data<float>();
const T *fc1_expert_biases = up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
const T *fc2_expert_biases = down_proj_bias ? down_proj_bias->data<T>() : nullptr;
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
auto *output_ = output->data<T>();
auto stream = input->stream();
@@ -136,7 +136,7 @@ public:
auto input_type = input->dtype();
auto input_dims = input->dims();
auto up_gate_proj_dims = up_gate_proj_weight->dims();
auto ffn1_dims = ffn1_weight->dims();
int64_t token_num = 0;
if (input_dims.size() == 3) {
token_num = input_dims[0] * input_dims[1];
@@ -145,12 +145,12 @@ public:
}
const int64_t num_rows = token_num;
const int64_t hidden_size = up_gate_proj_dims[1];
const int64_t hidden_size = ffn1_dims[1];
int64_t inter_dim = 0;
if (moe_type == "qkv") {
inter_dim = up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
} else {
inter_dim = up_gate_proj_dims[2];
inter_dim = ffn1_dims[2];
}
if (gemm_method_ == "weight_only_int4") {
@@ -158,7 +158,7 @@ public:
}
const int64_t inter_size = inter_dim;
const int64_t num_experts = up_gate_proj_dims[0];
const int64_t num_experts = ffn1_dims[0];
const int64_t k = moe_topk;
int64_t bytes =
@@ -260,38 +260,38 @@ public:
total_rows_before_expert_, stream);
if (gemm_method_ == "weight_only_int8") {
typename Int8Traits::Arguments up_gate_proj_quant_args;
typename Int8Traits::Arguments ffn1_quant_args;
int8_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const uint8_t *>(up_gate_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
reinterpret_cast<const uint8_t *>(ffn1_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(ffn1_scale->data<T>()),
reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
up_gate_proj_quant_args, "none", stream);
ffn1_quant_args, "none", stream);
} else if (gemm_method_ == "weight_only_int4") {
typename Int4Traits::Arguments up_gate_proj_quant_args;
typename Int4Traits::Arguments ffn1_quant_args;
int4_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const cutlass::uint4b_t *>(
up_gate_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
ffn1_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(ffn1_scale->data<T>()),
reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
up_gate_proj_quant_args, "none", stream);
ffn1_quant_args, "none", stream);
} else {
typename Fp16Traits::Arguments up_gate_proj_quant_args;
typename Fp16Traits::Arguments ffn1_quant_args;
fp16_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()), nullptr,
reinterpret_cast<const NvType *>(ffn1_weight->data<T>()), nullptr,
reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
up_gate_proj_quant_args, "none", stream);
ffn1_quant_args, "none", stream);
}
if (moe_type == "ffn") {
@@ -304,35 +304,35 @@ public:
T *fc2_result = fc2_output_tensor.data<T>();
if (gemm_method_ == "weight_only_int8") {
typename Int8Traits::Arguments down_proj_quant_args;
typename Int8Traits::Arguments ffn2_quant_args;
int8_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const uint8_t *>(down_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
reinterpret_cast<const uint8_t *>(ffn2_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(ffn2_scale->data<T>()),
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, down_proj_quant_args, stream);
num_experts, ffn2_quant_args, stream);
} else if (gemm_method_ == "weight_only_int4") {
typename Int4Traits::Arguments down_proj_quant_args;
typename Int4Traits::Arguments ffn2_quant_args;
int4_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const cutlass::uint4b_t *>(
down_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
ffn2_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(ffn2_scale->data<T>()),
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, down_proj_quant_args, stream);
num_experts, ffn2_quant_args, stream);
} else {
typename Fp16Traits::Arguments down_proj_quant_args;
typename Fp16Traits::Arguments ffn2_quant_args;
fp16_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()), nullptr,
reinterpret_cast<const NvType *>(ffn2_weight->data<T>()), nullptr,
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, down_proj_quant_args, stream);
num_experts, ffn2_quant_args, stream);
}
finalize_moe_routing_kernelLauncher<T>::run(

View File

@@ -24,12 +24,12 @@
template <paddle::DataType T>
void MoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method,
paddle::Tensor ffn_out,
@@ -51,11 +51,11 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
const int num_experts = up_gate_proj_weight.dims()[0];
const int num_experts = ffn1_weight.dims()[0];
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
assert(up_gate_proj_weight.dims().size() == 3);
int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size;
assert(ffn1_weight.dims().size() == 3);
int inter_dim = ffn1_weight.dims()[1] * ffn1_weight.dims()[2] / hidden_size;
constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k
Allocator* allocator = paddle::GetAllocator(place);
@@ -96,8 +96,8 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
using NvType = typename traits_::DataType;
auto fc1_expert_biases =
up_gate_proj_bias
? const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr())->data<data_t>()
ffn1_bias
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
: nullptr;
// This is a trick.
@@ -112,9 +112,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
int8_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const uint8_t*>(up_gate_proj_weight.data<int8_t>()),
reinterpret_cast<const uint8_t*>(ffn1_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<const NvType*>(fc1_expert_biases),
reinterpret_cast<NvType*>(fc1_out),
@@ -132,9 +132,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
int4_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const cutlass::uint4b_t*>(
up_gate_proj_weight.data<int8_t>()),
ffn1_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<const NvType*>(fc1_expert_biases),
reinterpret_cast<NvType*>(fc1_out),
@@ -151,12 +151,12 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
w4a8_moe_gemm_runner.moe_gemm(
reinterpret_cast<const int8_t *>(permute_input.data<int8_t>()),
reinterpret_cast<const cutlass::uint4b_t *>(
up_gate_proj_weight.data<int8_t>()),
ffn1_weight.data<int8_t>()),
quant_mode,
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())
->data<data_t>()),
nullptr, // up_gate_proj_scale_dyquant
nullptr, // ffn1_scale_dyquant
nullptr, // nf4_look_up_table
reinterpret_cast<NvType *>(fc1_out),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
@@ -172,7 +172,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<data_t>()),
reinterpret_cast<const NvType*>(up_gate_proj_weight.data<data_t>()),
reinterpret_cast<const NvType*>(ffn1_weight.data<data_t>()),
nullptr,
reinterpret_cast<const NvType*>(fc1_expert_biases),
reinterpret_cast<NvType*>(fc1_out),
@@ -199,9 +199,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>::Arguments quant_args;
int8_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const uint8_t*>(down_proj_weight.data<int8_t>()),
reinterpret_cast<const uint8_t*>(ffn2_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
@@ -218,9 +218,9 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
int4_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const cutlass::uint4b_t*>(
down_proj_weight.data<int8_t>()),
ffn2_weight.data<int8_t>()),
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())
->data<data_t>()),
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
@@ -232,17 +232,17 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
quant_args,
stream);
} else if (quant_method == "w4a8") {
data_t *down_proj_shift = nullptr;
data_t *down_proj_smooth = nullptr;
data_t *ffn2_shift = nullptr;
data_t *ffn2_smooth = nullptr;
Allocator::AllocationPtr int8_act_out;
int8_act_out = allocator->Allocate(
SizeOf(paddle::DataType::INT8) * act_out_tensor.numel());
MoeFastHardamardWrapper<data_t, int8_t>(
act_out_tensor.data<data_t>(),
expert_idx_per_token ? expert_idx_per_token.get().data<int64_t>() : nullptr,
down_proj_shift, // down_proj_shift->data<T>(),
down_proj_smooth, // down_proj_smooth->data<T>(),
down_proj_in_scale ? const_cast<paddle::Tensor*>(down_proj_in_scale.get_ptr())->data<float>() : nullptr,
ffn2_shift, // ffn2_shift->data<T>(),
ffn2_smooth, // ffn2_smooth->data<T>(),
ffn2_in_scale ? const_cast<paddle::Tensor*>(ffn2_in_scale.get_ptr())->data<float>() : nullptr,
1,
127.0,
-127.0,
@@ -254,12 +254,12 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
w4a8_moe_gemm_runner.moe_gemm(
reinterpret_cast<int8_t *>(int8_act_out->ptr()),
reinterpret_cast<const cutlass::uint4b_t *>(
down_proj_weight.data<int8_t>()),
ffn2_weight.data<int8_t>()),
quant_mode,
reinterpret_cast<const NvType*>(
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())
->data<data_t>()),
nullptr, // down_proj_scale_dyquant
nullptr, // ffn2_scale_dyquant
nullptr, // reinterpret_cast<const int32_t*>(d_nf4_look_up_table), // nf4_look_up_table
reinterpret_cast<NvType *>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
@@ -275,7 +275,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
typename cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>::Arguments quant_args;
fp16_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const NvType*>(down_proj_weight.data<data_t>()),
reinterpret_cast<const NvType*>(ffn2_weight.data<data_t>()),
nullptr,
reinterpret_cast<NvType*>(ffn_out_data),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
@@ -292,29 +292,29 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
paddle::Tensor MoeExpertFFNFunc(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency) {
cudaCheckError();
const auto t_type = quant_method == "w4a8" ? up_gate_proj_scale.get().dtype() : permute_input.dtype();
const auto t_type = quant_method == "w4a8" ? ffn1_scale.get().dtype() : permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, t_type);
switch (t_type) {
case paddle::DataType::BFLOAT16:
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
down_proj_in_scale,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn2_in_scale,
expert_idx_per_token,
quant_method,
ffn_out, used_in_ep_low_latency);
@@ -322,12 +322,12 @@ paddle::Tensor MoeExpertFFNFunc(
case paddle::DataType::FLOAT16:
MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
down_proj_in_scale,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn2_in_scale,
expert_idx_per_token,
quant_method,
ffn_out, used_in_ep_low_latency);
@@ -341,22 +341,22 @@ paddle::Tensor MoeExpertFFNFunc(
std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency) {
return {MoeExpertFFNFunc(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
down_proj_in_scale,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn2_in_scale,
expert_idx_per_token,
quant_method, used_in_ep_low_latency)};
}
@@ -364,12 +364,12 @@ std::vector<paddle::Tensor> MoeExpertFFN(
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const std::vector<int64_t>& permute_input_shape,
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_in_scale_shape,
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_in_scale_shape,
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
const std::string& quant_method,
const bool used_in_ep_low_latency) {
@@ -379,15 +379,15 @@ std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::DataType &permute_input_dtype,
const paddle::DataType &tokens_expert_prefix_sum_dtype,
const paddle::DataType &up_gate_proj_weight_dtype,
const paddle::DataType &down_proj_weight_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_in_scale_dtype,
const paddle::DataType &ffn1_weight_dtype,
const paddle::DataType &ffn2_weight_dtype,
const paddle::optional<paddle::DataType> &ffn1_bias_dtype,
const paddle::optional<paddle::DataType> &ffn1_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_in_scale_dtype,
const std::string &quant_method, const bool used_in_ep_low_latency) {
if (quant_method == "w4a8") {
return {up_gate_proj_scale_dtype.get()};
return {ffn1_scale_dtype.get()};
} else {
return {permute_input_dtype};
}
@@ -397,9 +397,9 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
* @brief Mixture of Experts (MoE) Feed-Forward Network Operator
*
* This operator performs the expert computation in MoE architecture, including:
* 1. First linear transformation (up_gate_proj) with optional quantization
* 1. First linear transformation (FFN1) with optional quantization
* 2. SwiGLU activation function
* 3. Second linear transformation (down_proj) with optional quantization
* 3. Second linear transformation (FFN2) with optional quantization
*
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
*
@@ -410,22 +410,22 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
* Shape: [num_experts]
* dtype: int64
* - up_gate_proj_weight: First FFN layer weights
* - ffn1_weight: First FFN layer weights
* Shape: [num_experts, inter_size * 2, hidden_size]
* dtype: Same as input (unquantized) or int8 (quantized)
* - down_proj_weight: Second FFN layer weights
* - ffn2_weight: Second FFN layer weights
* Shape: [num_experts, hidden_size, inter_size]
* dtype: Same as input (unquantized) or int8 (quantized)
* - up_gate_proj_bias: Optional bias for first FFN layer
* - ffn1_bias: Optional bias for first FFN layer
* Shape: [num_experts, inter_size * 2]
* dtype: Same as input
* - up_gate_proj_scale: Quantization scales for first FFN layer
* - ffn1_scale: Quantization scales for first FFN layer
* Shape: [num_experts, inter_size * 2]
* dtype: Same as input
* - down_proj_scale: Quantization scales for second FFN layer
* - ffn2_scale: Quantization scales for second FFN layer
* Shape: [num_experts, hidden_size]
* dtype: Same as input
* - down_proj_in_scale: Optional input scales for second FFN layer (w4a8 only)
* - ffn2_in_scale: Optional input scales for second FFN layer (w4a8 only)
* dtype: float32
* - expert_idx_per_token: Optional expert indices per token (w4a8 only)
* Shape: [total_tokens]
@@ -434,7 +434,7 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
* Outputs:
* - output_tensor: Output tensor after MoE FFN computation
* Shape: Same as permute_input
* dtype: Same as input (or up_gate_proj_scale dtype for w4a8)
* dtype: Same as input (or ffn1_scale dtype for w4a8)
*
* Attributes:
* - quant_method: Quantization method to use
@@ -449,12 +449,12 @@ std::vector<paddle::DataType> MoeExpertFFNInferDtype(
PD_BUILD_STATIC_OP(moe_expert_ffn)
.Inputs({"permute_input",
"tokens_expert_prefix_sum",
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_scale"),
paddle::Optional("down_proj_in_scale"),
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_scale"),
paddle::Optional("ffn2_in_scale"),
paddle::Optional("expert_idx_per_token")})
.Outputs({"output_tensor"})
.Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool"})

View File

@@ -23,17 +23,17 @@
template <typename DataT, typename NvType, typename WeightSavedT, cutlass::WintQuantMethod QuantMethod>
void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::Tensor* up_gate_proj_bias,
const paddle::Tensor* up_gate_proj_super_scale,
const paddle::Tensor* down_proj_super_scale,
const paddle::Tensor* up_gate_proj_local_scale,
const paddle::Tensor* up_gate_proj_code_scale,
const paddle::Tensor* up_gate_proj_code_zp,
const paddle::Tensor* down_proj_local_scale,
const paddle::Tensor* down_proj_code_scale,
const paddle::Tensor* down_proj_code_zp,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::Tensor* ffn1_bias,
const paddle::Tensor* ffn1_super_scale,
const paddle::Tensor* ffn2_super_scale,
const paddle::Tensor* ffn1_local_scale,
const paddle::Tensor* ffn1_code_scale,
const paddle::Tensor* ffn1_code_zp,
const paddle::Tensor* ffn2_local_scale,
const paddle::Tensor* ffn2_code_scale,
const paddle::Tensor* ffn2_code_zp,
paddle::Tensor fc1_out,
paddle::Tensor ffn_out,
const int64_t total_rows_in_ll_else_minus1,
@@ -46,15 +46,15 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
using WeightOnlyTraits = cutlass::WintQuantTraits<NvType, QuantMethod>;
using WeightType = typename WeightOnlyTraits::WeightType;
typename WeightOnlyTraits::Arguments up_gate_proj_quant_args;
typename WeightOnlyTraits::Arguments down_proj_quant_args;
typename WeightOnlyTraits::Arguments ffn1_quant_args;
typename WeightOnlyTraits::Arguments ffn2_quant_args;
if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
up_gate_proj_quant_args.local_scale_ptr = up_gate_proj_local_scale->data<uint8_t>();
up_gate_proj_quant_args.code_scale_ptr = up_gate_proj_code_scale->data<float>();
up_gate_proj_quant_args.code_zp_ptr = up_gate_proj_code_zp->data<float>();
down_proj_quant_args.local_scale_ptr = down_proj_local_scale->data<uint8_t>();
down_proj_quant_args.code_scale_ptr = down_proj_code_scale->data<float>();
down_proj_quant_args.code_zp_ptr = down_proj_code_zp->data<float>();
ffn1_quant_args.local_scale_ptr = ffn1_local_scale->data<uint8_t>();
ffn1_quant_args.code_scale_ptr = ffn1_code_scale->data<float>();
ffn1_quant_args.code_zp_ptr = ffn1_code_zp->data<float>();
ffn2_quant_args.local_scale_ptr = ffn2_local_scale->data<uint8_t>();
ffn2_quant_args.code_scale_ptr = ffn2_code_scale->data<float>();
ffn2_quant_args.code_zp_ptr = ffn2_code_zp->data<float>();
}
auto moe_gemm_runner = MoeGemmRunner<NvType, WeightOnlyTraits>();
@@ -62,9 +62,9 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<DataT>()),
reinterpret_cast<const WeightType*>(up_gate_proj_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(up_gate_proj_super_scale ? up_gate_proj_super_scale->data<DataT>() : nullptr),
reinterpret_cast<const NvType*>(up_gate_proj_bias ? up_gate_proj_bias->data<DataT>() : nullptr),
reinterpret_cast<const WeightType*>(ffn1_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(ffn1_super_scale ? ffn1_super_scale->data<DataT>() : nullptr),
reinterpret_cast<const NvType*>(ffn1_bias ? ffn1_bias->data<DataT>() : nullptr),
reinterpret_cast<NvType*>(fc1_out.data<DataT>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
@@ -72,7 +72,7 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
inter_size,
hidden_size,
num_experts,
up_gate_proj_quant_args,
ffn1_quant_args,
"none",
stream);
@@ -85,8 +85,8 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out.data<DataT>()),
reinterpret_cast<const WeightType*>(down_proj_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(down_proj_super_scale ? down_proj_super_scale->data<DataT>() : nullptr),
reinterpret_cast<const WeightType*>(ffn2_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(ffn2_super_scale ? ffn2_super_scale->data<DataT>() : nullptr),
reinterpret_cast<NvType*>(ffn_out.data<DataT>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
@@ -94,24 +94,24 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
hidden_size,
inter_size / 2,
num_experts,
down_proj_quant_args,
ffn2_quant_args,
stream);
}
template <paddle::DataType T>
void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
paddle::Tensor ffn_out,
bool used_in_ep_low_latency) {
using namespace phi;
@@ -121,12 +121,12 @@ void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
auto place = permute_input.place();
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
assert(up_gate_proj_weight.dims().size() == 3);
assert(ffn1_weight.dims().size() == 3);
const int num_experts = up_gate_proj_weight.dims()[0];
const int num_experts = ffn1_weight.dims()[0];
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size;
int inter_dim = ffn1_weight.dims()[1] * ffn1_weight.dims()[2] / hidden_size;
const int64_t inter_size = inter_dim * 4;
@@ -160,17 +160,17 @@ void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
WeightOnlyMoeFFNKernel<data_t, NvType, uint8_t, cutlass::WintQuantMethod::kWeightOnlyInt2>(
permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_code_zp.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_code_zp.get_ptr()),
ffn1_weight,
ffn2_weight,
const_cast<paddle::Tensor*>(ffn1_bias.get_ptr()),
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn1_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn1_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn1_code_zp.get_ptr()),
const_cast<paddle::Tensor*>(ffn2_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn2_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(ffn2_code_zp.get_ptr()),
fc1_out_tensor,
ffn_out,
total_rows_in_ll_else_minus1,
@@ -184,17 +184,17 @@ void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
const bool used_in_ep_low_latency) {
const auto dtype = permute_input.dtype();
@@ -204,34 +204,34 @@ paddle::Tensor MoeExpertFFNWint2Func(
case paddle::DataType::BFLOAT16:
MoeFFNWint2Kernel<paddle::DataType::BFLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn1_local_scale,
ffn1_code_scale,
ffn1_code_zp,
ffn2_local_scale,
ffn2_code_scale,
ffn2_code_zp,
ffn_out,
used_in_ep_low_latency);
break;
case paddle::DataType::FLOAT16:
MoeFFNWint2Kernel<paddle::DataType::FLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn1_local_scale,
ffn1_code_scale,
ffn1_code_zp,
ffn2_local_scale,
ffn2_code_scale,
ffn2_code_zp,
ffn_out,
used_in_ep_low_latency);
break;
@@ -244,49 +244,49 @@ paddle::Tensor MoeExpertFFNWint2Func(
std::vector<paddle::Tensor> MoeExpertFFNWint2(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn1_local_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_scale,
const paddle::optional<paddle::Tensor>& ffn1_code_zp,
const paddle::optional<paddle::Tensor>& ffn2_local_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_scale,
const paddle::optional<paddle::Tensor>& ffn2_code_zp,
const bool used_in_ep_low_latency) {
return {MoeExpertFFNWint2Func(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
ffn1_local_scale,
ffn1_code_scale,
ffn1_code_zp,
ffn2_local_scale,
ffn2_code_scale,
ffn2_code_zp,
used_in_ep_low_latency)};
}
std::vector<std::vector<int64_t>> MoeExpertFFNWint2InferShape(
const std::vector<int64_t>& permute_input_shape,
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_code_zp_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_code_zp_shape,
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_code_zp_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_code_zp_shape,
const bool used_in_ep_low_latency) {
return {permute_input_shape};
@@ -295,17 +295,17 @@ std::vector<std::vector<int64_t>> MoeExpertFFNWint2InferShape(
std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
const paddle::DataType &permute_input_dtype,
const paddle::DataType &tokens_expert_prefix_sum_dtype,
const paddle::DataType &up_gate_proj_weight_dtype,
const paddle::DataType &down_proj_weight_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_local_scale_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_code_scale_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_code_zp_dtype,
const paddle::optional<paddle::DataType> &down_proj_local_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_code_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_code_zp_dtype,
const paddle::DataType &ffn1_weight_dtype,
const paddle::DataType &ffn2_weight_dtype,
const paddle::optional<paddle::DataType> &ffn1_bias_dtype,
const paddle::optional<paddle::DataType> &ffn1_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_scale_dtype,
const paddle::optional<paddle::DataType> &ffn1_local_scale_dtype,
const paddle::optional<paddle::DataType> &ffn1_code_scale_dtype,
const paddle::optional<paddle::DataType> &ffn1_code_zp_dtype,
const paddle::optional<paddle::DataType> &ffn2_local_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_code_scale_dtype,
const paddle::optional<paddle::DataType> &ffn2_code_zp_dtype,
const bool used_in_ep_low_latency) {
return {permute_input_dtype};
@@ -315,9 +315,9 @@ std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
* @brief Weight-Only Quantized Mixture of Experts (MoE) Feed-Forward Network Operator
*
* This operator performs the expert computation in MoE architecture, including:
* 1. First linear transformation (up_gate_proj) with optional quantization
* 1. First linear transformation (FFN1) with optional quantization
* 2. SwiGLU activation function
* 3. Second linear transformation (down_proj) with optional quantization
* 3. Second linear transformation (FFN2) with optional quantization
*
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
*
@@ -328,26 +328,26 @@ std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
* Shape: [num_experts]
* dtype: int64
* - up_gate_proj_weight: First FFN layer weights
* - ffn1_weight: First FFN layer weights
* Shape: [num_experts, inter_size * 2, hidden_size]
* dtype: Same as input (unquantized) or int8 (quantized)
* - down_proj_weight: Second FFN layer weights
* - ffn2_weight: Second FFN layer weights
* Shape: [num_experts, hidden_size, inter_size]
* dtype: Same as input (unquantized) or int8 (quantized)
* - up_gate_proj_bias: Optional bias for first FFN layer
* - ffn1_bias: Optional bias for first FFN layer
* Shape: [num_experts, inter_size * 2]
* dtype: Same as input
* - up_gate_proj_scale: Quantization scales for first FFN layer
* - ffn1_scale: Quantization scales for first FFN layer
* Shape: [num_experts, inter_size * 2]
* dtype: Same as input
* - down_proj_scale: Quantization scales for second FFN layer
* - ffn2_scale: Quantization scales for second FFN layer
* Shape: [num_experts, hidden_size]
* dtype: Same as input
*
* Outputs:
* - output_tensor: Output tensor after MoE FFN computation
* Shape: Same as permute_input
* dtype: Same as input (or up_gate_proj_scale dtype for w4a8)
* dtype: Same as input (or ffn1_scale dtype for w4a8)
*
* Attributes:
* - used_in_ep_low_latency: Whether running in low latency mode
@@ -359,17 +359,17 @@ std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
PD_BUILD_STATIC_OP(moe_expert_ffn_wint2)
.Inputs({"permute_input",
"tokens_expert_prefix_sum",
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_scale"),
paddle::Optional("up_gate_proj_local_scale"),
paddle::Optional("up_gate_proj_code_scale"),
paddle::Optional("up_gate_proj_code_zp"),
paddle::Optional("down_proj_local_scale"),
paddle::Optional("down_proj_code_scale"),
paddle::Optional("down_proj_code_zp")})
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_scale"),
paddle::Optional("ffn1_local_scale"),
paddle::Optional("ffn1_code_scale"),
paddle::Optional("ffn1_code_zp"),
paddle::Optional("ffn2_local_scale"),
paddle::Optional("ffn2_code_scale"),
paddle::Optional("ffn2_code_zp")})
.Outputs({"output_tensor"})
.Attrs({"used_in_ep_low_latency:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertFFNWint2))

View File

@@ -25,7 +25,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
const paddle::Tensor &top_k_weight,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const bool norm_topk_prob,
const float routed_scaling_factor, const int num_rows,
const int hidden_size, const int topk,
@@ -38,7 +38,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
finalize_moe_routing_kernelLauncher<data_t>::run(
ffn_out.data<data_t>(), output->data<data_t>(),
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
top_k_weight.data<float>(), permute_indices_per_token.data<int32_t>(),
top_k_indices.data<int>(), num_rows, hidden_size, topk,
static_cast<int>(1), norm_topk_prob, routed_scaling_factor, stream);
@@ -48,7 +48,7 @@ paddle::Tensor MoeExpertReduceFunc(
const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const bool norm_topk_prob, const float routed_scaling_factor) {
const auto input_type = ffn_out.dtype();
auto place = ffn_out.place();
@@ -63,13 +63,13 @@ paddle::Tensor MoeExpertReduceFunc(
case paddle::DataType::BFLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(
ffn_out, top_k_weight, permute_indices_per_token, top_k_indices,
down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
ffn2_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
topk, &output);
break;
case paddle::DataType::FLOAT16:
MoeReduceKernel<paddle::DataType::BFLOAT16>(
ffn_out, top_k_weight, permute_indices_per_token, top_k_indices,
down_proj_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
ffn2_bias, norm_topk_prob, routed_scaling_factor, num_rows, hidden_size,
topk, &output);
break;
default:
@@ -83,10 +83,10 @@ MoeExpertReduce(const paddle::Tensor &ffn_out,
const paddle::Tensor &top_k_weight,
const paddle::Tensor &permute_indices_per_token,
const paddle::Tensor &top_k_indices,
const paddle::optional<paddle::Tensor> &down_proj_bias,
const paddle::optional<paddle::Tensor> &ffn2_bias,
const bool norm_topk_prob, const float routed_scaling_factor) {
return {MoeExpertReduceFunc(ffn_out, top_k_weight, permute_indices_per_token,
top_k_indices, down_proj_bias, norm_topk_prob,
top_k_indices, ffn2_bias, norm_topk_prob,
routed_scaling_factor)};
}
@@ -95,7 +95,7 @@ std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
const std::vector<int64_t> &top_k_weight_shape,
const std::vector<int64_t> &permute_indices_per_token_shape,
const std::vector<int64_t> &top_k_indices_shape,
const paddle::optional<std::vector<int64_t>> &down_proj_bias_shape) {
const paddle::optional<std::vector<int64_t>> &ffn2_bias_shape) {
const int moe_topk = top_k_indices_shape[1];
auto out_shape = ffn_out_shape;
if (out_shape[0] != -1) out_shape[0] /= moe_topk;
@@ -107,19 +107,19 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
const paddle::DataType &top_k_weight_dtype,
const paddle::DataType &permute_indices_per_token_dtype,
const paddle::DataType &top_k_indices_dtype,
const paddle::optional<paddle::DataType> &down_proj_bias_dtype) {
const paddle::optional<paddle::DataType> &ffn2_bias_dtype) {
return {ffn_out_dtype};
}
/**
* @brief Mixture of Experts (MoE) Expert Reduce Operator
*
*
* This operator performs the following key functions:
* 1. Combines outputs from multiple experts based on routing weights
* 2. Applies optional bias and scaling to the combined output
* 3. Restores the original token order from permuted expert outputs
*
*
* Inputs:
* - ffn_out: Outputs from all expert networks (permuted)
* Shape: [total_tokens * moe_topk, hidden_size]
@@ -133,19 +133,19 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
* - top_k_indices: Indices of selected top-k experts for each token
* Shape: [total_tokens, moe_topk]
* dtype: int32
* - down_proj_bias: Optional bias term for expert outputs (hidden_size)
*
* - ffn2_bias: Optional bias term for expert outputs (hidden_size)
*
* Outputs:
* - output: Combined expert outputs in original token order
* Shape: [total_tokens, hidden_size]
* dtype: Same as ffn_out
*
*
* Attributes:
* - norm_topk_prob: Whether to normalize top-k probabilities
* (true: weights sum to 1 for each token,
* false: use raw weights)
* - routed_scaling_factor: Scaling factor applied to top-k probabilities
*
*
* Note:
* - The operator expects permuted expert outputs from moe_expert_dispatch
* - When norm_topk_prob is true, weights are normalized per token
@@ -154,7 +154,7 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
*/
PD_BUILD_STATIC_OP(moe_expert_reduce)
.Inputs({"ffn_out", "top_k_weight", "permute_indices_per_token",
"top_k_indices", paddle::Optional("down_proj_bias")})
"top_k_indices", paddle::Optional("ffn2_bias")})
.Outputs({"output"})
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
.SetKernelFn(PD_KERNEL(MoeExpertReduce))

View File

@@ -26,6 +26,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -98,6 +99,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
seq_lens_encoder,
cu_seqlens_q,
padding_offsets,
cum_offsets,
block_tables,
decoder_batch_ids,
decoder_tile_ids_per_batch,
@@ -126,7 +128,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
seq_lens_this_time, // q_seq_len is 1
seq_lens_decoder,
padding_offsets,
cu_seqlens_q,
cum_offsets,
block_tables,
max_input_length,
max_len_kv_data,
@@ -149,6 +151,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -198,7 +201,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = seq_lens_this_time.dims()[0];
meta_data.batch_size = cum_offsets.dims()[0];
switch (query.dtype()) {
case paddle::DataType::BFLOAT16: {
@@ -212,6 +215,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
seq_lens_this_time,
cu_seqlens_q,
padding_offsets,
cum_offsets,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
@@ -258,6 +262,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
seq_lens_this_time,
cu_seqlens_q,
padding_offsets,
cum_offsets,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
@@ -311,6 +316,7 @@ std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& padding_offsets_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
@@ -365,6 +371,7 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& padding_offsets_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
@@ -419,6 +426,7 @@ PD_BUILD_OP(multi_head_latent_attention)
"seq_lens_this_time",
"cu_seqlens_q",
"padding_offsets",
"cum_offsets",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",

View File

@@ -18,6 +18,7 @@
#include <algorithm>
#include <optional>
#include "helper.h"
#include "noauxtc_kernel.h"
std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,

View File

@@ -17,11 +17,11 @@
#pragma once
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include "helper.h"
namespace cg = cooperative_groups;
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
constexpr int32_t WARP_SIZE = 32;
constexpr int32_t BLOCK_SIZE = 512;
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;

View File

@@ -91,12 +91,7 @@ std::vector<paddle::Tensor> rebuild_padding(
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(tmp_out.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = tmp_out.stream();
#endif
std::vector<int64_t> tmp_out_shape = tmp_out.shape();
const int token_num = tmp_out_shape[0];
const int dim_embed = tmp_out_shape[1];
@@ -130,7 +125,7 @@ std::vector<paddle::Tensor> rebuild_padding(
if (output_padding_offset) {
RebuildAppendPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
<<<grid_size, blocksize, 0, tmp_out.stream()>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
reinterpret_cast<const DataType_ *>(tmp_out.data<data_t>()),
cum_offsets.data<int>(),
@@ -143,7 +138,7 @@ std::vector<paddle::Tensor> rebuild_padding(
elem_nums);
} else {
RebuildPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
<<<grid_size, blocksize, 0, tmp_out.stream()>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(tmp_out.data<data_t>())),

View File

@@ -376,6 +376,7 @@ __global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
}
// scan/find
constexpr int WARP_SIZE = 32;
constexpr int WARP_COUNT = NumBuckets / WARP_SIZE;
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();

View File

@@ -18,7 +18,6 @@
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
const paddle::Tensor &top_p,
const paddle::optional<paddle::Tensor> &top_k,
int seed) {
std::vector<int64_t> probs_shape = probs.shape();
unsigned int batch_size = probs_shape[0];
@@ -41,18 +40,10 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
cudaError_t status;
if (top_k) {
status = sampling::TopKTopPSamplingFromProb<float, int64_t>(
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
batch_size, top_p.data<float>(), top_k.get().data<int64_t>(), vocab_size,
true, philox_seed, philox_offset, cu_stream);
}
else {
status = sampling::TopPSamplingFromProb<float, int64_t>(
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
batch_size, top_p.data<float>(), vocab_size,
true, philox_seed, philox_offset, cu_stream);
}
status = sampling::TopKTopPSamplingFromProb<float, int64_t>(
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
batch_size, top_p.data<float>(), vocab_size,
true, philox_seed, philox_offset, cu_stream);
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
@@ -62,21 +53,19 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
std::vector<std::vector<int64_t>>
TopPSamplingRejectInferShape(const std::vector<int64_t> &probs_shape,
const std::vector<int64_t> &top_p_shape,
const paddle::optional<std::vector<int64_t>> &top_k_shape) {
const std::vector<int64_t> &top_p_shape) {
int64_t bs = probs_shape[0];
return {{bs, 1}};
}
std::vector<paddle::DataType>
TopPSamplingRejectInferDtype(const paddle::DataType &probs_dtype,
const paddle::DataType &top_p_dtype,
const paddle::optional<paddle::DataType> &top_k_dtype) {
const paddle::DataType &top_p_shape) {
return {paddle::DataType::INT64};
}
PD_BUILD_STATIC_OP(rejection_top_p_sampling)
.Inputs({"probs", "top_p", paddle::Optional("top_k")})
.Inputs({"probs", "top_p"})
.Outputs({"samples"})
.Attrs({"seed: int"})
.SetKernelFn(PD_KERNEL(TopPSamplingReject))

View File

@@ -279,8 +279,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType, typename IdType>
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
float* top_p_arr, IdType* top_k_arr,
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, float* top_p_arr,
uint32_t d, uint64_t philox_seed,
uint64_t philox_offset) {
const uint32_t batch_size = gridDim.x;
@@ -288,8 +287,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
curandStatePhilox4_32_10_t state;
curand_init(philox_seed, bx, philox_offset, &state);
const uint32_t row_idx = bx;
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
const float p = top_p_arr[row_idx];
const uint32_t k = top_p_arr[row_idx] == 0 ? 1 : 20;
const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx];
extern __shared__ __align__(
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
@@ -480,7 +479,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
if (aggregate_gt_pivot_0 < top_p) {
// case 1: pivot_0 accepted
break;
}
}
if (aggregate_gt_pivot_1 < top_p) {
// case 2: pivot_0 rejected, pivot_1 accepted
low = pivot_0;
@@ -498,183 +497,6 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
}
}
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
typename TempStorage>
__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d,
TempStorage& temp_storage) {
const uint32_t tx = threadIdx.x;
vec_t<float, VEC_SIZE> in_data_vec;
float max_val = 0;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
in_data_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
float in_data_[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
in_data_[j] = in_data_vec[j];
}
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
__syncthreads();
}
if (tx == 0) {
temp_storage.max_val = max_val;
}
__syncthreads();
return temp_storage.max_val;
}
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
struct RenormTempStorage {
union {
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
reduce_value_count;
} block_prim;
struct {
float max_val;
float min_val;
union {
struct {
float values[2];
};
struct {
int counts[2];
};
struct {
ValueCount<float> pairs[2];
};
} block_aggregate;
};
};
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
typename DType, typename IdType>
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = bx;
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
vec_t<float, VEC_SIZE> probs_vec;
if (k < d) {
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
uint8_t smem_renorm[];
auto& temp_storage =
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
temp_storage.max_val = 0;
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);
double low = 0, high = max_val;
float min_gt_low, max_le_high;
float sum_low = 1;
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
// loop invariant:
// - f(low) >= k, f(high) < k
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
// stopping condition: min_gt_low == max_le_high
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
do {
double pivot_0 = (high + 2 * low) / 3;
double pivot_1 = (2 * high + low) / 3;
ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
min_gt_low = high;
max_le_high = low;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot_0_pair[j] = {
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
probs_gt_pivot_1_pair[j] = {
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
min_gt_low = min(min_gt_low, probs_vec[j]);
}
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
max_le_high = max(max_le_high, probs_vec[j]);
}
}
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
__syncthreads();
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
__syncthreads();
}
min_gt_low =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(min_gt_low, cub::Min());
__syncthreads();
max_le_high =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(max_le_high, cub::Max());
if (tx == 0) {
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
temp_storage.min_val = min_gt_low;
temp_storage.max_val = max_le_high;
}
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0];
aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1];
min_gt_low = temp_storage.min_val;
max_le_high = temp_storage.max_val;
if (aggregate_gt_pivot_1.count >= k) {
low = pivot_1;
sum_low = float(aggregate_gt_pivot_1.value);
} else if (aggregate_gt_pivot_0.count >= k) {
low = pivot_0;
high = min(pivot_1, max_le_high);
sum_low = float(aggregate_gt_pivot_0.value);
} else {
high = min(pivot_0, max_le_high);
}
} while (min_gt_low != max_le_high);
normalizer = ptx_rcp(max(sum_low, 1e-8));
pivot = low;
}
// normalize
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
}
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
}
}
template <typename T, typename IdType>
cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val,
@@ -707,7 +529,7 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
template <typename T, typename IdType>
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
uint32_t batch_size, const T *top_p_val,
uint32_t d, bool deterministic,
uint64_t philox_seed, uint64_t philox_offset,
cudaStream_t stream = 0) {
@@ -718,7 +540,7 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &output, &top_p_val, &top_k_val,
void* args[] = {&probs, &output, &top_p_val,
&d, &philox_seed, &philox_offset};
DISPATCH_ALIGNED_VEC_SIZE(
@@ -734,26 +556,4 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
});
}
template <typename DType, typename IdType>
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
uint32_t batch_size, uint32_t d,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &renormed_prob, &top_k_arr, &d};
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
});
}
} // namespace sampling
} // namespace sampling

View File

@@ -1,61 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/phi/backends/context_pool.h"
#include "sample_kernels/sampling.cuh"
std::vector<paddle::Tensor> TopKRenorm(const paddle::Tensor &probs,
const paddle::Tensor &top_k) {
std::vector<int64_t> probs_shape = probs.shape();
uint32_t batch_size = probs_shape[0];
uint32_t vocab_size = probs_shape[1];
auto cu_stream = probs.stream();
auto renorm_probs =
GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place());
cudaError_t status;
status = sampling::TopKRenormProb<float>(
const_cast<float *>(probs.data<float>()),
renorm_probs.data<float>(),
const_cast<int64_t *>(top_k.data<int64_t>()),
batch_size, vocab_size, cu_stream);
PD_CHECK(status == cudaSuccess, "TopKRenormProb failed with error code " +
std::string(cudaGetErrorString(status)));
return {renorm_probs};
}
std::vector<std::vector<int64_t>>
TopKRenormInferShape(const std::vector<int64_t> &probs_shape,
const std::vector<int64_t> &top_k_shape) {
return {probs_shape};
}
std::vector<paddle::DataType>
TopKRenormInferDtype(const paddle::DataType &probs_dtype,
const paddle::DataType &top_k_shape) {
return {probs_dtype};
}
PD_BUILD_STATIC_OP(top_k_renorm_probs)
.Inputs({"probs", "top_k"})
.Outputs({"renorm_probs"})
.SetKernelFn(PD_KERNEL(TopKRenorm))
.SetInferShapeFn(PD_INFER_SHAPE(TopKRenormInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(TopKRenormInferDtype));

View File

@@ -23,34 +23,34 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
#define MAX_BSZ 512
#define K 20
#define MAX_BSZ 128
#define K 10
// #define SAVE_WITH_OUTPUT_DEBUG
struct msgdata {
long mtype;
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
float mtext_f[MAX_BSZ * (K + 1)]; // score
int mtext_ranks[MAX_BSZ]; // ranks
};
void SaveOutMmsgTopK(const paddle::Tensor& x,
const paddle::Tensor& logprob_token_ids, // [bsz, k+1]
const paddle::Tensor& logprob_scores, // [bsz, k+1]
const paddle::Tensor& ranks,
const paddle::Tensor& scores,
const paddle::Tensor& topk_ids,
const paddle::Tensor& topk_scores, // [bsz, k]
const paddle::Tensor& not_need_stop,
int k,
int64_t rank_id) {
if (rank_id > 0) {
return;
}
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
auto logprob_token_ids_cpu = logprob_token_ids.copy_to(paddle::CPUPlace(), false);
auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false);
auto ranks_cpu = ranks.copy_to(paddle::CPUPlace(), false);
auto scores_cpu = scores.copy_to(paddle::CPUPlace(), false);
auto topk_ids_cpu = topk_ids.copy_to(paddle::CPUPlace(), false);
auto topk_scores_cpu = topk_scores.copy_to(paddle::CPUPlace(), false);
int64_t* x_data = x_cpu.data<int64_t>();
int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data<int64_t>();
float* logprob_scores_data = logprob_scores_cpu.data<float>();
int64_t* ranks_data = ranks_cpu.data<int64_t>();
float* scores_data = scores_cpu.data<float>();
int64_t* topk_ids_data = topk_ids_cpu.data<int64_t>();
float* topk_scores_data = topk_scores_cpu.data<float>();
static struct msgdata msg_sed;
int msg_queue_id = 1;
if (const char* inference_msg_queue_id_env_p =
@@ -106,23 +106,21 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
: -inference_msg_id_from_env;
int bsz = x.shape()[0];
int max_num_logprobs = logprob_token_ids.shape()[1];
msg_sed.mtext[1] = bsz;
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < K + 1; j++) {
for (int j = 0; j < k + 1; j++) {
const int64_t offset = i * (K + 1) + j;
if (j == 0) {
msg_sed.mtext[offset + 2] = (int)x_data[i];
msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j];
} else if (j < max_num_logprobs) {
msg_sed.mtext[offset + 2] = (int)logprob_token_ids_data[i * max_num_logprobs + j];
msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j];
msg_sed.mtext_f[offset] = scores_data[i];
} else if (j <= k + 1) {
msg_sed.mtext[offset + 2] = (int)topk_ids_data[i * k + j - 1];
msg_sed.mtext_f[offset] = topk_scores_data[i * k + j - 1];
} else {
msg_sed.mtext[offset + 2] = -1;
msg_sed.mtext_f[offset] = 0.0;
}
}
msg_sed.mtext_ranks[i] = (int)ranks_data[i];
}
#ifdef SAVE_WITH_OUTPUT_DEBUG
std::cout << "msg data: ";
@@ -133,7 +131,7 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
#endif
if ((msgsnd(msgid,
&msg_sed,
(MAX_BSZ * (K + 1) + 2) * 4 + (MAX_BSZ * (K + 1)) * 4 + MAX_BSZ * 4,
(MAX_BSZ * (K + 1) + 2) * 4 + (MAX_BSZ * (K + 1)) * 4,
0)) == -1) {
printf("full msg buffer\n");
}
@@ -141,8 +139,8 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
}
PD_BUILD_STATIC_OP(save_output_topk)
.Inputs({"x", "topk_ids", "logprob_scores", "ranks", "not_need_stop"})
.Attrs({"rank_id: int64_t"})
.Inputs({"x", "scores", "topk_ids", "topk_scores", "not_need_stop"})
.Attrs({"k: int", "rank_id: int64_t"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SaveOutMmsgTopK));

View File

@@ -91,12 +91,7 @@ void set_data_ipc(const paddle::Tensor& tmp_input,
memset((void *)shm, 0, sizeof(*shm));
void *data_ptr_now = reinterpret_cast<void*>(const_cast<data_t*>(tmp_input.data<data_t>()));
#ifdef PADDLE_WITH_HIP
checkCudaErrors(hipIpcGetMemHandle((hipIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
#else
checkCudaErrors(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
#endif
}

View File

@@ -13,7 +13,6 @@
// limitations under the License.
#include "paddle/extension.h"
#include "helper.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
@@ -52,18 +51,13 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_flags) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(stop_flags.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = stop_flags.stream();
#endif
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
int bs = seq_lens_this_time.shape()[0];
int length = pre_ids_all_shape[1];
int length_input_ids = input_ids.shape()[1];
int block_size = (bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
int block_size = (bs + 32 - 1) / 32 * 32;
set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(
stop_flags.data<bool>(),
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),

View File

@@ -37,18 +37,10 @@ std::vector<paddle::Tensor> ShareExternalData(paddle::Tensor& input,
}
shm = (volatile shmStruct *)info.addr;
void *ptr = nullptr;
#ifdef PADDLE_WITH_HIP
checkCudaErrors(
hipIpcOpenMemHandle(&ptr,
*(hipIpcMemHandle_t *)&shm->memHandle, // NOLINT
hipIpcMemLazyEnablePeerAccess));
#else
checkCudaErrors(
cudaIpcOpenMemHandle(&ptr,
*(cudaIpcMemHandle_t *)&shm->memHandle, // NOLINT
cudaIpcMemLazyEnablePeerAccess));
#endif
paddle::Tensor tmp_tensor = paddle::from_blob(
ptr,
shape,

View File

@@ -246,7 +246,7 @@ void token_penalty_multi_scores_kernel(
max_seq_len);
}
void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
@@ -338,4 +338,4 @@ PD_BUILD_STATIC_OP(speculate_get_token_penalty_multi_scores)
.Outputs({"logits_out"})
.Attrs({"max_seq_len: int"})
.SetInplaceMap({{"logits", "logits_out"}})
.SetKernelFn(PD_KERNEL(SpecTokenPenaltyMultiScores));
.SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores));

Some files were not shown because too many files have changed in this diff Show More