polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

7
.flake8 Normal file
View File

@@ -0,0 +1,7 @@
[flake8]
ignore = E203, E402, E501, E731, E741, W503, W605, E722
max-line-length = 119
# E402: module level import not at top of file
per-file-ignores =
__init__.py:F401,F403,E402

View File

@@ -2,7 +2,7 @@ name: CI
on: on:
pull_request: pull_request:
branches: branches:
- develop - develop
- 'release/*' - 'release/*'
workflow_dispatch: workflow_dispatch:
@@ -86,4 +86,4 @@ jobs:
git config --global --add safe.directory /workspace/FastDeploy git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy cd FastDeploy
bash scripts/run_ci.sh bash scripts/run_ci.sh
" "

View File

@@ -2,7 +2,7 @@ name: CI_XPU
on: on:
pull_request: pull_request:
branches: branches:
- develop - develop
- 'release/*' - 'release/*'
workflow_dispatch: workflow_dispatch:
@@ -63,7 +63,7 @@ jobs:
if [[ "$last_char" =~ [0-3] ]]; then if [[ "$last_char" =~ [0-3] ]]; then
gpu_id="$last_char" gpu_id="$last_char"
else else
gpu_id="0" gpu_id="0"
fi fi
FD_API_PORT=$((9180 + gpu_id * 100)) FD_API_PORT=$((9180 + gpu_id * 100))
FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100)) FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100))
@@ -84,4 +84,4 @@ jobs:
git config --global --add safe.directory /workspace/FastDeploy git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy cd FastDeploy
bash scripts/run_ci_xpu.sh bash scripts/run_ci_xpu.sh
" "

View File

@@ -5,12 +5,27 @@ default_stages:
- pre-commit # Run locally - pre-commit # Run locally
# - manual # Run in CI # - manual # Run in CI
repos: repos:
- repo: https://github.com/psf/black.git
rev: 22.8.0
hooks:
- id: black
files: \.(py|pyi)$
additional_dependencies: [toml]
# 自动排序
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
# 代码检查 # 代码检查
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7 rev: v0.11.7
hooks: hooks:
- id: ruff - id: ruff
args: [--output-format, github, --fix, --line-length=120] args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml]
# # 拼写检查 # # 拼写检查
# - repo: https://github.com/codespell-project/codespell # - repo: https://github.com/codespell-project/codespell
# rev: v2.4.1 # rev: v2.4.1
@@ -18,17 +33,13 @@ repos:
# - id: codespell # - id: codespell
# additional_dependencies: ['tomli'] # additional_dependencies: ['tomli']
# args: ['--toml', 'pyproject.toml'] # args: ['--toml', 'pyproject.toml']
# 自动排序
- repo: https://github.com/PyCQA/isort
rev: 6.0.1
hooks:
- id: isort
# markdown # markdown
- repo: https://github.com/jackdewinter/pymarkdown - repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.29 rev: v0.9.29
hooks: hooks:
- id: pymarkdown - id: pymarkdown
args: [fix] args: ["-d", "MD029,MD031", fix]
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0 rev: v5.0.0
hooks: hooks:

View File

@@ -8,7 +8,7 @@
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a> <a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a> <a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a> <a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
</p> </p>
<p align="center"> <p align="center">
@@ -17,8 +17,8 @@
| |
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/quick_start"><b> Quick Start </b></a> <a href="https://paddlepaddle.github.io/FastDeploy/get_started/quick_start"><b> Quick Start </b></a>
| |
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a> <a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
</p> </p>
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------

View File

@@ -131,4 +131,4 @@ python benchmarks/benchmark_mtp.py \
--s_itl-base-model主模型的解码延迟可由上述的性能压测工具获得与batch-size一一对应 --s_itl-base-model主模型的解码延迟可由上述的性能压测工具获得与batch-size一一对应
--dataset-name指定数据集类指定为"EBChat"可读取转存的FD格式数据集 --dataset-name指定数据集类指定为"EBChat"可读取转存的FD格式数据集
--dataset-path测试数据集路径 --dataset-path测试数据集路径
``` ```

View File

@@ -29,13 +29,13 @@ from typing import Optional
import aiohttp import aiohttp
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
@dataclass @dataclass
class RequestFuncInput: class RequestFuncInput:
"""Input for requesting LLMs via API""" """Input for requesting LLMs via API"""
no: int no: int
prompt: str prompt: str
history_QA: Optional[dict] history_QA: Optional[dict]
@@ -55,6 +55,7 @@ class RequestFuncInput:
@dataclass @dataclass
class RequestFuncOutput: class RequestFuncOutput:
"""Output for requesting LLMs via API""" """Output for requesting LLMs via API"""
no: int = 0 no: int = 0
generated_text: str = "" generated_text: str = ""
reasoning_content: str = "" reasoning_content: str = ""
@@ -66,7 +67,7 @@ class RequestFuncOutput:
itl: list = field(default_factory=list) # list of inter-token latencies itl: list = field(default_factory=list) # list of inter-token latencies
tpot: float = 0.0 # avg next-token latencies tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0 prompt_len: int = 0
prompt_tokens: int = 0 # 推理侧返回输入token数 prompt_tokens: int = 0 # 推理侧返回输入token数
error: str = "" error: str = ""
@@ -76,12 +77,9 @@ async def async_request_eb_openai_chat_completions(
) -> RequestFuncOutput: ) -> RequestFuncOutput:
"""Request an LLM using EB OpenAI""" """Request an LLM using EB OpenAI"""
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(("completions", "profile")), "OpenAI Chat Completions API URL must end with 'completions'."
("completions", "profile")
), "OpenAI Chat Completions API URL must end with 'completions'."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content: if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content) content.append(request_func_input.multi_modal_content)
@@ -91,7 +89,7 @@ async def async_request_eb_openai_chat_completions(
"stream": True, "stream": True,
"stream_options": { "stream_options": {
"include_usage": True, "include_usage": True,
"continuous_usage_stats": True "continuous_usage_stats": True,
}, },
} }
# 超参由yaml传入 # 超参由yaml传入
@@ -99,8 +97,8 @@ async def async_request_eb_openai_chat_completions(
if request_func_input.ignore_eos: if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos payload["ignore_eos"] = request_func_input.ignore_eos
print("payload:{}".format(json.dumps(payload, ensure_ascii=False))) print(f"payload:{json.dumps(payload, ensure_ascii=False)}")
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
@@ -115,16 +113,14 @@ async def async_request_eb_openai_chat_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload, headers=headers) as response:
headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
# print("####chunk:", chunk, type(chunk)) # print("####chunk:", chunk, type(chunk))
timestamp = time.perf_counter() timestamp = time.perf_counter()
@@ -138,22 +134,20 @@ async def async_request_eb_openai_chat_completions(
ttft = timestamp - st ttft = timestamp - st
output.ttft = ttft output.ttft = ttft
# cached_tokens # cached_tokens
output.prompt_len = data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0) output.prompt_len = (
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
)
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
output.generated_text += content or "" output.generated_text += content or ""
output.reasoning_content += reason_content or "" output.reasoning_content += reason_content or ""
output.arrival_time.append(choices[0].get("arrival_time", timestamp)) output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage", {}): elif usage := data.get("usage", {}):
output.output_tokens = usage.get( output.output_tokens = usage.get("completion_tokens", 0)
"completion_tokens", 0) output.prompt_tokens = usage.get("prompt_tokens", 0)
output.prompt_tokens = usage.get(
"prompt_tokens", 0)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
@@ -166,7 +160,12 @@ async def async_request_eb_openai_chat_completions(
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
else: else:
error_text = await response.text() error_text = await response.text()
print("####error response:", error_text, "####payload:", payload) print(
"####error response:",
error_text,
"####payload:",
payload,
)
output.error = error_text or "" output.error = error_text or ""
output.success = False output.success = False
except Exception: except Exception:
@@ -194,15 +193,14 @@ async def async_request_eb_openai_completions(
("completions", "profile") ("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'." ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
payload = { payload = {
"model": request_func_input.model, "model": request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
"stream": True, "stream": True,
"stream_options": { "stream_options": {
"include_usage": True, "include_usage": True,
"continuous_usage_stats": True "continuous_usage_stats": True,
}, },
} }
# 超参由yaml传入 # 超参由yaml传入
@@ -210,12 +208,12 @@ async def async_request_eb_openai_completions(
if request_func_input.ignore_eos: if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos payload["ignore_eos"] = request_func_input.ignore_eos
print("payload:", json.dumps(payload, ensure_ascii=False)) print("payload:", json.dumps(payload, ensure_ascii=False))
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
output = RequestFuncOutput() output = RequestFuncOutput()
@@ -227,8 +225,7 @@ async def async_request_eb_openai_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload, headers=headers) as response:
headers=headers) as response:
if response.status == 200: if response.status == 200:
first_chunk_received = False first_chunk_received = False
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
@@ -236,8 +233,7 @@ async def async_request_eb_openai_completions(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
# print("####chunk:", chunk, chunk.usage) # print("####chunk:", chunk, chunk.usage)
timestamp = time.perf_counter() timestamp = time.perf_counter()
@@ -250,7 +246,7 @@ async def async_request_eb_openai_completions(
# Note that text could be empty here # Note that text could be empty here
# e.g. for special tokens # e.g. for special tokens
text = choices[0].get("text") text = choices[0].get("text")
# First token # First token
if not first_chunk_received: if not first_chunk_received:
first_chunk_received = True first_chunk_received = True
@@ -259,26 +255,23 @@ async def async_request_eb_openai_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
generated_text += text or "" generated_text += text or ""
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
output.arrival_time.append(choices[0].get("arrival_time", timestamp)) output.arrival_time.append(choices[0].get("arrival_time", timestamp))
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.prompt_tokens = usage.get( output.prompt_tokens = usage.get("prompt_tokens")
"prompt_tokens") output.output_tokens = usage.get("completion_tokens")
output.output_tokens = usage.get(
"completion_tokens")
if first_chunk_received: if first_chunk_received:
output.success = True output.success = True
else: else:
output.success = False output.success = False
output.error = ( output.error = (
"Never received a valid chunk to calculate TTFT." "Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
"This response will be marked as failed!") )
output.generated_text = generated_text output.generated_text = generated_text
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
@@ -294,8 +287,8 @@ async def async_request_eb_openai_completions(
output.success = False output.success = False
exc_info = sys.exc_info() exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info)) output.error = "".join(traceback.format_exception(*exc_info))
print("final_output:{}".format(output)) print(f"final_output:{output}")
if pbar: if pbar:
pbar.update(1) pbar.update(1)
@@ -310,8 +303,7 @@ async def async_request_tgi(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
params = { params = {
"max_new_tokens": request_func_input.output_len, "max_new_tokens": request_func_input.output_len,
"do_sample": True, "do_sample": True,
@@ -358,8 +350,7 @@ async def async_request_tgi(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
output.arrival_time.append(data["arrival_time"]) output.arrival_time.append(data["arrival_time"])
@@ -388,8 +379,7 @@ async def async_request_trt_llm(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
payload = { payload = {
"accumulate_tokens": True, "accumulate_tokens": True,
"text_input": request_func_input.prompt, "text_input": request_func_input.prompt,
@@ -414,8 +404,7 @@ async def async_request_trt_llm(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
"data:")
data = json.loads(chunk) data = json.loads(chunk)
output.generated_text += data["text_output"] output.generated_text += data["text_output"]
@@ -427,8 +416,7 @@ async def async_request_trt_llm(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
@@ -453,8 +441,7 @@ async def async_request_deepspeed_mii(
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
"""Request an LLM using Deepspeed MII""" """Request an LLM using Deepspeed MII"""
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
payload = { payload = {
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
@@ -472,19 +459,16 @@ async def async_request_deepspeed_mii(
st = time.perf_counter() st = time.perf_counter()
try: try:
async with session.post(url=request_func_input.api_url, async with session.post(url=request_func_input.api_url, json=payload) as response:
json=payload) as response:
if response.status == 200: if response.status == 200:
parsed_resp = await response.json() parsed_resp = await response.json()
output.latency = time.perf_counter() - st output.latency = time.perf_counter() - st
if "choices" in parsed_resp: if "choices" in parsed_resp:
output.generated_text = parsed_resp["choices"][0][ output.generated_text = parsed_resp["choices"][0]["text"]
"text"]
elif "text" in parsed_resp: elif "text" in parsed_resp:
output.generated_text = parsed_resp["text"][0] output.generated_text = parsed_resp["text"][0]
else: else:
output.error = ("Unexpected response format: " output.error = "Unexpected response format: " "neither 'choices' nor 'text' found"
"neither 'choices' nor 'text' found")
output.success = False output.success = False
output.success = True output.success = True
else: else:
@@ -510,26 +494,22 @@ async def async_request_openai_completions(
("completions", "profile") ("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'." ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
payload = { payload = {
"model": request_func_input.model_name \ "model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
# "temperature": 0.0, # "temperature": 0.0,
"max_tokens": request_func_input.output_len, "max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs, "logprobs": request_func_input.logprobs,
"stream": True, "stream": True,
#"stream_options": { # "stream_options": {
# "include_usage": True, # "include_usage": True,
#}, # },
} }
if request_func_input.ignore_eos: if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos payload["ignore_eos"] = request_func_input.ignore_eos
headers = { headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
@@ -538,8 +518,7 @@ async def async_request_openai_completions(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload, headers=headers) as response:
headers=headers) as response:
if response.status == 200: if response.status == 200:
first_chunk_received = False first_chunk_received = False
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
@@ -547,8 +526,7 @@ async def async_request_openai_completions(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
# print("####chunk:", chunk, type(chunk)) # print("####chunk:", chunk, type(chunk))
data = json.loads(chunk) data = json.loads(chunk)
@@ -569,21 +547,19 @@ async def async_request_openai_completions(
# Decoding phase # Decoding phase
else: else:
output.itl.append(timestamp - output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
generated_text += text or "" generated_text += text or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get("completion_tokens")
"completion_tokens")
if first_chunk_received: if first_chunk_received:
output.success = True output.success = True
else: else:
output.success = False output.success = False
output.error = ( output.error = (
"Never received a valid chunk to calculate TTFT." "Never received a valid chunk to calculate TTFT." "This response will be marked as failed!"
"This response will be marked as failed!") )
output.generated_text = generated_text output.generated_text = generated_text
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
else: else:
@@ -606,25 +582,24 @@ async def async_request_openai_audio(
"""Request an LLM using OpenAI""" """Request an LLM using OpenAI"""
# Lazy import without PlaceholderModule to avoid vllm dep. # Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile import soundfile
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(
("transcriptions", "translations" ("transcriptions", "translations")
)), "OpenAI Chat Completions API URL must end with 'transcriptions' " ), "OpenAI Chat Completions API URL must end with 'transcriptions' "
"or `translations`." "or `translations`."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
payload = { payload = {
"model": request_func_input.model_name \ "model": (request_func_input.model_name if request_func_input.model_name else request_func_input.model),
if request_func_input.model_name else request_func_input.model,
"temperature": 0.0, "temperature": 0.0,
"max_completion_tokens": request_func_input.output_len, "max_completion_tokens": request_func_input.output_len,
"stream": True, "stream": True,
"language": "en", "language": "en",
# Flattened due to multipart/form-data # Flattened due to multipart/form-data
"stream_include_usage": True, "stream_include_usage": True,
"stream_continuous_usage_stats": True "stream_continuous_usage_stats": True,
} }
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
@@ -639,9 +614,9 @@ async def async_request_openai_audio(
buffer.seek(0) buffer.seek(0)
return buffer return buffer
with to_bytes(*request_func_input.multi_modal_content['audio']) as f: with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
form = aiohttp.FormData() form = aiohttp.FormData()
form.add_field('file', f, content_type='audio/wav') form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items(): for key, value in payload.items():
form.add_field(key, str(value)) form.add_field(key, str(value))
@@ -653,24 +628,20 @@ async def async_request_openai_audio(
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, async with session.post(url=api_url, data=form, headers=headers) as response:
data=form,
headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content: async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip() chunk_bytes = chunk_bytes.strip()
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
"data: ")
if chunk != "[DONE]": if chunk != "[DONE]":
timestamp = time.perf_counter() timestamp = time.perf_counter()
data = json.loads(chunk) data = json.loads(chunk)
if choices := data.get("choices"): if choices := data.get("choices"):
content = choices[0]["delta"].get( content = choices[0]["delta"].get("content")
"content")
# First token # First token
if ttft == 0.0: if ttft == 0.0:
ttft = timestamp - st ttft = timestamp - st
@@ -678,13 +649,11 @@ async def async_request_openai_audio(
# Decoding phase # Decoding phase
else: else:
output.itl.append( output.itl.append(timestamp - most_recent_timestamp)
timestamp - most_recent_timestamp)
generated_text += content or "" generated_text += content or ""
elif usage := data.get("usage"): elif usage := data.get("usage"):
output.output_tokens = usage.get( output.output_tokens = usage.get("completion_tokens")
"completion_tokens")
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
@@ -718,8 +687,11 @@ ASYNC_REQUEST_FUNCS = {
} }
OPENAI_COMPATIBLE_BACKENDS = [ OPENAI_COMPATIBLE_BACKENDS = [
k for k, v in ASYNC_REQUEST_FUNCS.items() k
if v in (async_request_openai_completions, for k, v in ASYNC_REQUEST_FUNCS.items()
async_request_eb_openai_chat_completions) if v
in (
async_request_openai_completions,
async_request_eb_openai_chat_completions,
)
] ]

View File

@@ -26,9 +26,9 @@ from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from typing import Any, Callable, Optional, Union from typing import Any, Optional, Union
from PIL import Image
from PIL import Image
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -38,6 +38,7 @@ class SampleRequest:
""" """
Represents a single inference request for benchmarking. Represents a single inference request for benchmarking.
""" """
no: int no: int
prompt: Union[str, Any] prompt: Union[str, Any]
history_QA: Union[str, Any] history_QA: Union[str, Any]
@@ -48,6 +49,7 @@ class SampleRequest:
class BenchmarkDataset(ABC): class BenchmarkDataset(ABC):
"""BenchmarkDataset""" """BenchmarkDataset"""
DEFAULT_SEED = 0 DEFAULT_SEED = 0
IS_MULTIMODAL = False IS_MULTIMODAL = False
@@ -68,8 +70,7 @@ class BenchmarkDataset(ABC):
self.dataset_path = dataset_path self.dataset_path = dataset_path
# Set the random seed, ensuring that a None value is replaced with the # Set the random seed, ensuring that a None value is replaced with the
# default seed. # default seed.
self.random_seed = (random_seed self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED
if random_seed is not None else self.DEFAULT_SEED)
self.data = None self.data = None
self.hyperparameter_path = hyperparameter_path self.hyperparameter_path = hyperparameter_path
self.hyperparameters = {} self.hyperparameters = {}
@@ -85,8 +86,7 @@ class BenchmarkDataset(ABC):
NotImplementedError: If a subclass does not implement this method. NotImplementedError: If a subclass does not implement this method.
""" """
# TODO (jenniferzhao): add support for downloading data # TODO (jenniferzhao): add support for downloading data
raise NotImplementedError( raise NotImplementedError("load_data must be implemented in subclasses.")
"load_data must be implemented in subclasses.")
@abstractmethod @abstractmethod
def sample(self, num_requests: int) -> list[SampleRequest]: def sample(self, num_requests: int) -> list[SampleRequest]:
@@ -105,8 +105,7 @@ class BenchmarkDataset(ABC):
""" """
raise NotImplementedError("sample must be implemented in subclasses.") raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(self, requests: list[SampleRequest], def maybe_oversample_requests(self, requests: list[SampleRequest], num_requests: int) -> None:
num_requests: int) -> None:
""" """
Oversamples the list of requests if its size is less than the desired Oversamples the list of requests if its size is less than the desired
number. number.
@@ -117,11 +116,9 @@ class BenchmarkDataset(ABC):
""" """
if len(requests) < num_requests: if len(requests) < num_requests:
random.seed(self.random_seed) random.seed(self.random_seed)
additional = random.choices(requests, additional = random.choices(requests, k=num_requests - len(requests))
k=num_requests - len(requests))
requests.extend(additional) requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.", logger.info("Oversampled requests to reach %d total samples.", num_requests)
num_requests)
def is_valid_sequence( def is_valid_sequence(
@@ -141,14 +138,12 @@ def is_valid_sequence(
""" """
# Check for invalid conditions # Check for invalid conditions
prompt_too_short = prompt_len < min_len prompt_too_short = prompt_len < min_len
output_too_short = (not skip_min_output_len_check) and (output_len output_too_short = (not skip_min_output_len_check) and (output_len < min_len)
< min_len)
prompt_too_long = prompt_len > max_prompt_len prompt_too_long = prompt_len > max_prompt_len
combined_too_long = (prompt_len + output_len) > max_total_len combined_too_long = (prompt_len + output_len) > max_total_len
# Return True if none of the invalid conditions are met # Return True if none of the invalid conditions are met
return not (prompt_too_short or output_too_short or prompt_too_long return not (prompt_too_short or output_too_short or prompt_too_long or combined_too_long)
or combined_too_long)
def process_image(image: Any) -> Mapping[str, Any]: def process_image(image: Any) -> Mapping[str, Any]:
@@ -171,28 +166,25 @@ def process_image(image: Any) -> Mapping[str, Any]:
Raises: Raises:
ValueError: If the input is not a supported type. ValueError: If the input is not a supported type.
""" """
if isinstance(image, dict) and 'bytes' in image: if isinstance(image, dict) and "bytes" in image:
image = Image.open(BytesIO(image['bytes'])) image = Image.open(BytesIO(image["bytes"]))
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
image = image.convert("RGB") image = image.convert("RGB")
with io.BytesIO() as image_data: with io.BytesIO() as image_data:
image.save(image_data, format="JPEG") image.save(image_data, format="JPEG")
image_base64 = base64.b64encode( image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
image_data.getvalue()).decode("utf-8")
return { return {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
"url": f"data:image/jpeg;base64,{image_base64}"
},
} }
if isinstance(image, str): if isinstance(image, str):
image_url = (image if image.startswith( image_url = image if image.startswith(("http://", "file://")) else f"file://{image}"
("http://", "file://")) else f"file://{image}")
return {"type": "image_url", "image_url": {"url": image_url}} return {"type": "image_url", "image_url": {"url": image_url}}
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" raise ValueError(
" or str or dictionary with raw image bytes.") f"Invalid image input {image}. Must be a PIL.Image.Image" " or str or dictionary with raw image bytes."
)
class EBDataset(BenchmarkDataset): class EBDataset(BenchmarkDataset):
@@ -243,8 +235,7 @@ class EBDataset(BenchmarkDataset):
new_output_len = int(entry["max_dec_len"]) new_output_len = int(entry["max_dec_len"])
if enable_multimodal_chat: if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation( prompt = self.apply_multimodal_chat_transformation(prompt, None)
prompt, None)
samples.append( samples.append(
SampleRequest( SampleRequest(
no=cnt, no=cnt,
@@ -252,17 +243,20 @@ class EBDataset(BenchmarkDataset):
prompt_len=self.prompt_len, prompt_len=self.prompt_len,
history_QA=[], history_QA=[],
expected_output_len=new_output_len, expected_output_len=new_output_len,
)) )
)
cnt += 1 cnt += 1
self.maybe_oversample_requests(samples, num_requests) self.maybe_oversample_requests(samples, num_requests)
return samples return samples
class EBChatDataset(BenchmarkDataset): class EBChatDataset(BenchmarkDataset):
""" """
Implements the ShareGPT dataset. Loads data from a JSON file and generates Implements the ShareGPT dataset. Loads data from a JSON file and generates
sample requests based on conversation turns. sample requests based on conversation turns.
""" """
prompt_len: int prompt_len: int
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
@@ -296,8 +290,7 @@ class EBChatDataset(BenchmarkDataset):
new_output_len = int(entry.get("max_tokens", 12288)) new_output_len = int(entry.get("max_tokens", 12288))
if enable_multimodal_chat: if enable_multimodal_chat:
prompt = self.apply_multimodal_chat_transformation( prompt = self.apply_multimodal_chat_transformation(prompt, None)
prompt, None)
samples.append( samples.append(
SampleRequest( SampleRequest(
no=cnt, no=cnt,
@@ -306,9 +299,9 @@ class EBChatDataset(BenchmarkDataset):
prompt_len=0, prompt_len=0,
history_QA=history_QA, history_QA=history_QA,
expected_output_len=new_output_len, expected_output_len=new_output_len,
)) )
)
cnt += 1 cnt += 1
self.maybe_oversample_requests(samples, num_requests) self.maybe_oversample_requests(samples, num_requests)
return samples return samples

View File

@@ -18,28 +18,16 @@ import argparse
import asyncio import asyncio
import contextlib import contextlib
import os import os
import signal
import socket
import subprocess
import time
from typing import Union from typing import Union
import openai from benchmark_dataset import EBChatDataset, EBDataset
import yaml
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_serving import benchmark from benchmark_serving import benchmark
def prepare_input_requests( def prepare_input_requests(num_prompts: int, dataset_name: str, dataset_path: str) -> Union[EBDataset, EBChatDataset]:
num_prompts: int, dataset_name: str, dataset_path: str
) -> Union[EBDataset, EBChatDataset]:
dataset_mapping = { dataset_mapping = {
"EB": lambda: EBDataset(dataset_path=dataset_path).sample( "EB": lambda: EBDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
num_requests=num_prompts "EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(num_requests=num_prompts),
),
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
num_requests=num_prompts
),
} }
try: try:
@@ -104,24 +92,27 @@ def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
def main(args): def main(args):
base_url = f"http://{args.host}:{args.port}" base_url = f"http://{args.host}:{args.port}"
input_requests = prepare_input_requests( input_requests = prepare_input_requests(args.num_prompts, args.dataset_name, args.dataset_path)
args.num_prompts, args.dataset_name, args.dataset_path
)
if len(args.max_concurrency) != len(args.s_itl_base_model): if len(args.max_concurrency) != len(args.s_itl_base_model):
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model") raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model): for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
# Wramup # Wramup
print("Starting warmup...") print("Starting warmup...")
with open(os.devnull, "w") as f: with open(os.devnull, "w") as f:
with contextlib.redirect_stdout(f): with contextlib.redirect_stdout(f):
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True) send_one_batch(
base_url,
max_concurrency,
input_requests[0:max_concurrency],
True,
)
# Benchmark # Benchmark
record = send_one_batch(base_url, max_concurrency, input_requests, False) record = send_one_batch(base_url, max_concurrency, input_requests, False)
metric_header = f"Speed up" metric_header = "Speed up"
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
for draft_token_step in args.draft_token_steps: for draft_token_step in args.draft_token_steps:
speedup = calculate_speedup( speedup = calculate_speedup(
@@ -130,11 +121,7 @@ def main(args):
s_itl, s_itl,
record["mean_s_itl_ms"], record["mean_s_itl_ms"],
) )
print( print("{:<40} {:<10.2f}".format(f"Speed up on {draft_token_step} steps draft", speedup))
"{:<40} {:<10.2f}".format(
f"Speed up on {draft_token_step} steps draft", speedup
)
)
print("=" * 50) print("=" * 50)

File diff suppressed because it is too large Load Diff

View File

@@ -24,9 +24,11 @@ import os
from typing import Any from typing import Any
def convert_to_pytorch_benchmark_format(args: argparse.Namespace, def convert_to_pytorch_benchmark_format(
metrics: dict[str, list], args: argparse.Namespace,
extra_info: dict[str, Any]) -> list: metrics: dict[str, list],
extra_info: dict[str, Any],
) -> list:
""" """
Save the benchmark results in the format used by PyTorch OSS benchmark with Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record on metric per record
@@ -54,12 +56,10 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
}, },
} }
tp = record["benchmark"]["extra_info"]["args"].get( tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
"tensor_parallel_size")
# Save tensor_parallel_size parameter if it's part of the metadata # Save tensor_parallel_size parameter if it's part of the metadata
if not tp and "tensor_parallel_size" in extra_info: if not tp and "tensor_parallel_size" in extra_info:
record["benchmark"]["extra_info"]["args"][ record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = extra_info["tensor_parallel_size"]
"tensor_parallel_size"] = extra_info["tensor_parallel_size"]
records.append(record) records.append(record)
@@ -68,6 +68,7 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
class InfEncoder(json.JSONEncoder): class InfEncoder(json.JSONEncoder):
"""InfEncoder""" """InfEncoder"""
def clear_inf(self, o: Any): def clear_inf(self, o: Any):
"""clear_inf""" """clear_inf"""
if isinstance(o, dict): if isinstance(o, dict):
@@ -87,4 +88,3 @@ def write_to_json(filename: str, records: list) -> None:
"""write_to_json""" """write_to_json"""
with open(filename, "w") as f: with open(filename, "w") as f:
json.dump(records, f, cls=InfEncoder) json.dump(records, f, cls=InfEncoder)

View File

@@ -25,32 +25,32 @@ import os
import random import random
import time import time
import warnings import warnings
import yaml from argparse import ArgumentParser as FlexibleArgumentParser
import requests
import copy
from collections.abc import AsyncGenerator, Iterable from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
import numpy as np import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, import requests
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, import yaml
RequestFuncOutput) from backend_request_func import (
ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS,
RequestFuncInput,
RequestFuncOutput,
)
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from argparse import ArgumentParser as FlexibleArgumentParser
from benchmark_dataset import (SampleRequest, EBDataset, EBChatDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
MILLISECONDS_TO_SECONDS_CONVERSION = 1000 MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@dataclass @dataclass
class BenchmarkMetrics: class BenchmarkMetrics:
"""Class containing all metrics that are used in this script""" """Class containing all metrics that are used in this script"""
completed: int completed: int
total_input: int total_input: int
total_output: int total_output: int
@@ -133,8 +133,7 @@ async def get_request(
input_requests: Iterable[SampleRequest] = iter(input_requests) input_requests: Iterable[SampleRequest] = iter(input_requests)
# Calculate scale parameter theta to maintain the desired request_rate. # Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, ( assert burstiness > 0, f"A positive burstiness factor is expected, but given {burstiness}."
f"A positive burstiness factor is expected, but given {burstiness}.")
theta = 1.0 / (request_rate * burstiness) theta = 1.0 / (request_rate * burstiness)
for request in input_requests: for request in input_requests:
@@ -160,7 +159,7 @@ def calculate_metrics(
) -> tuple[BenchmarkMetrics, list[int]]: ) -> tuple[BenchmarkMetrics, list[int]]:
"""Calculates various performance metrics based on the inputs and outputs.""" """Calculates various performance metrics based on the inputs and outputs."""
input_lens: list[int] = [] input_lens: list[int] = []
infer_input_lens: list[int] = [] # 推理侧输入token数 infer_input_lens: list[int] = [] # 推理侧输入token数
actual_output_lens: list[int] = [] actual_output_lens: list[int] = []
total_input = 0 total_input = 0
completed = 0 completed = 0
@@ -210,8 +209,9 @@ def calculate_metrics(
s_e2els.append(outputs[i].arrival_time[-1]) s_e2els.append(outputs[i].arrival_time[-1])
# 解码速度去掉首token # 解码速度去掉首token
if len(outputs[i].arrival_time) > 2: if len(outputs[i].arrival_time) > 2:
s_decodes.append((outputs[i].output_tokens - 1) / s_decodes.append(
(outputs[i].arrival_time[-1] - outputs[i].arrival_time[1])) (outputs[i].output_tokens - 1) / (outputs[i].arrival_time[-1] - outputs[i].arrival_time[1])
)
completed += 1 completed += 1
else: else:
actual_output_lens.append(0) actual_output_lens.append(0)
@@ -224,16 +224,13 @@ def calculate_metrics(
if "ttft" in goodput_config_dict: if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts) valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] / slo_values.append(goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION)
MILLISECONDS_TO_SECONDS_CONVERSION)
if "tpot" in goodput_config_dict: if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots) valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] / slo_values.append(goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION)
MILLISECONDS_TO_SECONDS_CONVERSION)
if "e2el" in goodput_config_dict: if "e2el" in goodput_config_dict:
valid_metrics.append(e2els) valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] / slo_values.append(goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION)
MILLISECONDS_TO_SECONDS_CONVERSION)
for req_metric in zip(*valid_metrics): for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
@@ -242,9 +239,9 @@ def calculate_metrics(
if completed == 0: if completed == 0:
warnings.warn( warnings.warn(
"All requests failed. This is likely due to a misconfiguration " "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.",
"on the benchmark arguments.", stacklevel=2,
stacklevel=2) )
metrics = BenchmarkMetrics( metrics = BenchmarkMetrics(
completed=completed, completed=completed,
total_input=total_input, total_input=total_input,
@@ -253,64 +250,50 @@ def calculate_metrics(
request_goodput=good_completed / dur_s, request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s, output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_s_decode=np.mean(s_decodes or 0) * mean_s_decode=np.mean(s_decodes or 0) * 1, # ttfts is empty if streaming is not supported by backend
1, # ttfts is empty if streaming is not supported by backend
std_s_decode=np.std(s_decodes or 0) * 1, std_s_decode=np.std(s_decodes or 0) * 1,
median_s_decode=np.median(s_decodes or 0) * 1, median_s_decode=np.median(s_decodes or 0) * 1,
percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1) percentiles_s_decode=[(p, np.percentile(s_decodes or 0, p) * 1) for p in selected_percentiles],
for p in selected_percentiles], mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
mean_ttft_ms=np.mean(ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles], mean_s_ttft_ms=np.mean(s_ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend
mean_s_ttft_ms=np.mean(s_ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
std_s_ttft_ms=np.std(s_ttfts or 0) * 1000, std_s_ttft_ms=np.std(s_ttfts or 0) * 1000,
median_s_ttft_ms=np.median(s_ttfts or 0) * 1000, median_s_ttft_ms=np.median(s_ttfts or 0) * 1000,
percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000) percentiles_s_ttft_ms=[(p, np.percentile(s_ttfts or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_tpot_ms=np.mean(tpots or 0) * 1000, mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_itl_ms=np.mean(itls or 0) * 1000, mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_s_itl_ms=np.mean(s_itls or 0) * 1000, mean_s_itl_ms=np.mean(s_itls or 0) * 1000,
std_s_itl_ms=np.std(s_itls or 0) * 1000, std_s_itl_ms=np.std(s_itls or 0) * 1000,
median_s_itl_ms=np.median(s_itls or 0) * 1000, median_s_itl_ms=np.median(s_itls or 0) * 1000,
percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000) percentiles_s_itl_ms=[(p, np.percentile(s_itls or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_e2el_ms=np.mean(e2els or 0) * 1000, mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000, mean_s_e2el_ms=np.mean(s_e2els or 0) * 1000,
std_s_e2el_ms=np.std(s_e2els or 0) * 1000, std_s_e2el_ms=np.std(s_e2els or 0) * 1000,
median_s_e2el_ms=np.median(s_e2els or 0) * 1000, median_s_e2el_ms=np.median(s_e2els or 0) * 1000,
percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000) percentiles_s_e2el_ms=[(p, np.percentile(s_e2els or 0, p) * 1000) for p in selected_percentiles],
for p in selected_percentiles],
mean_input_len=np.mean(input_lens or 0) * 1, mean_input_len=np.mean(input_lens or 0) * 1,
std_input_len=np.std(input_lens or 0) * 1, std_input_len=np.std(input_lens or 0) * 1,
median_input_len=np.median(input_lens or 0) * 1, median_input_len=np.median(input_lens or 0) * 1,
percentiles_input_len=[(p, np.percentile(input_lens or 0, p)) percentiles_input_len=[(p, np.percentile(input_lens or 0, p)) for p in selected_percentiles],
for p in selected_percentiles],
mean_s_input_len=np.mean(infer_input_lens or 0) * 1, mean_s_input_len=np.mean(infer_input_lens or 0) * 1,
std_s_input_len=np.std(infer_input_lens or 0) * 1, std_s_input_len=np.std(infer_input_lens or 0) * 1,
median_s_input_len=np.median(infer_input_lens or 0) * 1, median_s_input_len=np.median(infer_input_lens or 0) * 1,
percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p)) percentiles_s_input_len=[(p, np.percentile(infer_input_lens or 0, p)) for p in selected_percentiles],
for p in selected_percentiles],
mean_output_len=np.mean(actual_output_lens or 0) * 1, mean_output_len=np.mean(actual_output_lens or 0) * 1,
std_output_len=np.std(actual_output_lens or 0) * 1, std_output_len=np.std(actual_output_lens or 0) * 1,
median_output_len=np.median(actual_output_lens or 0) * 1, median_output_len=np.median(actual_output_lens or 0) * 1,
percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) for p in selected_percentiles],
for p in selected_percentiles],
) )
return metrics, actual_output_lens return metrics, actual_output_lens
@@ -351,20 +334,22 @@ async def benchmark(
if lora_modules: if lora_modules:
# For each input request, choose a LoRA module at random. # For each input request, choose a LoRA module at random.
lora_modules = iter( lora_modules = iter([random.choice(lora_modules) for _ in range(len(input_requests))])
[random.choice(lora_modules) \
for _ in range(len(input_requests))])
if profile: if profile:
print("Starting profiler...") print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id, test_prompt = None
model_name=model_name, test_output_len = None
prompt=test_prompt, profile_input = RequestFuncInput(
api_url=base_url + "/start_profile", model=model_id,
output_len=test_output_len, model_name=model_name,
logprobs=logprobs, prompt=test_prompt,
ignore_eos=ignore_eos, api_url=base_url + "/start_profile",
extra_body=extra_body) output_len=test_output_len,
logprobs=logprobs,
ignore_eos=ignore_eos,
extra_body=extra_body,
)
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(request_func_input=profile_input)
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
@@ -384,19 +369,16 @@ async def benchmark(
# and it will simplify the code in limited_request_func. # and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency) # semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext()) # if max_concurrency else contextlib.nullcontext())
semaphore = (asyncio.Semaphore(max_concurrency) semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
if max_concurrency else None)
async def limited_request_func(request_func_input, pbar): async def limited_request_func(request_func_input, pbar):
if semaphore is None: if semaphore is None:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar)
async with semaphore: async with semaphore:
return await request_func(request_func_input=request_func_input, return await request_func(request_func_input=request_func_input, pbar=pbar)
pbar=pbar)
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
print(f"开始时间:{datetime.now()}") print(f"开始时间:{datetime.now()}")
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
async for request in get_request(input_requests, request_rate, burstiness): async for request in get_request(input_requests, request_rate, burstiness):
@@ -409,25 +391,26 @@ async def benchmark(
req_lora_module = next(lora_modules) req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(model=req_model_id, request_func_input = RequestFuncInput(
model_name=req_model_name, model=req_model_id,
prompt=prompt, model_name=req_model_name,
prompt_len=0, prompt=prompt,
history_QA=history_QA, prompt_len=0,
hyper_parameters=hyper_parameters, history_QA=history_QA,
api_url=api_url, hyper_parameters=hyper_parameters,
output_len=output_len, api_url=api_url,
logprobs=logprobs, output_len=output_len,
ignore_eos=ignore_eos, logprobs=logprobs,
extra_body=extra_body) ignore_eos=ignore_eos,
tasks.append( extra_body=extra_body,
asyncio.create_task( )
limited_request_func(request_func_input=request_func_input, tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
pbar=pbar)))
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
print(f"完成时间:{datetime.now()}") print(f"完成时间:{datetime.now()}")
if profile: if profile:
print("Stopping profiler...") print("Stopping profiler...")
test_output_len = None
test_output_len = None
profile_input = RequestFuncInput( profile_input = RequestFuncInput(
model=model_id, model=model_id,
prompt=test_prompt, prompt=test_prompt,
@@ -454,22 +437,16 @@ async def benchmark(
) )
print("Benchmark complete!!!") print("Benchmark complete!!!")
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
metrics.total_output)) print("{:<40} {:<10.3f}".format("Request throughput (req/s):", metrics.request_throughput))
print("{:<40} {:<10.3f}".format("Request throughput (req/s):",
metrics.request_throughput))
if goodput_config_dict: if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput))
metrics.request_goodput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput))
metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
metrics.total_token_throughput))
result = { result = {
"duration": benchmark_duration, "duration": benchmark_duration,
@@ -477,8 +454,7 @@ async def benchmark(
"total_input_tokens": metrics.total_input, "total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output, "total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput, "request_throughput": metrics.request_throughput,
"request_goodput:": "request_goodput:": (metrics.request_goodput if goodput_config_dict else None),
metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput, "output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput, "total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
@@ -491,7 +467,6 @@ async def benchmark(
"reasoning_contents": [output.reasoning_content for output in outputs], "reasoning_contents": [output.reasoning_content for output in outputs],
"errors": [output.error for output in outputs], "errors": [output.error for output in outputs],
} }
quick_result = copy.deepcopy(result)
def process_one_metric( def process_one_metric(
# E.g., "ttft" # E.g., "ttft"
@@ -505,24 +480,25 @@ async def benchmark(
# metric. # metric.
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print("{:<40} {:<10.2f}".format( print(
f"Mean {metric_name} (ms):", "{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}_ms"))) f"Mean {metric_name} (ms):",
print("{:<40} {:<10.2f}".format( getattr(metrics, f"mean_{metric_attribute_name}_ms"),
f"Median {metric_name} (ms):", )
getattr(metrics, f"median_{metric_attribute_name}_ms"))) )
result[f"mean_{metric_attribute_name}_ms"] = getattr( print(
metrics, f"mean_{metric_attribute_name}_ms") "{:<40} {:<10.2f}".format(
result[f"median_{metric_attribute_name}_ms"] = getattr( f"Median {metric_name} (ms):",
metrics, f"median_{metric_attribute_name}_ms") getattr(metrics, f"median_{metric_attribute_name}_ms"),
result[f"std_{metric_attribute_name}_ms"] = getattr( )
metrics, f"std_{metric_attribute_name}_ms") )
for p, value in getattr(metrics, result[f"mean_{metric_attribute_name}_ms"] = getattr(metrics, f"mean_{metric_attribute_name}_ms")
f"percentiles_{metric_attribute_name}_ms"): result[f"median_{metric_attribute_name}_ms"] = getattr(metrics, f"median_{metric_attribute_name}_ms")
result[f"std_{metric_attribute_name}_ms"] = getattr(metrics, f"std_{metric_attribute_name}_ms")
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p) p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value result[f"p{p_word}_{metric_attribute_name}_ms"] = value
def process_one_length( def process_one_length(
@@ -537,31 +513,31 @@ async def benchmark(
# metric. # metric.
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print("{:<40} {:<10.2f}".format( print(
f"Mean {metric_name}:", "{:<40} {:<10.2f}".format(
getattr(metrics, f"mean_{metric_attribute_name}"))) f"Mean {metric_name}:",
print("{:<40} {:<10.2f}".format( getattr(metrics, f"mean_{metric_attribute_name}"),
f"Median {metric_name}:", )
getattr(metrics, f"median_{metric_attribute_name}"))) )
result[f"mean_{metric_attribute_name}"] = getattr( print(
metrics, f"mean_{metric_attribute_name}") "{:<40} {:<10.2f}".format(
result[f"median_{metric_attribute_name}"] = getattr( f"Median {metric_name}:",
metrics, f"median_{metric_attribute_name}") getattr(metrics, f"median_{metric_attribute_name}"),
result[f"std_{metric_attribute_name}"] = getattr( )
metrics, f"std_{metric_attribute_name}") )
for p, value in getattr(metrics, result[f"mean_{metric_attribute_name}"] = getattr(metrics, f"mean_{metric_attribute_name}")
f"percentiles_{metric_attribute_name}"): result[f"median_{metric_attribute_name}"] = getattr(metrics, f"median_{metric_attribute_name}")
result[f"std_{metric_attribute_name}"] = getattr(metrics, f"std_{metric_attribute_name}")
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}"):
p_word = str(int(p)) if int(p) == p else str(p) p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name}:", value))
value))
result[f"p{p_word}_{metric_attribute_name}"] = value result[f"p{p_word}_{metric_attribute_name}"] = value
process_one_length("s_decode", "Decode", "解码速度(tok/s)") process_one_length("s_decode", "Decode", "解码速度(tok/s)")
process_one_metric("ttft", "TTFT", "Time to First Token") process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token") process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_one_metric("tpot", "TPOT", process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
"Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency") process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency")
@@ -581,6 +557,7 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
""" """
快速评估 快速评估
""" """
def process_quick_metric( def process_quick_metric(
metric_attribute_name: str, metric_attribute_name: str,
metric_name: str, metric_name: str,
@@ -588,7 +565,7 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
): ):
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
mean_value = getattr(metrics, f"mean_{metric_attribute_name}_ms") mean_value = getattr(metrics, f"mean_{metric_attribute_name}_ms")
print("{:<40} {:<10.2f}".format(f"Mean {metric_name} (ms):", mean_value)) print("{:<40} {:<10.2f}".format(f"Mean {metric_name} (ms):", mean_value))
quick_result[f"mean_{metric_attribute_name}_ms"] = mean_value quick_result[f"mean_{metric_attribute_name}_ms"] = mean_value
@@ -600,17 +577,17 @@ def quick_summary(quick_result, selected_percentile_metrics, metrics):
): ):
if metric_attribute_name not in selected_percentile_metrics: if metric_attribute_name not in selected_percentile_metrics:
return return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
mean_value = getattr(metrics, f"mean_{metric_attribute_name}") mean_value = getattr(metrics, f"mean_{metric_attribute_name}")
print("{:<40} {:<10.2f}".format(f"Mean {metric_name}:", mean_value)) print("{:<40} {:<10.2f}".format(f"Mean {metric_name}:", mean_value))
quick_result[f"mean_{metric_attribute_name}"] = mean_value quick_result[f"mean_{metric_attribute_name}"] = mean_value
print("\n\n\n") print("\n\n\n")
print("{s:{c}^{n}}".format(s=' Benchmark Quick Summary ', n=50, c='=')) print("{s:{c}^{n}}".format(s=" Benchmark Quick Summary ", n=50, c="="))
process_quick_length("s_decode", "Decode", "解码速度(tok/s)") process_quick_length("s_decode", "Decode", "解码速度(tok/s)")
process_quick_metric("ttft", "TTFT", "Time to First Token") process_quick_metric("ttft", "TTFT", "Time to First Token")
process_quick_metric("s_ttft", "S_TTFT", "Infer Time to First Token") process_quick_metric("s_ttft", "S_TTFT", "Infer Time to First Token")
process_quick_metric("tpot", "TPOT", process_quick_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
"Time per Output Token (excl. 1st token)")
process_quick_metric("itl", "ITL", "Inter-token Latency") process_quick_metric("itl", "ITL", "Inter-token Latency")
process_quick_metric("s_itl", "S_ITL", "Infer Inter-token Latency") process_quick_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
process_quick_metric("e2el", "E2EL", "End-to-end Latency") process_quick_metric("e2el", "E2EL", "End-to-end Latency")
@@ -633,12 +610,14 @@ def check_goodput_args(args):
raise ValueError( raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. " f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of " "The service level objective name should be one of "
f"{str(VALID_NAMES)}. ") f"{VALID_NAMES!s}. "
)
if slo_val < 0: if slo_val < 0:
raise ValueError( raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. " f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be " "The service level objective value should be "
"non-negative.") "non-negative."
)
return goodput_config_dict return goodput_config_dict
@@ -652,37 +631,43 @@ def parse_goodput(slo_pairs):
except ValueError as err: except ValueError as err:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. " "Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" " 'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a " "pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err "number in milliseconds."
) from err
return goodput_config_dict return goodput_config_dict
def save_to_pytorch_benchmark_format(args: argparse.Namespace, def save_to_pytorch_benchmark_format(args: argparse.Namespace, results: dict[str, Any], file_name: str) -> None:
results: dict[str, Any],
file_name: str) -> None:
"""Save the benchmarking results to PyTorch Benchmark Format JSON file""" """Save the benchmarking results to PyTorch Benchmark Format JSON file"""
metrics = [ metrics = [
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", "median_ttft_ms",
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", "mean_ttft_ms",
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" "std_ttft_ms",
"p99_ttft_ms",
"mean_tpot_ms",
"median_tpot_ms",
"std_tpot_ms",
"p99_tpot_ms",
"median_itl_ms",
"mean_itl_ms",
"std_itl_ms",
"p99_itl_ms",
] ]
# These raw data might be useful, but they are rather big. They can be added # These raw data might be useful, but they are rather big. They can be added
# later if needed # later if needed
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
pt_records = convert_to_pytorch_benchmark_format( pt_records = convert_to_pytorch_benchmark_format(
args=args, args=args,
metrics={k: [results[k]] metrics={k: [results[k]] for k in metrics},
for k in metrics}, extra_info={k: results[k] for k in results if k not in metrics and k not in ignored_metrics},
extra_info={ )
k: results[k]
for k in results if k not in metrics and k not in ignored_metrics
})
if pt_records: if pt_records:
# Don't use json suffix here as we don't want CI to pick it up # Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
write_to_json(pt_file, pt_records) write_to_json(pt_file, pt_records)
def check_health(api_base_url: str) -> bool: def check_health(api_base_url: str) -> bool:
health_url = api_base_url.rstrip("/") + "/health" health_url = api_base_url.rstrip("/") + "/health"
try: try:
@@ -697,6 +682,7 @@ def check_health(api_base_url: str) -> bool:
print(f"[HEALTH] Failed to connect to {health_url}: {e}") print(f"[HEALTH] Failed to connect to {health_url}: {e}")
return False return False
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
"""Main entry point""" """Main entry point"""
print(args) print(args)
@@ -707,7 +693,6 @@ def main(args: argparse.Namespace):
model_id = args.model model_id = args.model
model_name = args.served_model_name model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode
if args.base_url is not None: if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}" api_url = f"{args.base_url}{args.endpoint}"
@@ -717,23 +702,17 @@ def main(args: argparse.Namespace):
base_url = f"http://{args.host}:{args.port}" base_url = f"http://{args.host}:{args.port}"
if args.dataset_name is None: if args.dataset_name is None:
raise ValueError( raise ValueError("Please specify '--dataset-name' and the corresponding " "'--dataset-path' if required.")
"Please specify '--dataset-name' and the corresponding "
"'--dataset-path' if required.")
# For datasets that follow a similar structure, use a mapping. # For datasets that follow a similar structure, use a mapping.
dataset_mapping = { dataset_mapping = {
"EB": "EB": lambda: EBDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
lambda: EBDataset(random_seed=args.seed, num_requests=args.num_prompts,
dataset_path=args.dataset_path).sample( output_len=args.sharegpt_output_len,
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
), ),
"EBChat": "EBChat": lambda: EBChatDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample(
lambda: EBChatDataset(random_seed=args.seed, num_requests=args.num_prompts,
dataset_path=args.dataset_path).sample( output_len=args.sharegpt_output_len,
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
), ),
} }
@@ -751,15 +730,14 @@ def main(args: argparse.Namespace):
"top_p": args.top_p, "top_p": args.top_p,
"top_k": args.top_k, "top_k": args.top_k,
"min_p": args.min_p, "min_p": args.min_p,
"temperature": args.temperature "temperature": args.temperature,
}.items() if v is not None }.items()
if v is not None
} }
# Sampling parameters are only supported by openai-compatible backend. # Sampling parameters are only supported by openai-compatible backend.
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
raise ValueError( raise ValueError("Sampling parameters are only supported by openai-compatible " "backends.")
"Sampling parameters are only supported by openai-compatible "
"backends.")
if "temperature" not in sampling_params: if "temperature" not in sampling_params:
sampling_params["temperature"] = 0.0 # Default to greedy decoding. sampling_params["temperature"] = 0.0 # Default to greedy decoding.
@@ -790,15 +768,14 @@ def main(args: argparse.Namespace):
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
profile=args.profile, profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","), selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[ selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
float(p) for p in args.metric_percentiles.split(",")
],
ignore_eos=args.ignore_eos, ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
extra_body=sampling_params, extra_body=sampling_params,
)) )
)
# Save config and results to json # Save config and results to json
if args.save_result: if args.save_result:
@@ -819,22 +796,23 @@ def main(args: argparse.Namespace):
kvstring = item.split("=") kvstring = item.split("=")
result_json[kvstring[0].strip()] = kvstring[1].strip() result_json[kvstring[0].strip()] = kvstring[1].strip()
else: else:
raise ValueError( raise ValueError("Invalid metadata format. Please use KEY=VALUE format.")
"Invalid metadata format. Please use KEY=VALUE format."
)
if not args.save_detailed: if not args.save_detailed:
# Remove fields with too many data points # Remove fields with too many data points
for field in [ for field in [
"input_lens", "output_lens", "ttfts", "itls", "input_lens",
"generated_texts", "errors" "output_lens",
"ttfts",
"itls",
"generated_texts",
"errors",
]: ]:
if field in result_json: if field in result_json:
del result_json[field] del result_json[field]
# Traffic # Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate result_json["request_rate"] = args.request_rate if args.request_rate < float("inf") else "inf"
< float("inf") else "inf")
result_json["burstiness"] = args.burstiness result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency result_json["max_concurrency"] = args.max_concurrency
@@ -843,21 +821,19 @@ def main(args: argparse.Namespace):
# Save to file # Save to file
base_model_id = model_id.split("/")[-1] base_model_id = model_id.split("/")[-1]
max_concurrency_str = (f"-concurrency{args.max_concurrency}" max_concurrency_str = f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else ""
if args.max_concurrency is not None else "") file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa
if args.result_filename: if args.result_filename:
file_name = args.result_filename file_name = args.result_filename
if args.result_dir: if args.result_dir:
file_name = os.path.join(args.result_dir, file_name) file_name = os.path.join(args.result_dir, file_name)
with open(file_name, "w", encoding='utf-8') as outfile: with open(file_name, "w", encoding="utf-8") as outfile:
json.dump(result_json, outfile) json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name) save_to_pytorch_benchmark_format(args, result_json, file_name)
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(description="Benchmark the online serving throughput.")
description="Benchmark the online serving throughput.")
parser.add_argument( parser.add_argument(
"--backend", "--backend",
type=str, type=str,
@@ -883,18 +859,29 @@ if __name__ == "__main__":
"--dataset-name", "--dataset-name",
type=str, type=str,
default="sharegpt", default="sharegpt",
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "EB", "EBChat"], choices=[
"sharegpt",
"burstgpt",
"sonnet",
"random",
"hf",
"EB",
"EBChat",
],
help="Name of the dataset to benchmark on.", help="Name of the dataset to benchmark on.",
) )
parser.add_argument("--dataset-path", parser.add_argument(
type=str, "--dataset-path",
default=None, type=str,
help="Path to the sharegpt/sonnet dataset. " default=None,
"Or the huggingface dataset ID if using HF dataset.") help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.",
parser.add_argument("--hyperparameter-path", )
type=str, parser.add_argument(
default=None, "--hyperparameter-path",
help="Path to the hyperparameter. ") type=str,
default=None,
help="Path to the hyperparameter. ",
)
parser.add_argument( parser.add_argument(
"--max-concurrency", "--max-concurrency",
type=int, type=int,
@@ -906,7 +893,8 @@ if __name__ == "__main__":
"initiated, this argument will control how many are actually allowed " "initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the " "to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, " "actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.") "if the server is not processing requests fast enough to keep up.",
)
parser.add_argument( parser.add_argument(
"--model", "--model",
@@ -917,7 +905,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--tokenizer", "--tokenizer",
type=str, type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 help="Name or path of the tokenizer, if not using the default tokenizer.",
) )
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument( parser.add_argument(
@@ -930,11 +918,13 @@ if __name__ == "__main__":
"--logprobs", "--logprobs",
type=int, type=int,
default=None, default=None,
help=("Number of logprobs-per-token to compute & return as part of " help=(
"the request. If unspecified, then either (1) if beam search " "Number of logprobs-per-token to compute & return as part of "
"is disabled, no logprobs are computed & a single dummy " "the request. If unspecified, then either (1) if beam search "
"logprob is returned for each token; or (2) if beam search " "is disabled, no logprobs are computed & a single dummy "
"is enabled 1 logprob per token is computed"), "logprob is returned for each token; or (2) if beam search "
"is enabled 1 logprob per token is computed"
),
) )
parser.add_argument( parser.add_argument(
"--request-rate", "--request-rate",
@@ -971,8 +961,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--profile", "--profile",
action="store_true", action="store_true",
help="Use Torch Profiler. The endpoint must be launched with " help="Use Torch Profiler. The endpoint must be launched with " "VLLM_TORCH_PROFILER_DIR to enable profiler.",
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
) )
parser.add_argument( parser.add_argument(
"--save-result", "--save-result",
@@ -1013,35 +1002,38 @@ if __name__ == "__main__":
"--ignore-eos", "--ignore-eos",
action="store_true", action="store_true",
help="Set ignore_eos flag when sending the benchmark request." help="Set ignore_eos flag when sending the benchmark request."
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.") "Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
)
parser.add_argument( parser.add_argument(
"--percentile-metrics", "--percentile-metrics",
type=str, type=str,
default="ttft,tpot,itl", default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentils. " help="Comma-separated list of selected metrics to report percentils. "
"This argument specifies the metrics to report percentiles. " "This argument specifies the metrics to report percentiles. "
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " 'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
"Default value is \"ttft,tpot,itl\".") 'Default value is "ttft,tpot,itl".',
)
parser.add_argument( parser.add_argument(
"--metric-percentiles", "--metric-percentiles",
type=str, type=str,
default="99", default="99",
help="Comma-separated list of percentiles for selected metrics. " help="Comma-separated list of percentiles for selected metrics. "
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
"Default value is \"99\". " 'Default value is "99". '
"Use \"--percentile-metrics\" to select metrics.", 'Use "--percentile-metrics" to select metrics.',
) )
parser.add_argument( parser.add_argument(
"--goodput", "--goodput",
nargs="+", nargs="+",
required=False, required=False,
help="Specify service level objectives for goodput as \"KEY:VALUE\" " help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in " "pairs, where the key is a metric name, and the value is in "
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are " "separated by spaces. Allowed request level metric names are "
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " '"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve") "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
# group for dataset specific arguments # group for dataset specific arguments
sonnet_group = parser.add_argument_group("sonnet dataset options") sonnet_group = parser.add_argument_group("sonnet dataset options")
@@ -1069,8 +1061,8 @@ if __name__ == "__main__":
"--sharegpt-output-len", "--sharegpt-output-len",
type=int, type=int,
default=None, default=None,
help="Output length for each request. Overrides the output length " help="Output length for each request. Overrides the output length " "from the ShareGPT dataset.",
"from the ShareGPT dataset.") )
random_group = parser.add_argument_group("random dataset options") random_group = parser.add_argument_group("random dataset options")
random_group.add_argument( random_group.add_argument(
@@ -1098,29 +1090,24 @@ if __name__ == "__main__":
"--random-prefix-len", "--random-prefix-len",
type=int, type=int,
default=0, default=0,
help=("Number of fixed prefix tokens before the random context " help=(
"in a request. " "Number of fixed prefix tokens before the random context "
"The total input length is the sum of `random-prefix-len` and " "in a request. "
"a random " "The total input length is the sum of `random-prefix-len` and "
"context length sampled from [input_len * (1 - range_ratio), " "a random "
"input_len * (1 + range_ratio)]."), "context length sampled from [input_len * (1 - range_ratio), "
"input_len * (1 + range_ratio)]."
),
) )
hf_group = parser.add_argument_group("hf dataset options") hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset", hf_group.add_argument("--hf-subset", type=str, default=None, help="Subset of the HF dataset.")
type=str, hf_group.add_argument("--hf-split", type=str, default=None, help="Split of the HF dataset.")
default=None,
help="Subset of the HF dataset.")
hf_group.add_argument("--hf-split",
type=str,
default=None,
help="Split of the HF dataset.")
hf_group.add_argument( hf_group.add_argument(
"--hf-output-len", "--hf-output-len",
type=int, type=int,
default=None, default=None,
help="Output length for each request. Overrides the output lengths " help="Output length for each request. Overrides the output lengths " "from the sampled HF dataset.",
"from the sampled HF dataset.",
) )
sampling_group = parser.add_argument_group("sampling parameters") sampling_group = parser.add_argument_group("sampling parameters")
@@ -1128,52 +1115,58 @@ if __name__ == "__main__":
"--top-p", "--top-p",
type=float, type=float,
default=None, default=None,
help="Top-p sampling parameter. Only has effect on openai-compatible " help="Top-p sampling parameter. Only has effect on openai-compatible " "backends.",
"backends.") )
sampling_group.add_argument( sampling_group.add_argument(
"--top-k", "--top-k",
type=int, type=int,
default=None, default=None,
help="Top-k sampling parameter. Only has effect on openai-compatible " help="Top-k sampling parameter. Only has effect on openai-compatible " "backends.",
"backends.") )
sampling_group.add_argument( sampling_group.add_argument(
"--min-p", "--min-p",
type=float, type=float,
default=None, default=None,
help="Min-p sampling parameter. Only has effect on openai-compatible " help="Min-p sampling parameter. Only has effect on openai-compatible " "backends.",
"backends.") )
sampling_group.add_argument( sampling_group.add_argument(
"--temperature", "--temperature",
type=float, type=float,
default=None, default=None,
help="Temperature sampling parameter. Only has effect on " help="Temperature sampling parameter. Only has effect on "
"openai-compatible backends. If not specified, default to greedy " "openai-compatible backends. If not specified, default to greedy "
"decoding (i.e. temperature==0.0).") "decoding (i.e. temperature==0.0).",
)
parser.add_argument( parser.add_argument(
'--tokenizer-mode', "--tokenizer-mode",
type=str, type=str,
default="auto", default="auto",
choices=['auto', 'slow', 'mistral', 'custom'], choices=["auto", "slow", "mistral", "custom"],
help='The tokenizer mode.\n\n* "auto" will use the ' help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will ' 'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* ' "always use the slow tokenizer. \n* "
'"mistral" will always use the `mistral_common` tokenizer. \n*' '"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.') '"custom" will use --tokenizer to select the preregistered tokenizer.',
)
parser.add_argument("--served-model-name", parser.add_argument(
type=str, "--served-model-name",
default=None, type=str,
help="The model name used in the API. " default=None,
"If not specified, the model name will be the " help="The model name used in the API. "
"same as the ``--model`` argument. ") "If not specified, the model name will be the "
"same as the ``--model`` argument. ",
)
parser.add_argument("--lora-modules", parser.add_argument(
nargs='+', "--lora-modules",
default=None, nargs="+",
help="A subset of LoRA module names passed in when " default=None,
"launching the server. For each request, the " help="A subset of LoRA module names passed in when "
"script chooses a LoRA module at random.") "launching the server. For each request, the "
"script chooses a LoRA module at random.",
)
args = parser.parse_args() args = parser.parse_args()

View File

@@ -7,4 +7,4 @@ tensor_parallel_size: 1
enable_chunked_prefill: True enable_chunked_prefill: True
max_num_batched_tokens: 384 max_num_batched_tokens: 384
quantization: wint4 quantization: wint4
reasoning_parser: ernie-45-vl reasoning_parser: ernie-45-vl

View File

@@ -12,4 +12,4 @@ rdma_comm_ports: "7671,7672,7673,7674"
pd_comm_port: "2334" pd_comm_port: "2334"
max_num_batched_tokens: 384 max_num_batched_tokens: 384
max_num_partial_prefills: 3 max_num_partial_prefills: 3
max_long_partial_prefills: 3 max_long_partial_prefills: 3

View File

@@ -9,4 +9,4 @@ cache_queue_port: 55664
engine_worker_queue_port: 6677 engine_worker_queue_port: 6677
cache_transfer_protocol: "rdma,ipc" cache_transfer_protocol: "rdma,ipc"
rdma_comm_ports: "7675,7676,7677,7678" rdma_comm_ports: "7675,7676,7677,7678"
pd_comm_port: "2333" pd_comm_port: "2333"

View File

@@ -10,4 +10,4 @@ engine_worker_queue_port: 6677
num_gpu_blocks_override: 1024 num_gpu_blocks_override: 1024
cache_transfer_protocol: "rdma" cache_transfer_protocol: "rdma"
rdma_comm_ports: "7671,7672,7673,7674,7675,7676,7677,7678" rdma_comm_ports: "7671,7672,7673,7674,7675,7676,7677,7678"
pd_comm_port: "2334" pd_comm_port: "2334"

View File

@@ -10,4 +10,4 @@ splitwise_role: decode
engine_worker_queue_port: 6678 engine_worker_queue_port: 6678
cache_transfer_protocol: "rdma,ipc" cache_transfer_protocol: "rdma,ipc"
rdma_comm_ports: "7671,7672,7673,7674" rdma_comm_ports: "7671,7672,7673,7674"
pd_comm_port: "2334" pd_comm_port: "2334"

View File

@@ -9,4 +9,4 @@ cache_queue_port: 55664
engine_worker_queue_port: 6677 engine_worker_queue_port: 6677
cache_transfer_protocol: "rdma,ipc" cache_transfer_protocol: "rdma,ipc"
rdma_comm_ports: "7675,7676,7677,7678" rdma_comm_ports: "7675,7676,7677,7678"
pd_comm_port: "2333" pd_comm_port: "2333"

View File

@@ -12,4 +12,4 @@ rdma_comm_ports: "7671,7672,7673,7674"
pd_comm_port: "2334" pd_comm_port: "2334"
max_num_batched_tokens: 384 max_num_batched_tokens: 384
max_num_partial_prefills: 3 max_num_partial_prefills: 3
max_long_partial_prefills: 3 max_long_partial_prefills: 3

View File

@@ -9,4 +9,4 @@ cache_queue_port: 55664
engine_worker_queue_port: 6677 engine_worker_queue_port: 6677
cache_transfer_protocol: "rdma,ipc" cache_transfer_protocol: "rdma,ipc"
rdma_comm_ports: "7675,7676,7677,7678" rdma_comm_ports: "7675,7676,7677,7678"
pd_comm_port: "2333" pd_comm_port: "2333"

View File

@@ -3,4 +3,4 @@ max_num_seqs: 75
gpu_memory_utilization: 0.85 gpu_memory_utilization: 0.85
kv_cache_ratio: 0.75 kv_cache_ratio: 0.75
quantization: wint4 quantization: wint4
tensor_parallel_size: 4 tensor_parallel_size: 4

View File

@@ -3,4 +3,4 @@ max_num_seqs: 25
gpu_memory_utilization: 0.9 gpu_memory_utilization: 0.9
kv_cache_ratio: 0.75 kv_cache_ratio: 0.75
quantization: wint8 quantization: wint8
tensor_parallel_size: 4 tensor_parallel_size: 4

View File

@@ -1,3 +1,3 @@
metadata: metadata:
min_tokens: 32 min_tokens: 32
max_tokens: 33 max_tokens: 33

View File

@@ -5,4 +5,4 @@ metadata:
max_tokens: 12288 max_tokens: 12288
repetition_penalty: 1.05 repetition_penalty: 1.05
frequency_penalty: 0 frequency_penalty: 0
presence_penalty: 0 presence_penalty: 0

View File

@@ -5,4 +5,4 @@ metadata:
max_tokens: 12288 max_tokens: 12288
repetition_penalty: 1.0 repetition_penalty: 1.0
frequency_penalty: 0 frequency_penalty: 0
presence_penalty: 1.5 presence_penalty: 1.5

View File

@@ -8,4 +8,4 @@ frequency_penalty: 0
presence_penalty: 0 presence_penalty: 0
skip_special_tokens: false skip_special_tokens: false
chat_template_kwargs: chat_template_kwargs:
enable_thinking: true enable_thinking: true

View File

@@ -3,4 +3,4 @@ max_num_seqs: 64
gpu_memory_utilization: 0.9 gpu_memory_utilization: 0.9
tensor_parallel_size: 8 tensor_parallel_size: 8
quantization: wint8 quantization: wint8
reasoning_parser: ernie-x1 reasoning_parser: ernie-x1

View File

@@ -40,4 +40,4 @@ void DecoderWriteCacheWithRoPEKernel(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out);

View File

@@ -216,7 +216,7 @@ __global__ void append_dequant_cache_kv_c8(
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>( uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM +
tid % 8 * num_elems_per_128b<CacheT>(); tid % 8 * num_elems_per_128b<CacheT>();
@@ -330,7 +330,7 @@ __global__ void append_dequant_cache_kv_c8(
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale; v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale;
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale; v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale;
convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T convert_c8<T,IS_FP8>(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T
#ifdef C8_DEBUG #ifdef C8_DEBUG
if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) { if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) {
@@ -373,14 +373,14 @@ void AppendDequantCache(
paddle::Tensor *k_out, paddle::Tensor *k_out,
paddle::Tensor *v_out, paddle::Tensor *v_out,
const cudaStream_t& stream const cudaStream_t& stream
) { ) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type; using NV_TYPE = typename cascade_attn_type_traits<T>::type;
if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
constexpr int NUM_WARPS = 4; constexpr int NUM_WARPS = 4;
int block_num = cache_num_blocks_x.data<int>()[0]; int block_num = cache_num_blocks_x.data<int>()[0];
dim3 grids(block_num, 1, kv_num_heads); dim3 grids(block_num, 1, kv_num_heads);
dim3 blocks(32, NUM_WARPS); dim3 blocks(32, NUM_WARPS);
const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2; const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2;
auto kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>; auto kernel_func = append_dequant_cache_kv_c8<NV_TYPE, uint8_t, HEAD_DIM, BLOCK_SIZE, NUM_WARPS, false>;

View File

@@ -41,7 +41,7 @@ __global__ void append_clear_cache_int8_block(
const int wid = tid / 32; const int wid = tid / 32;
const int lane_id = tid % 32; const int lane_id = tid % 32;
const int token_id = blockIdx.x; const int token_id = blockIdx.x;
const int bid = batch_id_per_token[token_id]; const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid]; const int start_token_idx = cu_seqlens_q[bid];
@@ -115,7 +115,7 @@ __global__ void append_clear_cache_int4_block(
const int wid = tid / 32; const int wid = tid / 32;
const int lane_id = tid % 32; const int lane_id = tid % 32;
const int token_id = blockIdx.x; const int token_id = blockIdx.x;
const int bid = batch_id_per_token[token_id]; const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid]; const int start_token_idx = cu_seqlens_q[bid];
@@ -484,7 +484,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int wid = tid / 32; const int wid = tid / 32;
const int lane_id = tid % 32; const int lane_id = tid % 32;
const int token_id = blockIdx.x; const int token_id = blockIdx.x;
const int bid = batch_id_per_token[token_id]; const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid]; const int start_token_idx = cu_seqlens_q[bid];
@@ -716,7 +716,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
const int wid = tid / 32; const int wid = tid / 32;
const int lane_id = tid % 32; const int lane_id = tid % 32;
const int token_id = blockIdx.x; const int token_id = blockIdx.x;
const int bid = batch_id_per_token[token_id]; const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid]; const int start_token_idx = cu_seqlens_q[bid];
@@ -1097,7 +1097,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
const int lane_id = tid % 32; const int lane_id = tid % 32;
const int token_id = blockIdx.x; const int token_id = blockIdx.x;
const int bid = batch_id_per_token[token_id]; const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid]; const int start_token_idx = cu_seqlens_q[bid];
@@ -1403,7 +1403,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
const int lane_id = tid % 32; const int lane_id = tid % 32;
const int token_id = blockIdx.x; const int token_id = blockIdx.x;
const int bid = batch_id_per_token[token_id]; const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid]; const int start_token_idx = cu_seqlens_q[bid];
@@ -1792,4 +1792,4 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
(uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F);
} }
} }
} }

View File

@@ -582,4 +582,4 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out);

View File

@@ -39,4 +39,4 @@ void SpeculateWriteCacheWithRoPEKernel(
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out);

View File

@@ -30,4 +30,4 @@ inline int getSMVersion()
return sm_major * 10 + sm_minor; return sm_major * 10 + sm_minor;
} }
} }

View File

@@ -136,4 +136,4 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
ElementAccumulator, DefaultScaleMode>; ElementAccumulator, DefaultScaleMode>;
}; };
} // namespace cutlass_extensions } // namespace cutlass_extensions

View File

@@ -1,11 +1,11 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

View File

@@ -1,11 +1,11 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -54,7 +54,7 @@
///////////////////////////////////FP8 Accumulation/////////////////////////// ///////////////////////////////////FP8 Accumulation///////////////////////////
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
/// This class provides API to promote (add) or scale (multiply_add) the results /// This class provides API to promote (add) or scale (multiply_add) the results
/// from the tensor core accumulators to the main accumulators when the number /// from the tensor core accumulators to the main accumulators when the number
/// of MMAs reaches the max number of MMA interval specified by user, after that /// of MMAs reaches the max number of MMA interval specified by user, after that
/// the tensor core accumulators are zeroed. /// the tensor core accumulators are zeroed.
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@@ -64,7 +64,7 @@ namespace cutlass::gemm::collective {
template < template <
class EngineAccum, class EngineAccum,
class LayoutAccum> class LayoutAccum>
struct GmmaFP8AccumulationWithScale { struct GmmaFP8AccumulationWithScale {
using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>; using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>;
using ElementAccumulator = typename EngineAccum::value_type; using ElementAccumulator = typename EngineAccum::value_type;
@@ -78,7 +78,7 @@ private:
uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted.
uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop
uint32_t mma_count_; // current executed MMAs uint32_t mma_count_; // current executed MMAs
uint32_t reset_accum_flag_; // accum needs to be zeroed or not. uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
// promote or `add` the partial accumulators to main accumulator (FADD). // promote or `add` the partial accumulators to main accumulator (FADD).
CUTLASS_DEVICE CUTLASS_DEVICE
@@ -116,11 +116,11 @@ public:
TensorAccum &accum, TensorAccum &accum,
uint32_t accum_promotion_interval, uint32_t accum_promotion_interval,
uint32_t mma_count_per_mainloop_iteration) uint32_t mma_count_per_mainloop_iteration)
: accum_(accum), : accum_(accum),
accum_promotion_interval_(accum_promotion_interval), accum_promotion_interval_(accum_promotion_interval),
mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration),
mma_count_(0), mma_count_(0),
reset_accum_flag_(0) reset_accum_flag_(0)
{ {
accum_temp_ = cute::make_fragment_like(accum); accum_temp_ = cute::make_fragment_like(accum);
} }
@@ -129,14 +129,14 @@ public:
// Methods (Common) // Methods (Common)
// //
CUTLASS_DEVICE CUTLASS_DEVICE
TensorAccum& operator()() { TensorAccum& operator()() {
return accum_temp_; return accum_temp_;
} }
/// prepare the MMA accumulators when initialization or zeroing is required. /// prepare the MMA accumulators when initialization or zeroing is required.
CUTLASS_DEVICE CUTLASS_DEVICE
bool prepare_if_needed() { bool prepare_if_needed() {
return reset_accum_flag_; return reset_accum_flag_;
} }

View File

@@ -1,11 +1,11 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -137,7 +137,7 @@ struct CollectiveMma<
using PipelineParams = typename MainloopPipeline::Params; using PipelineParams = typename MainloopPipeline::Params;
// Two threads per CTA are producers (1 for operand tile and 32 for scales) // Two threads per CTA are producers (1 for operand tile and 32 for scales)
static constexpr int NumProducerThreadEvents = 33; static constexpr int NumProducerThreadEvents = 33;
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
@@ -161,11 +161,11 @@ struct CollectiveMma<
SmemLayoutAtomB{}, SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
// Block scaling gmem-to-smem copy atom // Block scaling gmem-to-smem copy atom
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>; using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>; using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
// Block scaling smem layout // Block scaling smem layout
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>; using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1. using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
@@ -202,7 +202,7 @@ struct CollectiveMma<
StrideA dA; StrideA dA;
ElementB const* ptr_B; ElementB const* ptr_B;
StrideB dB; StrideB dB;
ElementBlockScale const* ptr_scale_A; ElementBlockScale const* ptr_scale_A;
ElementBlockScale const* ptr_scale_B; ElementBlockScale const* ptr_scale_B;
}; };
@@ -228,7 +228,7 @@ struct CollectiveMma<
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
// Block scaling factors for A and B // Block scaling factors for A and B
ElementBlockScale const* ptr_scale_A; ElementBlockScale const* ptr_scale_A;
ElementBlockScale const* ptr_scale_B; ElementBlockScale const* ptr_scale_B;
}; };
@@ -285,7 +285,7 @@ struct CollectiveMma<
constexpr int tma_alignment_bits = 128; constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1); auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL; auto [M,N,K,L] = problem_shape_MNKL;
bool implementable = true; bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value; constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{}); implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
@@ -346,7 +346,7 @@ struct CollectiveMma<
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l) auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{}); auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l) Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l) Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
@@ -406,26 +406,26 @@ struct CollectiveMma<
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape()); Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
Tensor gScaleA = local_tile( Tensor gScaleA = local_tile(
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}), mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
Tensor cScaleA = local_tile( Tensor cScaleA = local_tile(
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}), cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
make_coord(m_coord,_,l_coord)); make_coord(m_coord,_,l_coord));
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1) Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128 // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1) Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1) Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA); Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA); Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA); Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB); Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
@@ -455,7 +455,7 @@ struct CollectiveMma<
} }
} }
// Allocate predicate tensors for a_scales (since we can't guarantee that // Allocate predicate tensors for a_scales (since we can't guarantee that
// all scales are valid, since we could have a partial tiles along M) // all scales are valid, since we could have a partial tiles along M)
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0))); Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
#pragma unroll #pragma unroll
@@ -536,7 +536,7 @@ struct CollectiveMma<
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
// Block scaling // Block scaling
Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
Layout< Layout<
@@ -548,17 +548,17 @@ struct CollectiveMma<
// //
// Define C accumulators and A/B partitioning // Define C accumulators and A/B partitioning
// //
// Layout of warp group to thread mapping // Layout of warp group to thread mapping
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
stride<0>(typename TiledMma::BLayout{}) == 0 and stride<0>(typename TiledMma::BLayout{}) == 0 and
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{}, Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{}); Int<NumThreadsPerWarpGroup>{});
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
@@ -590,7 +590,7 @@ struct CollectiveMma<
// We release buffers to producer warps(dma load) with some mmas in flight // We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read; PipelineState smem_pipe_release = smem_pipe_read;
// Per block scale values for operand A and B // Per block scale values for operand A and B
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
@@ -618,7 +618,7 @@ struct CollectiveMma<
} }
int read_stage = smem_pipe_read.index(); int read_stage = smem_pipe_read.index();
// Load per block scale values from shared memory to registers. // Load per block scale values from shared memory to registers.
scale_b = sScaleB[read_stage]; scale_b = sScaleB[read_stage];
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
@@ -668,7 +668,7 @@ struct CollectiveMma<
int read_stage = smem_pipe_read.index(); int read_stage = smem_pipe_read.index();
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N) // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
scale_b = sScaleB[read_stage]; scale_b = sScaleB[read_stage];
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
@@ -712,7 +712,7 @@ struct CollectiveMma<
++smem_pipe_read; ++smem_pipe_read;
++smem_pipe_release; ++smem_pipe_release;
} }
accumulation.scale_residue_if_needed(tCrScaleAViewAsC); accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
warpgroup_fence_operand(accumulation()); warpgroup_fence_operand(accumulation());

View File

@@ -1,11 +1,11 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -50,4 +50,4 @@ struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm } // namespace cutlass::gemm

View File

@@ -90,4 +90,4 @@ struct GemmMoeProblemVisitor
} // namespace gemm } // namespace gemm
} // namespace cutlass } // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -90,7 +90,7 @@ template <
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Used for partial specialization /// Used for partial specialization
typename Enable = bool> typename Enable = bool>
class Wint2xMmaMultistage : class Wint2xMmaMultistage :
public Wint2xMmaBase<Shape_, Policy_, Stages> { public Wint2xMmaBase<Shape_, Policy_, Stages> {
public: public:
///< Base class ///< Base class

View File

@@ -57,7 +57,7 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){
hasbias, hasbias,
ElementD, ElementD,
void>; void>;
constexpr int ScaleMsPerTile = size<0>(TileShape{}); constexpr int ScaleMsPerTile = size<0>(TileShape{});
constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile; constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile;
@@ -161,7 +161,7 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){
arguments.scheduler.decomposition_mode = DecompositionMode::StreamK; arguments.scheduler.decomposition_mode = DecompositionMode::StreamK;
arguments.scheduler.reduction_mode = ReductionMode::Nondeterministic; arguments.scheduler.reduction_mode = ReductionMode::Nondeterministic;
} }
Gemm gemm_op; Gemm gemm_op;

View File

@@ -170,4 +170,4 @@ bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) {
return false; return false;
} }
return true; return true;
} }

View File

@@ -148,4 +148,4 @@ bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) {
return false; return false;
} }
return true; return true;
} }

View File

@@ -54,7 +54,7 @@ public:
virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0; virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0;
virtual std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int k) const = 0; virtual std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int k) const = 0;
protected: protected:
static constexpr int SPLIT_K_LIMIT = 7; static constexpr int SPLIT_K_LIMIT = 7;
static constexpr int MIN_M_TILE = 16; static constexpr int MIN_M_TILE = 16;

View File

@@ -93,8 +93,8 @@ std::vector<paddle::DataType> ExtractTextTokenOutputInferDtype(const paddle::Dat
PD_BUILD_STATIC_OP(extract_text_token_output) PD_BUILD_STATIC_OP(extract_text_token_output)
.Inputs({"max_seq_len", .Inputs({"max_seq_len",
"max_seq_len_index", "max_seq_len_index",
"mm_token_num_len", "mm_token_num_len",
"seq_lens_this_time", "seq_lens_this_time",
"cu_seqlens_q", "cu_seqlens_q",
"score_text"}) "score_text"})

View File

@@ -105,7 +105,7 @@ __global__ void cudaCoreGemm(InputType const* __restrict__ act,
} }
} }
} }
__syncthreads(); __syncthreads();
for (int32_t ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) { for (int32_t ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) {
int32_t mid = ii / TILE_N, nid = ii % TILE_N; int32_t mid = ii / TILE_N, nid = ii % TILE_N;
@@ -188,4 +188,4 @@ bool cuda_core_gemm_launcher(GemmParams const& params) {
template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(GemmParams const&); template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(GemmParams const&); template bool cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(GemmParams const&); template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, __nv_bfloat16>(GemmParams const&);
template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(GemmParams const&); template bool cuda_core_gemm_launcher<__nv_fp8_e5m2, half>(GemmParams const&);

View File

@@ -61,7 +61,7 @@ std::vector<paddle::Tensor> GetMmSplitFuse(const paddle::Tensor& task_input_ids,
st_idx += cur_st_len; st_idx += cur_st_len;
} }
} }
while (idx < seq_lens_origin) { while (idx < seq_lens_origin) {
idx = idx + split_fuse_text_size; idx = idx + split_fuse_text_size;
if (idx >= seq_lens_origin) { if (idx >= seq_lens_origin) {
@@ -116,7 +116,7 @@ std::vector<paddle::Tensor> GetMmSplitFuse(const paddle::Tensor& task_input_ids,
while (ib < img_total && cur_img_len < chunk_image_token_number) { while (ib < img_total && cur_img_len < chunk_image_token_number) {
int token_times = 4; int token_times = 4;
cur_img_len += (grid_thw_cpu[ib * 3 + 1] * grid_thw_cpu[ib * 3 + 2]) / token_times; cur_img_len += (grid_thw_cpu[ib * 3 + 1] * grid_thw_cpu[ib * 3 + 2]) / token_times;
ib ++; ib ++;
chunk_image_number ++; chunk_image_number ++;
} }
image_chunk_selections_vector.emplace_back(chunk_image_number); image_chunk_selections_vector.emplace_back(chunk_image_number);

View File

@@ -88,7 +88,7 @@ void sent_key_value_by_remote_ptr(
#ifdef DEBUG_IPC_SENT #ifdef DEBUG_IPC_SENT
std::cout<<"remote_key_tensor_sent_ptr:"<<(int64_t)remote_key_tensor_sent_ptr std::cout<<"remote_key_tensor_sent_ptr:"<<(int64_t)remote_key_tensor_sent_ptr
<<" local_key_tensor_sent_ptr:"<<(int64_t)local_key_tensor_sent_ptr <<" local_key_tensor_sent_ptr:"<<(int64_t)local_key_tensor_sent_ptr
<<" local_device_id:" << local_device_id <<" local_device_id:" << local_device_id
<<" remote_device_id:" << remote_device_id <<" remote_device_id:" << remote_device_id
<<" block_idx_stride:" << block_idx_stride <<" block_idx_stride:" << block_idx_stride
<<" block_size_byte:" << block_size_byte <<" block_size_byte:" << block_size_byte
@@ -107,25 +107,25 @@ void sent_key_value_by_remote_ptr(
#endif #endif
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT #ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeerAsync( cudaMemcpyPeerAsync(
reinterpret_cast<void*>(remote_key_tensor_sent_ptr), reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id, remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr), reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id, local_device_id,
block_size_byte, block_size_byte,
stream); stream);
#endif #endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT #ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeer( cudaMemcpyPeer(
reinterpret_cast<void*>(remote_key_tensor_sent_ptr), reinterpret_cast<void*>(remote_key_tensor_sent_ptr),
remote_device_id, remote_device_id,
reinterpret_cast<const void*>(local_key_tensor_sent_ptr), reinterpret_cast<const void*>(local_key_tensor_sent_ptr),
local_device_id, local_device_id,
block_size_byte); block_size_byte);
#endif #endif
cudaError_t err = cudaGetLastError(); cudaError_t err = cudaGetLastError();
if ( err != cudaSuccess ) if ( err != cudaSuccess )
{ {
printf("CUDA Error: %s\n", cudaGetErrorString(err)); printf("CUDA Error: %s\n", cudaGetErrorString(err));
} }
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT #ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaDeviceSynchronize(); cudaDeviceSynchronize();
@@ -140,7 +140,7 @@ void sent_key_value_by_remote_ptr(
#ifdef DEBUG_IPC_SENT #ifdef DEBUG_IPC_SENT
std::cout<<"remote_value_tensor_sent_ptr:"<<(int64_t)remote_value_tensor_sent_ptr std::cout<<"remote_value_tensor_sent_ptr:"<<(int64_t)remote_value_tensor_sent_ptr
<<" local_value_tensor_sent_ptr:"<<(int64_t)local_value_tensor_sent_ptr <<" local_value_tensor_sent_ptr:"<<(int64_t)local_value_tensor_sent_ptr
<<" local_device_id:" << local_device_id <<" local_device_id:" << local_device_id
<<" remote_device_id:" << remote_device_id <<" remote_device_id:" << remote_device_id
<<" block_idx_stride:" << block_idx_stride <<" block_idx_stride:" << block_idx_stride
<<" block_size_byte:" << block_size_byte <<" block_size_byte:" << block_size_byte
@@ -159,26 +159,26 @@ void sent_key_value_by_remote_ptr(
#endif #endif
#ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT #ifndef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeerAsync( cudaMemcpyPeerAsync(
reinterpret_cast<void*>(remote_value_tensor_sent_ptr), reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id, remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr), reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id, local_device_id,
block_size_byte, block_size_byte,
stream); stream);
#endif #endif
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT #ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
cudaMemcpyPeer( cudaMemcpyPeer(
reinterpret_cast<void*>(remote_value_tensor_sent_ptr), reinterpret_cast<void*>(remote_value_tensor_sent_ptr),
remote_device_id, remote_device_id,
reinterpret_cast<const void*>(local_value_tensor_sent_ptr), reinterpret_cast<const void*>(local_value_tensor_sent_ptr),
local_device_id, local_device_id,
block_size_byte); block_size_byte);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
#endif #endif
err = cudaGetLastError(); err = cudaGetLastError();
if ( err != cudaSuccess ) if ( err != cudaSuccess )
{ {
printf("CUDA Error: %s\n", cudaGetErrorString(err)); printf("CUDA Error: %s\n", cudaGetErrorString(err));
} }
#ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT #ifdef DEBUG_IPC_SENT_SYNC_AND_PRINT
PrintMatrix<T>(reinterpret_cast<T*>(remote_value_tensor_sent_ptr), PrintMatrix<T>(reinterpret_cast<T*>(remote_value_tensor_sent_ptr),
@@ -316,11 +316,11 @@ void SentKeyValueByRemotePtrBlockSync(const paddle::Tensor& local_key_tensor,
cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw; cudaStream_t cuda_stream = (cudaStream_t)cuda_stream_raw;
cudaStreamSynchronize(cuda_stream); cudaStreamSynchronize(cuda_stream);
} }
PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr) PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr)
.Inputs({"local_key_tensor", "local_value_tensor", "local_block_ids", "remote_block_ids", "remote_key_tensor", "remote_value_tensor"}) .Inputs({"local_key_tensor", "local_value_tensor", "local_block_ids", "remote_block_ids", "remote_key_tensor", "remote_value_tensor"})
.Attrs({ "block_num: int", .Attrs({ "block_num: int",
"local_device_id: int", "local_device_id: int",
"remote_device_id: int", "remote_device_id: int",
"cuda_stream_raw: int64_t"}) "cuda_stream_raw: int64_t"})
.Outputs({"local_key_tensor_out", "local_value_tensor_out"}) .Outputs({"local_key_tensor_out", "local_value_tensor_out"})
@@ -332,4 +332,4 @@ PD_BUILD_STATIC_OP(ipc_sent_key_value_cache_by_remote_ptr_block_sync)
.Attrs({"cuda_stream_raw: int64_t"}) .Attrs({"cuda_stream_raw: int64_t"})
.Outputs({"local_key_tensor_out", "local_value_tensor_out"}) .Outputs({"local_key_tensor_out", "local_value_tensor_out"})
.SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}}) .SetInplaceMap({{"local_key_tensor", "local_key_tensor_out"},{"local_value_tensor","local_value_tensor_out"}})
.SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtrBlockSync)); .SetKernelFn(PD_KERNEL(SentKeyValueByRemotePtrBlockSync));

View File

@@ -57,5 +57,3 @@ paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids,
num_experts); num_experts);
return token_nums_per_expert; return token_nums_per_expert;
} }

View File

@@ -737,7 +737,7 @@ void MoeFastHardamardWrapper(const T *x_data,
bool FLAGS_hardamard_use_diagonal_block_matrix = true; bool FLAGS_hardamard_use_diagonal_block_matrix = true;
static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size"); static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size");
static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ? static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ?
stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512; stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512;
constexpr int kThreads = 128; constexpr int kThreads = 128;
if (FLAGS_hardamard_use_diagonal_block_matrix) { if (FLAGS_hardamard_use_diagonal_block_matrix) {

View File

@@ -124,4 +124,4 @@ class CubKeyValueSorter {
int num_bits_; int num_bits_;
}; };
} // namespace phi } // namespace phi

View File

@@ -360,10 +360,10 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
normalizing_factor = 1.f / Z; normalizing_factor = 1.f / Z;
} }
__syncthreads(); __syncthreads();
T val = T(threadDataExp * normalizing_factor); T val = T(threadDataExp * normalizing_factor);
// top_k // top_k
using cub_kvp = cub::KeyValuePair<int, T>; using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>; using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduceP::TempStorage tmpStorageP; __shared__ typename BlockReduceP::TempStorage tmpStorageP;
@@ -374,10 +374,10 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
for (int k_idx = 0; k_idx < k; ++k_idx) { for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0; thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
if (threadIdx.x < num_experts) { if (threadIdx.x < num_experts) {
cub_kvp inp_kvp; cub_kvp inp_kvp;
int expert = threadIdx.x; int expert = threadIdx.x;
inp_kvp.key = expert; inp_kvp.key = expert;
inp_kvp.value = bias ? val + bias[expert] : val; inp_kvp.value = bias ? val + bias[expert] : val;
@@ -518,12 +518,12 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z; normalizing_factor = 1.f / Z;
} }
__syncthreads(); __syncthreads();
T val = T(threadDataExp * normalizing_factor); T val = T(threadDataExp * normalizing_factor);
// top_k // top_k
using cub_kvp = cub::KeyValuePair<int, T>; using cub_kvp = cub::KeyValuePair<int, T>;
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>; using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduceP::TempStorage tmpStorageP; __shared__ typename BlockReduceP::TempStorage tmpStorageP;
@@ -541,7 +541,7 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
if (threadIdx.x < num_experts) { if (threadIdx.x < num_experts) {
cub_kvp inp_kvp; cub_kvp inp_kvp;
int expert = threadIdx.x; int expert = threadIdx.x;
inp_kvp.key = expert; inp_kvp.key = expert;
inp_kvp.value = bias ? val + bias[expert] : val; inp_kvp.value = bias ? val + bias[expert] : val;
@@ -1065,7 +1065,7 @@ __global__ void initialize_moe_routing_kernel(
const T* unpermuted_input, const T* unpermuted_input,
OutT* permuted_output, OutT* permuted_output,
const int* expanded_dest_row_to_expanded_source_row, const int* expanded_dest_row_to_expanded_source_row,
const int *expert_idx_per_token, const int *expert_idx_per_token,
const float *w4a8_in_scale, const float *w4a8_in_scale,
int* expanded_source_row_to_expanded_dest_row, int* expanded_source_row_to_expanded_dest_row,
const int64_t num_rows, const int64_t num_rows,
@@ -1088,7 +1088,7 @@ __global__ void initialize_moe_routing_kernel(
expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_source_row_to_expanded_dest_row[expanded_source_row] =
expanded_dest_row; expanded_dest_row;
} }
if (expanded_dest_row < active_rows) { if (expanded_dest_row < active_rows) {
const int expert_idx = expert_idx_per_token[expanded_dest_row]; const int expert_idx = expert_idx_per_token[expanded_dest_row];
@@ -1130,7 +1130,7 @@ static void run(
const T* unpermuted_input, const T* unpermuted_input,
OutT* permuted_output, OutT* permuted_output,
const int* expanded_dest_row_to_expanded_source_row, const int* expanded_dest_row_to_expanded_source_row,
const int *expert_idx_per_token, const int *expert_idx_per_token,
const float *w4a8_in_scale, const float *w4a8_in_scale,
int* expanded_source_row_to_expanded_dest_row, int* expanded_source_row_to_expanded_dest_row,
const int64_t num_rows, const int64_t num_rows,

View File

@@ -17,7 +17,7 @@
// topk warps // topk warps
template<typename T, int VecSize> template<typename T, int VecSize>
__global__ void MoEDeepGEMMPermuteKernel(T* out, int* token_nums_per_expert, int* permute_indices_per_token, const T* x, const int64_t* topk_idx, const int token_num, const int topk, const int num_vecs, const int hidden, const int max_tokens_per_expert) { __global__ void MoEDeepGEMMPermuteKernel(T* out, int* token_nums_per_expert, int* permute_indices_per_token, const T* x, const int64_t* topk_idx, const int token_num, const int topk, const int num_vecs, const int hidden, const int max_tokens_per_expert) {
AlignedVector<T, VecSize> in_vec; AlignedVector<T, VecSize> in_vec;
const int bid = blockIdx.x; const int bid = blockIdx.x;
@@ -32,7 +32,7 @@ __global__ void MoEDeepGEMMPermuteKernel(T* out, int* token_nums_per_expert, int
} }
tgt_expert_token = __shfl_sync(0xFFFFFFFF, tgt_expert_token, 0); tgt_expert_token = __shfl_sync(0xFFFFFFFF, tgt_expert_token, 0);
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs; hidden_vec_id += 32) { for (int hidden_vec_id = tid; hidden_vec_id < num_vecs; hidden_vec_id += 32) {
Load<T, VecSize>(x + token_idx * hidden + hidden_vec_id * VecSize, &in_vec); Load<T, VecSize>(x + token_idx * hidden + hidden_vec_id * VecSize, &in_vec);
Store<T, VecSize>(in_vec, out + tgt_expert_id * max_tokens_per_expert * hidden + tgt_expert_token * hidden + hidden_vec_id * VecSize); Store<T, VecSize>(in_vec, out + tgt_expert_id * max_tokens_per_expert * hidden + tgt_expert_token * hidden + hidden_vec_id * VecSize);
@@ -81,7 +81,7 @@ std::vector<paddle::Tensor> MoEDeepGEMMPermuteDispatch(
permute_indices_per_token.data<int32_t>(), permute_indices_per_token.data<int32_t>(),
reinterpret_cast<const DataType_ *>(x.data<data_t>()), reinterpret_cast<const DataType_ *>(x.data<data_t>()),
topk_idx.data<int64_t>(), topk_idx.data<int64_t>(),
token_num, topk, num_vecs, token_num, topk, num_vecs,
hidden, max_tokens_per_expert hidden, max_tokens_per_expert
); );
@@ -112,4 +112,4 @@ PD_BUILD_STATIC_OP(moe_deepgemm_permute)
.Inputs({"x", "topk_idx"}) .Inputs({"x", "topk_idx"})
.Outputs({"permute_output", "token_nums_per_expert", "permute_indices_per_token"}) .Outputs({"permute_output", "token_nums_per_expert", "permute_indices_per_token"})
.Attrs({"num_experts: int", "max_tokens_per_expert: int"}) .Attrs({"num_experts: int", "max_tokens_per_expert: int"})
.SetKernelFn(PD_KERNEL(MoEDeepGEMMPermute)); .SetKernelFn(PD_KERNEL(MoEDeepGEMMPermute));

View File

@@ -232,12 +232,12 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
/** /**
* @brief Mixture of Experts (MoE) Expert Dispatch Operator * @brief Mixture of Experts (MoE) Expert Dispatch Operator
* *
* This operator performs the following key functions: * This operator performs the following key functions:
* 1. Computes top-k experts for each input token based on gating scores * 1. Computes top-k experts for each input token based on gating scores
* 2. Permutes input tokens according to their selected experts for efficient expert processing * 2. Permutes input tokens according to their selected experts for efficient expert processing
* 3. Computes prefix sums of tokens per expert for group_gemm optimization * 3. Computes prefix sums of tokens per expert for group_gemm optimization
* *
* Inputs: * Inputs:
* - input: The input tensor to be routed to experts * - input: The input tensor to be routed to experts
* Shape: [total_tokens, hidden_size] * Shape: [total_tokens, hidden_size]
@@ -246,7 +246,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
* Shape: [total_tokens, expert_num] * Shape: [total_tokens, expert_num]
* dtype: must be float32 * dtype: must be float32
* - gating_correction_bias: Optional bias term for gating correction (expert_num) * - gating_correction_bias: Optional bias term for gating correction (expert_num)
* *
* Outputs: * Outputs:
* - permute_input: Permuted input tensor organized by expert * - permute_input: Permuted input tensor organized by expert
* Shape: [moe_topk * total_tokens, hidden_size] * Shape: [moe_topk * total_tokens, hidden_size]
@@ -263,7 +263,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
* - top_k_indices: Indices of selected top-k experts for each token * - top_k_indices: Indices of selected top-k experts for each token
* Shape: [total_tokens, moe_topk] * Shape: [total_tokens, moe_topk]
* dtype: int32 * dtype: int32
* *
* Attributes: * Attributes:
* - moe_topk: Number of experts to select for each token (k value in top-k routing) * - moe_topk: Number of experts to select for each token (k value in top-k routing)
* - group_moe: Whether to perform group softmax within the operator * - group_moe: Whether to perform group softmax within the operator
@@ -272,7 +272,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
* - topk_only_mode: Operation mode selector * - topk_only_mode: Operation mode selector
* (true: only performs topk selection without softmax, * (true: only performs topk selection without softmax,
* false: performs full softmax+topk computation) * false: performs full softmax+topk computation)
* *
* Note: * Note:
* - The operator requires 2D input format [total_tokens, hidden_size] * - The operator requires 2D input format [total_tokens, hidden_size]
* - For optimal performance, expert_num should be a power of 2 when possible * - For optimal performance, expert_num should be a power of 2 when possible
@@ -283,7 +283,7 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch)
paddle::Optional("gating_correction_bias"), paddle::Optional("gating_correction_bias"),
paddle::Optional("w4a8_in_scale")}) paddle::Optional("w4a8_in_scale")})
.Outputs({"permute_input", "tokens_expert_prefix_sum", .Outputs({"permute_input", "tokens_expert_prefix_sum",
"permute_indices_per_token", "topk_weight", "topk_idx", "permute_indices_per_token", "topk_weight", "topk_idx",
"expert_idx_per_token"}) "expert_idx_per_token"})
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"}) .Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertDispatch)) .SetKernelFn(PD_KERNEL(MoeExpertDispatch))

View File

@@ -263,4 +263,4 @@ PD_BUILD_OP(moe_redundant_topk_select)
.SetInplaceMap({{"tokens_per_expert_stats_list", "tokens_per_expert_stats_list_out"}}) .SetInplaceMap({{"tokens_per_expert_stats_list", "tokens_per_expert_stats_list_out"}})
.SetKernelFn(PD_KERNEL(MoERedundantTopKSelectKernel)) .SetKernelFn(PD_KERNEL(MoERedundantTopKSelectKernel))
.SetInferShapeFn(PD_INFER_SHAPE(MoERedundantTopKSelectKernelInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(MoERedundantTopKSelectKernelInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoERedundantTopKSelectKernelInferDtype)); .SetInferDtypeFn(PD_INFER_DTYPE(MoERedundantTopKSelectKernelInferDtype));

View File

@@ -106,4 +106,4 @@ template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 256, 4,
template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS ); template __global__ void Marlin<half, MARLIN_NAMESPACE_NAME::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, 8, false>( MARLIN_KERNEL_PARAMS );
} }

View File

@@ -36,4 +36,4 @@ struct msgdata {
struct msgdatakv { struct msgdatakv {
long mtype; long mtype;
int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair int mtext[MAX_BSZ * 3 + 2]; // encoder_count, layer_id, bid- pair
}; };

View File

@@ -14,9 +14,10 @@
"""read_ids""" """read_ids"""
import os import os
import numpy as np
import struct import struct
import numpy as np
def deserialize_from_file(fp): def deserialize_from_file(fp):
"""deserialize from file""" """deserialize from file"""

View File

@@ -13,9 +13,10 @@
# limitations under the License. # limitations under the License.
"""read temp_ids from file""" """read temp_ids from file"""
import os import os
import numpy as np
import struct import struct
import numpy as np
def deserialize_from_file(fp): def deserialize_from_file(fp):
""" """

View File

@@ -15,7 +15,7 @@
#include "remote_cache_kv_ipc.h" #include "remote_cache_kv_ipc.h"
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data RemoteCacheKvIpc::kv_complete_signal_meta_data; RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data RemoteCacheKvIpc::kv_complete_signal_meta_data;
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query; RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query;
void* RemoteCacheKvIpc::kv_complete_signal_identity_ptr = nullptr; void* RemoteCacheKvIpc::kv_complete_signal_identity_ptr = nullptr;
bool RemoteCacheKvIpc::kv_complete_signal_shmem_opened = false; bool RemoteCacheKvIpc::kv_complete_signal_shmem_opened = false;
@@ -118,4 +118,3 @@ void CUDART_CB RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_que
RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.send_signal(); RemoteCacheKvIpc::kv_complete_signal_meta_data_per_query.send_signal();
// std::printf("#### save_cache_kv_complete_signal_layerwise_per_query); // std::printf("#### save_cache_kv_complete_signal_layerwise_per_query);
} }

View File

@@ -71,7 +71,7 @@ struct RemoteCacheKvIpc {
} }
} }
msg_sed.mtext[0] = encoder_count; msg_sed.mtext[0] = encoder_count;
if (!inited) { if (!inited) {
// just init once // just init once
const int msg_id = 1024 + rank; const int msg_id = 1024 + rank;
@@ -90,7 +90,7 @@ struct RemoteCacheKvIpc {
assert(layer_id_ <= num_layers_); assert(layer_id_ <= num_layers_);
} }
}; };
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data kv_complete_signal_meta_data; static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data kv_complete_signal_meta_data;
static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query kv_complete_signal_meta_data_per_query; static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data_per_query kv_complete_signal_meta_data_per_query;
static void* kv_complete_signal_identity_ptr; static void* kv_complete_signal_identity_ptr;

View File

@@ -125,7 +125,7 @@ void group_wise_scale(ScaleT* scale,
} }
} }
std::vector<paddle::Tensor> Fp8Int4WeightQuantizeKernel(const paddle::Tensor &input, std::vector<paddle::Tensor> Fp8Int4WeightQuantizeKernel(const paddle::Tensor &input,
int groupsize, int groupsize,
std::string scale_dtype) { std::string scale_dtype) {
auto input_cpu = input.copy_to(paddle::CPUPlace(), false); auto input_cpu = input.copy_to(paddle::CPUPlace(), false);
@@ -139,47 +139,47 @@ std::vector<paddle::Tensor> Fp8Int4WeightQuantizeKernel(const paddle::Tensor &in
if (groupsize > 0) { if (groupsize > 0) {
scale = paddle::full({shape[0] / groupsize * shape[1]}, 1.0, paddle::DataType::BFLOAT16, paddle::CPUPlace()); scale = paddle::full({shape[0] / groupsize * shape[1]}, 1.0, paddle::DataType::BFLOAT16, paddle::CPUPlace());
group_wise_scale(scale.data<phi::dtype::bfloat16>(), input_cpu.data<float>(), k, n, 7.0f, groupsize); group_wise_scale(scale.data<phi::dtype::bfloat16>(), input_cpu.data<float>(), k, n, 7.0f, groupsize);
group_wise_quant(packed_int4.data<int8_t>(), group_wise_quant(packed_int4.data<int8_t>(),
input_cpu.data<float>(), input_cpu.data<float>(),
scale.data<phi::dtype::bfloat16>(), scale.data<phi::dtype::bfloat16>(),
k, k,
n, n,
groupsize); groupsize);
} else { } else {
scale = paddle::full({shape[1]}, 1.0, paddle::DataType::BFLOAT16, paddle::CPUPlace()); scale = paddle::full({shape[1]}, 1.0, paddle::DataType::BFLOAT16, paddle::CPUPlace());
per_channel_scale(scale.data<phi::dtype::bfloat16>(), input_cpu.data<float>(), k, n, 7.0f); per_channel_scale(scale.data<phi::dtype::bfloat16>(), input_cpu.data<float>(), k, n, 7.0f);
per_channel_quant(packed_int4.data<int8_t>(), per_channel_quant(packed_int4.data<int8_t>(),
input_cpu.data<float>(), input_cpu.data<float>(),
scale.data<phi::dtype::bfloat16>(), scale.data<phi::dtype::bfloat16>(),
k, k,
n); n);
} }
} else if (scale_dtype == "float16") { } else if (scale_dtype == "float16") {
if (groupsize > 0) { if (groupsize > 0) {
scale = paddle::full({shape[0] / groupsize * shape[1]}, 1.0, paddle::DataType::FLOAT16, paddle::CPUPlace()); scale = paddle::full({shape[0] / groupsize * shape[1]}, 1.0, paddle::DataType::FLOAT16, paddle::CPUPlace());
group_wise_scale(scale.data<phi::dtype::float16>(), input_cpu.data<float>(), k, n, 7.0f, groupsize); group_wise_scale(scale.data<phi::dtype::float16>(), input_cpu.data<float>(), k, n, 7.0f, groupsize);
group_wise_quant(packed_int4.data<int8_t>(), group_wise_quant(packed_int4.data<int8_t>(),
input_cpu.data<float>(), input_cpu.data<float>(),
scale.data<phi::dtype::float16>(), scale.data<phi::dtype::float16>(),
k, k,
n, n,
groupsize); groupsize);
} else { } else {
scale = paddle::full({shape[1]}, 1.0, paddle::DataType::FLOAT16, paddle::CPUPlace()); scale = paddle::full({shape[1]}, 1.0, paddle::DataType::FLOAT16, paddle::CPUPlace());
per_channel_scale(scale.data<phi::dtype::float16>(), input_cpu.data<float>(), k, n, 7.0f); per_channel_scale(scale.data<phi::dtype::float16>(), input_cpu.data<float>(), k, n, 7.0f);
per_channel_quant(packed_int4.data<int8_t>(), per_channel_quant(packed_int4.data<int8_t>(),
input_cpu.data<float>(), input_cpu.data<float>(),
scale.data<phi::dtype::float16>(), scale.data<phi::dtype::float16>(),
k, k,
n); n);
} }
} }
auto out = paddle::full({shape[1] / 2, shape[0]}, 0, paddle::DataType::INT8, paddle::CPUPlace()); auto out = paddle::full({shape[1] / 2, shape[0]}, 0, paddle::DataType::INT8, paddle::CPUPlace());
preprocess_weights_for_mixed_gemm( preprocess_weights_for_mixed_gemm(
out.data<int8_t>(), out.data<int8_t>(),
packed_int4.data<int8_t>(), packed_int4.data<int8_t>(),
{k, n}, {k, n},
kernels::cutlass_kernels::QuantType::W4_AFP8, kernels::cutlass_kernels::QuantType::W4_AFP8,
false); false);
return {out, scale}; return {out, scale};

View File

@@ -1,11 +1,11 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -27,7 +27,7 @@
std::vector<paddle::Tensor> ShareExternalData(paddle::Tensor& input, std::vector<paddle::Tensor> ShareExternalData(paddle::Tensor& input,
const std::string shm_name, const std::string shm_name,
const std::vector<int>& shape) { const std::vector<int>& shape) {
volatile shmStruct *shm = NULL; volatile shmStruct *shm = NULL;
sharedMemoryInfo info; sharedMemoryInfo info;
if (sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info) != 0) { if (sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info) != 0) {
@@ -62,4 +62,4 @@ PD_BUILD_STATIC_OP(share_external_data)
.Inputs({"input"}) .Inputs({"input"})
.Outputs({"output"}) .Outputs({"output"})
.Attrs({"shm_name: std::string", "shape: std::vector<int>"}) .Attrs({"shm_name: std::string", "shape: std::vector<int>"})
.SetKernelFn(PD_KERNEL(ShareExternalData)); .SetKernelFn(PD_KERNEL(ShareExternalData));

View File

@@ -19,7 +19,7 @@
// #define DEBUG_EAGLE_KERNEL // #define DEBUG_EAGLE_KERNEL
__global__ void ComputeOrderKernel( __global__ void ComputeOrderKernel(
const int* seq_lens_this_time, const int* seq_lens_this_time,
const int* seq_lens_encoder, const int* seq_lens_encoder,
const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_encoder,
@@ -47,7 +47,7 @@ __global__ void ComputeOrderKernel(
printf("batch %d: cur_seq_lens_encoder > 0 \n", i); printf("batch %d: cur_seq_lens_encoder > 0 \n", i);
#endif #endif
for (int j = 0; j < cur_seq_lens_encoder; j++) { for (int j = 0; j < cur_seq_lens_encoder; j++) {
position_map[in_offset++] = out_offset++; position_map[in_offset++] = out_offset++;
} }
// 2. base model encoder. Base step=0 // 2. base model encoder. Base step=0
} else if (cur_base_model_seq_lens_encoder != 0) { } else if (cur_base_model_seq_lens_encoder != 0) {
@@ -69,13 +69,13 @@ __global__ void ComputeOrderKernel(
in_offset += cur_base_model_seq_lens_this_time; in_offset += cur_base_model_seq_lens_this_time;
} else /*Accept all draft tokens*/ { } else /*Accept all draft tokens*/ {
#ifdef DEBUG_EAGLE_KERNEL #ifdef DEBUG_EAGLE_KERNEL
printf("batch %d: accept_num > actual_draft_token_num \n", i); printf("batch %d: accept_num > actual_draft_token_num \n", i);
#endif #endif
position_map[in_offset + accept_num - 2] = out_offset++; position_map[in_offset + accept_num - 2] = out_offset++;
position_map[in_offset + accept_num - 1] = out_offset++; position_map[in_offset + accept_num - 1] = out_offset++;
in_offset += cur_base_model_seq_lens_this_time; in_offset += cur_base_model_seq_lens_this_time;
} }
} }
} }
output_token_num[0] = out_offset; output_token_num[0] = out_offset;
#ifdef DEBUG_EAGLE_KERNEL #ifdef DEBUG_EAGLE_KERNEL
@@ -208,7 +208,7 @@ std::vector<paddle::Tensor> EagleGetHiddenStates(
} }
case paddle::DataType::BFLOAT16: { case paddle::DataType::BFLOAT16: {
return DispatchDtype<paddle::DataType::BFLOAT16>( return DispatchDtype<paddle::DataType::BFLOAT16>(
input, input,
seq_lens_this_time, seq_lens_this_time,
seq_lens_encoder, seq_lens_encoder,
seq_lens_decoder, seq_lens_decoder,

View File

@@ -72,7 +72,7 @@ __global__ void computeOrderKernel(
output_token_num[0] = out_offset; output_token_num[0] = out_offset;
#ifdef DEBUG_EAGLE_KERNEL #ifdef DEBUG_EAGLE_KERNEL
printf("position map output_token_num%d:\n", output_token_num[0]); printf("position map output_token_num%d:\n", output_token_num[0]);
for (int i = 0; i < output_token_num[0]; i++) { for (int i = 0; i < output_token_num[0]; i++) {
printf("%d ", src_map[i]); printf("%d ", src_map[i]);
} }
printf("\n"); printf("\n");
@@ -187,4 +187,4 @@ PD_BUILD_STATIC_OP(eagle_get_self_hidden_states)
"seq_lens_this_time", "seq_lens_this_time",
"step_idx"}) "step_idx"})
.Outputs({"out"}) .Outputs({"out"})
.SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates)); .SetKernelFn(PD_KERNEL(EagleGetSelfHiddenStates));

View File

@@ -26,7 +26,7 @@ __global__ void RebuildAppendPaddingKernel(
const int seq_len, const int seq_len,
const int dim_embed, const int dim_embed,
const size_t elem_nums) { const size_t elem_nums) {
using LoadT = AlignedVector<T, VecSize>; using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec; LoadT src_vec;
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) { for (int64_t i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) {
@@ -42,7 +42,7 @@ __global__ void RebuildAppendPaddingKernel(
const int input_token_id = ori_token_id - cum_offset[bi] + seq_id; const int input_token_id = ori_token_id - cum_offset[bi] + seq_id;
const int bias_idx = i % dim_embed; const int bias_idx = i % dim_embed;
Load<T, VecSize>(&full_hidden_states[input_token_id * dim_embed + bias_idx], &src_vec); Load<T, VecSize>(&full_hidden_states[input_token_id * dim_embed + bias_idx], &src_vec);
Store<T, VecSize>(src_vec, &out[i]); Store<T, VecSize>(src_vec, &out[i]);
} }
@@ -78,14 +78,14 @@ std::vector<paddle::Tensor> DispatchDtype(
GetNumBlocks(pack_num, &grid_size); GetNumBlocks(pack_num, &grid_size);
RebuildAppendPaddingKernel<DataType_, PackSize><<<grid_size, threads_per_block, 0, full_hidden_states.stream()>>>( RebuildAppendPaddingKernel<DataType_, PackSize><<<grid_size, threads_per_block, 0, full_hidden_states.stream()>>>(
reinterpret_cast<DataType_*>(out.data<data_t>()), reinterpret_cast<DataType_*>(out.data<data_t>()),
reinterpret_cast<const DataType_*>(full_hidden_states.data<data_t>()), reinterpret_cast<const DataType_*>(full_hidden_states.data<data_t>()),
cum_offsets.data<int32_t>(), cum_offsets.data<int32_t>(),
seq_len_encoder.data<int32_t>(), seq_len_encoder.data<int32_t>(),
seq_len_decoder.data<int32_t>(), seq_len_decoder.data<int32_t>(),
output_padding_offset.data<int32_t>(), output_padding_offset.data<int32_t>(),
max_seq_len, max_seq_len,
dim_embed, dim_embed,
elem_nums); elem_nums);
return {out}; return {out};
} }
@@ -99,7 +99,7 @@ std::vector<paddle::Tensor> RebuildAppendPadding(
const paddle::Tensor& output_padding_offset, const paddle::Tensor& output_padding_offset,
const int max_seq_len) { const int max_seq_len) {
switch (full_hidden_states.dtype()) { switch (full_hidden_states.dtype()) {
case paddle::DataType::BFLOAT16: case paddle::DataType::BFLOAT16:
return DispatchDtype<paddle::DataType::BFLOAT16>( return DispatchDtype<paddle::DataType::BFLOAT16>(
@@ -137,7 +137,7 @@ std::vector<paddle::DataType> RebuildAppendPaddingInferDtype(
PD_BUILD_STATIC_OP(speculate_rebuild_append_padding) PD_BUILD_STATIC_OP(speculate_rebuild_append_padding)
.Inputs({"full_hidden_states", .Inputs({"full_hidden_states",
"cum_offsets", "cum_offsets",
"seq_len_encoder", "seq_len_encoder",
"seq_len_decoder", "seq_len_decoder",
@@ -146,4 +146,4 @@ PD_BUILD_STATIC_OP(speculate_rebuild_append_padding)
.Outputs({"out"}) .Outputs({"out"})
.SetKernelFn(PD_KERNEL(RebuildAppendPadding)) .SetKernelFn(PD_KERNEL(RebuildAppendPadding))
.SetInferShapeFn(PD_INFER_SHAPE(RebuildAppendPaddingInferShape)) .SetInferShapeFn(PD_INFER_SHAPE(RebuildAppendPaddingInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype)); .SetInferDtypeFn(PD_INFER_DTYPE(RebuildAppendPaddingInferDtype));

View File

@@ -93,7 +93,7 @@ __global__ void speculate_free_and_reschedule(bool *stop_flags,
used_list_len[tid] = 0; used_list_len[tid] = 0;
} }
} else if (seq_lens_this_time[tid] != 0 && max_possible_block_idx < block_num_per_seq && } else if (seq_lens_this_time[tid] != 0 && max_possible_block_idx < block_num_per_seq &&
block_table_now[(seq_lens_decoder[tid] + max_draft_tokens + block_table_now[(seq_lens_decoder[tid] + max_draft_tokens +
1) / 1) /
block_size] == -1) { block_size] == -1) {
// 统计需要分配block的位置和总数 // 统计需要分配block的位置和总数
@@ -347,7 +347,7 @@ PD_BUILD_STATIC_OP(speculate_step_reschedule)
"next_tokens", "next_tokens",
"first_token_ids", "first_token_ids",
"accept_num"}) "accept_num"})
.Attrs({"block_size: int", .Attrs({"block_size: int",
"encoder_decoder_block_num: int", "encoder_decoder_block_num: int",
"max_draft_tokens: int"}) "max_draft_tokens: int"})
.Outputs({"stop_flags_out", .Outputs({"stop_flags_out",

View File

@@ -60,7 +60,7 @@ __global__ void recover_block_system_cache(int *recover_block_list, // [bsz]
const int ori_free_list_len_tid0 = atomicSub(free_list_len, decoder_used_len); const int ori_free_list_len_tid0 = atomicSub(free_list_len, decoder_used_len);
ori_free_list_len = ori_free_list_len_tid0; ori_free_list_len = ori_free_list_len_tid0;
#ifdef DEBUG_STEP #ifdef DEBUG_STEP
printf("seq_id: %d, ori_seq_len_encoder: %d, step_idx_now: %d, seq_len: %d, ori_free_list_len_tid0: %d, ori_free_list_len: %d\n", printf("seq_id: %d, ori_seq_len_encoder: %d, step_idx_now: %d, seq_len: %d, ori_free_list_len_tid0: %d, ori_free_list_len: %d\n",
recover_id, ori_seq_len_encoder, step_idx_now, seq_len, ori_free_list_len_tid0, ori_free_list_len); recover_id, ori_seq_len_encoder, step_idx_now, seq_len, ori_free_list_len_tid0, ori_free_list_len);
#endif #endif
} }
@@ -95,7 +95,7 @@ void StepSystemCache(const paddle::Tensor& stop_flags,
const paddle::Tensor& recover_lens, const paddle::Tensor& recover_lens,
const paddle::Tensor& need_block_list, const paddle::Tensor& need_block_list,
const paddle::Tensor& need_block_len, const paddle::Tensor& need_block_len,
const paddle::Tensor& used_list_len, const paddle::Tensor& used_list_len,
const paddle::Tensor& free_list, const paddle::Tensor& free_list,
const paddle::Tensor& free_list_len, const paddle::Tensor& free_list_len,
const paddle::Tensor& input_ids, const paddle::Tensor& input_ids,
@@ -178,7 +178,7 @@ void StepSystemCache(const paddle::Tensor& stop_flags,
} }
PD_BUILD_STATIC_OP(step_system_cache) PD_BUILD_STATIC_OP(step_system_cache)
.Inputs({"stop_flags", .Inputs({"stop_flags",
"seq_lens_this_time", "seq_lens_this_time",
"ori_seq_lens_encoder", "ori_seq_lens_encoder",
"ori_seq_lens_decoder", "ori_seq_lens_decoder",

View File

@@ -68,26 +68,26 @@ void SwapCache(const paddle::Tensor& cache_gpu, // gpu
switch (cache_gpu.dtype()) { switch (cache_gpu.dtype()) {
case paddle::DataType::BFLOAT16: case paddle::DataType::BFLOAT16:
return SwapCacheImpl<paddle::DataType::BFLOAT16>( return SwapCacheImpl<paddle::DataType::BFLOAT16>(
cache_gpu, cache_gpu,
cache_cpu_ptr, cache_cpu_ptr,
max_block_num_cpu, max_block_num_cpu,
swap_block_ids_gpu, swap_block_ids_gpu,
swap_block_ids_cpu, swap_block_ids_cpu,
mode); mode);
case paddle::DataType::FLOAT16: case paddle::DataType::FLOAT16:
return SwapCacheImpl<paddle::DataType::FLOAT16>( return SwapCacheImpl<paddle::DataType::FLOAT16>(
cache_gpu, cache_gpu,
cache_cpu_ptr, cache_cpu_ptr,
max_block_num_cpu, max_block_num_cpu,
swap_block_ids_gpu, swap_block_ids_gpu,
swap_block_ids_cpu, swap_block_ids_cpu,
mode); mode);
case paddle::DataType::UINT8: case paddle::DataType::UINT8:
return SwapCacheImpl<paddle::DataType::UINT8>( return SwapCacheImpl<paddle::DataType::UINT8>(
cache_gpu, cache_gpu,
cache_cpu_ptr, cache_cpu_ptr,
max_block_num_cpu, max_block_num_cpu,
swap_block_ids_gpu, swap_block_ids_gpu,
swap_block_ids_cpu, swap_block_ids_cpu,
mode); mode);
default: default:

View File

@@ -47,7 +47,7 @@ inline cudaError_t GetGridSize(int64_t n, int block_size, int num_waves, int* nu
template<typename T, int VecSize> template<typename T, int VecSize>
__global__ void text_image_scatter_kernel( __global__ void text_image_scatter_kernel(
T* input_ptr, T* input_ptr,
T* text_gather_ptr, T* text_gather_ptr,
T* image_gather_ptr, T* image_gather_ptr,
int32_t* token_type_ids, int32_t* token_type_ids,
@@ -72,8 +72,8 @@ __global__ void text_image_scatter_kernel(
int32_t token_type_ids_num = token_type_ids[token_idx]; int32_t token_type_ids_num = token_type_ids[token_idx];
int64_t input_load_offset = token_idx * hidden_size + hidden_offset; int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec); Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
#pragma unroll #pragma unroll
for(int vi = 0; vi < VecSize; ++vi) { for(int vi = 0; vi < VecSize; ++vi) {
text_imgaes_vec[vi] = input_ptr_vec[vi]; text_imgaes_vec[vi] = input_ptr_vec[vi];
@@ -92,7 +92,7 @@ __global__ void text_image_scatter_kernel(
template<typename T, int VecSize> template<typename T, int VecSize>
__global__ void text_image_gather_kernel( __global__ void text_image_gather_kernel(
T* output_ptr, T* output_ptr,
T* text_gather_ptr, T* text_gather_ptr,
T* image_gather_ptr, T* image_gather_ptr,
int32_t* token_type_ids, int32_t* token_type_ids,
@@ -131,8 +131,8 @@ __global__ void text_image_gather_kernel(
} }
int64_t input_load_offset = token_idx * hidden_size + hidden_offset; int64_t input_load_offset = token_idx * hidden_size + hidden_offset;
Store<T, VecSize>(output_ptr_vec, output_ptr + input_load_offset); Store<T, VecSize>(output_ptr_vec, output_ptr + input_load_offset);
} }
} }
@@ -159,7 +159,7 @@ void LaunchTextImageGatherScatter(
const int64_t tot_element_num = token_num * hidden_size; const int64_t tot_element_num = token_num * hidden_size;
int64_t tot_pack_num = (tot_element_num + VecSize - 1) / VecSize; int64_t tot_pack_num = (tot_element_num + VecSize - 1) / VecSize;
const int block_size = 128; const int block_size = 128;
int grid_index = (token_num + block_size - 1) / block_size; int grid_index = (token_num + block_size - 1) / block_size;
constexpr int32_t kNumWaves = 16; constexpr int32_t kNumWaves = 16;
@@ -170,8 +170,8 @@ void LaunchTextImageGatherScatter(
if (is_scatter) { if (is_scatter) {
text_image_scatter_kernel<DataType_, 8><<<grid_dim, block_size>>>( text_image_scatter_kernel<DataType_, 8><<<grid_dim, block_size>>>(
reinterpret_cast<DataType_*>(input.data<data_t>()), reinterpret_cast<DataType_*>(input.data<data_t>()),
reinterpret_cast<DataType_*>(text_input.data<data_t>()), reinterpret_cast<DataType_*>(text_input.data<data_t>()),
reinterpret_cast<DataType_*>(image_input.data<data_t>()), reinterpret_cast<DataType_*>(image_input.data<data_t>()),
reinterpret_cast<int32_t*>(token_type_ids.data<int32_t>()), reinterpret_cast<int32_t*>(token_type_ids.data<int32_t>()),
reinterpret_cast<int32_t*>(text_index.data<int32_t>()), reinterpret_cast<int32_t*>(text_index.data<int32_t>()),
reinterpret_cast<int32_t*>(image_index.data<int32_t>()), reinterpret_cast<int32_t*>(image_index.data<int32_t>()),
@@ -181,8 +181,8 @@ void LaunchTextImageGatherScatter(
} else { } else {
text_image_gather_kernel<DataType_, 8><<<grid_dim, block_size>>>( text_image_gather_kernel<DataType_, 8><<<grid_dim, block_size>>>(
reinterpret_cast<DataType_*>(input.data<data_t>()), reinterpret_cast<DataType_*>(input.data<data_t>()),
reinterpret_cast<DataType_*>(text_input.data<data_t>()), reinterpret_cast<DataType_*>(text_input.data<data_t>()),
reinterpret_cast<DataType_*>(image_input.data<data_t>()), reinterpret_cast<DataType_*>(image_input.data<data_t>()),
reinterpret_cast<int32_t*>(token_type_ids.data<int32_t>()), reinterpret_cast<int32_t*>(token_type_ids.data<int32_t>()),
reinterpret_cast<int32_t*>(text_index.data<int32_t>()), reinterpret_cast<int32_t*>(text_index.data<int32_t>()),
reinterpret_cast<int32_t*>(image_index.data<int32_t>()), reinterpret_cast<int32_t*>(image_index.data<int32_t>()),
@@ -216,8 +216,8 @@ void TextImageGatherScatter(
PD_BUILD_STATIC_OP(text_image_gather_scatter) PD_BUILD_STATIC_OP(text_image_gather_scatter)
.Inputs({"input", .Inputs({"input",
"text_input", "text_input",
"image_input", "image_input",
"token_type_ids", "token_type_ids",
"text_index", "text_index",
"image_index"}) "image_index"})
@@ -229,5 +229,5 @@ PD_BUILD_STATIC_OP(text_image_gather_scatter)
.SetInplaceMap({{"text_input", "text_input_out"}, .SetInplaceMap({{"text_input", "text_input_out"},
{"image_input", "image_input_out"}, {"image_input", "image_input_out"},
{"text_index", "text_index_out"}, {"text_index", "text_index_out"},
{"image_index", "image_index_out"}}) {"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageGatherScatter)); .SetKernelFn(PD_KERNEL(TextImageGatherScatter));

View File

@@ -16,7 +16,7 @@
template <int VecSize> template <int VecSize>
__global__ void text_image_index_out_kernel( __global__ void text_image_index_out_kernel(
int32_t* token_type_ids, int32_t* token_type_ids,
int32_t* text_index, int32_t* text_index,
int32_t* image_index, int32_t* image_index,
const int64_t token_num const int64_t token_num
@@ -25,7 +25,7 @@ __global__ void text_image_index_out_kernel(
if (global_thread_idx >= 1) return; if (global_thread_idx >= 1) return;
int text_count = 0; int text_count = 0;
int images_count = 0; int images_count = 0;
for (int i = 0; i < token_num; ++i) { for (int i = 0; i < token_num; ++i) {
// printf(" %d %d %d %d \n", text_index[i], text_count, images_count, i); // printf(" %d %d %d %d \n", text_index[i], text_count, images_count, i);
if (token_type_ids[i] == 0) { if (token_type_ids[i] == 0) {
@@ -60,5 +60,5 @@ PD_BUILD_STATIC_OP(text_image_index_out)
.Outputs({"text_index_out", .Outputs({"text_index_out",
"image_index_out"}) "image_index_out"})
.SetInplaceMap({{"text_index", "text_index_out"}, .SetInplaceMap({{"text_index", "text_index_out"},
{"image_index", "image_index_out"}}) {"image_index", "image_index_out"}})
.SetKernelFn(PD_KERNEL(TextImageIndexOut)); .SetKernelFn(PD_KERNEL(TextImageIndexOut));

View File

@@ -810,4 +810,4 @@ PD_BUILD_STATIC_OP(tune_cublaslt_gemm)
"is_test: bool", "is_test: bool",
"is_read_from_file: bool", "is_read_from_file: bool",
"path: std::string"}) "path: std::string"})
.SetKernelFn(PD_KERNEL(TuneCublasltGemm)); .SetKernelFn(PD_KERNEL(TuneCublasltGemm));

View File

@@ -33,7 +33,7 @@ __global__ void update_inputs_beam_kernel(
if (block_idx == 0) { if (block_idx == 0) {
seq_lens_this_time[thread_idx] = seq_lens_this_time[bsz_index]; seq_lens_this_time[thread_idx] = seq_lens_this_time[bsz_index];
seq_lens_encoder[thread_idx] = seq_lens_encoder[bsz_index]; seq_lens_encoder[thread_idx] = seq_lens_encoder[bsz_index];
} }
if (block_idx < seq_len) { if (block_idx < seq_len) {
input_ids[thread_idx * seq_len + block_idx] = input_ids[bsz_index * seq_len + block_idx]; input_ids[thread_idx * seq_len + block_idx] = input_ids[bsz_index * seq_len + block_idx];
} }
@@ -74,8 +74,8 @@ void UpdateInputesBeam(
PD_BUILD_STATIC_OP(update_inputs_beam) PD_BUILD_STATIC_OP(update_inputs_beam)
.Inputs({"beam_width", .Inputs({"beam_width",
"seq_lens_this_time", "seq_lens_this_time",
"seq_lens_encoder", "seq_lens_encoder",
"input_ids", "input_ids",
"logits"}) "logits"})
.Outputs({"seq_lens_this_time_out", .Outputs({"seq_lens_this_time_out",
@@ -86,4 +86,4 @@ PD_BUILD_STATIC_OP(update_inputs_beam)
{"seq_lens_encoder", "seq_lens_encoder_out"}, {"seq_lens_encoder", "seq_lens_encoder_out"},
{"input_ids", "input_ids_out"}, {"input_ids", "input_ids_out"},
{"logits", "logits_out"}}) {"logits", "logits_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputesBeam)); .SetKernelFn(PD_KERNEL(UpdateInputesBeam));

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" setup for FastDeploy custom ops """ """setup for FastDeploy custom ops"""
import importlib import importlib
import json import json
import os import os
@@ -41,8 +41,7 @@ ROOT_DIR = Path(__file__).parent.parent
# cannot import envs directly because it depends on fastdeploy, # cannot import envs directly because it depends on fastdeploy,
# which is not installed yet # which is not installed yet
envs = load_module_from_path('envs', envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
os.path.join(ROOT_DIR, 'fastdeploy', 'envs.py'))
archs = json.loads(envs.FD_BUILDING_ARCS) archs = json.loads(envs.FD_BUILDING_ARCS)
use_bf16 = envs.FD_CPU_USE_BF16 == "True" use_bf16 = envs.FD_CPU_USE_BF16 == "True"
@@ -143,8 +142,7 @@ def get_nvcc_version():
""" """
Get cuda version of nvcc. Get cuda version of nvcc.
""" """
nvcc_output = subprocess.check_output(["nvcc", "--version"], nvcc_output = subprocess.check_output(["nvcc", "--version"], universal_newlines=True)
universal_newlines=True)
output = nvcc_output.split() output = nvcc_output.split()
release_idx = output.index("release") + 1 release_idx = output.index("release") + 1
nvcc_cuda_version = float(output[release_idx].split(",")[0]) nvcc_cuda_version = float(output[release_idx].split(",")[0])
@@ -160,13 +158,19 @@ def get_gencode_flags(archs):
for cc_val in cc_s: for cc_val in cc_s:
if cc_val == 90: if cc_val == 90:
arch_code = "90a" arch_code = "90a"
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"] flags += [
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x "-gencode",
f"arch=compute_{arch_code},code=sm_{arch_code}",
]
elif cc_val == 100: # Assuming 100 is the code for Blackwell SM10.x
# Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a' # Per NVIDIA dev blog, for CUTLASS and architecture-specific features on CC 10.0, use '100a'
# https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/ # https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/
# "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0" # "The CUTLASS build instructions specify using the a flag when building for devices of CC 9.0 and 10.0"
arch_code = "100a" arch_code = "100a"
flags += ["-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}"] flags += [
"-gencode",
f"arch=compute_{arch_code},code=sm_{arch_code}",
]
else: else:
flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"] flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
return flags return flags
@@ -194,7 +198,7 @@ if paddle.is_compiled_with_rocm():
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir) clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
if not os.listdir(json_dir): if not os.listdir(json_dir):
raise ValueError("Git clone nlohmann_json failed!") raise ValueError("Git clone nlohmann_json failed!")
sources=[ sources = [
"gpu_ops/set_value_by_flags.cu", "gpu_ops/set_value_by_flags.cu",
"gpu_ops/token_penalty_multi_scores.cu", "gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/stop_generation.cu", "gpu_ops/stop_generation.cu",
@@ -302,8 +306,7 @@ elif paddle.is_compiled_with_cuda():
if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir): if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
if not os.path.exists(cutlass_dir): if not os.path.exists(cutlass_dir):
os.makedirs(cutlass_dir) os.makedirs(cutlass_dir)
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir)
cutlass_dir)
if not os.listdir(cutlass_dir): if not os.listdir(cutlass_dir):
raise ValueError("Git clone cutlass failed!") raise ValueError("Git clone cutlass failed!")
@@ -312,8 +315,7 @@ elif paddle.is_compiled_with_cuda():
if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir): if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
if not os.path.exists(deep_gemm_dir): if not os.path.exists(deep_gemm_dir):
os.makedirs(deep_gemm_dir) os.makedirs(deep_gemm_dir)
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", deep_gemm_dir)
deep_gemm_dir)
if not os.listdir(deep_gemm_dir): if not os.listdir(deep_gemm_dir):
raise ValueError("Git clone DeepGEMM failed!") raise ValueError("Git clone DeepGEMM failed!")
cur_path = os.path.dirname(os.path.abspath(__file__)) cur_path = os.path.dirname(os.path.abspath(__file__))
@@ -347,15 +349,13 @@ elif paddle.is_compiled_with_cuda():
try: try:
shutil.copytree(src_dir, dst_dir) shutil.copytree(src_dir, dst_dir)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
f"Failed to copy from {src_dir} to {dst_dir}: {e}")
json_dir = "third_party/nlohmann_json" json_dir = "third_party/nlohmann_json"
if not os.path.exists(json_dir) or not os.listdir(json_dir): if not os.path.exists(json_dir) or not os.listdir(json_dir):
if not os.path.exists(json_dir): if not os.path.exists(json_dir):
os.makedirs(json_dir) os.makedirs(json_dir)
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir)
json_dir)
if not os.listdir(json_dir): if not os.listdir(json_dir):
raise ValueError("Git clone nlohmann_json failed!") raise ValueError("Git clone nlohmann_json failed!")
@@ -372,7 +372,7 @@ elif paddle.is_compiled_with_cuda():
"-Ithird_party/nlohmann_json/include", "-Ithird_party/nlohmann_json/include",
] ]
nvcc_version = get_nvcc_version() nvcc_version = get_nvcc_version()
print(f'nvcc_version = {nvcc_version}') print(f"nvcc_version = {nvcc_version}")
if nvcc_version >= 12.0: if nvcc_version >= 12.0:
sources += ["gpu_ops/sample_kernels/air_top_p_sampling.cu"] sources += ["gpu_ops/sample_kernels/air_top_p_sampling.cu"]
cc = max(get_sm_version(archs)) cc = max(get_sm_version(archs))
@@ -414,31 +414,24 @@ elif paddle.is_compiled_with_cuda():
# Running generate fp8 gemm codes. # Running generate fp8 gemm codes.
# Common for SM89, SM90, SM100 (Blackwell) # Common for SM89, SM90, SM100 (Blackwell)
nvcc_compile_args += ["-DENABLE_FP8"] nvcc_compile_args += ["-DENABLE_FP8"]
nvcc_compile_args += [ nvcc_compile_args += ["-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"]
"-Igpu_ops/cutlass_kernels/fp8_gemm_fused/autogen"
]
# This script seems general enough for different SM versions, specific templates are chosen by CUTLASS. # This script seems general enough for different SM versions, specific templates are chosen by CUTLASS.
os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py") os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py")
if cc >= 90: # Hopper and newer if cc >= 90: # Hopper and newer
# SM90 (Hopper) specific auto-generation and flags # SM90 (Hopper) specific auto-generation and flags
if cc == 90: # Only for SM90 if cc == 90: # Only for SM90
nvcc_compile_args += [ nvcc_compile_args += [
# The gencode for 90a is added in get_gencode_flags now # The gencode for 90a is added in get_gencode_flags now
# "-gencode", # "-gencode",
# "arch=compute_90a,code=compute_90a", # "arch=compute_90a,code=compute_90a",
"-O3", "-O3",
"-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a "-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a
] ]
print("SM90: Running SM90-specific FP8 kernel auto-generation.") print("SM90: Running SM90-specific FP8 kernel auto-generation.")
os.system( os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py")
"python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py") os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py")
os.system( os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py")
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py"
)
os.system(
"python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py"
)
nvcc_compile_args += [ nvcc_compile_args += [
"-DENABLE_SCALED_MM_SM90=1", "-DENABLE_SCALED_MM_SM90=1",
@@ -450,14 +443,14 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu", "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu",
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu", "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
] ]
elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics
print("SM100 (Blackwell): Applying SM100 configurations.") print("SM100 (Blackwell): Applying SM100 configurations.")
nvcc_compile_args += [ nvcc_compile_args += [
# The gencode for 100a is added in get_gencode_flags # The gencode for 100a is added in get_gencode_flags
# "-gencode", # "-gencode",
# "arch=compute_100a,code=compute_100a", # "arch=compute_100a,code=compute_100a",
"-O3", # Common optimization flag "-O3", # Common optimization flag
"-DNDEBUG", # Common debug flag "-DNDEBUG", # Common debug flag
# Potentially add -DENABLE_SM100_FEATURES if specific macros are identified # Potentially add -DENABLE_SM100_FEATURES if specific macros are identified
] ]
# Placeholder for SM100-specific kernel auto-generation scripts # Placeholder for SM100-specific kernel auto-generation scripts
@@ -469,18 +462,16 @@ elif paddle.is_compiled_with_cuda():
# Add SM100 specific sources if any, e.g., for new hardware intrinsics # Add SM100 specific sources if any, e.g., for new hardware intrinsics
# sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example # sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example
pass # No SM100 specific sources identified yet beyond what CUTLASS handles pass # No SM100 specific sources identified yet beyond what CUTLASS handles
else: # For cc >= 89 but not 90 or 100 (e.g. SM89) else: # For cc >= 89 but not 90 or 100 (e.g. SM89)
print(f"SM{cc}: Running generic FP8 kernel auto-generation.") print(f"SM{cc}: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system( os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
else: # For cc == 89 (Ada) else: # For cc == 89 (Ada)
print("SM89: Running generic FP8 kernel auto-generation.") print("SM89: Running generic FP8 kernel auto-generation.")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py")
os.system( os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
"python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py")
# Common FP8 sources for SM89+ # Common FP8 sources for SM89+
sources += [ sources += [
@@ -493,7 +484,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu", "gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu", "gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu", "gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
"gpu_ops/fused_hadamard_quant_fp8.cu" "gpu_ops/fused_hadamard_quant_fp8.cu",
] ]
sources += find_end_files(fp8_auto_gen_directory, ".cu") sources += find_end_files(fp8_auto_gen_directory, ".cu")

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" setup for FASTDEPLOY base ops """ """setup for FASTDEPLOY base ops"""
from paddle.utils.cpp_extension import CppExtension, setup from paddle.utils.cpp_extension import CppExtension, setup
@@ -27,7 +27,8 @@ setup(
"cpu_ops/rebuild_padding.cc", "cpu_ops/rebuild_padding.cc",
], ],
extra_compile_args=[ extra_compile_args=[
"-DPy_LIMITED_API=0x03090000", "-DPADDLE_ON_INFERENCE" "-DPy_LIMITED_API=0x03090000",
"-DPADDLE_ON_INFERENCE",
], ],
), ),
) )

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" setup for FASTDEPLOY custom cpu ops """ """setup for FASTDEPLOY custom cpu ops"""
import os import os
import subprocess import subprocess
import tarfile import tarfile
@@ -26,8 +26,7 @@ ROOT_DIR = Path(__file__).parent.parent
# which is not installed yet # which is not installed yet
from .setup_ops import load_module_from_path from .setup_ops import load_module_from_path
envs = load_module_from_path('envs', envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.py"))
os.path.join(ROOT_DIR, 'fastdeploy', 'envs.py'))
BUILDING_ARCS = [] BUILDING_ARCS = []
use_bf16 = envs.FD_CPU_USE_BF16 == "True" use_bf16 = envs.FD_CPU_USE_BF16 == "True"

View File

@@ -48,17 +48,26 @@ def get_candidate_configs(sm):
candidate_configs = list() candidate_configs = list()
hasbias = ("false", "true") hasbias = ("false", "true")
KernelSchedule = ( KernelSchedule = ("KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>",)
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>", ) EpilogueSchedule = ("TmaWarpSpecializedCooperative",)
EpilogueSchedule = ("TmaWarpSpecializedCooperative", )
TileSchedule = ("PersistentScheduler", "StreamKScheduler") TileSchedule = ("PersistentScheduler", "StreamKScheduler")
for act_tag in [ for act_tag in [
("noact", "Identity"), ("noact", "Identity"),
# ("relu", "ReLu"), # ("relu", "ReLu"),
# ("gelu", "GELU"), # ("gelu", "GELU"),
]: ]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, candidate_configs.extend(
EpilogueSchedule, TileSchedule)]) [
(
hasbias,
act_tag,
tiles,
KernelSchedule,
EpilogueSchedule,
TileSchedule,
)
]
)
return candidate_configs return candidate_configs
@@ -66,16 +75,13 @@ def get_shape_str(tile_shape):
""" """
return tile_shape string. return tile_shape string.
""" """
blocks, clusters = [ blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
blocks = [elem.strip("_") for elem in blocks] blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters] clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule, def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule, tile_schedule):
tile_schedule):
""" """
check the cutlass config valid. check the cutlass config valid.
""" """
@@ -304,13 +310,10 @@ def SubstituteTemplate(template, values_base):
SubstituteTemplate SubstituteTemplate
""" """
values = copy.deepcopy(values_base) values = copy.deepcopy(values_base)
if values.get("KernelSchedule" if values.get("KernelSchedule") is not None and "Auto" in values["KernelSchedule"]:
) is not None and "Auto" in values["KernelSchedule"]:
values["KernelSchedule"] = "collective::" + values["KernelSchedule"] values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
if values.get("EpilogueSchedule" if values.get("EpilogueSchedule") is not None and "Auto" in values["EpilogueSchedule"]:
) is not None and "Auto" in values["EpilogueSchedule"]: values["EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
values[
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
text = template text = template
changed = True changed = True
while changed: while changed:
@@ -329,8 +332,7 @@ def parse_args():
parse_args parse_args
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
) )
parser.add_argument( parser.add_argument(
"--cuda_arch", "--cuda_arch",
@@ -346,15 +348,15 @@ def parse_args():
# generate source .cu # generate source .cu
def generate_source_cu( def generate_source_cu(
inputs_type: (str), inputs_type: str,
outputs_type: (str), outputs_type: str,
hasbiases: (str), hasbiases: str,
act_tag: (str), act_tag: str,
tiles: (str), tiles: str,
KernelSchedule: (str), KernelSchedule: str,
EpilogueSchedule: (str), EpilogueSchedule: str,
TileSchedule: (str), TileSchedule: str,
sm: str, sm: str,
): ):
""" """
generate_source_cu generate_source_cu
@@ -369,8 +371,11 @@ def generate_source_cu(
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
for tile_schedule in TileSchedule: for tile_schedule in TileSchedule:
if not check_config_valid( if not check_config_valid(
tile_config, kernel_schedule, tile_config,
epilogue_schedule, tile_schedule): kernel_schedule,
epilogue_schedule,
tile_schedule,
):
continue continue
value_dict = { value_dict = {
"input_type": input_type, "input_type": input_type,
@@ -385,30 +390,32 @@ def generate_source_cu(
"SM": sm, "SM": sm,
"sm": sm[-2:], "sm": sm[-2:],
} }
all_code += SubstituteTemplate( all_code += SubstituteTemplate(GemmDeclare, value_dict)
GemmDeclare, value_dict)
return all_code return all_code
# generate gemm launch .cu # generate gemm launch .cu
def generate_launch_gemm_cus( def generate_launch_gemm_cus(
generate_dir: (str), inputs_type: (str), outputs_type: (str), generate_dir: str,
fuse_gemm_configs: tuple, sm: str): inputs_type: str,
outputs_type: str,
fuse_gemm_configs: tuple,
sm: str,
):
""" """
generate_launch_gemm_cus generate_launch_gemm_cus
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs] act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0] single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0] hasbiases: str = single_config[0]
tiles: (str) = single_config[2] tiles: str = single_config[2]
KernelSchedule: (str) = single_config[3] KernelSchedule: str = single_config[3]
EpilogueSchedule: (str) = single_config[4] EpilogueSchedule: str = single_config[4]
TileSchedule: (str) = single_config[5] TileSchedule: str = single_config[5]
code_map = {} code_map = {}
head_path = os.path.join(generate_dir, head_path = os.path.join(generate_dir, f"launch_block_gemm_kernel_sm{sm[-2:]}.h")
f"launch_block_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead head_all_code = LaunchGemmHead
for tile_config in tiles: for tile_config in tiles:
blocks, clusters = get_shape_str(tile_config) blocks, clusters = get_shape_str(tile_config)
@@ -418,19 +425,19 @@ def generate_launch_gemm_cus(
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule: for tile_schedule in TileSchedule:
if not check_config_valid(tile_config, kernel_schedule, if not check_config_valid(
epilogue_schedule, tile_config,
tile_schedule): kernel_schedule,
epilogue_schedule,
tile_schedule,
):
continue continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}" gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = { value_dict = {
"sm": "sm": sm[-2:],
sm[-2:], "gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
} }
head_all_code += SubstituteTemplate( head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
LaunchGemmDeclare, value_dict)
os.makedirs(generate_dir, exist_ok=True) os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f: with open(head_path, "w") as f:
f.write(head_all_code) f.write(head_all_code)
@@ -444,19 +451,19 @@ def generate_launch_gemm_cus(
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule: for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(
epilogue_schedule, tile_shape,
tile_schedule): kernel_schedule,
epilogue_schedule,
tile_schedule,
):
continue continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}" gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = { value_dict = {
"sm": "sm": sm[-2:],
sm[-2:], "gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
} }
source_all_code = SubstituteTemplate( source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
LaunchGemmPart0, value_dict)
type_id = 0 type_id = 0
for input_type in inputs_type: for input_type in inputs_type:
for output_type in outputs_type: for output_type in outputs_type:
@@ -476,16 +483,14 @@ def generate_launch_gemm_cus(
"SM": sm, "SM": sm,
"sm": sm[-2:], "sm": sm[-2:],
} }
source_all_code += SubstituteTemplate( source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
LaunchGemmPart1, value_dict)
type_id += 1 type_id += 1
source_all_code += LaunchGemmPart2 source_all_code += LaunchGemmPart2
gemm_config_str = gemm_config_str.replace("<", "").replace( gemm_config_str = gemm_config_str.replace("<", "").replace(">", "")
">", "")
code_map[gemm_config_str] = source_all_code code_map[gemm_config_str] = source_all_code
source_path = os.path.join( source_path = os.path.join(
generate_dir, generate_dir,
f"launch_block_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu" f"launch_block_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
) )
with open(source_path, "w") as f: with open(source_path, "w") as f:
f.write(source_all_code) f.write(source_all_code)
@@ -495,19 +500,18 @@ def generate_launch_gemm_cus(
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu # generate fp8_fp8_gemm_scale_bias_act_sm90.cu
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str), def generate_dispatch_gemm_cu(inputs_type: str, outputs_type: str, fuse_gemm_configs: tuple, sm: str):
fuse_gemm_configs: tuple, sm: str):
""" """
generate_dispatch_gemm_cu generate_dispatch_gemm_cu
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs] act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0] single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0] hasbiases: str = single_config[0]
tiles: (str) = single_config[2] tiles: str = single_config[2]
KernelSchedule: (str) = single_config[3] KernelSchedule: str = single_config[3]
EpilogueSchedule: (str) = single_config[4] EpilogueSchedule: str = single_config[4]
TileSchedule: (str) = single_config[5] TileSchedule: str = single_config[5]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]}) all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0 type_id = 0
for input_type in inputs_type: for input_type in inputs_type:
@@ -530,9 +534,12 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
for tile_schedule in TileSchedule: for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(
epilogue_schedule, tile_shape,
tile_schedule): kernel_schedule,
epilogue_schedule,
tile_schedule,
):
continue continue
value_dict = { value_dict = {
"TileShape": tile_shape[0], "TileShape": tile_shape[0],
@@ -554,18 +561,18 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str_2 = gemm_config_str_1 + f"_{epilogue_schedule}"
for tile_schedule in TileSchedule: for tile_schedule in TileSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(
epilogue_schedule, tile_shape,
tile_schedule): kernel_schedule,
epilogue_schedule,
tile_schedule,
):
continue continue
gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}" gemm_config_str = gemm_config_str_2 + f"_{tile_schedule}"
value_dict = { value_dict = {
"sm": "sm": sm[-2:],
sm[-2:], "tile_id": str(tile_id),
"tile_id": "gemm_config": gemm_config_str.replace("<", "").replace(">", ""),
str(tile_id),
"gemm_config":
gemm_config_str.replace("<", "").replace(">", ""),
} }
all_code += SubstituteTemplate(code_part5, value_dict) all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1 tile_id += 1
@@ -610,12 +617,17 @@ if __name__ == "__main__":
f.close() f.close()
# Compile parallelization # Compile parallelization
generate_launch_gemm_cus( generate_launch_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type, "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
outputs_type, fuse_gemm_configs, sm_dict[sm]) inputs_type,
outputs_type,
fuse_gemm_configs,
sm_dict[sm],
)
# hard code for act_tag # hard code for act_tag
file_name = (f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/" file_name = (
f"fp8_fp8_block_gemm_scale_bias_act_sm{sm}.cu") f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/" f"fp8_fp8_block_gemm_scale_bias_act_sm{sm}.cu"
)
all_code = generate_dispatch_gemm_cu( all_code = generate_dispatch_gemm_cu(
inputs_type, inputs_type,
outputs_type, outputs_type,

View File

@@ -24,27 +24,28 @@ def get_candidate_tiles():
""" """
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")] base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([ base_configs.extend(
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"), [
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"), ("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"), ("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"), ("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"), ("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"), ("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"), ("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"), ("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"), ("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"), ("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
]) ("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
]
)
return base_configs return base_configs
def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, def get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
max_stages):
""" """
get_dual_gemm_candidate_configs returns a list of candidate configs for the dual_gemm_fused_kernel. get_dual_gemm_candidate_configs returns a list of candidate configs for the dual_gemm_fused_kernel.
""" """
@@ -299,8 +300,7 @@ def check_min_split_k(value):
""" """
ivalue = int(value) ivalue = int(value)
if ivalue > 1: if ivalue > 1:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support.")
"Dual gemm split_k mode is not support.")
return ivalue return ivalue
@@ -310,8 +310,7 @@ def check_max_split_k(value):
""" """
ivalue = int(value) ivalue = int(value)
if ivalue > 1: if ivalue > 1:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError("Dual gemm split_k mode is not support..")
"Dual gemm split_k mode is not support..")
return ivalue return ivalue
@@ -320,8 +319,7 @@ def parse_args():
parse_args parse_args
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
) )
parser.add_argument( parser.add_argument(
"--cuda_arch", "--cuda_arch",
@@ -421,8 +419,7 @@ def generate_dual_gemm_source_cu(
"hasbias": hasbias, "hasbias": hasbias,
"SM": sm, "SM": sm,
} }
all_code += SubstituteTemplate( all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict)
GemmSplitKDeclare, value_dict)
all_code += CommonTail all_code += CommonTail
return all_code return all_code
@@ -449,12 +446,12 @@ def generate_launch_dual_gemm_cus(
head_path = os.path.join(generate_dir, "launch_dual_gemm_kernel.h") head_path = os.path.join(generate_dir, "launch_dual_gemm_kernel.h")
head_all_code = LaunchGemmHead head_all_code = LaunchGemmHead
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile gemm_config = (
] f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_" f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_" f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}") )
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = { value_dict = {
@@ -467,12 +464,12 @@ def generate_launch_dual_gemm_cus(
f.close() f.close()
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile gemm_config = (
] f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_" f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_" f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}") )
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = { value_dict = {
@@ -498,16 +495,14 @@ def generate_launch_dual_gemm_cus(
"num_stages": str(stage), "num_stages": str(stage),
"SM": sm, "SM": sm,
} }
source_all_code += SubstituteTemplate( source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
LaunchGemmPart1, value_dict)
# split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict) # split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
type_id += 1 type_id += 1
source_all_code += LaunchGemmPart2 source_all_code += LaunchGemmPart2
# source_all_code += split_k_code # source_all_code += split_k_code
# source_all_code += LaunchGemmPart4 # source_all_code += LaunchGemmPart4
code_map[gemm_config_str] = source_all_code code_map[gemm_config_str] = source_all_code
source_path = os.path.join( source_path = os.path.join(generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu")
generate_dir, f"launch_dual_gemm_kernel_{gemm_config_str}.cu")
with open(source_path, "w") as f: with open(source_path, "w") as f:
f.write(source_all_code) f.write(source_all_code)
f.close() f.close()
@@ -566,12 +561,12 @@ def generate_dispatch_dual_gemm_cu(
tile_id = 0 tile_id = 0
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile gemm_config = (
] f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_"
gemm_config = (f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_" f"warp{warps[0]}x{warps[1]}x{warps[2]}_"
f"warp{warps[0]}x{warps[1]}x{warps[2]}_" f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
f"mma{mmas[0]}x{mmas[1]}x{mmas[2]}") )
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
value_dict = { value_dict = {
@@ -580,10 +575,12 @@ def generate_dispatch_dual_gemm_cu(
} }
all_code += SubstituteTemplate(code_part5, value_dict) all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1 tile_id += 1
value_dict.update({ value_dict.update(
"min_split_k": str(min_split_k), {
"max_split_k": str(max_split_k), "min_split_k": str(min_split_k),
}) "max_split_k": str(max_split_k),
}
)
all_code += SubstituteTemplate(code_part6, value_dict) all_code += SubstituteTemplate(code_part6, value_dict)
return all_code return all_code
@@ -602,8 +599,7 @@ if __name__ == "__main__":
for sm in archs: for sm in archs:
if sm == "89": if sm == "89":
fuse_gemm_configs = get_dual_gemm_candidate_configs( fuse_gemm_configs = get_dual_gemm_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages)
sm, min_split_k, max_split_k, min_stages, max_stages)
for fuse_gemm_config in fuse_gemm_configs: for fuse_gemm_config in fuse_gemm_configs:
file_name = ( file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/" f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"

View File

@@ -19,8 +19,7 @@ import re
def get_candidate_tiles(): def get_candidate_tiles():
""" """ """
"""
cta_shape = [ cta_shape = [
("<_64, _16, _128>"), ("<_64, _16, _128>"),
("<_64, _32, _128>"), ("<_64, _32, _128>"),
@@ -45,8 +44,7 @@ def get_candidate_tiles():
def get_dual_gemm_candidate_configs(sm): def get_dual_gemm_candidate_configs(sm):
""" """ """
"""
tiles = get_candidate_tiles() tiles = get_candidate_tiles()
candidate_configs = list() candidate_configs = list()
@@ -64,35 +62,27 @@ def get_dual_gemm_candidate_configs(sm):
("swiglu", "SiLu"), ("swiglu", "SiLu"),
("geglu", "GELU"), ("geglu", "GELU"),
]: ]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, EpilogueSchedule)])
EpilogueSchedule)])
return candidate_configs return candidate_configs
def get_shape_str(tile_shape): def get_shape_str(tile_shape):
""" """ """
""" blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
blocks = [elem.strip("_") for elem in blocks] blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters] clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule): def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
""" """ """
"""
blocks, clusters = get_shape_str(tile_shape) blocks, clusters = get_shape_str(tile_shape)
if int( if int(blocks[0]) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
blocks[0]
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
return False return False
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule: if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
return False return False
if tile_shape[ if tile_shape[0] == "<_128, _128, _128>" and kernel_schedule == "KernelTmaWarpSpecializedPingpongFP8FastAccum":
0] == "<_128, _128, _128>" and kernel_schedule == "KernelTmaWarpSpecializedPingpongFP8FastAccum":
return False return False
return True return True
@@ -302,8 +292,7 @@ bool fp8_fp8_dual_gemm_scale_bias_act(DualGemmEpilogueAllParams params) {
def SubstituteTemplate(template, values): def SubstituteTemplate(template, values):
""" """ """
"""
text = template text = template
changed = True changed = True
while changed: while changed:
@@ -318,10 +307,8 @@ def SubstituteTemplate(template, values):
def parse_args(): def parse_args():
""" """ """
""" parser = argparse.ArgumentParser(description="auto generate the fp8_fp8_dual_gemm_fused_kernels_sm90.")
parser = argparse.ArgumentParser(
description="auto generate the fp8_fp8_dual_gemm_fused_kernels_sm90.")
parser.add_argument( parser.add_argument(
"--cuda_arch", "--cuda_arch",
type=str, type=str,
@@ -336,17 +323,16 @@ def parse_args():
# generate source .cu # generate source .cu
def generate_dual_gemm_source_cu( def generate_dual_gemm_source_cu(
inputs_type: (str), inputs_type: str,
biases_type: (str), biases_type: str,
hasbiases: (str), hasbiases: str,
act_tag: (str), act_tag: str,
tiles: (str), tiles: str,
KernelSchedule: (str), KernelSchedule: str,
EpilogueSchedule: (str), EpilogueSchedule: str,
sm: str, sm: str,
): ):
""" """ """
"""
all_code = CommonHead all_code = CommonHead
for input_type in inputs_type: for input_type in inputs_type:
for bias_type in biases_type: for bias_type in biases_type:
@@ -354,9 +340,7 @@ def generate_dual_gemm_source_cu(
for tile_config in tiles: for tile_config in tiles:
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
kernel_schedule,
epilogue_schedule):
continue continue
value_dict = { value_dict = {
"input_type": input_type, "input_type": input_type,
@@ -370,28 +354,29 @@ def generate_dual_gemm_source_cu(
"SM": sm, "SM": sm,
"sm": sm[-2:], "sm": sm[-2:],
} }
all_code += SubstituteTemplate( all_code += SubstituteTemplate(GemmDeclare, value_dict)
GemmDeclare, value_dict)
return all_code return all_code
# generate gemm launch .cu # generate gemm launch .cu
def generate_launch_dual_gemm_cus( def generate_launch_dual_gemm_cus(
generate_dir: (str), inputs_type: (str), biases_type: (str), generate_dir: str,
fuse_gemm_configs: tuple, sm: str): inputs_type: str,
""" biases_type: str,
""" fuse_gemm_configs: tuple,
sm: str,
):
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs] act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0] single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0] hasbiases: str = single_config[0]
tiles: (str) = single_config[2] tiles: str = single_config[2]
KernelSchedule: (str) = single_config[3] KernelSchedule: str = single_config[3]
EpilogueSchedule: (str) = single_config[4] EpilogueSchedule: str = single_config[4]
code_map = {} code_map = {}
head_path = os.path.join(generate_dir, head_path = os.path.join(generate_dir, f"launch_dual_gemm_kernel_sm{sm[-2:]}.h")
f"launch_dual_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead head_all_code = LaunchGemmHead
for tile_config in tiles: for tile_config in tiles:
blocks, clusters = get_shape_str(tile_config) blocks, clusters = get_shape_str(tile_config)
@@ -401,16 +386,14 @@ def generate_launch_dual_gemm_cus(
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}" gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, kernel_schedule, if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = { value_dict = {
"sm": sm[-2:], "sm": sm[-2:],
"gemm_config": gemm_config_str, "gemm_config": gemm_config_str,
} }
head_all_code += SubstituteTemplate(LaunchGemmDeclare, head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
value_dict)
os.makedirs(generate_dir, exist_ok=True) os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f: with open(head_path, "w") as f:
f.write(head_all_code) f.write(head_all_code)
@@ -422,16 +405,14 @@ def generate_launch_dual_gemm_cus(
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}" gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = { value_dict = {
"sm": sm[-2:], "sm": sm[-2:],
"gemm_config": gemm_config_str, "gemm_config": gemm_config_str,
} }
source_all_code = SubstituteTemplate(LaunchGemmPart0, source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
value_dict)
type_id = 0 type_id = 0
for input_type in inputs_type: for input_type in inputs_type:
for bias_type in biases_type: for bias_type in biases_type:
@@ -450,14 +431,13 @@ def generate_launch_dual_gemm_cus(
"SM": sm, "SM": sm,
"sm": sm[-2:], "sm": sm[-2:],
} }
source_all_code += SubstituteTemplate( source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
LaunchGemmPart1, value_dict)
type_id += 1 type_id += 1
source_all_code += LaunchGemmPart2 source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code code_map[gemm_config_str] = source_all_code
source_path = os.path.join( source_path = os.path.join(
generate_dir, generate_dir,
f"launch_dual_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu" f"launch_dual_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
) )
with open(source_path, "w") as f: with open(source_path, "w") as f:
f.write(source_all_code) f.write(source_all_code)
@@ -467,16 +447,14 @@ def generate_launch_dual_gemm_cus(
# generate fp8_fp8_gemm_scale_bias_act.cu # generate fp8_fp8_gemm_scale_bias_act.cu
def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str), def generate_dispatch_dual_gemm_cu(inputs_type: str, biases_type: str, fuse_gemm_configs: tuple, sm: str):
fuse_gemm_configs: tuple, sm: str): """ """
"""
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs] act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0] single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0] hasbiases: str = single_config[0]
tiles: (str) = single_config[2] tiles: str = single_config[2]
KernelSchedule: (str) = single_config[3] KernelSchedule: str = single_config[3]
EpilogueSchedule: (str) = single_config[4] EpilogueSchedule: str = single_config[4]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]}) all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0 type_id = 0
@@ -500,8 +478,7 @@ def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
for tile_shape in tiles: for tile_shape in tiles:
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
value_dict = { value_dict = {
"TileShape": tile_shape[0], "TileShape": tile_shape[0],
@@ -520,8 +497,7 @@ def generate_dispatch_dual_gemm_cu(inputs_type: (str), biases_type: (str),
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}" gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = { value_dict = {
@@ -570,12 +546,15 @@ if __name__ == "__main__":
f.close() f.close()
# Compile parallelization # Compile parallelization
generate_launch_dual_gemm_cus( generate_launch_dual_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type, "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
biases_type, fuse_gemm_configs, sm_dict[sm]) inputs_type,
biases_type,
fuse_gemm_configs,
sm_dict[sm],
)
# hard code for act_tag # hard code for act_tag
file_name = ( file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/" f"gpu_ops/cutlass_kernels/fp8_gemm_fused/" f"autogen/fp8_fp8_dual_gemm_scale_bias_act_sm{sm}.cu"
f"autogen/fp8_fp8_dual_gemm_scale_bias_act_sm{sm}.cu"
) )
all_code = generate_dispatch_dual_gemm_cu( all_code = generate_dispatch_dual_gemm_cu(
inputs_type, inputs_type,

View File

@@ -31,25 +31,26 @@ def get_candidate_tiles():
""" """
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")] base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([ base_configs.extend(
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"), [
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"), ("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"), ("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"), ("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"), ("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"), ("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"), ("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"), ("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
]) ("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
]
)
return base_configs return base_configs
def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, def get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages):
max_stages):
""" """
获取候选的gemm算子配置列表。 获取候选的gemm算子配置列表。
@@ -353,8 +354,7 @@ def parse_args():
代码参数解析 代码参数解析
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
) )
parser.add_argument( parser.add_argument(
"--cuda_arch", "--cuda_arch",
@@ -448,8 +448,7 @@ def generate_source_cu(
"hasbias": hasbias, "hasbias": hasbias,
"SM": sm, "SM": sm,
} }
all_code += SubstituteTemplate(GemmSplitKDeclare, all_code += SubstituteTemplate(GemmSplitKDeclare, value_dict)
value_dict)
all_code += CommonTail all_code += CommonTail
return all_code return all_code
@@ -473,9 +472,7 @@ def generate_launch_gemm_cus(
head_path = os.path.join(generate_dir, "launch_gemm_kernel.h") head_path = os.path.join(generate_dir, "launch_gemm_kernel.h")
head_all_code = LaunchGemmHead head_all_code = LaunchGemmHead
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
@@ -489,9 +486,7 @@ def generate_launch_gemm_cus(
f.close() f.close()
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
@@ -517,17 +512,14 @@ def generate_launch_gemm_cus(
"num_stages": str(stage), "num_stages": str(stage),
"SM": sm, "SM": sm,
} }
source_all_code += SubstituteTemplate( source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
LaunchGemmPart1, value_dict) split_k_code += SubstituteTemplate(LaunchGemmPart3, value_dict)
split_k_code += SubstituteTemplate(
LaunchGemmPart3, value_dict)
type_id += 1 type_id += 1
source_all_code += LaunchGemmPart2 source_all_code += LaunchGemmPart2
source_all_code += split_k_code source_all_code += split_k_code
source_all_code += LaunchGemmPart4 source_all_code += LaunchGemmPart4
code_map[gemm_config_str] = source_all_code code_map[gemm_config_str] = source_all_code
source_path = os.path.join( source_path = os.path.join(generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu")
generate_dir, f"launch_gemm_kernel_{gemm_config_str}.cu")
with open(source_path, "w") as f: with open(source_path, "w") as f:
f.write(source_all_code) f.write(source_all_code)
f.close() f.close()
@@ -581,9 +573,7 @@ def generate_dispatch_gemm_cu(
all_code += code_part4 all_code += code_part4
tile_id = 0 tile_id = 0
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
@@ -593,10 +583,12 @@ def generate_dispatch_gemm_cu(
} }
all_code += SubstituteTemplate(code_part5, value_dict) all_code += SubstituteTemplate(code_part5, value_dict)
tile_id += 1 tile_id += 1
value_dict.update({ value_dict.update(
"min_split_k": str(min_split_k), {
"max_split_k": str(max_split_k), "min_split_k": str(min_split_k),
}) "max_split_k": str(max_split_k),
}
)
all_code += SubstituteTemplate(code_part6, value_dict) all_code += SubstituteTemplate(code_part6, value_dict)
return all_code return all_code
@@ -614,9 +606,7 @@ if __name__ == "__main__":
for sm in archs: for sm in archs:
if sm == "89": if sm == "89":
fuse_gemm_configs = get_candidate_configs(sm, min_split_k, fuse_gemm_configs = get_candidate_configs(sm, min_split_k, max_split_k, min_stages, max_stages)
max_split_k, min_stages,
max_stages)
for fuse_gemm_config in fuse_gemm_configs: for fuse_gemm_config in fuse_gemm_configs:
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu" file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[3][0]}.cu"
all_code = generate_source_cu( all_code = generate_source_cu(
@@ -654,9 +644,7 @@ if __name__ == "__main__":
# hard code for act_tag # hard code for act_tag
file_name = ( file_name = "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu"
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act.cu"
)
all_code = generate_dispatch_gemm_cu( all_code = generate_dispatch_gemm_cu(
inputs_type, inputs_type,
outputs_type, outputs_type,

View File

@@ -20,44 +20,44 @@ import re
def get_candidate_tiles(): def get_candidate_tiles():
""" """ """
"""
base_configs = [ base_configs = [
("<_64, _64, _128>", "<_1, _8, _1>"), ("<_64, _64, _128>", "<_1, _8, _1>"),
("<_64, _128, _128>", "<_2, _1, _1>"), ("<_64, _128, _128>", "<_2, _1, _1>"),
("<_128, _128, _128>", "<_2, _1, _1>"), ("<_128, _128, _128>", "<_2, _1, _1>"),
] ]
base_configs.extend([ base_configs.extend(
("<_64, _64, _128>", "<_1, _1, _1>"), [
("<_64, _64, _128>", "<_1, _2, _1>"), ("<_64, _64, _128>", "<_1, _1, _1>"),
("<_64, _64, _128>", "<_2, _1, _1>"), ("<_64, _64, _128>", "<_1, _2, _1>"),
("<_64, _64, _64>", "<_1, _1, _1>"), ("<_64, _64, _128>", "<_2, _1, _1>"),
("<_64, _64, _64>", "<_1, _2, _1>"), ("<_64, _64, _64>", "<_1, _1, _1>"),
("<_64, _64, _64>", "<_2, _1, _1>"), ("<_64, _64, _64>", "<_1, _2, _1>"),
("<_64, _128, _128>", "<_1, _2, _1>"), ("<_64, _64, _64>", "<_2, _1, _1>"),
("<_64, _128, _128>", "<_1, _1, _1>"), ("<_64, _128, _128>", "<_1, _2, _1>"),
("<_128, _128, _64>", "<_2, _1, _1>"), ("<_64, _128, _128>", "<_1, _1, _1>"),
("<_256, _128, _128>", "<_1, _2, _1>"), ("<_128, _128, _64>", "<_2, _1, _1>"),
("<_256, _128, _128>", "<_1, _1, _1>"), ("<_256, _128, _128>", "<_1, _2, _1>"),
# The following configurations are rarely selected in Qwen2-7B-model. ("<_256, _128, _128>", "<_1, _1, _1>"),
# ("<_256, _128, _128>", "<_4, _1, _1>"), # The following configurations are rarely selected in Qwen2-7B-model.
# ("<_256, _128, _128>", "<_1, _4, _1>"), # ("<_256, _128, _128>", "<_4, _1, _1>"),
# ("<_256, _128, _128>", "<_2, _4, _1>"), # ("<_256, _128, _128>", "<_1, _4, _1>"),
# ("<_128, _128, _256>", "<_1, _2, _1>"), # ("<_256, _128, _128>", "<_2, _4, _1>"),
# ("<_128, _128, _128>", "<_4, _1, _1>"), # ("<_128, _128, _256>", "<_1, _2, _1>"),
# ("<_128, _128, _128>", "<_2, _4, _1>"), # ("<_128, _128, _128>", "<_4, _1, _1>"),
# ("<_128, _128, _128>", "<_1, _2, _1>"), # ("<_128, _128, _128>", "<_2, _4, _1>"),
# ("<_128, _128, _128>", "<_1, _1, _1>"), # ("<_128, _128, _128>", "<_1, _2, _1>"),
# ("<_128, _128, _128>", "<_1, _4, _1>"), # ("<_128, _128, _128>", "<_1, _1, _1>"),
# ("<_128, _128, _64>", "<_2, _2, _1>"), # ("<_128, _128, _128>", "<_1, _4, _1>"),
]) # ("<_128, _128, _64>", "<_2, _2, _1>"),
]
)
return base_configs return base_configs
def get_candidate_configs(sm): def get_candidate_configs(sm):
""" """ """
"""
tiles = get_candidate_tiles() tiles = get_candidate_tiles()
candidate_configs = list() candidate_configs = list()
@@ -73,36 +73,31 @@ def get_candidate_configs(sm):
("relu", "ReLu"), ("relu", "ReLu"),
("gelu", "GELU"), ("gelu", "GELU"),
]: ]:
candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, candidate_configs.extend([(hasbias, act_tag, tiles, KernelSchedule, EpilogueSchedule)])
EpilogueSchedule)])
return candidate_configs return candidate_configs
def get_shape_str(tile_shape): def get_shape_str(tile_shape):
""" """ """
""" blocks, clusters = [s.replace(" ", "").strip("<>").split(",") for s in tile_shape]
blocks, clusters = [
s.replace(" ", "").strip("<>").split(",") for s in tile_shape
]
blocks = [elem.strip("_") for elem in blocks] blocks = [elem.strip("_") for elem in blocks]
clusters = [elem.strip("_") for elem in clusters] clusters = [elem.strip("_") for elem in clusters]
return blocks, clusters return blocks, clusters
def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule): def check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
""" """ """
"""
blocks, clusters = get_shape_str(tile_shape) blocks, clusters = get_shape_str(tile_shape)
if int( if int(blocks[0]) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
blocks[0]
) < 128 and kernel_schedule == "KernelTmaWarpSpecializedCooperativeFP8FastAccum":
return False return False
if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule: if "Cooperative" in kernel_schedule and "Cooperative" not in epilogue_schedule:
return False return False
if (tile_shape[0] == "<_256, _128, _128>" if (
and "Cooperative" not in kernel_schedule tile_shape[0] == "<_256, _128, _128>"
and "Cooperative" not in epilogue_schedule): and "Cooperative" not in kernel_schedule
and "Cooperative" not in epilogue_schedule
):
return False return False
return True return True
@@ -321,16 +316,12 @@ bool fp8_fp8_gemm_scale_bias_act(GemmEpilogueAllParams params) {
def SubstituteTemplate(template, values_base): def SubstituteTemplate(template, values_base):
""" """ """
"""
values = copy.deepcopy(values_base) values = copy.deepcopy(values_base)
if values.get("KernelSchedule" if values.get("KernelSchedule") is not None and "Auto" in values["KernelSchedule"]:
) is not None and "Auto" in values["KernelSchedule"]:
values["KernelSchedule"] = "collective::" + values["KernelSchedule"] values["KernelSchedule"] = "collective::" + values["KernelSchedule"]
if values.get("EpilogueSchedule" if values.get("EpilogueSchedule") is not None and "Auto" in values["EpilogueSchedule"]:
) is not None and "Auto" in values["EpilogueSchedule"]: values["EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
values[
"EpilogueSchedule"] = "collective::" + values["EpilogueSchedule"]
text = template text = template
changed = True changed = True
while changed: while changed:
@@ -345,10 +336,8 @@ def SubstituteTemplate(template, values_base):
def parse_args(): def parse_args():
""" """ """
""" parser = argparse.ArgumentParser(description="auto generate fp8_fp8_gemm_fused_kernels_sm90.")
parser = argparse.ArgumentParser(
description="auto generate fp8_fp8_gemm_fused_kernels_sm90.")
parser.add_argument( parser.add_argument(
"--cuda_arch", "--cuda_arch",
type=str, type=str,
@@ -363,17 +352,16 @@ def parse_args():
# generate source .cu # generate source .cu
def generate_source_cu( def generate_source_cu(
inputs_type: (str), inputs_type: str,
outputs_type: (str), outputs_type: str,
hasbiases: (str), hasbiases: str,
act_tag: (str), act_tag: str,
tiles: (str), tiles: str,
KernelSchedule: (str), KernelSchedule: str,
EpilogueSchedule: (str), EpilogueSchedule: str,
sm: str, sm: str,
): ):
""" """ """
"""
all_code = CommonHead all_code = CommonHead
for input_type in inputs_type: for input_type in inputs_type:
@@ -382,9 +370,7 @@ def generate_source_cu(
for tile_config in tiles: for tile_config in tiles:
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
kernel_schedule,
epilogue_schedule):
continue continue
value_dict = { value_dict = {
"input_type": input_type, "input_type": input_type,
@@ -398,25 +384,27 @@ def generate_source_cu(
"SM": sm, "SM": sm,
"sm": sm[-2:], "sm": sm[-2:],
} }
all_code += SubstituteTemplate( all_code += SubstituteTemplate(GemmDeclare, value_dict)
GemmDeclare, value_dict)
return all_code return all_code
# generate gemm launch .cu # generate gemm launch .cu
def generate_launch_gemm_cus( def generate_launch_gemm_cus(
generate_dir: (str), inputs_type: (str), outputs_type: (str), generate_dir: str,
fuse_gemm_configs: tuple, sm: str): inputs_type: str,
""" outputs_type: str,
""" fuse_gemm_configs: tuple,
sm: str,
):
""" """
act_tags = [single_config[1] for single_config in fuse_gemm_configs] act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0] single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0] hasbiases: str = single_config[0]
tiles: (str) = single_config[2] tiles: str = single_config[2]
KernelSchedule: (str) = single_config[3] KernelSchedule: str = single_config[3]
EpilogueSchedule: (str) = single_config[4] EpilogueSchedule: str = single_config[4]
code_map = {} code_map = {}
head_path = os.path.join(generate_dir, f"launch_gemm_kernel_sm{sm[-2:]}.h") head_path = os.path.join(generate_dir, f"launch_gemm_kernel_sm{sm[-2:]}.h")
head_all_code = LaunchGemmHead head_all_code = LaunchGemmHead
@@ -426,16 +414,14 @@ def generate_launch_gemm_cus(
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}" gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_config, kernel_schedule, if not check_config_valid(tile_config, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = { value_dict = {
"sm": sm[-2:], "sm": sm[-2:],
"gemm_config": gemm_config_str, "gemm_config": gemm_config_str,
} }
head_all_code += SubstituteTemplate(LaunchGemmDeclare, head_all_code += SubstituteTemplate(LaunchGemmDeclare, value_dict)
value_dict)
os.makedirs(generate_dir, exist_ok=True) os.makedirs(generate_dir, exist_ok=True)
with open(head_path, "w") as f: with open(head_path, "w") as f:
f.write(head_all_code) f.write(head_all_code)
@@ -447,16 +433,14 @@ def generate_launch_gemm_cus(
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}" gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = { value_dict = {
"sm": sm[-2:], "sm": sm[-2:],
"gemm_config": gemm_config_str, "gemm_config": gemm_config_str,
} }
source_all_code = SubstituteTemplate(LaunchGemmPart0, source_all_code = SubstituteTemplate(LaunchGemmPart0, value_dict)
value_dict)
type_id = 0 type_id = 0
for input_type in inputs_type: for input_type in inputs_type:
for output_type in outputs_type: for output_type in outputs_type:
@@ -475,14 +459,14 @@ def generate_launch_gemm_cus(
"SM": sm, "SM": sm,
"sm": sm[-2:], "sm": sm[-2:],
} }
source_all_code += SubstituteTemplate( source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
LaunchGemmPart1, value_dict)
type_id += 1 type_id += 1
source_all_code += LaunchGemmPart2 source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code code_map[gemm_config_str] = source_all_code
source_path = os.path.join( source_path = os.path.join(
generate_dir, generate_dir,
f"launch_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu") f"launch_gemm_kernel_sm{sm[-2:]}_{gemm_config_str}.cu",
)
with open(source_path, "w") as f: with open(source_path, "w") as f:
f.write(source_all_code) f.write(source_all_code)
f.close() f.close()
@@ -491,17 +475,15 @@ def generate_launch_gemm_cus(
# generate fp8_fp8_gemm_scale_bias_act_sm90.cu # generate fp8_fp8_gemm_scale_bias_act_sm90.cu
def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str), def generate_dispatch_gemm_cu(inputs_type: str, outputs_type: str, fuse_gemm_configs: tuple, sm: str):
fuse_gemm_configs: tuple, sm: str): """ """
"""
"""
act_tags = [single_config[1] for single_config in fuse_gemm_configs] act_tags = [single_config[1] for single_config in fuse_gemm_configs]
single_config = fuse_gemm_configs[0] single_config = fuse_gemm_configs[0]
hasbiases: (str) = single_config[0] hasbiases: str = single_config[0]
tiles: (str) = single_config[2] tiles: str = single_config[2]
KernelSchedule: (str) = single_config[3] KernelSchedule: str = single_config[3]
EpilogueSchedule: (str) = single_config[4] EpilogueSchedule: str = single_config[4]
all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]}) all_code = SubstituteTemplate(code_part0, {"sm": sm[-2:]})
type_id = 0 type_id = 0
@@ -524,8 +506,7 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for tile_shape in tiles: for tile_shape in tiles:
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
value_dict = { value_dict = {
"TileShape": tile_shape[0], "TileShape": tile_shape[0],
@@ -544,8 +525,7 @@ def generate_dispatch_gemm_cu(inputs_type: (str), outputs_type: (str),
for kernel_schedule in KernelSchedule: for kernel_schedule in KernelSchedule:
gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}" gemm_config_str_1 = gemm_config_str_0 + f"_{kernel_schedule}"
for epilogue_schedule in EpilogueSchedule: for epilogue_schedule in EpilogueSchedule:
if not check_config_valid(tile_shape, kernel_schedule, if not check_config_valid(tile_shape, kernel_schedule, epilogue_schedule):
epilogue_schedule):
continue continue
gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}" gemm_config_str = gemm_config_str_1 + f"_{epilogue_schedule}"
value_dict = { value_dict = {
@@ -576,7 +556,8 @@ if __name__ == "__main__":
for fuse_gemm_config in fuse_gemm_configs: for fuse_gemm_config in fuse_gemm_configs:
file_name = ( file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/" f"gpu_ops/cutlass_kernels/fp8_gemm_fused/"
f"autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu") f"autogen/generic_gemm_kernel_sm{sm}_{fuse_gemm_config[1][0]}.cu"
)
all_code = generate_source_cu( all_code = generate_source_cu(
inputs_type, inputs_type,
outputs_type, outputs_type,
@@ -594,8 +575,12 @@ if __name__ == "__main__":
f.close() f.close()
# Compile parallelization # Compile parallelization
generate_launch_gemm_cus( generate_launch_gemm_cus(
"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen", inputs_type, "gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen",
outputs_type, fuse_gemm_configs, sm_dict[sm]) inputs_type,
outputs_type,
fuse_gemm_configs,
sm_dict[sm],
)
# hard code for act_tag # hard code for act_tag
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act_sm{sm}.cu" file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/fp8_fp8_gemm_scale_bias_act_sm{sm}.cu"

View File

@@ -30,22 +30,24 @@ def get_candidate_tiles():
""" """
base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")] base_configs = [("<64, 64, 64>", "<32, 32, 64>", "<16, 8, 32>")]
base_configs.extend([ base_configs.extend(
("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"), [
("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"), ("<16, 32, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"), ("<16, 64, 64>", "<16, 32, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"), ("<32, 128, 64>", "<32, 32, 64>", "<16, 8, 32>"),
("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"), ("<64, 128, 64>", "<32, 64, 64>", "<16, 8, 32>"),
("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"), ("<64, 64, 128>", "<32, 64, 64>", "<16, 8, 32>"),
("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"), ("<64, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"), ("<128, 64, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<128, 128, 64>", "<64, 32, 64>", "<16, 8, 32>"),
("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"), ("<128, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<128, 128, 64>", "<128, 32, 64>", "<16, 8, 32>"),
("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"), ("<128, 256, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<128, 64, 128>", "<64, 32, 128>", "<16, 8, 32>"), ("<256, 128, 64>", "<64, 64, 64>", "<16, 8, 32>"),
("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"), ("<128, 64, 128>", "<64, 32, 128>", "<16, 8, 32>"),
]) ("<16, 256, 128>", "<16, 64, 128>", "<16, 8, 32>"),
]
)
return base_configs return base_configs
@@ -278,8 +280,7 @@ def parse_args():
代码参数解析 代码参数解析
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description="The argument for generating the generic_mixed_gemm_kernelLauncher instance."
"The argument for generating the generic_mixed_gemm_kernelLauncher instance."
) )
parser.add_argument( parser.add_argument(
"--cuda_arch", "--cuda_arch",
@@ -370,13 +371,10 @@ def generate_launch_gemm_cus(
- dict (code_map) - 包含每个Gemm配置对应的源代码的字典格式为{"gemm_config": source_code}。 - dict (code_map) - 包含每个Gemm配置对应的源代码的字典格式为{"gemm_config": source_code}。
""" """
code_map = {} code_map = {}
head_path = os.path.join(generate_dir, head_path = os.path.join(generate_dir, "launch_visitor_gemm_fused_kernel.h")
"launch_visitor_gemm_fused_kernel.h")
head_all_code = LaunchGemmHead head_all_code = LaunchGemmHead
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
@@ -390,9 +388,7 @@ def generate_launch_gemm_cus(
f.close() f.close()
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
@@ -415,14 +411,14 @@ def generate_launch_gemm_cus(
"num_stages": str(stage), "num_stages": str(stage),
"SM": sm, "SM": sm,
} }
source_all_code += SubstituteTemplate( source_all_code += SubstituteTemplate(LaunchGemmPart1, value_dict)
LaunchGemmPart1, value_dict)
type_id += 1 type_id += 1
source_all_code += LaunchGemmPart2 source_all_code += LaunchGemmPart2
code_map[gemm_config_str] = source_all_code code_map[gemm_config_str] = source_all_code
source_path = os.path.join( source_path = os.path.join(
generate_dir, generate_dir,
f"launch_visitor_gemm_fused_kernel_{gemm_config_str}.cu") f"launch_visitor_gemm_fused_kernel_{gemm_config_str}.cu",
)
with open(source_path, "w") as f: with open(source_path, "w") as f:
f.write(source_all_code) f.write(source_all_code)
f.close() f.close()
@@ -485,9 +481,7 @@ def generate_dispatch_gemm_cu(
all_code += code_part4 all_code += code_part4
tile_id = 0 tile_id = 0
for tile in tiles: for tile in tiles:
blocks, warps, mmas = [ blocks, warps, mmas = [s.replace(" ", "").strip("<>").split(",") for s in tile]
s.replace(" ", "").strip("<>").split(",") for s in tile
]
gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}" gemm_config = f"block{blocks[0]}x{blocks[1]}x{blocks[2]}_warp{warps[0]}x{warps[1]}x{warps[2]}_mma{mmas[0]}x{mmas[1]}x{mmas[2]}"
for stage in stages: for stage in stages:
gemm_config_str = gemm_config + f"_stage{stage}" gemm_config_str = gemm_config + f"_stage{stage}"
@@ -512,10 +506,11 @@ if __name__ == "__main__":
for sm in archs: for sm in archs:
if sm == "89": if sm == "89":
fuse_gemm_configs = get_candidate_configs(sm, min_stages, fuse_gemm_configs = get_candidate_configs(sm, min_stages, max_stages)
max_stages)
for fuse_gemm_config in fuse_gemm_configs: for fuse_gemm_config in fuse_gemm_configs:
file_name = f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_visitor_gemm_fused_kernel_sm{sm}.cu" file_name = (
f"gpu_ops/cutlass_kernels/fp8_gemm_fused/autogen/generic_visitor_gemm_fused_kernel_sm{sm}.cu"
)
all_code = generate_source_cu( all_code = generate_source_cu(
inputs_type, inputs_type,
outputs_type, outputs_type,
@@ -544,9 +539,7 @@ if __name__ == "__main__":
sm_dict[sm], sm_dict[sm],
) )
file_name = ( file_name = "gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu"
"gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused.cu"
)
all_code = generate_dispatch_gemm_cu( all_code = generate_dispatch_gemm_cu(
inputs_type, inputs_type,
outputs_type, outputs_type,

View File

@@ -113,7 +113,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
vsl.kv_lod_vp = { vsl.kv_lod_vp = {
const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()), const_cast<int32_t *>(encoder_seq_lod_cpu.data<int32_t>()),
enc_batch + 1, nullptr}; enc_batch + 1, nullptr};
baidu::xpu::api::VectorParam<int32_t> prefix_lens_vp{ baidu::xpu::api::VectorParam<int32_t> prefix_lens_vp{
nullptr, nullptr,
0, 0,

View File

@@ -30,8 +30,7 @@ current_file = Path(__file__).resolve()
base_dir = current_file.parent base_dir = current_file.parent
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, XDNN_LIB_DIR):
XDNN_LIB_DIR):
""" """
build xpu plugin build xpu plugin
""" """
@@ -49,7 +48,10 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
# 删除指定目录 # 删除指定目录
dirs_to_remove = [ dirs_to_remove = [
"dist", "fastdeploy_ops.egg-info", "build", "plugin/build" "dist",
"fastdeploy_ops.egg-info",
"build",
"plugin/build",
] ]
for dir_name in dirs_to_remove: for dir_name in dirs_to_remove:
if os.path.exists(dir_name): if os.path.exists(dir_name):
@@ -58,8 +60,7 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
# 在 plugin 目录中执行构建脚本 # 在 plugin 目录中执行构建脚本
plugin_dir = "plugin" plugin_dir = "plugin"
build_script = os.path.join(current_working_directory, plugin_dir, build_script = os.path.join(current_working_directory, plugin_dir, "build.sh")
"build.sh")
print("build_script: ", build_script) print("build_script: ", build_script)
@@ -74,14 +75,16 @@ def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR,
# 执行构建脚本 # 执行构建脚本
try: try:
print("Running build script...") print("Running build script...")
subprocess.run([build_script], subprocess.run(
check=True, [build_script],
cwd=os.path.join(current_working_directory, plugin_dir)) check=True,
cwd=os.path.join(current_working_directory, plugin_dir),
)
print("Build completed successfully.") print("Build completed successfully.")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print(f"Build failed with error: {e}") print(f"Build failed with error: {e}")
except Exception as e: except Exception as e:
print(f"Unexpected error: {str(e)}") print(f"Unexpected error: {e!s}")
def xpu_setup_ops(): def xpu_setup_ops():
@@ -124,17 +127,14 @@ def xpu_setup_ops():
XVLLM_PATH = os.getenv("XVLLM_PATH") XVLLM_PATH = os.getenv("XVLLM_PATH")
assert XVLLM_PATH is not None, "XVLLM_PATH is not set." assert XVLLM_PATH is not None, "XVLLM_PATH is not set."
XVLLM_KERNEL_INC_PATH = os.path.join(XVLLM_PATH, "infer_ops", "include") XVLLM_KERNEL_INC_PATH = os.path.join(XVLLM_PATH, "infer_ops", "include")
XVLLM_KERNEL_LIB_PATH = os.path.join(XVLLM_PATH, "infer_ops", "so", XVLLM_KERNEL_LIB_PATH = os.path.join(XVLLM_PATH, "infer_ops", "so", "libapiinfer.so")
"libapiinfer.so")
XVLLM_KERNEL_LIB_DIR = os.path.join(XVLLM_PATH, "infer_ops", "so") XVLLM_KERNEL_LIB_DIR = os.path.join(XVLLM_PATH, "infer_ops", "so")
XVLLM_OP_INC_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "include") XVLLM_OP_INC_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "include")
XVLLM_OP_LIB_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "so", XVLLM_OP_LIB_PATH = os.path.join(XVLLM_PATH, "xft_blocks", "so", "libxft_blocks.so")
"libxft_blocks.so")
XVLLM_OP_LIB_DIR = os.path.join(XVLLM_PATH, "xft_blocks", "so") XVLLM_OP_LIB_DIR = os.path.join(XVLLM_PATH, "xft_blocks", "so")
# build plugin # build plugin
build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH, build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH, XDNN_LIB_DIR)
XDNN_LIB_DIR)
ops = [ ops = [
# custom ops # custom ops
@@ -152,7 +152,6 @@ def xpu_setup_ops():
"./ops/block_attn.cc", "./ops/block_attn.cc",
"./ops/moe_layer.cc", "./ops/moe_layer.cc",
"./ops/weight_quantize_xpu.cc", "./ops/weight_quantize_xpu.cc",
# device manage ops # device manage ops
"./ops/device/get_context_gm_max_mem_demand.cc", "./ops/device/get_context_gm_max_mem_demand.cc",
"./ops/device/get_free_global_memory.cc", "./ops/device/get_free_global_memory.cc",

View File

@@ -29,7 +29,7 @@ for i in range(bs):
ids_len = seq_lens[i, 0] ids_len = seq_lens[i, 0]
input_ids[i, 0:ids_len] = np.random.randint(1, 10, seq_lens[i, 0], "int64") input_ids[i, 0:ids_len] = np.random.randint(1, 10, seq_lens[i, 0], "int64")
x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset( (x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k,) = get_padding_offset(
paddle.to_tensor(input_ids), paddle.to_tensor(input_ids),
paddle.to_tensor(cum_offset), paddle.to_tensor(cum_offset),
paddle.to_tensor(token_num), paddle.to_tensor(token_num),
@@ -46,19 +46,14 @@ print("padding_offset:\n", padding_offset)
print("cu_seqlens_q:\n", cu_seqlens_q) print("cu_seqlens_q:\n", cu_seqlens_q)
print("cu_seqlens_k:\n", cu_seqlens_k) print("cu_seqlens_k:\n", cu_seqlens_k)
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64")
"int64")
ref_cum_offsets_out = np.array([0, 6, 13], "int32") ref_cum_offsets_out = np.array([0, 6, 13], "int32")
ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13], ref_padding_offset = np.array([0, 0, 0, 0, 6, 6, 6, 13, 13, 13, 13, 13, 13], "int32")
"int32")
ref_cu_seqlens_q = np.array([0, 4, 7, 13], "int32") ref_cu_seqlens_q = np.array([0, 4, 7, 13], "int32")
ref_cu_seqlens_k = np.array([0, 4, 7, 13], "int32") ref_cu_seqlens_k = np.array([0, 4, 7, 13], "int32")
assert sum(ref_x_remove_padding - assert sum(ref_x_remove_padding - x_remove_padding) == 0, "Check x_remove_padding failed."
x_remove_padding) == 0, 'Check x_remove_padding failed.' assert sum(ref_cum_offsets_out - cum_offsets_out) == 0, "Check cum_offsets_out failed."
assert sum(ref_cum_offsets_out - assert sum(ref_padding_offset - padding_offset) == 0, "Check padding_offset failed."
cum_offsets_out) == 0, 'Check cum_offsets_out failed.' assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, "Check cu_seqlens_q failed."
assert sum(ref_padding_offset - assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, "Check cu_seqlens_k failed."
padding_offset) == 0, 'Check padding_offset failed.'
assert sum(ref_cu_seqlens_q - cu_seqlens_q) == 0, 'Check cu_seqlens_q failed.'
assert sum(ref_cu_seqlens_k - cu_seqlens_k) == 0, 'Check cu_seqlens_k failed.'

View File

@@ -21,10 +21,15 @@ paddle.seed(2023)
pre_ids = paddle.to_tensor( pre_ids = paddle.to_tensor(
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]], [[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
"int64") "int64",
logits = paddle.to_tensor([[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1], )
[0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1]], logits = paddle.to_tensor(
"float32") [
[0.1, 0.9, 0.3, 0.4, 0.5, 0.6, 0.7, 0.1, 0.1, 0.1],
[0.1, 0.9, 0.7, 0.6, 0.5, 0.4, 0.1, 0.1, 0.1, 0.1],
],
"float32",
)
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32") penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
frequency_scores = paddle.to_tensor([0.1, 0.1], "float32") frequency_scores = paddle.to_tensor([0.1, 0.1], "float32")
presence_scores = paddle.to_tensor([0.0, 0.0], "float32") presence_scores = paddle.to_tensor([0.0, 0.0], "float32")
@@ -88,78 +93,536 @@ ref_logits = np.array(
) )
diff_logits = np.sum(np.abs(ref_logits - logits.numpy())) diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
print("diff_logits\n", diff_logits) print("diff_logits\n", diff_logits)
assert diff_logits < 1e-6, 'Check failed.' assert diff_logits < 1e-6, "Check failed."
pre_ids = paddle.to_tensor( pre_ids = paddle.to_tensor(
[[ [
2, 3, 3, 5, 8, 9, 3, 9, 1, 8, 9, 2, 3, 8, 8, 9, 9, 1, 4, 2, 6, 2, 6, 8, [
7, 2, 2, 3, 8, 1, 5, 7, 9, 2, 2, 9, 1, 4, 9, 8, 5, 8, 5, 7, 3, 6, 4, 4, 2,
9, 9, 8, 5, 5, 2, 2, 9, 4, 8, 1, 9, 6, 9, 2, 2, 7, 2, 2, 9, 4, 6, 4, 6, 3,
1, 4, 1, 9, 1, 8, 8, 5, 7, 9, 4, 2, 5, 1, 1, 4, 1, 5, 5, 4, 4, 2, 1, 8, 3,
7, 1, 2, 9, 6, 7, 9, 6, 7, 7, 4, 9, 9, 7, 5, 1, 8, 9, 8, 8, 5, 4, 6, 4, 5,
7, 5, 5, 7, 6, 9, 3, 9 8,
], 9,
[ 3,
7, 8, 1, 3, 1, 7, 6, 3, 5, 3, 8, 3, 1, 9, 7, 1, 1, 9, 5, 4, 9, 6, 1, 9,
9, 3, 8, 3, 9, 9, 6, 4, 2, 8, 5, 3, 1, 6, 9, 1, 3, 9, 8, 1, 7, 5, 1, 1,
5, 1, 8, 7, 4, 5, 9, 8, 7, 4, 7, 3, 6, 4, 6, 6, 5, 5, 2, 9, 9, 5, 8, 8,
8, 4, 8, 2, 8, 1, 3, 9, 1, 8, 5, 8, 3, 8, 8, 2, 7, 3, 7, 5, 7, 2, 6, 9,
3, 5, 1, 4, 6, 1, 9, 8, 2, 2, 3, 6, 7, 6, 2, 6, 5, 1, 5, 6, 2, 1, 6, 2,
4, 7, 7, 3, 8, 5, 1, 9, 1, 2, 8, 6, 8 3,
]]) 8,
8,
9,
9,
1,
4,
2,
6,
2,
6,
8,
7,
2,
2,
3,
8,
1,
5,
7,
9,
2,
2,
9,
1,
4,
9,
8,
5,
8,
5,
7,
3,
6,
4,
4,
9,
9,
8,
5,
5,
2,
2,
9,
4,
8,
1,
9,
6,
9,
2,
2,
7,
2,
2,
9,
4,
6,
4,
6,
1,
4,
1,
9,
1,
8,
8,
5,
7,
9,
4,
2,
5,
1,
1,
4,
1,
5,
5,
4,
4,
2,
1,
8,
7,
1,
2,
9,
6,
7,
9,
6,
7,
7,
4,
9,
9,
7,
5,
1,
8,
9,
8,
8,
5,
4,
6,
4,
7,
5,
5,
7,
6,
9,
3,
9,
],
[
7,
8,
1,
3,
1,
7,
6,
3,
5,
3,
8,
3,
1,
9,
7,
1,
1,
9,
5,
4,
9,
6,
1,
9,
3,
8,
3,
9,
9,
6,
4,
2,
8,
5,
3,
1,
6,
9,
1,
3,
9,
8,
1,
7,
5,
1,
5,
1,
8,
7,
4,
5,
9,
8,
7,
4,
7,
3,
6,
4,
6,
6,
5,
5,
2,
9,
9,
5,
8,
8,
4,
8,
2,
8,
1,
3,
9,
1,
8,
5,
8,
3,
8,
8,
2,
7,
3,
7,
5,
7,
2,
6,
3,
5,
1,
4,
6,
1,
9,
8,
2,
2,
3,
6,
7,
6,
2,
6,
5,
1,
5,
6,
2,
1,
6,
4,
7,
7,
3,
8,
5,
1,
9,
1,
2,
8,
6,
8,
],
]
)
logits = paddle.to_tensor( logits = paddle.to_tensor(
[[ [
0.16274983, 0.61470598, 0.94366980, 0.82005417, 0.50752640, 0.38316748, [
0.92648441, 0.24050158, 0.05461595, 0.42218581, 0.36270225, 0.15464807, 0.16274983,
0.13614719, 0.67509544, 0.40315166, 0.10671722, 0.24832056, 0.76091218, 0.61470598,
0.11598995, 0.10962527, 0.04688513, 0.81536716, 0.72259802, 0.60476679, 0.94366980,
0.16701800, 0.84160781, 0.79649884, 0.78021604, 0.75329530, 0.98587888, 0.82005417,
0.13421868, 0.16027625, 0.15269397, 0.06228730, 0.73856270, 0.34721911, 0.50752640,
0.73683006, 0.78178608, 0.32068327, 0.79906309, 0.44214272, 0.63330448, 0.38316748,
0.08016958, 0.63367140, 0.19788943, 0.55346787, 0.11142531, 0.90518415, 0.92648441,
0.21236691, 0.81587470, 0.83752930, 0.70979482, 0.35684183, 0.28715104, 0.24050158,
0.87162822, 0.17679396, 0.98725849, 0.76129991, 0.04090235, 0.37181064, 0.05461595,
0.63317049, 0.24689502, 0.21126501, 0.57617670, 0.74346697, 0.40613672, 0.42218581,
0.56907010, 0.68556929, 0.29032683, 0.17866278, 0.35165095, 0.97015840, 0.36270225,
0.70785582, 0.54259878, 0.14712237, 0.90483177, 0.02094105, 0.36411613, 0.15464807,
0.02495066, 0.88874054, 0.88895452, 0.86216462, 0.58062190, 0.95583254, 0.13614719,
0.20553111, 0.29870346, 0.69652933, 0.36861244, 0.85316223, 0.50240189, 0.67509544,
0.17566244, 0.61080140, 0.88203174, 0.98675215, 0.24344546, 0.17213407, 0.40315166,
0.78160852, 0.25165486, 0.48188508, 0.82812423, 0.10199814, 0.90475923, 0.10671722,
0.66907483, 0.71910626, 0.40660757, 0.59460294, 0.70212913, 0.90841550, 0.24832056,
0.00329034, 0.11290466, 0.89654654, 0.69114941, 0.29473618, 0.62027222, 0.76091218,
0.37333879, 0.98911142, 0.46510187, 0.65914583, 0.73022646, 0.12790845, 0.11598995,
0.12817244, 0.43015456, 0.75011456, 0.43562204, 0.48086026, 0.75587070, 0.10962527,
0.98481447, 0.77367836 0.04688513,
], 0.81536716,
[ 0.72259802,
0.12336024, 0.74152875, 0.09191196, 0.99301219, 0.44764417, 0.60476679,
0.01848883, 0.78326035, 0.99228370, 0.81447607, 0.02627683, 0.16701800,
0.51033205, 0.98703283, 0.15247856, 0.77640921, 0.60799915, 0.84160781,
0.87518770, 0.76818430, 0.86542630, 0.31795895, 0.04829503, 0.79649884,
0.85567141, 0.30271924, 0.67515039, 0.59728831, 0.78710967, 0.78021604,
0.75111693, 0.56837374, 0.49085775, 0.91510201, 0.59545547, 0.75329530,
0.99482232, 0.59036905, 0.58267909, 0.28770933, 0.53237396, 0.98587888,
0.95318258, 0.93987304, 0.61142951, 0.26737869, 0.52285451, 0.13421868,
0.03479086, 0.61631846, 0.66777998, 0.15736090, 0.00447258, 0.16027625,
0.37035006, 0.15281211, 0.95372260, 0.25963321, 0.61036694, 0.15269397,
0.15020694, 0.19171195, 0.55252832, 0.00391038, 0.31052542, 0.06228730,
0.96495175, 0.42586124, 0.05630261, 0.99728668, 0.01856293, 0.73856270,
0.83201504, 0.10701843, 0.56434178, 0.38009524, 0.51095045, 0.34721911,
0.13202040, 0.07133843, 0.75313550, 0.17111187, 0.80716974, 0.73683006,
0.00172165, 0.83906764, 0.73240769, 0.85843354, 0.11042888, 0.78178608,
0.07912333, 0.33689004, 0.22334915, 0.59059596, 0.52789515, 0.32068327,
0.29831955, 0.39515004, 0.55602801, 0.83818001, 0.05865780, 0.79906309,
0.25654668, 0.76624149, 0.35190639, 0.04158346, 0.59157544, 0.44214272,
0.30779791, 0.94609004, 0.10759670, 0.65575141, 0.37828529, 0.63330448,
0.29571742, 0.76361233, 0.72476572, 0.18568406, 0.85430276, 0.08016958,
0.02057583, 0.76195669, 0.65507215, 0.69129735, 0.25084621, 0.63367140,
0.75223947, 0.06064088, 0.20287007, 0.35887691, 0.75043523, 0.19788943,
0.47575447, 0.40021798, 0.44464844, 0.67975360, 0.40443239, 0.55346787,
0.71052992, 0.21782248, 0.50568426, 0.89037591, 0.06661721, 0.11142531,
0.28788096, 0.70773387, 0.42428264, 0.80419677, 0.42710736, 0.90518415,
0.87317258, 0.88229448, 0.79217333 0.21236691,
]]) 0.81587470,
0.83752930,
0.70979482,
0.35684183,
0.28715104,
0.87162822,
0.17679396,
0.98725849,
0.76129991,
0.04090235,
0.37181064,
0.63317049,
0.24689502,
0.21126501,
0.57617670,
0.74346697,
0.40613672,
0.56907010,
0.68556929,
0.29032683,
0.17866278,
0.35165095,
0.97015840,
0.70785582,
0.54259878,
0.14712237,
0.90483177,
0.02094105,
0.36411613,
0.02495066,
0.88874054,
0.88895452,
0.86216462,
0.58062190,
0.95583254,
0.20553111,
0.29870346,
0.69652933,
0.36861244,
0.85316223,
0.50240189,
0.17566244,
0.61080140,
0.88203174,
0.98675215,
0.24344546,
0.17213407,
0.78160852,
0.25165486,
0.48188508,
0.82812423,
0.10199814,
0.90475923,
0.66907483,
0.71910626,
0.40660757,
0.59460294,
0.70212913,
0.90841550,
0.00329034,
0.11290466,
0.89654654,
0.69114941,
0.29473618,
0.62027222,
0.37333879,
0.98911142,
0.46510187,
0.65914583,
0.73022646,
0.12790845,
0.12817244,
0.43015456,
0.75011456,
0.43562204,
0.48086026,
0.75587070,
0.98481447,
0.77367836,
],
[
0.12336024,
0.74152875,
0.09191196,
0.99301219,
0.44764417,
0.01848883,
0.78326035,
0.99228370,
0.81447607,
0.02627683,
0.51033205,
0.98703283,
0.15247856,
0.77640921,
0.60799915,
0.87518770,
0.76818430,
0.86542630,
0.31795895,
0.04829503,
0.85567141,
0.30271924,
0.67515039,
0.59728831,
0.78710967,
0.75111693,
0.56837374,
0.49085775,
0.91510201,
0.59545547,
0.99482232,
0.59036905,
0.58267909,
0.28770933,
0.53237396,
0.95318258,
0.93987304,
0.61142951,
0.26737869,
0.52285451,
0.03479086,
0.61631846,
0.66777998,
0.15736090,
0.00447258,
0.37035006,
0.15281211,
0.95372260,
0.25963321,
0.61036694,
0.15020694,
0.19171195,
0.55252832,
0.00391038,
0.31052542,
0.96495175,
0.42586124,
0.05630261,
0.99728668,
0.01856293,
0.83201504,
0.10701843,
0.56434178,
0.38009524,
0.51095045,
0.13202040,
0.07133843,
0.75313550,
0.17111187,
0.80716974,
0.00172165,
0.83906764,
0.73240769,
0.85843354,
0.11042888,
0.07912333,
0.33689004,
0.22334915,
0.59059596,
0.52789515,
0.29831955,
0.39515004,
0.55602801,
0.83818001,
0.05865780,
0.25654668,
0.76624149,
0.35190639,
0.04158346,
0.59157544,
0.30779791,
0.94609004,
0.10759670,
0.65575141,
0.37828529,
0.29571742,
0.76361233,
0.72476572,
0.18568406,
0.85430276,
0.02057583,
0.76195669,
0.65507215,
0.69129735,
0.25084621,
0.75223947,
0.06064088,
0.20287007,
0.35887691,
0.75043523,
0.47575447,
0.40021798,
0.44464844,
0.67975360,
0.40443239,
0.71052992,
0.21782248,
0.50568426,
0.89037591,
0.06661721,
0.28788096,
0.70773387,
0.42428264,
0.80419677,
0.42710736,
0.87317258,
0.88229448,
0.79217333,
],
]
)
# pre_ids = paddle.to_tensor(np.float32(np.random.random([2, 1024]))) # pre_ids = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
# logits = paddle.to_tensor(np.float32(np.random.random([2, 1024]))) # logits = paddle.to_tensor(np.float32(np.random.random([2, 1024])))
penalty_scores = paddle.to_tensor([1.0, 1.0], "float32") penalty_scores = paddle.to_tensor([1.0, 1.0], "float32")
@@ -195,60 +658,270 @@ print("min_len\n", min_len)
print("eos_token_id\n", eos_token_id) print("eos_token_id\n", eos_token_id)
ref_logits = np.array( ref_logits = np.array(
[[ [
-10000000000., -10000000000., 1.88733959, 1.64010835, 1.01505280, [
0.76633495, 1.85296881, 0.48100317, 0.10923190, 0.84437162, 0.72540450, -10000000000.0,
0.30929613, 0.27229437, 1.35019088, 0.80630332, 0.21343444, 0.49664113, -10000000000.0,
1.52182436, 0.23197991, 0.21925054, 0.09377026, 1.63073432, 1.44519603, 1.88733959,
1.20953357, 0.33403599, 1.68321562, 1.59299767, 1.56043208, 1.50659060, 1.64010835,
1.97175777, 0.26843736, 0.32055250, 0.30538794, 0.12457460, 1.47712541, 1.01505280,
0.69443822, 1.47366011, 1.56357217, 0.64136654, 1.59812617, 0.88428545, 0.76633495,
1.26660895, 0.16033916, 1.26734281, 0.39577886, 1.10693574, 0.22285062, 1.85296881,
1.81036830, 0.42473382, 1.63174939, 1.67505860, 1.41958964, 0.71368366, 0.48100317,
0.57430208, 1.74325645, 0.35358793, 1.97451699, 1.52259982, 0.08180470, 0.10923190,
0.74362129, 1.26634097, 0.49379003, 0.42253003, 1.15235341, 1.48693395, 0.84437162,
0.81227344, 1.13814020, 1.37113857, 0.58065367, 0.35732555, 0.70330191, 0.72540450,
1.94031680, 1.41571164, 1.08519757, 0.29424474, 1.80966353, 0.04188210, 0.30929613,
0.72823226, 0.04990132, 1.77748108, 1.77790904, 1.72432923, 1.16124380, 0.27229437,
1.91166508, 0.41106221, 0.59740692, 1.39305866, 0.73722488, 1.70632446, 1.35019088,
1.00480378, 0.35132489, 1.22160280, 1.76406348, 1.97350430, 0.48689091, 0.80630332,
0.34426814, 1.56321704, 0.50330973, 0.96377015, 1.65624845, 0.20399629, 0.21343444,
1.80951846, 1.33814967, 1.43821251, 0.81321514, 1.18920588, 1.40425825, 0.49664113,
1.81683099, 0.00658068, 0.22580932, 1.79309309, 1.38229883, 0.58947235, 1.52182436,
1.24054444, 0.74667758, 1.97822285, 0.93020374, 1.31829166, 1.46045291, 0.23197991,
0.25581691, 0.25634488, 0.86030912, 1.50022912, 0.87124407, 0.96172053, 0.21925054,
1.51174140, 1.96962893, 1.54735672 0.09377026,
1.63073432,
1.44519603,
1.20953357,
0.33403599,
1.68321562,
1.59299767,
1.56043208,
1.50659060,
1.97175777,
0.26843736,
0.32055250,
0.30538794,
0.12457460,
1.47712541,
0.69443822,
1.47366011,
1.56357217,
0.64136654,
1.59812617,
0.88428545,
1.26660895,
0.16033916,
1.26734281,
0.39577886,
1.10693574,
0.22285062,
1.81036830,
0.42473382,
1.63174939,
1.67505860,
1.41958964,
0.71368366,
0.57430208,
1.74325645,
0.35358793,
1.97451699,
1.52259982,
0.08180470,
0.74362129,
1.26634097,
0.49379003,
0.42253003,
1.15235341,
1.48693395,
0.81227344,
1.13814020,
1.37113857,
0.58065367,
0.35732555,
0.70330191,
1.94031680,
1.41571164,
1.08519757,
0.29424474,
1.80966353,
0.04188210,
0.72823226,
0.04990132,
1.77748108,
1.77790904,
1.72432923,
1.16124380,
1.91166508,
0.41106221,
0.59740692,
1.39305866,
0.73722488,
1.70632446,
1.00480378,
0.35132489,
1.22160280,
1.76406348,
1.97350430,
0.48689091,
0.34426814,
1.56321704,
0.50330973,
0.96377015,
1.65624845,
0.20399629,
1.80951846,
1.33814967,
1.43821251,
0.81321514,
1.18920588,
1.40425825,
1.81683099,
0.00658068,
0.22580932,
1.79309309,
1.38229883,
0.58947235,
1.24054444,
0.74667758,
1.97822285,
0.93020374,
1.31829166,
1.46045291,
0.25581691,
0.25634488,
0.86030912,
1.50022912,
0.87124407,
0.96172053,
1.51174140,
1.96962893,
1.54735672,
],
[
-10000000000.0,
-10000000000.0,
-40000.0,
3.97204876,
1.79057670,
0.07395532,
3.13304138,
3.96913481,
3.25790429,
-40000.0,
2.04132819,
3.94813132,
0.60991424,
3.10563684,
2.43199658,
3.50075078,
3.07273722,
3.46170521,
1.27183580,
0.19318011,
3.42268562,
1.21087694,
2.70060158,
2.38915324,
3.14843869,
3.00446773,
2.27349496,
1.96343100,
3.66040802,
2.38182187,
3.97928929,
2.36147618,
2.33071637,
1.15083730,
2.12949586,
3.81273031,
3.75949216,
2.44571805,
1.06951475,
2.09141803,
0.13916343,
2.46527386,
2.67111993,
0.62944359,
0.01789032,
1.48140025,
0.61124843,
3.81489038,
1.03853285,
2.44146776,
0.60082775,
0.76684779,
2.21011329,
0.01564152,
1.24210167,
3.85980701,
1.70344496,
0.22521044,
3.98914671,
0.07425172,
3.32806015,
0.42807373,
2.25736713,
1.52038097,
2.04380178,
0.52808160,
0.28535372,
3.01254201,
0.68444747,
3.22867894,
0.00688660,
3.35627055,
2.92963076,
3.43373418,
0.44171551,
0.31649333,
1.34756017,
0.89339662,
2.36238384,
2.11158061,
1.19327819,
1.58060014,
2.22411203,
3.35272002,
0.23463120,
1.02618670,
3.06496596,
1.40762556,
0.16633384,
2.36630177,
1.23119164,
3.78436017,
0.43038681,
2.62300563,
1.51314116,
1.18286967,
3.05444932,
2.89906287,
0.74273622,
3.41721106,
0.08230332,
3.04782677,
2.62028861,
2.76518941,
1.00338483,
3.00895786,
0.24256352,
0.81148028,
1.43550766,
3.00174093,
1.90301788,
1.60087192,
1.77859378,
2.71901441,
1.61772954,
2.84211969,
0.87128991,
2.02273703,
3.56150365,
0.26646885,
1.15152383,
2.83093548,
1.69713056,
3.21678710,
1.70842946,
3.49269032,
3.52917790,
3.16869330,
],
], ],
[
-10000000000., -10000000000., -40000., 3.97204876, 1.79057670,
0.07395532, 3.13304138, 3.96913481, 3.25790429, -40000., 2.04132819,
3.94813132, 0.60991424, 3.10563684, 2.43199658, 3.50075078,
3.07273722, 3.46170521, 1.27183580, 0.19318011, 3.42268562,
1.21087694, 2.70060158, 2.38915324, 3.14843869, 3.00446773,
2.27349496, 1.96343100, 3.66040802, 2.38182187, 3.97928929,
2.36147618, 2.33071637, 1.15083730, 2.12949586, 3.81273031,
3.75949216, 2.44571805, 1.06951475, 2.09141803, 0.13916343,
2.46527386, 2.67111993, 0.62944359, 0.01789032, 1.48140025,
0.61124843, 3.81489038, 1.03853285, 2.44146776, 0.60082775,
0.76684779, 2.21011329, 0.01564152, 1.24210167, 3.85980701,
1.70344496, 0.22521044, 3.98914671, 0.07425172, 3.32806015,
0.42807373, 2.25736713, 1.52038097, 2.04380178, 0.52808160,
0.28535372, 3.01254201, 0.68444747, 3.22867894, 0.00688660,
3.35627055, 2.92963076, 3.43373418, 0.44171551, 0.31649333,
1.34756017, 0.89339662, 2.36238384, 2.11158061, 1.19327819,
1.58060014, 2.22411203, 3.35272002, 0.23463120, 1.02618670,
3.06496596, 1.40762556, 0.16633384, 2.36630177, 1.23119164,
3.78436017, 0.43038681, 2.62300563, 1.51314116, 1.18286967,
3.05444932, 2.89906287, 0.74273622, 3.41721106, 0.08230332,
3.04782677, 2.62028861, 2.76518941, 1.00338483, 3.00895786,
0.24256352, 0.81148028, 1.43550766, 3.00174093, 1.90301788,
1.60087192, 1.77859378, 2.71901441, 1.61772954, 2.84211969,
0.87128991, 2.02273703, 3.56150365, 0.26646885, 1.15152383,
2.83093548, 1.69713056, 3.21678710, 1.70842946, 3.49269032,
3.52917790, 3.16869330
]],
"float32", "float32",
) )
diff_logits = np.sum(np.abs(ref_logits - logits.numpy())) diff_logits = np.sum(np.abs(ref_logits - logits.numpy()))
print("diff_logits\n", diff_logits) print("diff_logits\n", diff_logits)
assert diff_logits < 1e-6, 'Check failed.' assert diff_logits < 1e-6, "Check failed."

View File

@@ -21,19 +21,30 @@ paddle.seed(2023)
pre_ids_all = paddle.to_tensor( pre_ids_all = paddle.to_tensor(
[[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]], [[1, 9, 3, 4, 5, 6, 7, -1, -1, -1], [1, 9, 7, 6, 5, 4, -1, -1, -1, -1]],
"int64") "int64",
input_ids = paddle.to_tensor([[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1], )
[1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1]], input_ids = paddle.to_tensor(
"int64") [
[1, 9, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1],
[1, 9, 7, 6, 5, 4, -1, -1, -1, -1, -1, -1, -1],
],
"int64",
)
seq_lens_this_time = paddle.to_tensor([1, 1], "int32") seq_lens_this_time = paddle.to_tensor([1, 1], "int32")
seq_lens_encoder = paddle.to_tensor([1, 1], "int32") seq_lens_encoder = paddle.to_tensor([1, 1], "int32")
seq_lens_decoder = paddle.to_tensor([1, 1], "int32") seq_lens_decoder = paddle.to_tensor([1, 1], "int32")
step_idx = paddle.to_tensor([1, 1], "int64") step_idx = paddle.to_tensor([1, 1], "int64")
stop_flags = paddle.to_tensor([0, 1], "bool") stop_flags = paddle.to_tensor([0, 1], "bool")
print("pre_ids_all\n", pre_ids_all) print("pre_ids_all\n", pre_ids_all)
set_value_by_flags_and_idx(pre_ids_all, input_ids, seq_lens_this_time, set_value_by_flags_and_idx(
seq_lens_encoder, seq_lens_decoder, step_idx, pre_ids_all,
stop_flags) input_ids,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
stop_flags,
)
print("pre_ids_all\n", pre_ids_all) print("pre_ids_all\n", pre_ids_all)
print("input_ids\n", input_ids) print("input_ids\n", input_ids)
print("seq_lens_this_time\n", seq_lens_this_time) print("seq_lens_this_time\n", seq_lens_this_time)
@@ -73,4 +84,4 @@ ref_pre_ids_all = np.array(
) )
diff_pre_ids_all = np.sum(np.abs(ref_pre_ids_all - pre_ids_all.numpy())) diff_pre_ids_all = np.sum(np.abs(ref_pre_ids_all - pre_ids_all.numpy()))
print("diff_pre_ids_all\n", diff_pre_ids_all) print("diff_pre_ids_all\n", diff_pre_ids_all)
assert diff_pre_ids_all == 0, 'Check failed.' assert diff_pre_ids_all == 0, "Check failed."

View File

@@ -41,10 +41,7 @@ step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64")
max_block_num = block_bs * max_seq_len // block_size max_block_num = block_bs * max_seq_len // block_size
free_list_len = int(max_block_num * (1 - block_ratio)) free_list_len = int(max_block_num * (1 - block_ratio))
free_list_len = np.full([1], free_list_len, "int32") free_list_len = np.full([1], free_list_len, "int32")
free_list = np.arange(max_block_num - 1, free_list = np.arange(max_block_num - 1, max_block_num - free_list_len - 1, -1, dtype="int32")
max_block_num - free_list_len - 1,
-1,
dtype="int32")
encoder_block_lens = np.zeros([max_bs], "int32") encoder_block_lens = np.zeros([max_bs], "int32")
used_list_len = np.zeros([max_bs], "int32") used_list_len = np.zeros([max_bs], "int32")
@@ -53,19 +50,15 @@ encoder_block_id = 0
for i in range(bs): for i in range(bs):
enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size
encoder_block_lens[i] = enc_block_num encoder_block_lens[i] = enc_block_num
dec_block_num = (seq_lens_decoder[i] + block_size - dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num
1) // block_size - enc_block_num
used_list_len[i] = dec_block_num used_list_len[i] = dec_block_num
block_tables[i, :enc_block_num] = np.arange( block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
encoder_block_id += enc_block_num encoder_block_id += enc_block_num
if dec_block_num > 0: if dec_block_num > 0:
block_tables[ block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[
i, enc_block_num:enc_block_num + free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1
dec_block_num] = free_list[free_list_len[0] - 1 - ]
dec_block_num:free_list_len[0] - 1] free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1
free_list[free_list_len[0] - 1 - dec_block_num:free_list_len[0] -
1] = -1
free_list_len[0] -= dec_block_num free_list_len[0] -= dec_block_num
assert free_list_len[0] >= 0 assert free_list_len[0] >= 0
@@ -137,13 +130,32 @@ first_token_ids = paddle.to_tensor(first_token_ids)
# print("step_idx: ", step_idx) # print("step_idx: ", step_idx)
# print("next_tokens: ", next_tokens) # print("next_tokens: ", next_tokens)
step_paddle(stop_flags, seq_lens_this_time, ori_seq_lens_encoder, step_paddle(
seq_lens_encoder, seq_lens_decoder, block_tables, stop_flags,
encoder_block_lens, is_block_step, step_block_list, step_lens, seq_lens_this_time,
recover_block_list, recover_lens, need_block_list, need_block_len, ori_seq_lens_encoder,
used_list_len, free_list, free_list_len, input_ids, pre_ids, seq_lens_encoder,
step_idx, next_tokens, first_token_ids, block_size, seq_lens_decoder,
encoder_decoder_block_num) block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_lens,
recover_block_list,
recover_lens,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
input_ids,
pre_ids,
step_idx,
next_tokens,
first_token_ids,
block_size,
encoder_decoder_block_num,
)
print("-" * 50 + "after step op" + "-" * 50) print("-" * 50 + "after step op" + "-" * 50)
print("stop_flags: ", stop_flags) print("stop_flags: ", stop_flags)

View File

@@ -30,8 +30,7 @@ end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64")
print("topk_ids\n", topk_ids) print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens) print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags) print("stop_flags\n", stop_flags)
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, False)
False)
print("topk_ids\n", topk_ids) print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens) print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags) print("stop_flags\n", stop_flags)
@@ -40,44 +39,220 @@ print("end_ids\n", end_ids)
ref_topk_ids = np.array( ref_topk_ids = np.array(
[ [
0, 0, 2, 3, -1, 0, 0, 0, 0, 9, 10, 0, 12, 0, -1, 15, 16, 0, 18, 19, 20, 0,
0, 22, 23, 0, 25, 26, 27, -1, 29, 30, 31, 0, 0, 0, -1, -1, 37, 38, 39, 0,
-1, -1, 0, 0, 0, 0, 46, -1, 0, 49, 50, 0, 52, 53, 0, -1, 0, 57, -1, 59, 2,
60, 0, 0, 63 3,
-1,
0,
0,
0,
0,
9,
10,
0,
12,
0,
-1,
15,
16,
0,
18,
19,
20,
0,
22,
23,
0,
25,
26,
27,
-1,
29,
30,
31,
0,
0,
0,
-1,
-1,
37,
38,
39,
-1,
-1,
0,
0,
0,
0,
46,
-1,
0,
49,
50,
0,
52,
53,
0,
-1,
0,
57,
-1,
59,
60,
0,
0,
63,
], ],
"int64", "int64",
) )
ref_next_tokens = np.array( ref_next_tokens = np.array(
[ [
0, 0, 2, 3, 0, 0, 0, 0, 0, 9, 10, 0, 12, 0, 0, 15, 16, 0, 18, 19, 20, 0,
0, 22, 23, 0, 25, 26, 27, 0, 29, 30, 31, 0, 0, 0, 0, 0, 37, 38, 39, 0, 0,
0, 0, 0, 0, 0, 46, 0, 0, 49, 50, 0, 52, 53, 0, 0, 0, 57, 0, 59, 60, 0, 2,
0, 63 3,
0,
0,
0,
0,
0,
9,
10,
0,
12,
0,
0,
15,
16,
0,
18,
19,
20,
0,
22,
23,
0,
25,
26,
27,
0,
29,
30,
31,
0,
0,
0,
0,
0,
37,
38,
39,
0,
0,
0,
0,
0,
0,
46,
0,
0,
49,
50,
0,
52,
53,
0,
0,
0,
57,
0,
59,
60,
0,
0,
63,
], ],
"int64", "int64",
) )
ref_stop_flags = np.array( ref_stop_flags = np.array(
[ [
True, True, True, True, True, True, True, True, True, False, False, True,
True, False, True, True, False, False, True, False, False, False, True, True,
False, False, True, False, False, False, True, False, False, False, True,
True, True, True, True, True, False, False, False, True, True, True, True,
True, True, True, False, True, True, False, False, True, False, False, True,
True, True, True, False, True, False, False, True, True, False True,
True,
True,
True,
False,
False,
True,
False,
True,
True,
False,
False,
True,
False,
False,
False,
True,
False,
False,
True,
False,
False,
False,
True,
False,
False,
False,
True,
True,
True,
True,
True,
False,
False,
False,
True,
True,
True,
True,
True,
True,
False,
True,
True,
False,
False,
True,
False,
False,
True,
True,
True,
False,
True,
False,
False,
True,
True,
False,
], ],
"bool", "bool",
) )
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy())) diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
print("diff_topk_ids\n", diff_topk_ids) print("diff_topk_ids\n", diff_topk_ids)
assert diff_topk_ids == 0, 'Check failed.' assert diff_topk_ids == 0, "Check failed."
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy())) diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
print("diff_next_tokens\n", diff_next_tokens) print("diff_next_tokens\n", diff_next_tokens)
assert diff_next_tokens == 0, 'Check failed.' assert diff_next_tokens == 0, "Check failed."
diff_stop_flags = np.sum( diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
np.abs(
ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
print("diff_stop_flags\n", diff_stop_flags) print("diff_stop_flags\n", diff_stop_flags)
assert diff_stop_flags == 0, 'Check failed.' assert diff_stop_flags == 0, "Check failed."
# test beam_search=True # test beam_search=True
topk_ids = paddle.arange(0, bs, dtype="int64") topk_ids = paddle.arange(0, bs, dtype="int64")
@@ -88,8 +263,7 @@ end_ids = paddle.to_tensor([0, 1, 2, 3, 4, 5], "int64")
print("topk_ids\n", topk_ids) print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens) print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags) print("stop_flags\n", stop_flags)
set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, set_stop_value_multi_ends(topk_ids, stop_flags, seq_lens, end_ids, next_tokens, True)
True)
print("topk_ids\n", topk_ids) print("topk_ids\n", topk_ids)
print("next_tokens\n", next_tokens) print("next_tokens\n", next_tokens)
print("stop_flags\n", stop_flags) print("stop_flags\n", stop_flags)
@@ -98,42 +272,217 @@ print("end_ids\n", end_ids)
ref_topk_ids = np.array( ref_topk_ids = np.array(
[ [
0, 1, 2, 3, 4, 0, 6, 7, -1, 9, 10, 0, -1, 13, 14, 15, 0, 17, 18, 19, 0,
20, 0, 22, 23, 24, 25, -1, -1, 28, 29, 0, 0, -1, 33, 34, 35, 36, 37, 0, 1,
-1, 0, 41, -1, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, 0, 0, 0, 0, 58, 2,
-1, 60, 61, -1, 63 3,
4,
0,
6,
7,
-1,
9,
10,
0,
-1,
13,
14,
15,
0,
17,
18,
19,
20,
0,
22,
23,
24,
25,
-1,
-1,
28,
29,
0,
0,
-1,
33,
34,
35,
36,
37,
0,
-1,
0,
41,
-1,
0,
44,
45,
46,
0,
0,
49,
0,
0,
0,
53,
0,
0,
0,
0,
58,
-1,
60,
61,
-1,
63,
], ],
"int64", "int64",
) )
ref_next_tokens = np.array( ref_next_tokens = np.array(
[ [
0, 1, 2, 3, 4, 0, 6, 7, 0, 9, 10, 0, 0, 13, 14, 15, 0, 17, 18, 19, 20, 0,
0, 22, 23, 24, 25, 0, 0, 28, 29, 0, 0, 0, 33, 34, 35, 36, 37, 0, 0, 0, 1,
41, 0, 0, 44, 45, 46, 0, 0, 49, 0, 0, 0, 53, 0, 0, 0, 0, 58, 0, 60, 61, 2,
0, 63 3,
4,
0,
6,
7,
0,
9,
10,
0,
0,
13,
14,
15,
0,
17,
18,
19,
20,
0,
22,
23,
24,
25,
0,
0,
28,
29,
0,
0,
0,
33,
34,
35,
36,
37,
0,
0,
0,
41,
0,
0,
44,
45,
46,
0,
0,
49,
0,
0,
0,
53,
0,
0,
0,
0,
58,
0,
60,
61,
0,
63,
], ],
"int64", "int64",
) )
ref_stop_flags = np.array( ref_stop_flags = np.array(
[ [
False, False, False, False, False, True, False, False, True, False, False,
False, True, True, False, False, False, True, False, False, False, False,
False, True, False, False, False, False, True, True, False, False, False,
True, True, True, False, False, False, False, False, True, True, True, False,
False, True, True, False, False, False, True, True, False, True, True, False,
True, False, True, True, True, True, False, True, False, False, True, True,
False False,
False,
True,
False,
False,
True,
True,
False,
False,
False,
True,
False,
False,
False,
False,
True,
False,
False,
False,
False,
True,
True,
False,
False,
True,
True,
True,
False,
False,
False,
False,
False,
True,
True,
True,
False,
True,
True,
False,
False,
False,
True,
True,
False,
True,
True,
True,
False,
True,
True,
True,
True,
False,
True,
False,
False,
True,
False,
], ],
"bool", "bool",
) )
diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy())) diff_topk_ids = np.sum(np.abs(ref_topk_ids - topk_ids.numpy()))
print("diff_topk_ids\n", diff_topk_ids) print("diff_topk_ids\n", diff_topk_ids)
assert diff_topk_ids == 0, 'Check failed.' assert diff_topk_ids == 0, "Check failed."
diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy())) diff_next_tokens = np.sum(np.abs(ref_next_tokens - next_tokens.numpy()))
print("diff_next_tokens\n", diff_next_tokens) print("diff_next_tokens\n", diff_next_tokens)
assert diff_next_tokens == 0, 'Check failed.' assert diff_next_tokens == 0, "Check failed."
diff_stop_flags = np.sum( diff_stop_flags = np.sum(np.abs(ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
np.abs(
ref_stop_flags.astype(np.int32) - stop_flags.numpy().astype(np.int32)))
print("diff_stop_flags\n", diff_stop_flags) print("diff_stop_flags\n", diff_stop_flags)
assert diff_stop_flags == 0, 'Check failed.' assert diff_stop_flags == 0, "Check failed."

View File

@@ -60,9 +60,17 @@ print("stop_nums:\n", stop_nums)
print("next_tokens:\n", next_tokens) print("next_tokens:\n", next_tokens)
print("is_block_step:\n", is_block_step) print("is_block_step:\n", is_block_step)
update_inputs(stop_flags, not_need_stop, seq_lens_this_time, seq_lens_encoder, update_inputs(
seq_lens_decoder, input_ids, stop_nums, next_tokens, stop_flags,
is_block_step) not_need_stop,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
input_ids,
stop_nums,
next_tokens,
is_block_step,
)
print("-" * 50) print("-" * 50)
print("stop_flags:\n", stop_flags) print("stop_flags:\n", stop_flags)
@@ -75,32 +83,269 @@ print("stop_nums:\n", stop_nums)
print("next_tokens:\n", next_tokens) print("next_tokens:\n", next_tokens)
ref_not_need_stop_out = np.array([True]) ref_not_need_stop_out = np.array([True])
ref_seq_lens_this_time_out = np.array([ ref_seq_lens_this_time_out = np.array(
0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, [
0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1 0,
], "int32") 0,
ref_seq_lens_encoder_out = np.array([ 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 1,
], "int32") 0,
ref_seq_lens_decoder_out = np.array([ 1,
0, 0, 2, 0, 0, 6, 0, 8, 8, 10, 0, 12, 12, 0, 0, 0, 0, 0, 0, 0, 20, 22, 0, 1,
24, 24, 0, 26, 28, 0, 0, 0, 32, 32, 0, 34, 0, 0, 38, 0, 40, 0, 0, 42, 0, 0, 1,
46, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 0,
], "int32") 1,
input_ids_np[:, 0] = np.array([ 1,
6, 5, 9, 8, 6, 2, 8, 1, 3, 1, 3, 6, 9, 8, 1, 9, 1, 8, 8, 6, 7, 6, 5, 3, 5, 0,
9, 3, 6, 3, 9, 8, 8, 8, 8, 4, 8, 7, 4, 2, 3, 5, 8, 4, 2, 5, 6, 8, 9, 6, 7, 0,
4, 2, 4, 6, 2, 3, 4, 9, 7, 2, 1, 8, 7, 8 0,
], "int64") 0,
0,
0,
0,
1,
1,
0,
1,
1,
0,
1,
1,
0,
0,
0,
1,
1,
0,
1,
0,
0,
1,
0,
1,
0,
0,
1,
0,
0,
1,
1,
1,
],
"int32",
)
ref_seq_lens_encoder_out = np.array(
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
],
"int32",
)
ref_seq_lens_decoder_out = np.array(
[
0,
0,
2,
0,
0,
6,
0,
8,
8,
10,
0,
12,
12,
0,
0,
0,
0,
0,
0,
0,
20,
22,
0,
24,
24,
0,
26,
28,
0,
0,
0,
32,
32,
0,
34,
0,
0,
38,
0,
40,
0,
0,
42,
0,
0,
46,
46,
48,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
],
"int32",
)
input_ids_np[:, 0] = np.array(
[
6,
5,
9,
8,
6,
2,
8,
1,
3,
1,
3,
6,
9,
8,
1,
9,
1,
8,
8,
6,
7,
6,
5,
3,
5,
9,
3,
6,
3,
9,
8,
8,
8,
8,
4,
8,
7,
4,
2,
3,
5,
8,
4,
2,
5,
6,
8,
9,
6,
7,
4,
2,
4,
6,
2,
3,
4,
9,
7,
2,
1,
8,
7,
8,
],
"int64",
)
assert not_need_stop.numpy( assert not_need_stop.numpy() == ref_not_need_stop_out, "Check not_need_stop failed."
) == ref_not_need_stop_out, 'Check not_need_stop failed.' assert np.all(seq_lens_this_time.numpy() == ref_seq_lens_this_time_out), "Check seq_lens_this_time failed."
assert np.all(seq_lens_this_time.numpy() == assert np.all(seq_lens_encoder.numpy() == ref_seq_lens_encoder_out), "Check seq_lens_encoder failed."
ref_seq_lens_this_time_out), 'Check seq_lens_this_time failed.' assert np.all(seq_lens_decoder.numpy() == ref_seq_lens_decoder_out), "Check seq_lens_decoder failed."
assert np.all(seq_lens_encoder.numpy() == assert np.all(input_ids.numpy() == input_ids_np), "Check input_ids failed."
ref_seq_lens_encoder_out), 'Check seq_lens_encoder failed.'
assert np.all(seq_lens_decoder.numpy() ==
ref_seq_lens_decoder_out), 'Check seq_lens_decoder failed.'
assert np.all(input_ids.numpy() == input_ids_np), 'Check input_ids failed.'

View File

@@ -29,16 +29,15 @@ def np_quant_weight_int4(weight_np):
weight = np.transpose(weight_np, [1, 0]) # n,k weight = np.transpose(weight_np, [1, 0]) # n,k
max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1) # k => k,1 max_value = np.max(np.abs(weight), axis=1).reshape(-1, 1) # k => k,1
quanted_weight = np_clip_and_round(weight / max_value * 7.0, 7) # n,k quanted_weight = np_clip_and_round(weight / max_value * 7.0, 7) # n,k
quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | ( quanted_weight = (quanted_weight[:, 1::2] & 0xF) << 4 | (quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
quanted_weight[:, ::2] & 0xF) # pack int4, [n,k//2]
weight_scales = (max_value).astype(weight_np.dtype).reshape(-1) weight_scales = (max_value).astype(weight_np.dtype).reshape(-1)
return quanted_weight, weight_scales.astype(np.float32) return quanted_weight, weight_scales.astype(np.float32)
def np_quant_weight(weight_np, algo='weight_only_int8'): def np_quant_weight(weight_np, algo="weight_only_int8"):
assert weight_np.dtype == np.float32 assert weight_np.dtype == np.float32
if algo == 'weight_only_int4': if algo == "weight_only_int4":
return np_quant_weight_int4(weight_np) return np_quant_weight_int4(weight_np)
weight = np.transpose(weight_np, [1, 0]) weight = np.transpose(weight_np, [1, 0])
@@ -56,7 +55,7 @@ def int8_to_bin_np(value):
def int8_to_bin(value): def int8_to_bin(value):
if not -128 <= value <= 127: if not -128 <= value <= 127:
raise ValueError("int8 值必须在 -128 到 127 之间") raise ValueError("int8 值必须在 -128 到 127 之间")
return format(value & 0xFF, '08b') # '08b' 表示 8 位二进制,高位补零 return format(value & 0xFF, "08b") # '08b' 表示 8 位二进制,高位补零
# 1) preparation # 1) preparation
@@ -70,7 +69,7 @@ w_np = (np.random.random((k, n)).astype(np.float32) - 0.5) * 10
qw_np, wscale_np = np_quant_weight(w_np, algo) qw_np, wscale_np = np_quant_weight(w_np, algo)
# 3) xpu calculation # 3) xpu calculation
dtype = 'float32' dtype = "float32"
x_pd = paddle.to_tensor(w_np, dtype=dtype) x_pd = paddle.to_tensor(w_np, dtype=dtype)
qw_pd, wscale_pd = weight_quantize_xpu(x_pd, algo, -1, -1) qw_pd, wscale_pd = weight_quantize_xpu(x_pd, algo, -1, -1)
qw_pd_trans = paddle.transpose(qw_pd, [1, 0]) qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
@@ -83,12 +82,7 @@ qw_pd_trans = paddle.transpose(qw_pd, [1, 0])
# comparation # comparation
print(f"wscale_pd, mean={wscale_pd.mean()}, std={wscale_pd.std()}") print(f"wscale_pd, mean={wscale_pd.mean()}, std={wscale_pd.std()}")
print(f"wscale_np, mean={wscale_np.mean()}, std={wscale_np.std()}") print(f"wscale_np, mean={wscale_np.mean()}, std={wscale_np.std()}")
print( print(f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}")
f"qw_np, mean={qw_np.astype(np.float32).mean()}, std={qw_np.astype(np.float32).std()}" print(f"qw_pd_trans, mean={qw_pd_trans.astype('float32').mean()}, std={qw_pd_trans.astype('float32').std()}")
) sum_diff = np.sum(np.abs(qw_pd_trans.astype("float32").numpy() - qw_np.astype("float32")))
print(
f"qw_pd_trans, mean={qw_pd_trans.astype('float32').mean()}, std={qw_pd_trans.astype('float32').std()}"
)
sum_diff = np.sum(
np.abs(qw_pd_trans.astype("float32").numpy() - qw_np.astype("float32")))
print(f"sum_diff: {sum_diff}") print(f"sum_diff: {sum_diff}")

View File

@@ -37,4 +37,4 @@ python benchmark_serving.py \
--num-prompts 1 \ --num-prompts 1 \
--max-concurrency 1 \ --max-concurrency 1 \
--save-result --save-result
``` ```

View File

@@ -15,7 +15,7 @@ We provide two transmission methods for KV Cache, targeting intra-machine and in
Uses cudaMemcpyPeer for KV Cache transmission between two GPUs within a single machine, offering low latency and high throughput. Uses cudaMemcpyPeer for KV Cache transmission between two GPUs within a single machine, offering low latency and high throughput.
### Inter-machine Transmission ### Inter-machine Transmission
For transmission between multiple machines, uses high-speed RDMA network for KV Cache transmission. We provide the `rdma_comm` high-speed transmission network library for cross-machine KV Cache transmission. For transmission between multiple machines, uses high-speed RDMA network for KV Cache transmission. We provide the `rdma_comm` high-speed transmission network library for cross-machine KV Cache transmission.
## PD Disaggregated Scheduling ## PD Disaggregated Scheduling
![Splitwise Scheduler](./images/disaggregated.png) ![Splitwise Scheduler](./images/disaggregated.png)
@@ -60,7 +60,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--cache-queue-port 8187 \ --cache-queue-port 8187 \
--tensor-parallel-size 4 \ --tensor-parallel-size 4 \
--quantization wint4 \ --quantization wint4 \
--innode-prefill-ports 8182 \ --innode-prefill-ports 8182 \
--splitwise-role "decode" --splitwise-role "decode"
``` ```
@@ -72,7 +72,8 @@ Refer to the example code `offline_disaggregated_demo.py` in the `fastdeploy/dem
### Multi-machine Disaggregated Deployment ### Multi-machine Disaggregated Deployment
#### Prerequisite: Redis #### Prerequisite: Redis
- Installation via `conda` * Installation via `conda`
```bash ```bash
# Install # Install
conda install redis conda install redis
@@ -80,7 +81,8 @@ conda install redis
nohup redis-server > redis.log 2>&1 & nohup redis-server > redis.log 2>&1 &
``` ```
- Installation via `apt` * Installation via `apt`
```bash ```bash
# Install # Install
sudo apt install redis-server -y sudo apt install redis-server -y
@@ -88,7 +90,8 @@ sudo apt install redis-server -y
sudo systemctl start redis-server sudo systemctl start redis-server
``` ```
- Installation via `yum` * Installation via `yum`
```bash ```bash
# Install # Install
sudo yum install redis -y sudo yum install redis -y

View File

@@ -38,6 +38,7 @@ conda install redis
# Launch # Launch
nohup redis-server > redis.log 2>&1 & nohup redis-server > redis.log 2>&1 &
``` ```
### apt installation (Debian/Ubuntu) ### apt installation (Debian/Ubuntu)
```bash ```bash
@@ -57,6 +58,7 @@ sudo systemctl start redis
``` ```
## Launching FastDeploy ## Launching FastDeploy
```bash ```bash
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--port 8801 \ --port 8801 \
@@ -72,6 +74,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--scheduler-min-load_score 3 \ --scheduler-min-load_score 3 \
--scheduler-load-shards-num 1 --scheduler-load-shards-num 1
``` ```
[Scheduler Launching Parameter](../online_serving/scheduler.md) [Scheduler Launching Parameter](../online_serving/scheduler.md)
### Deployment notes: ### Deployment notes:

View File

@@ -36,4 +36,4 @@ python -m fastdeploy.entrypoints.openai.api_server \
Set `enable_prefix_caching=True` when launching FastDeploy. Enable CPU caching via `swap_space` based on available machine memory. Set `enable_prefix_caching=True` when launching FastDeploy. Enable CPU caching via `swap_space` based on available machine memory.
A test example is provided: `demo/offline_prefix_caching_demo.py` A test example is provided: `demo/offline_prefix_caching_demo.py`

View File

@@ -18,8 +18,9 @@ Interfaces that support toggling the reasoning mode:
For reasoning models, the length of the reasoning content can be controlled via `reasoning_max_tokens`. Add `metadata={"reasoning_max_tokens": 1024}` to the request. For reasoning models, the length of the reasoning content can be controlled via `reasoning_max_tokens`. Add `metadata={"reasoning_max_tokens": 1024}` to the request.
### Quick Start ### Quick Start
When launching the model service, specify the parser name using the `--reasoning-parser` argument. When launching the model service, specify the parser name using the `--reasoning-parser` argument.
This parser will process the model's output and extract the `reasoning_content` field. This parser will process the model's output and extract the `reasoning_content` field.
```bash ```bash
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
--model /path/to/your/model \ --model /path/to/your/model \
@@ -29,7 +30,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
--quantization wint4 \ --quantization wint4 \
--reasoning-parser ernie-45-vl --reasoning-parser ernie-45-vl
``` ```
Next, make a request to the model that should return the reasoning content in the response. Next, make a request to the model that should return the reasoning content in the response.
```bash ```bash
curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@@ -43,10 +46,12 @@ curl -X POST "http://0.0.0.0:8192/v1/chat/completions" \
"metadata": {"enable_thinking": true} "metadata": {"enable_thinking": true}
}' }'
``` ```
The `reasoning_content` field contains the reasoning steps to reach the final conclusion, while the `content` field holds the conclusion itself. The `reasoning_content` field contains the reasoning steps to reach the final conclusion, while the `content` field holds the conclusion itself.
### Streaming chat completions ### Streaming chat completions
Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in `chat completion response chunks` Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in `chat completion response chunks`
```python ```python
from openai import OpenAI from openai import OpenAI
# Set OpenAI's API key and API base to use vLLM's API server. # Set OpenAI's API key and API base to use vLLM's API server.
@@ -69,4 +74,4 @@ for chunk in chat_response:
if chunk.choices[0].delta is not None: if chunk.choices[0].delta is not None:
print(chunk.choices[0].delta, end='') print(chunk.choices[0].delta, end='')
print("\n") print("\n")
``` ```

View File

@@ -10,22 +10,22 @@ This project implements an efficient **Speculative Decoding** inference framewor
- **Ngram** - **Ngram**
- **MTP (Multi-Token Prediction)** - **MTP (Multi-Token Prediction)**
- ✅ Supported: TP Sharding - ✅ Supported: TP Sharding
- ✅ Supported: Shared Prefix - ✅ Supported: Shared Prefix
- ✅ Supported: TP Sharding + PD Separation - ✅ Supported: TP Sharding + PD Separation
- ⏳ Coming Soon: EP + DP + PD Separation - ⏳ Coming Soon: EP + DP + PD Separation
- ⏳ Coming Soon: Support Chunk-prefill - ⏳ Coming Soon: Support Chunk-prefill
- ⏳ Coming Soon: Multi-layer MTP Layer - ⏳ Coming Soon: Multi-layer MTP Layer
--- ---
### Coming Soon ### Coming Soon
- Draft Model - Draft Model
- Eagle - Eagle
- Hydra - Hydra
- Medusa - Medusa
- ... - ...
--- ---
@@ -54,7 +54,7 @@ This project implements an efficient **Speculative Decoding** inference framewor
## 🚀 Using Multi-Token Prediction (MTP) ## 🚀 Using Multi-Token Prediction (MTP)
For detailed theory, refer to: For detailed theory, refer to:
📄 [DeepSeek-V3 Paper](https://arxiv.org/pdf/2412.19437) 📄 [DeepSeek-V3 Paper](https://arxiv.org/pdf/2412.19437)
### TP Sharding Mode ### TP Sharding Mode
@@ -147,4 +147,4 @@ python -m fastdeploy.entrypoints.openai.api_server \
--config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \ --config ${path_to_FastDeploy}benchmarks/yaml/eb45t-32k-wint4-mtp-h100-tp4.yaml \
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}' --speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${mtp_model_path}"}'
``` ```

View File

@@ -132,4 +132,3 @@ Upon completion, accuracy results are saved in ```result.jsonl```, e.g.:
```json ```json
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}} {"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
``` ```

View File

@@ -6,4 +6,4 @@ FastDeploy currently supports installation on the following hardware platforms:
- [Kunlun XPU Installation](kunlunxin_xpu.md) - [Kunlun XPU Installation](kunlunxin_xpu.md)
- [Enflame S60 GCU Installation](Enflame_gcu.md) - [Enflame S60 GCU Installation](Enflame_gcu.md)
- [Iluvatar GPU Installation](iluvatar_gpu.md) - [Iluvatar GPU Installation](iluvatar_gpu.md)
- [Hygon DCU Installation](hygon_dcu.md) - [Hygon DCU Installation](hygon_dcu.md)

View File

@@ -37,6 +37,7 @@ image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04
``` ```
## 2. Start service ## 2. Start service
```bash ```bash
export FD_ATTENTION_BACKEND="BLOCK_ATTN" export FD_ATTENTION_BACKEND="BLOCK_ATTN"
python -m fastdeploy.entrypoints.openai.api_server \ python -m fastdeploy.entrypoints.openai.api_server \
@@ -47,7 +48,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--gpu-memory-utilization=0.8 --gpu-memory-utilization=0.8
``` ```
#### Send requests ### Send requests
Send requests using either curl or Python Send requests using either curl or Python
@@ -78,4 +79,4 @@ response = client.chat.completions.create(
stream=False, stream=False,
) )
print(response) print(response)
``` ```

Some files were not shown because too many files have changed in this diff Show More